slideflow.simclr¶
This module contains utility functions for training a SimCLR model. Please see Self-Supervised Learning (SSL) for more information on the high-level API and recommended use.
- get_args(**kwargs)[source]¶
Configure a
SimCLR_Args
object for training SimCLR.- Keyword Arguments:
**kwargs – Please see the
slideflow.simclr.SimCLR_Args
documentation for information on available parameters.- Returns:
slideflow.simclr.SimCLR_Args
- load(path, as_pretrained: bool = False)[source]¶
Load a SavedModel or checkpoint for inference.
- Parameters:
path (str) – Path to saved model.
- Returns:
Tensorflow SimCLR model.
- load_model_args(model_path, ignore_missing=False)[source]¶
Load args.json associated with a given SimCLR model or checkpoint.
- Parameters:
model_path (str) – Path to SimCLR model or checkpoint.
- Returns:
Dictionary of contents of args.json file. If file is not found and ignore_missing is False, will return None. If ignore_missing is True, will raise an OSError.
- Raises:
OSError – If args.json cannot be found and ignore_missing is False.
- run_simclr(args, builder=None, model_dir=None, cache_dataset=False, checkpoint_path=None, use_tpu=False, tpu_name=None, tpu_zone=None, gcp_project=None)[source]¶
Train a SimCLR model.
- Parameters:
simCLR_args (SimpleNamespace) – SimCLR arguments, as provided by
slideflow.simclr.get_args()
.builder (DatasetBuilder, optional) – Builder for preparing SimCLR input pipelines. If None, will build using TensorflowDatasets and simclr_args.dataset.
model_dir (str) – Model directory for training.
cache_dataset (bool) – Whether to cache the entire dataset in memory. If the dataset is ImageNet, this is a very bad idea, but for smaller datasets it can improve performance
checkpoint_path (str) – Loading from the given checkpoint for fine-tuning if a finetuning checkpoint does not already exist in model_dir
use_tpu (bool) – Whether to run on TPU.
tpu_name (str) – The Cloud TPU to use for training. This should be either the name used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url
tpu_zone (str) – GCE zone where the Cloud TPU is located in. If not specified, we will attempt to automatically detect the GCE project from metadata
gcp_project (str) – Project name for the Cloud TPU-enabled project. If not specified, we will attempt to automatically detect the GCE project from metadata
- class SimCLR_Args(learning_rate=0.075, learning_rate_scaling='sqrt', warmup_epochs=10, weight_decay=0.0001, batch_norm_decay=0.9, train_batch_size=512, train_split='train', train_epochs=100, train_steps=0, eval_steps=0, eval_batch_size=256, checkpoint_epochs=1, checkpoint_steps=0, eval_split='validation', dataset='imagenet2012', mode='train', train_mode='pretrain', lineareval_while_pretraining=True, zero_init_logits_layer=False, fine_tune_after_block=-1, master=None, data_dir=None, optimizer='lars', momentum=0.9, keep_checkpoint_max=5, temperature=0.1, hidden_norm=True, proj_head_mode='nonlinear', proj_out_dim=128, num_proj_layers=3, ft_proj_selector=0, global_bn=True, width_multiplier=1, resnet_depth=50, sk_ratio=0.0, se_ratio=0.0, image_size=224, color_jitter_strength=1.0, use_blur=True, num_classes=None, stain_augment=True)[source]¶
SimCLR arguments.
- A class containg all default - if not overwritten at initialization -
SimCLR arguments.
- Keyword Arguments:
learning_rate (float) – Initial learning rate per batch size of 256.
learning_rate_scaling (str) – How to scale the learning rate as a function of batch size. ‘linear’ or ‘sqrt’.
warmup_epochs (int) – Number of epochs of warmup.
weight_decay (float) – Amount of weight decay to use.
batch_norm_decay (float) – Batch norm decay parameter.
train_batch_size (int) – Batch size for training.
train_split (str) – Split for training
train_epoch (int) – Number of epochs to train for.
train_step (int) – Number of steps to train for. If provided, overrides train_epochs.
eval_steps (int) – Number of steps to eval for. If not provided, evals over entire dataset.
eval_batch_size (int) – Batch size for eval.
checkpoint_epochs (int) – Number of epochs between checkpoints/summaries.
checkpoint_steps (int) – Number of steps between checkpoints/summaries. If provided, overrides checkpoint_epochs.
eval_split (str) – Split for evaluation.
dataset (str) – Name of a dataset.
mode (str) – Whether to perform training or evaluation. ‘train’, ‘eval’, or ‘train_then_eval’
train_mode (str) – The train mode controls different objectives and trainable components.
lineareval_while_pretraining (bool) – Whether to finetune supervised head while pretraining. ‘pretrain’ or ‘finetune’
zero_init_logits_layer (bool) – If True, zero initialize layers after avg_pool for supervised learning.
fine_tune_after_block (int) – The layers after which block that we will fine-tune. -1 means fine-tuning everything. 0 means fine-tuning after stem block. 4 means fine-tuning just the linear head.
master (str) – Address/name of the TensorFlow master to use. By default, use an in-process master.
data_dir (str) – Directory where dataset is stored.
optimizer (str) – Optimizer to use. ‘momentum’, ‘adam’, ‘lars’
momentum (float) – Momentum parameter.
keep_checkpoint_max (int) – Maximum number of checkpoints to keep.
temperature (float) – Temperature parameter for contrastive loss.
hidden_norm (bool) – Temperature parameter for contrastive loss.
proj_head_mode (str) – How the head projection is done. ‘none’, ‘linear’, ‘nonlinear’
proj_out_dim (int) – Number of head projection dimension.
num_proj_layers (int) – Number of non-linear head layers.
ft_proj_selector (int) – Which layer of the projection head to use during fine-tuning. 0 means no projection head, and -1 means the final layer.
global_bn (bool) – Whether to aggregate BN statistics across distributed cores.
width_multiplier (int) – Multiplier to change width of network.
resnet_depth (int) – Depth of ResNet.
sk_ratio (float) – If it is bigger than 0, it will enable SK. Recommendation: 0.0625.
se_ratio (float) – If it is bigger than 0, it will enable SE.
image_size (int) – Input image size.
color_jitter_strength (float) – The strength of color jittering.
use_blur (bool) – Whether or not to use Gaussian blur for augmentation during pretraining.
num_classes (int) – Number of classes for the supervised head.
- class DatasetBuilder(train_dts=None, val_dts=None, test_dts=None, *, labels=None, val_kwargs=None, steps_per_epoch_override=None, normalizer=None, normalizer_source=None, dataset_kwargs=None)[source]¶
Build a training/validation dataset pipeline for SimCLR.
- Parameters:
train_dts (sf.Dataset, optional) – Training dataset.
val_dts (sf.Dataset, optional) – Optional validation dataset.
test_dts (sf.Dataset, optional) – Optional held-out test set.
- Keyword Arguments:
labels (str or dict) – Labels for training the supervised head. Can be a name of an outcome (str) or a dict mapping slide names to labels.
val_kwargs (dict, optional) – Optional keyword arguments for generating a validation dataset from
train_dts
viatrain_dts.split()
. Incompatible withval_dts
.steps_per_epoch_override (int, optional) – Override the number of steps per epoch.
dataset_kwargs (dict, optional) – Keyword arguments passed to the
slideflow.Dataset.tensorflow()
method when creating the input pipeline.