import time
import rasterio
import cv2
import threading
import multiprocessing as mp
import numpy as np
import cellpose
import cellpose.models
import logging
import slideflow as sf
import zarr
import torch
import shapely.affinity as sa
from queue import Queue
from numcodecs import Blosc
from matplotlib.colors import to_rgb
from tqdm import tqdm
from typing import Tuple, Union, Callable, Optional, Iterable, TYPE_CHECKING, List
from functools import partial
from PIL import Image, ImageDraw
from cellpose.utils import outlines_list
from cellpose.models import Cellpose
from cellpose import transforms, plot, dynamics
from slideflow.slide.utils import draw_roi
from slideflow.util import batch_generator, log
from slideflow.model import torch_utils
from . import seg_utils
if TYPE_CHECKING:
from rich.progress import Progress, TaskID
import shapely.geometry
# -----------------------------------------------------------------------------
[docs]class Segmentation:
def __init__(
self,
masks: np.ndarray,
*,
slide: Optional[sf.WSI] = None,
flows: Optional[np.ndarray] = None,
styles: Optional[np.ndarray] = None,
diams: Optional[np.ndarray] = None,
wsi_dim: Optional[Tuple[int, int]] = None,
wsi_offset: Optional[Tuple[int, int]] = None
):
"""Organizes a collection of cell segmentation masks for a slide.
Args:
masks (np.ndarray): Array of masks, dtype int32, where 0 represents
non-segmented background, and each segmented mask is represented
by unique increasing integers.
Keyword args:
slide (slideflow.WSI): If provided, ``Segmentation`` can coordinate
extracting tiles at mask centroids. Defaults to None.
flows (np.ndarray): Array of flows, dtype float32. Defaults to None.
wsi_dim (tuple(int, int)): Size of ``masks`` in the slide
pixel space (highest magnification). Used to align the mask
array to a corresponding slide. Required for calculating
centroids. Defaults to None.
wsi_offset (tuple(int, int)): Top-left starting location for
``masks``, in slide pixel space (highest magnification).
Used to align the mask array to a corresponding slide.
Required for calculating centroids. Defaults to None.
styles (np.ndarray): Array of styles, currently ignored.
diams (np.ndarray): Array of diameters, currently ignored.
"""
if not isinstance(masks, np.ndarray):
raise ValueError("First argument (masks) must be a numpy array.")
self.slide = slide
self.masks = masks
self.flows = flows
self._outlines = None
self._centroids = None
self.wsi_dim = wsi_dim
self.wsi_offset = wsi_offset
@classmethod
def load(cls, path) -> "Segmentation":
"""Alternate class initializer; load a saved Segmentation from *.zip.
Args:
path (str): Path to *.zip containing saved Segmentation, as created
through :meth:`slideflow.cellseg.Segmentation.save`.
"""
loaded = zarr.load(path)
if 'masks' not in loaded:
raise TypeError(f"Unable to load '{path}'; 'masks' index not found.")
flows = None if 'flows' not in loaded else loaded['flows']
obj = cls(slide=None, masks=loaded['masks'], flows=flows)
obj.wsi_dim = loaded['wsi_dim']
obj.wsi_offset = loaded['wsi_offset']
if 'centroids' in loaded:
obj._centroids = loaded['centroids']
return obj
@property
def outlines(self) -> np.ndarray:
"""Calculate and return mask outlines as ``np.ndarray``."""
if self._outlines is None:
self.calculate_outlines()
return self._outlines
@property
def wsi_ratio(self) -> Optional[float]:
"""Ratio of WSI base dimension to the mask shape.
Returns `None` if ``wsi_dim`` was not set.
"""
if self.wsi_dim is not None:
return self.wsi_dim[1] / self.masks.shape[0]
else:
return None
def apply_rois(
self,
scale: float,
annpolys: List["shapely.geometry.Polygon"]
) -> None:
"""Apply regions of interest (ROIs), excluding masks outside ROIs.
Args:
scale (float): ROI scale (roi size / WSI base dimension size).
annpolys (list(``shapely.geometry.Polygon``)): List of ROI
polygons, as available in ``slideflow.WSI.rois``.
"""
if self.wsi_ratio is not None and len(annpolys):
roi_seg_scale = scale / self.wsi_ratio
scaled_polys = [
sa.scale(
poly,
xfact=roi_seg_scale,
yfact=roi_seg_scale,
origin=(0, 0)
) for poly in annpolys
]
roi_seg_mask = rasterio.features.rasterize(
scaled_polys,
out_shape=self.masks.shape,
all_touched=False
).astype(bool)
self.masks *= roi_seg_mask
self.calculate_centroids(force=True)
elif self.wsi_ratio is None:
log.warning("Unable to apply ROIs; WSI dimensions not set.")
return
else:
# No ROIs to apply
return
def centroids(self, wsi_dim: bool = False) -> np.ndarray:
"""Calculate and return mask centroids.
Args:
wsi_dim (bool): Convert centroids from mask space to WSI space.
Requires that ``wsi_dim`` was provided during initialization.
Returns:
A ``np.ndarray`` with shape ``(2, num_masks)``.
"""
if self._centroids is None:
self.calculate_centroids()
if wsi_dim:
if self.wsi_dim is None:
raise ValueError("Unable to calculate wsi_dim for centroids - "
"wsi_dim is not set.")
ratio = self.wsi_dim[1] / self.masks.shape[0]
return ((self._centroids * ratio)[:, ::-1] + self.wsi_offset).astype(np.int32)
else:
return self._centroids
def _draw_centroid(self, img, color='green'):
pil_img = Image.fromarray(img)
draw = ImageDraw.Draw(pil_img)
for c in self.centroids():
x, y = np.int32(c[1]), np.int32(c[0])
draw.ellipse((x-3, y-3, x+3, y+3), fill=color)
return np.asarray(pil_img)
def calculate_centroids(self, force: bool = False) -> None:
"""Calculate centroids.
Centroid values are buffered into ``Segmentation._centroids`` to
reduce unnecessary recalculations.
Args:
force (bool): Recalculate centroids, even if calculated before.
"""
if self._centroids is not None and not force:
return
mask_s = seg_utils.sparse_mask(self.masks)
self._centroids = seg_utils.get_sparse_centroid(self.masks, mask_s)
def calculate_outlines(self, force: bool = False) -> None:
"""Calculate mask outlines.
Mask outlines are buffered into ``Segmentation._outlines`` to
reduce unnecessary recalculations.
Args:
force (bool): Recalculate outlines, even if calculated before.
"""
if self._outlines is not None and not force:
return
self._outlines = outlines_list(self.masks)
def centroid_to_image(self, color: str = 'green') -> np.ndarray:
"""Render an image with the location of all centroids as a point.
Args:
color (str): Centroid color. Defaults to 'green'.
"""
img = np.zeros((self.masks.shape[0], self.masks.shape[1], 3), dtype=np.uint8)
return self._draw_centroid(img, color=color)
def extract_centroids(
self,
slide: str,
tile_px: int = 128,
) -> Callable:
"""Return a generator which extracts tiles from a slide at mask centroids.
Args:
slide (str): Path to a slide.
tile_px (int): Height/width of tile to extract at centroids.
Defaults to 128.
Returns:
A generator which yields a numpy array, with shape
``(tile_px, tile_px, 3)``, at each mask centroid.
"""
reader = sf.slide.wsi_reader(slide)
factor = reader.dimensions[1] / self.masks.shape[0]
def generator():
for c in self._centroids:
cf = c * factor + self.wsi_offset
yield reader.read_from_pyramid(
(cf[1]-(tile_px/2), cf[0]-(tile_px/2)),
(tile_px, tile_px),
(tile_px, tile_px),
convert='numpy',
flatten=True
)
return generator
def mask_to_image(self, centroid=False, color='cyan', centroid_color='green'):
"""Render an image of all masks.
Masks are rendered on a black background.
Args:
centroid (bool): Include centroids as points on the image.
Defaults to False.
color (str): Color of the masks. Defaults to 'cyan'.
centroid_color (str): Color of centroid points. Defaults to 'green'.
Returns:
np.ndarray
"""
if isinstance(color, str):
color = [int(c * 255) for c in to_rgb(color)]
else:
assert len(color) == 3
img = np.zeros((self.masks.shape[0], self.masks.shape[1], 3), dtype=np.uint8)
img[self.masks > 0] = color
if centroid:
return self._draw_centroid(img, color=centroid_color)
else:
return img
def outline_to_image(self, centroid=False, color='red', centroid_color='green'):
"""Render an image with the outlines of all masks.
Args:
centroid (bool): Include centroids as points on the image.
Defaults to False.
color (str): Color of the mask outlines. Defaults to 'red'.
centroid_color (str): Color of centroid points. Defaults to 'green'.
Returns:
np.ndarray
"""
empty = np.zeros((self.masks.shape[0], self.masks.shape[1], 3), dtype=np.uint8)
img = draw_roi(empty, self.outlines, color=color)
if centroid:
return self._draw_centroid(img, color=centroid_color)
else:
return img
def save(
self,
filename: str,
centroids: bool = True,
flows: bool = True
) -> None:
"""Save segmentation masks and metadata to \*.zip.
A :class:`slideflow.cellseg.Segmentation` object can be loaded from
this file with ``.load()``.
Args:
filename (str): Destination filename (ends with \*.zip)
centroids (bool): Save centroid locations.
flows (bool): Save flows.
"""
if not filename.endswith('zip'):
filename += '.zip'
save_dict = dict(
masks=self.masks,
compressor=Blosc(
cname='zstd', clevel=3, shuffle=Blosc.BITSHUFFLE
)
)
if centroids:
self.calculate_centroids()
if self._centroids is not None and centroids:
save_dict['centroids'] = self._centroids
if self.flows is not None and flows:
save_dict['flows'] = self.flows
if self.wsi_dim is not None:
save_dict['wsi_dim'] = self.wsi_dim
if self.wsi_offset is not None:
save_dict['wsi_offset'] = self.wsi_offset
seg_utils.save_zarr_compressed(filename, **save_dict)
# -----------------------------------------------------------------------------
def follow_flows(dP_and_cellprob, cp_thresh, gpus=(0,), **kwargs):
dP, cellprob = dP_and_cellprob
if gpus is not None:
_id = mp.current_process()._identity
proc = 0 if not len(_id) else _id[0]
kwargs['device'] = torch.device(f'cuda:{gpus[proc % len(gpus)]}')
if np.any(cellprob > cp_thresh):
return dynamics.follow_flows(
dP * (cellprob > cp_thresh) / 5.,
use_gpu=(gpus is not None),
**kwargs
)
else:
return (None, None)
def remove_bad_flow(mask_and_dP, flow_threshold, gpus=(0,), **kwargs):
mask, dP = mask_and_dP
if gpus is not None:
_id = mp.current_process()._identity
proc = 0 if not len(_id) else _id[0]
kwargs['device'] = torch.device(f'cuda:{gpus[proc % len(gpus)]}')
if mask.max() > 0 and flow_threshold is not None and flow_threshold > 0:
mask = dynamics.remove_bad_flow_masks(
mask,
dP,
threshold=flow_threshold,
use_gpu=(gpus is not None),
**kwargs
)
return mask
def resize_and_clean_mask(mask, target_size=None):
# Resizing
recast = mask.max() >= 2**16-1
if target_size:
if recast:
mask = mask.astype(np.float32)
else:
mask = mask.astype(np.uint16)
mask = cv2.resize(
mask,
(target_size, target_size),
interpolation=cv2.INTER_NEAREST
).astype(np.uint32)
elif not recast:
mask = mask.astype(np.uint16)
mask = dynamics.utils.fill_holes_and_remove_small_masks(mask, min_size=15)
if mask.dtype == np.uint32 and mask.max() == 65535:
log.warn(f'more than 65535 masks in image, masks returned as np.uint32')
return mask
def get_empty_mask(shape):
mask = np.zeros(shape, np.uint16)
p = np.zeros((len(shape), *shape), np.uint16)
return mask, p
def normalize_img(X):
X = X.float()
i99 = torch.quantile(X, 0.99)
i1 = torch.quantile(X, 0.01)
return (X - i1) / (i99 - i1)
def process_image(img, nchan):
return transforms.convert_image(
img,
channels=[[0, 0]],
channel_axis=None,
z_axis=None,
do_3D=False,
normalize=False,
invert=False,
nchan=nchan)
def process_batch(img_batch):
# Ensure Ly and Lx are divisible by 4
assert not (img_batch.shape[1] % 16 or img_batch.shape[2] % 16)
# Normalize and permute axes.
img_batch = normalize_img(img_batch)
img_batch = torch.permute(img_batch, (0, 3, 1, 2))
return img_batch
def get_masks(args, cp_thresh):
p, inds, cellprob = args
if inds is None:
mask, p = get_empty_mask(cellprob.shape)
else:
mask = dynamics.get_masks(p, iscell=(cellprob > cp_thresh))
return mask, p
def tile_processor(slide, q, batch_size, nchan):
tiles = batch_generator(
slide.torch(
incl_loc='grid',
num_threads=4,
to_tensor=False,
grayspace_fraction=1,
lazy_iter=True
),
batch_size
)
for tile_dict in tiles:
imgs = [t['image_raw'] for t in tile_dict]
imgs = np.array([process_image(img, nchan) for img in imgs])
c = [(t['loc_x'], t['loc_y']) for t in tile_dict]
q.put((imgs, c))
q.put(None)
[docs]def segment_slide(
slide: Union[sf.WSI, str],
model: Union["cellpose.models.Cellpose", str] = 'cyto2',
*,
diam_um: Optional[float] = None,
diam_mean: Optional[int] = None,
window_size: Optional[int] = None,
downscale: Optional[float] = None,
batch_size: int = 8,
gpus: Optional[Union[int, Iterable[int]]] = (0,),
spawn_workers: bool = True,
pb: Optional["Progress"] = None,
pb_tasks: Optional[List["TaskID"]] = None,
show_progress: bool = True,
save_flow: bool = True,
cp_thresh: float = 0.0,
flow_threshold: float = 0.4,
interp: bool = True,
tile: bool = True,
verbose: bool = True,
device: Optional[str] = None,
) -> Segmentation:
"""Segment cells in a whole-slide image, returning masks and centroids.
Args:
slide (str, :class:`slideflow.WSI`): Whole-slide image. May be a path
(str) or WSI object (`slideflow.WSI`).
Keyword arguments:
model (str, :class:`cellpose.models.Cellpose`): Cellpose model to use
for cell segmentation. May be any valid cellpose model. Defaults
to 'cyto2'.
diam_um (float, optional): Cell diameter to detect, in microns.
Determines tile extraction microns-per-pixel resolution to match
the given pixel diameter specified by `diam_mean`. Not used if
`slide` is a `sf.WSI` object.
diam_mean (int, optional): Cell diameter to detect, in pixels (without
image resizing). If None, uses Cellpose defaults (17 for the
'nuclei' model, 30 for all others).
window_size (int): Window size, in pixels, at which to segment cells.
Not used if slide is a `sf.WSI` object.
downscale (float): Factor by which to downscale generated masks after
calculation. Defaults to None (keep masks at original size).
batch_size (int): Batch size for cell segmentation. Defaults to 8.
gpus (int, list(int)): GPUs to use for cell segmentation.
Defaults to 0 (first GPU).
spawn_workers (bool): Enable spawn-based multiprocessing. Increases
cell segmentation speed at the cost of higher memory utilization.
pb (:class:`rich.progress.Progress`, optional): Progress bar instance.
Used for external progress bar tracking. Defaults to None.
pb_tasks (list(:class:`rich.progress.TaskID`)): Progress bar tasks.
Used for external progress bar tracking. Defaults to None.
show_progress (bool): Show a tqdm progress bar. Defaults to True.
save_flow (bool): Save flow values for the whole-slide image.
Increases memory utilization. Defaults to True.
cp_thresh (float): Cell probability threshold. All pixels with value
above threshold kept for masks, decrease to find more and larger
masks. Defaults to 0.
flow_threshold (float): Flow error threshold (all cells with errors
below threshold are kept). Defaults to 0.4.
interp (bool): Interpolate during 2D dynamics. Defaults to True.
tile (bool): Tiles image to decrease GPU/CPU memory usage.
Defaults to True.
verbose (bool): Verbose log output at the INFO level. Defaults to True.
Returns:
:class:`slideflow.cellseg.Segmentation`
"""
# Quiet the logger to suppress warnings of empty masks
logging.getLogger('cellpose').setLevel(40)
if diam_mean is None:
diam_mean = 30 if model != 'nuclei' else 17
# Initial validation checks. ----------------------------------------------
if isinstance(slide, str):
assert diam_um is not None, "Must supply diam_um if slide is a path to a slide"
assert window_size is not None, "Must supply window_size if slide is a path to a slide"
tile_um = int(window_size * (diam_um / diam_mean))
slide = sf.WSI(slide, tile_px=window_size, tile_um=tile_um, verbose=False)
elif window_size is not None or diam_um is not None:
raise ValueError("Invalid argument: cannot provide window_size or diam_um "
"when slide is a sf.WSI object")
else:
window_size = slide.tile_px
diam_um = diam_mean * (slide.tile_um/slide.tile_px)
if window_size % 16:
raise ValueError("Window size (tile_px) must be a multiple of 16.")
if downscale is None:
target_size = window_size
else:
target_size = int(window_size / downscale)
if slide.stride_div != 1:
log.warn("Whole-slide cell segmentation not configured for strides "
f"other than 1 (got: {slide.stride_div}).")
# Set up model and parameters. --------------------------------------------
start_time = time.time()
device = torch_utils.get_device(device)
if device.type == 'cpu':
# Run from CPU if CUDA is not available
model = Cellpose(gpu=False, device=device)
gpus = None
log.info("No GPU detected - running from CPU")
else:
model = Cellpose(gpu=True, device=device)
cp = model.cp
cp.batch_size = batch_size
cp.net.load_model(cp.pretrained_model[0], cpu=(not cp.gpu)) # Modify to accept different models
cp.net.eval()
rescale = 1 # No rescaling, as we are manually setting diameter = diam_mean
mask_dim = (slide.stride * (slide.shape[0]-1) + slide.tile_px,
slide.stride * (slide.shape[1]-1) + slide.tile_px)
all_masks = np.zeros((slide.shape[1] * target_size,
slide.shape[0] * target_size),
dtype=np.uint32)
if save_flow:
all_flows = np.zeros((slide.shape[1] * target_size,
slide.shape[0] * target_size, 3),
dtype=np.uint8)
log_fn = log.info if verbose else log.debug
log_fn("=== Segmentation parameters ===")
log_fn(f"Diameter (px): {diam_mean}")
log_fn(f"Diameter (um): {diam_um}")
log_fn(f"Window size: {window_size}")
log_fn(f"Target size: {target_size}")
log_fn(f"Perform tiled: {tile}")
log_fn(f"Slide dimensions: {slide.dimensions}")
log_fn(f"Slide shape: {slide.shape}")
log_fn(f"Slide stride (px): {slide.stride}")
log_fn(f"Est. tiles: {slide.estimated_num_tiles}")
log_fn(f"Save flow: {save_flow}")
log_fn(f"Mask dimensions: {mask_dim}")
log_fn(f"Mask size: {all_masks.shape}")
log_fn("===============================")
# Processes and pools. ----------------------------------------------------
tile_q = mp.Queue(4)
y_q = Queue(2)
ctx = mp.get_context('spawn')
fork_pool = mp.Pool(
batch_size,
initializer=sf.util.set_ignore_sigint
)
if spawn_workers:
spawn_pool = ctx.Pool(
4,
initializer=sf.util.set_ignore_sigint
)
else:
spawn_pool = mp.dummy.Pool(4)
proc_fn = mp.Process if sf.slide_backend() != 'libvips' else threading.Thread
tile_process = proc_fn(
target=tile_processor,
args=(slide, tile_q, batch_size, cp.nchan)
)
tile_process.start()
def net_runner():
while True:
item = tile_q.get()
if item is None:
y_q.put(None)
break
imgs, c = item
torch_batch = cp._to_device(imgs)
torch_batch = process_batch(torch_batch)
if tile:
y, style = cp._run_tiled(
torch_batch.cpu().numpy(),
augment=False,
bsize=224,
return_conv=False
)
else:
y, style = cp.network(torch_batch)
y_q.put((y, style, c))
runner = threading.Thread(target=net_runner)
runner.start()
# Main loop. --------------------------------------------------------------
running_max = 0
if show_progress:
tqdm_pb = tqdm(total=slide.estimated_num_tiles)
while True:
item = y_q.get()
if item is None:
break
y, style, c = item
# Initial preparation
#style /= (style**2).sum()**0.5
y = np.transpose(y, (0,2,3,1))
cellprob = y[:, :, :, 2].astype(np.float32)
dP = y[:, :, :, :2].transpose((3,0,1,2))
del y, style
#styles = style.squeeze()
# Calculate flows
batch_p, batch_ind = zip(*spawn_pool.map(
partial(follow_flows,
niter=(1 / rescale * 200),
interp=interp,
cp_thresh=cp_thresh,
gpus=gpus),
zip([dP[:, i] for i in range(len(c))], cellprob)
))
# Calculate masks
batch_masks, batch_p = zip(*fork_pool.map(
partial(get_masks, cp_thresh=cp_thresh),
zip(batch_p, batch_ind, cellprob)))
# Remove bad flow
batch_masks = spawn_pool.map(
partial(remove_bad_flow, flow_threshold=flow_threshold, gpus=gpus),
zip(batch_masks, [dP[:, i] for i in range(len(c))]))
# Resize masks and clean (remove small masks/holes)
batch_masks = fork_pool.map(
partial(resize_and_clean_mask,
target_size=(None if target_size == window_size
else target_size)),
batch_masks)
dP = dP.squeeze()
cellprob = cellprob.squeeze()
#p = np.stack(batch_p, axis=0)
#flows = [plot.dx_to_circ(dP), dP, cellprob, p]
for i in range(len(c)):
x, y = c[i][0], c[i][1]
img_masks = batch_masks[i].astype(np.uint32)
max_in_mask = img_masks.max()
img_masks[np.nonzero(img_masks)] += running_max
running_max += max_in_mask
all_masks[y * target_size: (y+1)*target_size,
x * target_size: (x+1)*target_size] = img_masks
if save_flow:
flow_plot = plot.dx_to_circ(dP[:, i])
if target_size != window_size:
flow_plot = cv2.resize(flow_plot, (target_size, target_size))
all_flows[y * target_size: (y+1)*target_size,
x * target_size: (x+1)*target_size, :] = flow_plot
# Final cleanup
del dP, cellprob
# Update progress bars
if show_progress:
tqdm_pb.update(batch_size)
if pb is not None and pb_tasks:
for task in pb_tasks:
pb.advance(task, batch_size)
# Close pools/processes and log time.
spawn_pool.close()
spawn_pool.join()
fork_pool.close()
fork_pool.join()
runner.join()
tile_process.join()
ttime = time.time() - start_time
log.info(f"Segmented {running_max} cells for {slide.name} ({ttime:.0f} s)")
# Calculate WSI dimensions and return final segmentation.
wsi_dim = (slide.shape[0] * slide.full_extract_px,
slide.shape[1] * slide.full_extract_px)
wsi_offset = (0, 0)
return Segmentation(
slide=slide,
masks=all_masks,
flows=None if not save_flow else all_flows,
wsi_dim=wsi_dim,
wsi_offset=wsi_offset)