Shortcuts

Source code for slideflow_gpl.clam.config

# Slideflow-GPL - Add-ons for the deep learning library Slideflow
# Copyright (C) 2024 James Dolezal
#
# This file is part of Slideflow-GPL.
#
# Slideflow-GPL is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Slideflow-GPL is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Slideflow-GPL. If not, see <https://www.gnu.org/licenses/>.

import slideflow as sf
from typing import Union, List, Tuple, Optional, TYPE_CHECKING
from slideflow import log, errors, Dataset
from slideflow.mil import MILModelConfig, TrainerConfig

if TYPE_CHECKING:
    import torch

# -----------------------------------------------------------------------------

[docs]class CLAMModelConfig(MILModelConfig): valid_models = ['clam_sb', 'clam_mb', 'mil_fc_mc', 'mil_fc'] def __init__( self, model: str = 'clam_sb', *, model_size: str = 'small', bag_loss: str = 'ce', bag_weight: float = 0.7, dropout: bool = False, opt: str = 'adam', inst_loss: str = 'ce', no_inst_cluster: bool = False, B: int = 8, model_kwargs: Optional[dict] = None, validate: bool = True, **kwargs ): """Model configuration for CLAM models. These configuration options are identical to the options in the `original CLAM paper <https://arxiv.org/abs/2004.09666>`_. Keyword args: model (str): Model. Either ``'clam_sb'``, ``'clam_mb'``, ``'mil_fc'``, or ``'mil_fc_mc'``. Defaults to ``'clam_sb'``. model_size (str): Size of the model. Available sizes include: ``clam_sb`` .. list-table:: :header-rows: 0 * - small - [1024, 512, 256] * - big - [1024, 512, 384] * - multiscale - [2048, 512, 256] * - xception - [2048, 256, 128] * - xception_multi - [1880, 128, 64] * - xception_3800 - [3800, 512, 256] ``clam_mb`` .. list-table:: :header-rows: 0 * - small - [1024, 512, 256] * - big - [1024, 512, 384] * - multiscale - [2048, 512, 256] ``mil_fc`` .. list-table:: :header-rows: 0 * - small - [1024, 512] ``mil_fc_mc`` .. list-table:: :header-rows: 0 * - small - [1024, 512] bag_loss (str): Primary loss function. Either 'ce' or 'svm'. If 'ce', the model loss function is a cross entropy loss. If 'svm', the model loss is topk.SmoothTop1SVM. Defaults to 'ce'. bag_weight (float): Weight of the bag loss. The total loss is defined0 as ``W * loss + (1 - W) * instance_loss``, where ``W`` is the bag weight. Defaults to 0.7 dropout (bool): Add dropout (p=0.25) after the attention layers. Defaults to False. opt (str): Optimizer. Either 'adam' (Adam optimizer) or 'sgd' (Stochastic Gradient Descent). Defaults to 'adam'. inst_loss (str): Instance loss function. Either 'ce' or 'svm'. If 'ce', the instance loss is a cross entropy loss. If 'svm', the loss is topk.SmoothTop1SVM. Defaults to 'ce'. no_inst_cluster (bool): Disable instance-level clustering. Defaults to False. B (int): Number of positive/negative patches to sample for instance-level training. Defaults to 8. validate (bool): Validate the hyperparameter configuration. Defaults to True. """ for argname, argval in dict(locals()).items(): if argname not in ('kwargs', 'validate'): setattr(self, argname, argval) if kwargs and validate: raise errors.UnrecognizedHyperparameterError("Unrecognized parameters: {}".format( ', '.join(list(kwargs.keys())) )) elif kwargs: log.warning("Ignoring unrecognized parameters: {}".format( ', '.join(list(kwargs.keys())) )) @property def model_fn(self): from .model import CLAM_MB, CLAM_SB, MIL_fc_mc, MIL_fc model_dict = { 'clam_sb': CLAM_SB, 'clam_mb': CLAM_MB, 'mil_fc_mc': MIL_fc_mc, 'mil_fc': MIL_fc } return model_dict[self.model] @property def loss_fn(self): from .legacy.utils import loss_utils if self.bag_loss == 'ce': if self.model.startswith('clam'): return loss_utils.CrossEntropyWithInstanceLoss else: return loss_utils.CrossEntropyLoss else: raise ValueError("Unrecognized bag loss: {}".format(self.bag_loss)) @property def model_type(self): return 'classification' def get_metrics(self): from .legacy.utils import loss_utils return [loss_utils.RocAuc()] def build_model(self, n_in, n_out, **kwargs): if isinstance(self.model_size, str): config_size = self.model_fn.sizes[self.model_size] else: config_size = self.model_size model_size = [n_in] + config_size[1:] return self.model_fn(size=model_size, n_classes=n_out, **kwargs) def verify_trainer(self, trainer): if hasattr(trainer, 'batch_size') and trainer.batch_size > 1: log.info( "CLAM models do not support batch sizes > 1; setting batch_size to 1." ) trainer.batch_size = 1 def inspect_batch(self, batch) -> Tuple[int, int]: """Inspect a batch to determine the input and output dimensions..""" bags, targets, _ = batch[0] n_in = bags.shape[-1] n_out = targets.shape[-1] return n_in, n_out def _verify_eval_params(self, **kwargs): """Verify evaluation parameters.""" super()._verify_eval_params(**kwargs) if kwargs.get('uq'): raise ValueError( "Cannot calculate uncertainty quantification using CLAM models." ) def _build_dataloader( self, bags, targets, encoder, *, dataset_kwargs = None, dataloader_kwargs = None ) -> "torch.utils.DataLoader": from fastai.vision.all import DataLoader from .data import build_clam_dataset dataset_kwargs = dataset_kwargs or dict() dataloader_kwargs = dataloader_kwargs or dict() dataset = build_clam_dataset(bags, targets, encoder=encoder, **dataset_kwargs) dataloader = DataLoader(dataset, **dataloader_kwargs) return dataloader def predict(self, model, bags, attention=False, device=None, **kwargs): """Generate CLAM predictions for a list of bags.""" from .inference import run_inference self._verify_eval_params(**kwargs) return run_inference(model, bags, attention=attention) def batched_predict(self, *args, **kwargs): """CLAM models do not support batched predictions with batch_size > 1. Thus, this method is equivalent to :meth:`predict`, which generates predictions for each bag individually. """ return self.predict(*args, **kwargs)
# ----------------------------------------------------------------------------- class LegacyCLAMTrainerConfig(TrainerConfig): tag = 'legacy_clam' def __init__( self, *, num_splits: int = 1, # Unused; kept for backwards compatibility k: int = 3, k_start: int = -1, k_end: int = -1, max_epochs: int = 20, lr: float = 1e-4, reg: float = 1e-5, label_frac: float = 1, weighted_sample: bool = False, log_data: bool = False, testing: bool = False, early_stopping: bool = False, subtyping: bool = False, seed: int = 1, results_dir: Optional[str] = None, # Unused; kept for compatibility n_classes: Optional[int] = None, split_dir=None, data_root_dir=None, micro_average=False, **kwargs ): """Training configuration for the legacy CLAM trainer. This configures the legacy CLAM trainer. The FastAI trainer is preferred for all models, including CLAM. The configuration options for the legacy CLAM trainer are identical to the options in the `original CLAM paper <https://arxiv.org/abs/2004.09666>`_. Keyword args: k (int): Number of cross-fold splits. Defaults to 3. k_start (int): Starting cross-fold. Defaults to first cross-fold. k_end (int): Ending cross-fold. Defaults to ending after last cross-fold is done. max_epochs (int): Number of epochs to train. Defaults to 20. lr (float): Learning rate. Defaults to 1e-4. reg (float): Weight decay. Defaults to 1e-5. weighted_sample (bool): Equally sample from all outcome classes. Defaults to False. log_data (bool): Log to tensorboard. Defaults to False. early_stopping (bool): Stop the training if validation loss doesn't improve after 5 epochs. Will not trigger early stopping until epoch 50. Defaults to False. subtyping (bool): Whether this is a subtyping problem. Defaults to False. seed (int): Set the random seed. Defaults to 1. n_classes (int): Number of outcome classes. Defaults to None. micro_average (bool): Use micro averaging when calculate AUROC. **kwargs: All additional keyword arguments are passed to :class:`slideflow.mil.CLAMModelConfig`. """ for argname, argval in dict(locals()).items(): if argname != 'kwargs': setattr(self, argname, argval) self.model_config = CLAMModelConfig(**kwargs) def _to_clam_args(self): """Convert into CLAM_Args format (legacy support).""" from .legacy import CLAM_Args all_kw = self.to_dict() all_kw.update(self.model_config.to_dict()) all_kw['model_type'] = all_kw['model'] all_kw['drop_out'] = all_kw['dropout'] del all_kw['model'] del all_kw['dropout'] del all_kw['model_kwargs'] return CLAM_Args(**all_kw) def train( self, train_dataset: Dataset, val_dataset: Optional[Dataset], outcomes: Union[str, List[str]], bags: Union[str, List[str]], *, outdir: str = 'mil', exp_label: Optional[str] = None, **kwargs ): from .legacy.trainer import train_clam # Prepare output directory outdir = self.prepare_training(outcomes, exp_label, outdir) # Use training data as validation if no validation set is provided if val_dataset is None: sf.log.info( "Training without validation; metrics will be calculated on training data." ) val_dataset = train_dataset return train_clam( self, train_dataset, val_dataset, outcomes, bags, outdir=outdir, **kwargs ) def _verify_eval_params(self, **kwargs): """Verify evaluation parameters.""" super()._verify_eval_params(**kwargs) if kwargs.get('aggregation_level') == 'patient': raise ValueError( "Cannot aggregate bags by patient using the legacy CLAM trainer." ) if kwargs.get('uq'): raise ValueError( "Cannot calculate uncertainty quantification using the legacy CLAM trainer." )