
Source code for slideflow.gan.interpolate

"""Tool to assist with embedding interpolation for a class-conditional GAN."""

from typing import (Generator, List, Optional, Tuple, Union,
                    TYPE_CHECKING, Any, Iterable)
import warnings
import numpy as np
import pandas as pd
import slideflow as sf
import json
from os.path import join, dirname, exists
from PIL import Image
from tqdm import tqdm
from functools import partial

from slideflow.gan.utils import crop, noise_tensor
from slideflow import errors

    import torch
    import tensorflow as tf

[docs]class StyleGAN2Interpolator:
[docs] def __init__( self, gan_pkl: str, start: int, end: int, *, device: Optional["torch.device"] = None, target_um: Optional[int] = None, target_px: Optional[int] = None, gan_um: Optional[int] = None, gan_px: Optional[int] = None, noise_mode: str = 'const', truncation_psi: int = 1, **gan_kwargs ) -> None: """Coordinates class and embedding interpolation for a trained class-conditional StyleGAN2. Args: gan_pkl (str): Path to saved network pkl. start (int): Starting class index. end (int): Ending class index. Keyword Args: device (torch.device, optional): Torch device. If None, will automatically select a GPU if available. Defaults to None. target_um (int, optional): Target size in microns for the interpolated images. GAN output will be cropped/resized to match this target. If None, will match GAN output. Defaults to None. target_px (int, optional): Target size in pixels for the interpolated images. GAN output will be cropped/resized to match this target. If None, will match GAN output. Defaults to None. gan_um (int, optional): Size in microns of the GAN output. If None, will attempt to auto-detect from training_options.json. Defaults to None. gan_px (int, optional): Size in pixels of the GAN output. If None, will attempt to auto-detect from training_options.json. Defaults to None. noise_mode (str, optional): Noise mode for GAN. Defaults to 'const'. truncation_psi (int, optional): Truncation psi for GAN. Defaults to 1. **gan_kwargs: Additional keyword arguments for GAN inference. """ from slideflow.model.torch_utils import get_device from slideflow.gan.stylegan2.stylegan2 import embedding training_options = join(dirname(gan_pkl), 'training_options.json') if exists(training_options): with open(training_options, 'r') as f: opt = json.load(f) if 'slideflow_kwargs' in opt: _gan_px = opt['slideflow_kwargs']['tile_px'] _gan_um = opt['slideflow_kwargs']['tile_um'] if gan_px != gan_px or _gan_um != _gan_um: sf.log.warn("Provided GAN tile size (gan_px={}, gan_um={}) does " "not match training_options.json (gan_px={}, " "gan_um={})".format(gan_px, gan_um, _gan_px, _gan_um)) if gan_px is None: gan_px = _gan_px if gan_um is None: gan_um = _gan_um if gan_px is None or gan_um is None: raise ValueError("Unable to auto-detect gan_px/gan_um from " "training_options.json. Must be set with gan_um " "and gan_px.") if target_px is None: target_px = gan_px if target_um is None: target_um = gan_um if device is None: device = get_device() self.E_G, self.G = embedding.load_embedding_gan(gan_pkl, device) self.device = device self.gan_kwargs = dict( noise_mode=noise_mode, truncation_psi=truncation_psi, **gan_kwargs) self.embeddings = embedding.get_embeddings(self.G, device=device) self.embed0 = self.embeddings[start] self.embed1 = self.embeddings[end] self.features = None # type: Optional[sf.model.Features] self.normalizer = None self.target_px = target_px self.crop_kw = dict( gan_um=gan_um, gan_px=gan_px, target_um=target_um, ) self._classifier_backend = sf.backend()
def _crop_and_convert_to_uint8(self, img: "torch.Tensor") -> Any: """Convert a batch of GAN images to a resized/cropped uint8 tensor. Args: img (torch.Tensor): Raw GAN output images (torch.float32) Returns: Any: GAN images (torch.uint8) """ import torch import if self._classifier_backend == 'tensorflow': import tensorflow as tf dtype = tf.uint8 elif self._classifier_backend == 'torch': dtype = torch.uint8 else: raise errors.UnrecognizedBackendError img = crop(img, **self.crop_kw) # type: ignore img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) img =, standardize=False, resize_px=self.target_px) return, dtype) def _preprocess_from_uint8( self, img: Any, normalize: bool, standardize: bool, ) -> Any: """Convert and resize a batch of uint8 tensors to standardized/normalized tensors ready for input to the classifier/feature model. Args: img (Any): GAN images (uint8) normalize (bool): Normalize the images. standardize (bool): Standardize the images. Returns: Any: Resized GAN images (uint8 or float32 if standardize=True) """ normalizer = self.normalizer if normalize else None if self._classifier_backend == 'tensorflow': return img, normalizer=normalizer, standardize=standardize)['tile_image'] elif self._classifier_backend == 'torch': return img, normalizer=normalizer, standardize=standardize) else: raise errors.UnrecognizedBackendError def _standardize(self, img: Any) -> Any: """Standardize image from uint8 to float. Args: img (Any): uint8 image tensor. Returns: Any: Standardized float image tensor. """ if self._classifier_backend == 'tensorflow': import tensorflow as tf return, tf.float32) elif self._classifier_backend == 'torch': import torch return, torch.float32) else: raise errors.UnrecognizedBackendError def _build_gan_dataset(self, generator) -> Iterable: """Build a dataset from a given GAN generator. Args: generator (Generator): Python generator which yields cropped (but not resized) uint8 tensors. Returns: Iterable: Iterable dataset which yields processed (resized and normalized) images. """ if self._classifier_backend == 'tensorflow': import tensorflow as tf sig = tf.TensorSpec(shape=(None, self.target_px, self.target_px, 3), dtype=tf.uint8) dts =, output_signature=sig) return partial(, normalizer=self.normalizer),, deterministic=True ) elif self._classifier_backend == 'torch': return map( partial(, normalizer=self.normalizer), generator()) else: raise errors.UnrecognizedBackendError
[docs] def z(self, seed: Union[int, List[int]]) -> "torch.Tensor": """Returns a noise tensor for a given seed. Args: seed (int): Seed. Returns: torch.tensor: Noise tensor for the corresponding seed. """ import torch if isinstance(seed, int): return noise_tensor(seed, self.E_G.z_dim).to(self.device) # type: ignore elif isinstance(seed, list): return torch.stack( [noise_tensor(s, self.E_G.z_dim).to(self.device) for s in seed], dim=0) else: raise ValueError(f"Unrecognized seed: {seed}")
def set_feature_model(self, *args, **kwargs): warnings.warn( "StyleGAN2Interpolator.set_feature_model() is deprecated. " "Please use .set_classifier() instead.", DeprecationWarning) return self.set_classifier(*args, **kwargs)
[docs] def set_classifier( self, path: str, layers: Optional[Union[str, List[str]]] = None, **kwargs ) -> None: """Configures a classifier model to be used for generating features and predictions during interpolation. Args: path (str): Path to trained model. layers (Union[str, List[str]], optional): Layers from which to calculate activations for interpolated images. Defaults to None. """ if sf.util.is_tensorflow_model_path(path): from slideflow.model.tensorflow import Features import self.features = Features( path, layers=layers, include_preds=True, **kwargs) self.normalizer = self.features.wsi_normalizer # type: ignore self._classifier_backend = 'tensorflow' elif sf.util.is_torch_model_path(path): from slideflow.model.torch import Features import self.features = Features( path, layers=layers, include_preds=True, **kwargs) self.normalizer = self.features.wsi_normalizer # type: ignore self._classifier_backend = 'torch' else: raise ValueError(f"Unrecognized backend for model {path}")
[docs] def plot_comparison( self, seeds: Union[int, List[int]], titles: Optional[List[str]] = None ) -> None: """Plots side-by-side comparison of images from the starting and ending interpolation classes. Args: seeds (int or list(int)): Seeds to display. """ import matplotlib.pyplot as plt if not isinstance(seeds, list): seeds = [seeds] if titles is None: titles = ['Start', 'End'] assert len(titles) == 2 def _process_to_pil(_img): _img = self._crop_and_convert_to_uint8(_img) _img = self._preprocess_from_uint8(_img, standardize=False, normalize=False) if self._classifier_backend == 'torch': _img = return Image.fromarray([0], np.uint8)) scale = 5 fig, ax = plt.subplots(len(seeds), 2, figsize=(2 * scale, len(seeds) * scale)) for s, seed in enumerate(seeds): img0 = _process_to_pil(self.generate_start(seed)) img1 = _process_to_pil(self.generate_end(seed)) if len(seeds) == 1: _ax0 = ax[0] _ax1 = ax[1] else: _ax0 = ax[s, 0] _ax1 = ax[s, 1] if s == 0: _ax0.set_title(titles[0]) _ax1.set_title(titles[1]) _ax0.imshow(img0) _ax1.imshow(img1) _ax0.axis('off') _ax1.axis('off') fig.subplots_adjust(wspace=0.05, hspace=0)
[docs] def generate(self, seed: Union[int, List[int]], embedding: "torch.Tensor") -> "torch.Tensor": """Generate an image from a given embedding. Args: seed (int): Seed for noise vector. embedding (torch.Tensor): Class embedding. Returns: torch.Tensor: Image (float32, shape=(1, 3, height, width)) """ z = self.z(seed) if z.shape[0] == 1 and embedding.shape[0] > 1: z = z.repeat(embedding.shape[0], 1) elif z.shape[0] > 1 and embedding.shape[0] == 1: embedding = embedding.repeat(z.shape[0], 1) return self.E_G(z, embedding, **self.gan_kwargs)
[docs] def generate_start(self, seed: int) -> "torch.Tensor": """Generate an image from the starting class. Args: seed (int): Seed for noise vector. Returns: torch.Tensor: Image (float32, shape=(1, 3, height, width)) """ return self.generate(seed, self.embed0)
[docs] def generate_end(self, seed: int) -> "torch.Tensor": """Generate an image from the ending class. Args: seed (int): Seed for noise vector. Returns: torch.Tensor: Image (float32, shape=(1, 3, height, width)) """ return self.generate(seed, self.embed1)
[docs] def generate_np_from_embedding( self, seed: int, embedding: "torch.Tensor" ) -> np.ndarray: """Generate a numpy image from a given embedding. Args: seed (int): Seed for noise vector. embedding (torch.Tensor): Class embedding. Returns: np.ndarray: Image (uint8, shape=(height, width, 3)) """ import torch img = self.generate(seed, embedding) img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)[0] img = img.permute(1, 2, 0) return, np.uint8)
[docs] def generate_np_start(self, seed: int) -> np.ndarray: """Generate a numpy image from the starting class. Args: seed (int): Seed for noise vector. Returns: np.ndarray: Image (uint8, shape=(height, width, 3)) """ return self.generate_np_from_embedding(seed, self.embed0)
[docs] def generate_np_end(self, seed: int) -> np.ndarray: """Generate a numpy image from the ending class. Args: seed (int): Seed for noise vector. Returns: np.ndarray: Image (uint8, shape=(height, width, 3)) """ return self.generate_np_from_embedding(seed, self.embed1)
[docs] def generate_tf_from_embedding( self, seed: Union[int, List[int]], embedding: "torch.Tensor" ) -> Tuple["tf.Tensor", "tf.Tensor"]: """Create a processed Tensorflow image from the GAN output from a given seed and embedding. Args: seed (int): Seed for noise vector. embedding (torch.tensor): Class embedding. Returns: A tuple containing tf.Tensor: Unprocessed resized image, uint8. tf.Tensor: Processed resized image, standardized and normalized. """ gan_out = self.generate(seed, embedding) gan_out = self._crop_and_convert_to_uint8(gan_out) gan_out = self._preprocess_from_uint8(gan_out, standardize=False, normalize=True) standardized = self._standardize(gan_out) if isinstance(seed, list) or (len(embedding.shape) > 1 and embedding.shape[0] > 1): return gan_out, standardized else: return gan_out[0], standardized[0]
[docs] def generate_tf_start(self, seed: int) -> Tuple["tf.Tensor", "tf.Tensor"]: """Create a processed Tensorflow image from the GAN output of a given seed and the starting class embedding. Args: seed (int): Seed for noise vector. Returns: A tuple containing tf.Tensor: Unprocessed image (tf.Tensor), uint8. tf.Tensor: Processed image (tf.Tensor), standardized and normalized. """ return self.generate_tf_from_embedding(seed, self.embed0)
[docs] def generate_tf_end(self, seed: int) -> Tuple["tf.Tensor", "tf.Tensor"]: """Create a processed Tensorflow image from the GAN output of a given seed and the ending class embedding. Args: seed (int): Seed for noise vector. Returns: A tuple containing tf.Tensor: Unprocessed resized image, uint8. tf.Tensor: Processed resized image, standardized and normalized. """ return self.generate_tf_from_embedding(seed, self.embed1)
[docs] def class_interpolate(self, seed: int, steps: int = 100) -> Generator: """Sets up a generator that returns images during class embedding interpolation. Args: seed (int): Seed for random noise vector. steps (int, optional): Number of steps for interpolation. Defaults to 100. Returns: Generator: Generator which yields images (torch.tensor, uint8) during interpolation. Yields: Generator: images (torch.tensor, dtype=uint8) """ from slideflow.gan.stylegan2.stylegan2 import embedding return embedding.class_interpolate( self.E_G, self.z(seed), self.embed0, self.embed1, device=self.device, steps=steps, **self.gan_kwargs )
[docs] def linear_interpolate(self, seed: int, steps: int = 100) -> Generator: """Sets up a generator that returns images during linear label interpolation. Args: seed (int): Seed for random noise vector. steps (int, optional): Number of steps for interpolation. Defaults to 100. Returns: Generator: Generator which yields images (torch.tensor, uint8) during interpolation. Yields: Generator: images (torch.tensor, dtype=uint8) """ from slideflow.gan.stylegan2.stylegan2 import embedding return embedding.linear_interpolate( self.G, self.z(seed), device=self.device, steps=steps, **self.gan_kwargs )
[docs] def interpolate_and_predict( self, seed: int, steps: int = 100, outcome_idx: int = 0, ) -> Tuple[List, ...]: """Interpolates between starting and ending classes for a seed, recording raw images, processed images, and predictions. Args: seed (int): Seed for random noise vector. steps (int, optional): Number of steps during interpolation. Defaults to 100. Returns: Tuple[List, ...]: Raw images, processed images, and predictions. """ if not isinstance(seed, int): raise ValueError("Seed must be an integer.") import torch import matplotlib.pyplot as plt import seaborn as sns imgs = [] proc_imgs = [] preds = [] for img in tqdm(self.class_interpolate(seed, steps), total=steps, desc=f"Working on seed {seed}..."): img = torch.from_numpy(np.expand_dims(img, axis=0)).permute(0, 3, 1, 2) img = (img / 127.5) - 1 img = self._crop_and_convert_to_uint8(img) img = self._preprocess_from_uint8(img, standardize=False, normalize=True) processed_img = self._standardize(img) img =, np.float32)[0] if self.features is not None: pred = self.features(processed_img)[-1] if self._classifier_backend == 'torch': pred = pred.cpu() pred = pred.numpy() preds += [pred[0][outcome_idx]] imgs += [img] proc_imgs += [processed_img[0]] sns.lineplot(x=range(len(preds)), y=preds, label=f"Seed {seed}") plt.axhline(y=0, color='black', linestyle='--') plt.title("Prediction during interpolation") plt.xlabel("Interpolation Step (Start -> End)") plt.ylabel("Prediction") return imgs, proc_imgs, preds