slideflow.mil¶
This submodule contains tools for multiple-instance learning (MIL) model training and evaluation. See Multiple-Instance Learning (MIL) for more information. A summary of the API is given below.
- Training:
train_mil()
: Train an MIL model, using an MIL configuration, Datasets, and a directory of bags.build_fastai_learner()
: Build and return the FastAI Learner, but do not execute training. Useful for customizing training.build_multimodal_learner()
: Build and return a FastAI Learner designed for multi-modal/multi-magnification input.
- Evaluation/Inference:
eval_mil()
: Evaluate an MIL model using a path to a saved model, a Dataset, and path to bags. Generates metrics.predict_mil()
: Generate predictions from an MIL model and saved bags. Returns a pandas dataframe.predict_multimodal_mil()
: Generate predictions from a multimodal MIL model. Returns a dataframe.predict_slide()
: Generate MIL predictions for a single slide. Returns a 2D array of predictions and attention.predict_from_bags()
: Low-level interface for generating predictions from a loaded MIL model and pre-loaded bag Tensors.predict_from_multimodal_bags()
: Low-level interface for generating multimodal predictions from a loaded MIL model and bag Tensors.get_mil_tile_predictions()
: Get tile-level predictions and attention from a saved MIL model for a given Dataset and saved bags.generate_attention_heatmaps()
: Generate and save attention heatmaps.generate_mil_features()
: Get last-layer activations from an MIL model. Returns an MILFeatures object.
Main functions¶
- mil_config(model: str | Callable, trainer: str = 'fastai', **kwargs)[source]¶
Create a multiple-instance learning (MIL) training configuration.
All models by default are trained with the FastAI trainer. Additional trainers and additional models can be installed with
slideflow-extras
.- Parameters:
model (str, Callable) – Either the name of a model, or a custom torch module. Valid model names include
"attention_mil"
,"transmil"
, and"bistro.transformer"
.trainer (str) – Type of MIL trainer to use. Only ‘fastai’ is available, unless additional trainers are installed.
**kwargs – All additional keyword arguments are passed to
slideflow.mil.TrainerConfig
- train_mil(config: TrainerConfig, train_dataset: Dataset, val_dataset: Dataset | None, outcomes: str | List[str], bags: str | List[str], *, outdir: str = 'mil', exp_label: str | None = None, **kwargs) Learner [source]¶
Train a multiple-instance learning (MIL) model.
This high-level trainer facilitates training from a given MIL configuration, using Datasets as input and with input features taken from a given directory of bags.
- Parameters:
config (
slideflow.mil.TrainerConfig
) – Trainer and model configuration.train_dataset (
slideflow.Dataset
) – Training dataset.val_dataset (
slideflow.Dataset
) – Validation dataset.outcomes (str) – Outcome column (annotation header) from which to derive category labels.
bags (str) – Either a path to directory with *.pt files, or a list of paths to individual *.pt files. Each file should contain exported feature vectors, with each file containing all tile features for one patient.
- Keyword Arguments:
outdir (str) – Directory in which to save model and results.
exp_label (str) – Experiment label, used for naming the subdirectory in the
{project root}/mil
folder, where training history and the model will be saved.attention_heatmaps (bool) – Generate attention heatmaps for slides. Not available for multi-modal MIL models. Defaults to False.
interpolation (str, optional) – Interpolation strategy for smoothing attention heatmaps. Defaults to ‘bicubic’.
cmap (str, optional) – Matplotlib colormap for heatmap. Can be any valid matplotlib colormap. Defaults to ‘inferno’.
norm (str, optional) – Normalization strategy for assigning heatmap values to colors. Either ‘two_slope’, or any other valid value for the
norm
argument ofmatplotlib.pyplot.imshow
. If ‘two_slope’, normalizes values less than 0 and greater than 0 separately. Defaults to None.
- build_fastai_learner(config: TrainerConfig, train_dataset: Dataset, val_dataset: Dataset, outcomes: str | List[str], bags: str | ndarray | List[str], *, outdir: str = 'mil', return_shape: bool = False, **kwargs) Learner [source]¶
Build a FastAI Learner for training an MIL model.
Does not execute training. Useful for customizing a Learner object prior to training.
- Parameters:
train_dataset (
slideflow.Dataset
) – Training dataset.val_dataset (
slideflow.Dataset
) – Validation dataset.outcomes (str) – Outcome column (annotation header) from which to derive category labels.
bags (str) – list of paths to individual *.pt files. Each file should contain exported feature vectors, with each file containing all tile features for one patient.
- Keyword Arguments:
outdir (str) – Directory in which to save model and results.
return_shape (bool) – Return the input and output shapes of the model. Defaults to False.
exp_label (str) – Experiment label, used for naming the subdirectory in the
outdir
folder, where training history and the model will be saved.lr (float) – Learning rate, or maximum learning rate if
fit_one_cycle=True
.epochs (int) – Maximum epochs.
**kwargs – Additional keyword arguments to pass to the FastAI learner.
- Returns:
fastai.learner.Learner, and optionally a tuple of input and output shapes if
return_shape=True
.
- build_multimodal_learner(config: TrainerConfig, train_dataset: Dataset, val_dataset: Dataset, outcomes: str | List[str], bags: ndarray | List[str], *, outdir: str = 'mil', return_shape: bool = False) Learner [source]¶
Build a multi-magnification FastAI Learner for training an MIL model.
Does not execute training. Useful for customizing a Learner object prior to training.
- Parameters:
train_dataset (
slideflow.Dataset
) – Training dataset.val_dataset (
slideflow.Dataset
) – Validation dataset.outcomes (str) – Outcome column (annotation header) from which to derive category labels.
bags (list(str)) – List of bag directories containing *.pt files, one directory for each mode.
- Keyword Arguments:
outdir (str) – Directory in which to save model and results.
return_shape (bool) – Return the input and output shapes of the model. Defaults to False.
exp_label (str) – Experiment label, used for naming the subdirectory in the
outdir
folder, where training history and the model will be saved.lr (float) – Learning rate, or maximum learning rate if
fit_one_cycle=True
.epochs (int) – Maximum epochs.
**kwargs – Additional keyword arguments to pass to the FastAI learner.
- Returns:
fastai.learner.Learner, and optionally a tuple of input and output shapes if
return_shape=True
.
- eval_mil(weights: str, dataset: Dataset, outcomes: str | List[str], bags: str | List[str], config: TrainerConfig | None = None, *, outdir: str = 'mil', attention_heatmaps: bool = False, uq: bool = False, aggregation_level: str | None = None, **heatmap_kwargs) DataFrame [source]¶
Evaluate a multiple-instance learning model.
Saves results for the evaluation in the target folder, including predictions (parquet format), attention (Numpy format for each slide), and attention heatmaps (if
attention_heatmaps=True
).Logs classifier metrics (AUROC and AP) to the console.
- Parameters:
weights (str) – Path to model weights to load.
dataset (sf.Dataset) – Dataset to evaluation.
bags (str, list(str)) – Path to bags, or list of bag file paths. Each bag should contain PyTorch array of features from all tiles in a slide, with the shape
(n_tiles, n_features)
.config (
slideflow.mil.TrainerConfig
) – Configuration for building model. Ifweights
is a path to a model directory, will attempt to readmil_params.json
from this location and load saved configuration. Defaults to None.
- Keyword Arguments:
outdir (str) – Path at which to save results.
attention_heatmaps (bool) – Generate attention heatmaps for slides. Not available for multi-modal MIL models. Defaults to False.
interpolation (str, optional) – Interpolation strategy for smoothing attention heatmaps. Defaults to ‘bicubic’.
aggregation_level (str, optional) – Aggregation level for predictions. Either ‘slide’ or ‘patient’. Defaults to None (uses the model configuration).
cmap (str, optional) – Matplotlib colormap for heatmap. Can be any valid matplotlib colormap. Defaults to ‘inferno’.
norm (str, optional) – Normalization strategy for assigning heatmap values to colors. Either ‘two_slope’, or any other valid value for the
norm
argument ofmatplotlib.pyplot.imshow
. If ‘two_slope’, normalizes values less than 0 and greater than 0 separately. Defaults to None.
- predict_mil(model: str | Callable, dataset: Dataset, outcomes: str | List[str], bags: str | ndarray | List[str], *, config: TrainerConfig | None = None, attention: bool = False, aggregation_level: str | None = None, **kwargs) DataFrame | Tuple[DataFrame, List[ndarray]] [source]¶
Generate predictions for a dataset from a saved MIL model.
- Parameters:
model (torch.nn.Module) – Model from which to generate predictions.
dataset (sf.Dataset) – Dataset from which to generation predictions.
bags (str, list(str)) – Path to bags, or list of bag file paths. Each bag should contain PyTorch array of features from all tiles in a slide, with the shape
(n_tiles, n_features)
.
- Keyword Arguments:
config (
slideflow.mil.TrainerConfig
) – Configuration for the MIL model. Required if model is a loadedtorch.nn.Module
. Defaults to None.attention (bool) – Whether to calculate attention scores. Defaults to False.
uq (bool) – Whether to generate uncertainty estimates. Experimental. Defaults to False.
aggregation_level (str) – Aggregation level for predictions. Either ‘slide’ or ‘patient’. Defaults to None.
attention_pooling (str) – Attention pooling strategy. Either ‘avg’ or ‘max’. Defaults to None.
- Returns:
Dataframe of predictions.
list(np.ndarray): Attention scores (if
attention=True
)- Return type:
pd.DataFrame
- predict_multimodal_mil(model: str | Callable, dataset: Dataset, outcomes: str | List[str], bags: ndarray | List[List[str]], *, config: TrainerConfig | None = None, attention: bool = False, aggregation_level: str | None = None, **kwargs) DataFrame | Tuple[DataFrame, List[ndarray]] [source]¶
Generate predictions for a dataset from a saved multimodal MIL model.
- Parameters:
model (torch.nn.Module) – Model from which to generate predictions.
dataset (sf.Dataset) – Dataset from which to generation predictions.
bags (str, list(str)) – Path to bags, or list of bag file paths. Each bag should contain PyTorch array of features from all tiles in a slide, with the shape
(n_tiles, n_features)
.
- Keyword Arguments:
config (
slideflow.mil.TrainerConfig
) – Configuration for the MIL model. Required if model is a loadedtorch.nn.Module
. Defaults to None.attention (bool) – Whether to calculate attention scores. Defaults to False.
uq (bool) – Whether to generate uncertainty estimates. Defaults to False.
aggregation_level (str) – Aggregation level for predictions. Either ‘slide’ or ‘patient’. Defaults to None.
attention_pooling (str) – Attention pooling strategy. Either ‘avg’ or ‘max’. Defaults to None.
- Returns:
Dataframe of predictions.
list(np.ndarray): Attention scores (if
attention=True
)- Return type:
pd.DataFrame
- predict_from_bags(model: torch.nn.Module, bags: ndarray | List[str], *, attention: bool = False, attention_pooling: str | None = None, use_lens: bool = False, device: Any | None = None, apply_softmax: bool | None = None, uq: bool = False) Tuple[ndarray, List[ndarray]] [source]¶
Generate MIL predictions for a list of bags.
Predictions are generated for each bag in the list one at a time, and not batched.
- Parameters:
- Keyword Arguments:
attention (bool) – Whether to calculate attention scores. Defaults to False.
attention_pooling (str, optional) – Pooling strategy for attention scores. Can be ‘avg’, ‘max’, or None. Defaults to None.
use_lens (bool) – Whether to use the length of each bag as an additional input to the model. Defaults to False.
device (str, optional) – Device on which to run inference. Defaults to None.
apply_softmax (bool) – Whether to apply softmax to the model output. Defaults to True for categorical outcomes, False for continuous outcomes.
uq (bool) – Whether to generate uncertainty estimates. Defaults to False.
- Returns:
Predictions and attention scores.
- Return type:
Tuple[np.ndarray, List[np.ndarray]]
- predict_from_multimodal_bags(model: torch.nn.Module, bags: List[ndarray] | List[List[str]], *, attention: bool = True, attention_pooling: str | None = None, use_lens: bool = True, device: Any | None = None, apply_softmax: bool | None = None) Tuple[ndarray, List[List[ndarray]]] [source]¶
Generate multi-mag MIL predictions for a nested list of bags.
- Parameters:
- Keyword Arguments:
attention (bool) – Whether to calculate attention scores. Defaults to False.
attention_pooling (str, optional) – Pooling strategy for attention scores. Can be ‘avg’, ‘max’, or None. Defaults to None.
use_lens (bool) – Whether to use the length of each bag as an additional input to the model. Defaults to False.
device (str, optional) – Device on which to run inference. Defaults to None.
apply_softmax (bool) – Whether to apply softmax to the model output. Defaults to True for categorical outcomes, False for continuous
- Returns:
Predictions and attention scores.
- Return type:
Tuple[np.ndarray, List[List[np.ndarray]]]
- predict_slide(model: str, slide: str | WSI, extractor: BaseFeatureExtractor | None = None, *, normalizer: StainNormalizer | None = None, config: TrainerConfig | None = None, attention: bool = False, native_normalizer: bool | None = True, extractor_kwargs: dict | None = None, **kwargs) Tuple[ndarray, ndarray | None] [source]¶
Generate predictions (and attention) for a single slide.
- Parameters:
model (str) – Path to MIL model.
slide (str) – Path to slide.
extractor (
slideflow.mil.BaseFeatureExtractor
, optional) –Feature extractor. If not provided, will attempt to auto-detect extractor from model.
Note
If the extractor has a stain normalizer, this will be used to normalize the slide before extracting features.
- Keyword Arguments:
normalizer (
slideflow.stain.StainNormalizer
, optional) – Stain normalizer. If not provided, will attempt to use stain normalizer from extractor.config (
slideflow.mil.TrainerConfig
) – Configuration for building model. If None, will attempt to readmil_params.json
from the model directory and load saved configuration. Defaults to None.attention (bool) – Whether to return attention scores. Defaults to False.
attention_pooling (str) – Attention pooling strategy. Either ‘avg’ or ‘max’. Defaults to None.
native_normalizer (bool, optional) – Whether to use PyTorch/Tensorflow-native stain normalization, if applicable. If False, will use the OpenCV/Numpy implementations. Defaults to None, which auto-detects based on the slide backend (False if libvips, True if cucim). This behavior is due to performance issued when using native stain normalization with libvips-compatible multiprocessing.
- Returns:
Predictions and attention scores. Attention scores are None if
attention
is False. For single-channel attention, this is a masked 2D array with the same shape as the slide grid (arranged as a heatmap, with unused tiles masked). For multi-channel attention, this is a masked 3D array with shape(n_channels, X, Y)
.- Return type:
Tuple[np.ndarray, Optional[np.ndarray]]
- get_mil_tile_predictions(weights: str, dataset: Dataset, bags: str | ndarray | List[str], *, config: TrainerConfig | None = None, outcomes: str | List[str] | None = None, dest: str | None = None, uq: bool = False, device: Any | None = None, tile_batch_size: int = 512, **kwargs) DataFrame [source]¶
Generate tile-level predictions for a MIL model.
- Parameters:
weights (str) – Path to model weights to load.
dataset (
slideflow.Dataset
) – Dataset.bags (str, list(str)) – Path to bags, or list of bag file paths. Each bag should contain PyTorch array of features from all tiles in a slide, with the shape
(n_tiles, n_features)
.
- Keyword Arguments:
config (
slideflow.mil.TrainerConfig
) – Configuration for building model. Ifweights
is a path to a model directory, will attempt to readmil_params.json
from this location and load saved configuration. Defaults to None.dest (str) – Path at which to save tile predictions.
uq (bool) – Whether to generate uncertainty estimates. Defaults to False.
device (str, optional) – Device on which to run inference. Defaults to None.
tile_batch_size (int) – Batch size for tile-level predictions. Defaults to 512.
attention_pooling (str) – Attention pooling strategy. Either ‘avg’ or ‘max’. Defaults to None.
- Returns:
Dataframe of tile predictions.
- Return type:
pd.DataFrame
- generate_attention_heatmaps(outdir: str, dataset: Dataset, bags: List[str] | ndarray, attention: ndarray | List[ndarray], **kwargs) None [source]¶
Generate and save attention heatmaps for a dataset.
- Parameters:
outdir (str) – Path at which to save heatmap images.
dataset (sf.Dataset) – Dataset.
bags (str, list(str)) – List of bag file paths. Each bag should contain PyTorch array of features from all tiles in a slide, with the shape
(n_tiles, n_features)
.attention (list(np.ndarray)) – Attention scores for each slide. Length of
attention
should equal the length ofbags
.
- Keyword Arguments:
interpolation (str, optional) – Interpolation strategy for smoothing heatmap. Defaults to ‘bicubic’.
cmap (str, optional) – Matplotlib colormap for heatmap. Can be any valid matplotlib colormap. Defaults to ‘inferno’.
norm (str, optional) – Normalization strategy for assigning heatmap values to colors. Either ‘two_slope’, or any other valid value for the
norm
argument ofmatplotlib.pyplot.imshow
. If ‘two_slope’, normalizes values less than 0 and greater than 0 separately. Defaults to None.
- generate_mil_features(weights: str, dataset: sf.Dataset, bags: str | ndarray | List[str], *, config: TrainerConfig | None = None) MILFeatures [source]¶
Generate activations weights from the last layer of an MIL model.
Returns MILFeatures object.
- Parameters:
weights (str) – Path to model weights to load.
config (
slideflow.mil.TrainerConfig
) – Configuration for building model. Ifweights
is a path to a model directory, will attempt to readmil_params.json
from this location and load saved configuration. Defaults to None.dataset (
slideflow.Dataset
) – Dataset.bags (str, list(str)) – Path to bags, or list of bag file paths. Each bag should contain PyTorch array of features from all tiles in a slide, with the shape
(n_tiles, n_features)
.
TrainerConfig¶
- class TrainerConfig(model: str | Callable = 'attention_mil', *, aggregation_level: str = 'slide', lr: float | None = None, wd: float = 1e-05, bag_size: int = 512, max_val_bag_size: int | None = None, fit_one_cycle: bool = True, epochs: int = 32, batch_size: int = 64, drop_last: bool = True, save_monitor: str = 'valid_loss', weighted_loss: bool = True, **kwargs)[source]¶
Training configuration for FastAI MIL models.
This configuration should not be created directly, but rather should be created through
slideflow.mil.mil_config()
, which will create and prepare an appropriate trainer configuration.- Parameters:
model (str, Callable) – Either the name of a model, or a custom torch module. Valid model names include
"attention_mil"
,"transmil"
, and"bistro.transformer"
.- Keyword Arguments:
aggregation_level (str) – When equal to
'slide'
each bag contains tiles from a single slide. When equal to'patient'
tiles from all slides of a patient are grouped together.lr (float, optional) – Learning rate. If
fit_one_cycle=True
, this is the maximum learning rate. If None, uses the Leslie Smith LR Range test to find an optimal learning rate. Defaults to None.wd (float) – Weight decay. Only used if
fit_one_cycle=False
. Defaults to 1e-5.bag_size (int) – Bag size. Defaults to 512.
max_val_bag_size (int, optional) – Maximum validation bag size. If None, all validation bags will be unclipped and unpadded (full size). Defaults to None.
fit_one_cycle (bool) – Use 1cycle learning rate schedule. Defaults to True.
epochs (int) – Maximum number of epochs. Defaults to 32.
batch_size (int) – Batch size. Defaults to 64.
**kwargs – All additional keyword arguments are passed to
slideflow.mil.MILModelConfig
.
|
MIL model architecture (class/module). |
|
MIL loss function. |
|
Whether the model is multimodal. |
|
Type of model (classification or regression). |
- to_dict(self)¶
Converts this training configuration to a dictionary.
- json_dump(self)¶
Converts this training configuration to a JSON-compatible dict.
- is_classification(self)¶
Whether the model is a classification model.
- get_metrics(self)¶
Get model metrics.
- Returns:
List of metrics to use for model evaluation. Defaults to RocAuc for classification models, and mse and Pearson correlation coefficient for regression models.
- Return type:
List[Callable]
- prepare_training(self, outcomes: str | List[str], exp_label: str | None, outdir: str | None) str ¶
Prepare for training.
Sets up the output directory for the model.
- predict(self, model, bags, attention=False, **kwargs)¶
Generate model prediction from bags.
- Parameters:
model (torch.nn.Module) – Loaded PyTorch MIL model.
bags (torch.Tensor) – Bags, with shape
(n_bags, n_tiles, n_features)
.
- Keyword Arguments:
attention (bool) – Whether to return attention maps.
- Returns:
Predictions and attention.
- Return type:
Tuple[np.ndarray, List[np.ndarray]]
- batched_predict(self, model: Module, loaded_bags: Tensor, **kwargs) Tuple[ndarray, List[ndarray]] ¶
Generate predictions from a batch of bags.
- Parameters:
model (torch.nn.Module) – Loaded PyTorch MIL model.
loaded_bags (torch.Tensor) – Loaded bags, with shape
(n_bags, n_tiles, n_features)
.
- Keyword Arguments:
device (torch.device, optional) – Device on which to run the model. If None, uses the default device.
forward_kwargs (dict, optional) – Additional keyword arguments to pass to the model’s forward function.
attention (bool) – Whether to return attention maps.
attention_pooling (str) – Attention pooling strategy. Either ‘avg’ or ‘max’. Defaults to ‘avg’.
uq (bool) – Whether to return uncertainty quantification.
- Returns:
Predictions and attention.
- Return type:
Tuple[np.ndarray, List[np.ndarray]]
- train(self, train_dataset: Dataset, val_dataset: Dataset | None, outcomes: str | List[str], bags: str | List[str], *, outdir: str = 'mil', exp_label: str | None = None, **kwargs) Learner ¶
Train a multiple-instance learning (MIL) model.
- Parameters:
config (
slideflow.mil.TrainerConfig
) – Trainer and model configuration.train_dataset (
slideflow.Dataset
) – Training dataset.val_dataset (
slideflow.Dataset
) – Validation dataset.outcomes (str) – Outcome column (annotation header) from which to derive category labels.
bags (str) – Either a path to directory with *.pt files, or a list of paths to individual *.pt files. Each file should contain exported feature vectors, with each file containing all tile features for one patient.
- Keyword Arguments:
outdir (str) – Directory in which to save model and results.
exp_label (str) – Experiment label, used for naming the subdirectory in the
{project root}/mil
folder, where training history and the model will be saved.attention_heatmaps (bool) – Generate attention heatmaps for slides. Not available for multi-modal MIL models. Defaults to False.
interpolation (str, optional) – Interpolation strategy for smoothing attention heatmaps. Defaults to ‘bicubic’.
cmap (str, optional) – Matplotlib colormap for heatmap. Can be any valid matplotlib colormap. Defaults to ‘inferno’.
norm (str, optional) – Normalization strategy for assigning heatmap values to colors. Either ‘two_slope’, or any other valid value for the
norm
argument ofmatplotlib.pyplot.imshow
. If ‘two_slope’, normalizes values less than 0 and greater than 0 separately. Defaults to None.
- eval(self, model: Module, dataset: Dataset, outcomes: str | List[str], bags: str | List[str], *, outdir: str = 'mil', attention_heatmaps: bool = False, uq: bool = False, aggregation_level: str | None = None, params: dict | None = None, **heatmap_kwargs) DataFrame ¶
Evaluate a multiple-instance learning model.
Saves results for the evaluation in the target folder, including predictions (parquet format), attention (Numpy format for each slide), and attention heatmaps (if
attention_heatmaps=True
).Logs classifier metrics (AUROC and AP) to the console.
- Parameters:
- Keyword Arguments:
outdir (str) – Path at which to save results.
attention_heatmaps (bool) – Generate attention heatmaps for slides. Not available for multi-modal MIL models. Defaults to False.
interpolation (str, optional) – Interpolation strategy for smoothing attention heatmaps. Defaults to ‘bicubic’.
cmap (str, optional) – Matplotlib colormap for heatmap. Can be any valid matplotlib colormap. Defaults to ‘inferno’.
norm (str, optional) – Normalization strategy for assigning heatmap values to colors. Either ‘two_slope’, or any other valid value for the
norm
argument ofmatplotlib.pyplot.imshow
. If ‘two_slope’, normalizes values less than 0 and greater than 0 separately. Defaults to None.
- Returns:
Dataframe of predictions.
- Return type:
pd.DataFrame
- build_train_dataloader(self, bags, targets, encoder, *, dataset_kwargs=None, dataloader_kwargs=None) torch.utils.DataLoader ¶
Build a training dataloader.
- Parameters:
- Keyword Arguments:
- Returns:
Training dataloader.
- Return type:
torch.utils.DataLoader
- build_val_dataloader(self, bags, targets, encoder, *, dataset_kwargs=None, dataloader_kwargs=None) torch.utils.DataLoader ¶
Build a validation dataloader.
- Parameters:
- Keyword Arguments:
- Returns:
Validation dataloader.
- Return type:
torch.utils.DataLoader
MILModelConfig¶
- class MILModelConfig(model: str | Callable = 'attention_mil', *, use_lens: bool | None = None, apply_softmax: bool = True, model_kwargs: dict | None = None, validate: bool = True, loss: str | Callable = 'cross_entropy', **kwargs)[source]¶
Model configuration for an MIL model.
- Parameters:
model (str, Callable) – Either the name of a model, or a custom torch module. Valid model names include
"attention_mil"
and"transmil"
. Defaults to ‘attention_mil’.- Keyword Arguments:
use_lens (bool, optional) – Whether the model expects a second argument to its
.forward()
function, an array with the bag size for each slide. If None, will default to True for'attention_mil'
models and False otherwise. Defaults to None.apply_softmax (bool) – Whether to apply softmax to model outputs. Defaults to True. Ignored if the model is not a classification model.
model_kwargs (dict, optional) – Additional keyword arguments to pass to the model constructor. Defaults to None.
validate (bool) – Whether to validate the keyword arguments. If True, will raise an error if any unrecognized keyword arguments are passed. Defaults to True.
loss (str, Callable) – Loss function. Defaults to ‘cross_entropy’.
|
Whether softmax will be applied to model outputs. |
|
MIL loss function. |
|
MIL model architecture (class/module). |
|
Type of model (classification or regression). |
|
Whether the model is multimodal. |
- is_classification(self)¶
Whether the model is a classification model.
- to_dict(self)¶
Converts this model configuration to a dictionary.
- predict(self, model, bags, attention=False, apply_softmax=None, **kwargs)¶
Generate model prediction from bags.
- Parameters:
model (torch.nn.Module) – Loaded PyTorch MIL model.
bags (torch.Tensor) – Bags, with shape
(n_bags, n_tiles, n_features)
.
- Keyword Arguments:
- Returns:
Predictions and attention.
- Return type:
Tuple[np.ndarray, List[np.ndarray]]
- batched_predict(self, model: Module, loaded_bags: Tensor, *, device: Any | None = None, forward_kwargs: dict | None = None, attention: bool = False, attention_pooling: str | None = None, uq: bool = False, apply_softmax: bool | None = None) Tuple[ndarray, List[ndarray]] ¶
Generate predictions from a batch of bags.
More efficient than calling
predict()
multiple times.- Parameters:
model (torch.nn.Module) – Loaded PyTorch MIL model.
loaded_bags (torch.Tensor) – Loaded bags, with shape
(n_bags, n_tiles, n_features)
.
- Keyword Arguments:
device (torch.device, optional) – Device on which to run the model. If None, uses the default device.
forward_kwargs (dict, optional) – Additional keyword arguments to pass to the model’s forward function.
attention (bool) – Whether to return attention maps.
attention_pooling (str) – Attention pooling strategy. Either ‘avg’ or ‘max’. Defaults to None.
uq (bool) – Whether to return uncertainty quantification.
- Returns:
Predictions and attention.
- Return type:
Tuple[np.ndarray, List[np.ndarray]]
CLAMModelConfig¶
The CLAM model configuration class requires slideflow-gpl
, which can be installed with:
pip install slideflow-gpl
Once installed, the class is available at slideflow.clam.CLAMModelConfig
.
- class CLAMModelConfig(model: str = 'clam_sb', *, model_size: str = 'small', bag_loss: str = 'ce', bag_weight: float = 0.7, dropout: bool = False, opt: str = 'adam', inst_loss: str = 'ce', no_inst_cluster: bool = False, B: int = 8, model_kwargs: dict | None = None, validate: bool = True, **kwargs)[source]¶
Model configuration for CLAM models.
These configuration options are identical to the options in the original CLAM paper.
- Keyword Arguments:
model (str) – Model. Either
'clam_sb'
,'clam_mb'
,'mil_fc'
, or'mil_fc_mc'
. Defaults to'clam_sb'
.model_size (str) –
Size of the model. Available sizes include:
clam_sb
small
[1024, 512, 256]
big
[1024, 512, 384]
multiscale
[2048, 512, 256]
xception
[2048, 256, 128]
xception_multi
[1880, 128, 64]
xception_3800
[3800, 512, 256]
clam_mb
small
[1024, 512, 256]
big
[1024, 512, 384]
multiscale
[2048, 512, 256]
mil_fc
small
[1024, 512]
mil_fc_mc
small
[1024, 512]
bag_loss (str) – Primary loss function. Either ‘ce’ or ‘svm’. If ‘ce’, the model loss function is a cross entropy loss. If ‘svm’, the model loss is topk.SmoothTop1SVM. Defaults to ‘ce’.
bag_weight (float) – Weight of the bag loss. The total loss is defined0 as
W * loss + (1 - W) * instance_loss
, whereW
is the bag weight. Defaults to 0.7dropout (bool) – Add dropout (p=0.25) after the attention layers. Defaults to False.
opt (str) – Optimizer. Either ‘adam’ (Adam optimizer) or ‘sgd’ (Stochastic Gradient Descent). Defaults to ‘adam’.
inst_loss (str) – Instance loss function. Either ‘ce’ or ‘svm’. If ‘ce’, the instance loss is a cross entropy loss. If ‘svm’, the loss is topk.SmoothTop1SVM. Defaults to ‘ce’.
no_inst_cluster (bool) – Disable instance-level clustering. Defaults to False.
B (int) – Number of positive/negative patches to sample for instance-level training. Defaults to 8.
validate (bool) – Validate the hyperparameter configuration. Defaults to True.