# coding=utf-8
# Copyright 2020 The SimCLR Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific simclr governing permissions and
# limitations under the License.
# ==============================================================================
"""Model specification for SimCLR."""
import math
import tensorflow.compat.v2 as tf
from . import data_util
from . import lars_optimizer
from . import resnet
def build_optimizer(learning_rate, optimizer, momentum, weight_decay):
"""Returns the optimizer."""
if optimizer == 'momentum':
return tf.keras.optimizers.SGD(learning_rate, momentum, nesterov=True)
elif optimizer == 'adam':
return tf.keras.optimizers.Adam(learning_rate)
elif optimizer == 'lars':
return lars_optimizer.LARSOptimizer(
learning_rate,
momentum=momentum,
weight_decay=weight_decay,
exclude_from_weight_decay=[
'batch_normalization', 'bias', 'head_supervised'
])
else:
raise ValueError('Unknown optimizer {}'.format(optimizer))
def add_weight_decay(model, optimizer, weight_decay, adjust_per_optimizer=True):
"""Compute weight decay."""
if adjust_per_optimizer and 'lars' in optimizer:
# Weight decay are taking care of by optimizer for these cases.
# Except for supervised head, which will be added here.
l2_losses = [
tf.nn.l2_loss(v)
for v in model.trainable_variables
if 'head_supervised' in v.name and 'bias' not in v.name
]
if l2_losses:
return weight_decay * tf.add_n(l2_losses)
else:
return 0
# TODO(srbs): Think of a way to avoid name-based filtering here.
l2_losses = [
tf.nn.l2_loss(v)
for v in model.trainable_weights
if 'batch_normalization' not in v.name
]
loss = weight_decay * tf.add_n(l2_losses)
return loss
def get_train_steps(num_examples, train_steps, train_epochs, train_batch_size):
"""Determine the number of training steps."""
return train_steps or (
num_examples * train_epochs // train_batch_size + 1)
class WarmUpAndCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Applies a warmup schedule on a given learning rate decay schedule."""
def __init__(
self,
learning_rate,
num_examples,
*,
warmup_epochs=10,
train_batch_size=512,
learning_rate_scaling='linear',
train_steps=0,
train_epochs=100,
name=None
):
super(WarmUpAndCosineDecay, self).__init__()
self.base_learning_rate = learning_rate
self.num_examples = num_examples
self._name = name
self.warmup_epochs = warmup_epochs
self.train_batch_size = train_batch_size
self.learning_rate_scaling = learning_rate_scaling
self.train_steps = train_steps
self.train_epochs = train_epochs
def __call__(self, step):
with tf.name_scope(self._name or 'WarmUpAndCosineDecay'):
warmup_steps = int(
round(self.warmup_epochs * self.num_examples //
self.train_batch_size))
if self.learning_rate_scaling == 'linear':
scaled_lr = self.base_learning_rate * self.train_batch_size / 256.
elif self.learning_rate_scaling == 'sqrt':
scaled_lr = self.base_learning_rate * math.sqrt(self.train_batch_size)
else:
raise ValueError('Unknown learning rate scaling {}'.format(
self.learning_rate_scaling))
learning_rate = (
step / float(warmup_steps) * scaled_lr if warmup_steps else scaled_lr)
# Cosine decay learning rate schedule
total_steps = get_train_steps(self.num_examples, self.train_steps,
self.train_epochs, self.train_batch_size)
# TODO(srbs): Cache this object.
cosine_decay = tf.keras.experimental.CosineDecay(
scaled_lr, total_steps - warmup_steps)
learning_rate = tf.where(step < warmup_steps, learning_rate,
cosine_decay(step - warmup_steps))
return learning_rate
def get_config(self):
return {
'base_learning_rate': self.base_learning_rate,
'num_examples': self.num_examples,
}
class LinearLayer(tf.keras.layers.Layer):
def __init__(
self,
num_classes,
use_bias=True,
use_bn=False,
name='linear_layer',
**kwargs
):
# Note: use_bias is ignored for the dense layer when use_bn=True.
# However, it is still used for batch norm.
super(LinearLayer, self).__init__(**kwargs)
self.num_classes = num_classes
self.use_bias = use_bias
self.use_bn = use_bn
self._name = name
if self.use_bn:
self.bn_relu = resnet.BatchNormRelu(relu=False, center=use_bias)
def build(self, input_shape):
# TODO(srbs): Add a new SquareDense layer.
if callable(self.num_classes):
num_classes = self.num_classes(input_shape)
else:
num_classes = self.num_classes
self.dense = tf.keras.layers.Dense(
num_classes,
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
use_bias=self.use_bias and not self.use_bn)
super(LinearLayer, self).build(input_shape)
def call(self, inputs, training):
assert inputs.shape.ndims == 2, inputs.shape
inputs = self.dense(inputs)
if self.use_bn:
inputs = self.bn_relu(inputs, training=training)
return inputs
class ProjectionHead(tf.keras.layers.Layer):
def __init__(
self,
proj_out_dim,
proj_head_mode='nonlinear',
num_proj_layers=3,
ft_proj_selector=0,
**kwargs
):
self.linear_layers = []
if proj_head_mode == 'none':
pass # directly use the output hiddens as hiddens
elif proj_head_mode == 'linear':
self.linear_layers = [
LinearLayer(
num_classes=proj_out_dim, use_bias=False, use_bn=True, name='l_0')
]
elif proj_head_mode == 'nonlinear':
for j in range(num_proj_layers):
if j != num_proj_layers - 1:
# for the middle layers, use bias and relu for the output.
self.linear_layers.append(
LinearLayer(
num_classes=lambda input_shape: int(input_shape[-1]),
use_bias=True,
use_bn=True,
name='nl_%d' % j))
else:
# for the final layer, neither bias nor relu is used.
self.linear_layers.append(
LinearLayer(
num_classes=proj_out_dim,
use_bias=False,
use_bn=True,
name='nl_%d' % j))
else:
raise ValueError('Unknown head projection mode {}'.format(
proj_head_mode))
super(ProjectionHead, self).__init__(**kwargs)
self.proj_head_mode = proj_head_mode
self.num_proj_layers = num_proj_layers
self.ft_proj_selector = ft_proj_selector
def call(self, inputs, training):
if self.proj_head_mode == 'none':
return inputs # directly use the output hiddens as hiddens
hiddens_list = [tf.identity(inputs, 'proj_head_input')]
if self.proj_head_mode == 'linear':
assert len(self.linear_layers) == 1, len(self.linear_layers)
return hiddens_list.append(self.linear_layers[0](hiddens_list[-1],
training))
elif self.proj_head_mode == 'nonlinear':
for j in range(self.num_proj_layers):
hiddens = self.linear_layers[j](hiddens_list[-1], training)
if j != self.num_proj_layers - 1:
# for the middle layers, use bias and relu for the output.
hiddens = tf.nn.relu(hiddens)
hiddens_list.append(hiddens)
else:
raise ValueError('Unknown head projection mode {}'.format(
self.proj_head_mode))
# The first element is the output of the projection head.
# The second element is the input of the finetune head.
proj_head_output = tf.identity(hiddens_list[-1], 'proj_head_output')
return proj_head_output, hiddens_list[self.ft_proj_selector]
class SupervisedHead(tf.keras.layers.Layer):
def __init__(self, num_classes, name='head_supervised', **kwargs):
super(SupervisedHead, self).__init__(name=name, **kwargs)
self.linear_layer = LinearLayer(num_classes)
def call(self, inputs, training):
inputs = self.linear_layer(inputs, training)
inputs = tf.identity(inputs, name='logits_sup')
return inputs
[docs]class SimCLR(tf.keras.models.Model):
"""Resnet model with projection or supervised layer."""
def __init__(
self,
num_classes,
resnet_depth=50,
width_multiplier=1,
sk_ratio=0.,
se_ratio=0.,
image_size=224,
batch_norm_decay=0.9,
train_mode='pretrain',
lineareval_while_pretraining=True,
fine_tune_after_block=-1,
use_blur=True,
proj_out_dim=128,
proj_head_mode='nonlinear',
num_proj_layers=3,
ft_proj_selector=0,
**kwargs
):
super(SimCLR, self).__init__(**kwargs)
self.resnet_model = resnet.resnet(
train_mode=train_mode,
width_multiplier=width_multiplier,
resnet_depth=resnet_depth,
cifar_stem=image_size <= 32,
sk_ratio=sk_ratio,
se_ratio=se_ratio,
batch_norm_decay=batch_norm_decay,
fine_tune_after_block=fine_tune_after_block
)
self._projection_head = ProjectionHead(
proj_out_dim,
proj_head_mode=proj_head_mode,
num_proj_layers=num_proj_layers,
ft_proj_selector=ft_proj_selector
)
if ((train_mode == 'finetune' or lineareval_while_pretraining) and num_classes):
self.supervised_head = SupervisedHead(num_classes)
self.train_mode = train_mode
self.fine_tune_after_block = fine_tune_after_block
self.use_blur = use_blur
self.image_size = image_size
self.lineareval_while_pretraining = lineareval_while_pretraining
self.num_classes = num_classes
def __call__(self, inputs, training):
features = inputs
if training and self.train_mode == 'pretrain':
if self.fine_tune_after_block > -1:
raise ValueError('Does not support layer freezing during pretraining,'
'should set fine_tune_after_block<=-1 for safety.')
if inputs.shape[3] is None:
raise ValueError('The input channels dimension must be statically known '
f'(got input shape {inputs.shape})')
num_transforms = inputs.shape[3] // 3
num_transforms = tf.repeat(3, num_transforms)
# Split channels, and optionally apply extra batched augmentation.
features_list = tf.split(
features, num_or_size_splits=num_transforms, axis=-1)
if self.use_blur and training and self.train_mode == 'pretrain':
features_list = data_util.batch_random_blur(features_list,
self.image_size,
self.image_size)
features = tf.concat(features_list, 0) # (num_transforms * bsz, h, w, c)
# Base network forward pass.
hiddens = self.resnet_model(features, training=training)
# Add heads.
projection_head_outputs, supervised_head_inputs = self._projection_head(
hiddens, training)
if self.train_mode == 'finetune':
supervised_head_outputs = self.supervised_head(supervised_head_inputs,
training)
return None, supervised_head_outputs
elif (self.train_mode == 'pretrain'
and self.lineareval_while_pretraining
and self.num_classes):
# When performing pretraining and linear evaluation together we do not
# want information from linear eval flowing back into pretraining network
# so we put a stop_gradient.
supervised_head_outputs = self.supervised_head(
tf.stop_gradient(supervised_head_inputs), training)
return projection_head_outputs, supervised_head_outputs
else:
return projection_head_outputs, None