'''PyTorch backend for the slideflow.model submodule.'''
import inspect
import json
import os
import types
import numpy as np
import multiprocessing as mp
import pandas as pd
import torch
import torchvision
from torch import Tensor
from torch.nn.functional import softmax
from packaging import version
from rich.progress import Progress, TimeElapsedColumn
from collections import defaultdict
from os.path import join
from pandas.api.types import is_float_dtype, is_integer_dtype
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
Union, Callable)
import slideflow as sf
import slideflow.util.neptune_utils
from slideflow import errors
from slideflow.model import base as _base
from slideflow.model import torch_utils
from slideflow.model.torch_utils import autocast
from slideflow.model.base import log_manifest, BaseFeatureExtractor
from slideflow.util import log, NormFit, ImgBatchSpeedColumn, no_scope
if TYPE_CHECKING:
import pandas as pd
from slideflow.norm import StainNormalizer
class LinearBlock(torch.nn.Module):
'''Block module that includes a linear layer -> ReLU -> BatchNorm'''
def __init__(
self,
in_ftrs: int,
out_ftrs: int,
dropout: Optional[float] = None
) -> None:
super().__init__()
self.in_ftrs = in_ftrs
self.out_ftrs = out_ftrs
self.linear = torch.nn.Linear(in_ftrs, out_ftrs)
self.relu = torch.nn.ReLU(inplace=True)
self.bn = torch.nn.BatchNorm1d(out_ftrs)
if dropout:
self.dropout = torch.nn.Dropout(dropout)
else:
self.dropout = None # type: ignore
def forward(self, x: Tensor) -> Tensor:
x = self.linear(x)
x = self.relu(x)
x = self.bn(x)
if self.dropout is not None:
x = self.dropout(x)
return x
class ModelWrapper(torch.nn.Module):
'''Wrapper for PyTorch modules to support multiple outcomes, clinical
(patient-level) inputs, and additional hidden layers.'''
def __init__(
self,
model: Any,
n_classes: List[int],
num_slide_features: int = 0,
hidden_layers: Optional[List[int]] = None,
drop_images: bool = False,
dropout: Optional[float] = None,
include_top: bool = True
) -> None:
super().__init__()
self.model = model
self.n_classes = len(n_classes)
self.drop_images = drop_images
self.num_slide_features = num_slide_features
self.num_hidden_layers = 0 if not hidden_layers else len(hidden_layers)
self.has_aux = False
log.debug(f'Model class name: {model.__class__.__name__}')
if not drop_images:
# Check for auxillary classifier
if model.__class__.__name__ in ('Inception3',):
log.debug("Auxillary classifier detected")
self.has_aux = True
# Get the last linear layer prior to the logits layer
if model.__class__.__name__ in ('Xception', 'NASNetALarge'):
num_ftrs = self.model.last_linear.in_features
self.model.last_linear = torch.nn.Identity()
elif model.__class__.__name__ in ('SqueezeNet'):
num_ftrs = 1000
elif hasattr(self.model, 'classifier'):
children = list(self.model.classifier.named_children())
if len(children):
# VGG, AlexNet
if include_top:
log.debug("Including existing fully-connected "
"top classifier layers")
last_linear_name, last_linear = children[-1]
num_ftrs = last_linear.in_features
setattr(
self.model.classifier,
last_linear_name,
torch.nn.Identity()
)
elif model.__class__.__name__ in ('AlexNet',
'MobileNetV2',
'MNASNet'):
log.debug("Removing fully-connected classifier layers")
_, first_classifier = children[1]
num_ftrs = first_classifier.in_features
self.model.classifier = torch.nn.Identity()
elif model.__class__.__name__ in ('VGG', 'MobileNetV3'):
log.debug("Removing fully-connected classifier layers")
_, first_classifier = children[0]
num_ftrs = first_classifier.in_features
self.model.classifier = torch.nn.Identity()
else:
num_ftrs = self.model.classifier.in_features
self.model.classifier = torch.nn.Identity()
elif hasattr(self.model, 'fc'):
num_ftrs = self.model.fc.in_features
self.model.fc = torch.nn.Identity()
elif hasattr(self.model, 'out_features'):
num_ftrs = self.model.out_features
elif hasattr(self.model, 'head'):
num_ftrs = self.model.head.out_features
else:
print(self.model)
raise errors.ModelError("Unable to find last linear layer for "
f"model {model.__class__.__name__}")
else:
num_ftrs = 0
# Add slide-level features
num_ftrs += num_slide_features
# Add hidden layers
if hidden_layers:
hl_ftrs = [num_ftrs] + hidden_layers
for i in range(len(hidden_layers)):
setattr(self, f'h{i}', LinearBlock(hl_ftrs[i],
hl_ftrs[i+1],
dropout=dropout))
num_ftrs = hidden_layers[-1]
# Add the outcome/logits layers for each outcome, if multiple outcomes
for i, n in enumerate(n_classes):
setattr(self, f'fc{i}', torch.nn.Linear(num_ftrs, n))
def __getattr__(self, name: str) -> Any:
try:
return super().__getattr__(name)
except AttributeError as e:
if name == 'model':
raise e
return getattr(self.model, name)
def forward(
self,
img: Tensor,
slide_features: Optional[Tensor] = None
):
if slide_features is None and self.num_slide_features:
raise ValueError("Expected 2 inputs, got 1")
# Last linear of core convolutional model
if not self.drop_images:
x = self.model(img)
# Discard auxillary classifier
if self.has_aux and self.training:
x = x.logits
# Merging image data with any slide-level input data
if self.num_slide_features and not self.drop_images:
assert slide_features is not None
x = torch.cat([x, slide_features], dim=1)
elif self.num_slide_features:
x = slide_features
# Hidden layers
if self.num_hidden_layers:
x = self.h0(x)
if self.num_hidden_layers > 1:
for h in range(1, self.num_hidden_layers):
x = getattr(self, f'h{h}')(x)
# Return a list of outputs if we have multiple outcomes
if self.n_classes > 1:
out = [getattr(self, f'fc{i}')(x) for i in range(self.n_classes)]
# Otherwise, return the single output
else:
out = self.fc0(x)
return out # , x
[docs]class ModelParams(_base._ModelParams):
"""Build a set of hyperparameters."""
ModelDict = {
'resnet18': torchvision.models.resnet18,
'resnet50': torchvision.models.resnet50,
'alexnet': torchvision.models.alexnet,
'squeezenet': torchvision.models.squeezenet.squeezenet1_1,
'densenet': torchvision.models.densenet161,
'inception': torchvision.models.inception_v3,
'googlenet': torchvision.models.googlenet,
'shufflenet': torchvision.models.shufflenet_v2_x1_0,
'resnext50_32x4d': torchvision.models.resnext50_32x4d,
'vgg16': torchvision.models.vgg16, # needs support added
'mobilenet_v2': torchvision.models.mobilenet_v2,
'mobilenet_v3_small': torchvision.models.mobilenet_v3_small,
'mobilenet_v3_large': torchvision.models.mobilenet_v3_large,
'wide_resnet50_2': torchvision.models.wide_resnet50_2,
'mnasnet': torchvision.models.mnasnet1_0,
'xception': torch_utils.xception,
'nasnet_large': torch_utils.nasnetalarge
}
def __init__(self, *, loss: str = 'CrossEntropy', **kwargs) -> None:
self.OptDict = {
'Adadelta': torch.optim.Adadelta,
'Adagrad': torch.optim.Adagrad,
'Adam': torch.optim.Adam,
'AdamW': torch.optim.AdamW,
'SparseAdam': torch.optim.SparseAdam,
'Adamax': torch.optim.Adamax,
'ASGD': torch.optim.ASGD,
'LBFGS': torch.optim.LBFGS,
'RMSprop': torch.optim.RMSprop,
'Rprop': torch.optim.Rprop,
'SGD': torch.optim.SGD
}
self.RegressionLossDict = {
'L1': torch.nn.L1Loss,
'MSE': torch.nn.MSELoss,
'NLL': torch.nn.NLLLoss, # negative log likelihood
'HingeEmbedding': torch.nn.HingeEmbeddingLoss,
'SmoothL1': torch.nn.SmoothL1Loss,
'CosineEmbedding': torch.nn.CosineEmbeddingLoss,
}
self.AllLossDict = {
'CrossEntropy': torch.nn.CrossEntropyLoss,
'CTC': torch.nn.CTCLoss,
'PoissonNLL': torch.nn.PoissonNLLLoss,
'GaussianNLL': torch.nn.GaussianNLLLoss,
'KLDiv': torch.nn.KLDivLoss,
'BCE': torch.nn.BCELoss,
'BCEWithLogits': torch.nn.BCEWithLogitsLoss,
'MarginRanking': torch.nn.MarginRankingLoss,
'MultiLabelMargin': torch.nn.MultiLabelMarginLoss,
'Huber': torch.nn.HuberLoss,
'SoftMargin': torch.nn.SoftMarginLoss,
'MultiLabelSoftMargin': torch.nn.MultiLabelSoftMarginLoss,
'MultiMargin': torch.nn.MultiMarginLoss,
'TripletMargin': torch.nn.TripletMarginLoss,
'TripletMarginWithDistance': torch.nn.TripletMarginWithDistanceLoss,
'L1': torch.nn.L1Loss,
'MSE': torch.nn.MSELoss,
'NLL': torch.nn.NLLLoss, # negative log likelihood
'HingeEmbedding': torch.nn.HingeEmbeddingLoss,
'SmoothL1': torch.nn.SmoothL1Loss,
'CosineEmbedding': torch.nn.CosineEmbeddingLoss,
}
super().__init__(loss=loss, **kwargs)
assert self.model in self.ModelDict.keys() or self.model.startswith('timm_')
assert self.optimizer in self.OptDict.keys()
assert self.loss in self.AllLossDict
if self.model == 'inception':
log.warn("Model 'inception' has an auxillary classifier, which "
"is currently ignored during training. Auxillary "
"classifier loss will be included during training "
"starting in version 1.3")
def get_opt(self, params_to_update: Iterable) -> torch.optim.Optimizer:
return self.OptDict[self.optimizer](
params_to_update,
lr=self.learning_rate,
weight_decay=self.l2
)
def get_loss(self) -> torch.nn.modules.loss._Loss:
return self.AllLossDict[self.loss]()
def get_model_loader(self, model: str) -> Callable:
if model in self.ModelDict:
return self.ModelDict[model]
elif model.startswith('timm_'):
def loader(**kwargs):
try:
import timm
except ImportError:
raise ImportError(f"Unable to load model {model}; "
"timm package not installed.")
return timm.create_model(model[5:], **kwargs)
return loader
else:
raise ValueError(f"Model {model} not found.")
def build_model(
self,
labels: Optional[Dict] = None,
num_classes: Optional[Union[int, Dict[Any, int]]] = None,
num_slide_features: int = 0,
pretrain: Optional[str] = None,
checkpoint: Optional[str] = None
) -> torch.nn.Module:
assert num_classes is not None or labels is not None
if num_classes is None:
assert labels is not None
num_classes = self._detect_classes_from_labels(labels)
if not isinstance(num_classes, dict):
num_classes = {'out-0': num_classes}
# Prepare custom model pretraining
if pretrain:
log.debug(f"Using pretraining: [green]{pretrain}")
if (isinstance(pretrain, str)
and sf.util.path_to_ext(pretrain).lower() == 'zip'):
_pretrained = pretrain
pretrain = None
else:
_pretrained = None
# Build base model
if self.model in ('xception', 'nasnet_large'):
_model = self.get_model_loader(self.model)(
num_classes=1000,
pretrained=pretrain
)
else:
# Compatibility logic for prior versions of PyTorch
model_fn = self.get_model_loader(self.model)
model_fn_sig = inspect.signature(model_fn)
model_kw = [
param.name
for param in model_fn_sig.parameters.values()
if param.kind == param.POSITIONAL_OR_KEYWORD
]
call_kw = {}
if 'image_size' in model_kw:
call_kw.update(dict(image_size=self.tile_px))
if (version.parse(torchvision.__version__) >= version.parse("0.13")
and not self.model.startswith('timm_')):
# New Torchvision API
w = 'DEFAULT' if pretrain == 'imagenet' else pretrain
call_kw.update(dict(weights=w)) # type: ignore
else:
call_kw.update(dict(pretrained=pretrain)) # type: ignore
_model = model_fn(**call_kw)
# Add final layers to models
hidden_layers = [
self.hidden_layer_width
for _ in range(self.hidden_layers)
]
model = ModelWrapper(
_model,
list(num_classes.values()),
num_slide_features,
hidden_layers,
self.drop_images,
dropout=self.dropout,
include_top=self.include_top
)
if _pretrained is not None:
lazy_load_pretrained(model, _pretrained)
if checkpoint is not None:
model.load_state_dict(torch.load(checkpoint))
return model
def model_type(self) -> str:
"""Returns 'regression', 'classification', or 'survival', reflecting the loss."""
#check if loss is custom_[type] and returns type
if self.loss.startswith('custom'):
return self.loss[7:]
elif self.loss == 'NLL':
return 'survival'
elif self.loss in self.RegressionLossDict:
return 'regression'
else:
return 'classification'
[docs]class Trainer:
"""Base trainer class containing functionality for model building, input
processing, training, and evaluation.
This base class requires categorical outcome(s). Additional outcome types
are supported by :class:`slideflow.model.RegressionTrainer` and
:class:`slideflow.model.SurvivalTrainer`.
Slide-level (e.g. clinical) features can be used as additional model input
by providing slide labels in the slide annotations dictionary, under
the key 'input'.
"""
_model_type = 'classification'
def __init__(
self,
hp: ModelParams,
outdir: str,
labels: Dict[str, Any],
*,
slide_input: Optional[Dict[str, Any]] = None,
name: str = 'Trainer',
feature_sizes: Optional[List[int]] = None,
feature_names: Optional[List[str]] = None,
outcome_names: Optional[List[str]] = None,
mixed_precision: bool = True,
allow_tf32: bool = False,
config: Dict[str, Any] = None,
use_neptune: bool = False,
neptune_api: Optional[str] = None,
neptune_workspace: Optional[str] = None,
load_method: str = 'weights',
custom_objects: Optional[Dict[str, Any]] = None,
device: Optional[str] = None,
transform: Optional[Union[Callable, Dict[str, Callable]]] = None,
pin_memory: bool = True,
num_workers: int = 4,
chunk_size: int = 8
):
"""Sets base configuration, preparing model inputs and outputs.
Args:
hp (:class:`slideflow.ModelParams`): ModelParams object.
outdir (str): Destination for event logs and checkpoints.
labels (dict): Dict mapping slide names to outcome labels (int or
float format).
slide_input (dict): Dict mapping slide names to additional
slide-level input, concatenated after post-conv.
name (str, optional): Optional name describing the model, used for
model saving. Defaults to None.
feature_sizes (list, optional): List of sizes of input features.
Required if providing additional input features as model input.
feature_names (list, optional): List of names for input features.
Used when permuting feature importance.
outcome_names (list, optional): Name of each outcome. Defaults to
"Outcome {X}" for each outcome.
mixed_precision (bool, optional): Use FP16 mixed precision (rather
than FP32). Defaults to True.
allow_tf32 (bool): Allow internal use of Tensorfloat-32 format.
Defaults to False.
config (dict, optional): Training configuration dictionary, used
for logging and image format verification. Defaults to None.
use_neptune (bool, optional): Use Neptune API logging.
Defaults to False
neptune_api (str, optional): Neptune API token, used for logging.
Defaults to None.
neptune_workspace (str, optional): Neptune workspace.
Defaults to None.
load_method (str): Loading method to use when reading model.
This argument is ignored in the PyTorch backend, as all models
are loaded by first building the model with hyperparameters
detected in ``params.json``, then loading weights with
``torch.nn.Module.load_state_dict()``. Defaults to
'full' (ignored).
transform (callable or dict, optional): Optional transform to
apply to input images. If dict, must have the keys 'train'
and/or 'val', mapping to callables that takes a single
image Tensor as input and returns a single image Tensor.
If None, no transform is applied. If a single callable is
provided, it will be applied to both training and validation
data. If a dict is provided, the 'train' transform will be
applied to training data and the 'val' transform will be
applied to validation data. If a dict is provided and either
'train' or 'val' is None, no transform will be applied to
that data. Defaults to None.
pin_memory (bool): Set the ``pin_memory`` attribute for dataloaders.
Defaults to True.
num_workers (int): Set the number of workers for dataloaders.
Defaults to 4.
chunk_size (int): Set the chunk size for TFRecord reading.
Defaults to 8.
"""
self.hp = hp
self.outdir = outdir
self.labels = labels
self.patients = dict() # type: Dict[str, str]
self.name = name
self.model = None # type: Optional[torch.nn.Module]
self.inference_model = None # type: Optional[torch.nn.Module]
self.mixed_precision = mixed_precision
self.device = torch_utils.get_device(device)
self.mid_train_val_dts: Optional[Iterable] = None
self.loss_fn: torch.nn.modules.loss._Loss
self.use_tensorboard: bool
self.writer = None # type: Optional[torch.utils.tensorboard.SummaryWriter]
self.pin_memory = pin_memory
self.num_workers = num_workers
self.chunk_size = chunk_size
self._reset_training_params()
if custom_objects is not None:
log.warn("custom_objects argument ignored in PyTorch backend.")
# Enable or disable Tensorflow-32
# Allows PyTorch to internally use tf32 for matmul and convolutions
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
torch.backends.cudnn.allow_tf32 = allow_tf32 # type: ignore
self._allow_tf32 = allow_tf32
# Slide-level input args
if slide_input:
self.slide_input = {
k: [float(vi) for vi in v]
for k, v in slide_input.items()
}
else:
self.slide_input = None # type: ignore
self.feature_names = feature_names
self.feature_sizes = feature_sizes
self.num_slide_features = 0 if not feature_sizes else sum(feature_sizes)
self.normalizer = self.hp.get_normalizer()
if self.normalizer:
log.info(f'Using realtime {self.hp.normalizer} normalization')
if not os.path.exists(outdir):
os.makedirs(outdir)
self._process_transforms(transform)
self._process_outcome_labels(outcome_names)
if isinstance(labels, pd.DataFrame):
cat_assign = self._process_category_assignments()
# Log parameters
if config is None:
config = {
'slideflow_version': sf.__version__,
'backend': sf.backend(),
'git_commit': sf.__gitcommit__,
'model_name': self.name,
'full_model_name': self.name,
'outcomes': self.outcome_names,
'model_type': self.hp.model_type(),
'img_format': None,
'tile_px': self.hp.tile_px,
'tile_um': self.hp.tile_um,
'input_features': None,
'input_feature_sizes': None,
'input_feature_labels': None,
'hp': self.hp.to_dict(),
}
if isinstance(labels, pd.DataFrame):
config['outcome_labels'] = {str(k): v for k,v in cat_assign.items()}
sf.util.write_json(config, join(self.outdir, 'params.json'))
# Neptune logging
self.config = config
self.img_format = config['img_format'] if 'img_format' in config else None
self.use_neptune = use_neptune
self.neptune_run = None
if self.use_neptune:
if neptune_api is None or neptune_workspace is None:
raise ValueError("If using Neptune, must supply neptune_api"
" and neptune_workspace.")
self.neptune_logger = sf.util.neptune_utils.NeptuneLog(
neptune_api,
neptune_workspace
)
@property
def num_outcomes(self) -> int:
if self.hp.model_type() == 'classification':
assert self.outcome_names is not None
return len(self.outcome_names)
else:
return 1
@property
def multi_outcome(self) -> bool:
return (self.num_outcomes > 1)
def _process_category_assignments(self) -> Dict[int, str]:
"""Get category assignments for categorical outcome labels.
Dataframes with integer labels are assumed to be categorical if
if hp.model_type is 'classification'.
Dataframes with float labels are assumed to be continuous.
Dataframes with other labels are assumed to be categorical, and will
be assigned an integer label based on the order of unique values.
"""
if not isinstance(self.labels, pd.DataFrame):
raise ValueError("Expected DataFrame with 'label' column.")
if 'label' not in self.labels.columns:
raise ValueError("Expected DataFrame with 'label' column.")
if self.hp.model_type() == 'classification':
if is_integer_dtype(self.labels['label']) or is_float_dtype(self.labels['label']):
return {i: str(i) for i in sorted(self.labels['label'].unique())}
else:
int_to_str = dict(enumerate(sorted(self.labels['label'].unique())))
str_to_int = {v: k for k, v in int_to_str.items()}
log.info("Assigned integer labels to categories:")
log.info(str_to_int)
self.labels['label'] = self.labels['label'].map(str_to_int)
return int_to_str
else:
return {}
def _process_transforms(
self,
transform: Optional[Union[Callable, Dict[str, Callable]]] = None
) -> None:
"""Process custom transformations for training and/or validation."""
if not isinstance(transform, dict):
transform = {'train': transform, 'val': transform}
if any([t not in ('train', 'val') for t in transform]):
raise ValueError("transform must be a callable or dict with keys "
"'train' and/or 'val'")
if 'train' not in transform:
transform['train'] = None
if 'val' not in transform:
transform['val'] = None
self.transform = transform
def _process_outcome_labels(
self,
outcome_names: Optional[List[str]] = None,
) -> None:
"""Process outcome labels to determine number of outcomes and names.
Supports experimental tile-level labels provided via pandas DataFrame.
Args:
labels (dict): Dict mapping slide names to outcome labels (int or
float format). Experimental funtionality: if labels is a
pandas DataFrame, the 'label' column will be used as the
outcome labels.
outcome_names (list, optional): Name of each outcome. Defaults to
"Outcome {X}" for each outcome.
"""
# Process DataFrame tile-level labels
if isinstance(self.labels, pd.DataFrame):
if 'label' not in self.labels.columns:
raise errors.ModelError("Expected DataFrame with 'label' "
"column.")
if outcome_names and len(outcome_names) > 1:
raise errors.ModelError(
"Expected single outcome name for labels from a pandas dataframe."
)
self.outcome_names = outcome_names or ['Outcome 0']
return
# Process dictionary slide-level labels
outcome_labels = np.array(list(self.labels.values()))
if len(outcome_labels.shape) == 1:
outcome_labels = np.expand_dims(outcome_labels, axis=1)
if not outcome_names:
self.outcome_names = [
f'Outcome {i}'
for i in range(outcome_labels.shape[1])
]
else:
self.outcome_names = outcome_names
if not len(self.outcome_names) == outcome_labels.shape[1]:
n_names = len(self.outcome_names)
n_out = outcome_labels.shape[1]
raise errors.ModelError(f"Number of outcome names ({n_names}) does"
f" not match number of outcomes ({n_out})")
def _reset_training_params(self) -> None:
self.global_step = 0
self.epoch = 0 # type: int
self.step = 0 # type: int
self.log_frequency = 0 # type: int
self.early_stop = False # type: bool
self.moving_average = [] # type: List
self.dataloaders = {} # type: Dict[str, Any]
self.validation_batch_size = None # type: Optional[int]
self.validate_on_batch = 0
self.validation_steps = 0
self.ema_observations = 0 # type: int
self.ema_smoothing = 0
self.last_ema = -1 # type: float
self.ema_one_check_prior = -1 # type: float
self.ema_two_checks_prior = -1 # type: float
self.epoch_records = 0 # type: int
self.running_loss = 0.0
self.running_corrects = {} # type: Union[Tensor, Dict[str, Tensor]]
def _accuracy_as_numpy(
self,
acc: Union[Tensor, float, List[Tensor], List[float]]
) -> Union[float, List[float]]:
if isinstance(acc, list):
return [t.item() if isinstance(t, Tensor) else t for t in acc]
else:
return (acc.item() if isinstance(acc, Tensor) else acc)
def _build_model(
self,
checkpoint: Optional[str] = None,
pretrain: Optional[str] = None
) -> None:
if checkpoint:
log.info(f"Loading checkpoint at [green]{checkpoint}")
self.load(checkpoint)
else:
self.model = self.hp.build_model(
labels=self.labels,
pretrain=pretrain,
num_slide_features=self.num_slide_features
)
# Create an inference model before any multi-GPU parallelization
# is applied to the self.model parameter
self.inference_model = self.model
def _calculate_accuracy(
self,
running_corrects: Union[Tensor, Dict[Any, Tensor]],
num_records: int = 1
) -> Tuple[Union[Tensor, List[Tensor]], str]:
'''Reports accuracy of each outcome.'''
assert self.hp.model_type() == 'classification'
if self.num_outcomes > 1:
if not isinstance(running_corrects, dict):
raise ValueError("Expected running_corrects to be a dict:"
" num_outcomes is > 1")
acc_desc = ''
acc_list = [running_corrects[r] / num_records
for r in running_corrects]
for o in range(len(running_corrects)):
_acc = running_corrects[f'out-{o}'] / num_records
acc_desc += f"out-{o} acc: {_acc:.4f} "
return acc_list, acc_desc
else:
assert not isinstance(running_corrects, dict)
_acc = running_corrects / num_records
return _acc, f'acc: {_acc:.4f}'
def _calculate_loss(
self,
outputs: Union[Tensor, List[Tensor]],
labels: Union[Tensor, Dict[Any, Tensor]],
loss_fn: torch.nn.modules.loss._Loss
) -> Tensor:
'''Calculates loss in a manner compatible with multiple outcomes.'''
if self.num_outcomes > 1:
if not isinstance(labels, dict):
raise ValueError("Expected labels to be a dict: num_outcomes"
" is > 1")
loss = sum([
loss_fn(out, labels[f'out-{o}'])
for o, out in enumerate(outputs)
])
else:
loss = loss_fn(outputs, labels)
return loss # type: ignore
def _check_early_stopping(
self,
val_acc: Optional[Union[float, List[float]]] = None,
val_loss: Optional[float] = None
) -> str:
if val_acc is None and val_loss is None:
if (self.hp.early_stop
and self.hp.early_stop_method == 'manual'
and self.hp.manual_early_stop_epoch <= self.epoch # type: ignore
and self.hp.manual_early_stop_batch <= self.step): # type: ignore
log.info(f'Manual early stop triggered: epoch {self.epoch}, '
f'batch {self.step}')
if self.epoch not in self.hp.epochs:
self.hp.epochs += [self.epoch]
self.early_stop = True
else:
if self.hp.early_stop_method == 'accuracy':
if self.num_outcomes > 1:
raise errors.ModelError(
"Early stopping method 'accuracy' not supported with"
" multiple outcomes; use 'loss'.")
early_stop_val = val_acc
else:
early_stop_val = val_loss
assert early_stop_val is not None
assert isinstance(early_stop_val, float)
self.moving_average += [early_stop_val]
if len(self.moving_average) >= self.ema_observations:
# Only keep track of the last [ema_observations]
self.moving_average.pop(0)
if self.last_ema == -1:
# Simple moving average
self.last_ema = (sum(self.moving_average)
/ len(self.moving_average)) # type: ignore
log_msg = f' (SMA: {self.last_ema:.3f})'
else:
alpha = (self.ema_smoothing / (1 + self.ema_observations))
self.last_ema = (early_stop_val * alpha
+ (self.last_ema * (1 - alpha)))
log_msg = f' (EMA: {self.last_ema:.3f})'
if self.neptune_run and self.last_ema != -1:
neptune_dest = "metrics/val/batch/exp_moving_avg"
self.neptune_run[neptune_dest].log(self.last_ema)
if (self.hp.early_stop
and self.ema_two_checks_prior != -1
and self.epoch > self.hp.early_stop_patience):
if ((self.hp.early_stop_method == 'accuracy'
and self.last_ema <= self.ema_two_checks_prior)
or (self.hp.early_stop_method == 'loss'
and self.last_ema >= self.ema_two_checks_prior)):
log.info(f'Early stop triggered: epoch {self.epoch}, '
f'step {self.step}')
self._log_early_stop_to_neptune()
if self.epoch not in self.hp.epochs:
self.hp.epochs += [self.epoch]
self.early_stop = True
return log_msg
self.ema_two_checks_prior = self.ema_one_check_prior
self.ema_one_check_prior = self.last_ema
return ''
def _detect_patients(self, *args):
self.patients = dict()
for dataset in args:
if dataset is None:
continue
dataset_patients = dataset.patients()
if not dataset_patients:
self.patients.update({s: s for s in self.slides})
else:
self.patients.update(dataset_patients)
def _empty_corrects(self) -> Union[int, Dict[str, int]]:
if self.multi_outcome:
return {
f'out-{o}': 0
for o in range(self.num_outcomes)
}
else:
return 0
def _epoch_metrics(
self,
acc: Union[float, List[float]],
loss: float,
label: str
) -> Dict[str, Dict[str, Union[float, List[float]]]]:
epoch_metrics = {'loss': loss} # type: Dict
if self.hp.model_type() == 'classification':
epoch_metrics.update({'accuracy': acc})
return {f'{label}_metrics': epoch_metrics}
def _val_metrics(self, **kwargs) -> Dict[str, Dict[str, float]]:
"""Evaluate model and calculate metrics.
Returns:
Dict[str, Dict[str, float]]: Dict with validation metrics.
Returns metrics in the form:
```
{
'val_metrics': {
'loss': ...,
'accuracy': ...,
},
'tile_auc': ...,
'slide_auc': ...,
...
}
```
"""
if hasattr(self, 'optimizer'):
self.optimizer.zero_grad()
assert self.model is not None
self.model.eval()
results_log = os.path.join(self.outdir, 'results_log.csv')
epoch_results = {}
# Preparations for calculating accuracy/loss in metrics_from_dataset()
def update_corrects(pred, labels, running_corrects=None):
if running_corrects is None:
running_corrects = self._empty_corrects()
if self.hp.model_type() == 'classification':
labels = self._labels_to_device(labels, self.device)
return self._update_corrects(pred, labels, running_corrects)
else:
return 0
def update_loss(pred, labels, running_loss, size):
labels = self._labels_to_device(labels, self.device)
loss = self._calculate_loss(pred, labels, self.loss_fn)
return running_loss + (loss.item() * size)
torch_args = types.SimpleNamespace(
update_corrects=update_corrects,
update_loss=update_loss,
num_slide_features=self.num_slide_features,
slide_input=self.slide_input,
normalizer=(self.normalizer if self._has_gpu_normalizer() else None),
)
# Calculate patient/slide/tile metrics (AUC, R-squared, C-index, etc)
metrics, acc, loss = sf.stats.metrics_from_dataset(
self.inference_model,
model_type=self.hp.model_type(),
patients=self.patients,
dataset=self.dataloaders['val'],
data_dir=self.outdir,
outcome_names=self.outcome_names,
neptune_run=self.neptune_run,
torch_args=torch_args,
uq=bool(self.hp.uq),
**kwargs
)
loss_and_acc = {'loss': loss}
if self.hp.model_type() == 'classification':
loss_and_acc.update({'accuracy': acc})
self._log_epoch(
'val',
self.epoch,
loss,
self._calculate_accuracy(acc)[1] # type: ignore
)
epoch_metrics = {'val_metrics': loss_and_acc}
for metric in metrics:
if metrics[metric]['tile'] is None:
continue
epoch_results[f'tile_{metric}'] = metrics[metric]['tile']
epoch_results[f'slide_{metric}'] = metrics[metric]['slide']
epoch_results[f'patient_{metric}'] = metrics[metric]['patient']
epoch_metrics.update(epoch_results)
sf.util.update_results_log(
results_log,
'trained_model',
{f'epoch{self.epoch}': epoch_metrics}
)
self._log_eval_to_neptune(loss, acc, metrics, epoch_metrics)
return epoch_metrics
def _fit_normalizer(self, norm_fit: Optional[NormFit]) -> None:
"""Fit the Trainer normalizer using the specified fit, if applicable.
Args:
norm_fit (Optional[Dict[str, np.ndarray]]): Normalizer fit.
"""
if norm_fit is not None and not self.normalizer:
raise ValueError("norm_fit supplied, but model params do not"
"specify a normalizer.")
if self.normalizer and norm_fit is not None:
self.normalizer.set_fit(**norm_fit) # type: ignore
elif (self.normalizer
and 'norm_fit' in self.config
and self.config['norm_fit'] is not None):
log.debug("Detecting normalizer fit from model config")
self.normalizer.set_fit(**self.config['norm_fit'])
def _has_gpu_normalizer(self) -> bool:
import slideflow.norm.torch
return (isinstance(self.normalizer, sf.norm.torch.TorchStainNormalizer)
and self.normalizer.device != "cpu")
def _labels_to_device(
self,
labels: Union[Dict[Any, Tensor], Tensor],
device: torch.device
) -> Union[Dict[Any, Tensor], Tensor]:
'''Moves a set of outcome labels to the given device.'''
if self.num_outcomes > 1:
if not isinstance(labels, dict):
raise ValueError("Expected labels to be a dict: num_outcomes"
" is > 1")
labels = {
k: la.to(device, non_blocking=True) for k, la in labels.items()
}
elif isinstance(labels, dict):
labels = torch.stack(list(labels.values()), dim=1)
return labels.to(device, non_blocking=True)
else:
labels = labels.to(device, non_blocking=True)
return labels
def _log_epoch(
self,
phase: str,
epoch: int,
loss: float,
accuracy_desc: str,
) -> None:
"""Logs epoch description."""
log.info(f'[bold blue]{phase}[/] Epoch {epoch} | loss:'
f' {loss:.4f} {accuracy_desc}')
def _log_manifest(
self,
train_dts: Optional["sf.Dataset"],
val_dts: Optional["sf.Dataset"],
labels: Optional[Union[str, Dict]] = 'auto'
) -> None:
"""Log the tfrecord and label manifest to slide_manifest.csv
Args:
train_dts (sf.Dataset): Training dataset. May be None.
val_dts (sf.Dataset): Validation dataset. May be None.
labels (dict, optional): Labels dictionary. May be None.
Defaults to 'auto' (read from self.labels).
"""
if labels == 'auto':
_labels = self.labels
elif labels is None:
_labels = None
else:
assert isinstance(labels, dict)
_labels = labels
log_manifest(
(train_dts.tfrecords() if train_dts else None),
(val_dts.tfrecords() if val_dts else None),
labels=_labels,
filename=join(self.outdir, 'slide_manifest.csv')
)
def _log_to_tensorboard(
self,
loss: float,
acc: Union[float, List[float]],
label: str
) -> None:
self.writer.add_scalar(f'Loss/{label}', loss, self.global_step)
if self.hp.model_type() == 'classification':
if self.num_outcomes > 1:
assert isinstance(acc, list)
for o, _acc in enumerate(acc):
self.writer.add_scalar(
f'Accuracy-{o}/{label}', _acc, self.global_step
)
else:
self.writer.add_scalar(
f'Accuracy/{label}', acc, self.global_step
)
def _log_to_neptune(
self,
loss: float,
acc: Union[Tensor, List[Tensor]],
label: str,
phase: str
) -> None:
"""Logs epoch loss/accuracy to Neptune."""
assert phase in ('batch', 'epoch')
step = self.epoch if phase == 'epoch' else self.global_step
if self.neptune_run:
self.neptune_run[f"metrics/{label}/{phase}/loss"].log(loss,
step=step)
acc = self._accuracy_as_numpy(acc)
if isinstance(acc, list):
for a, _acc in enumerate(acc):
sf.util.neptune_utils.list_log(
run=self.neptune_run,
label=f'metrics/{label}/{phase}/accuracy-{a}',
val=_acc,
step=step
)
else:
sf.util.neptune_utils.list_log(
run=self.neptune_run,
label=f'metrics/{label}/{phase}/accuracy',
val=acc,
step=step
)
def _log_early_stop_to_neptune(self) -> None:
# Log early stop to neptune
if self.neptune_run:
self.neptune_run["early_stop/early_stop_epoch"] = self.epoch
self.neptune_run["early_stop/early_stop_batch"] = self.step
self.neptune_run["early_stop/method"] = self.hp.early_stop_method
self.neptune_run["sys/tags"].add("early_stopped")
def _log_eval_to_neptune(
self,
loss: float,
acc: float,
metrics: Dict[str, Any],
epoch_results: Dict[str, Any]
) -> None:
if self.use_neptune:
assert self.neptune_run is not None
self.neptune_run['results'] = epoch_results
# Validation epoch metrics
self.neptune_run['metrics/val/epoch/loss'].log(loss,
step=self.epoch)
sf.util.neptune_utils.list_log(
self.neptune_run,
'metrics/val/epoch/accuracy',
acc,
step=self.epoch
)
for metric in metrics:
if metrics[metric]['tile'] is None:
continue
for outcome in metrics[metric]['tile']:
# If only one outcome,
# log to metrics/val/epoch/[metric].
# If more than one outcome,
# log to metrics/val/epoch/[metric]/[outcome_name]
def metric_label(s):
if len(metrics[metric]['tile']) == 1:
return f'metrics/val/epoch/{s}_{metric}'
else:
return f'metrics/val/epoch/{s}_{metric}/{outcome}'
tile_metric = metrics[metric]['tile'][outcome]
slide_metric = metrics[metric]['slide'][outcome]
patient_metric = metrics[metric]['patient'][outcome]
# If only one value for a metric, log to .../[metric]
# If more than one value for a metric
# (e.g. AUC for each category),
# log to .../[metric]/[i]
sf.util.neptune_utils.list_log(
self.neptune_run,
metric_label('tile'),
tile_metric,
step=self.epoch
)
sf.util.neptune_utils.list_log(
self.neptune_run,
metric_label('slide'),
slide_metric,
step=self.epoch
)
sf.util.neptune_utils.list_log(
self.neptune_run,
metric_label('patient'),
patient_metric,
step=self.epoch
)
def _mid_training_validation(self) -> None:
"""Perform mid-epoch validation, if appropriate."""
if not self.validate_on_batch:
return
elif not (
'val' in self.dataloaders
and self.step > 0
and self.step % self.validate_on_batch == 0
):
return
if self.model is None or self.inference_model is None:
raise errors.ModelError("Model not yet initialized.")
self.model.eval()
running_val_loss = 0
num_val = 0
running_val_correct = self._empty_corrects()
for _ in range(self.validation_steps):
val_img, val_label, slides, *_ = next(self.mid_train_val_dts) # type:ignore
val_img = val_img.to(self.device)
val_img = val_img.to(memory_format=torch.channels_last)
with torch.inference_mode():
_mp = (self.mixed_precision and self.device.type in ('cuda', 'cpu'))
with autocast(self.device.type, mixed_precision=_mp): # type: ignore
# GPU normalization, if specified.
if self._has_gpu_normalizer():
val_img = self.normalizer.preprocess(val_img)
if self.num_slide_features:
_slide_in = [self.slide_input[s] for s in slides]
inp = (val_img, Tensor(_slide_in).to(self.device))
else:
inp = (val_img,) # type: ignore
val_outputs = self.inference_model(*inp)
val_label = self._labels_to_device(val_label, self.device)
val_batch_loss = self._calculate_loss(
val_outputs, val_label, self.loss_fn
)
running_val_loss += val_batch_loss.item() * val_img.size(0)
if self.hp.model_type() == 'classification':
running_val_correct = self._update_corrects(
val_outputs, val_label, running_val_correct # type: ignore
)
num_val += val_img.size(0)
val_loss = running_val_loss / num_val
if self.hp.model_type() == 'classification':
val_acc, val_acc_desc = self._calculate_accuracy(
running_val_correct, num_val # type: ignore
)
else:
val_acc, val_acc_desc = 0, '' # type: ignore
log_msg = f'Batch {self.step}: val loss: {val_loss:.4f} {val_acc_desc}'
# Log validation metrics to neptune & check early stopping
self._log_to_neptune(val_loss, val_acc, 'val', phase='batch')
log_msg += self._check_early_stopping(
self._accuracy_as_numpy(val_acc),
val_loss
)
log.info(log_msg)
# Log to tensorboard
if self.use_tensorboard:
if self.num_outcomes > 1:
assert isinstance(running_val_correct, dict)
_val_acc = [
running_val_correct[f'out-{o}'] / num_val
for o in range(len(val_outputs))
]
else:
assert not isinstance(running_val_correct, dict)
_val_acc = running_val_correct / num_val # type: ignore
self._log_to_tensorboard(
val_loss,
self._accuracy_as_numpy(_val_acc),
'test'
) # type: ignore
# Put model back in training mode
self.model.train()
def _prepare_optimizers_and_loss(self) -> None:
if self.model is None:
raise ValueError("Model has not yet been initialized.")
self.optimizer = self.hp.get_opt(self.model.parameters())
if self.hp.learning_rate_decay:
self.scheduler = torch.optim.lr_scheduler.ExponentialLR(
self.optimizer,
gamma=self.hp.learning_rate_decay
)
log.debug("Using exponentially decaying learning rate")
else:
self.scheduler = None # type: ignore
self.loss_fn = self.hp.get_loss()
if self.mixed_precision and self.device.type == 'cuda':
self.scaler = torch.cuda.amp.GradScaler()
def _prepare_neptune_run(self, dataset: "sf.Dataset", label: str) -> None:
if self.use_neptune:
tags = [label]
if 'k-fold' in self.config['validation_strategy']:
tags += [f'k-fold{self.config["k_fold_i"]}']
self.neptune_run = self.neptune_logger.start_run(
self.name,
self.config['project'],
dataset,
tags=tags
)
assert self.neptune_run is not None
self.neptune_logger.log_config(self.config, label)
self.neptune_run['data/slide_manifest'].upload(
os.path.join(self.outdir, 'slide_manifest.csv')
)
try:
config_path = join(self.outdir, 'params.json')
config = sf.util.load_json(config_path)
config['neptune_id'] = self.neptune_run['sys/id'].fetch()
except Exception:
log.info("Unable to log params (params.json) with Neptune.")
def _print_model_summary(self, train_dts) -> None:
"""Prints model summary and logs to neptune."""
if self.model is None:
raise ValueError("Model has not yet been initialized.")
empty_inp = [torch.empty(
[self.hp.batch_size, 3, train_dts.tile_px, train_dts.tile_px]
)]
if self.num_slide_features:
empty_inp += [
torch.empty([self.hp.batch_size, self.num_slide_features])
]
if sf.getLoggingLevel() <= 20:
model_summary = torch_utils.print_module_summary(
self.model, empty_inp
)
if self.neptune_run:
self.neptune_run['summary'] = model_summary
def _save_model(self) -> None:
assert self.model is not None
name = self.name if self.name else 'trained_model'
save_path = os.path.join(self.outdir, f'{name}_epoch{self.epoch}.zip')
torch.save(self.model.state_dict(), save_path)
log.info(f"Model saved to [green]{save_path}")
def _close_dataloaders(self):
"""Close dataloaders, ensuring threads have joined."""
del self.mid_train_val_dts
for name, d in self.dataloaders.items():
if '_dataset' in dir(d):
log.debug(f"Closing dataloader {name} via _dataset.close()")
d._dataset.close()
elif 'dataset' in dir(d):
log.debug(f"Closing dataloader {name} via dataset.close()")
d.dataset.close()
def _setup_dataloaders(
self,
train_dts: Optional["sf.Dataset"],
val_dts: Optional["sf.Dataset"],
mid_train_val: bool = False,
incl_labels: bool = True,
from_wsi: bool = False,
**kwargs
) -> None:
"""Prepare dataloaders from training and validation."""
interleave_args = types.SimpleNamespace(
rank=0,
num_replicas=1,
labels=(self.labels if incl_labels else None),
chunk_size=self.chunk_size,
pin_memory=self.pin_memory,
num_workers=self.num_workers if not from_wsi else 0,
onehot=False,
incl_slidenames=True,
from_wsi=from_wsi,
**kwargs
)
# Use GPU stain normalization for PyTorch normalizers, if supported
_augment_str = self.hp.augment
if self._has_gpu_normalizer():
log.info("Using GPU for stain normalization")
interleave_args.standardize = False
if isinstance(_augment_str, str):
_augment_str = _augment_str.replace('n', '')
else:
interleave_args.normalizer = self.normalizer
if train_dts is not None:
self.dataloaders = {
'train': iter(train_dts.torch(
infinite=True,
batch_size=self.hp.batch_size,
augment=_augment_str,
transform=self.transform['train'],
drop_last=True,
**vars(interleave_args)
))
}
else:
self.dataloaders = {}
if val_dts is not None:
if not self.validation_batch_size:
validation_batch_size = self.hp.batch_size
self.dataloaders['val'] = val_dts.torch(
infinite=False,
batch_size=validation_batch_size,
augment=False,
transform=self.transform['val'],
incl_loc=True,
**vars(interleave_args)
)
# Mid-training validation dataset
if mid_train_val:
self.mid_train_val_dts = torch_utils.cycle(
self.dataloaders['val']
)
if not self.validate_on_batch:
val_log_msg = ''
else:
val_log_msg = f'every {str(self.validate_on_batch)} steps and '
log.debug(f'Validation during training: {val_log_msg}at epoch end')
if self.validation_steps:
num_samples = self.validation_steps * self.hp.batch_size
log.debug(
f'Using {self.validation_steps} batches ({num_samples} '
'samples) each validation check'
)
else:
log.debug('Using entire validation set each validation check')
else:
log.debug('Validation during training: None')
def _training_step(self, pb: Progress) -> None:
assert self.model is not None
images, labels, slides = next(self.dataloaders['train'])
images = images.to(self.device, non_blocking=True)
images = images.to(memory_format=torch.channels_last)
labels = self._labels_to_device(labels, self.device)
self.optimizer.zero_grad()
with torch.set_grad_enabled(True):
_mp = (self.mixed_precision and self.device.type in ('cuda', 'cpu'))
with autocast(self.device.type, mixed_precision=_mp): # type: ignore
# GPU normalization, if specified.
if self._has_gpu_normalizer():
images = self.normalizer.preprocess(
images,
augment=(isinstance(self.hp.augment, str)
and 'n' in self.hp.augment)
)
# Slide-level features
if self.num_slide_features:
_slide_in = [self.slide_input[s] for s in slides]
inp = (images, Tensor(_slide_in).to(self.device))
else:
inp = (images,) # type: ignore
outputs = self.model(*inp)
loss = self._calculate_loss(outputs, labels, self.loss_fn)
# Update weights
if self.mixed_precision and self.device.type == 'cuda':
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
else:
loss.backward()
self.optimizer.step()
# Update learning rate if using a scheduler
_lr_decay_steps = self.hp.learning_rate_decay_steps
if self.scheduler and (self.global_step+1) % _lr_decay_steps == 0:
log.debug("Stepping learning rate decay")
self.scheduler.step()
# Record accuracy and loss
self.epoch_records += images.size(0)
if self.hp.model_type() == 'classification':
self.running_corrects = self._update_corrects(
outputs, labels, self.running_corrects
)
train_acc, acc_desc = self._calculate_accuracy(
self.running_corrects, self.epoch_records
)
else:
train_acc, acc_desc = 0, '' # type: ignore
self.running_loss += loss.item() * images.size(0)
_loss = self.running_loss / self.epoch_records
pb.update(task_id=0, # type: ignore
description=(f'[bold blue]train[/] '
f'loss: {_loss:.4f} {acc_desc}'))
pb.advance(task_id=0) # type: ignore
# Log to tensorboard
if self.use_tensorboard and self.global_step % self.log_frequency == 0:
if self.num_outcomes > 1:
_train_acc = [
(self.running_corrects[f'out-{o}'] # type: ignore
/ self.epoch_records)
for o in range(len(outputs))
]
else:
_train_acc = (self.running_corrects # type: ignore
/ self.epoch_records)
self._log_to_tensorboard(
loss.item(),
self._accuracy_as_numpy(_train_acc),
'train'
)
# Log to neptune & check early stopping
self._log_to_neptune(loss.item(), train_acc, 'train', phase='batch')
self._check_early_stopping(None, None)
def _update_corrects(
self,
outputs: Union[Tensor, Dict[Any, Tensor]],
labels: Union[Tensor, Dict[str, Tensor]],
running_corrects: Union[Tensor, Dict[str, Tensor]],
) -> Union[Tensor, Dict[str, Tensor]]:
'''Updates running accuracy in a manner compatible with >1 outcomes.'''
assert self.hp.model_type() == 'classification'
if self.num_outcomes > 1:
for o, out in enumerate(outputs):
_, preds = torch.max(out, 1)
running_corrects[f'out-{o}'] += torch.sum( # type: ignore
preds == labels[f'out-{o}'].data # type: ignore
)
else:
_, preds = torch.max(outputs, 1) # type: ignore
running_corrects += torch.sum(preds == labels.data) # type: ignore
return running_corrects
def _validate_early_stop(self) -> None:
"""Validates early stopping parameters."""
if (self.hp.early_stop and self.hp.early_stop_method == 'accuracy' and
self.hp.model_type() == 'classification' and self.num_outcomes > 1):
raise errors.ModelError("Cannot combine 'accuracy' early stopping "
"with multiple categorical outcomes.")
if (self.hp.early_stop_method == 'manual'
and (self.hp.manual_early_stop_epoch is None
or self.hp.manual_early_stop_batch is None)):
raise errors.ModelError(
"Early stopping method 'manual' requires that both "
"manual_early_stop_epoch and manual_early_stop_batch are set "
"in model params."
)
def _verify_img_format(self, dataset, *datasets: Optional["sf.Dataset"]) -> str:
"""Verify that the image format of the dataset matches the model config.
Args:
dataset (sf.Dataset): Dataset to check.
*datasets (sf.Dataset): Additional datasets to check. May be None.
Returns:
str: Image format, either 'png' or 'jpg', if a consistent image
format was found, otherwise None.
"""
# First, verify all datasets have the same image format
img_formats = set([d.img_format for d in datasets if d])
if len(img_formats) > 1:
log.error("Multiple image formats detected: {}.".format(
', '.join(img_formats)
))
return None
elif self.img_format and not dataset.img_format:
log.warning("Unable to verify image format (PNG/JPG) of dataset.")
return None
elif self.img_format and dataset.img_format != self.img_format:
log.error(
"Mismatched image formats. Expected '{}' per model config, "
"but dataset has format '{}'.".format(
self.img_format,
dataset.img_format))
return None
else:
return dataset.img_format
def load(self, model: str, training=True) -> None:
"""Loads a state dict at the given model location. Requires that the
Trainer's hyperparameters (Trainer.hp)
match the hyperparameters of the model to be loaded."""
if self.labels is not None:
self.model = self.hp.build_model(
labels=self.labels,
num_slide_features=self.num_slide_features
)
else:
self.model = self.hp.build_model(
num_classes=len(self.outcome_names),
num_slide_features=self.num_slide_features
)
self.model.load_state_dict(torch.load(model))
self.inference_model = self.model
def predict(
self,
dataset: "sf.Dataset",
batch_size: Optional[int] = None,
norm_fit: Optional[NormFit] = None,
format: str = 'parquet',
from_wsi: bool = False,
roi_method: str = 'auto',
reduce_method: Union[str, Callable] = 'average',
) -> Dict[str, "pd.DataFrame"]:
"""Perform inference on a model, saving predictions.
Args:
dataset (:class:`slideflow.dataset.Dataset`): Dataset containing
TFRecords to evaluate.
batch_size (int, optional): Evaluation batch size. Defaults to the
same as training (per self.hp)
norm_fit (Dict[str, np.ndarray]): Normalizer fit, mapping fit
parameters (e.g. target_means, target_stds) to values
(np.ndarray). If not provided, will fit normalizer using
model params (if applicable). Defaults to None.
format (str, optional): Format in which to save predictions. Either
'csv', 'feather', or 'parquet'. Defaults to 'parquet'.
from_wsi (bool): Generate predictions from tiles dynamically
extracted from whole-slide images, rather than TFRecords.
Defaults to False (use TFRecords).
roi_method (str): ROI method to use if from_wsi=True (ignored if
from_wsi=False). Either 'inside', 'outside', 'auto', 'ignore'.
If 'inside' or 'outside', will extract tiles in/out of an ROI,
and raise errors.MissingROIError if an ROI is not available.
If 'auto', will extract tiles inside an ROI if available,
and across the whole-slide if no ROI is found.
If 'ignore', will extract tiles across the whole-slide
regardless of whether an ROI is available.
Defaults to 'auto'.
reduce_method (str, optional): Reduction method for calculating
slide-level and patient-level predictions for categorical
outcomes. Options include 'average', 'mean', 'proportion',
'median', 'sum', 'min', 'max', or a callable function.
'average' and 'mean' are synonymous, with both options kept
for backwards compatibility. If 'average' or 'mean', will
reduce with average of each logit across tiles. If
'proportion', will convert tile predictions into onehot encoding
then reduce by averaging these onehot values. For all other
values, will reduce with the specified function, applied via
the pandas ``DataFrame.agg()`` function. Defaults to 'average'.
Returns:
Dict[str, pd.DataFrame]: Dictionary with keys 'tile', 'slide', and
'patient', and values containing DataFrames with tile-, slide-,
and patient-level predictions.
"""
if format not in ('csv', 'feather', 'parquet'):
raise ValueError(f"Unrecognized format {format}")
self._detect_patients(dataset)
# Verify image format
self._verify_img_format(dataset)
# Fit normalizer
self._fit_normalizer(norm_fit)
# Load and initialize model
if not self.model:
raise errors.ModelNotLoadedError
self.model.to(self.device)
self.model.eval()
self._log_manifest(None, dataset, labels=None)
if from_wsi and sf.slide_backend() == 'libvips':
pool = mp.Pool(
sf.util.num_cpu(default=8),
initializer=sf.util.set_ignore_sigint
)
elif from_wsi:
pool = mp.dummy.Pool(sf.util.num_cpu(default=8))
else:
pool = None
if not batch_size:
batch_size = self.hp.batch_size
self._setup_dataloaders(
train_dts=None,
val_dts=dataset,
incl_labels=False,
from_wsi=from_wsi,
roi_method=roi_method,
pool=pool)
log.info('Generating predictions...')
torch_args = types.SimpleNamespace(
num_slide_features=self.num_slide_features,
slide_input=self.slide_input,
normalizer=(self.normalizer if self._has_gpu_normalizer() else None),
)
dfs = sf.stats.predict_dataset(
model=self.model,
dataset=self.dataloaders['val'],
model_type=self._model_type,
torch_args=torch_args,
outcome_names=self.outcome_names,
uq=bool(self.hp.uq),
patients=self.patients,
reduce_method=reduce_method
)
# Save predictions
sf.stats.metrics.save_dfs(dfs, format=format, outdir=self.outdir)
self._close_dataloaders()
if pool is not None:
pool.close()
return dfs
def evaluate(
self,
dataset: "sf.Dataset",
batch_size: Optional[int] = None,
save_predictions: Union[bool, str] = 'parquet',
reduce_method: Union[str, Callable] = 'average',
norm_fit: Optional[NormFit] = None,
uq: Union[bool, str] = 'auto',
from_wsi: bool = False,
roi_method: str = 'auto',
):
"""Evaluate model, saving metrics and predictions.
Args:
dataset (:class:`slideflow.dataset.Dataset`): Dataset to evaluate.
batch_size (int, optional): Evaluation batch size. Defaults to the
same as training (per self.hp)
save_predictions (bool or str, optional): Save tile, slide, and
patient-level predictions at each evaluation. May be 'csv',
'feather', or 'parquet'. If False, will not save predictions.
Defaults to 'parquet'.
reduce_method (str, optional): Reduction method for calculating
slide-level and patient-level predictions for categorical
outcomes. Options include 'average', 'mean', 'proportion',
'median', 'sum', 'min', 'max', or a callable function.
'average' and 'mean' are synonymous, with both options kept
for backwards compatibility. If 'average' or 'mean', will
reduce with average of each logit across tiles. If
'proportion', will convert tile predictions into onehot encoding
then reduce by averaging these onehot values. For all other
values, will reduce with the specified function, applied via
the pandas ``DataFrame.agg()`` function. Defaults to 'average'.
norm_fit (Dict[str, np.ndarray]): Normalizer fit, mapping fit
parameters (e.g. target_means, target_stds) to values
(np.ndarray). If not provided, will fit normalizer using
model params (if applicable). Defaults to None.
uq (bool or str, optional): Enable UQ estimation (for
applicable models). Defaults to 'auto'.
from_wsi (bool): Generate predictions from tiles dynamically
extracted from whole-slide images, rather than TFRecords.
Defaults to False (use TFRecords).
roi_method (str): ROI method to use if from_wsi=True (ignored if
from_wsi=False). Either 'inside', 'outside', 'auto', 'ignore'.
If 'inside' or 'outside', will extract tiles in/out of an ROI,
and raise errors.MissingROIError if an ROI is not available.
If 'auto', will extract tiles inside an ROI if available,
and across the whole-slide if no ROI is found.
If 'ignore', will extract tiles across the whole-slide
regardless of whether an ROI is available.
Defaults to 'auto'.
Returns:
Dictionary of evaluation metrics.
"""
if uq != 'auto':
if not isinstance(uq, bool):
raise ValueError(f"Unrecognized value {uq} for uq")
self.hp.uq = uq
if batch_size:
self.validation_batch_size = batch_size
if not self.model:
raise errors.ModelNotLoadedError
if from_wsi and sf.slide_backend() == 'libvips':
pool = mp.Pool(
sf.util.num_cpu(default=8),
initializer=sf.util.set_ignore_sigint
)
elif from_wsi:
pool = mp.dummy.Pool(sf.util.num_cpu(default=8))
else:
pool = None
self._detect_patients(dataset)
self._verify_img_format(dataset)
self._fit_normalizer(norm_fit)
self.model.to(self.device)
self.model.eval()
self.loss_fn = self.hp.get_loss()
self._log_manifest(None, dataset)
self._prepare_neptune_run(dataset, 'eval')
self._setup_dataloaders(
train_dts=None,
val_dts=dataset,
from_wsi=from_wsi,
roi_method=roi_method,
pool=pool)
# Generate performance metrics
log.info('Performing evaluation...')
metrics = self._val_metrics(
label='eval',
reduce_method=reduce_method,
save_predictions=save_predictions
)
results = {'eval': {
k: v for k, v in metrics.items() if k != 'val_metrics'
}}
results['eval'].update(metrics['val_metrics']) # type: ignore
results_str = json.dumps(results['eval'], indent=2, sort_keys=True)
log.info(f"Evaluation metrics: {results_str}")
results_log = os.path.join(self.outdir, 'results_log.csv')
sf.util.update_results_log(results_log, 'eval_model', results)
if self.neptune_run:
self.neptune_run['eval/results'] = results['eval']
self.neptune_run.stop()
self._close_dataloaders()
if pool is not None:
pool.close()
return results
def train(
self,
train_dts: "sf.Dataset",
val_dts: "sf.Dataset",
log_frequency: int = 20,
validate_on_batch: int = 0,
validation_batch_size: Optional[int] = None,
validation_steps: int = 50,
starting_epoch: int = 0,
ema_observations: int = 20,
ema_smoothing: int = 2,
use_tensorboard: bool = True,
steps_per_epoch_override: int = 0,
save_predictions: Union[bool, str] = 'parquet',
save_model: bool = True,
resume_training: Optional[str] = None,
pretrain: Optional[str] = 'imagenet',
checkpoint: Optional[str] = None,
save_checkpoints: bool = False,
multi_gpu: bool = False,
norm_fit: Optional[NormFit] = None,
reduce_method: Union[str, Callable] = 'average',
seed: int = 0,
from_wsi: bool = False,
roi_method: str = 'auto',
) -> Dict[str, Any]:
"""Builds and trains a model from hyperparameters.
Args:
train_dts (:class:`slideflow.dataset.Dataset`): Training dataset.
val_dts (:class:`slideflow.dataset.Dataset`): Validation dataset.
log_frequency (int, optional): How frequent to update Tensorboard
logs, in batches. Defaults to 100.
validate_on_batch (int, optional): Validation will be performed
every N batches. Defaults to 0.
validation_batch_size (int, optional): Validation batch size.
Defaults to same as training (per self.hp).
validation_steps (int, optional): Number of batches to use for each
instance of validation. Defaults to 200.
starting_epoch (int, optional): Starts training at this epoch.
Defaults to 0.
ema_observations (int, optional): Number of observations over which
to perform exponential moving average smoothing.
Defaults to 20.
ema_smoothing (int, optional): Exponential average smoothing value.
Defaults to 2.
use_tensoboard (bool, optional): Enable tensorboard callbacks.
Defaults to False.
steps_per_epoch_override (int, optional): Manually set the number
of steps per epoch. Defaults to None.
save_predictions (bool or str, optional): Save tile, slide, and
patient-level predictions at each evaluation. May be 'csv',
'feather', or 'parquet'. If False, will not save predictions.
Defaults to 'parquet'.
save_model (bool, optional): Save models when evaluating at
specified epochs. Defaults to False.
resume_training (str, optional): Not applicable to PyTorch backend.
Included as argument for compatibility with Tensorflow backend.
Will raise NotImplementedError if supplied.
pretrain (str, optional): Either 'imagenet' or path to Tensorflow
model from which to load weights. Defaults to 'imagenet'.
checkpoint (str, optional): Path to cp.ckpt from which to load
weights. Defaults to None.
norm_fit (Dict[str, np.ndarray]): Normalizer fit, mapping fit
parameters (e.g. target_means, target_stds) to values
(np.ndarray). If not provided, will fit normalizer using
model params (if applicable). Defaults to None.
reduce_method (str, optional): Reduction method for calculating
slide-level and patient-level predictions for categorical
outcomes. Options include 'average', 'mean', 'proportion',
'median', 'sum', 'min', 'max', or a callable function.
'average' and 'mean' are synonymous, with both options kept
for backwards compatibility. If 'average' or 'mean', will
reduce with average of each logit across tiles. If
'proportion', will convert tile predictions into onehot encoding
then reduce by averaging these onehot values. For all other
values, will reduce with the specified function, applied via
the pandas ``DataFrame.agg()`` function. Defaults to 'average'.
seed (int): Set numpy random seed. Defaults to 0.
from_wsi (bool): Generate predictions from tiles dynamically
extracted from whole-slide images, rather than TFRecords.
Defaults to False (use TFRecords).
roi_method (str): ROI method to use if from_wsi=True (ignored if
from_wsi=False). Either 'inside', 'outside', 'auto', 'ignore'.
If 'inside' or 'outside', will extract tiles in/out of an ROI,
and raise errors.MissingROIError if an ROI is not available.
If 'auto', will extract tiles inside an ROI if available,
and across the whole-slide if no ROI is found.
If 'ignore', will extract tiles across the whole-slide
regardless of whether an ROI is available.
Defaults to 'auto'.
Returns:
Dict: Nested dict containing metrics for each evaluated epoch.
"""
if resume_training is not None:
raise NotImplementedError(
"PyTorch backend does not support `resume_training`; "
"please use `checkpoint`"
)
if save_checkpoints:
log.warning(
"The argument save_checkpoints is ignored when training models "
"in the PyTorch backend. To save a model throughout training, "
"use the `epochs` hyperparameter."
)
results = {'epochs': defaultdict(dict)} # type: Dict[str, Any]
starting_epoch = max(starting_epoch, 1)
self._detect_patients(train_dts, val_dts)
self._reset_training_params()
self.validation_batch_size = validation_batch_size
self.validate_on_batch = validate_on_batch
self.validation_steps = validation_steps
self.ema_observations = ema_observations
self.ema_smoothing = ema_smoothing
self.log_frequency = log_frequency
self.use_tensorboard = use_tensorboard
# Verify image format across datasets.
img_format = self._verify_img_format(train_dts, val_dts)
if img_format and self.config['img_format'] is None:
self.config['img_format'] = img_format
sf.util.write_json(self.config, join(self.outdir, 'params.json'))
if self.use_tensorboard:
from google.protobuf import __version__ as protobuf_version
if version.parse(protobuf_version) >= version.parse('3.21'):
log.warning(
"Tensorboard is incompatible with protobuf >= 3.21."
"Downgrade protobuf to enable tensorboard logging."
)
self.use_tensorboard = False
if from_wsi and sf.slide_backend() == 'libvips':
pool = mp.Pool(
sf.util.num_cpu(default=8),
initializer=sf.util.set_ignore_sigint
)
elif from_wsi:
pool = mp.dummy.Pool(sf.util.num_cpu(default=8))
else:
pool = None
# Validate early stopping parameters
self._validate_early_stop()
# Fit normalizer to dataset, if applicable
self._fit_normalizer(norm_fit)
if self.normalizer and self.hp.normalizer_source == 'dataset':
self.normalizer.fit(train_dts)
if self.normalizer:
config_path = join(self.outdir, 'params.json')
if not os.path.exists(config_path):
config = {
'slideflow_version': sf.__version__,
'hp': self.hp.to_dict(),
'backend': sf.backend()
}
else:
config = sf.util.load_json(config_path)
config['norm_fit'] = self.normalizer.get_fit(as_list=True)
sf.util.write_json(config, config_path)
# Training preparation
if steps_per_epoch_override:
self.steps_per_epoch = steps_per_epoch_override
log.info(f"Setting steps per epoch = {steps_per_epoch_override}")
else:
self.steps_per_epoch = train_dts.num_tiles // self.hp.batch_size
log.info(f"Steps per epoch = {self.steps_per_epoch}")
if self.use_tensorboard:
# Delayed import due to protobuf version conflicts.
from torch.utils.tensorboard import SummaryWriter
self.writer = SummaryWriter(self.outdir, flush_secs=60)
self._log_manifest(train_dts, val_dts)
# Prepare neptune run
self._prepare_neptune_run(train_dts, 'train')
# Build model
self._build_model(checkpoint, pretrain)
assert self.model is not None
# Print model summary
self._print_model_summary(train_dts)
# Multi-GPU
if multi_gpu:
self.model = torch.nn.DataParallel(self.model)
self.model = self.model.to(self.device)
# Setup dataloaders
self._setup_dataloaders(
train_dts=train_dts,
val_dts=val_dts,
mid_train_val=True,
roi_method=roi_method,
from_wsi=from_wsi,
pool=pool)
# Model parameters and optimizer
self._prepare_optimizers_and_loss()
# === Epoch loop ======================================================
for self.epoch in range(starting_epoch, max(self.hp.epochs)+1):
np.random.seed(seed+self.epoch)
log.info(f'[bold]Epoch {self.epoch}/{max(self.hp.epochs)}')
# Training loop ---------------------------------------------------
self.epoch_records = 0
self.running_loss = 0.0
self.step = 1
self.running_corrects = self._empty_corrects() # type: ignore
self.model.train()
pb = Progress(
*Progress.get_default_columns(),
TimeElapsedColumn(),
ImgBatchSpeedColumn(self.hp.batch_size),
transient=sf.getLoggingLevel()>20
)
task = pb.add_task("Training...", total=self.steps_per_epoch)
pb.start()
with sf.util.cleanup_progress(pb):
while self.step <= self.steps_per_epoch:
self._training_step(pb)
if self.early_stop:
break
self._mid_training_validation()
self.step += 1
self.global_step += 1
# Update and log epoch metrics ------------------------------------
loss = self.running_loss / self.epoch_records
epoch_metrics = {'train_metrics': {'loss': loss}}
if self.hp.model_type() == 'classification':
acc, acc_desc = self._calculate_accuracy(
self.running_corrects, self.epoch_records
)
epoch_metrics['train_metrics'].update({
'accuracy': self._accuracy_as_numpy(acc) # type: ignore
})
else:
acc, acc_desc = 0, '' # type: ignore
results['epochs'][f'epoch{self.epoch}'].update(epoch_metrics)
self._log_epoch('train', self.epoch, loss, acc_desc)
self._log_to_neptune(loss, acc, 'train', 'epoch')
if save_model and (self.epoch in self.hp.epochs or self.early_stop):
self._save_model()
# Full evaluation -------------------------------------------------
# Perform full evaluation if the epoch is one of the
# predetermined epochs at which to save/eval a model
if 'val' in self.dataloaders and self.epoch in self.hp.epochs:
epoch_res = self._val_metrics(
save_predictions=save_predictions,
reduce_method=reduce_method,
label=f'val_epoch{self.epoch}',
)
results['epochs'][f'epoch{self.epoch}'].update(epoch_res)
# Early stopping --------------------------------------------------
if self.early_stop:
break
# === [end epoch loop] ================================================
if self.neptune_run:
self.neptune_run['sys/tags'].add('training_complete')
self.neptune_run.stop()
self._close_dataloaders()
if pool is not None:
pool.close()
return results
[docs]class RegressionTrainer(Trainer):
"""Extends the base :class:`slideflow.model.Trainer` class to add support
for continuous outcomes. Requires that all outcomes be continuous, with appropriate
regression loss function. Uses R-squared as the evaluation metric, rather
than AUROC.
In this case, for the PyTorch backend, the continuous outcomes support is
already baked into the base Trainer class, so no additional modifications
are required. This class is written to inherit the Trainer class without
modification to maintain consistency with the Tensorflow backend.
"""
_model_type = 'regression'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]class SurvivalTrainer(Trainer):
"""Cox proportional hazards (CPH) models are not yet implemented, but are
planned for a future update."""
def __init__(self, *args, **kwargs):
raise NotImplementedError
# -----------------------------------------------------------------------------
[docs]class Features(BaseFeatureExtractor):
"""Interface for obtaining predictions and features from intermediate layer
activations from Slideflow models.
Use by calling on either a batch of images (returning outputs for a single
batch), or by calling on a :class:`slideflow.WSI` object, which will
generate an array of spatially-mapped activations matching the slide.
Examples
*Calling on batch of images:*
.. code-block:: python
interface = Features('/model/path', layers='postconv')
for image_batch in train_data:
# Return shape: (batch_size, num_features)
batch_features = interface(image_batch)
*Calling on a slide:*
.. code-block:: python
slide = sf.slide.WSI(...)
interface = Features('/model/path', layers='postconv')
# Return shape:
# (slide.grid.shape[0], slide.grid.shape[1], num_features)
activations_grid = interface(slide)
Note:
When this interface is called on a batch of images, no image processing
or stain normalization will be performed, as it is assumed that
normalization will occur during data loader image processing. When the
interface is called on a `slideflow.WSI`, the normalization strategy
will be read from the model configuration file, and normalization will
be performed on image tiles extracted from the WSI. If this interface
was created from an existing model and there is no model configuration
file to read, a slideflow.norm.StainNormalizer object may be passed
during initialization via the argument `wsi_normalizer`.
"""
def __init__(
self,
path: Optional[str],
layers: Optional[Union[str, List[str]]] = 'postconv',
*,
include_preds: bool = False,
mixed_precision: bool = True,
channels_last: bool = True,
device: Optional[torch.device] = None,
apply_softmax: Optional[bool] = None,
pooling: Optional[Any] = None,
load_method: str = 'weights',
):
"""Creates an activations interface from a saved slideflow model which
outputs feature activations at the designated layers.
Intermediate layers are returned in the order of layers.
predictions are returned last.
Args:
path (str): Path to saved Slideflow model.
layers (list(str), optional): Layers from which to generate
activations. The post-convolution activation layer is accessed
via 'postconv'. Defaults to 'postconv'.
include_preds (bool, optional): Include predictions in output. Will be
returned last. Defaults to False.
mixed_precision (bool, optional): Use mixed precision.
Defaults to True.
device (:class:`torch.device`, optional): Device for model.
Defaults to torch.device('cuda')
apply_softmax (bool): Apply softmax transformation to model output.
Defaults to True for classification models, False for regression models.
pooling (Callable or str, optional): PyTorch pooling function to use
on feature layers. May be a string ('avg' or 'max') or a
callable PyTorch function.
load_method (str): Loading method to use when reading model.
This argument is ignored in the PyTorch backend, as all models
are loaded by first building the model with hyperparameters
detected in ``params.json``, then loading weights with
``torch.nn.Module.load_state_dict()``. Defaults to
'full' (ignored).
"""
super().__init__('torch', include_preds=include_preds)
if layers and isinstance(layers, str):
layers = [layers]
self.layers = layers
self.path = path
self.apply_softmax = apply_softmax
self.mixed_precision = mixed_precision
self.channels_last = channels_last
self._model = None
self._pooling = None
self._include_preds = None
# Transformation for standardizing uint8 images to float32
self.transform = torchvision.transforms.Lambda(lambda x: x / 127.5 - 1)
# Hook for storing layer activations during model inference
self.activation = {} # type: Dict[Any, Tensor]
# Configure device
self.device = torch_utils.get_device(device)
if path is not None:
config = sf.util.get_model_config(path)
if 'img_format' in config:
self.img_format = config['img_format']
self.hp = ModelParams() # type: Optional[ModelParams]
self.hp.load_dict(config['hp'])
self.wsi_normalizer = self.hp.get_normalizer()
if 'norm_fit' in config and config['norm_fit'] is not None:
self.wsi_normalizer.set_fit(**config['norm_fit']) # type: ignore
self.tile_px = self.hp.tile_px
self._model = self.hp.build_model(
num_classes=len(config['outcome_labels'])
)
if apply_softmax is None:
self.apply_softmax = True if config['model_type'] == 'classification' else False
log.debug(f"Using apply_softmax={self.apply_softmax}")
self._model.load_state_dict(torch.load(path))
self._model.to(self.device)
self._model.eval()
if self._model.__class__.__name__ == 'ModelWrapper':
self.model_type = self._model.model.__class__.__name__
else:
self.model_type = self._model.__class__.__name__
self._build(pooling=pooling)
@classmethod
def from_model(
cls,
model: torch.nn.Module,
tile_px: int,
layers: Optional[Union[str, List[str]]] = 'postconv',
*,
include_preds: bool = False,
mixed_precision: bool = True,
channels_last: bool = True,
wsi_normalizer: Optional["StainNormalizer"] = None,
apply_softmax: bool = True,
pooling: Optional[Any] = None
):
"""Creates an activations interface from a loaded slideflow model which
outputs feature activations at the designated layers.
Intermediate layers are returned in the order of layers.
predictions are returned last.
Args:
model (:class:`tensorflow.keras.models.Model`): Loaded model.
tile_px (int): Width/height of input image size.
layers (list(str), optional): Layers from which to generate
activations. The post-convolution activation layer is accessed
via 'postconv'. Defaults to 'postconv'.
include_preds (bool, optional): Include predictions in output. Will be
returned last. Defaults to False.
mixed_precision (bool, optional): Use mixed precision.
Defaults to True.
wsi_normalizer (:class:`slideflow.norm.StainNormalizer`): Stain
normalizer to use on whole-slide images. Is not used on
individual tile datasets via __call__. Defaults to None.
apply_softmax (bool): Apply softmax transformation to model output.
Defaults to True.
pooling (Callable or str, optional): PyTorch pooling function to use
on feature layers. May be a string ('avg' or 'max') or a
callable PyTorch function.
"""
device = next(model.parameters()).device
if include_preds is not None:
kw = dict(include_preds=include_preds)
else:
kw = dict()
obj = cls(
None,
layers,
mixed_precision=mixed_precision,
channels_last=channels_last,
device=device,
**kw
)
if isinstance(model, torch.nn.Module):
obj._model = model
obj._model.eval()
else:
raise errors.ModelError("Model is not a valid PyTorch model.")
obj.hp = None
if obj._model.__class__.__name__ == 'ModelWrapper':
obj.model_type = obj._model.model.__class__.__name__
else:
obj.model_type = obj._model.__class__.__name__
obj.tile_px = tile_px
obj.wsi_normalizer = wsi_normalizer
obj.apply_softmax = apply_softmax
obj._build(pooling=pooling)
return obj
def __call__(
self,
inp: Union[Tensor, "sf.WSI"],
**kwargs
) -> Optional[Union[List[Tensor], np.ndarray]]:
"""Process a given input and return activations and/or predictions. Expects
either a batch of images or a :class:`slideflow.slide.WSI` object.
When calling on a `WSI` object, keyword arguments are passed to
:meth:`slideflow.WSI.build_generator()`.
"""
if isinstance(inp, sf.slide.WSI):
return self._predict_slide(inp, **kwargs)
else:
return self._predict(inp, **kwargs)
def __repr__(self):
return ("{}(\n".format(self.__class__.__name__) +
" path={!r},\n".format(self.path) +
" layers={!r},\n".format(self.layers) +
" include_preds={!r},\n".format(self.include_preds) +
" apply_softmax={!r},\n".format(self.apply_softmax) +
" pooling={!r},\n".format(self._pooling) +
")")
def _predict_slide(
self,
slide: "sf.WSI",
*,
img_format: str = 'auto',
batch_size: int = 32,
dtype: type = np.float16,
grid: Optional[np.ndarray] = None,
shuffle: bool = False,
show_progress: bool = True,
callback: Optional[Callable] = None,
normalizer: Optional[Union[str, "StainNormalizer"]] = None,
normalizer_source: Optional[str] = None,
**kwargs
) -> Optional[np.ndarray]:
"""Generate activations from slide => activation grid array."""
# Check image format
if img_format == 'auto' and self.img_format is None:
raise ValueError(
'Unable to auto-detect image format (png or jpg). Set the '
'format by passing img_format=... to the call function.'
)
elif img_format == 'auto':
assert self.img_format is not None
img_format = self.img_format
return sf.model.extractors.features_from_slide(
self,
slide,
img_format=img_format,
batch_size=batch_size,
dtype=dtype,
grid=grid,
shuffle=shuffle,
show_progress=show_progress,
callback=callback,
normalizer=(normalizer if normalizer else self.wsi_normalizer),
normalizer_source=normalizer_source,
preprocess_fn=self.transform,
**kwargs
)
def _predict(self, inp: Tensor, no_grad: bool = True) -> List[Tensor]:
"""Return activations for a single batch of images."""
assert torch.is_floating_point(inp), "Input tensor must be float"
_mp = (self.mixed_precision and self.device.type in ('cuda', 'cpu'))
with autocast(self.device.type, mixed_precision=_mp): # type: ignore
with torch.inference_mode() if no_grad else no_scope():
inp = inp.to(self.device)
if self.channels_last:
inp = inp.to(memory_format=torch.channels_last)
logits = self._model(inp)
if isinstance(logits, (tuple, list)) and self.apply_softmax:
logits = [softmax(l, dim=1) for l in logits]
elif self.apply_softmax:
logits = softmax(logits, dim=1)
layer_activations = []
if self.layers:
for la in self.layers:
act = self.activation[la]
if la == 'postconv':
act = self._postconv_processing(act)
layer_activations.append(act)
if self.include_preds:
layer_activations += [logits]
self.activation = {}
return layer_activations
def _get_postconv(self):
"""Returns post-convolutional layer."""
if self.model_type == 'ViT':
return self._model.to_latent
if self.model_type in ('ResNet', 'Inception3', 'GoogLeNet'):
return self._model.avgpool
if self.model_type in ('AlexNet', 'SqueezeNet', 'VGG', 'MobileNetV2',
'MobileNetV3', 'MNASNet'):
if self._model.classifier.__class__.__name__ == 'Identity':
return self._model.classifier
else:
return next(self._model.classifier.children())
if self.model_type == 'DenseNet':
return self._model.features.norm5
if self.model_type == 'ShuffleNetV2':
return list(self._model.conv5.children())[1]
if self.model_type == 'Xception':
return self._model.bn4
raise errors.FeaturesError(f"'postconv' layer not configured for "
f"model type {self.model_type}")
def _postconv_processing(self, output: Tensor) -> Tensor:
"""Applies processing (pooling, resizing) to post-conv outputs,
to convert output to the shape (batch_size, num_features)"""
def pool(x):
return torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
def squeeze(x):
return x.view(x.size(0), -1)
if self.model_type in ('ViT', 'AlexNet', 'VGG', 'MobileNetV2',
'MobileNetV3', 'MNASNet'):
return output
if self.model_type in ('ResNet', 'Inception3', 'GoogLeNet'):
return squeeze(output)
if self.model_type in ('SqueezeNet', 'DenseNet', 'ShuffleNetV2',
'Xception'):
return squeeze(pool(output))
return output
def _build(self, pooling: Optional[Any] = None) -> None:
"""Builds the interface model that outputs feature activations at the
designated layers and/or predictions. Intermediate layers are returned in
the order of layers. predictions are returned last.
Args:
pooling (Callable or str, optional): PyTorch pooling function to use
on feature layers. May be a string ('avg' or 'max') or a
callable PyTorch function.
"""
self._pooling = pooling
if isinstance(pooling, str):
if pooling == 'avg':
pooling = lambda x: torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
elif pooling == 'max':
pooling = lambda x: torch.nn.functional.adaptive_max_pool2d(x, (1, 1))
else:
raise ValueError(f"Unrecognized pooling value {pooling}. "
"Expected 'avg', 'max', or custom Tensor op.")
self.activation = {}
def squeeze(x):
return x.view(x.size(0), -1)
def get_activation(name):
def hook(model, input, output):
if len(output.shape) == 4 and pooling is not None:
self.activation[name] = squeeze(pooling(output)).detach()
else:
self.activation[name] = output.detach()
return hook
if isinstance(self.layers, list):
for la in self.layers:
if la == 'postconv':
self._get_postconv().register_forward_hook(
get_activation('postconv')
)
else:
la_out = torch_utils.get_module_by_name(self._model, la)
la_out.register_forward_hook(
get_activation(la)
)
elif self.layers is not None:
raise errors.FeaturesError(f"Unrecognized type {type(self.layers)}"
" for self.layers")
# Calculate output and layer sizes
rand_data = torch.rand(1, 3, self.tile_px, self.tile_px)
output = self._model(rand_data.to(self.device))
if isinstance(output, (tuple, list)) and self.include_preds:
log.warning("Multi-categorical outcomes is experimental "
"for this interface.")
self.num_classes = sum(o.shape[1] for o in output)
self.num_outputs = len(output)
elif self.include_preds:
self.num_classes = output.shape[1]
self.num_outputs = 1
else:
self.num_classes = 0
self.num_outputs = 0
self.num_features = sum([f.shape[1] for f in self.activation.values()])
if self.include_preds:
log.debug(f'Number of classes: {self.num_classes}')
log.debug(f'Number of activation features: {self.num_features}')
def dump_config(self):
return {
'class': 'slideflow.model.torch.Features',
'kwargs': {
'path': self.path,
'layers': self.layers,
'include_preds': self.include_preds,
'apply_softmax': self.apply_softmax,
'pooling': self._pooling
}
}
class UncertaintyInterface(Features):
def __init__(
self,
path: Optional[str],
layers: Optional[Union[str, List[str]]] = 'postconv',
*,
mixed_precision: bool = True,
channels_last: bool = True,
device: Optional[torch.device] = None,
apply_softmax: Optional[bool] = None,
pooling: Optional[Any] = None,
load_method: str = 'weights',
) -> None:
super().__init__(
path,
layers=layers,
mixed_precision=mixed_precision,
channels_last=channels_last,
device=device,
apply_softmax=apply_softmax,
pooling=pooling,
load_method=load_method,
include_preds=True
)
if self._model is not None:
torch_utils.enable_dropout(self._model)
# TODO: As the below to-do suggests, this should be updated
# for multi-class
self.num_uncertainty = 1
if self.num_classes > 2:
log.warn("UncertaintyInterface not yet implemented for multi-class"
" models")
@classmethod
def from_model(cls, *args, **kwargs):
if 'include_preds' in kwargs and not kwargs['include_preds']:
raise ValueError("UncertaintyInterface requires include_preds=True")
kwargs['include_preds'] = None
obj = super().from_model(*args, **kwargs)
torch_utils.enable_dropout(obj._model)
return obj
def __repr__(self):
return ("{}(\n".format(self.__class__.__name__) +
" path={!r},\n".format(self.path) +
" layers={!r},\n".format(self.layers) +
" apply_softmax={!r},\n".format(self.apply_softmax) +
" pooling={!r},\n".format(self._pooling) +
")")
def _predict(self, inp: Tensor, no_grad: bool = True) -> List[Tensor]:
"""Return activations (mean), predictions (mean), and uncertainty
(stdev) for a single batch of images."""
assert torch.is_floating_point(inp), "Input tensor must be float"
_mp = (self.mixed_precision and self.device.type in ('cuda', 'cpu'))
out_pred_drop = [[] for _ in range(self.num_outputs)]
if self.layers:
out_act_drop = [[] for _ in range(len(self.layers))]
for _ in range(30):
with autocast(self.device.type, mixed_precision=_mp): # type: ignore
with torch.inference_mode() if no_grad else no_scope():
inp = inp.to(self.device)
if self.channels_last:
inp = inp.to(memory_format=torch.channels_last)
logits = self._model(inp)
if isinstance(logits, (tuple, list)) and self.apply_softmax:
logits = [softmax(l, dim=1) for l in logits]
elif self.apply_softmax:
logits = softmax(logits, dim=1)
for n in range(self.num_outputs):
out_pred_drop[n] += [
(logits[n] if self.num_outputs > 1 else logits)
]
layer_activations = []
if self.layers:
for la in self.layers:
act = self.activation[la]
if la == 'postconv':
act = self._postconv_processing(act)
layer_activations.append(act)
for n in range(len(self.layers)):
out_act_drop[n].append(layer_activations[n]
)
self.activation = {}
for n in range(self.num_outputs):
out_pred_drop[n] = torch.stack(out_pred_drop[n], axis=0)
predictions = torch.mean(torch.cat(out_pred_drop), dim=0)
# TODO: Only takes STDEV from first outcome category which works for
# outcomes with 2 categories, but a better solution is needed
# for num_categories > 2
uncertainty = torch.std(torch.cat(out_pred_drop), dim=0)[:, 0]
uncertainty = torch.unsqueeze(uncertainty, axis=-1)
if self.layers:
for n in range(self.layers):
out_act_drop[n] = torch.stack(out_act_drop[n], axis=0)
reduced_activations = [
torch.mean(out_act_drop[n], dim=0)
for n in range(len(self.layers))
]
return reduced_activations + [predictions, uncertainty]
else:
return predictions, uncertainty
def dump_config(self):
return {
'class': 'slideflow.model.torch.UncertaintyInterface',
'kwargs': {
'path': self.path,
'layers': self.layers,
'apply_softmax': self.apply_softmax,
'pooling': self._pooling
}
}
# -----------------------------------------------------------------------------
[docs]def load(path: str) -> torch.nn.Module:
"""Load a model trained with Slideflow.
Args:
path (str): Path to saved model. Must be a model trained in Slideflow.
Returns:
torch.nn.Module: Loaded model.
"""
config = sf.util.get_model_config(path)
hp = ModelParams.from_dict(config['hp'])
if len(config['outcomes']) == 1 or config['model_type'] == 'regression':
num_classes = len(list(config['outcome_labels'].keys()))
else:
num_classes = {
outcome: len(list(config['outcome_labels'][outcome].keys()))
for outcome in config['outcomes']
}
model = hp.build_model(
num_classes=num_classes,
num_slide_features=0 if not config['input_feature_sizes'] else sum(config['input_feature_sizes']),
pretrain=None
)
if not torch.cuda.is_available():
kw = dict(map_location=torch.device('cpu'))
else:
kw = dict()
model.load_state_dict(torch.load(path, **kw))
return model
[docs]def lazy_load_pretrained(
module: torch.nn.Module,
to_load: str
) -> None:
"""Loads pretrained model weights into an existing module, ignoring
incompatible Tensors.
Args:
module (torch.nn.Module): Destination module for weights.
to_load (str, torch.nn.Module): Module with weights to load. Either
path to PyTorch Slideflow model, or an existing PyTorch module.
Returns:
None
"""
# Get state dictionaries
current_model_dict = module.state_dict()
if isinstance(to_load, str):
loaded_state_dict = torch.load(to_load)
else:
loaded_state_dict = to_load.state_dict()
# Only transfer valid states
new_state_dict = {k:v if v.size()==current_model_dict[k].size()
else current_model_dict[k]
for k,v in zip(current_model_dict.keys(),
loaded_state_dict.values())}
n_states = len(list(new_state_dict.keys()))
log.info(f"Loaded {n_states} Tensor states from "
f"pretrained model [green] {to_load}")
module.load_state_dict(new_state_dict, strict=False)