Shortcuts

Source code for slideflow.grad

"""Submodule for calculating/displaying pixel attribution (saliency maps)."""

from typing import Any, Callable, Dict, Optional

import slideflow as sf
import numpy as np
import saliency.core as saliency
from functools import partial
from slideflow import errors
from slideflow.grad.plot_utils import (comparison_plot, inferno, multi_plot,
                                       oranges, overlay,
                                       saliency_map_comparison)

VANILLA = 0
VANILLA_SMOOTH = 1
INTEGRATED_GRADIENTS = 2
INTEGRATED_GRADIENTS_SMOOTH = 3
GUIDED_INTEGRATED_GRADIENTS = 4
GUIDED_INTEGRATED_GRADIENTS_SMOOTH = 5
BLUR_INTEGRATED_GRADIENTS = 6
BLUR_INTEGRATED_GRADIENTS_SMOOTH = 7
XRAI = 8
XRAI_FAST = 9

[docs]class SaliencyMap:
[docs] def __init__(self, model: Callable, class_idx: int) -> None: """Class to assist with calculation and display of saliency maps. Args: model (Callable): Differentiable model from which saliency is calculated. class_idx (int): Index of class for backpropagating gradients. """ if not callable(model): raise ValueError("'model' must be a differentiable model.") self.model = model self.feature_model = None self.feature_model_layer = None self.class_idx = class_idx self.gradients = saliency.GradientSaliency() self.gradcam_grads = saliency.GradCam() self.ig = saliency.IntegratedGradients() self.guided_ig = saliency.GuidedIG() self.blur_ig = saliency.BlurIG() self.xrai_grads = saliency.XRAI() self.fast_xrai_params = saliency.XRAIParameters() self.fast_xrai_params.algorithm = 'fast' self._fn_map = { VANILLA: self.vanilla, VANILLA_SMOOTH: partial(self.vanilla, smooth=True), INTEGRATED_GRADIENTS: self.integrated_gradients, INTEGRATED_GRADIENTS_SMOOTH: partial(self.integrated_gradients, smooth=True), GUIDED_INTEGRATED_GRADIENTS: self.guided_integrated_gradients, GUIDED_INTEGRATED_GRADIENTS_SMOOTH: partial(self.guided_integrated_gradients, smooth=True), BLUR_INTEGRATED_GRADIENTS: self.blur_integrated_gradients, BLUR_INTEGRATED_GRADIENTS_SMOOTH: partial(self.blur_integrated_gradients), XRAI: self.xrai, XRAI_FAST: self.xrai_fast }
@property def model_backend(self): return sf.util.model_backend(self.model) @property def device(self): if self.model_backend == 'tensorflow': return None else: return next(self.model.parameters()).device def _update_feature_model(self, layer): if self.feature_model_layer == layer: return # Cleanup old model del self.feature_model self.feature_model = None self.feature_model_layer = layer if self.model_backend == 'tensorflow': import tensorflow as tf flattened = sf.model.tensorflow_utils.flatten(self.model) conv_layer = flattened.get_layer(layer) self.feature_model = tf.keras.models.Model([flattened.inputs], [conv_layer.output, flattened.output]) else: import torch from slideflow.model import torch_utils conv_layer = torch_utils.get_module_by_name(self.model, layer) self._torch_conv_layer_outputs = {} def conv_layer_forward(m, i, o): # move the RGB dimension to the last dimension self._torch_conv_layer_outputs[saliency.base.CONVOLUTION_LAYER_VALUES] = torch.movedim(o, 1, 3).detach().numpy() def conv_layer_backward(m, i, o): # move the RGB dimension to the last dimension self._torch_conv_layer_outputs[saliency.base.CONVOLUTION_OUTPUT_GRADIENTS] = torch.movedim(o[0], 1, 3).detach().numpy() conv_layer.register_forward_hook(conv_layer_forward) conv_layer.register_full_backward_hook(conv_layer_backward) def _grad_fn_torch( self, image: np.ndarray, call_model_args: Any = None, expected_keys: Dict = None ) -> Any: """Calculate gradient attribution with PyTorch backend. Images are expected to be in W, H, C format. """ import torch from slideflow.io.torch import whc_to_cwh image = torch.tensor(image, requires_grad=True).to(torch.float32).to(self.device) # type: ignore output = self.model(whc_to_cwh(image)) if saliency.base.INPUT_OUTPUT_GRADIENTS in expected_keys: # type: ignore outputs = output[:, self.class_idx] grads = torch.autograd.grad(outputs, image, grad_outputs=torch.ones_like(outputs)) # type: ignore gradients = grads[0].cpu().detach().numpy() return {saliency.base.INPUT_OUTPUT_GRADIENTS: gradients} else: # For Grad-CAM one_hot = torch.zeros_like(output) one_hot[:, self.class_idx] = 1 self.model.zero_grad() # type: ignore output.backward(gradient=one_hot, retain_graph=True) return self._torch_conv_layer_outputs def _grad_fn_tf( self, image: np.ndarray, call_model_args: Any = None, expected_keys: Dict = None ) -> Any: """Calculate gradient attribution with Tensorflow backend.""" import tensorflow as tf image = tf.convert_to_tensor(image) with tf.GradientTape() as tape: if expected_keys == [saliency.base.INPUT_OUTPUT_GRADIENTS]: # For vanilla gradient, Integrated Gradients, XRAI tape.watch(image) output = self.model(image)[:, self.class_idx] gradients = tape.gradient(output, image) return {saliency.base.INPUT_OUTPUT_GRADIENTS: gradients} else: # For Grad-CAM conv_layer, output_layer = self.feature_model(image) gradients = np.array(tape.gradient(output_layer, conv_layer)) return {saliency.base.CONVOLUTION_LAYER_VALUES: conv_layer, saliency.base.CONVOLUTION_OUTPUT_GRADIENTS: gradients} def _grad_fn( self, image: np.ndarray, call_model_args: Any = None, expected_keys: Dict = None ) -> Any: """Calculate gradient attribution.""" if self.model_backend == 'tensorflow': return self._grad_fn_tf(image, call_model_args, expected_keys) elif self.model_backend == 'torch': return self._grad_fn_torch(image, call_model_args, expected_keys) else: raise errors.UnrecognizedBackendError def _apply_mask_fn( self, img: np.ndarray, grads: saliency.CoreSaliency, baseline: bool = False, smooth: bool = False, **kwargs ) -> np.ndarray: """Applys a saliency masking function to a gradients map. Args: img (np.ndarray or list(np.ndarray)): Image or list of images. grads (saliency.CoreSaliency): Gradients for saliency. baseline (bool): Requires x_baseline argument. smooth (bool): Use a smoothed mask. Returns: np.ndarray: Saliency map. """ mask_fn = grads.GetSmoothedMask if smooth else grads.GetMask def _get_mask(_img): if baseline: kwargs.update({'x_baseline': np.zeros(_img.shape)}) out = mask_fn(_img, self._grad_fn, **kwargs) return out if isinstance(img, list): # Normalize together image_3d = list(map(_get_mask, img)) v_maxes, v_mins = zip(*[max_min(img3d) for img3d in image_3d]) vmax = max(v_maxes) vmin = min(v_mins) return [grayscale(img3d, vmax=vmax, vmin=vmin) for img3d in image_3d] else: return grayscale(_get_mask(img))
[docs] def all(self, img: np.ndarray) -> Dict: """Calculate all saliency map methods. Args: img (np.ndarray): Pre-processed input image in W, H, C format. Returns: Dict: Dictionary mapping name of saliency method to saliency map. """ return { 'Vanilla': self.vanilla(img), 'Vanilla (Smoothed)': self.vanilla(img, smooth=True), 'Integrated Gradients': self.integrated_gradients(img), 'Integrated Gradients (Smooth)': self.integrated_gradients(img, smooth=True), 'Guided Integrated Gradients': self.guided_integrated_gradients(img), 'Guided Integrated Gradients (Smooth)': self.guided_integrated_gradients(img, smooth=True), 'Blur Integrated Gradients': self.blur_integrated_gradients(img), 'Blur Integrated Gradients (Smooth)': self.blur_integrated_gradients(img, smooth=True), }
def get(self, img: np.ndarray, method: int) -> np.ndarray: return self._fn_map[method](img)
[docs] def vanilla( self, img: np.ndarray, smooth: bool = False, **kwargs ) -> np.ndarray: """Calculate gradient-based saliency map. Args: img (np.ndarray): Pre-processed input image in W, H, C format. smooth (bool, optional): Smooth gradients. Defaults to False. Returns: np.ndarray: Saliency map. """ return self._apply_mask_fn( img, self.gradients, smooth=smooth, **kwargs )
[docs] def gradcam( self, img: np.ndarray, layer: str, smooth: bool = False, **kwargs ) -> np.ndarray: """Calculate gradient-based saliency map. Args: img (np.ndarray): Pre-processed input image in W, H, C format. smooth (bool, optional): Smooth gradients. Defaults to False. Returns: np.ndarray: Saliency map. """ self._update_feature_model(layer) return self._apply_mask_fn( img, self.gradcam_grads, smooth=smooth, **kwargs )
[docs] def integrated_gradients( self, img: np.ndarray, x_steps: int = 25, batch_size: int = 20, smooth: bool = False, **kwargs ) -> np.ndarray: """Calculate saliency map using integrated gradients. Args: img (np.ndarray): Pre-processed input image in W, H, C format. x_steps (int, optional): Steps for gradient calculation. Defaults to 25. max_dist (float, optional): Maximum distance for gradient calculation. Defaults to 1.0. smooth (bool, optional): Smooth gradients. Defaults to False. Returns: np.ndarray: Saliency map. """ return self._apply_mask_fn( img, self.ig, smooth=smooth, x_steps=x_steps, batch_size=batch_size, baseline=True, **kwargs )
[docs] def guided_integrated_gradients( self, img: np.ndarray, x_steps: int = 25, max_dist: float = 1.0, fraction: float = 0.5, smooth: bool = False, **kwargs ) -> np.ndarray: """Calculate saliency map using guided integrated gradients. Args: img (np.ndarray): Pre-processed input image in W, H, C format. x_steps (int, optional): Steps for gradient calculation. Defaults to 25. max_dist (float, optional): Maximum distance for gradient calculation. Defaults to 1.0. fraction (float, optional): Fraction for gradient calculation. Defaults to 0.5. smooth (bool, optional): Smooth gradients. Defaults to False. Returns: np.ndarray: Saliency map. """ return self._apply_mask_fn( img, self.guided_ig, x_steps=x_steps, max_dist=max_dist, fraction=fraction, smooth=smooth, baseline=True, **kwargs )
[docs] def blur_integrated_gradients( self, img: np.ndarray, batch_size: int = 20, smooth: bool = False, **kwargs ) -> np.ndarray: """Calculate saliency map using blur integrated gradients. Args: img (np.ndarray): Pre-processed input image in W, H, C format. batch_size (int, optional): Batch size. Defaults to 20. smooth (bool, optional): Smooth gradients. Defaults to False. Returns: np.ndarray: Saliency map. """ return self._apply_mask_fn( img, self.blur_ig, smooth=smooth, batch_size=batch_size, **kwargs )
[docs] def xrai( self, img: np.ndarray, batch_size: int = 20, **kwargs ) -> np.ndarray: """Calculate saliency map using XRAI. Args: img (np.ndarray): Pre-processed input image in W, H, C format. batch_size (int, optional): Batch size. Defaults to 20. Returns: np.ndarray: Saliency map. """ mask = self.xrai_grads.GetMask( img, self._grad_fn, batch_size=batch_size, **kwargs ) if isinstance(img, list): # Normalize together v_maxes, v_mins = zip(*[max_min(img3d) for img3d in mask]) vmax = max(v_maxes) vmin = min(v_mins) return [normalize_xrai(img3d, vmax=vmax, vmin=vmin) for img3d in mask] else: return normalize_xrai(mask)
[docs] def xrai_fast( self, img: np.ndarray, batch_size: int = 20, **kwargs ) -> np.ndarray: """Calculate saliency map using XRAI (fast implementation). Args: img (np.ndarray): Pre-processed input image in W, H, C format. batch_size (int, optional): Batch size. Defaults to 20. Returns: np.ndarray: Saliency map. """ mask = self.xrai_grads.GetMask( img, self._grad_fn, batch_size=batch_size, extra_parameters=self.fast_xrai_params, **kwargs ) if isinstance(img, list): # Normalize together v_maxes, v_mins = zip(*[max_min(img3d) for img3d in mask]) vmax = max(v_maxes) vmin = min(v_mins) return [normalize_xrai(img3d, vmax=vmax, vmin=vmin) for img3d in mask] else: return normalize_xrai(mask)
[docs]def grayscale(image_3d, vmax=None, vmin=None, percentile=99): """Returns a 3D tensor as a grayscale 2D tensor. This method sums a 3D tensor across the absolute value of axis=2, and then clips values at a given percentile. """ if vmax is None and vmin is None: vmax, vmin = max_min(image_3d, percentile=percentile) image_2d = np.sum(np.abs(image_3d), axis=2) return np.clip((image_2d - vmin) / (vmax - vmin), 0, 1)
def normalize_xrai(mask, percentile=99): vmax = np.percentile(mask, percentile) vmin = np.min(mask) return np.clip((mask - vmin) / (vmax - vmin), 0, 1) def max_min(image_3d, percentile=99): image_2d = np.sum(np.abs(image_3d), axis=2) vmax = np.percentile(image_2d, percentile) vmin = np.min(image_2d) return vmax, vmin