"""H&E stain normalization and augmentation tools."""
from __future__ import absolute_import
import os
import sys
import multiprocessing as mp
from io import BytesIO
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import cv2
import numpy as np
import slideflow as sf
from PIL import Image
from contextlib import contextmanager
from rich.progress import Progress
from slideflow import errors
from slideflow.dataset import Dataset
from slideflow.util import detuple, log, cleanup_progress, _as_list
from slideflow.norm import (augment, macenko, reinhard, vahadane)
if TYPE_CHECKING:
import tensorflow as tf
import torch
[docs]class StainNormalizer:
vectorized = False
normalizers = {
'macenko': macenko.MacenkoNormalizer,
'macenko_fast': macenko.MacenkoFastNormalizer,
'reinhard': reinhard.ReinhardNormalizer,
'reinhard_fast': reinhard.ReinhardFastNormalizer,
'reinhard_mask': reinhard.ReinhardMaskNormalizer,
'reinhard_fast_mask': reinhard.ReinhardFastMaskNormalizer,
'vahadane': vahadane.VahadaneSpamsNormalizer,
'vahadane_sklearn': vahadane.VahadaneSklearnNormalizer,
'vahadane_spams': vahadane.VahadaneSpamsNormalizer,
'augment': augment.AugmentNormalizer
} # type: Dict[str, Any]
def __init__(self, method: str, **kwargs) -> None:
"""H&E Stain normalizer supporting various normalization methods.
The stain normalizer supports numpy images, PNG or JPG strings,
Tensorflow tensors, and PyTorch tensors. The default ``.transform()``
method will attempt to preserve the original image type while
minimizing conversions to and from Tensors.
Alternatively, you can manually specify the image conversion type
by using the appropriate function. For example, to convert a Tensor
to a normalized numpy RGB image, use ``.tf_to_rgb()``.
Args:
method (str): Normalization method. Options include 'macenko',
'reinhard', 'reinhard_fast', 'reinhard_mask',
'reinhard_fast_mask', 'vahadane', 'vahadane_spams',
'vahadane_sklearn', and 'augment'.
Keyword args:
stain_matrix_target (np.ndarray, optional): Set the stain matrix
target for the normalizer. May raise an error if the normalizer
does not have a stain_matrix_target fit attribute.
target_concentrations (np.ndarray, optional): Set the target
concentrations for the normalizer. May raise an error if the
normalizer does not have a target_concentrations fit attribute.
target_means (np.ndarray, optional): Set the target means for the
normalizer. May raise an error if the normalizer does not have
a target_means fit attribute.
target_stds (np.ndarray, optional): Set the target standard
deviations for the normalizer. May raise an error if the
normalizer does not have a target_stds fit attribute.
Raises:
ValueError: If the specified normalizer method is not available.
Examples
Normalize a numpy image using the default fit.
>>> import slideflow as sf
>>> macenko = sf.norm.StainNormalizer('macenko')
>>> macenko.transform(image)
Fit the normalizer to a target image (numpy or path).
>>> macenko.fit(target_image)
Fit the normalizer using a preset configuration.
>>> macenko.fit('v2')
Fit the normalizer to all images in a Dataset.
>>> dataset = sf.Dataset(...)
>>> macenko.fit(dataset)
Normalize an image and convert from Tensor to numpy array (RGB).
>>> macenko.tf_to_rgb(image)
Normalize images during DataLoader pre-processing.
>>> dataset = sf.Dataset(...)
>>> dataloader = dataset.torch(..., normalizer=macenko)
>>> dts = dataset.tensorflow(..., normalizer=macenko)
"""
if method not in self.normalizers:
raise ValueError(f"Unrecognized normalizer method {method}")
self.method = method
self.n = self.normalizers[method]()
if kwargs:
self.n.fit(**kwargs)
def __repr__(self):
base = "{}(\n".format(self.__class__.__name__)
base += " method = {!r},\n".format(self.method)
for fit_param, fit_val in self.get_fit().items():
base += " {} = {!r},\n".format(fit_param, fit_val)
base += ")"
return base
@property
def device(self) -> str:
return 'cpu'
def _torch_transform(
self,
inp: "torch.Tensor",
*,
augment: bool = False
) -> "torch.Tensor":
"""Normalize a torch uint8 image (CWH).
Normalization ocurs via intermediate conversion to WHC.
Args:
inp (torch.Tensor): Image, uint8. Images are normalized in
W x H x C space. Images provided as C x W x H will be
auto-converted and permuted back after normalization.
Returns:
torch.Tensor: Image, uint8.
"""
import torch
from slideflow.io.torch import cwh_to_whc, whc_to_cwh, is_cwh
if len(inp.shape) == 4:
return torch.stack([self._torch_transform(img) for img in inp])
elif is_cwh(inp):
# Convert from CWH -> WHC (normalize) -> CWH
return whc_to_cwh(
torch.from_numpy(
self.rgb_to_rgb(
cwh_to_whc(inp).cpu().numpy(),
augment=augment
)
)
)
else:
return torch.from_numpy(
self.rgb_to_rgb(inp.cpu().numpy(), augment=augment)
)
def _torch_augment(self, inp: "torch.Tensor") -> "torch.Tensor":
"""Augment a torch uint8 image (CWH).
Augmentation ocurs via intermediate conversion to WHC.
Args:
inp (torch.Tensor): Image, uint8. Images are normalized in
W x H x C space. Images provided as C x W x H will be
auto-converted and permuted back after normalization.
Returns:
torch.Tensor: Image, uint8.
"""
import torch
from slideflow.io.torch import cwh_to_whc, whc_to_cwh, is_cwh
if len(inp.shape) == 4:
return torch.stack([self._torch_augment(img) for img in inp])
elif is_cwh(inp):
# Convert from CWH -> WHC (normalize) -> CWH
return whc_to_cwh(
torch.from_numpy(
self.augment_rgb(cwh_to_whc(inp).cpu().numpy())
)
)
else:
return torch.from_numpy(self.augment_rgb(inp.cpu().numpy()))
def fit(
self,
arg1: Optional[Union[Dataset, np.ndarray, str]],
batch_size: int = 64,
num_threads: Union[str, int] = 'auto',
**kwargs,
) -> "StainNormalizer":
"""Fit the normalizer to a target image or dataset of images.
Args:
arg1: (Dataset, np.ndarray, str): Target to fit. May be a str,
numpy image array (uint8), path to an image, or a Slideflow
Dataset. If this is a string, will fit to the corresponding
preset fit (either 'v1', 'v2', or 'v3').
If a Dataset is provided, will average fit values across
all images in the dataset.
batch_size (int, optional): Batch size during fitting, if fitting
to dataset. Defaults to 64.
num_threads (Union[str, int], optional): Number of threads to use
during fitting, if fitting to a dataset. Defaults to 'auto'.
"""
# Fit to a dataset
if isinstance(arg1, Dataset):
# Set up thread pool
if num_threads == 'auto':
num_threads = sf.util.num_cpu(default=8) # type: ignore
log.debug(f"Setting up pool (size={num_threads}) for norm fitting")
log.debug(f"Using normalizer batch size of {batch_size}")
pool = mp.dummy.Pool(num_threads) # type: ignore
dataset = arg1
if sf.backend() == 'tensorflow':
dts = dataset.tensorflow(
None,
batch_size,
standardize=False,
infinite=False
)
elif sf.backend() == 'torch':
dts = dataset.torch(
None,
batch_size,
standardize=False,
infinite=False,
num_workers=8
)
all_fit_vals = [] # type: ignore
pb = Progress(transient=True)
task = pb.add_task('Fitting normalizer...', total=dataset.num_tiles)
pb.start()
with cleanup_progress(pb):
for img_batch, slide in dts:
if sf.model.is_torch_tensor(img_batch):
img_batch = img_batch.permute(0, 2, 3, 1) # BCWH -> BWHC
mapped = pool.imap(lambda x: self.n.fit(x.numpy()), img_batch)
for fit_vals in mapped:
if all_fit_vals == []:
all_fit_vals = [[] for _ in range(len(fit_vals))]
for v, val in enumerate(fit_vals):
all_fit_vals[v] += [np.squeeze(val)]
pb.advance(task, batch_size)
self.n.set_fit(*[np.array(v).mean(axis=0) for v in all_fit_vals])
pool.close()
# Fit to numpy image
elif isinstance(arg1, np.ndarray):
self.n.fit(arg1, **kwargs)
# Fit to a preset
elif (isinstance(arg1, str)
and arg1 in sf.norm.utils.fit_presets[self.n.preset_tag]):
self.n.fit_preset(arg1, **kwargs)
# Fit to a path to an image
elif isinstance(arg1, str):
self.src_img = cv2.cvtColor(cv2.imread(arg1), cv2.COLOR_BGR2RGB)
self.n.fit(self.src_img, **kwargs)
elif arg1 is None and kwargs:
self.set_fit(**kwargs)
else:
raise ValueError(f'Unrecognized args for fit: {arg1}')
log.debug('Fit normalizer: {}'.format(
', '.join([f"{fit_key} = {fit_val}"
for fit_key, fit_val in self.get_fit().items()])
))
return self
def get_fit(self, as_list: bool = False):
"""Get the current normalizer fit.
Args:
as_list (bool). Convert the fit values (numpy arrays) to list
format. Defaults to False.
Returns:
Dict[str, np.ndarray]: Dictionary mapping fit parameters (e.g.
'target_concentrations') to their respective fit values.
"""
_fit = self.n.get_fit()
if as_list:
return {k: _as_list(v) for k, v in _fit.items()}
else:
return _fit
def set_fit(self, **kwargs) -> None:
"""Set the normalizer fit to the given values.
Keyword args:
stain_matrix_target (np.ndarray, optional): Set the stain matrix
target for the normalizer. May raise an error if the normalizer
does not have a stain_matrix_target fit attribute.
target_concentrations (np.ndarray, optional): Set the target
concentrations for the normalizer. May raise an error if the
normalizer does not have a target_concentrations fit attribute.
target_means (np.ndarray, optional): Set the target means for the
normalizer. May raise an error if the normalizer does not have
a target_means fit attribute.
target_stds (np.ndarray, optional): Set the target standard
deviations for the normalizer. May raise an error if the
normalizer does not have a target_stds fit attribute.
"""
self.n.set_fit(**{k:v for k, v in kwargs.items() if v is not None})
def set_augment(self, preset: Optional[str] = None, **kwargs) -> None:
"""Set the normalizer augmentation space.
Args:
preset (str, optional): Augmentation preset. Defaults to None.
Keyword args:
matrix_stdev (np.ndarray): Standard deviation
of the stain matrix target. Must have the shape (3, 2).
Used for Macenko normalizers.
Defaults to None (will not augment stain matrix).
concentrations_stdev (np.ndarray): Standard deviation
of the target concentrations. Must have the shape (2,).
Used for Macenko normalizers.
Defaults to None (will not augment target concentrations).
means_stdev (np.ndarray): Standard deviation
of the target means. Must have the shape (3,).
Used for Reinhard normalizers.
Defaults to None (will not augment target means).
stds_stdev (np.ndarray): Standard deviation
of the target stds. Must have the shape (3,).
Used for Reinhard normalizers.
Defaults to None (will not augment target stds).
"""
if preset is not None:
return self.n.augment_preset(preset)
if kwargs:
self.n.set_augment(**{k:v for k, v in kwargs.items() if v is not None})
def transform(
self,
image: Union[np.ndarray, "tf.Tensor", "torch.Tensor"],
*,
augment: bool = False
) -> Union[np.ndarray, "tf.Tensor", "torch.Tensor"]:
"""Normalize a target image, attempting to preserve the original type.
Args:
image (np.ndarray, tf.Tensor, or torch.Tensor): Image as a uint8
array. Numpy and Tensorflow images are normalized in W x H x C
space. PyTorch images provided as C x W x H will be
auto-converted and permuted back after normalization.
Keyword args:
augment (bool): Transform using stain aumentation.
Defaults to False.
Returns:
Normalized image of the original type (uint8).
"""
if isinstance(image, (str, bytes)):
raise ValueError("Unable to auto-transform bytes or str; please "
"use .png_to_png() or .jpeg_to_jpeg().")
if 'tensorflow' in sys.modules:
import tensorflow as tf
if isinstance(image, tf.Tensor):
return self.tf_to_tf(image, augment=augment)
if 'torch' in sys.modules:
import torch
if isinstance(image, torch.Tensor):
return self.torch_to_torch(image, augment=augment)
if isinstance(image, np.ndarray):
return self.rgb_to_rgb(image, augment=augment)
raise ValueError(f"Unrecognized image type {type(image)}; expected "
"np.ndarray, tf.Tensor, or torch.Tensor")
def augment(
self,
image: Union[np.ndarray, "tf.Tensor", "torch.Tensor"]
) -> Union[np.ndarray, "tf.Tensor", "torch.Tensor"]:
"""Augment a target image, attempting to preserve the original type.
Args:
image (np.ndarray, tf.Tensor, or torch.Tensor): Image as a uint8
array. Numpy and Tensorflow images are normalized in W x H x C
space. PyTorch images provided as C x W x H will be
auto-converted and permuted back after normalization.
Returns:
Augmented image of the original type (uint8).
"""
if not hasattr(self.n, 'augment'):
raise errors.AugmentationNotSupportedError(
f"Normalizer {self.method} does not support augmentation.")
if isinstance(image, (str, bytes)):
raise ValueError("Unable to augment bytes or str; image "
"must first be converted to an array or Tensor.")
if 'tensorflow' in sys.modules:
import tensorflow as tf
if isinstance(image, tf.Tensor):
if isinstance(image, dict):
image['tile_image'] = tf.py_function(
self.augment_rgb,
[image['tile_image']],
tf.uint8
)
elif len(image.shape) == 4:
image = tf.stack([self.augment_rgb(_i) for _i in image])
else:
image = tf.py_function(
self.augment_rgb,
[image],
tf.uint8
)
return image
if 'torch' in sys.modules:
import torch
if isinstance(image, torch.Tensor):
if isinstance(image, dict):
to_return = {
k: v for k, v in image.items()
if k != 'tile_image'
}
to_return['tile_image'] = self._torch_augment(
image['tile_image']
)
return to_return
else:
return self._torch_augment(image)
if isinstance(image, np.ndarray):
return self.augment_rgb(image)
raise ValueError(f"Unrecognized image type {type(image)}; expected "
"np.ndarray, tf.Tensor, or torch.Tensor")
def augment_rgb(self, image: np.ndarray) -> np.ndarray:
"""Augment a numpy array (uint8), returning a numpy array (uint8).
Args:
image (np.ndarray): Image (uint8).
Returns:
np.ndarray: Augmented image, uint8, W x H x C.
"""
return self.n.augment(image)
def jpeg_to_jpeg(
self,
jpeg_string: Union[str, bytes],
*,
quality: int = 100,
augment: bool = False
) -> bytes:
"""Normalize a JPEG image, returning a JPEG image.
Args:
jpeg_string (str, bytes): JPEG image data.
Keyword args:
augment (bool): Transform using stain aumentation.
Defaults to False.
quality (int, optional): Quality level for creating the resulting
normalized JPEG image. Defaults to 100.
Returns:
bytes: Normalized JPEG image.
"""
cv_image = self.jpeg_to_rgb(jpeg_string, augment=augment)
with BytesIO() as output:
Image.fromarray(cv_image).save(
output,
format="JPEG",
quality=quality
)
return output.getvalue()
def jpeg_to_rgb(
self,
jpeg_string: Union[str, bytes],
*,
augment: bool = False
) -> np.ndarray:
"""Normalize a JPEG image, returning a numpy uint8 array.
Args:
jpeg_string (str, bytes): JPEG image data.
Keyword args:
augment (bool): Transform using stain aumentation.
Defaults to False.
Returns:
np.ndarray: Normalized image, uint8, W x H x C.
"""
cv_image = cv2.imdecode(
np.fromstring(jpeg_string, dtype=np.uint8),
cv2.IMREAD_COLOR
)
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
return self.rgb_to_rgb(cv_image, augment=augment)
def png_to_png(
self,
png_string: Union[str, bytes],
*,
augment: bool = False
) -> bytes:
"""Normalize a PNG image, returning a PNG image.
Args:
png_string (str, bytes): PNG image data.
Keyword args:
augment (bool): Transform using stain aumentation.
Defaults to False.
Returns:
bytes: Normalized PNG image.
"""
cv_image = self.png_to_rgb(png_string, augment=augment)
with BytesIO() as output:
Image.fromarray(cv_image).save(output, format="PNG")
return output.getvalue()
def png_to_rgb(
self,
png_string: Union[str, bytes],
*,
augment: bool = False
) -> np.ndarray:
"""Normalize a PNG image, returning a numpy uint8 array.
Args:
png_string (str, bytes): PNG image data.
Keyword args:
augment (bool): Transform using stain aumentation.
Defaults to False.
Returns:
np.ndarray: Normalized image, uint8, W x H x C.
"""
return self.jpeg_to_rgb(png_string, augment=augment) # It should auto-detect format
def rgb_to_rgb(
self,
image: np.ndarray,
*,
augment: bool = False
) -> np.ndarray:
"""Normalize a numpy array (uint8), returning a numpy array (uint8).
Args:
image (np.ndarray): Image (uint8).
Keyword args:
augment (bool): Transform using stain aumentation.
Defaults to False.
Returns:
np.ndarray: Normalized image, uint8, W x H x C.
"""
return self.n.transform(image, augment=augment)
def tf_to_rgb(
self,
image: "tf.Tensor",
*,
augment: bool = False
) -> np.ndarray:
"""Normalize a tf.Tensor (uint8), returning a numpy array (uint8).
Args:
image (tf.Tensor): Image (uint8).
Keyword args:
augment (bool): Transform using stain aumentation.
Defaults to False.
Returns:
np.ndarray: Normalized image, uint8, W x H x C.
"""
return self.rgb_to_rgb(np.array(image), augment=augment)
def tf_to_tf(
self,
image: Union[Dict, "tf.Tensor"],
*args: Any,
augment: bool = False
) -> Tuple[Union[Dict, "tf.Tensor"], ...]:
"""Normalize a tf.Tensor (uint8), returning a numpy array (uint8).
Args:
image (tf.Tensor, Dict): Image (uint8) either as a raw Tensor,
or a Dictionary with the image under the key 'tile_image'.
args (Any, optional): Any additional arguments, which will be passed
and returned unmodified.
Keyword args:
augment (bool): Transform using stain aumentation.
Defaults to False.
Returns:
A tuple containing the normalized tf.Tensor image (uint8,
W x H x C) and any additional arguments provided.
"""
import tensorflow as tf
if isinstance(image, dict):
image['tile_image'] = tf.py_function(
partial(self.tf_to_rgb, augment=augment),
[image['tile_image']],
tf.uint8
)
elif len(image.shape) == 4:
image = tf.stack([self.tf_to_tf(_i, augment=augment) for _i in image])
else:
image = tf.py_function(
partial(self.tf_to_rgb, augment=augment),
[image],
tf.uint8
)
return detuple(image, args)
def torch_to_torch(
self,
image: Union[Dict, "torch.Tensor"],
*args,
augment: bool = False
) -> Tuple[Union[Dict, "torch.Tensor"], ...]:
"""Normalize a torch.Tensor (uint8), returning a numpy array (uint8).
Args:
image (torch.Tensor, Dict): Image (uint8) either as a raw Tensor,
or a Dictionary with the image under the key 'tile_image'.
args (Any, optional): Any additional arguments, which will be passed
and returned unmodified.
Keyword args:
augment (bool): Transform using stain aumentation.
Defaults to False.
Returns:
A tuple containing
np.ndarray: Normalized torch.Tensor image, uint8 (channel dimension matching the input image)
args (Any, optional): Any additional arguments provided, unmodified.
"""
if isinstance(image, dict):
to_return = {
k: v for k, v in image.items()
if k != 'tile_image'
}
to_return['tile_image'] = self._torch_transform(
image['tile_image'],
augment=augment
)
return detuple(to_return, args)
else:
return detuple(self._torch_transform(image, augment=augment), args)
# --- Context management --------------------------------------------------
@contextmanager
def context(
self,
context: Union[str, "sf.WSI", np.ndarray, "tf.Tensor", "torch.Tensor"]
):
"""Set the whole-slide context for the stain normalizer.
With contextual normalization, max concentrations are determined
from the context (whole-slide image) rather than the image being
normalized. This may improve stain normalization for sections of
a slide that are predominantly eosin (e.g. necrosis or low cellularity).
When calculating max concentrations from the image context,
white pixels (255) will be masked.
This function is a context manager used for temporarily setting the
image context. For example:
.. code-block:: python
with normalizer.context(slide):
normalizer.transform(target)
If a slide (``sf.WSI``) is used for context, any existing QC filters
and regions of interest will be used to mask out background as white
pixels, and the masked thumbnail will be used for creating the
normalizer context. If no QC has been applied to the slide and the
slide does not have any Regions of Interest, then both otsu's
thresholding and Gaussian blur filtering will be applied
to the thumbnail for masking.
Args:
I (np.ndarray, sf.WSI): Context to use for normalization, e.g.
a whole-slide image thumbnail, optionally masked with masked
areas set to (255, 255, 255).
"""
self.set_context(context)
yield
self.clear_context()
def set_context(
self,
context: Union[str, "sf.WSI", np.ndarray, "tf.Tensor", "torch.Tensor"]
) -> bool:
"""Set the whole-slide context for the stain normalizer.
With contextual normalization, max concentrations are determined
from the context (whole-slide image) rather than the image being
normalized. This may improve stain normalization for sections of
a slide that are predominantly eosin (e.g. necrosis or low cellularity).
When calculating max concentrations from the image context,
white pixels (255) will be masked.
If a slide (``sf.WSI``) is used for context, any existing QC filters
and regions of interest will be used to mask out background as white
pixels, and the masked thumbnail will be used for creating the
normalizer context. If no QC has been applied to the slide and the
slide does not have any Regions of Interest, then both otsu's
thresholding and Gaussian blur filtering will be applied
to the thumbnail for masking.
Args:
I (np.ndarray, sf.WSI): Context to use for normalization, e.g.
a whole-slide image thumbnail, optionally masked with masked
areas set to (255, 255, 255).
"""
if hasattr(self.n, 'set_context'):
if isinstance(context, str):
image = np.asarray(sf.WSI(context, 500, 500).thumb(mpp=4))
elif isinstance(context, sf.WSI):
image = context.masked_thumb(mpp=4, background='white')
else:
image = context # type: ignore
self.n.set_context(image)
return True
else:
return False
def clear_context(self) -> None:
"""Remove any previously set stain normalizer context."""
if hasattr(self.n, 'clear_context'):
self.n.clear_context()
def autoselect(
method: str,
source: Optional[str] = None,
backend: Optional[str] = None,
**kwargs
) -> StainNormalizer:
"""Select the best normalizer for a given method, and fit to a given source.
If a normalizer method has a native implementation in the current backend
(Tensorflow or PyTorch), the native normalizer will be used.
If not, the default numpy implementation will be used.
Currently, the PyTorch-native normalizers are NOT used by default, as they
are slower than the numpy implementations. Thus, with the PyTorch backend,
all normalizers will be the default numpy implementations.
Args:
method (str): Normalization method. Options include 'macenko',
'reinhard', 'reinhard_fast', 'reinhard_mask', 'reinhard_fast_mask',
'vahadane', 'vahadane_spams', 'vahadane_sklearn', and 'augment'.
source (str, optional): Stain normalization preset or path to a source
image. Valid presets include 'v1', 'v2', and 'v3'. If None, will
use the default present ('v3'). Defaults to None.
backend (str): Backend to use for native normalizers. Options include
'tensorflow', 'torch', and 'opencv'. If None, will use the current
backend, falling back to opencv/numpy if a native normalizer is
not available. Defaults to None.
Returns:
StainNormalizer: Initialized StainNormalizer.
"""
if backend is None:
backend = sf.backend()
if backend == 'tensorflow':
import slideflow.norm.tensorflow
BackendNormalizer = sf.norm.tensorflow.TensorflowStainNormalizer
elif backend == 'torch':
import slideflow.norm.torch
BackendNormalizer = sf.norm.torch.TorchStainNormalizer # type: ignore
elif backend == 'opencv':
BackendNormalizer = StainNormalizer
else:
raise errors.UnrecognizedBackendError
if method in BackendNormalizer.normalizers:
normalizer = BackendNormalizer(method, **kwargs)
else:
normalizer = StainNormalizer(method, **kwargs) # type: ignore
if source is not None and source != 'dataset':
normalizer.fit(source)
return normalizer