This module contains utility functions for training a SimCLR model. Please see Self-Supervised Learning (SSL) for more information on the high-level API and recommended use.


Configure a SimCLR_Args object for training SimCLR.

Keyword Arguments:

**kwargs – Please see the slideflow.simclr.SimCLR_Args documentation for information on available parameters.



load(path, as_pretrained: bool = False)[source]

Load a SavedModel or checkpoint for inference.


path (str) – Path to saved model.


Tensorflow SimCLR model.

load_model_args(model_path, ignore_missing=False)[source]

Load args.json associated with a given SimCLR model or checkpoint.


model_path (str) – Path to SimCLR model or checkpoint.


Dictionary of contents of args.json file. If file is not found and ignore_missing is False, will return None. If ignore_missing is True, will raise an OSError.


OSError – If args.json cannot be found and ignore_missing is False.

run_simclr(args, builder=None, model_dir=None, cache_dataset=False, checkpoint_path=None, use_tpu=False, tpu_name=None, tpu_zone=None, gcp_project=None)[source]

Train a SimCLR model.

  • simCLR_args (SimpleNamespace) – SimCLR arguments, as provided by slideflow.simclr.get_args().

  • builder (DatasetBuilder, optional) – Builder for preparing SimCLR input pipelines. If None, will build using TensorflowDatasets and simclr_args.dataset.

  • model_dir (str) – Model directory for training.

  • cache_dataset (bool) – Whether to cache the entire dataset in memory. If the dataset is ImageNet, this is a very bad idea, but for smaller datasets it can improve performance

  • checkpoint_path (str) – Loading from the given checkpoint for fine-tuning if a finetuning checkpoint does not already exist in model_dir

  • use_tpu (bool) – Whether to run on TPU.

  • tpu_name (str) – The Cloud TPU to use for training. This should be either the name used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url

  • tpu_zone (str) – GCE zone where the Cloud TPU is located in. If not specified, we will attempt to automatically detect the GCE project from metadata

  • gcp_project (str) – Project name for the Cloud TPU-enabled project. If not specified, we will attempt to automatically detect the GCE project from metadata

class SimCLR(*args, **kwargs)[source]

Resnet model with projection or supervised layer.

class SimCLR_Args(learning_rate=0.075, learning_rate_scaling='sqrt', warmup_epochs=10, weight_decay=0.0001, batch_norm_decay=0.9, train_batch_size=512, train_split='train', train_epochs=100, train_steps=0, eval_steps=0, eval_batch_size=256, checkpoint_epochs=1, checkpoint_steps=0, eval_split='validation', dataset='imagenet2012', mode='train', train_mode='pretrain', lineareval_while_pretraining=True, zero_init_logits_layer=False, fine_tune_after_block=-1, master=None, data_dir=None, optimizer='lars', momentum=0.9, keep_checkpoint_max=5, temperature=0.1, hidden_norm=True, proj_head_mode='nonlinear', proj_out_dim=128, num_proj_layers=3, ft_proj_selector=0, global_bn=True, width_multiplier=1, resnet_depth=50, sk_ratio=0.0, se_ratio=0.0, image_size=224, color_jitter_strength=1.0, use_blur=True, num_classes=None, stain_augment=True)[source]

SimCLR arguments.

A class containg all default - if not overwritten at initialization -

SimCLR arguments.

Keyword Arguments:
  • learning_rate (float) – Initial learning rate per batch size of 256.

  • learning_rate_scaling (str) – How to scale the learning rate as a function of batch size. ‘linear’ or ‘sqrt’.

  • warmup_epochs (int) – Number of epochs of warmup.

  • weight_decay (float) – Amount of weight decay to use.

  • batch_norm_decay (float) – Batch norm decay parameter.

  • train_batch_size (int) – Batch size for training.

  • train_split (str) – Split for training

  • train_epoch (int) – Number of epochs to train for.

  • train_step (int) – Number of steps to train for. If provided, overrides train_epochs.

  • eval_steps (int) – Number of steps to eval for. If not provided, evals over entire dataset.

  • eval_batch_size (int) – Batch size for eval.

  • checkpoint_epochs (int) – Number of epochs between checkpoints/summaries.

  • checkpoint_steps (int) – Number of steps between checkpoints/summaries. If provided, overrides checkpoint_epochs.

  • eval_split (str) – Split for evaluation.

  • dataset (str) – Name of a dataset.

  • mode (str) – Whether to perform training or evaluation. ‘train’, ‘eval’, or ‘train_then_eval’

  • train_mode (str) – The train mode controls different objectives and trainable components.

  • lineareval_while_pretraining (bool) – Whether to finetune supervised head while pretraining. ‘pretrain’ or ‘finetune’

  • zero_init_logits_layer (bool) – If True, zero initialize layers after avg_pool for supervised learning.

  • fine_tune_after_block (int) – The layers after which block that we will fine-tune. -1 means fine-tuning everything. 0 means fine-tuning after stem block. 4 means fine-tuning just the linear head.

  • master (str) – Address/name of the TensorFlow master to use. By default, use an in-process master.

  • data_dir (str) – Directory where dataset is stored.

  • optimizer (str) – Optimizer to use. ‘momentum’, ‘adam’, ‘lars’

  • momentum (float) – Momentum parameter.

  • keep_checkpoint_max (int) – Maximum number of checkpoints to keep.

  • temperature (float) – Temperature parameter for contrastive loss.

  • hidden_norm (bool) – Temperature parameter for contrastive loss.

  • proj_head_mode (str) – How the head projection is done. ‘none’, ‘linear’, ‘nonlinear’

  • proj_out_dim (int) – Number of head projection dimension.

  • num_proj_layers (int) – Number of non-linear head layers.

  • ft_proj_selector (int) – Which layer of the projection head to use during fine-tuning. 0 means no projection head, and -1 means the final layer.

  • global_bn (bool) – Whether to aggregate BN statistics across distributed cores.

  • width_multiplier (int) – Multiplier to change width of network.

  • resnet_depth (int) – Depth of ResNet.

  • sk_ratio (float) – If it is bigger than 0, it will enable SK. Recommendation: 0.0625.

  • se_ratio (float) – If it is bigger than 0, it will enable SE.

  • image_size (int) – Input image size.

  • color_jitter_strength (float) – The strength of color jittering.

  • use_blur (bool) – Whether or not to use Gaussian blur for augmentation during pretraining.

  • num_classes (int) – Number of classes for the supervised head.

class DatasetBuilder(train_dts=None, val_dts=None, test_dts=None, *, labels=None, val_kwargs=None, steps_per_epoch_override=None, normalizer=None, normalizer_source=None, dataset_kwargs=None)[source]

Build a training/validation dataset pipeline for SimCLR.

  • train_dts (sf.Dataset, optional) – Training dataset.

  • val_dts (sf.Dataset, optional) – Optional validation dataset.

  • test_dts (sf.Dataset, optional) – Optional held-out test set.

Keyword Arguments:
  • labels (str or dict) – Labels for training the supervised head. Can be a name of an outcome (str) or a dict mapping slide names to labels.

  • val_kwargs (dict, optional) – Optional keyword arguments for generating a validation dataset from train_dts via train_dts.split(). Incompatible with val_dts.

  • steps_per_epoch_override (int, optional) – Override the number of steps per epoch.

  • dataset_kwargs (dict, optional) – Keyword arguments passed to the slideflow.Dataset.tensorflow() method when creating the input pipeline.