Shortcuts

Source code for slideflow.heatmap

import os
from collections import namedtuple
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
                    Union)

import numpy as np
import shapely.geometry as sg
from mpl_toolkits.axes_grid1.inset_locator import mark_inset, zoomed_inset_axes
from threading import Thread

import slideflow as sf
from slideflow import errors
from slideflow.slide import WSI
from slideflow.util import log

if TYPE_CHECKING:
    import matplotlib.pyplot as plt
    from matplotlib.axes import Axes
    from PIL import Image
    try:
        import tensorflow as tf
    except ImportError:
        pass
    try:
        import torch
    except ImportError:
        pass

Inset = namedtuple("Inset", "x y zoom loc mark1 mark2 axes")

# -----------------------------------------------------------------------------

[docs]class Heatmap: """Generate a heatmap of predictions across a whole-slide image. This interface is designed to be used with tile-based models, and does not support multiple-instance learning models. Attention heatmaps of multiple-instance learning models can be generated using :func:`slideflow.mil.predict_slide`. """ def __init__( self, slide: Union[str, WSI], model: str, stride_div: Optional[int] = None, batch_size: int = 32, num_threads: Optional[int] = None, num_processes: Optional[int] = None, img_format: str = 'auto', generate: bool = True, generator_kwargs: Optional[Dict[str, Any]] = None, device: Optional["torch.device"] = None, load_method: Optional[str] = None, **wsi_kwargs ) -> None: """Initialize a heatmap from a path to a slide or a :class:`slideflow.WSI`. Examples Create a heatmap from a path to a slide. .. code-block:: python model_path = 'path/to/saved_model' heatmap = sf.Heatmap('slide.svs', model_path) Create a heatmap, with grayspace filtering disabled. .. code-block:: python heatmap = sf.Heatmap(..., grayspace_fraction=1) Create a heatmap from a ``sf.WSI`` object. .. code-block:: python # Load a slide wsi = sf.WSI(tile_px=299, tile_um=302) # Apply Otsu's thresholding to the slide, # so heatmap is only generated on areas with tissue. wsi.qc('otsu') # Generate the heatmap heatmap = sf.Heatmap(wsi, model_path) Args: slide (str): Path to slide. model (str): Path to Tensorflow or PyTorch model. stride_div (int, optional): Divisor for stride when convoluting across slide. Defaults to 2. roi_dir (str, optional): Directory in which slide ROI is contained. Defaults to None. rois (list, optional): List of paths to slide ROIs. Alternative to providing roi_dir. Defaults to None. roi_method (str): Either 'inside', 'outside', 'auto', or 'ignore'. Determines how ROIs are used to extract tiles. If 'inside' or 'outside', will extract tiles in/out of an ROI, and raise errors.MissingROIError if an ROI is not available. If 'auto', will extract tiles inside an ROI if available, and across the whole-slide if no ROI is found. If 'ignore', will extract tiles across the whole-slide regardless of whether an ROI is available. Defaults to 'auto'. batch_size (int, optional): Batch size for calculating predictions. Defaults to 32. num_threads (int, optional): Number of tile worker threads. Cannot supply both ``num_threads`` (uses thread pool) and ``num_processes`` (uses multiprocessing pool). Defaults to CPU core count. num_processes (int, optional): Number of child processes to spawn for multiprocessing pool. Defaults to None (does not use multiprocessing). enable_downsample (bool, optional): Enable the use of downsampled slide image layers. Defaults to True. img_format (str, optional): Image format (png, jpg) to use when extracting tiles from slide. Must match the image format the model was trained on. If 'auto', will use the format logged in the model params.json. Defaults to 'auto'. generate (bool): Generate the heatmap after initialization. If False, heatmap will need to be manually generated by calling :meth:``Heatmap.generate()``. generator_kwargs (dict, optional): Keyword arguments passed to the :meth:`slideflow.WSI.build_generator()`. device (torch.device, optional): PyTorch device. Defaults to initializing a new CUDA device. Keyword args: Any keyword argument accepted by :class:`slideflow.WSI`. """ if num_processes is not None and num_threads is not None: raise ValueError("Invalid argument: cannot supply both " "num_processes and num_threads") self.insets = [] # type: List[Inset] model_config = sf.util.get_model_config(model) self.uq = model_config['hp']['uq'] if img_format == 'auto' and 'img_format' not in model_config: raise errors.HeatmapError( f"Unable to auto-detect image format from model at {model}. " "Manually set to png or jpg with Heatmap(img_format=...)") elif img_format == 'auto': self.img_format = model_config['img_format'] else: self.img_format = img_format if sf.util.is_torch_model_path(model): int_kw = {'device': device} else: int_kw = {} if load_method is not None: int_kw.update(dict(load_method=load_method)) if self.uq: if sf.util.is_torch_model_path(model): import slideflow.model.torch interface_fn = sf.model.torch.UncertaintyInterface else: import slideflow.model.tensorflow interface_fn = sf.model.tensorflow.UncertaintyInterface # type: ignore self.interface = interface_fn(model, **int_kw) else: if sf.util.is_torch_model_path(model): import slideflow.model.torch interface_fn = sf.model.torch.Features else: import slideflow.model.tensorflow interface_fn = sf.model.tensorflow.Features # type: ignore self.interface = interface_fn( # type: ignore model, layers=None, include_preds=True, **int_kw) self.model_path = model self.num_threads = num_threads self.num_processes = num_processes self.batch_size = batch_size self.device = device self.tile_px = model_config['tile_px'] self.tile_um = model_config['tile_um'] self.num_classes = self.interface.num_classes self.num_features = self.interface.num_features self.num_uncertainty = self.interface.num_uncertainty self.predictions = None self.uncertainty = None self._thumb = None if isinstance(slide, str): if stride_div is None: stride_div = 2 self.slide_path = slide self.stride_div = stride_div try: self.slide = WSI( self.slide_path, self.tile_px, self.tile_um, self.stride_div, **wsi_kwargs # type: ignore ) except errors.SlideLoadError: raise errors.HeatmapError( f'Error loading slide {self.slide.name} for heatmap') elif isinstance(slide, WSI): if slide.tile_px != self.tile_px: raise ValueError( "Slide tile_px ({}) does not match model ({})".format( slide.tile_px, self.tile_px)) if slide.tile_um != self.tile_um: raise ValueError( "Slide tile_um ({}) does not match model ({})".format( slide.tile_um, self.tile_um)) if stride_div is not None: log.warn("slide is a WSI; ignoring supplied stride_div.") if wsi_kwargs: log.warn("WSI provided; ignoring keyword arguments: " ", ".join(list(wsi_kwargs.keys()))) self.slide_path = slide.path self.slide = slide self.stride_div = slide.stride_div else: raise ValueError(f"Unrecognized value {slide} for argument slide") if generate: if generator_kwargs is None: generator_kwargs = {} self.generate(**generator_kwargs) elif generator_kwargs: log.warn("Heatmap generate=False, ignoring generator_kwargs (" f"{generator_kwargs})") @staticmethod def _prepare_ax(ax: Optional["Axes"] = None) -> "Axes": """Creates matplotlib figure and axis if one is not supplied, otherwise clears the axis contents. Args: ax (matplotlib.axes.Axes): Figure axis. If not supplied, will create a new figure and axis. Otherwise, clears axis contents. Defaults to None. Returns: matplotlib.axes.Axes: Figure axes. """ import matplotlib.pyplot as plt if ax is None: fig = plt.figure(figsize=(18, 16)) ax = fig.add_subplot(111) fig.subplots_adjust(bottom=0.25, top=0.95) else: ax.clear() return ax def generate( self, asynchronous: bool = False, **kwargs ) -> Optional[Tuple[np.ndarray, Thread]]: """Manually generate the heatmap. This function is automatically called when creating the heatmap if the heatmap was initialized with ``generate=True`` (default behavior). Args: asynchronous (bool, optional): Generate heatmap in a separate thread, returning the numpy array which is updated in realtime with heatmap predictions and the heatmap thread. Defaults to False, returning None. callback (Callable, optional): Callback function to call each time the heatmap grid updated. The callback function should accept a single argument: a list of nested (x_idx, y_idx) lists, indicating the grid indices updated. Defaults to None. Returns: ``None`` if ``threaded=False``, otherwise returns a tuple containing **grid**: Numpy array containing updated in realtime with heatmap predictions as they are calculated. **Thread**: Thread in which heatmap is generated. """ # Load the slide def _generate(grid=None): out = self.interface( self.slide, num_threads=self.num_threads, num_processes=self.num_processes, batch_size=self.batch_size, img_format=self.img_format, dtype=np.float32, grid=grid, **kwargs ) if self.uq: self.predictions = out[:, :, :-(self.num_uncertainty)] self.uncertainty = out[:, :, -(self.num_uncertainty):] else: self.predictions = out self.uncertainty = None log.info(f"Heatmap complete for [green]{self.slide.name}") if asynchronous: it = self.interface grid = np.ma.ones(( self.slide.grid.shape[1], self.slide.grid.shape[0], it.num_features + it.num_classes + it.num_uncertainty), dtype=np.float32) heatmap_thread = Thread(target=_generate, args=(grid,)) heatmap_thread.start() return grid, heatmap_thread else: _generate() return None def _format_ax( self, ax: "Axes", thumb_size: Tuple[int, int], show_roi: bool = True, **kwargs ) -> None: """Formats matplotlib axis in preparation for heatmap plotting. Args: ax (matplotlib.axes.Axes): Figure axis. show_roi (bool, optional): Include ROI on heatmap. Defaults to True. """ ax.tick_params( axis='x', top=True, labeltop=True, bottom=False, labelbottom=False ) # Plot ROIs if show_roi: roi_scale = self.slide.dimensions[0] / thumb_size[0] annPolys = [ sg.Polygon(annotation.scaled_coords(roi_scale)) for annotation in self.slide.rois ] for roi in self.slide.rois: for hole in roi.holes.values(): annPolys.append(sg.Polygon(hole.scaled_coords(roi_scale))) for i, poly in enumerate(annPolys): if poly.geom_type == 'Polygon': x, y = poly.exterior.xy ax.plot(x, y, zorder=20, **kwargs) elif poly.geom_type in ('MultiPolygon', 'GeometryCollection'): for p in poly.geoms: if p.geom_type == 'Polygon': x, y = p.exterior.xy ax.plot(x, y, zorder=20, **kwargs) else: log.warning("Unable to plot ROI {} (geometry={})".format( i, poly.geom_type )) def add_inset( self, x: Tuple[int, int], y: Tuple[int, int], zoom: int = 5, loc: int = 1, mark1: int = 2, mark2: int = 4, axes: bool = True ) -> Inset: """Adds a zoom inset to the heatmap.""" _inset = Inset( x=x, y=y, zoom=zoom, loc=loc, mark1=mark1, mark2=mark2, axes=axes ) self.insets += [_inset] return _inset def clear_insets(self) -> None: """Removes zoom insets.""" self.insets = [] def load(self, path: str) -> None: """Load heatmap predictions and uncertainty from .npz file. This function is an alias for :meth:`slideflow.Heatmap.load_npz()`. Args: path (str, optional): Source .npz file. Must have 'predictions' key and optionally 'uncertainty'. Returns: None """ self.load_npz(path) def load_npz(self, path: str) -> None: """Load heatmap predictions and uncertainty from .npz file. Loads predictions from ``'predictions'`` in .npz file, and uncertainty from ``'uncertainty'`` if present, as generated from :meth:`slideflow.Heatmap.save_npz()``. This function is the same as calling ``heatmap.load()``. Args: path (str, optional): Source .npz file. Must have 'predictions' key and optionally 'uncertainty'. Returns: None """ npzfile = np.load(path) if ('predictions' not in npzfile) and ('logits' in npzfile): log.warn("Loading predictions from 'logits' key.") self.predictions = npzfile['logits'] else: self.predictions = npzfile['predictions'] if 'uncertainty' in npzfile: self.uncertainty = npzfile['uncertainty'] def plot_thumbnail( self, show_roi: bool = False, roi_color: str = 'k', linewidth: int = 5, width: Optional[int] = None, mpp: Optional[float] = None, ax: Optional["Axes"] = None, ) -> "plt.image.AxesImage": """Plot a thumbnail of the slide, with or without ROI. Args: show_roi (bool, optional): Overlay ROIs onto heatmap image. Defaults to True. roi_color (str): ROI line color. Defaults to 'k' (black). linewidth (int): Width of ROI line. Defaults to 5. ax (matplotlib.axes.Axes, optional): Figure axis. If not supplied, will prepare a new figure axis. Returns: plt.image.AxesImage: Result from ax.imshow(). """ ax = self._prepare_ax(ax) if width is None and mpp is None: width = 2048 self._thumb = self.slide.thumb(width=width, mpp=mpp) self._format_ax( ax, thumb_size=self._thumb.size, show_roi=show_roi, color=roi_color, linewidth=linewidth, ) imshow_thumb = ax.imshow(self._thumb, zorder=0) for inset in self.insets: axins = zoomed_inset_axes(ax, inset.zoom, loc=inset.loc) axins.imshow(self._thumb) axins.set_xlim(inset.x[0], inset.x[1]) axins.set_ylim(inset.y[0], inset.y[1]) mark_inset( ax, axins, loc1=inset.mark1, loc2=inset.mark2, fc='none', ec='0', zorder=100 ) if not inset.axes: axins.get_xaxis().set_ticks([]) axins.get_yaxis().set_ticks([]) return imshow_thumb def plot_with_logit_cmap( self, logit_cmap: Union[Callable, Dict], interpolation: str = 'none', ax: Optional["Axes"] = None, **thumb_kwargs, ) -> None: """Plot a heatmap using a specified logit colormap. Args: logit_cmap (obj, optional): Either function or a dictionary use to create heatmap colormap. Each image tile will generate a list of predictions of length O, where O is the number of outcomes. If logit_cmap is a function, then the logit prediction list will be passed to the function, and the function is expected to return [R, G, B] values for display. If logit_cmap is a dictionary, it should map 'r', 'g', and 'b' to indices; the prediction for these outcome indices will be mapped to the RGB colors. Thus, the corresponding color will only reflect up to three outcomes. Example mapping prediction for outcome 0 to the red colorspace, 3 to green, etc: {'r': 0, 'g': 3, 'b': 1} interpolation (str, optional): Interpolation strategy to use for smoothing heatmap. Defaults to 'none'. ax (matplotlib.axes.Axes, optional): Figure axis. If not supplied, will prepare a new figure axis. Keyword args: show_roi (bool, optional): Overlay ROIs onto heatmap image. Defaults to True. roi_color (str): ROI line color. Defaults to 'k' (black). linewidth (int): Width of ROI line. Defaults to 5. """ ax = self._prepare_ax(ax) self.plot_thumbnail(ax=ax, **thumb_kwargs) ax.set_facecolor("black") if callable(logit_cmap): map_logit = logit_cmap else: # Make heatmap with specific logit predictions mapped # to r, g, and b def map_logit(logit): return (logit[logit_cmap['r']], logit[logit_cmap['g']], logit[logit_cmap['b']]) extent = calculate_heatmap_extent( self.slide, self._thumb, self.predictions ) ax.imshow( [[map_logit(logit) for logit in row] for row in self.predictions], extent=extent, interpolation=interpolation, zorder=10 ) ax.set_xlim(0, self._thumb.size[0]) ax.set_ylim(self._thumb.size[1], 0) def plot_uncertainty( self, heatmap_alpha: float = 0.6, cmap: str = 'coolwarm', interpolation: str = 'none', ax: Optional["Axes"] = None, **thumb_kwargs ): """Plot heatmap of uncertainty. Args: heatmap_alpha (float, optional): Alpha of heatmap overlay. Defaults to 0.6. cmap (str, optional): Matplotlib heatmap colormap. Defaults to 'coolwarm'. interpolation (str, optional): Interpolation strategy to use for smoothing heatmap. Defaults to 'none'. ax (matplotlib.axes.Axes, optional): Figure axis. If not supplied, will prepare a new figure axis. Keyword args: show_roi (bool, optional): Overlay ROIs onto heatmap image. Defaults to True. roi_color (str): ROI line color. Defaults to 'k' (black). linewidth (int): Width of ROI line. Defaults to 5. """ import matplotlib.colors as mcol ax = self._prepare_ax(ax) implot = self.plot_thumbnail(ax=ax, **thumb_kwargs) if heatmap_alpha == 1: implot.set_alpha(0) uqnorm = mcol.TwoSlopeNorm( vmin=0, vcenter=self.uncertainty.max()/2, vmax=self.uncertainty.max() ) extent = calculate_heatmap_extent( self.slide, self._thumb, self.predictions ) ax.imshow( self.uncertainty, norm=uqnorm, extent=extent, cmap=cmap, alpha=heatmap_alpha, interpolation=interpolation, zorder=10 ) ax.set_xlim(0, self._thumb.size[0]) ax.set_ylim(self._thumb.size[1], 0) def plot( self, class_idx: int, heatmap_alpha: float = 0.6, cmap: str = 'coolwarm', interpolation: str = 'none', vmin: float = 0, vmax: float = 1, vcenter: float = 0.5, ax: Optional["Axes"] = None, **thumb_kwargs ) -> None: """Plot a predictive heatmap. If in a Jupyter notebook, the heatmap will be displayed in the cell output. If running via script or shell, the heatmap can then be shown on screen using matplotlib ``plt.show()``: .. code-block:: import slideflow as sf import matplotlib.pyplot as plt heatmap = sf.Heatmap(...) heatmap.plot() plt.show() Args: class_idx (int): Class index to plot. heatmap_alpha (float, optional): Alpha of heatmap overlay. Defaults to 0.6. show_roi (bool, optional): Overlay ROIs onto heatmap image. Defaults to True. cmap (str, optional): Matplotlib heatmap colormap. Defaults to 'coolwarm'. interpolation (str, optional): Interpolation strategy to use for smoothing heatmap. Defaults to 'none'. vmin (float): Minimimum value to display on heatmap. Defaults to 0. vcenter (float): Center value for color display on heatmap. Defaults to 0.5. vmax (float): Maximum value to display on heatmap. Defaults to 1. ax (matplotlib.axes.Axes, optional): Figure axis. If not supplied, will prepare a new figure axis. Keyword args: show_roi (bool, optional): Overlay ROIs onto heatmap image. Defaults to True. roi_color (str): ROI line color. Defaults to 'k' (black). linewidth (int): Width of ROI line. Defaults to 5. """ import matplotlib.colors as mcol if self.predictions is None: raise errors.HeatmapError( "Cannot plot Heatmap which is not yet generated; generate with " "either heatmap.generate() or Heatmap(..., generate=True)" ) ax = self._prepare_ax(ax) implot = self.plot_thumbnail(ax=ax, **thumb_kwargs) if heatmap_alpha == 1: implot.set_alpha(0) ax.set_facecolor("black") divnorm = mcol.TwoSlopeNorm( vmin=vmin, vcenter=vcenter, vmax=vmax ) extent = calculate_heatmap_extent( self.slide, self._thumb, self.predictions ) ax.imshow( self.predictions[:, :, class_idx], norm=divnorm, extent=extent, cmap=cmap, alpha=heatmap_alpha, interpolation=interpolation, zorder=10 ) ax.set_xlim(0, self._thumb.size[0]) ax.set_ylim(self._thumb.size[1], 0) def save_npz(self, path: Optional[str] = None) -> str: """Save heatmap predictions and uncertainty in .npz format. Saves heatmap predictions to ``'predictions'`` in the .npz file. If uncertainty was calculated, this is saved to ``'uncertainty'``. A Heatmap instance can load a saved .npz file with :meth:`slideflow.Heatmap.load()`. Args: path (str, optional): Destination filename for .npz file. Defaults to {slidename}.npz Returns: str: Path to .npz file. """ if path is None: path = f'{self.slide.name}.npz' np_kwargs = dict(predictions=self.predictions) if self.uq: np_kwargs['uncertainty'] = self.uncertainty np.savez(path, **np_kwargs) return path def save( self, outdir: str, show_roi: bool = True, interpolation: str = 'none', logit_cmap: Optional[Union[Callable, Dict]] = None, roi_color: str = 'k', linewidth: int = 5, **kwargs ) -> None: """Saves calculated predictions as heatmap overlays. Args: outdir (str): Path to directory in which to save heatmap images. show_roi (bool, optional): Overlay ROIs onto heatmap image. Defaults to True. interpolation (str, optional): Interpolation strategy to use for smoothing heatmap. Defaults to 'none'. logit_cmap (obj, optional): Either function or a dictionary use to create heatmap colormap. Each image tile will generate a list of predictions of length O, where O is the number of outcomes. If logit_cmap is a function, then the logit prediction list will be passed to the function, and the function is expected to return [R, G, B] values for display. If logit_cmap is a dictionary, it should map 'r', 'g', and 'b' to indices; the prediction for these outcome indices will be mapped to the RGB colors. Thus, the corresponding color will only reflect up to three outcomes. Example mapping prediction for outcome 0 to the red colorspace, 3 to green, etc: {'r': 0, 'g': 3, 'b': 1} roi_color (str): ROI line color. Defaults to 'k' (black). linewidth (int): Width of ROI line. Defaults to 5. Keyword args: cmap (str, optional): Matplotlib heatmap colormap. Defaults to 'coolwarm'. vmin (float): Minimimum value to display on heatmap. Defaults to 0. vcenter (float): Center value for color display on heatmap. Defaults to 0.5. vmax (float): Maximum value to display on heatmap. Defaults to 1. """ with sf.util.matplotlib_backend('Agg'): import matplotlib.pyplot as plt if self.predictions is None: raise errors.HeatmapError( "Cannot plot Heatmap which is not yet generated; generate with " "either heatmap.generate() or Heatmap(..., generate=True)" ) # Save heatmaps in .npz format self.save_npz(os.path.join(outdir, f'{self.slide.name}.npz')) def _savefig(label, bbox_inches='tight', **kwargs): plt.savefig( os.path.join(outdir, f'{self.slide.name}-{label}.png'), bbox_inches=bbox_inches, **kwargs ) log.info('Saving base figures...') # Prepare matplotlib figure ax = self._prepare_ax() thumb_kwargs = dict(roi_color=roi_color, linewidth=linewidth) # Save base thumbnail as separate figure self.plot_thumbnail(show_roi=False, ax=ax, **thumb_kwargs) # type: ignore _savefig('raw') # Save thumbnail + ROI as separate figure self.plot_thumbnail(show_roi=True, ax=ax, **thumb_kwargs) # type: ignore _savefig('raw+roi') if logit_cmap: self.plot_with_logit_cmap(logit_cmap, show_roi=show_roi, ax=ax) _savefig('custom') else: heatmap_kwargs = dict( show_roi=show_roi, interpolation=interpolation, **kwargs ) save_kwargs = dict( bbox_inches='tight', facecolor=ax.get_facecolor(), edgecolor='none' ) # Make heatmap plots and sliders for each outcome category for i in range(self.num_classes): log.info(f'Making {i+1}/{self.num_classes}...') self.plot(i, heatmap_alpha=0.6, ax=ax, **heatmap_kwargs) _savefig(str(i), **save_kwargs) self.plot(i, heatmap_alpha=1, ax=ax, **heatmap_kwargs) _savefig(f'{i}-solid', **save_kwargs) # Uncertainty map if self.uq: log.info('Making uncertainty heatmap...') self.plot_uncertainty(heatmap_alpha=0.6, ax=ax, **heatmap_kwargs) _savefig('UQ', **save_kwargs) self.plot_uncertainty(heatmap_alpha=1, ax=ax, **heatmap_kwargs) _savefig('UQ-solid', **save_kwargs) plt.close() log.info(f'Saved heatmaps for [green]{self.slide.name}') def view(self): """Load the Heatmap into Slideflow Studio for interactive view. See :ref:`studio` for more information. """ from slideflow.studio import Studio studio = Studio() studio.load_slide(self.slide.path, stride=self.stride_div) studio.load_model(self.model_path) studio.load_heatmap(self) studio.run()
class ModelHeatmap(Heatmap): def __init__( self, slide: Union[str, WSI], model: Union[str, "torch.nn.Module", "tf.keras.Model"], *, img_format: str, tile_px: Optional[int] = None, tile_um: Optional[int] = None, stride_div: Optional[int] = None, normalizer: Optional[sf.norm.StainNormalizer] = None, batch_size: int = 32, num_threads: Optional[int] = None, num_processes: Optional[int] = None, generate: bool = True, uq: bool = False, load_method: Optional[str] = None, apply_softmax: Optional[bool] = None, generator_kwargs: Optional[Dict[str, Any]] = None, **wsi_kwargs ): """Convolutes across a whole slide, calculating predictions and saving predictions internally for later use. Args: slide (str): Path to slide. model (str): Path to Tensorflow or PyTorch model. Keyword args: img_format (str, optional): Image format (png, jpg) to use when extracting tiles from slide. Must match the image format the model was trained on. If 'auto', will use the format logged in the model params.json. tile_px (int): Tile width in pixels. Required if ``model`` is a path. Defaults to None. tile_um (int or str): Tile width in microns (int) or magnification (str, e.g. "20x"). Required if ``model`` is a path. Defaults to None. stride_div (int, optional): Divisor for stride when convoluting across slide. Defaults to 2. normalizer (:class:`slideflow.norm.StainNormalizer`): Stain normalizer to use when preprocessing image tiles. Defaults to None. batch_size (int, optional): Batch size for calculating predictions. Defaults to 32. num_threads (int, optional): Number of tile worker threads. Cannot supply both ``num_threads`` (uses thread pool) and ``num_processes`` (uses multiprocessing pool). Defaults to CPU core count. num_processes (int, optional): Number of child processes to spawn for multiprocessing pool. Defaults to None (does not use multiprocessing). generate (bool): Generate the heatmap after initialization. If False, heatmap will need to be manually generated by calling :meth:``Heatmap.generate()``. uq (bool): Calculate uncertainty via dropout (requires model with dropout layers). Defaults to False. load_method (str): Either 'full' or 'weights'. Method to use when loading a Tensorflow model. If 'full', loads the model with ``tf.keras.models.load_model()``. If 'weights', will read the ``params.json`` configuration file, build the model architecture, and then load weights from the given model with ``Model.load_weights()``. Loading with 'full' may improve compatibility across Slideflow versions. Loading with 'weights' may improve compatibility across hardware & environments. apply_softmax (bool): Apply softmax transformation to logits. Only used for PyTorch models (raises an error if this argument is specified and the model is not a PyTorch model). Defaults to True. roi_dir (str, optional): Directory in which slide ROI is contained. Defaults to None. rois (list, optional): List of paths to slide ROIs. Alternative to providing roi_dir. Defaults to None. roi_method (str): Either 'inside', 'outside', 'auto', or 'ignore'. Determines how ROIs are used to extract tiles. If 'inside' or 'outside', will extract tiles in/out of an ROI, and raise errors.MissingROIError if an ROI is not available. If 'auto', will extract tiles inside an ROI if available, and across the whole-slide if no ROI is found. If 'ignore', will extract tiles across the whole-slide regardless of whether an ROI is available. Defaults to 'auto'. """ if num_processes is not None and num_threads is not None: raise ValueError("Invalid argument: cannot supply both " "num_processes and num_threads") self.uq = uq self.img_format = img_format self.num_threads = num_threads self.num_processes = num_processes self.batch_size = batch_size self.insets = [] # type: List[Inset] if generator_kwargs is None: generator_kwargs = {} if apply_softmax is not None: if sf.util.model_backend(model) == 'tensorflow': raise ValueError("Keyword argument 'apply_softmax' is invalid " "for Tensorflow models.") if isinstance(slide, str): if tile_px is None: raise ValueError("If slide is a path, must supply tile_px.") if tile_um is None: raise ValueError("If slide is a path, must supply tile_um.") if stride_div is None: stride_div = 2 self.slide_path = slide self.tile_px = tile_px self.tile_um = tile_um self.stride_div = stride_div try: self.slide = WSI( self.slide_path, self.tile_px, self.tile_um, self.stride_div, **wsi_kwargs # type: ignore ) except errors.SlideLoadError: raise errors.HeatmapError( f'Error loading slide {self.slide.name} for heatmap') elif isinstance(slide, WSI): if tile_px is not None: log.warn("slide is a WSI; ignoring supplied tile_px.") if tile_um is not None: log.warn("slide is a WSI; ignoring supplied tile_um.") if stride_div is not None: log.warn("slide is a WSI; ignoring supplied stride_div.") if wsi_kwargs: log.warn("WSI provided; ignoring keyword arguments: " + ", ".join(list(wsi_kwargs.keys()))) self.slide_path = slide.path self.slide = slide self.tile_px = slide.tile_px self.tile_um = slide.tile_um self.stride_div = slide.stride_div else: raise ValueError(f"Unrecognized value {slide} for argument slide") if uq and sf.util.model_backend(model) == 'tensorflow': import slideflow.model.tensorflow interface_class = sf.model.tensorflow.UncertaintyInterface # type: ignore interface_kw = {} # type: Dict[str, Any] elif uq and sf.util.model_backend(model) == 'torch': import slideflow.model.torch interface_class = sf.model.torch.UncertaintyInterface # type: ignore interface_kw = dict(tile_px=self.tile_px, apply_softmax=apply_softmax) elif sf.util.model_backend(model) == 'tensorflow': import slideflow.model.tensorflow interface_class = sf.model.tensorflow.Features # type: ignore interface_kw = dict(include_preds=True) elif sf.util.model_backend(model) == 'torch': import slideflow.model.torch interface_class = sf.model.torch.Features # type: ignore interface_kw = dict( include_preds=True, tile_px=self.tile_px, apply_softmax=apply_softmax ) else: raise ValueError(f"Unable to interpret model {model}") if load_method is not None: interface_kw.update(dict(load_method=load_method)) if isinstance(model, str): self.interface = interface_class( model, layers=None, **interface_kw) else: self.interface = interface_class.from_model( model, layers=None, wsi_normalizer=normalizer, **interface_kw) self.num_classes = self.interface.num_classes self.num_features = self.interface.num_features self.num_uncertainty = self.interface.num_uncertainty self.predictions = None self.uncertainty = None self.model_path = None if generate: self.generate(**generator_kwargs) elif generator_kwargs: log.warn("Heatmap generate=False, ignoring generator_kwargs (" f"{generator_kwargs})") def view(self): raise NotImplementedError # ----------------------------------------------------------------------------- def calculate_heatmap_extent( wsi: "sf.WSI", thumbnail: "Image", grid: np.ndarray ) -> Tuple[float, float, float, float]: """Calculate implot extent for a heatmap grid.""" full_extract = int(wsi.tile_um / wsi.mpp) wsi_stride = int(full_extract / wsi.stride_div) _overlay_wsi_dim = (wsi_stride * (grid.shape[1]), wsi_stride * (grid.shape[0])) _overlay_offset_wsi_dim = ( full_extract/2 - wsi_stride/2, full_extract/2 - wsi_stride/2 ) thumb_ratio = ( wsi.dimensions[0] / thumbnail.size[0], wsi.dimensions[1] / thumbnail.size[1] ) return ( _overlay_offset_wsi_dim[0] / thumb_ratio[0], _overlay_wsi_dim[0] / thumb_ratio[0], _overlay_wsi_dim[1] / thumb_ratio[1], _overlay_offset_wsi_dim[1] / thumb_ratio[1] ) # -----------------------------------------------------------------------------