from __future__ import absolute_import, division, print_function
import csv
import os
import sys
import time
import warnings
from functools import partial
from multiprocessing.dummy import Pool as DPool
from os.path import join
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import cv2
import numpy as np
from rich.progress import track
import slideflow as sf
from slideflow import errors
from slideflow.stats import SlideMap, get_centroid_index
from slideflow.util import log
from slideflow.stats import get_centroid_index
if TYPE_CHECKING:
from slideflow.norm import StainNormalizer
# -----------------------------------------------------------------------------
def process_tile_image(args, decode_kwargs):
if args is None:
return None, None, None, None
point_index, x, y, display_size, alpha, image = args
if not point_index:
return None, None, None, None
if isinstance(image, tuple):
tfr, tfr_idx = image
image = sf.io.get_tfrecord_by_index(tfr, tfr_idx)['image_raw']
if image is None:
return point_index, None, None, None
if sf.model.is_tensorflow_tensor(image):
image = image.numpy()
image = decode_image(image, **decode_kwargs)
extent = [
x - display_size/2,
x + display_size/2,
y - display_size/2,
y + display_size/2
]
return point_index, image, extent, alpha
def decode_image(
image: Union[str, np.ndarray],
normalizer: Optional["StainNormalizer"],
img_format: str
) -> np.ndarray:
"""Internal method to convert an image string (as stored in TFRecords)
to an RGB array."""
if normalizer:
try:
if isinstance(image, np.ndarray):
return normalizer.rgb_to_rgb(image)
elif img_format in ('jpg', 'jpeg'):
return normalizer.jpeg_to_rgb(image)
elif img_format == 'png':
return normalizer.png_to_rgb(image)
else:
return normalizer.transform(image)
except Exception as e:
log.error("Error encountered during image normalization, "
f"displaying image tile non-normalized. {e}")
if isinstance(image, np.ndarray):
return image
else:
image_arr = np.fromstring(image, np.uint8)
tile_image_bgr = cv2.imdecode(image_arr, cv2.IMREAD_COLOR)
return cv2.cvtColor(tile_image_bgr, cv2.COLOR_BGR2RGB)
def find_corresponding_points(row, points):
return points.loc[((points.grid_x == row.x) & (points.grid_y == row.y))].index
# -----------------------------------------------------------------------------
[docs]class Mosaic:
"""Visualization of plotted image tiles."""
def __init__(
self,
images: Union[SlideMap, List[np.ndarray], np.ndarray, List[Tuple[str, int]]],
coords: Optional[Union[Tuple[int, int], np.ndarray]] = None,
*,
tfrecords: List[str] = None,
normalizer: Optional[Union[str, "StainNormalizer"]] = None,
normalizer_source: Optional[str] = None,
**grid_kwargs
) -> None:
"""Generate a mosaic map, which visualizes plotted image tiles.
Creating a mosaic map requires two components: a set of images and
corresponding coordinates. Images and coordinates can either be manually
provided, or the mosaic can dynamically read images from TFRecords as
needed, reducing memory requirements.
The first argument provides the images, and may be any of the following:
- A list or array of images (np.ndarray, HxWxC)
- A list of tuples, containing ``(slide_name, tfrecord_index)``
- A ``slideflow.SlideMap`` object
The second argument provides the coordinates, and may be any of:
- A list or array of (x, y) coordinates for each image
- None (if the first argument is a ``SlideMap``, which has coordinates)
If images are to be read dynamically from tfrecords (with a ``SlideMap``,
or by providing tfrecord indices directly), the keyword argument
``tfrecords`` must be specified with paths to tfrecords.
Published examples:
- Figure 4: https://doi.org/10.1038/s41379-020-00724-3
- Figure 6: https://doi.org/10.1038/s41467-022-34025-x
Examples
Generate a mosaic map from a list of images and coordinates.
.. code-block:: python
# Example data (images are HxWxC, np.ndarray)
images = [np.ndarray(...), ...]
coords = [(0.2, 0.9), ...]
# Generate the mosaic
mosaic = Mosaic(images, coordinates)
Generate a mosaic map from tuples of TFRecord paths and indices.
.. code-block:: python
# Example data
paths = ['/path/to/tfrecord.tfrecords', ...]
idx = [253, 112, ...]
coords = [(0.2, 0.9), ...]
tuples = [(tfr, idx) for tfr, i in zip(paths, idx)]
# Generate mosaic map
mosaic = sf.Mosaic(tuples, coords)
Generate a mosaic map from a SlideMap and list of TFRecord paths.
.. code-block:: python
# Prepare a SlideMap from a project
P = sf.Project('/project/path')
ftrs = P.generate_features('/path/to/model')
slide_map = sf.SlideMap.from_features(ftrs)
# Generate mosaic
mosaic = Mosaic(slide_map, tfrecords=ftrs.tfrecords)
Args:
images (list(np.ndarray), tuple, :class:`slideflow.SlideMap`):
Images from which to generate the mosaic. May be a list or
array of images (np.ndarray, HxWxC), a list of tuples,
containing ``(slide_name, tfrecord_index)``, or a
``slideflow.SlideMap`` object.
coords (list(str)): Coordinates for images. May be a list or array
of (x, y) coordinates for each image (of same length
as ``images``), or None (if ``images`` is a ``SlideMap`` object).
Keyword args:
tfrecords (list(str), optional): TFRecord paths. Required if
``images`` is either a ``SlideMap`` object or a list of tuples
containing ``(slide_name, tfrecord_index)``. Defaults to None.
num_tiles_x (int, optional): Mosaic map grid size. Defaults to 50.
tile_select (str, optional): 'first', 'nearest', or 'centroid'.
Determines how to choose a tile for display on each grid space.
If 'first', will display the first valid tile in a grid space
(fastest; recommended). If 'nearest', will display tile nearest
to center of grid space. If 'centroid', for each grid, will
calculate which tile is nearest to centroid tile_meta.
Defaults to 'nearest'.
tile_meta (dict, optional): Tile metadata, used for tile_select.
Dictionary should have slide names as keys, mapped to list of
metadata (length of list = number of tiles in slide).
Defaults to None.
normalizer ((str or :class:`slideflow.norm.StainNormalizer`), optional):
Normalization strategy to use on image tiles. Defaults to None.
normalizer_source (str, optional): Stain normalization preset or
path to a source image. Valid presets include 'v1', 'v2', and
'v3'. If None, will use the default present ('v3').
Defaults to None.
"""
self.tile_point_distances = [] # type: List[Dict]
self.slide_map = None
self.tfrecords = tfrecords
self.grid_images = {}
self.grid_coords = [] # type: np.ndarray
self.grid_idx = [] # type: np.ndarray
if isinstance(images, SlideMap):
if tfrecords is None:
raise ValueError("If building a Mosaic from a SlideMap, must "
"provide paths to tfrecords via keyword arg "
"tfrecords=...")
elif isinstance(tfrecords, list) and not len(tfrecords):
raise errors.TFRecordsNotFoundError()
self._prepare_from_slidemap(images)
elif isinstance(images[0], (tuple, list)) and isinstance(images[0][0], str):
self._prepare_from_tuples(images, coords) # type: ignore
else:
assert coords is not None
assert len(images) == len(coords)
self._prepare_from_coords(images, coords) # type: ignore
# ---------------------------------------------------------------------
# Detect tfrecord image format
if self.tfrecords is not None:
_, self.img_format = sf.io.detect_tfrecord_format(self.tfrecords[0])
else:
self.img_format = 'numpy'
# Setup normalization
if isinstance(normalizer, str):
log.info(f'Using realtime {normalizer} normalization')
self.normalizer = sf.norm.autoselect(
method=normalizer,
source=normalizer_source
) # type: Optional[StainNormalizer]
elif normalizer is not None:
self.normalizer = normalizer
else:
self.normalizer = None
self.generate_grid(**grid_kwargs)
def _prepare_from_coords(
self,
images: Union[List[np.ndarray], np.ndarray],
coords: List[Union[Tuple[int, int], np.ndarray]]
) -> None:
"""Prepare the Mosaic map from a set of images and coordinates."""
log.info('Loading coordinates and plotting points...')
self.images = images
self.mapped_tiles = [] # type: List[int]
self.points = [{
'coord': coords[i],
'global_index': i,
'category': 'none',
'has_paired_tile': False,
} for i in range(len(coords))]
def _prepare_from_slidemap(
self,
slide_map: SlideMap,
*,
tile_meta: Optional[Dict] = None,
) -> None:
"""Prepare the Mosaic map from a ``SlideMap`` object."""
log.info('Loading coordinates from SlideMap and plotting points...')
self.slide_map = slide_map
self.mapped_tiles = {} # type: Dict[str, List[int]]
self.points = slide_map.data.copy()
self.points['has_paired_tile'] = False
self.points['points_index'] = self.points.index
self.points['alpha'] = 1.
if tile_meta:
self.points['meta'] = self.points.apply(lambda row: tile_meta[row.slide][row.tfr_index], axis=1)
log.debug("Loading complete.")
def _prepare_from_tuples(
self,
images: List[Tuple[str, int]],
coords: List[Union[Tuple[int, int], np.ndarray]],
) -> None:
"""Prepare from a list of tuples with TFRecord names/indices."""
log.info('Loading coordinates from SlideMap and plotting points...')
self.mapped_tiles = {} # type: Dict[str, List[int]]
self.points = []
for i, (tfr, idx) in enumerate(images):
self.points.append({
'coord': np.array(coords[i]),
'global_index': i,
'category': 'none',
'slide': (tfr if self.tfrecords is not None
else sf.util.path_to_name(tfr)),
'tfrecord': (tfr if self.tfrecords is None
else self._get_tfrecords_from_slide(tfr)),
'tfrecord_index': idx,
'has_paired_tile': None,
})
def _get_image_from_point(self, index):
point = self.points.loc[index]
if 'tfr_index' in point:
tfr = self._get_tfrecords_from_slide(point.slide)
tfr_idx = point.tfr_index
if not tfr:
log.error(f"TFRecord {tfr} not found in slide_map")
return None
image = sf.io.get_tfrecord_by_index(tfr, tfr_idx)['image_raw']
else:
image = self.images[index]
return image
def _get_tfrecords_from_slide(self, slide: str) -> Optional[str]:
"""Using the internal list of TFRecord paths, returns the path to a
TFRecord for a given corresponding slide."""
for tfr in self.tfrecords:
if sf.util.path_to_name(tfr) == slide:
return tfr
log.error(f'Unable to find TFRecord path for slide [green]{slide}')
return None
def _initialize_figure(self, figsize, background):
import matplotlib.pyplot as plt
fig = plt.figure(figsize=figsize)
self.ax = fig.add_subplot(111, aspect='equal')
self.ax.set_facecolor(background)
fig.tight_layout()
plt.subplots_adjust(
left=0.02,
bottom=0,
right=0.98,
top=1,
wspace=0.1,
hspace=0
)
self.ax.set_aspect('equal', 'box')
self.ax.set_xticklabels([])
self.ax.set_yticklabels([])
def _plot_tile_image(self, image, extent, alpha=1):
return self.ax.imshow(
image,
aspect='equal',
origin='lower',
extent=extent,
zorder=99,
alpha=alpha
)
def _finalize_figure(self):
self.ax.autoscale(enable=True, tight=None)
def _record_point(self, index):
point = self.points.loc[index]
if 'tfr_index' in point:
tfr = self._get_tfrecords_from_slide(point.slide)
if tfr is None:
return
if tfr in self.mapped_tiles:
self.mapped_tiles[tfr] += [point.tfr_index]
else:
self.mapped_tiles[tfr] = [point.tfr_index]
else:
self.mapped_tiles += [index]
@property
def decode_kwargs(self):
return dict(normalizer=self.normalizer, img_format=self.img_format)
def points_at_grid_index(self, x, y):
return self.points.loc[((self.points.grid_x == x) & (self.points.grid_y == y))]
def selected_points(self):
return self.points.loc[self.points.selected]
def generate_grid(
self,
num_tiles_x: int = 50,
tile_meta: Optional[Dict] = None,
tile_select: str = 'first',
max_dist: Optional[float] = None,
):
"""Generate the mosaic map grid.
Args:
num_tiles_x (int, optional): Mosaic map grid size. Defaults to 50.
tile_meta (dict, optional): Tile metadata, used for tile_select.
Dictionary should have slide names as keys, mapped to list of
metadata (length of list = number of tiles in slide).
Defaults to None.
tile_select (str, optional): 'first', 'nearest', or 'centroid'.
Determines how to choose a tile for display on each grid space.
If 'first', will display the first valid tile in a grid space
(fastest; recommended). If 'nearest', will display tile nearest
to center of grid space. If 'centroid', for each grid, will
calculate which tile is nearest to centroid tile_meta.
Defaults to 'nearest'.
"""
# Initial validation checks
if tile_select not in ('nearest', 'centroid', 'first'):
raise TypeError(f'Unknown tile selection method {tile_select}')
else:
log.debug(f'Tile selection method: {tile_select}')
self.num_tiles_x = num_tiles_x
self.grid_images = {}
# Build the grid
x_points = self.points.x.values
y_points = self.points.y.values
max_x = x_points.max()
min_x = x_points.min()
max_y = y_points.max()
min_y = y_points.min()
log.debug(f'Loaded {len(self.points)} points.')
self.tile_size = (max_x - min_x) / self.num_tiles_x
self.num_tiles_y = int((max_y - min_y) / self.tile_size)
self.grid_idx = np.reshape(np.dstack(np.indices((self.num_tiles_x, self.num_tiles_y))), (self.num_tiles_x * self.num_tiles_y, 2))
_grid_offset = np.array([(self.tile_size/2) + min_x, (self.tile_size/2) + min_y])
self.grid_coords = (self.grid_idx * self.tile_size) + _grid_offset
points_added = 0
x_bins = np.arange(min_x, max_x, ((max_x - min_x) / self.num_tiles_x))[1:]
y_bins = np.arange(min_y, max_y, ((max_y - min_y) / self.num_tiles_y))[1:]
self.points['grid_x'] = np.digitize(self.points.x.values, x_bins, right=False)
self.points['grid_y'] = np.digitize(self.points.y.values, y_bins, right=False)
self.points['selected'] = False
log.debug(f'{points_added} points added to grid')
# Then, calculate distances from each point to each spot on the grid
def select_nearest_points(idx):
grid_x, grid_y = self.grid_idx[idx][0], self.grid_idx[idx][1]
grid_coords = self.grid_coords[idx]
# Calculate distance for each point within the grid tile from
# center of the grid tile
_points = self.points_at_grid_index(grid_x, grid_y)
if not _points.empty:
if tile_select == 'nearest':
point_coords = np.stack([_points.x.values, _points.y.values], axis=-1)
dist = np.linalg.norm(
point_coords - grid_coords,
ord=2,
axis=1.
)
if max_dist is not None:
masked_dist = np.ma.masked_array(dist, (dist >= (max_dist * self.tile_size)))
if masked_dist.count():
self.points.loc[_points.index[np.argmin(masked_dist)], 'selected'] = True
else:
self.points.loc[_points.index[np.argmin(dist)], 'selected'] = True
elif not tile_meta:
raise errors.MosaicError(
'Mosaic centroid option requires tile_meta.'
)
else:
centroid_index = get_centroid_index(_points.meta.values)
self.points.loc[_points.index[centroid_index], 'selected'] = True
start = time.time()
if tile_select == 'first':
grid_group = self.points.groupby(['grid_x', 'grid_y'])
first_indices = grid_group.nth(0).points_index.values
self.points.loc[first_indices, 'selected'] = True
elif tile_select in ('nearest', 'centroid'):
self.points['selected'] = False
dist_fn = partial(select_nearest_points)
pool = DPool(sf.util.num_cpu())
for i, _ in track(enumerate(pool.imap_unordered(dist_fn, range(len(self.grid_idx))), 1), total=len(self.grid_idx)):
pass
pool.close()
pool.join()
else:
raise ValueError(
f'Unrecognized value for tile_select: "{tile_select}"'
)
end = time.time()
if sf.getLoggingLevel() <= 20:
sys.stdout.write('\r\033[K')
log.debug(f'Tile image selection complete ({end - start:.1f} sec)')
def export(self, path: str) -> None:
"""Export SlideMap and configuration for later loading.
Args:
path (str): Directory in which to save configuration.
"""
if self.slide_map is None:
raise ValueError(
"Mosaic.export() requires a Mosaic built from a SlideMap."
)
self.slide_map.save(path)
if isinstance(self.tfrecords, list):
tfr = self.tfrecords
else:
tfr = list(self.tfrecords)
sf.util.write_json(tfr, join(path, 'tfrecords.json'))
log.info(f"Mosaic configuration exported to {path}")
def plot(
self,
figsize: Tuple[int, int] = (200, 200),
focus: Optional[List[str]] = None,
focus_slide: Optional[str] = None,
background: str = '#dfdfdf',
pool: Optional[Any] = None,
) -> None:
"""Initializes figures and places image tiles.
If in a Jupyter notebook, the heatmap will be displayed in the cell
output. If running via script or shell, the heatmap can then be
shown on screen using matplotlib ``plt.show()``:
.. code-block::
import slideflow as sf
import matplotlib.pyplot as plt
heatmap = sf.Heatmap(...)
heatmap.plot()
plt.show()
Args:
figsize (Tuple[int, int], optional): Figure size. Defaults to
(200, 200).
focus (list, optional): List of tfrecords (paths) to highlight
on the mosaic. Defaults to None.
focus_slide (str, optional): Highlight tiles from this slide.
Defaults to None.
"""
if (focus is not None or focus_slide is not None) and self.tfrecords is None:
raise ValueError("Unable to plot with focus; slides/tfrecords not configured.")
log.debug("Initializing figure...")
self._initialize_figure(figsize=figsize, background=background)
# Reset alpha and display size
if focus_slide:
self.points['alpha'] = 1.
self.points['display_size'] = self.tile_size
if focus_slide:
for idx in self.grid_idx:
_points = self.points_at_grid_index(x=idx[0], y=idx[1])
if not _points.empty and focus_slide:
n_matching = len(_points.loc[_points.slide == focus_slide])
self.points.loc[_points.index, 'alpha'] = n_matching / len(_points)
# Then, pair grid tiles and points according to their distances
log.info('Placing image tiles...')
placed = 0
start = time.time()
to_map = []
should_close_pool = False
has_tfr = 'tfr_index' in self.points.columns
selected_points = self.selected_points()
for idx, point in selected_points.iterrows():
if has_tfr:
tfr = self._get_tfrecords_from_slide(point.slide)
tfr_idx = point.tfr_index
if tfr:
image = (tfr, tfr_idx)
else:
log.error(f"TFRecord {tfr} not found in slide_map")
image = None
else:
image = self.images[idx]
to_map.append((idx, point.grid_x * self.tile_size, point.grid_y * self.tile_size, point.display_size, point.alpha, image))
if pool is None:
pool = DPool(sf.util.num_cpu())
should_close_pool = True
for i, (point_idx, image, extent, alpha) in track(enumerate(pool.imap(partial(process_tile_image, decode_kwargs=self.decode_kwargs), to_map)), total=len(selected_points)):
if point_idx is not None:
self._record_point(point_idx)
self._plot_tile_image(image, extent, alpha)
point = self.points.loc[point_idx]
self.grid_images[(point.grid_x, point.grid_y)] = image
placed += 1
if should_close_pool:
pool.close()
pool.join()
log.debug(f'Tile images placed: {placed} ({time.time()-start:.2f}s)')
if focus:
self.focus(focus)
self._finalize_figure()
def save(self, filename: str, **kwargs: Any) -> None:
"""Saves the mosaic map figure to the given filename.
Args:
filename (str): Path at which to save the mosiac image.
Keyword args:
figsize (Tuple[int, int], optional): Figure size. Defaults to
(200, 200).
focus (list, optional): List of tfrecords (paths) to highlight on
the mosaic.
"""
with sf.util.matplotlib_backend('Agg'):
import matplotlib.pyplot as plt
self.plot(**kwargs)
log.info('Exporting figure...')
try:
if not os.path.exists(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename))
except FileNotFoundError:
pass
plt.savefig(filename, bbox_inches='tight')
log.info(f'Saved figure to [green]{filename}')
plt.close()
def save_report(self, filename: str) -> None:
"""Saves a report of which tiles (and their corresponding slide)
were displayed on the Mosaic map, in CSV format."""
with open(filename, 'w') as f:
writer = csv.writer(f)
writer.writerow(['slide', 'index'])
if isinstance(self.mapped_tiles, dict):
for tfr in self.mapped_tiles:
for idx in self.mapped_tiles[tfr]:
writer.writerow([tfr, idx])
else:
for idx in self.mapped_tiles:
writer.writerow([idx])
log.info(f'Mosaic report saved to [green]{filename}')
def view(self, slides: List[str] = None) -> None:
"""Open Mosaic in Slideflow Studio.
See :ref:`studio` for more information.
Args:
slides (list(str), optional): Path to whole-slide images. Used for
displaying image tile context when hovering over a mosaic grid.
Defaults to None.
"""
from slideflow.studio.widgets import MosaicWidget
from slideflow.studio import Studio
studio = Studio(widgets=[MosaicWidget])
mosaic = studio.get_widget('MosaicWidget')
mosaic.load(
self.slide_map,
tfrecords=self.tfrecords,
slides=slides,
normalizer=self.normalizer
)
studio.run()