Shortcuts

Source code for slideflow.slide.qc.strided_dl

import numpy as np

from tqdm import tqdm
from contextlib import contextmanager
from typing import Callable, Union, Optional, TYPE_CHECKING
from .strided_qc import _StridedQC, _StridedQC_V2

if TYPE_CHECKING:
    import slideflow as sf


[docs]class StridedDL(_StridedQC): def __init__( self, model: Callable, pred_idx: int, tile_px: int, tile_um: Union[str, int], *, buffer: int = 8, verbose: bool = False, pred_threshold: float = 0.5, **wsi_kwargs ): """QC function which uses a deep learning model to generate a QC mask. When this QC method is applied to a slide, the given deep learning model generates predictions across the whole-slide image (using the class index specified by ``pred_idx``). Areas with a prediction above ``pred_threshold`` are masked, to be discarded. Examples Create a DeepFocus module that filters out-of-focus tiles. .. code-block:: python import slideflow as sf from slideflow.slide.qc import strided_dl from deepfocus import deepfocus_v3 deepfocus = strided_dl.StridedDL( model=deepfocus_v3(), pred_idx=1, tile_px=64, tile_um='40x' ) wsi = sf.WSI(...) wsi.qc(deepfocus) Do the same, but using class inheritance. .. code-block:: python import slideflow as sf from slideflow.slide.qc import strided_dl from deepfocus import deepfocus_v3 class DeepFocus(strided_dl.StridedDL): def __init__(self): model = deepfocus_v3() checkpoint = '/path/to/checkpoint-ver5' load_checkpoint(model, checkpoint) super().__init__( model=model, pred_idx=1, tile_px=64, tile_um='40x' ) wsi = sf.WSI(...) deepfocus = DeepFocus() wsi.qc(deepfocus) Args: model (callable): Deep learning model. pred_idx (int): Index of the model output to interpret as the final prediction. tile_px (int): Tile size. tile_um (str or float): Tile size, in microns (int) or magnification (str). Keyword arguments: verbose (bool): Show a progress bar during calculation. buffer (int): Number of tiles (width and height) to extract and process simultaneously. Extracted tile size (width/height) will be ``tile_px * buffer``. Defaults to 8. grayspace_fraction (float): Grayspace fraction when extracting tiles from slides. Defaults to 1 (disables). pred_threshold (float): Predictions below this value are masked. kwargs (Any): All remaining keyword arguments are passed to :meth:`slideflow.WSI.build_generator()`. """ super().__init__( tile_px=tile_px, tile_um=tile_um, buffer=buffer, verbose=verbose, lazy_iter=True, deterministic=False, **wsi_kwargs ) self.model = model self.pred_idx = pred_idx self.pred_threshold = pred_threshold def build_mask(self, x, y) -> np.ndarray: """Build the base, empty QC mask.""" return np.ones((x, y), dtype=np.float32) def apply(self, image: np.ndarray) -> np.ndarray: """Predict focus value of an image tile using DeepFocus model.""" y_pred = self.model(image, training=False)[:, self.pred_idx].numpy() return y_pred.reshape(self.buffer, self.buffer) def collate_mask(self, mask: np.ndarray): """Convert the mask from predictions to bool using a threshold.""" if self.pred_threshold is not None: return mask > self.pred_threshold else: return mask def preprocess(self, image: np.ndarray): """Apply preprocessing to an image.""" return np.clip(image.astype(np.float32) / 255, 0, 1) @contextmanager def _set_threshold(self, threshold: Optional[Union[bool, float]]): """Temporariliy set or disable the prediction threshold.""" _orig_threshold = self.pred_threshold if isinstance(threshold, float): # Set the threshold to a given threshold self.pred_threshold = threshold elif threshold is False: # Disable thresholding (return raw values) self.pred_threshold = None yield # Return the threshold to irs original value self.pred_threshold = _orig_threshold def __call__( self, wsi: "sf.WSI", threshold: Optional[Union[bool, float]] = None ) -> Optional[np.ndarray]: with self._set_threshold(threshold): return super().__call__(wsi)
# ----------------------------------------------------------------------------- def _taper_mask(ly=224, lx=224, sig=7.5): bsize = max(224, max(ly, lx)) xm = np.arange(bsize) xm = np.abs(xm - xm.mean()) mask = 1/(1 + np.exp((xm - (bsize/2-20)) / sig)) mask = mask * mask[:, np.newaxis] mask = mask[bsize//2-ly//2 : bsize//2+ly//2+ly%2, bsize//2-lx//2 : bsize//2+lx//2+lx%2] return mask # ----------------------------------------------------------------------------- class StridedDL_V2(_StridedQC_V2): """Implementation of a strided deep learning QC algorithm. The _StrdedQC_V2 base class collates tiled QC masks into a single mask by cropping out the overlap regions. This approach is suitable for algorithms that generate artifacts at the edges of tiles, but is not adequate for stitching together deep learning predictions. This class is a subclass of _StridedQC_V2, and is designed to stitch together output from a deep learning QC model for tiles using a tapered mask. """ def __init__( self, *args, out_classes: int = 0, **kwargs ): """Create a new StridedDL_V2 object. Args: *args (Any): Arguments to pass to the parent class. out_classes (int): Number of output classes from the deep learning model. If provided, the shape of the QC mask will be (out_classes, h, w). If 0 or not provided, the shape will be (h, w). **kwargs (Any): Keyword arguments to pass to the parent class. """ super().__init__(*args, **kwargs) self.out_classes = out_classes def _calc_mask(self, item): """Calculate a QC mask from a given tile.""" grid_i = item['grid'][0] grid_j = item['grid'][1] image = item['image'] mask = self.apply(image) return mask, (grid_i, grid_j) def build_masks(self, wsi: "sf.WSI"): """Return empty arrays for storing QC mask and the average (taper) mask.""" dim = (wsi.dimensions[1], wsi.dimensions[0]) px_ratio = wsi.tile_px / wsi.full_extract_px target_dim = tuple((np.array(dim) * px_ratio).astype(int)) if self.out_classes: qc_dim = (self.out_classes, target_dim[0], target_dim[1]) else: qc_dim = target_dim qc_mask = np.zeros(qc_dim, np.float32) avg_mask = np.zeros(target_dim, np.float32) return qc_mask, avg_mask def get_tile_bounds(self, wsi: "sf.WSI", i: int, j: int): """Return the bounds of a tile.""" fy, fx = wsi.grid_to_coord(i, j, anchor="topleft") px_ratio = wsi.tile_px / wsi.full_extract_px x0 = int(fx * px_ratio) y0 = int(fy * px_ratio) x1 = x0 + wsi.tile_px y1 = y0 + wsi.tile_px return x0, x1, y0, y1 def __call__( self, wsi: "sf.WSI", ) -> Optional[np.ndarray]: """Apply QC filtering to a slide.""" qc_wsi, mpp = self.get_slide_and_mpp(wsi) qc_mask, avg_mask = self.build_masks(qc_wsi) dts = self.build_tile_generator(qc_wsi) # Get the base taper mask taper_mask = _taper_mask(ly=self.tile_px, lx=self.tile_px, sig=7.5) # Progress bar tracking if self.verbose: pb = tqdm(dts, desc="Generating...", total=qc_wsi.estimated_num_tiles) else: pb = dts # Apply QC filter to each tile if self.filter_pool is not None: map_fn = self.filter_pool.imap_unordered else: map_fn = map for (tile_mask, (i, j)) in map_fn(self._calc_mask, pb): x0, x1, y0, y1 = self.get_tile_bounds(qc_wsi, i, j) if self.out_classes: x1 = min(x1, qc_mask.shape[1]) y1 = min(y1, qc_mask.shape[2]) qc_mask[:, x0:x1, y0:y1] += tile_mask[:, 0: x1-x0, 0: y1-y0] * taper_mask[0: x1-x0, 0: y1-y0] else: x1 = min(x1, qc_mask.shape[0]) y1 = min(y1, qc_mask.shape[1]) qc_mask[x0:x1, y0:y1] += tile_mask[0: x1-x0, 0: y1-y0] * taper_mask[0: x1-x0, 0: y1-y0] avg_mask[x0:x1, y0:y1] += taper_mask[0: x1-x0, 0: y1-y0] # Normalize the mask qc_mask = qc_mask / avg_mask # Close pools if not self.persistent_threads: self.close_pools() return qc_mask