Shortcuts

Source code for slideflow.mil.train

"""Training functions for various multi-instance learning (MIL) models."""

import os
import numpy as np
import slideflow as sf
import pandas as pd
from os.path import join, exists
from typing import Union, List, Optional, TYPE_CHECKING
from slideflow import Dataset, log
from slideflow.util import path_to_name
from os.path import join, isdir

from .. import utils
from ..eval import predict_mil, predict_multimodal_mil, generate_attention_heatmaps
from .._params import TrainerConfig

if TYPE_CHECKING:
    from fastai.learner import Learner


# -----------------------------------------------------------------------------

[docs]def train_mil( config: TrainerConfig, train_dataset: Dataset, val_dataset: Optional[Dataset], outcomes: Union[str, List[str]], bags: Union[str, List[str]], *, outdir: str = 'mil', exp_label: Optional[str] = None, **kwargs ) -> "Learner": """Train a multiple-instance learning (MIL) model. This high-level trainer facilitates training from a given MIL configuration, using Datasets as input and with input features taken from a given directory of bags. Args: config (:class:`slideflow.mil.TrainerConfig`): Trainer and model configuration. train_dataset (:class:`slideflow.Dataset`): Training dataset. val_dataset (:class:`slideflow.Dataset`): Validation dataset. outcomes (str): Outcome column (annotation header) from which to derive category labels. bags (str): Either a path to directory with \*.pt files, or a list of paths to individual \*.pt files. Each file should contain exported feature vectors, with each file containing all tile features for one patient. Keyword args: outdir (str): Directory in which to save model and results. exp_label (str): Experiment label, used for naming the subdirectory in the ``{project root}/mil`` folder, where training history and the model will be saved. attention_heatmaps (bool): Generate attention heatmaps for slides. Not available for multi-modal MIL models. Defaults to False. interpolation (str, optional): Interpolation strategy for smoothing attention heatmaps. Defaults to 'bicubic'. cmap (str, optional): Matplotlib colormap for heatmap. Can be any valid matplotlib colormap. Defaults to 'inferno'. norm (str, optional): Normalization strategy for assigning heatmap values to colors. Either 'two_slope', or any other valid value for the ``norm`` argument of ``matplotlib.pyplot.imshow``. If 'two_slope', normalizes values less than 0 and greater than 0 separately. Defaults to None. """ if not isinstance(config, TrainerConfig): raise ValueError(f"Unrecognized training configuration of type {type(config)}") return config.train( train_dataset=train_dataset, val_dataset=val_dataset, outcomes=outcomes, bags=bags, outdir=outdir, exp_label=exp_label, **kwargs )
# -----------------------------------------------------------------------------
[docs]def build_fastai_learner( config: TrainerConfig, train_dataset: Dataset, val_dataset: Dataset, outcomes: Union[str, List[str]], bags: Union[str, np.ndarray, List[str]], *, outdir: str = 'mil', return_shape: bool = False, **kwargs ) -> "Learner": """Build a FastAI Learner for training an MIL model. Does not execute training. Useful for customizing a Learner object prior to training. Args: train_dataset (:class:`slideflow.Dataset`): Training dataset. val_dataset (:class:`slideflow.Dataset`): Validation dataset. outcomes (str): Outcome column (annotation header) from which to derive category labels. bags (str): list of paths to individual \*.pt files. Each file should contain exported feature vectors, with each file containing all tile features for one patient. Keyword args: outdir (str): Directory in which to save model and results. return_shape (bool): Return the input and output shapes of the model. Defaults to False. exp_label (str): Experiment label, used for naming the subdirectory in the ``outdir`` folder, where training history and the model will be saved. lr (float): Learning rate, or maximum learning rate if ``fit_one_cycle=True``. epochs (int): Maximum epochs. **kwargs: Additional keyword arguments to pass to the FastAI learner. Returns: fastai.learner.Learner, and optionally a tuple of input and output shapes if ``return_shape=True``. """ from . import _fastai labels, unique = utils.get_labels((train_dataset, val_dataset), outcomes, config.is_classification()) # Prepare bags if isinstance(bags, str) or (isinstance(bags, list) and isdir(bags[0])): train_bags = train_dataset.get_bags(bags) if val_dataset is train_dataset: bags = train_bags else: val_bags = val_dataset.get_bags(bags) bags = np.concatenate((train_bags, val_bags)) else: bags = np.array(bags) train_slides = train_dataset.slides() val_slides = val_dataset.slides() if config.aggregation_level == 'slide': # Aggregate feature bags across slides. bags, targets, train_idx, val_idx = utils.aggregate_trainval_bags_by_slide( bags, # type: ignore labels, train_slides, val_slides, log_manifest=(join(outdir, 'slide_manifest.csv') if outdir else None) ) elif config.aggregation_level == 'patient': # Associate patients and their slides. # This is a dictionary where each key is a slide name and each value # is a patient code. Multiple slides can match to the same patient. slide_to_patient = { **train_dataset.patients(), **val_dataset.patients() } # Aggregate feature bags across patients. n_slide_bags = len(bags) bags, targets, train_idx, val_idx = utils.aggregate_trainval_bags_by_patient( bags, # type: ignore labels, train_slides, val_slides, slide_to_patient=slide_to_patient, log_manifest=(join(outdir, 'slide_manifest.csv') if outdir else None) ) log.info(f"Aggregated {n_slide_bags} slide bags to {len(bags)} patient bags.") log.info("Training dataset: {} merged bags (from {} possible slides)".format( len(train_idx), len(train_slides))) log.info("Validation dataset: {} merged bags (from {} possible slides)".format( len(val_idx), len(val_slides))) # Build FastAI Learner learner, (n_in, n_out) = _fastai.build_learner( config, bags=bags, targets=targets, train_idx=train_idx, val_idx=val_idx, unique_categories=unique, outdir=outdir, **kwargs ) if return_shape: return learner, (n_in, n_out) else: return learner
[docs]def build_multimodal_learner( config: TrainerConfig, train_dataset: Dataset, val_dataset: Dataset, outcomes: Union[str, List[str]], bags: Union[np.ndarray, List[str]], *, outdir: str = 'mil', return_shape: bool = False, ) -> "Learner": """Build a multi-magnification FastAI Learner for training an MIL model. Does not execute training. Useful for customizing a Learner object prior to training. Args: train_dataset (:class:`slideflow.Dataset`): Training dataset. val_dataset (:class:`slideflow.Dataset`): Validation dataset. outcomes (str): Outcome column (annotation header) from which to derive category labels. bags (list(str)): List of bag directories containing \*.pt files, one directory for each mode. Keyword args: outdir (str): Directory in which to save model and results. return_shape (bool): Return the input and output shapes of the model. Defaults to False. exp_label (str): Experiment label, used for naming the subdirectory in the ``outdir`` folder, where training history and the model will be saved. lr (float): Learning rate, or maximum learning rate if ``fit_one_cycle=True``. epochs (int): Maximum epochs. **kwargs: Additional keyword arguments to pass to the FastAI learner. Returns: fastai.learner.Learner, and optionally a tuple of input and output shapes if ``return_shape=True``. """ from . import _fastai # Verify bags are in the correct format. if (not isinstance(bags, (tuple, list)) or not all([isinstance(b, str) and isdir(b) for b in bags])): raise ValueError("Expected bags to be a list of paths, got {}".format(type(bags))) num_modes = len(bags) # Prepare labels and slides labels, unique = utils.get_labels((train_dataset, val_dataset), outcomes, config.is_classification()) # --- Prepare bags -------------------------------------------------------- train_bags, train_slides = utils._get_nested_bags(train_dataset, bags) val_bags, val_slides = utils._get_nested_bags(val_dataset, bags) # --- Process bags and targets for training ------------------------------- # Note: we are skipping patient-level bag aggregation for now. # TODO: implement patient-level bag aggregation for multi-modal MIL. # Concatenate training and validation bags. all_bags = np.concatenate((train_bags, val_bags)) # shape: (num_slides, num_modes) assert all_bags.shape[0] == len(train_slides) + len(val_slides) all_slides = train_slides + val_slides targets = np.array([labels[s] for s in all_slides]) train_idx = np.arange(len(train_slides)) val_idx = np.arange(len(train_slides), len(all_slides)) # Write the slide manifest if outdir: sf.util.log_manifest( train_slides, val_slides, labels=labels, filename=join(outdir, 'slide_manifest.csv'), remove_extension=False ) # Print a multi-modal dataset summary. log.info( "[bold]Multi-modal MIL training summary:[/]" + "\n - [blue]Modes[/]: {}".format(num_modes) + "\n - [blue]Slides with bags[/]: {}".format(len(np.unique(all_slides))) + "\n - [blue]Multi-modal bags[/]: {}".format(all_bags.shape[0]) + "\n - [blue]Unique categories[/]: {}".format(len(unique)) + "\n - [blue]Training multi-modal bags[/]: {}".format(len(train_idx)) + "\n - [blue]Training slides[/]: {}".format(len(np.unique(train_slides))) + "\n - [blue]Validation multi-modal bags[/]: {}".format(len(val_idx)) + "\n - [blue]Validation slides[/]: {}".format(len(np.unique(val_slides))) ) # Print a detailed summary of each mode. for i, mode in enumerate(bags): try: bags_config = sf.util.load_json(join(mode, 'bags_config.json')) except Exception: log.info( "Mode {i}: " + "\n - Bags: {}".format(mode) ) else: log.info( f"[bold]Mode {i+1}[/]: [green]{mode}[/]" + "\n - Feature extractor: [purple]{}[/]".format(bags_config['extractor']['class'].split('.')[-1]) + "\n - Tile size (px): {}".format(bags_config['tile_px']) + "\n - Tile size (um): {}".format(bags_config['tile_um']) + "\n - Normalizer: {}".format(bags_config['normalizer']) ) # --- Build FastAI Learner ------------------------------------------------ # Build FastAI Learner learner, (n_in, n_out) = _fastai.build_learner( config, all_bags, targets, train_idx, val_idx, unique_categories=unique, outdir=outdir, ) if return_shape: return learner, (n_in, n_out) else: return learner
# ------------------------------------------------------------------------------ # Internal training functions. def _train_mil( config: TrainerConfig, train_dataset: Dataset, val_dataset: Dataset, outcomes: Union[str, List[str]], bags: Union[str, List[str]], *, outdir: str = 'mil', attention_heatmaps: bool = False, uq: bool = False, device: Optional[str] = None, **heatmap_kwargs ) -> "Learner": """Train an MIL model using FastAI. Args: train_dataset (:class:`slideflow.Dataset`): Training dataset. val_dataset (:class:`slideflow.Dataset`): Validation dataset. outcomes (str): Outcome column (annotation header) from which to derive category labels. bags (str): Either a path to directory with \*.pt files, or a list of paths to individual \*.pt files. Each file should contain exported feature vectors, with each file containing all tile features for one patient. Keyword args: outdir (str): Directory in which to save model and results. exp_label (str): Experiment label, used for naming the subdirectory in the ``{project root}/mil`` folder, where training history and the model will be saved. lr (float): Learning rate, or maximum learning rate if ``fit_one_cycle=True``. epochs (int): Maximum epochs. attention_heatmaps (bool): Generate attention heatmaps for slides. Defaults to False. interpolation (str, optional): Interpolation strategy for smoothing attention heatmaps. Defaults to 'bicubic'. cmap (str, optional): Matplotlib colormap for heatmap. Can be any valid matplotlib colormap. Defaults to 'inferno'. norm (str, optional): Normalization strategy for assigning heatmap values to colors. Either 'two_slope', or any other valid value for the ``norm`` argument of ``matplotlib.pyplot.imshow``. If 'two_slope', normalizes values less than 0 and greater than 0 separately. Defaults to None. Returns: fastai.learner.Learner """ from . import _fastai # Prepare validation bags. if isinstance(bags, str) or (isinstance(bags, list) and isdir(bags[0])): val_bags = val_dataset.get_bags(bags) else: val_bags = np.array([b for b in bags if sf.util.path_to_name(b) in val_dataset.slides()]) # Build learner. learner, (n_in, n_out) = build_fastai_learner( config, train_dataset, val_dataset, outcomes, bags=bags, outdir=outdir, device=device, return_shape=True ) # Save MIL settings. # Attempt to read the unique categories from the learner. if not hasattr(learner.dls.train_ds, 'encoder'): unique = None else: encoder = learner.dls.train_ds.encoder if encoder is not None: unique = encoder.categories_[0].tolist() else: unique = None _log_mil_params(config, outcomes, unique, bags, n_in, n_out, outdir) # Train. _fastai.train(learner, config) # Generate validation predictions. df, attention = predict_mil( learner.model, dataset=val_dataset, config=config, outcomes=outcomes, bags=val_bags, attention=True, uq=uq, ) if outdir: pred_out = join(outdir, 'predictions.parquet') df.to_parquet(pred_out) log.info(f"Predictions saved to [green]{pred_out}[/]") # Print classification metrics, including per-category accuracy utils.rename_df_cols(df, outcomes, categorical=config.is_classification(), inplace=True) config.run_metrics(df, level='slide', outdir=outdir) # Export attention to numpy arrays if attention and outdir: utils._export_attention( join(outdir, 'attention'), attention, [path_to_name(b) for b in val_bags] ) # Attention heatmaps. if attention and attention_heatmaps and outdir: generate_attention_heatmaps( outdir=join(outdir, 'heatmaps'), dataset=val_dataset, bags=val_bags, attention=attention, **heatmap_kwargs ) return learner def _train_multimodal_mil( config: TrainerConfig, train_dataset: Dataset, val_dataset: Optional[Dataset], outcomes: Union[str, List[str]], bags: List[str], *, outdir: str = 'mil', exp_label: Optional[str] = None, attention_heatmaps: bool = False, ): """Train a multi-modal (e.g. multi-magnification) MIL model.""" from . import _fastai # Export attention & heatmaps. if attention_heatmaps: raise ValueError( "Attention heatmaps cannot yet be exported for multi-modal " "models. Please use Slideflow Studio for visualization of " "multi-modal attention." ) # Build learner. learner, (n_in, n_out) = build_multimodal_learner( config, train_dataset, val_dataset, outcomes, bags=bags, outdir=outdir, return_shape=True ) # Save MIL settings. # Attempt to read the unique categories from the learner. if not hasattr(learner.dls.train_ds, 'encoder'): unique = None else: encoder = learner.dls.train_ds.encoder if encoder is not None: unique = encoder.categories_[0].tolist() else: unique = None _log_mil_params(config, outcomes, unique, bags, n_in, n_out, outdir) # Execute training. _fastai.train(learner, config) df, attention = predict_multimodal_mil( learner.model, dataset=val_dataset, config=config, outcomes=outcomes, bags=bags, attention=True ) # Print classification metrics, including per-category accuracy utils.rename_df_cols(df, outcomes, categorical=config.is_classification(), inplace=True) config.run_metrics(df, level='slide', outdir=outdir) # Export predictions. if outdir: pred_out = join(outdir, 'predictions.parquet') df.to_parquet(pred_out) log.info(f"Predictions saved to [green]{pred_out}[/]") # Export attention. if attention and outdir: utils._export_attention(join(outdir, 'attention'), attention, df.slide.values) return learner # ------------------------------------------------------------------------------ def _log_mil_params(config, outcomes, unique, bags, n_in, n_out, outdir=None): """Log MIL parameters to JSON.""" mil_params = config.json_dump() mil_params['outcomes'] = outcomes if unique is not None: mil_params['outcome_labels'] = dict(zip(range(len(unique)), unique)) else: mil_params['outcome_labels'] = None mil_params['bags'] = bags mil_params['input_shape'] = n_in mil_params['output_shape'] = n_out if isinstance(bags, str) and exists(join(bags, 'bags_config.json')): mil_params['bags_extractor'] = sf.util.load_json( join(bags, 'bags_config.json') ) elif isinstance(bags, list): mil_params['bags_extractor'] = {} for b in bags: if isdir(b) and exists(join(b, 'bags_config.json')): mil_params['bags_extractor'][b] = sf.util.load_json( join(b, 'bags_config.json') ) else: mil_params['bags_extractor'][b] = None else: mil_params['bags_extractor'] = None if outdir: sf.util.write_json(mil_params, join(outdir, 'mil_params.json')) return mil_params