Shortcuts

Source code for slideflow.simclr.simclr.tf2

# coding=utf-8
# Copyright 2020 The SimCLR Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific simclr governing permissions and
# limitations under the License.
# ==============================================================================
"""The main training pipeline."""

import json
import math
import os

from tqdm import tqdm
from slideflow import log as logging
from . import data as data_lib
from . import metrics
from . import model as model_lib
from . import objective as obj_lib
from . import utils as utils_lib

import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds

# -----------------------------------------------------------------------------

def build_saved_model(
    model,
    include_projection_head=True,
    include_supervised_head=True
):
  """Returns a tf.Module for saving to SavedModel."""

  class SimCLRModel(tf.Module):
    """Saved model for exporting to hub."""

    def __init__(self, model):
      self.model = model
      # This can't be called `trainable_variables` because `tf.Module` has
      # a getter with the same name.
      self.trainable_variables_list = model.trainable_variables

    @tf.function
    def __call__(self, inputs, trainable):
      self.model(inputs, training=trainable)
      return utils_lib.get_salient_tensors_dict(
        include_projection_head, include_supervised_head
      )

  module = SimCLRModel(model)
  input_spec = tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.float32)
  module.__call__.get_concrete_function(input_spec, trainable=True)
  module.__call__.get_concrete_function(input_spec, trainable=False)
  return module


def save(model, destination, simclr_args, global_step=None, named_by_step=False):
  """Export as SavedModel for finetuning and inference."""
  is_supervised = ((simclr_args.train_mode == 'finetune'
                    or simclr_args.lineareval_while_pretraining)
                   and simclr_args.num_classes > 0)
  saved_model = build_saved_model(model, include_supervised_head=is_supervised)
  if named_by_step:
    checkpoint_export_dir = destination + f'_step{global_step}'
  else:
    checkpoint_export_dir = destination
  if tf.io.gfile.exists(checkpoint_export_dir):
    tf.io.gfile.rmtree(checkpoint_export_dir)
  tf.saved_model.save(saved_model, checkpoint_export_dir)
  with open(os.path.join(checkpoint_export_dir, 'args.json'), "w") as data_file:
    json.dump(simclr_args.to_dict(), data_file, indent=1)


