• Docs >
  • Generative Networks (GANs)
Shortcuts

Generative Networks (GANs)


Slideflow includes tools to easily interface with the PyTorch implementations of StyleGAN2 and StyleGAN3, allowing you to train these Generative Adversarial Networks (GANs). Slideflow additionally includes tools to assist with image generation, interpolation between class labels, and interactively visualize GAN-generated images and their predictions. See our manuscript on the use of GANs to generate synthetic histology for an example of how these networks might be used.

Note

StyleGAN requires PyTorch <0.13 and Slideflow-NonCommercial, which can be installed with:

pip install slideflow-noncommercial

Training StyleGAN

The easiest way to train StyleGAN2/StyleGAN3 is with slideflow.Project.gan_train(). Both standard and class-conditional GANs are supported. To train a GAN, pass a slideflow.Dataset, experiment label, and StyleGAN keyword arguments to this function:

import slideflow as sf

P = sf.Project('/project/path')
dataset = P.dataset(tile_px=512, tile_um=400)

P.gan_train(
  dataset=dataset,
  model='stylegan3',
  cfg='stylegan3-r',
  exp_label="ExperimentLabel",
  gpus=4,
  batch=32,
  ...
)

The trained networks will be saved in the gan/ subfolder in the project directory.

StyleGAN2/3 can only be trained on images with sizes that are powers of 2. You can crop and/or resize images from a Dataset to match this requirement by using the crop and/or resize arguments:

dataset = P.dataset(tile_px=299, ...)

# Train a GAN on images resized to 256x256
P.gan_train(
  ...,
  resize=256,
)

See the slideflow.Project.gan_train() documentation for additional keyword arguments to customize training.

Class conditioning

GANs can also be trained with class conditioning. To train a class-conditional GAN, simply provide a list of categorical outcome labels to the outcomes argument of slideflow.Project.gan_train(). For example, to train a GAN with class conditioning on ER status:

P.gan_train(
  ...,
  outcomes='er_status'
)

Tile-level labels

In addition to class conditioning with slide-level labels, StyleGAN2/StyleGAN3 can be trained with tile-level class conditioning. Tile-level labels can be generated through ROI annotations, as described in Strong Supervision with Tile Labels.

Prepare a pandas dataframe, indexed with the format {slide}-{x}-{y}, where slide is the name of the slide (without extension), x is the corresponding tile x-coordinate, and y is the tile y-coordinate. The dataframe should have a single column, label, containing onehot-encoded category labels. For example:

import pandas as pd

df = pd.DataFrame(
  index=[
    'slide1-251-425',
    'slide1-560-241',
    'slide1-321-502',
    ...
  ],
  data={
    'label': [
      [1, 0, 0],
      [1, 0, 0],
      [0, 1, 0],
      ...
    ]
  }
)

This dataframe can be generated, as described in Strong Supervision with Tile Labels, through the slideflow.Dataset.get_tile_dataframe() function. For GAN conditioning, the label column should be onehot-encoded.

Once the dataframe is complete, save it in parquet format:

df.to_parquet('tile_labels.parquet')

And supply this file to the tile_labels argument of slideflow.Project.gan_train():

P.gan_train(
  ...,
  tile_labels='tile_labels.parquet'
)

Generating images

Images can be generated from a trained GAN and exported either as loose images in PNG or JPG format, or alternatively stored in TFRecords. Images are generated from a list of seeds (list of int). Use the slideflow.Project.gan_generate() function to generate images, with out set to a directory path if exporting loose images, or out set to a filename ending in .tfrecords if saving images in TFRecord format:

network_pkl = '/path/to/trained/gan.pkl'
P.gan_generate(
  network_pkl,
  out='target.tfrecords',
  seeds=range(100),
  ...
)

The image format is set with the format argument:

P.gan_generate(
  ...,
  format='jpg',
)

Class index (for class-conditional GANs) is set with class_idx:

P.gan_generate(
  ...,
  class_idx=1,
)

Finally, images can be resized after generation to match a target tile size:

P.gan_generate(
  ...,
  gan_px=512,
  gan_um=400,
  target_px=299,
  target_um=302,
)

Interactive visualization

Slideflow Studio can be used to interactively visualize GAN-generated images (see Slideflow Studio: Live Visualization). Images can be directly exported from this interface. This tool also enables you to visualize real-time predictions for GAN generated images when as inputs to a trained classifier.

For more examples of using Slideflow to work with GAN-generated images, see our GitHub repository for code accompanying the previously referenced manuscript.