• Docs >
  • Tutorial 8: Multiple-Instance Learning

Tutorial 8: Multiple-Instance Learning

In contrast with tutorials 1-4, which focused on training and evaluating traditional tile-based models, this tutorial provides an example of training a multiple-instance learning (MIL) model. MIL models are particularly useful for heterogeneous tumors, when only parts of a whole-slide image may carry a distinctive histological signature. In this tutorial, we’ll train a MIL model to predict the ER status of breast cancer patients from whole slide images. Note: MIL models require PyTorch.

We’ll start the same way as Tutorial 1: Model training (simple), loading a project and preparing a dataset.

>>> import slideflow as sf
>>> P = sf.load_project('/home/er_project')
>>> dataset = P.dataset(
...   tile_px=256,
...   tile_um=128,
...   filters={
...     'er_status_by_ihc': ['Positive', 'Negative']
... })

If tiles have not yet been extracted for this dataset, do that now.

>>> dataset.extract_tiles(qc='otsu')

Once a dataset has been prepared, the next step in training an MIL model is converting images into features. For this example, we’ll use the pretrained HistoSSL feature extractor, a vision transformer pretrained on 40 million histology images. HistoSSL was trained on tiles of size 224x224, so our images will be center-cropped to match.

>>> from slideflow.model import build_feature_extractor
>>> histossl = build_feature_extractor('histossl', tile_px=256)

This model is developed and licensed by Owkin, Inc. The license for use is
provided in the LICENSE file in the same directory as this source file
(slideflow/model/extractors/histossl/LICENSE), and is also available
at https://github.com/owkin/HistoSSLscaling. By using this feature extractor,
you agree to the terms of the license.

>>> histossl.cite()

    author       = {Alexandre Filiot and Ridouane Ghermi and Antoine Olivier and Paul Jacob and Lucas Fidon and Alice Mac Kain and Charlie Saillard and Jean-Baptiste Schiratti},
    title        = {Scaling Self-Supervised Learning for Histopathology with Masked Image Modeling},
    elocation-id = {2023.07.21.23292757},
    year         = {2023},
    doi          = {10.1101/2023.07.21.23292757},
    publisher    = {Cold Spring Harbor Laboratory Press},
    url          = {https://www.medrxiv.org/content/early/2023/07/26/2023.07.21.23292757},
    eprint       = {https://www.medrxiv.org/content/early/2023/07/26/2023.07.21.23292757.full.pdf},
    journal      = {medRxiv}
>>> histossl.num_features

The HistoSSL feature extractor produces a 768-dimensional vector for each tile. We can generate and export bags of these features for all slides in our dataset using slideflow.Project.generate_feature_bags().

>>> P.generate_feature_bags(
...     histossl,
...     dataset,
...     outdir='/bags/path'
... )

The output directory, /bags/path, should look like:

├── slide1.pt
├── slide1.indez.npz
├── slide2.pt
├── slide2.index.npz
├── ...
└── bags_config.json

The *.pt files contain the feature vectors for tiles in each slide, and the *.index.npz files contain the corresponding X, Y coordinates for each tile. The bags_config.json file contains the feature extractor configuration.

The next step is to create an MIL model configuration using slideflow.mil.mil_config(), specifying the architecture and relevant hyperparameters. For the architecture, we’ll use slideflow.mil.models.Attention_MIL. For the hyperparameters, we’ll use a learning rate of 1e-4, a batch size of 32, 1cycle learning rate scheduling, and train for 10 epochs.

>>> from slideflow.mil import mil_config
>>> config = mil_config(
...     model='attention_mil',
...     lr=1e-4,
...     batch_size=32,
...     epochs=10,
...     fit_one_cycle=True
... )

Finally, we can train the model using slideflow.mil.train_mil(). We’ll split our dataset into 70% training and 30% validation, training to the outcome “er_status_by_ihc” and saving the model to /model/path.

>>> from slideflow.mil import train_mil
>>> train, val = dataset.split(labels='er_status_by_ihc', val_fraction=0.3)
>>> train_mil(
...     config,
...     train_dataset=train,
...     val_dataset=val,
...     outcomes='er_status_by_ihc',
...     bags='/bags/path',
...     outdir='/model/path'
... )

During training, you’ll see the training/validation loss and validation AUROC for each epoch. At the end of training, you’ll see the validation metrics for each outcome.

[18:51:01] INFO     Training FastAI MIL model with config:
           INFO     TrainerConfigFastAI(
[18:51:02] INFO     Training dataset: 272 merged bags (from 272 possible slides)
           INFO     Validation dataset: 116 merged bags (from 116 possible slides)
[18:51:04] INFO     Training model Attention_MIL (in=1024, out=2, loss=CrossEntropyLoss)
epoch     train_loss  valid_loss  roc_auc_score  time
0         0.328032    0.285096    0.580233       00:01
Better model found at epoch 0 with valid_loss value: 0.2850962281227112.
1         0.319219    0.266496    0.733721       00:01
Better model found at epoch 1 with valid_loss value: 0.266496479511261.
2         0.293969    0.230561    0.859690       00:01
Better model found at epoch 2 with valid_loss value: 0.23056122660636902.
3         0.266627    0.190546    0.927519       00:01
Better model found at epoch 3 with valid_loss value: 0.1905461698770523.
4         0.236985    0.165320    0.939147       00:01
Better model found at epoch 4 with valid_loss value: 0.16532012820243835.
5         0.215019    0.153572    0.946512       00:01
Better model found at epoch 5 with valid_loss value: 0.153572216629982.
6         0.199093    0.144464    0.948837       00:01
Better model found at epoch 6 with valid_loss value: 0.1444639265537262.
7         0.185597    0.141776    0.952326       00:01
Better model found at epoch 7 with valid_loss value: 0.14177580177783966.
8         0.173794    0.141409    0.951938       00:01
Better model found at epoch 8 with valid_loss value: 0.14140936732292175.
9         0.167547    0.140791    0.952713       00:01
Better model found at epoch 9 with valid_loss value: 0.14079126715660095.
[18:51:18] INFO     Predictions saved to {...}/predictions.parquet
           INFO     Validation metrics for outcome brs_class:
[18:51:18] INFO     slide-level AUC (cat # 0): 0.953 AP: 0.984 (opt. threshold: 0.544)
           INFO     slide-level AUC (cat # 1): 0.953 AP: 0.874 (opt. threshold: 0.458)
           INFO     Category 0 acc: 88.4% (76/86)
           INFO     Category 1 acc: 83.3% (25/30)

After training has completed, the output directory, /model/path, should look like:

├── attention
│   ├── slide1_att.npz
│   └── ...
├── models
│   └── best_valid.pth
├── history.csv
├── mil_params.json
├── predictions.parquet
└── slide_manifest.csv

The final model weights are saved in models/best_valid.pth. Validation dataset predictions are saved in the “predictions.parquet” file. A manifest of training/validation data is saved in the “slide_manifest.csv” file, and training history is saved in the “history.csv” file. Attention values for all tiles in each slide are saved in the attention/ directory.

The final saved model can be used for evaluation (slideflow.mil.eval_mil) or inference (slideflow.mil.predict_slide or Slideflow Studio). The saved model path should be referenced by the parent directory (in this case, “/model/path”) rather than the model file itself. For more information on MIL models, see Multiple-Instance Learning (MIL).