[docs]def load(path, as_pretrained: bool = False): """Load a SavedModel or checkpoint for inference. Args: path (str): Path to saved model. Returns: Tensorflow SimCLR model. """ args = utils_lib.load_model_args(path) if as_pretrained: args.train_mode = 'pretrain' model = model_lib.SimCLR(**args.model_kwargs) step = tf.Variable(0, dtype=tf.int64) checkpoint = tf.train.Checkpoint(model=model, global_step=step) if path.endswith('.ckpt'): path = path.split('.ckpt')[0] checkpoint.restore(path).expect_partial() return model
def try_restore_from_checkpoint( model, global_step, optimizer, model_dir, checkpoint_path, keep_checkpoint_max=5, zero_init_logits_layer=False, ): """Restores the latest ckpt if it exists, otherwise check checkpoint_path""" checkpoint = tf.train.Checkpoint( model=model, global_step=global_step, optimizer=optimizer) checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=model_dir, max_to_keep=keep_checkpoint_max) latest_ckpt = checkpoint_manager.latest_checkpoint if latest_ckpt: # Restore model weights, global step, optimizer states logging.info('Restoring from latest checkpoint: %s', latest_ckpt) checkpoint_manager.checkpoint.restore(latest_ckpt).expect_partial() elif checkpoint_path: # Restore model weights only, but not global step and optimizer states logging.info('Restoring from given checkpoint: %s', checkpoint_path) checkpoint_manager2 = tf.train.CheckpointManager( tf.train.Checkpoint(model=model), directory=model_dir, max_to_keep=keep_checkpoint_max) checkpoint_manager2.checkpoint.restore(checkpoint_path).expect_partial() if zero_init_logits_layer: model = checkpoint_manager2.checkpoint.model output_layer_parameters = model.supervised_head.trainable_weights logging.info('Initializing output layer parameters %s to zero', [x.op.name for x in output_layer_parameters]) for x in output_layer_parameters: x.assign(tf.zeros_like(x)) return checkpoint_manager def checkpoint_to_saved_model(ckpt, args, dest, global_step=0): model = model_lib.SimCLR(**args.model_kwargs) checkpoint = tf.train.Checkpoint( model=model, global_step=tf.Variable(0, dtype=tf.int64) ) checkpoint.restore(ckpt).expect_partial() save(model, dest, args, global_step=global_step) # ----------------------------------------------------------------------------- def perform_evaluation( model, builder, eval_steps, ckpt, strategy, model_dir, cache_dataset, args, ): """Perform evaluation.""" if args.train_mode == 'pretrain' and not args.lineareval_while_pretraining: logging.info('Skipping eval during pretraining without linear eval.') return elif not builder.num_classes: logging.info('Skipping eval during pretraining; no labels supplied.') # Build input pipeline. ds = data_lib.build_distributed_dataset( builder, args.eval_batch_size, False, args, strategy, cache_dataset=cache_dataset ) summary_writer = tf.summary.create_file_writer(model_dir) # Build metrics. with strategy.scope(): regularization_loss = tf.keras.metrics.Mean('eval/regularization_loss') label_top_1_accuracy = tf.keras.metrics.Accuracy( 'eval/label_top_1_accuracy') label_top_5_accuracy = tf.keras.metrics.TopKCategoricalAccuracy( 5, 'eval/label_top_5_accuracy') all_metrics = [ regularization_loss, label_top_1_accuracy, label_top_5_accuracy ] # Restore checkpoint. logging.info('Restoring from %s', ckpt) checkpoint = tf.train.Checkpoint( model=model, global_step=tf.Variable(0, dtype=tf.int64)) checkpoint.restore(ckpt).expect_partial() global_step = checkpoint.global_step logging.info('Performing eval at step %d', global_step.numpy()) def single_step(features, labels): _, supervised_head_outputs = model(features, training=False) assert supervised_head_outputs is not None outputs = supervised_head_outputs l = labels['labels'] metrics.update_finetune_metrics_eval(label_top_1_accuracy, label_top_5_accuracy, outputs, l) reg_loss = model_lib.add_weight_decay( model, args.optimizer, args.weight_decay, adjust_per_optimizer=True ) regularization_loss.update_state(reg_loss) with strategy.scope(): @tf.function def run_single_step(iterator): images, labels = next(iterator) features, labels = images, {'labels': labels} strategy.run(single_step, (features, labels)) iterator = iter(ds) for i in range(eval_steps): run_single_step(iterator) logging.info('Completed eval for %d / %d steps', i + 1, eval_steps) logging.info('Finished eval for %s', ckpt) # Write summaries cur_step = global_step.numpy() logging.info('Writing summaries for %d step', cur_step) with summary_writer.as_default(): metrics.log_and_write_metrics_to_summary(all_metrics, cur_step) summary_writer.flush() # Record results as JSON. result_json_path = os.path.join(model_dir, 'result.json') result = {metric.name: metric.result().numpy() for metric in all_metrics} result['global_step'] = global_step.numpy() logging.info(result) with tf.io.gfile.GFile(result_json_path, 'w') as f: json.dump({k: float(v) for k, v in result.items()}, f) result_json_path = os.path.join( model_dir, 'result_%d.json'%result['global_step']) with tf.io.gfile.GFile(result_json_path, 'w') as f: json.dump({k: float(v) for k, v in result.items()}, f) flag_json_path = os.path.join(model_dir, 'args.json') with tf.io.gfile.GFile(flag_json_path, 'w') as f: serializable_flags = {} for key, val in vars(args).items(): # Some flag value types e.g. datetime.timedelta are not json serializable, # filter those out. if utils_lib.json_serializable(val): serializable_flags[key] = val json.dump(serializable_flags, f, indent=1) # Export as SavedModel for finetuning and inference. save( model, os.path.join(model_dir, 'saved_model'), simclr_args=args, global_step=result['global_step'], named_by_step=True ) return result
[docs]def run_simclr( args, builder=None, model_dir=None, cache_dataset=False, checkpoint_path=None, use_tpu=False, tpu_name=None, tpu_zone=None, gcp_project=None, ): """Train a SimCLR model. Args: simCLR_args (SimpleNamespace): SimCLR arguments, as provided by :func:`slideflow.simclr.get_args`. builder (DatasetBuilder, optional): Builder for preparing SimCLR input pipelines. If None, will build using TensorflowDatasets and `simclr_args.dataset`. model_dir (str): Model directory for training. cache_dataset (bool): Whether to cache the entire dataset in memory. If the dataset is ImageNet, this is a very bad idea, but for smaller datasets it can improve performance checkpoint_path (str): Loading from the given checkpoint for fine-tuning if a finetuning checkpoint does not already exist in model_dir use_tpu (bool): Whether to run on TPU. tpu_name (str): The Cloud TPU to use for training. This should be either the name used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url tpu_zone (str): GCE zone where the Cloud TPU is located in. If not specified, we will attempt to automatically detect the GCE project from metadata gcp_project (str): Project name for the Cloud TPU-enabled project. If not specified, we will attempt to automatically detect the GCE project from metadata """ logging.debug("Building SimCLR dataset") if builder is None: builder = tfds.builder(args.dataset, data_dir=args.data_dir) builder.download_and_prepare() num_train_examples = builder.info.splits[args.train_split].num_examples num_eval_examples = builder.info.splits[args.eval_split].num_examples args.num_classes = builder.info.features['label'].num_classes train_steps = model_lib.get_train_steps(num_train_examples, args.train_steps, args.train_epochs, args.train_batch_size) eval_steps = args.eval_steps or int( math.ceil(num_eval_examples / args.eval_batch_size)) epoch_steps = int(round(num_train_examples / args.train_batch_size)) logging.info(f"SimCLR Args: {json.dumps(args.to_dict(), indent=1)}") logging.info('# train examples: %d', num_train_examples) logging.info('# train_steps: %d', train_steps) logging.info('# eval examples: %d', num_eval_examples) logging.info('# eval steps: %d', eval_steps) checkpoint_steps = ( args.checkpoint_steps or (args.checkpoint_epochs * epoch_steps)) topology = None if use_tpu: logging.debug("Configuring TPUs") if tpu_name: cluster = tf.distribute.cluster_resolver.TPUClusterResolver( tpu_name, zone=tpu_zone, project=gcp_project) else: cluster = tf.distribute.cluster_resolver.TPUClusterResolver(args.master) tf.config.experimental_connect_to_cluster(cluster) topology = tf.tpu.experimental.initialize_tpu_system(cluster) logging.info('Topology:') logging.info('num_tasks: %d', topology.num_tasks) logging.info('num_tpus_per_task: %d', topology.num_tpus_per_task) strategy = tf.distribute.TPUStrategy(cluster) else: # For (multiple) GPUs. logging.debug("Configuring distributed dataset with MirroredStrategy") strategy = tf.distribute.MirroredStrategy() logging.info('Running using MirroredStrategy on %d replicas', strategy.num_replicas_in_sync) with strategy.scope(): model = model_lib.SimCLR(**args.model_kwargs) if args.mode == 'eval': logging.debug("Performing evaluation") for ckpt in tf.train.checkpoints_iterator( model_dir, min_interval_secs=15): result = perform_evaluation( model, builder, eval_steps, ckpt, strategy, model_dir, cache_dataset, args ) if result['global_step'] >= train_steps: logging.info('Eval complete. Exiting...') return else: logging.debug("Setting up file writer for logs") summary_writer = tf.summary.create_file_writer(model_dir) if not os.path.exists(model_dir): os.makedirs(model_dir) with open(os.path.join(model_dir, 'args.json'), "w") as data_file: json.dump(args.to_dict(), data_file, indent=1) with strategy.scope(): # Build input pipeline. logging.debug("Setting up distributed dataset") ds = data_lib.build_distributed_dataset(builder, args.train_batch_size, True, args, strategy) # Build LR schedule and optimizer. learning_rate = model_lib.WarmUpAndCosineDecay( learning_rate=args.learning_rate, num_examples=num_train_examples, warmup_epochs=args.warmup_epochs, train_batch_size=args.train_batch_size, learning_rate_scaling=args.learning_rate_scaling, train_steps=args.train_steps, train_epochs=args.train_epochs ) optimizer = model_lib.build_optimizer( learning_rate=learning_rate, optimizer=args.optimizer, momentum=args.momentum, weight_decay=args.weight_decay ) # Build metrics. all_metrics = [] # For summaries. weight_decay_metric = tf.keras.metrics.Mean('train/weight_decay') total_loss_metric = tf.keras.metrics.Mean('train/total_loss') all_metrics.extend([weight_decay_metric, total_loss_metric]) if args.train_mode == 'pretrain': contrast_loss_metric = tf.keras.metrics.Mean('train/contrast_loss') contrast_acc_metric = tf.keras.metrics.Mean('train/contrast_acc') contrast_entropy_metric = tf.keras.metrics.Mean( 'train/contrast_entropy') all_metrics.extend([ contrast_loss_metric, contrast_acc_metric, contrast_entropy_metric ]) if args.train_mode == 'finetune' or args.lineareval_while_pretraining: supervised_loss_metric = tf.keras.metrics.Mean('train/supervised_loss') supervised_acc_metric = tf.keras.metrics.Mean('train/supervised_acc') all_metrics.extend([supervised_loss_metric, supervised_acc_metric]) # Restore checkpoint if available. logging.debug("Attempting to restore from checkpoint") checkpoint_manager = try_restore_from_checkpoint( model, optimizer.iterations, optimizer, model_dir, checkpoint_path, keep_checkpoint_max=args.keep_checkpoint_max, zero_init_logits_layer=args.zero_init_logits_layer) steps_per_loop = min(checkpoint_steps, train_steps) def single_step(features, labels): with tf.GradientTape() as tape: # Log summaries on the last step of the training loop to match # logging frequency of other scalar summaries. # # Notes: # 1. Summary ops on TPUs get outside compiled so they do not affect # performance. # 2. Summaries are recorded only on replica 0. So effectively this # summary would be written once per host when should_record == True. # 3. optimizer.iterations is incremented in the call to apply_gradients. # So we use `iterations + 1` here so that the step number matches # those of scalar summaries. # 4. We intentionally run the summary op before the actual model # training so that it can run in parallel. should_record = tf.equal((optimizer.iterations + 1) % steps_per_loop, 0) with tf.summary.record_if(should_record): # Only log augmented images for the first tower. tf.summary.image( 'image', features[:, :, :, :3], step=optimizer.iterations + 1) projection_head_outputs, supervised_head_outputs = model( features, training=True) loss = None if projection_head_outputs is not None: outputs = projection_head_outputs con_loss, logits_con, labels_con = obj_lib.add_contrastive_loss( outputs, hidden_norm=args.hidden_norm, temperature=args.temperature, strategy=strategy) if loss is None: loss = con_loss else: loss += con_loss metrics.update_pretrain_metrics_train(contrast_loss_metric, contrast_acc_metric, contrast_entropy_metric, con_loss, logits_con, labels_con) if supervised_head_outputs is not None: outputs = supervised_head_outputs l = labels['labels'] if (args.train_mode == 'pretrain' and args.lineareval_while_pretraining and args.num_classes): l = tf.concat([l, l], 0) sup_loss = obj_lib.add_supervised_loss(labels=l, logits=outputs) if loss is None: loss = sup_loss else: loss += sup_loss metrics.update_finetune_metrics_train(supervised_loss_metric, supervised_acc_metric, sup_loss, l, outputs) weight_decay = model_lib.add_weight_decay( model, args.optimizer, args.weight_decay, adjust_per_optimizer=True ) weight_decay_metric.update_state(weight_decay) loss += weight_decay total_loss_metric.update_state(loss) # The default behavior of `apply_gradients` is to sum gradients from all # replicas so we divide the loss by the number of replicas so that the # mean gradient is applied. loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) with strategy.scope(): @tf.function def train_single_step(iterator): # Drop the "while" prefix created by tf.while_loop which otherwise # gets prefixed to every variable name. This does not affect training # but does affect the checkpoint conversion script. # TODO(b/161712658): Remove this. with tf.name_scope(''): images, labels = next(iterator) features, labels = images, {'labels': labels} strategy.run(single_step, (features, labels)) def train_multiple_steps(iterator): for _ in tqdm(range(steps_per_loop)): train_single_step(iterator) global_step = optimizer.iterations cur_step = global_step.numpy() iterator = iter(ds) logging.debug("Beginning training") while cur_step < train_steps: # Calls to tf.summary.xyz lookup the summary writer resource which is # set by the summary writer's context manager. with summary_writer.as_default(): train_multiple_steps(iterator) cur_step = global_step.numpy() checkpoint_manager.save(cur_step) logging.info('Completed: %d / %d steps', cur_step, train_steps) metrics.log_and_write_metrics_to_summary(all_metrics, cur_step) tf.summary.scalar( 'learning_rate', learning_rate(tf.cast(global_step, dtype=tf.float32)), global_step) summary_writer.flush() for metric in all_metrics: metric.reset_states() logging.info('Training complete...') if args.mode == 'train_then_eval': perform_evaluation(model, builder, eval_steps, checkpoint_manager.latest_checkpoint, strategy, model_dir, cache_dataset, args) else: # Export as SavedModel for finetuning and inference. save( model, os.path.join(model_dir, 'saved_model'), args, global_step=global_step)