'''Tensorflow backend for the slideflow.model submodule.'''
from __future__ import absolute_import, division, print_function
import atexit
import inspect
import json
import logging
import os
import shutil
import numpy as np
import multiprocessing as mp
import tensorflow as tf
from packaging import version
from os.path import dirname, exists, join
from types import SimpleNamespace
from typing import (
TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, Callable, Iterable
)
from tensorflow.keras import applications as kapps
import slideflow as sf
import slideflow.model.base as _base
import slideflow.util.neptune_utils
from slideflow import errors
from slideflow.util import log, NormFit, no_scope
from . import tensorflow_utils as tf_utils
from .base import log_manifest, BaseFeatureExtractor
from .tensorflow_utils import unwrap, flatten, eval_from_model, build_uq_model # type: ignore
# Set the tensorflow logger
if sf.getLoggingLevel() == logging.DEBUG:
logging.getLogger('tensorflow').setLevel(logging.DEBUG)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
else:
logging.getLogger('tensorflow').setLevel(logging.ERROR)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
sf.util.allow_gpu_memory_growth()
if TYPE_CHECKING:
import pandas as pd
from slideflow.norm import StainNormalizer
class StaticDropout(tf.keras.layers.Dropout):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def call(self, inputs, **kwargs):
return super().call(inputs, training=True)
class ModelParams(_base._ModelParams):
"""Build a set of hyperparameters."""
ModelDict = {
'xception': kapps.Xception,
'vgg16': kapps.VGG16,
'vgg19': kapps.VGG19,
'resnet50': kapps.ResNet50,
'resnet101': kapps.ResNet101,
'resnet152': kapps.ResNet152,
'resnet50_v2': kapps.ResNet50V2,
'resnet101_v2': kapps.ResNet101V2,
'resnet152_v2': kapps.ResNet152V2,
'inception': kapps.InceptionV3,
'nasnet_large': kapps.NASNetLarge,
'inception_resnet_v2': kapps.InceptionResNetV2,
'mobilenet': kapps.MobileNet,
'mobilenet_v2': kapps.MobileNetV2,
'densenet_121': kapps.DenseNet121,
'densenet_169': kapps.DenseNet169,
'densenet_201': kapps.DenseNet201,
# 'ResNeXt50': kapps.ResNeXt50,
# 'ResNeXt101': kapps.ResNeXt101,
# 'NASNet': kapps.NASNet
}
OptDict = {
'Adam': tf.keras.optimizers.Adam,
'SGD': tf.keras.optimizers.SGD,
'RMSprop': tf.keras.optimizers.RMSprop,
'Adagrad': tf.keras.optimizers.Adagrad,
'Adadelta': tf.keras.optimizers.Adadelta,
'Adamax': tf.keras.optimizers.Adamax,
'Nadam': tf.keras.optimizers.Nadam
}
if hasattr(kapps, 'EfficientNetV2B0'):
ModelDict.update({'efficientnet_v2b0': kapps.EfficientNetV2B0})
if hasattr(kapps, 'EfficientNetV2B1'):
ModelDict.update({'efficientnet_v2b1': kapps.EfficientNetV2B1})
if hasattr(kapps, 'EfficientNetV2B2'):
ModelDict.update({'efficientnet_v2b2': kapps.EfficientNetV2B2})
if hasattr(kapps, 'EfficientNetV2B3'):
ModelDict.update({'efficientnet_v2b3': kapps.EfficientNetV2B3})
if hasattr(kapps, 'EfficientNetV2S'):
ModelDict.update({'efficientnet_v2s': kapps.EfficientNetV2S})
if hasattr(kapps, 'EfficientNetV2M'):
ModelDict.update({'efficientnet_v2m': kapps.EfficientNetV2M})
if hasattr(kapps, 'EfficientNetV2L'):
ModelDict.update({'efficientnet_v2l': kapps.EfficientNetV2L})
RegressionLossDict = {
loss: getattr(tf.keras.losses, loss)
for loss in [
'mean_squared_error',
'mean_absolute_error',
'mean_absolute_percentage_error',
'mean_squared_logarithmic_error',
'squared_hinge',
'hinge',
'logcosh'
]
}
RegressionLossDict.update({
'negative_log_likelihood': tf_utils.negative_log_likelihood
})
AllLossDict = {
loss: getattr(tf.keras.losses, loss)
for loss in [
'mean_squared_error',
'mean_absolute_error',
'mean_absolute_percentage_error',
'mean_squared_logarithmic_error',
'squared_hinge',
'hinge',
'categorical_hinge',
'logcosh',
'huber',
'categorical_crossentropy',
'sparse_categorical_crossentropy',
'binary_crossentropy',
'kullback_leibler_divergence',
'poisson'
]
}
AllLossDict.update({
'batch_loss_crossentropy': tf_utils.batch_loss_crossentropy,
'negative_log_likelihood': tf_utils.negative_log_likelihood
})
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.model in self.ModelDict.keys()
assert self.optimizer in self.OptDict.keys()
assert self.loss in self.AllLossDict.keys()
def _add_hidden_layers(
self,
model: tf.keras.Model,
regularizer: tf.keras.layers.Layer
) -> Tuple[tf.keras.Model, tf.keras.layers.Layer]:
"""Adds hidden layer(s) to a model.
Args:
model (tf.keras.Model): Tensorflow model.
regularizer (tf.keras.layers.Layer): Regularization for hidden layers.
Returns:
A tuple containing
tf.keras.Model: Model with hidden layers added.
tf.keras.layers.Layer: Last linear layer.
"""
log.debug("Using Batch normalization")
last_linear = None
for i in range(self.hidden_layers):
model = tf.keras.layers.Dense(self.hidden_layer_width,
name=f'hidden_{i}',
activation='relu',
kernel_regularizer=regularizer)(model)
model = tf.keras.layers.BatchNormalization()(model)
last_linear = model
if self.uq:
model = StaticDropout(self.dropout)(model)
elif self.dropout:
model = tf.keras.layers.Dropout(self.dropout)(model)
return model, last_linear
def _get_dense_regularizer(self) -> Optional[tf.keras.layers.Layer]:
"""Return regularizer for dense (hidden) layers."""
if self.l2_dense and not self.l1_dense:
log.debug(f"Using L2 regularization for dense layers (weight={self.l2_dense})")
return tf.keras.regularizers.l2(self.l2_dense)
elif self.l1_dense and not self.l2_dense:
log.debug(f"Using L1 regularization for dense layers (weight={self.l1_dense})")
return tf.keras.regularizers.l1(self.l1_dense)
elif self.l1_dense and self.l2_dense:
log.debug(f"Using L1 (weight={self.l1_dense}) and L2 (weight={self.l2_dense}) reg for dense layers")
return tf.keras.regularizers.l1_l2(l1=self.l1_dense, l2=self.l2_dense)
else:
log.debug("Not using regularization for dense layers")
return None
def _add_regularization(self, model: tf.keras.Model) -> tf.keras.Model:
"""Add non-hidden layer regularization.
Args:
model (tf.keras.Model): Tensorflow model.
Returns:
tf.keras.Model: Tensorflow model with regularization added.
"""
if self.l2 and not self.l1:
log.debug(f"Using L2 regularization for base model (weight={self.l2})")
regularizer = tf.keras.regularizers.l2(self.l2)
elif self.l1 and not self.l2:
log.debug(f"Using L1 regularization for base model (weight={self.l1})")
regularizer = tf.keras.regularizers.l1(self.l1)
elif self.l1 and self.l2:
log.debug(f"Using L1 (weight={self.l1}) and L2 (weight={self.l2}) regularization for base model")
regularizer = tf.keras.regularizers.l1_l2(l1=self.l1, l2=self.l2)
else:
log.debug("Not using regularization for base model")
regularizer = None
if regularizer is not None:
model = tf_utils.add_regularization(model, regularizer)
return model
def _freeze_layers(self, model: tf.keras.Model) -> tf.keras.Model:
"""Freeze last X layers, where X = self.trainable_layers.
Args:
model (tf.keras.Model): Tensorflow model.
Returns:
tf.keras.Model: Tensorflow model with frozen layers.
"""
freezeIndex = int(len(model.layers) - (self.trainable_layers - 1)) # - self.hp.hidden_layers - 1))
log.info(f'Only training on last {self.trainable_layers} layers (of {len(model.layers)} total)')
for layer in model.layers[:freezeIndex]:
layer.trainable = False
return model
def _get_core(self, weights: Optional[str] = None) -> tf.keras.Model:
"""Returns a Keras model of the appropriate architecture, input shape,
pooling, and initial weights.
Args:
weights (Optional[str], optional): Pretrained weights to use.
Defaults to None.
Returns:
tf.keras.Model: Core model.
"""
input_shape = (self.tile_px, self.tile_px, 3)
model_fn = self.ModelDict[self.model]
model_kwargs = {
'input_shape': input_shape,
'include_top': self.include_top,
'pooling': self.pooling,
'weights': weights
}
# Only pass kwargs accepted by model function
model_fn_sig = inspect.signature(model_fn)
model_kw = [
param.name
for param in model_fn_sig.parameters.values()
if param.kind == param.POSITIONAL_OR_KEYWORD
]
model_kwargs = {key: model_kwargs[key] for key in model_kw if key in model_kwargs}
return model_fn(**model_kwargs)
def _build_base(
self,
pretrain: Optional[str] = 'imagenet',
load_method: str = 'weights'
) -> tf.keras.Model:
""""Builds the base image model, from a Keras model core, with the
appropriate input tensors and identity layers.
Args:
pretrain (str, optional): Pretrained weights to load.
Defaults to 'imagenet'.
load_method (str): Either 'full' or 'weights'. Method to use
when loading a Tensorflow model. If 'full', loads the model with
``tf.keras.models.load_model()``. If 'weights', will read the
``params.json`` configuration file, build the model architecture,
and then load weights from the given model with
``Model.load_weights()``. Loading with 'full' may improve
compatibility across Slideflow versions. Loading with 'weights'
may improve compatibility across hardware & environments.
Returns:
tf.keras.Model: Base model.
"""
image_shape = (self.tile_px, self.tile_px, 3)
tile_input_tensor = tf.keras.Input(shape=image_shape, name='tile_image')
if pretrain:
log.debug(f'Using pretraining from [magenta]{pretrain}')
if pretrain and pretrain != 'imagenet':
pretrained_model = load(pretrain, method=load_method, training=True)
try:
# This is the tile_image input
pretrained_input = pretrained_model.get_layer(name='tile_image').input
# Name of the pretrained model core, which should be at layer 1
pretrained_name = pretrained_model.get_layer(index=1).name
# This is the post-convolution layer
pretrained_output = pretrained_model.get_layer(name='post_convolution').output
base_model = tf.keras.Model(inputs=pretrained_input,
outputs=pretrained_output,
name=f'pretrained_{pretrained_name}').layers[1]
except ValueError:
log.warning('Unable to automatically read pretrained model, will try legacy format')
base_model = pretrained_model.get_layer(index=0)
else:
base_model = self._get_core(weights=pretrain)
if self.include_top:
base_model = tf.keras.Model(
inputs=base_model.input,
outputs=base_model.layers[-2].output,
name=base_model.name
)
# Add regularization
base_model = self._add_regularization(base_model)
# Allow only a subset of layers in the base model to be trainable
if self.trainable_layers != 0:
base_model = self._freeze_layers(base_model)
# This is an identity layer that simply returns the last layer, allowing us to name and access this layer later
post_convolution_identity_layer = tf.keras.layers.Activation('linear', name='post_convolution')
layers = [tile_input_tensor, base_model]
if not self.pooling:
layers += [tf.keras.layers.Flatten()]
layers += [post_convolution_identity_layer]
if self.uq:
layers += [StaticDropout(self.dropout)]
elif self.dropout:
layers += [tf.keras.layers.Dropout(self.dropout)]
tile_image_model = tf.keras.Sequential(layers)
model_inputs = [tile_image_model.input]
return tile_image_model, model_inputs
def _build_classification_or_regression_model(
self,
num_classes: Union[int, Dict[Any, int]],
num_slide_features: int = 0,
activation: str = 'softmax',
pretrain: str = 'imagenet',
checkpoint: Optional[str] = None,
load_method: str = 'weights'
) -> tf.keras.Model:
"""Assembles classification or regression model, using pretraining (imagenet)
or the base layers of a supplied model.
Args:
num_classes (int or dict): Either int (single categorical outcome,
indicating number of classes) or dict (dict mapping categorical
outcome names to number of unique categories in each outcome).
num_slide_features (int): Number of slide-level features separate
from image input. Defaults to 0.
activation (str): Type of final layer activation to use.
Defaults to softmax.
pretrain (str): Either 'imagenet' or path to model to use as
pretraining. Defaults to 'imagenet'.
checkpoint (str): Path to checkpoint from which to resume model
training. Defaults to None.
load_method (str): Either 'full' or 'weights'. Method to use
when loading a Tensorflow model. If 'full', loads the model with
``tf.keras.models.load_model()``. If 'weights', will read the
``params.json`` configuration file, build the model architecture,
and then load weights from the given model with
``Model.load_weights()``. Loading with 'full' may improve
compatibility across Slideflow versions. Loading with 'weights'
may improve compatibility across hardware & environments.
"""
tile_image_model, model_inputs = self._build_base(pretrain, load_method)
if num_slide_features:
log.debug(f'Model has {num_slide_features} slide input features')
slide_feature_input_tensor = tf.keras.Input(
shape=(num_slide_features),
name='slide_feature_input'
)
else:
log.debug('Not using any slide-level input features.')
# Merge layers
if num_slide_features and ((self.tile_px == 0) or self.drop_images):
log.info('Generating model with only slide-level input - no images')
merged_model = slide_feature_input_tensor
model_inputs += [slide_feature_input_tensor]
elif num_slide_features:
# Add slide feature input tensors
merged_model = tf.keras.layers.Concatenate(name='input_merge')(
[slide_feature_input_tensor, tile_image_model.output]
)
model_inputs += [slide_feature_input_tensor]
else:
merged_model = tile_image_model.output
# Add hidden layers
regularizer = self._get_dense_regularizer()
merged_model, last_linear = self._add_hidden_layers(
merged_model, regularizer
)
# Multi-categorical outcomes
if isinstance(num_classes, dict):
outputs = []
for c in num_classes:
final_dense_layer = tf.keras.layers.Dense(
num_classes[c],
kernel_regularizer=regularizer,
name=f'logits-{c}'
)(merged_model)
outputs += [
tf.keras.layers.Activation(
activation,
dtype='float32',
name=f'out-{c}'
)(final_dense_layer)
]
else:
final_dense_layer = tf.keras.layers.Dense(
num_classes,
kernel_regularizer=regularizer,
name='logits'
)(merged_model)
outputs = [
tf.keras.layers.Activation(
activation,
dtype='float32',
name='output'
)(final_dense_layer)
]
# Assemble final model
log.debug(f'Using {activation} activation')
model = tf.keras.Model(inputs=model_inputs, outputs=outputs)
# Disable experimental batch loss
if False:
model.add_loss(tf_utils.batch_loss_crossentropy(last_linear))
if checkpoint:
log.info(f'Loading checkpoint weights from [green]{checkpoint}')
model.load_weights(checkpoint)
return model
def _build_survival_model(
self,
num_classes: Union[int, Dict[Any, int]],
num_slide_features: int = 1,
pretrain: Optional[str] = None,
checkpoint: Optional[str] = None,
load_method: str = 'weights',
training: bool = True
) -> tf.keras.Model:
"""Assembles a survival model, using pretraining (imagenet)
or the base layers of a supplied model.
Args:
num_classes (int or dict): Either int (single categorical outcome,
indicating number of classes) or dict (dict mapping categorical
outcome names to number of unique categories in each outcome).
num_slide_features (int): Number of slide-level features separate
from image input. Defaults to 0.
activation (str): Type of final layer activation to use.
Defaults to softmax.
pretrain (str): Either 'imagenet' or path to model to use as
pretraining. Defaults to 'imagenet'.
checkpoint (str): Path to checkpoint from which to resume model
training. Defaults to None.
load_method (str): Either 'full' or 'weights'. Method to use
when loading a Tensorflow model. If 'full', loads the model with
``tf.keras.models.load_model()``. If 'weights', will read the
``params.json`` configuration file, build the model architecture,
and then load weights from the given model with
``Model.load_weights()``. Loading with 'full' may improve
compatibility across Slideflow versions. Loading with 'weights'
may improve compatibility across hardware & environments.
"""
activation = 'linear'
tile_image_model, model_inputs = self._build_base(pretrain, load_method)
# Add slide feature input tensors, if there are more slide features
# than just the event input tensor for survival models
if training:
event_input_tensor = tf.keras.Input(shape=(1), name='event_input')
if not (num_slide_features == 1):
slide_feature_input_tensor = tf.keras.Input(
shape=(num_slide_features - 1),
name='slide_feature_input'
)
# Merge layers
if num_slide_features and ((self.tile_px == 0) or self.drop_images):
# Add images
log.info('Generating model with only slide-level input - no images')
merged_model = slide_feature_input_tensor
model_inputs += [slide_feature_input_tensor]
if training:
model_inputs += [event_input_tensor]
elif num_slide_features and num_slide_features > 1:
# Add slide feature input tensors, if there are more slide features
# than just the event input tensor for survival models
merged_model = tf.keras.layers.Concatenate(name='input_merge')(
[slide_feature_input_tensor, tile_image_model.output]
)
model_inputs += [slide_feature_input_tensor]
if training:
model_inputs += [event_input_tensor]
else:
merged_model = tile_image_model.output
if training:
model_inputs += [event_input_tensor]
# Add hidden layers
regularizer = self._get_dense_regularizer()
merged_model, last_linear = self._add_hidden_layers(
merged_model, regularizer
)
log.debug(f'Using {activation} activation')
# Multi-categorical outcomes
if type(num_classes) == dict:
outputs = []
for c in num_classes:
final_dense_layer = tf.keras.layers.Dense(
num_classes[c],
kernel_regularizer=regularizer,
name=f'logits-{c}'
)(merged_model)
outputs += [tf.keras.layers.Activation(
activation,
dtype='float32',
name=f'out-{c}'
)(final_dense_layer)]
else:
final_dense_layer = tf.keras.layers.Dense(
num_classes,
kernel_regularizer=regularizer,
name='logits'
)(merged_model)
outputs = [tf.keras.layers.Activation(
activation,
dtype='float32',
name='output'
)(final_dense_layer)]
if training:
outputs[0] = tf.keras.layers.Concatenate(
name='output_merge_survival',
dtype='float32'
)([outputs[0], event_input_tensor])
# Assemble final model
model = tf.keras.Model(inputs=model_inputs, outputs=outputs)
if checkpoint:
log.info(f'Loading checkpoint weights from [green]{checkpoint}')
model.load_weights(checkpoint)
return model
def build_model(
self,
labels: Optional[Dict] = None,
num_classes: Optional[Union[int, Dict[Any, int]]] = None,
**kwargs
) -> tf.keras.Model:
"""Auto-detects model type (classification, regression, survival) from parameters
and builds, using pretraining or the base layers of a supplied model.
Args:
labels (dict, optional): Dict mapping slide names to outcomes.
Used to detect number of outcome categories.
num_classes (int or dict, optional): Either int (single categorical
outcome, indicating number of classes) or dict (dict mapping
categorical outcome names to number of unique categories in
each outcome). Must supply either `num_classes` or `label`
(can detect number of classes from labels)
num_slide_features (int, optional): Number of slide-level features
separate from image input. Defaults to 0.
activation (str, optional): Type of final layer activation to use.
Defaults to 'softmax' (classification models) or 'regression'
(regression or survival models).
pretrain (str, optional): Either 'imagenet' or path to model to use
as pretraining. Defaults to 'imagenet'.
checkpoint (str, optional): Path to checkpoint from which to resume
model training. Defaults to None.
load_method (str): Either 'full' or 'weights'. Method to use
when loading a Tensorflow model. If 'full', loads the model with
``tf.keras.models.load_model()``. If 'weights', will read the
``params.json`` configuration file, build the model architecture,
and then load weights from the given model with
``Model.load_weights()``. Loading with 'full' may improve
compatibility across Slideflow versions. Loading with 'weights'
may improve compatibility across hardware & environments.
"""
assert num_classes is not None or labels is not None
if num_classes is None:
num_classes = self._detect_classes_from_labels(labels) # type: ignore
if self.model_type() == 'classification':
return self._build_classification_or_regression_model(
num_classes, **kwargs, activation='softmax'
)
elif self.model_type() == 'regression':
return self._build_classification_or_regression_model(
num_classes, **kwargs, activation='linear'
)
elif self.model_type() == 'survival':
return self._build_survival_model(num_classes, **kwargs)
else:
raise errors.ModelError(f'Unknown model type: {self.model_type()}')
def get_loss(self) -> tf.keras.losses.Loss:
return self.AllLossDict[self.loss]
def get_opt(self) -> tf.keras.optimizers.Optimizer:
"""Returns optimizer with appropriate learning rate."""
if self.learning_rate_decay not in (0, 1):
initial_learning_rate = self.learning_rate
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate,
decay_steps=self.learning_rate_decay_steps,
decay_rate=self.learning_rate_decay,
staircase=True
)
return self.OptDict[self.optimizer](learning_rate=lr_schedule)
else:
return self.OptDict[self.optimizer](learning_rate=self.learning_rate)
def model_type(self) -> str:
"""Returns 'regression', 'classification', or 'survival', reflecting the loss."""
#check if loss is custom_[type] and returns type
if self.loss.startswith('custom'):
return self.loss[7:]
elif self.loss == 'negative_log_likelihood':
return 'survival'
elif self.loss in self.RegressionLossDict:
return 'regression'
else:
return 'classification'
class _PredictionAndEvaluationCallback(tf.keras.callbacks.Callback):
"""Prediction and Evaluation Callback used during model training."""
def __init__(self, parent: "Trainer", cb_args: SimpleNamespace) -> None:
super(_PredictionAndEvaluationCallback, self).__init__()
self.parent = parent
self.hp = parent.hp
self.cb_args = cb_args
self.early_stop = False
self.early_stop_batch = 0
self.early_stop_epoch = 0
self.last_ema = -1 # type: float
self.moving_average = [] # type: List
self.ema_two_checks_prior = -1 # type: float
self.ema_one_check_prior = -1 # type: float
self.epoch_count = cb_args.starting_epoch
self.model_type = self.hp.model_type()
self.results = {'epochs': {}} # type: Dict[str, Dict]
self.neptune_run = self.parent.neptune_run
self.global_step = 0
self.train_summary_writer = tf.summary.create_file_writer(
join(self.parent.outdir, 'train'))
self.val_summary_writer = tf.summary.create_file_writer(
join(self.parent.outdir, 'validation'))
# Circumvents buffer overflow error with Python 3.10.
# Without this, a buffer overflow error will be encountered when
# attempting to make a matplotlib figure (with the tkagg backend)
# during model evaluation. I have not yet been able to track down
# the root cause.
if self.cb_args.using_validation:
import matplotlib.pyplot as plt
plt.figure()
plt.close()
def _log_training_metrics(self, logs):
"""Log training metrics to Tensorboard/Neptune."""
# Log to Tensorboard.
with self.train_summary_writer.as_default():
for _log in logs:
tf.summary.scalar(
f'batch_{_log}',
data=logs[_log],
step=self.global_step)
# Log to neptune.
if self.neptune_run:
self.neptune_run['metrics/train/batch/loss'].log(
logs['loss'],
step=self.global_step)
sf.util.neptune_utils.list_log(
self.neptune_run,
'metrics/train/batch/accuracy',
logs['accuracy'],
step=self.global_step)
def _log_validation_metrics(self, metrics):
"""Log validation metrics to Tensorboard/Neptune."""
# Tensorboard logging for validation metrics
with self.val_summary_writer.as_default():
for _log in metrics:
tf.summary.scalar(
f'batch_{_log}',
data=metrics[_log],
step=self.global_step)
# Log to neptune
if self.neptune_run:
for v in metrics:
self.neptune_run[f"metrics/val/batch/{v}"].log(
round(metrics[v], 3),
step=self.global_step
)
if self.last_ema != -1:
self.neptune_run["metrics/val/batch/exp_moving_avg"].log(
round(self.last_ema, 3),
step=self.global_step
)
self.neptune_run["early_stop/stopped_early"] = False
def _log_epoch_evaluation(self, epoch_results, metrics, accuracy, loss, logs={}):
"""Log the end-of-epoch evaluation to CSV, Tensorboard, and Neptune."""
epoch = self.epoch_count
run = self.neptune_run
sf.util.update_results_log(
self.cb_args.results_log,
'trained_model',
{f'epoch{epoch}': epoch_results}
)
with self.val_summary_writer.as_default():
# Note: Tensorboard epoch logging starts with index=0,
# whereas all other logging starts with index=1
if isinstance(accuracy, (list, tuple, np.ndarray)):
for i in range(len(accuracy)):
tf.summary.scalar(f'epoch_accuracy-{i}', data=accuracy[i], step=epoch-1)
elif accuracy is not None:
tf.summary.scalar(f'epoch_accuracy', data=accuracy, step=epoch-1)
if isinstance(loss, (list, tuple, np.ndarray)):
for i in range(len(loss)):
tf.summary.scalar(f'epoch_loss-{i}', data=loss[i], step=epoch-1)
else:
tf.summary.scalar(f'epoch_loss', data=loss, step=epoch-1)
# Log epoch results to Neptune
if run:
# Training epoch metrics
run['metrics/train/epoch/loss'].log(logs['loss'], step=epoch)
sf.util.neptune_utils.list_log(
run,
'metrics/train/epoch/accuracy',
logs['accuracy'],
step=epoch
)
# Validation epoch metrics
run['metrics/val/epoch/loss'].log(loss, step=epoch)
sf.util.neptune_utils.list_log(
run,
'metrics/val/epoch/accuracy',
accuracy,
step=epoch
)
for metric in metrics:
if metrics[metric]['tile'] is None:
continue
for outcome in metrics[metric]['tile']:
# If only one outcome, log to metrics/val/epoch/[metric].
# If more than one outcome, log to
# metrics/val/epoch/[metric]/[outcome_name]
def metric_label(s):
if len(metrics[metric]['tile']) == 1:
return f'metrics/val/epoch/{s}_{metric}'
else:
return f'metrics/val/epoch/{s}_{metric}/{outcome}'
tile_metric = metrics[metric]['tile'][outcome]
slide_metric = metrics[metric]['slide'][outcome]
patient_metric = metrics[metric]['patient'][outcome]
# If only one value for a metric, log to .../[metric]
# If more than one value for a metric (e.g. AUC for each
# category), log to .../[metric]/[i]
sf.util.neptune_utils.list_log(
run,
metric_label('tile'),
tile_metric,
step=epoch
)
sf.util.neptune_utils.list_log(
run,
metric_label('slide'),
slide_metric,
step=epoch
)
sf.util.neptune_utils.list_log(
run,
metric_label('patient'),
patient_metric,
step=epoch
)
def _metrics_from_dataset(
self,
epoch_label: str,
) -> Tuple[Dict, float, float]:
return sf.stats.metrics_from_dataset(
self.model,
model_type=self.hp.model_type(),
patients=self.parent.patients,
dataset=self.cb_args.validation_data,
outcome_names=self.parent.outcome_names,
label=epoch_label,
data_dir=self.parent.outdir,
num_tiles=self.cb_args.num_val_tiles,
save_predictions=self.cb_args.save_predictions,
reduce_method=self.cb_args.reduce_method,
loss=self.hp.get_loss(),
uq=bool(self.hp.uq),
)
def on_epoch_end(self, epoch: int, logs={}) -> None:
if sf.getLoggingLevel() <= 20:
print('\r\033[K', end='')
self.epoch_count += 1
if (self.epoch_count in [e for e in self.hp.epochs]
or self.early_stop):
if self.parent.name:
model_name = self.parent.name
else:
model_name = 'trained_model'
model_path = os.path.join(
self.parent.outdir,
f'{model_name}_epoch{self.epoch_count}'
)
if self.cb_args.save_model:
self.model.save(model_path)
log.info(f'Trained model saved to [green]{model_path}')
# Try to copy model settings/hyperparameters file
# into the model folder
params_dest = join(model_path, 'params.json')
if not exists(params_dest):
try:
config_path = join(dirname(model_path), 'params.json')
if self.neptune_run:
config = sf.util.load_json(config_path)
config['neptune_id'] = self.neptune_run['sys/id'].fetch()
sf.util.write_json(config, config_path)
shutil.copy(config_path, params_dest)
shutil.copy(
join(dirname(model_path), 'slide_manifest.csv'),
join(model_path, 'slide_manifest.csv')
)
except Exception as e:
log.warning(e)
log.warning('Unable to copy params.json/slide_manifest'
'.csv files into model folder.')
if self.cb_args.using_validation:
self.evaluate_model(logs)
elif self.early_stop:
self.evaluate_model(logs)
self.model.stop_training = self.early_stop
def on_train_batch_end(self, batch: int, logs={}) -> None:
# Tensorboard logging for training metrics
if batch > 0 and batch % self.cb_args.log_frequency == 0:
#with self.train_summary_writer.as_default():
self._log_training_metrics(logs)
# Check if manual early stopping has been triggered
if (self.hp.early_stop
and self.hp.early_stop_method == 'manual'):
assert self.hp.manual_early_stop_batch is not None
assert self.hp.manual_early_stop_epoch is not None
if (self.hp.manual_early_stop_epoch <= (self.epoch_count+1)
and self.hp.manual_early_stop_batch <= batch):
log.info('Manual early stop triggered: epoch '
f'{self.epoch_count+1}, batch {batch}')
self.model.stop_training = True
self.early_stop = True
self.early_stop_batch = batch
self.early_stop_epoch = self.epoch_count + 1
# Validation metrics
if (self.cb_args.using_validation and self.cb_args.validate_on_batch
and (batch > 0)
and (batch % self.cb_args.validate_on_batch == 0)):
_, acc, loss = eval_from_model(
self.model,
self.cb_args.mid_train_validation_data,
model_type=self.hp.model_type(),
uq=False,
loss=self.hp.get_loss(),
steps=self.cb_args.validation_steps,
verbosity='quiet',
)
val_metrics = {'loss': loss}
val_log_metrics = {'loss': loss}
if isinstance(acc, float):
val_metrics['accuracy'] = acc
val_log_metrics['accuracy'] = acc
elif acc is not None:
val_metrics.update({f'accuracy-{i+1}': acc[i] for i in range(len(acc))})
val_log_metrics.update({f'out-{i}_accuracy': acc[i] for i in range(len(acc))})
val_loss = val_metrics['loss']
self.model.stop_training = False
if (self.hp.early_stop_method == 'accuracy'
and 'accuracy' in val_metrics):
early_stop_value = val_metrics['accuracy']
val_acc = f"{val_metrics['accuracy']:.3f}"
else:
early_stop_value = val_loss
val_acc = ', '.join([
f'{val_metrics[v]:.3f}'
for v in val_metrics
if 'accuracy' in v
])
if 'accuracy' in logs:
train_acc = f"{logs['accuracy']:.3f}"
else:
train_acc = ', '.join([
f'{logs[v]:.3f}'
for v in logs
if 'accuracy' in v
])
if sf.getLoggingLevel() <= 20:
print('\r\033[K', end='')
self.moving_average += [early_stop_value]
self._log_validation_metrics(val_log_metrics)
# Log training metrics if not already logged this batch
if batch % self.cb_args.log_frequency > 0:
self._log_training_metrics(logs)
# Base logging message
batch_msg = f'[blue]Batch {batch:<5}[/]'
loss_msg = f"[green]loss[/]: {logs['loss']:.3f}"
val_loss_msg = f"[magenta]val_loss[/]: {val_loss:.3f}"
if self.model_type == 'classification':
acc_msg = f"[green]acc[/]: {train_acc}"
val_acc_msg = f"[magenta]val_acc[/]: {val_acc}"
log_message = f"{batch_msg} {loss_msg}, {acc_msg} | "
log_message += f"{val_loss_msg}, {val_acc_msg}"
else:
log_message = f"{batch_msg} {loss_msg} | {val_loss_msg}"
# Calculate exponential moving average of validation accuracy
if len(self.moving_average) <= self.cb_args.ema_observations:
log.info(log_message)
else:
# Only keep track of the last [ema_observations] val accuracies
self.moving_average.pop(0)
if self.last_ema == -1:
# Calculate simple moving average
self.last_ema = (sum(self.moving_average)
/ len(self.moving_average))
log.info(log_message + f' (SMA: {self.last_ema:.3f})')
else:
# Update exponential moving average
sm = self.cb_args.ema_smoothing
obs = self.cb_args.ema_observations
self.last_ema = ((early_stop_value * (sm / (1 + obs)))
+ (self.last_ema * (1 - (sm / (1 + obs)))))
log.info(log_message + f' (EMA: {self.last_ema:.3f})')
# If early stopping and our patience criteria has been met,
# check if validation accuracy is still improving
steps_per_epoch = self.cb_args.steps_per_epoch
if (self.hp.early_stop
and self.hp.early_stop_method in ('loss', 'accuracy')
and self.last_ema != -1
and ((float(batch) / steps_per_epoch) + self.epoch_count)
> self.hp.early_stop_patience):
if (self.ema_two_checks_prior != -1
and ((self.hp.early_stop_method == 'accuracy'
and self.last_ema <= self.ema_two_checks_prior)
or (self.hp.early_stop_method == 'loss'
and self.last_ema >= self.ema_two_checks_prior))):
log.info(f'Early stop: epoch {self.epoch_count+1}, batch '
f'{batch}')
self.model.stop_training = True
self.early_stop = True
self.early_stop_batch = batch
self.early_stop_epoch = self.epoch_count + 1
# Log early stop to neptune
if self.neptune_run:
self.neptune_run["early_stop/early_stop_epoch"] = self.epoch_count
self.neptune_run["early_stop/early_stop_batch"] = batch
self.neptune_run["early_stop/method"] = self.hp.early_stop_method
self.neptune_run["early_stop/stopped_early"] = self.early_stop
self.neptune_run["sys/tags"].add("early_stopped")
else:
self.ema_two_checks_prior = self.ema_one_check_prior
self.ema_one_check_prior = self.last_ema
# Update global step (for tracking metrics across epochs)
self.global_step += 1
def on_train_end(self, logs={}) -> None:
if sf.getLoggingLevel() <= 20:
print('\r\033[K')
if self.neptune_run:
self.neptune_run['sys/tags'].add('training_complete')
def evaluate_model(self, logs={}) -> None:
log.debug("Evaluating model from evaluation callback")
epoch = self.epoch_count
metrics, acc, loss = self._metrics_from_dataset(f'val_epoch{epoch}')
# Note that Keras loss during training includes regularization losses,
# so this loss will not match validation loss calculated during training
val_metrics = {'accuracy': acc, 'loss': loss}
log.info('Validation metrics: ' + json.dumps(val_metrics, indent=4))
self.results['epochs'][f'epoch{epoch}'] = {
'train_metrics': {k: v for k, v in logs.items() if k[:3] != 'val'},
'val_metrics': val_metrics
}
if self.early_stop:
self.results['epochs'][f'epoch{epoch}'].update({
'early_stop_epoch': self.early_stop_epoch,
'early_stop_batch': self.early_stop_batch,
})
for m in metrics:
if metrics[m]['tile'] is None:
continue
self.results['epochs'][f'epoch{epoch}'][f'tile_{m}'] = metrics[m]['tile']
self.results['epochs'][f'epoch{epoch}'][f'slide_{m}'] = metrics[m]['slide']
self.results['epochs'][f'epoch{epoch}'][f'patient_{m}'] = metrics[m]['patient']
epoch_results = self.results['epochs'][f'epoch{epoch}']
self._log_epoch_evaluation(
epoch_results, metrics=metrics, accuracy=acc, loss=loss, logs=logs
)
class Trainer:
"""Base trainer class containing functionality for model building, input
processing, training, and evaluation.
This base class requires categorical outcome(s). Additional outcome types
are supported by :class:`slideflow.model.RegressionTrainer` and
:class:`slideflow.model.SurvivalTrainer`.
Slide-level (e.g. clinical) features can be used as additional model input
by providing slide labels in the slide annotations dictionary, under the
key 'input'.
"""
_model_type = 'classification'
def __init__(
self,
hp: ModelParams,
outdir: str,
labels: Dict[str, Any],
*,
slide_input: Optional[Dict[str, Any]] = None,
name: str = 'Trainer',
feature_sizes: Optional[List[int]] = None,
feature_names: Optional[List[str]] = None,
outcome_names: Optional[List[str]] = None,
mixed_precision: bool = True,
allow_tf32: bool = False,
config: Dict[str, Any] = None,
use_neptune: bool = False,
neptune_api: Optional[str] = None,
neptune_workspace: Optional[str] = None,
load_method: str = 'weights',
custom_objects: Optional[Dict[str, Any]] = None,
transform: Optional[Union[Callable, Dict[str, Callable]]] = None,
) -> None:
"""Sets base configuration, preparing model inputs and outputs.
Args:
hp (:class:`slideflow.ModelParams`): ModelParams object.
outdir (str): Path for event logs and checkpoints.
labels (dict): Dict mapping slide names to outcome labels (int or
float format).
slide_input (dict): Dict mapping slide names to additional
slide-level input, concatenated after post-conv.
name (str, optional): Optional name describing the model, used for
model saving. Defaults to 'Trainer'.
feature_sizes (list, optional): List of sizes of input features.
Required if providing additional input features as input to
the model.
feature_names (list, optional): List of names for input features.
Used when permuting feature importance.
outcome_names (list, optional): Name of each outcome. Defaults to
"Outcome {X}" for each outcome.
mixed_precision (bool, optional): Use FP16 mixed precision (rather
than FP32). Defaults to True.
allow_tf32 (bool): Allow internal use of Tensorfloat-32 format.
Defaults to False.
load_method (str): Either 'full' or 'weights'. Method to use
when loading a Tensorflow model. If 'full', loads the model with
``tf.keras.models.load_model()``. If 'weights', will read the
``params.json`` configuration file, build the model architecture,
and then load weights from the given model with
``Model.load_weights()``. Loading with 'full' may improve
compatibility across Slideflow versions. Loading with 'weights'
may improve compatibility across hardware & environments.
config (dict, optional): Training configuration dictionary, used
for logging and image format verification. Defaults to None.
use_neptune (bool, optional): Use Neptune API logging.
Defaults to False
neptune_api (str, optional): Neptune API token, used for logging.
Defaults to None.
neptune_workspace (str, optional): Neptune workspace.
Defaults to None.
custom_objects (dict, Optional): Dictionary mapping names
(strings) to custom classes or functions. Defaults to None.
transform (callable or dict, optional): Optional transform to
apply to input images. If dict, must have the keys 'train'
and/or 'val', mapping to callables that takes a single
image Tensor as input and returns a single image Tensor.
If None, no transform is applied. If a single callable is
provided, it will be applied to both training and validation
data. If a dict is provided, the 'train' transform will be
applied to training data and the 'val' transform will be
applied to validation data. If a dict is provided and either
'train' or 'val' is None, no transform will be applied to
that data. Defaults to None.
"""
if load_method not in ('full', 'weights'):
raise ValueError("Unrecognized value for load_method, must be "
"either 'full' or 'weights'.")
self.outdir = outdir
self.tile_px = hp.tile_px
self.labels = labels
self.hp = hp
self.slides = list(labels.keys())
self.slide_input = slide_input
self.feature_names = feature_names
self.feature_sizes = feature_sizes
self.num_slide_features = 0 if not feature_sizes else sum(feature_sizes)
self.mixed_precision = mixed_precision
self._allow_tf32 = allow_tf32
self.name = name
self.neptune_run = None
self.annotations_tables = []
self.eval_callback = _PredictionAndEvaluationCallback # type: tf.keras.callbacks.Callback
self.load_method = load_method
self.custom_objects = custom_objects
self.patients = dict()
if not os.path.exists(outdir):
os.makedirs(outdir)
# Format outcome labels (ensures compatibility with single
# and multi-outcome models)
outcome_labels = np.array(list(labels.values()))
if len(outcome_labels.shape) == 1:
outcome_labels = np.expand_dims(outcome_labels, axis=1)
if not outcome_names:
outcome_names = [
f'Outcome {i}'
for i in range(outcome_labels.shape[1])
]
outcome_names = sf.util.as_list(outcome_names)
if labels and (len(outcome_names) != outcome_labels.shape[1]):
num_names = len(outcome_names)
num_outcomes = outcome_labels.shape[1]
raise errors.ModelError(f'Size of outcome_names ({num_names}) != '
f'number of outcomes {num_outcomes}')
self.outcome_names = outcome_names
self._setup_inputs()
if labels:
self.num_classes = self.hp._detect_classes_from_labels(labels)
with tf.device('/cpu'):
for oi in range(outcome_labels.shape[1]):
self.annotations_tables += [tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(
self.slides,
outcome_labels[:, oi]
), -1
)]
else:
self.num_classes = None # type: ignore
# Normalization setup
self.normalizer = self.hp.get_normalizer()
if self.normalizer:
log.info(f'Using realtime {self.hp.normalizer} normalization')
# Mixed precision and Tensorfloat-32
if self.mixed_precision:
_policy = 'mixed_float16'
log.debug(f'Enabling mixed precision ({_policy})')
if version.parse(tf.__version__) > version.parse("2.8"):
tf.keras.mixed_precision.set_global_policy(_policy)
else:
policy = tf.keras.mixed_precision.experimental.Policy(_policy)
tf.keras.mixed_precision.experimental.set_policy(policy)
tf.config.experimental.enable_tensor_float_32_execution(allow_tf32)
# Custom transforms
self._process_transforms(transform)
# Log parameters
if config is None:
config = {
'slideflow_version': sf.__version__,
'backend': sf.backend(),
'git_commit': sf.__gitcommit__,
'model_name': self.name,
'full_model_name': self.name,
'outcomes': self.outcome_names,
'model_type': self.hp.model_type(),
'img_format': None,
'tile_px': self.hp.tile_px,
'tile_um': self.hp.tile_um,
'input_features': None,
'input_feature_sizes': None,
'input_feature_labels': None,
'hp': self.hp.to_dict()
}
sf.util.write_json(config, join(self.outdir, 'params.json'))
self.config = config
self.img_format = config['img_format'] if 'img_format' in config else None
# Initialize Neptune
self.use_neptune = use_neptune
if self.use_neptune:
if neptune_api is None or neptune_workspace is None:
raise ValueError("If using Neptune, must supply values "
"neptune_api and neptune_workspace.")
self.neptune_logger = sf.util.neptune_utils.NeptuneLog(
neptune_api,
neptune_workspace
)
def _process_transforms(
self,
transform: Optional[Union[Callable, Dict[str, Callable]]] = None
) -> None:
"""Process custom transformations for training and/or validation."""
if not isinstance(transform, dict):
transform = {'train': transform, 'val': transform}
if any([t not in ('train', 'val') for t in transform]):
raise ValueError("transform must be a callable or dict with keys "
"'train' and/or 'val'")
if 'train' not in transform:
transform['train'] = None
if 'val' not in transform:
transform['val'] = None
self.transform = transform
def _setup_inputs(self) -> None:
"""Setup slide-level input."""
if self.num_slide_features:
assert self.slide_input is not None
try:
if self.num_slide_features:
log.info(f'Training with both images and '
f'{self.num_slide_features} slide-level input'
'features')
except KeyError:
raise errors.ModelError("Unable to find slide-level input at "
"'input' key in annotations")
for slide in self.slides:
if len(self.slide_input[slide]) != self.num_slide_features:
num_in_feature_table = len(self.slide_input[slide])
raise errors.ModelError(
f'Length of input for slide {slide} does not match '
f'feature_sizes; expected {self.num_slide_features}, '
f'got {num_in_feature_table}'
)
def _compile_model(self) -> None:
"""Compile keras model."""
self.model.compile(
optimizer=self.hp.get_opt(),
loss=self.hp.get_loss(),
metrics=['accuracy']
)
def _fit_normalizer(self, norm_fit: Optional[NormFit]) -> None:
"""Fit the Trainer normalizer using the specified fit, if applicable.
Args:
norm_fit (Optional[Dict[str, np.ndarray]]): Normalizer fit.
"""
if norm_fit is not None and not self.normalizer:
raise ValueError("norm_fit supplied, but model params do not"
"specify a normalizer.")
if self.normalizer and norm_fit is not None:
self.normalizer.set_fit(**norm_fit) # type: ignore
elif (self.normalizer
and 'norm_fit' in self.config
and self.config['norm_fit'] is not None):
log.debug("Detecting normalizer fit from model config")
self.normalizer.set_fit(**self.config['norm_fit'])
def _parse_tfrecord_labels(
self,
image: tf.Tensor,
slide: tf.Tensor
) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
"""Parses raw entry read from TFRecord."""
image_dict = {'tile_image': image}
if self.num_classes is None:
label = None
elif len(self.num_classes) > 1: # type: ignore
label = {
f'out-{oi}': self.annotations_tables[oi].lookup(slide)
for oi in range(len(self.num_classes)) # type: ignore
}
else:
label = self.annotations_tables[0].lookup(slide)
# Add additional non-image feature inputs if indicated,
# excluding the event feature used for survival models
if self.num_slide_features:
def slide_lookup(s):
return self.slide_input[s.numpy().decode('utf-8')]
num_features = self.num_slide_features
slide_feature_input_val = tf.py_function(
func=slide_lookup,
inp=[slide],
Tout=[tf.float32] * num_features
)
image_dict.update({'slide_feature_input': slide_feature_input_val})
return image_dict, label
def _retrain_top_layers(
self,
train_data: tf.data.Dataset,
steps_per_epoch: int,
callbacks: tf.keras.callbacks.Callback = None,
epochs: int = 1
) -> Dict:
"""Retrain only the top layer, leaving all other layers frozen."""
log.info('Retraining top layer')
# Freeze the base layer
self.model.layers[0].trainable = False
#val_steps = 200 if validation_data else None
self._compile_model()
toplayer_model = self.model.fit(
train_data,
epochs=epochs,
verbose=(sf.getLoggingLevel() <= 20),
steps_per_epoch=steps_per_epoch,
callbacks=callbacks
)
# Unfreeze the base layer
self.model.layers[0].trainable = True
return toplayer_model.history
def _detect_patients(self, *args):
self.patients = dict()
for dataset in args:
if dataset is None:
continue
dataset_patients = dataset.patients()
if not dataset_patients:
self.patients.update({s: s for s in self.slides})
else:
self.patients.update(dataset_patients)
def _interleave_kwargs(self, **kwargs) -> Dict[str, Any]:
args = SimpleNamespace(
labels=self._parse_tfrecord_labels,
normalizer=self.normalizer,
**kwargs
)
return vars(args)
def _interleave_kwargs_val(self, **kwargs) -> Dict[str, Any]:
return self._interleave_kwargs(**kwargs)
def _metric_kwargs(self, **kwargs) -> Dict[str, Any]:
args = SimpleNamespace(
model=self.model,
model_type=self._model_type,
patients=self.patients,
outcome_names=self.outcome_names,
data_dir=self.outdir,
neptune_run=self.neptune_run,
**kwargs
)
return vars(args)
def _verify_img_format(self, dataset, *datasets: Optional["sf.Dataset"]) -> str:
"""Verify that the image format of the dataset matches the model config.
Args:
dataset (sf.Dataset): Dataset to check.
*datasets (sf.Dataset): Additional datasets to check. May be None.
Returns:
str: Image format, either 'png' or 'jpg', if a consistent image
format was found, otherwise None.
"""
# First, verify all datasets have the same image format
img_formats = set([d.img_format for d in datasets if d])
if len(img_formats) > 1:
log.error("Multiple image formats detected: {}.".format(
', '.join(img_formats)
))
return None
elif self.img_format and not dataset.img_format:
log.warning("Unable to verify image format (PNG/JPG) of dataset.")
return None
elif self.img_format and dataset.img_format != self.img_format:
log.error(
"Mismatched image formats. Expected '{}' per model config, "
"but dataset has format '{}'.".format(
self.img_format,
dataset.img_format))
return None
else:
return dataset.img_format
def load(self, model: str, **kwargs) -> tf.keras.Model:
self.model = load(
model,
method=self.load_method,
custom_objects=self.custom_objects,
**kwargs
)
def predict(
self,
dataset: "sf.Dataset",
batch_size: Optional[int] = None,
norm_fit: Optional[NormFit] = None,
format: str = 'parquet',
from_wsi: bool = False,
roi_method: str = 'auto',
reduce_method: Union[str, Callable] = 'average',
) -> Dict[str, "pd.DataFrame"]:
"""Perform inference on a model, saving tile-level predictions.
Args:
dataset (:class:`slideflow.dataset.Dataset`): Dataset containing
TFRecords to evaluate.
batch_size (int, optional): Evaluation batch size. Defaults to the
same as training (per self.hp)
norm_fit (Dict[str, np.ndarray]): Normalizer fit, mapping fit
parameters (e.g. target_means, target_stds) to values
(np.ndarray). If not provided, will fit normalizer using
model params (if applicable). Defaults to None.
format (str, optional): Format in which to save predictions. Either
'csv', 'feather', or 'parquet'. Defaults to 'parquet'.
from_wsi (bool): Generate predictions from tiles dynamically
extracted from whole-slide images, rather than TFRecords.
Defaults to False (use TFRecords).
roi_method (str): ROI method to use if from_wsi=True (ignored if
from_wsi=False). Either 'inside', 'outside', 'auto', 'ignore'.
If 'inside' or 'outside', will extract tiles in/out of an ROI,
and raise errors.MissingROIError if an ROI is not available.
If 'auto', will extract tiles inside an ROI if available,
and across the whole-slide if no ROI is found.
If 'ignore', will extract tiles across the whole-slide
regardless of whether an ROI is available.
Defaults to 'auto'.
reduce_method (str, optional): Reduction method for calculating
slide-level and patient-level predictions for categorical
outcomes. Options include 'average', 'mean', 'proportion',
'median', 'sum', 'min', 'max', or a callable function.
'average' and 'mean' are synonymous, with both options kept
for backwards compatibility. If 'average' or 'mean', will
reduce with average of each logit across tiles. If
'proportion', will convert tile predictions into onehot encoding
then reduce by averaging these onehot values. For all other
values, will reduce with the specified function, applied via
the pandas ``DataFrame.agg()`` function. Defaults to 'average'.
Returns:
Dict[str, pd.DataFrame]: Dictionary with keys 'tile', 'slide', and
'patient', and values containing DataFrames with tile-, slide-,
and patient-level predictions.
"""
if format not in ('csv', 'feather', 'parquet'):
raise ValueError(f"Unrecognized format {format}")
self._detect_patients(dataset)
# Verify image format
self._verify_img_format(dataset)
# Fit normalizer
self._fit_normalizer(norm_fit)
# Load and initialize model
if not self.model:
raise errors.ModelNotLoadedError
log_manifest(
None,
dataset.tfrecords(),
labels=self.labels,
filename=join(self.outdir, 'slide_manifest.csv')
)
if not batch_size:
batch_size = self.hp.batch_size
with tf.name_scope('input'):
interleave_kwargs = self._interleave_kwargs_val(
batch_size=batch_size,
infinite=False,
transform=self.transform['val'],
augment=False
)
tf_dts_w_slidenames = dataset.tensorflow(
incl_loc=True,
incl_slidenames=True,
from_wsi=from_wsi,
roi_method=roi_method,
**interleave_kwargs
)
# Generate predictions
log.info('Generating predictions...')
dfs = sf.stats.predict_dataset(
model=self.model,
dataset=tf_dts_w_slidenames,
model_type=self._model_type,
uq=bool(self.hp.uq),
num_tiles=dataset.num_tiles,
outcome_names=self.outcome_names,
patients=self.patients,
reduce_method=reduce_method,
)
# Save predictions
sf.stats.metrics.save_dfs(dfs, format=format, outdir=self.outdir)
return dfs
def evaluate(
self,
dataset: "sf.Dataset",
batch_size: Optional[int] = None,
save_predictions: Union[bool, str] = 'parquet',
reduce_method: Union[str, Callable] = 'average',
norm_fit: Optional[NormFit] = None,
uq: Union[bool, str] = 'auto',
from_wsi: bool = False,
roi_method: str = 'auto',
) -> Dict[str, Any]:
"""Evaluate model, saving metrics and predictions.
Args:
dataset (:class:`slideflow.dataset.Dataset`): Dataset containing
TFRecords to evaluate.
batch_size (int, optional): Evaluation batch size. Defaults to the
same as training (per self.hp)
save_predictions (bool or str, optional): Save tile, slide, and
patient-level predictions at each evaluation. May be 'csv',
'feather', or 'parquet'. If False, will not save predictions.
Defaults to 'parquet'.
reduce_method (str, optional): Reduction method for calculating
slide-level and patient-level predictions for categorical
outcomes. Options include 'average', 'mean', 'proportion',
'median', 'sum', 'min', 'max', or a callable function.
'average' and 'mean' are synonymous, with both options kept
for backwards compatibility. If 'average' or 'mean', will
reduce with average of each logit across tiles. If
'proportion', will convert tile predictions into onehot encoding
then reduce by averaging these onehot values. For all other
values, will reduce with the specified function, applied via
the pandas ``DataFrame.agg()`` function. Defaults to 'average'.
norm_fit (Dict[str, np.ndarray]): Normalizer fit, mapping fit
parameters (e.g. target_means, target_stds) to values
(np.ndarray). If not provided, will fit normalizer using
model params (if applicable). Defaults to None.
uq (bool or str, optional): Enable UQ estimation (for
applicable models). Defaults to 'auto'.
Returns:
Dictionary of evaluation metrics.
"""
if uq != 'auto':
if not isinstance(uq, bool):
raise ValueError(f"Unrecognized value {uq} for uq")
self.hp.uq = uq
self._detect_patients(dataset)
# Verify image format
self._verify_img_format(dataset)
# Perform evaluation
_unit_type = 'slides' if from_wsi else 'tfrecords'
log.info(f'Evaluating {len(dataset.tfrecords())} {_unit_type}')
# Fit normalizer
self._fit_normalizer(norm_fit)
# Load and initialize model
if not self.model:
raise errors.ModelNotLoadedError
log_manifest(
None,
dataset.tfrecords(),
labels=self.labels,
filename=join(self.outdir, 'slide_manifest.csv')
)
# Neptune logging
if self.use_neptune:
assert self.neptune_run is not None
self.neptune_run = self.neptune_logger.start_run(
self.name,
self.config['project'],
dataset,
tags=['eval']
)
self.neptune_logger.log_config(self.config, 'eval')
self.neptune_run['data/slide_manifest'].upload(
join(self.outdir, 'slide_manifest.csv')
)
if not batch_size:
batch_size = self.hp.batch_size
with tf.name_scope('input'):
interleave_kwargs = self._interleave_kwargs_val(
batch_size=batch_size,
infinite=False,
transform=self.transform['val'],
augment=False
)
tf_dts_w_slidenames = dataset.tensorflow(
incl_slidenames=True,
incl_loc=True,
from_wsi=from_wsi,
roi_method=roi_method,
**interleave_kwargs
)
# Generate performance metrics
log.info('Calculating performance metrics...')
metric_kwargs = self._metric_kwargs(
dataset=tf_dts_w_slidenames,
num_tiles=dataset.num_tiles,
label='eval'
)
metrics, acc, loss = sf.stats.metrics_from_dataset(
save_predictions=save_predictions,
reduce_method=reduce_method,
loss=self.hp.get_loss(),
uq=bool(self.hp.uq),
**metric_kwargs
)
results = {'eval': {}} # type: Dict[str, Dict[str, float]]
for metric in metrics:
if metrics[metric]:
log.info(f"Tile {metric}: {metrics[metric]['tile']}")
log.info(f"Slide {metric}: {metrics[metric]['slide']}")
log.info(f"Patient {metric}: {metrics[metric]['patient']}")
results['eval'].update({
f'tile_{metric}': metrics[metric]['tile'],
f'slide_{metric}': metrics[metric]['slide'],
f'patient_{metric}': metrics[metric]['patient']
})
# Note that Keras loss during training includes regularization losses,
# so this loss will not match validation loss calculated during training
val_metrics = {'accuracy': acc, 'loss': loss}
results_log = os.path.join(self.outdir, 'results_log.csv')
log.info('Evaluation metrics:')
for m in val_metrics:
log.info(f'{m}: {val_metrics[m]}')
results['eval'].update(val_metrics)
sf.util.update_results_log(results_log, 'eval_model', results)
# Update neptune log
if self.neptune_run:
self.neptune_run['eval/results'] = val_metrics
self.neptune_run.stop()
return results
def train(
self,
train_dts: "sf.Dataset",
val_dts: Optional["sf.Dataset"],
log_frequency: int = 100,
validate_on_batch: int = 0,
validation_batch_size: int = None,
validation_steps: int = 200,
starting_epoch: int = 0,
ema_observations: int = 20,
ema_smoothing: int = 2,
use_tensorboard: bool = True,
steps_per_epoch_override: int = 0,
save_predictions: Union[bool, str] = 'parquet',
save_model: bool = True,
resume_training: Optional[str] = None,
pretrain: Optional[str] = 'imagenet',
checkpoint: Optional[str] = None,
save_checkpoints: bool = True,
multi_gpu: bool = False,
reduce_method: Union[str, Callable] = 'average',
norm_fit: Optional[NormFit] = None,
from_wsi: bool = False,
roi_method: str = 'auto',
) -> Dict[str, Any]:
"""Builds and trains a model from hyperparameters.
Args:
train_dts (:class:`slideflow.Dataset`): Training dataset. Will call
the `.tensorflow()` method to retrieve the tf.data.Dataset
used for model fitting.
val_dts (:class:`slideflow.Dataset`): Validation dataset. Will call
the `.tensorflow()` method to retrieve the tf.data.Dataset
used for model fitting.
log_frequency (int, optional): How frequent to update Tensorboard
logs, in batches. Defaults to 100.
validate_on_batch (int, optional): Validation will also be performed
every N batches. Defaults to 0.
validation_batch_size (int, optional): Validation batch size.
Defaults to same as training (per self.hp).
validation_steps (int, optional): Number of batches to use for each
instance of validation. Defaults to 200.
starting_epoch (int, optional): Starts training at the specified
epoch. Defaults to 0.
ema_observations (int, optional): Number of observations over which
to perform exponential moving average smoothing. Defaults to 20.
ema_smoothing (int, optional): Exponential average smoothing value.
Defaults to 2.
use_tensoboard (bool, optional): Enable tensorboard callbacks.
Defaults to False.
steps_per_epoch_override (int, optional): Manually set the number
of steps per epoch. Defaults to 0 (automatic).
save_predictions (bool or str, optional): Save tile, slide, and
patient-level predictions at each evaluation. May be 'csv',
'feather', or 'parquet'. If False, will not save predictions.
Defaults to 'parquet'.
save_model (bool, optional): Save models when evaluating at
specified epochs. Defaults to True.
resume_training (str, optional): Path to model to continue training.
Only valid in Tensorflow backend. Defaults to None.
pretrain (str, optional): Either 'imagenet' or path to Tensorflow
model from which to load weights. Defaults to 'imagenet'.
checkpoint (str, optional): Path to cp.ckpt from which to load
weights. Defaults to None.
save_checkpoint (bool, optional): Save checkpoints at each epoch.
Defaults to True.
multi_gpu (bool, optional): Enable multi-GPU training using
Tensorflow/Keras MirroredStrategy.
reduce_method (str, optional): Reduction method for calculating
slide-level and patient-level predictions for categorical
outcomes. Options include 'average', 'mean', 'proportion',
'median', 'sum', 'min', 'max', or a callable function.
'average' and 'mean' are synonymous, with both options kept
for backwards compatibility. If 'average' or 'mean', will
reduce with average of each logit across tiles. If
'proportion', will convert tile predictions into onehot encoding
then reduce by averaging these onehot values. For all other
values, will reduce with the specified function, applied via
the pandas ``DataFrame.agg()`` function. Defaults to 'average'.
norm_fit (Dict[str, np.ndarray]): Normalizer fit, mapping fit
parameters (e.g. target_means, target_stds) to values
(np.ndarray). If not provided, will fit normalizer using
model params (if applicable). Defaults to None.
Returns:
dict: Nested results dict with metrics for each evaluated epoch.
"""
if self.hp.model_type() != self._model_type:
hp_model = self.hp.model_type()
raise errors.ModelError(f"Incompatible models: {hp_model} (hp) and "
f"{self._model_type} (model)")
self._detect_patients(train_dts, val_dts)
# Verify image format across datasets.
img_format = self._verify_img_format(train_dts, val_dts)
if img_format and self.config['img_format'] is None:
self.config['img_format'] = img_format
sf.util.write_json(self.config, join(self.outdir, 'params.json'))
# Clear prior Tensorflow graph to free memory
tf.keras.backend.clear_session()
results_log = os.path.join(self.outdir, 'results_log.csv')
# Fit the normalizer to the training data and log the source mean/stddev
if self.normalizer and self.hp.normalizer_source == 'dataset':
self.normalizer.fit(train_dts)
else:
self._fit_normalizer(norm_fit)
if self.normalizer:
config_path = join(self.outdir, 'params.json')
if not exists(config_path):
config = {
'slideflow_version': sf.__version__,
'hp': self.hp.to_dict(),
'backend': sf.backend()
}
else:
config = sf.util.load_json(config_path)
config['norm_fit'] = self.normalizer.get_fit(as_list=True)
sf.util.write_json(config, config_path)
# Prepare multiprocessing pool if from_wsi=True
if from_wsi:
pool = mp.Pool(
sf.util.num_cpu(default=8),
initializer=sf.util.set_ignore_sigint
)
else:
pool = None
# Save training / validation manifest
if val_dts is None:
val_paths = None
elif from_wsi:
val_paths = val_dts.slide_paths()
else:
val_paths = val_dts.tfrecords()
log_manifest(
train_dts.tfrecords(),
val_paths,
labels=self.labels,
filename=join(self.outdir, 'slide_manifest.csv')
)
# Neptune logging
if self.use_neptune:
tags = ['train']
if 'k-fold' in self.config['validation_strategy']:
tags += [f'k-fold{self.config["k_fold_i"]}']
self.neptune_run = self.neptune_logger.start_run(
self.name,
self.config['project'],
train_dts,
tags=tags
)
self.neptune_logger.log_config(self.config, 'train')
self.neptune_run['data/slide_manifest'].upload( # type: ignore
os.path.join(self.outdir, 'slide_manifest.csv')
)
# Set up multi-GPU strategy
if multi_gpu:
strategy = tf.distribute.MirroredStrategy()
log.info('Multi-GPU training with '
f'{strategy.num_replicas_in_sync} devices')
# Fixes "OSError: [Errno 9] Bad file descriptor" after training
atexit.register(strategy._extended._collective_ops._pool.close)
else:
strategy = None
with strategy.scope() if strategy else no_scope():
# Build model from ModelParams
if resume_training:
self.model = load(resume_training, method='weights', training=True)
else:
model = self.hp.build_model(
labels=self.labels,
num_slide_features=self.num_slide_features,
pretrain=pretrain,
checkpoint=checkpoint,
load_method=self.load_method
)
self.model = model
tf_utils.log_summary(model, self.neptune_run)
with tf.name_scope('input'):
t_kwargs = self._interleave_kwargs(
batch_size=self.hp.batch_size,
infinite=True,
augment=self.hp.augment,
transform=self.transform['train'],
from_wsi=from_wsi,
pool=pool,
roi_method=roi_method
)
train_data = train_dts.tensorflow(drop_last=True, **t_kwargs)
log.debug(f"Training: {train_dts.num_tiles} total tiles.")
# Set up validation data
using_validation = (val_dts
and (len(val_dts.tfrecords()) if not from_wsi
else len(val_dts.slide_paths())))
if using_validation:
assert val_dts is not None
with tf.name_scope('input'):
if not validation_batch_size:
validation_batch_size = self.hp.batch_size
v_kwargs = self._interleave_kwargs_val(
batch_size=validation_batch_size,
infinite=False,
augment=False,
transform=self.transform['val'],
from_wsi=from_wsi,
pool=pool,
roi_method=roi_method
)
validation_data = val_dts.tensorflow(
incl_slidenames=True,
incl_loc=True,
drop_last=True,
**v_kwargs
)
log.debug(f"Validation: {val_dts.num_tiles} total tiles.")
if validate_on_batch:
log.debug('Validation during training: every '
f'{validate_on_batch} steps and at epoch end')
mid_v_kwargs = v_kwargs.copy()
mid_v_kwargs['infinite'] = True
mid_train_validation_data = iter(val_dts.tensorflow(
incl_slidenames=True,
incl_loc=True,
drop_last=True,
**mid_v_kwargs
))
else:
log.debug('Validation during training: at epoch end')
mid_train_validation_data = None
if validation_steps:
num_samples = validation_steps * self.hp.batch_size
log.debug(f'Using {validation_steps} batches ({num_samples}'
' samples) each validation check')
else:
log.debug('Using entire validation set each val check')
else:
log.debug('Validation during training: None')
validation_data = None
mid_train_validation_data = None
validation_steps = 0
# Calculate parameters
if from_wsi:
train_tiles = train_data.est_num_tiles
val_tiles = validation_data.est_num_tiles
else:
train_tiles = train_dts.num_tiles
val_tiles = 0 if val_dts is None else val_dts.num_tiles
if max(self.hp.epochs) <= starting_epoch:
max_epoch = max(self.hp.epochs)
log.error(f'Starting epoch ({starting_epoch}) cannot be greater'
f' than max target epoch ({max_epoch})')
if (self.hp.early_stop and self.hp.early_stop_method == 'accuracy'
and self._model_type != 'classification'):
log.error("Unable to use 'accuracy' early stopping with model "
f"type '{self.hp.model_type()}'")
if starting_epoch != 0:
log.info(f'Starting training at epoch {starting_epoch}')
if steps_per_epoch_override:
steps_per_epoch = steps_per_epoch_override
else:
steps_per_epoch = round(train_tiles / self.hp.batch_size)
cb_args = SimpleNamespace(
starting_epoch=starting_epoch,
using_validation=using_validation,
validate_on_batch=validate_on_batch,
validation_steps=validation_steps,
ema_observations=ema_observations,
ema_smoothing=ema_smoothing,
steps_per_epoch=steps_per_epoch,
validation_data=validation_data,
mid_train_validation_data=mid_train_validation_data,
num_val_tiles=val_tiles,
save_predictions=save_predictions,
save_model=save_model,
results_log=results_log,
reduce_method=reduce_method,
log_frequency=log_frequency
)
# Create callbacks for early stopping, checkpoint saving,
# summaries, and history
val_callback = self.eval_callback(self, cb_args)
callbacks = [tf.keras.callbacks.History(), val_callback]
if save_checkpoints:
cp_callback = tf.keras.callbacks.ModelCheckpoint(
os.path.join(self.outdir, 'cp.ckpt'),
save_weights_only=True,
verbose=(sf.getLoggingLevel() <= 20)
)
callbacks += [cp_callback]
if use_tensorboard:
log.debug(
"Logging with Tensorboard to {} every {} batches.".format(
self.outdir, log_frequency
))
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=self.outdir,
histogram_freq=0,
write_graph=False,
update_freq='batch'
)
callbacks += [tensorboard_callback]
# Retrain top layer only, if using transfer learning and
# not resuming training
total_epochs = (self.hp.toplayer_epochs
+ (max(self.hp.epochs) - starting_epoch))
if self.hp.toplayer_epochs:
self._retrain_top_layers(
train_data,
steps_per_epoch,
callbacks=None,
epochs=self.hp.toplayer_epochs
)
self._compile_model()
# Train the model
log.info('Beginning training')
try:
self.model.fit(
train_data,
steps_per_epoch=steps_per_epoch,
epochs=total_epochs,
verbose=(sf.getLoggingLevel() <= 20),
initial_epoch=self.hp.toplayer_epochs,
callbacks=callbacks
)
except tf.errors.ResourceExhaustedError as e:
log.error(f"Training failed for [bold]{self.name}[/]. "
f"Error: \n {e}")
results = val_callback.results
if self.use_neptune and self.neptune_run is not None:
self.neptune_run['results'] = results['epochs']
self.neptune_run.stop()
# Cleanup
if pool is not None:
pool.close()
del mid_train_validation_data
return results
class RegressionTrainer(Trainer):
"""Extends the base :class:`slideflow.model.Trainer` class to add support
for regression models with continuous outcomes. Requires that all outcomes be continuous,
with appropriate regression loss function. Uses R-squared as the evaluation
metric, rather than AUROC."""
_model_type = 'regression'
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def _compile_model(self) -> None:
self.model.compile(optimizer=self.hp.get_opt(),
loss=self.hp.get_loss(),
metrics=[self.hp.get_loss()])
def _parse_tfrecord_labels(
self,
image: Union[Dict[str, tf.Tensor], tf.Tensor],
slide: tf.Tensor
) -> Tuple[Union[Dict[str, tf.Tensor], tf.Tensor], tf.Tensor]:
image_dict = {'tile_image': image}
if self.num_classes is None:
label = None
else:
label = [
self.annotations_tables[oi].lookup(slide)
for oi in range(self.num_classes) # type: ignore
]
# Add additional non-image feature inputs if indicated,
# excluding the event feature used for survival models
if self.num_slide_features:
def slide_lookup(s):
return self.slide_input[s.numpy().decode('utf-8')]
num_features = self.num_slide_features
slide_feature_input_val = tf.py_function(
func=slide_lookup,
inp=[slide],
Tout=[tf.float32] * num_features
)
image_dict.update({'slide_feature_input': slide_feature_input_val})
return image_dict, label
class SurvivalTrainer(RegressionTrainer):
"""Cox Proportional Hazards model. Requires that the user provide event
data as the first input feature, and time to outcome as the continuous outcome.
Uses concordance index as the evaluation metric."""
_model_type = 'survival'
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
if not self.num_slide_features:
raise errors.ModelError('Model error - survival models must '
'include event input')
def _setup_inputs(self) -> None:
# Setup slide-level input
try:
num_features = self.num_slide_features - 1
if num_features:
log.info(f'Training with both images and {num_features} '
'categories of slide-level input')
log.info('Interpreting first feature as event for survival model')
else:
log.info('Training with images alone. Interpreting first '
'feature as event for survival model')
except KeyError:
raise errors.ModelError("Unable to find slide-level input at "
"'input' key in annotations")
assert self.slide_input is not None
for slide in self.slides:
if len(self.slide_input[slide]) != self.num_slide_features:
num_in_feature_table = len(self.slide_input[slide])
raise errors.ModelError(
f'Length of input for slide {slide} does not match '
f'feature_sizes; expected {self.num_slide_features}, got '
f'{num_in_feature_table}'
)
def load(self, model: str, **kwargs) -> tf.keras.Model:
if self.load_method == 'full':
custom_objects = {
'negative_log_likelihood': tf_utils.negative_log_likelihood,
'concordance_index': tf_utils.concordance_index
}
self.model = tf.keras.models.load_model(
model,
custom_objects=custom_objects
)
self.model.compile(
loss=tf_utils.negative_log_likelihood,
metrics=tf_utils.concordance_index
)
else:
self.model = load(model, method=self.load_method, **kwargs)
def _compile_model(self) -> None:
self.model.compile(optimizer=self.hp.get_opt(),
loss=tf_utils.negative_log_likelihood,
metrics=tf_utils.concordance_index)
def _parse_tfrecord_labels(
self,
image: Union[Dict[str, tf.Tensor], tf.Tensor],
slide: tf.Tensor
) -> Tuple[Union[Dict[str, tf.Tensor], tf.Tensor], tf.Tensor]:
image_dict = {'tile_image': image}
if self.num_classes is None:
label = None
else:
label = [
self.annotations_tables[oi].lookup(slide)
for oi in range(self.num_classes) # type: ignore
]
# Add additional non-image feature inputs if indicated,
# excluding the event feature used for survival models
if self.num_slide_features:
# Time-to-event data must be added as a separate feature
def slide_lookup(s):
return self.slide_input[s.numpy().decode('utf-8')][1:]
def event_lookup(s):
return self.slide_input[s.numpy().decode('utf-8')][0]
num_features = self.num_slide_features - 1
event_input_val = tf.py_function(
func=event_lookup,
inp=[slide],
Tout=[tf.float32]
)
image_dict.update({'event_input': event_input_val})
slide_feature_input_val = tf.py_function(
func=slide_lookup,
inp=[slide],
Tout=[tf.float32] * num_features
)
# Add slide input features, excluding the event feature
# used for survival models
if not (self.num_slide_features == 1):
image_dict.update(
{'slide_feature_input': slide_feature_input_val}
)
return image_dict, label
class Features(BaseFeatureExtractor):
"""Interface for obtaining predictions and features from intermediate layer
activations from Slideflow models.
Use by calling on either a batch of images (returning outputs for a single
batch), or by calling on a :class:`slideflow.WSI` object, which will
generate an array of spatially-mapped activations matching the slide.
Examples
*Calling on batch of images:*
.. code-block:: python
interface = Features('/model/path', layers='postconv')
for image_batch in train_data:
# Return shape: (batch_size, num_features)
batch_features = interface(image_batch)
*Calling on a slide:*
.. code-block:: python
slide = sf.WSI(...)
interface = Features('/model/path', layers='postconv')
# Returns shape:
# (slide.grid.shape[0], slide.grid.shape[1], num_features)
activations_grid = interface(slide)
Note:
When this interface is called on a batch of images, no image processing
or stain normalization will be performed, as it is assumed that
normalization will occur during data loader image processing. When the
interface is called on a `slideflow.WSI`, the normalization strategy
will be read from the model configuration file, and normalization will
be performed on image tiles extracted from the WSI. If this interface
was created from an existing model and there is no model configuration
file to read, a slideflow.norm.StainNormalizer object may be passed
during initialization via the argument `wsi_normalizer`.
"""
def __init__(
self,
path: Optional[str],
layers: Optional[Union[str, List[str]]] = 'postconv',
include_preds: bool = False,
load_method: str = 'weights',
pooling: Optional[Any] = None,
device: Optional[str] = None,
) -> None:
"""Creates a features interface from a saved slideflow model which
outputs feature activations at the designated layers.
Intermediate layers are returned in the order of layers.
predictions are returned last.
Args:
path (str): Path to saved Slideflow model.
layers (list(str), optional): Layers from which to generate
activations. The post-convolution activation layer is accessed
via 'postconv'. Defaults to 'postconv'.
include_preds (bool, optional): Include predictions in output. Will be
returned last. Defaults to False.
load_method (str): Either 'full' or 'weights'. Method to use
when loading a Tensorflow model. If 'full', loads the model with
``tf.keras.models.load_model()``. If 'weights', will read the
``params.json`` configuration file, build the model architecture,
and then load weights from the given model with
``Model.load_weights()``. Loading with 'full' may improve
compatibility across Slideflow versions. Loading with 'weights'
may improve compatibility across hardware & environments.
"""
super().__init__('tensorflow', include_preds=include_preds)
if layers and isinstance(layers, str):
layers = [layers]
self.layers = layers
self.path = path
self.device = device
if isinstance(device, str):
self.device = device.replace('cuda', 'gpu')
self._pooling = None
self._include_preds = None
if path is not None:
self._model = load(self.path, method=load_method) # type: ignore
config = sf.util.get_model_config(path)
if 'img_format' in config:
self.img_format = config['img_format']
self.hp = sf.ModelParams()
self.hp.load_dict(config['hp'])
self.wsi_normalizer = self.hp.get_normalizer()
if 'norm_fit' in config and config['norm_fit'] is not None:
if self.wsi_normalizer is None:
log.warn('norm_fit found in model config file, but model '
'params does not use a normalizer. Ignoring.')
else:
self.wsi_normalizer.set_fit(**config['norm_fit'])
self._build(
layers=layers, include_preds=include_preds, pooling=pooling # type: ignore
)
@classmethod
def from_model(
cls,
model: tf.keras.Model,
layers: Optional[Union[str, List[str]]] = 'postconv',
include_preds: bool = False,
wsi_normalizer: Optional["StainNormalizer"] = None,
pooling: Optional[Any] = None,
device: Optional[str] = None
):
"""Creates a features interface from a loaded slideflow model which
outputs feature activations at the designated layers.
Intermediate layers are returned in the order of layers.
predictions are returned last.
Args:
model (:class:`tensorflow.keras.models.Model`): Loaded model.
layers (list(str), optional): Layers from which to generate
activations. The post-convolution activation layer is accessed
via 'postconv'. Defaults to 'postconv'.
include_preds (bool, optional): Include predictions in output. Will be
returned last. Defaults to False.
wsi_normalizer (:class:`slideflow.norm.StainNormalizer`): Stain
normalizer to use on whole-slide images. Not used on
individual tile datasets via __call__. Defaults to None.
"""
obj = cls(None, layers, include_preds, device=device)
if isinstance(model, tf.keras.models.Model):
obj._model = model
else:
raise errors.ModelError(f"Model {model} is not a valid Tensorflow "
"model.")
obj._build(
layers=layers, include_preds=include_preds, pooling=pooling # type: ignore
)
obj.wsi_normalizer = wsi_normalizer
return obj
def __repr__(self):
return ("{}(\n".format(self.__class__.__name__) +
" path={!r},\n".format(self.path) +
" layers={!r},\n".format(self.layers) +
" include_preds={!r},\n".format(self._include_preds) +
" pooling={!r},\n".format(self._pooling) +
")")
def __call__(
self,
inp: Union[tf.Tensor, "sf.WSI"],
**kwargs
) -> Optional[Union[np.ndarray, tf.Tensor]]:
"""Process a given input and return features and/or predictions.
Expects either a batch of images or a :class:`slideflow.WSI`.
When calling on a `WSI` object, keyword arguments are passed to
:meth:`slideflow.WSI.build_generator()`.
"""
if isinstance(inp, sf.WSI):
return self._predict_slide(inp, **kwargs)
else:
return self._predict(inp)
def _predict_slide(
self,
slide: "sf.WSI",
*,
img_format: str = 'auto',
batch_size: int = 32,
dtype: type = np.float16,
grid: Optional[np.ndarray] = None,
shuffle: bool = False,
show_progress: bool = True,
callback: Optional[Callable] = None,
normalizer: Optional[Union[str, "sf.norm.StainNormalizer"]] = None,
normalizer_source: Optional[str] = None,
**kwargs
) -> Optional[np.ndarray]:
"""Generate activations from slide => activation grid array."""
# Check image format
if img_format == 'auto' and self.img_format is None:
raise ValueError(
'Unable to auto-detect image format (png or jpg). Set the '
'format by passing img_format=... to the call function.'
)
elif img_format == 'auto':
assert self.img_format is not None
img_format = self.img_format
return sf.model.extractors.features_from_slide(
self,
slide,
img_format=img_format,
batch_size=batch_size,
dtype=dtype,
grid=grid,
shuffle=shuffle,
show_progress=show_progress,
callback=callback,
normalizer=(normalizer if normalizer else self.wsi_normalizer),
normalizer_source=normalizer_source,
**kwargs
)
@tf.function
def _predict(self, inp: tf.Tensor) -> tf.Tensor:
"""Return activations for a single batch of images."""
with tf.device(self.device) if self.device else no_scope():
return self.model(inp, training=False)
def _build(
self,
layers: Optional[Union[str, List[str]]],
include_preds: bool = True,
pooling: Optional[Any] = None
) -> None:
"""Builds the interface model that outputs feature activations at the
designated layers and/or predictions. Intermediate layers are returned in
the order of layers. predictions are returned last."""
self._pooling = pooling
self._include_preds = include_preds
if isinstance(pooling, str):
if pooling == 'avg':
pooling = tf.keras.layers.GlobalAveragePooling2D
elif pooling == 'max':
pooling = tf.keras.layers.GlobalMaxPool2D
else:
raise ValueError(f"Unrecognized pooling value {pooling}. "
"Expected 'avg', 'max', or Keras layer.")
if layers and not isinstance(layers, list):
layers = [layers]
if layers:
if 'postconv' in layers:
layers[layers.index('postconv')] = 'post_convolution' # type: ignore
log.debug(f"Setting up interface to return activations from layers "
f"{', '.join(layers)}")
else:
layers = []
def pool_if_3d(tensor):
if pooling is not None and len(tensor.shape) == 4:
return pooling()(tensor)
else:
return tensor
# Find the desired layers
outputs = {}
outer_layer_outputs = {
self._model.layers[i].name: self._model.layers[i].output
for i in range(len(self._model.layers))
}
core_layer_outputs = {}
inner_layers = [la for la in layers if la not in outer_layer_outputs]
if inner_layers:
intermediate_core = tf.keras.models.Model(
inputs=self._model.layers[1].input,
outputs=[
pool_if_3d(self._model.layers[1].get_layer(il).output)
for il in inner_layers
]
)
if len(inner_layers) > 1:
int_out = intermediate_core(self._model.input)
for la, layer in enumerate(inner_layers):
core_layer_outputs[layer] = int_out[la]
else:
outputs[inner_layers[0]] = intermediate_core(self._model.input)
for layer in layers:
if layer in outer_layer_outputs:
outputs[layer] = outer_layer_outputs[layer]
elif layer in core_layer_outputs:
outputs[layer] = core_layer_outputs[layer]
# Build a model that outputs the given layers
outputs_list = [] if not layers else [outputs[la] for la in layers]
if include_preds:
outputs_list += [self._model.output]
self.model = tf.keras.models.Model(
inputs=self._model.input,
outputs=outputs_list
)
self.num_features = sum([outputs[o].shape[1] for o in outputs])
self.num_outputs = len(outputs_list)
if isinstance(self._model.output, list) and include_preds:
log.warning("Multi-categorical outcomes is experimental "
"for this interface.")
self.num_classes = sum(o.shape[1] for o in self._model.output)
elif include_preds:
self.num_classes = self._model.output.shape[1]
else:
self.num_classes = 0
if include_preds:
log.debug(f'Number of classes: {self.num_classes}')
log.debug(f'Number of activation features: {self.num_features}')
def dump_config(self):
return {
'class': 'slideflow.model.tensorflow.Features',
'kwargs': {
'path': self.path,
'layers': self.layers,
'include_preds': self._include_preds,
'pooling': self._pooling
}
}
class UncertaintyInterface(Features):
def __init__(
self,
path: Optional[str],
layers: Optional[Union[str, List[str]]] = 'postconv',
load_method: str = 'weights',
pooling: Optional[Any] = None
) -> None:
super().__init__(
path,
layers=layers,
include_preds=True,
load_method=load_method,
pooling=pooling
)
# TODO: As the below to-do suggests, this should be updated
# for multi-class
self.num_uncertainty = 1
if self.num_classes > 2:
log.warn("UncertaintyInterface not yet implemented for multi-class"
" models")
@classmethod
def from_model( # type: ignore
cls,
model: tf.keras.Model,
layers: Optional[Union[str, List[str]]] = None,
wsi_normalizer: Optional["StainNormalizer"] = None,
pooling: Optional[Any] = None
):
obj = cls(None, layers)
if isinstance(model, tf.keras.models.Model):
obj._model = model
else:
raise errors.ModelError(f"Model {model} is not a valid Tensorflow "
"model.")
obj._build(
layers=layers, include_preds=True, pooling=pooling # type: ignore
)
obj.wsi_normalizer = wsi_normalizer
return obj
def __repr__(self):
return ("{}(\n".format(self.__class__.__name__) +
" path={!r},\n".format(self.path) +
" layers={!r},\n".format(self.layers) +
" pooling={!r},\n".format(self._pooling) +
")")
@tf.function
def _predict(self, inp):
"""Return activations (mean), predictions (mean), and uncertainty
(stdev) for a single batch of images."""
out_drop = [[] for _ in range(self.num_outputs)]
for _ in range(30):
yp = self.model(inp, training=False)
for n in range(self.num_outputs):
out_drop[n] += [(yp[n] if self.num_outputs > 1 else yp)]
for n in range(self.num_outputs):
out_drop[n] = tf.stack(out_drop[n], axis=0)
predictions = tf.math.reduce_mean(out_drop[-1], axis=0)
# TODO: Only takes STDEV from first outcome category which works for
# outcomes with 2 categories, but a better solution is needed
# for num_categories > 2
uncertainty = tf.math.reduce_std(out_drop[-1], axis=0)[:, 0]
uncertainty = tf.expand_dims(uncertainty, axis=-1)
if self.num_outputs > 1:
out = [
tf.math.reduce_mean(out_drop[n], axis=0)
for n in range(self.num_outputs-1)
]
return out + [predictions, uncertainty]
else:
return predictions, uncertainty
def dump_config(self):
return {
'class': 'slideflow.model.tensorflow.UncertaintyInterface',
'kwargs': {
'path': self.path,
'layers': self.layers,
'pooling': self._pooling
}
}
[docs]def load(
path: str,
method: str = 'weights',
custom_objects: Optional[Dict[str, Any]] = None,
training: bool = False
) -> tf.keras.models.Model:
"""Load a model trained with Slideflow.
Args:
path (str): Path to saved model. Must be a model trained in Slideflow.
method (str): Method to use when loading the model; either 'full' or
'weights'. If 'full', will load the saved model with
``tf.keras.models.load_model()``. If 'weights', will read the
``params.json`` configuration file, build the model architecture,
and then load weights from the given model with
``Model.load_weights()``. Loading with 'full' may improve
compatibility across Slideflow versions. Loading with 'weights'
may improve compatibility across hardware & environments.
custom_objects (dict, Optional): Dictionary mapping names
(strings) to custom classes or functions. Defaults to None.
Returns:
tf.keras.models.Model: Loaded model.
"""
if method not in ('full', 'weights'):
raise ValueError(f"Unrecognized method {method}, expected "
"either 'full' or 'weights'")
log.debug(f"Loading model with method='{method}'")
if method == 'full':
return tf.keras.models.load_model(path, custom_objects=custom_objects)
else:
config = sf.util.get_model_config(path)
hp = ModelParams.from_dict(config['hp'])
if len(config['outcomes']) == 1 or config['model_type'] == 'regression':
num_classes = len(list(config['outcome_labels'].keys()))
else:
num_classes = {
outcome: len(list(config['outcome_labels'][outcome].keys()))
for outcome in config['outcomes']
} # type: ignore
if config['model_type'] == 'survival':
survival_kw = dict(training=training)
else:
survival_kw = dict()
model = hp.build_model( # type: ignore
num_classes=num_classes,
num_slide_features=0 if not config['input_feature_sizes'] else sum(config['input_feature_sizes']),
pretrain=None,
**survival_kw
)
model.load_weights(join(path, 'variables/variables'))
return model