import csv
import os
import pickle
import queue
import sys
import threading
import time
import warnings
import multiprocessing as mp
from collections import defaultdict
from math import isnan
from os.path import exists, join
from typing import (
TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, Iterable, Callable
)
import numpy as np
import pandas as pd
import scipy.stats as stats
import slideflow as sf
from rich.progress import track, Progress
from slideflow import errors
from slideflow.util import log, Labels, ImgBatchSpeedColumn, tfrecord2idx
from .base import BaseFeatureExtractor
if TYPE_CHECKING:
import tensorflow as tf
import torch
# -----------------------------------------------------------------------------
[docs]class DatasetFeatures:
"""Loads annotations, saved layer activations / features, and prepares
output saving directories. Will also read/write processed features to a
PKL cache file to save time in future iterations.
Note:
Storing predictions along with layer features is optional, to offer the user
reduced memory footprint. For example, saving predictions for a 10,000 slide
dataset with 1000 categorical outcomes would require:
4 bytes/float32-logit
* 1000 predictions/slide
* 3000 tiles/slide
* 10000 slides
~= 112 GB
"""
def __init__(
self,
model: Union[str, "tf.keras.models.Model", "torch.nn.Module"],
dataset: "sf.Dataset",
*,
labels: Optional[Labels] = None,
cache: Optional[str] = None,
annotations: Optional[Labels] = None,
**kwargs: Any
) -> None:
"""Calculate features / layer activations from model, storing to
internal parameters ``self.activations``, and ``self.predictions``,
``self.locations``, dictionaries mapping slides to arrays of activations,
predictions, and locations for each tiles' constituent tiles.
Args:
model (str): Path to model from which to calculate activations.
dataset (:class:`slideflow.Dataset`): Dataset from which to
generate activations.
labels (dict, optional): Dict mapping slide names to outcome
categories.
cache (str, optional): File for PKL cache.
Keyword Args:
augment (bool, str, optional): Whether to use data augmentation
during feature extraction. If True, will use default
augmentation. If str, will use augmentation specified by the
string. Defaults to None.
batch_size (int): Batch size for activations calculations.
Defaults to 32.
device (str, optional): Device to use for feature extraction.
Only used for PyTorch feature extractors. Defaults to None.
include_preds (bool): Calculate and store predictions.
Defaults to True.
include_uncertainty (bool, optional): Whether to include model
uncertainty in the output. Only used if the feature generator
is a UQ-enabled model. Defaults to True.
layers (str, list(str)): Layers to extract features from. May be
the name of a single layer (str) or a list of layers (list).
Only used if model is a str. Defaults to 'postconv'.
normalizer ((str or :class:`slideflow.norm.StainNormalizer`), optional):
Stain normalization strategy to use on image tiles prior to
feature extraction. This argument is invalid if ``model`` is a
feature extractor built from a trained model, as stain
normalization will be specified by the model configuration.
Defaults to None.
normalizer_source (str, optional): Stain normalization preset
or path to a source image. Valid presets include 'v1', 'v2',
and 'v3'. If None, will use the default present ('v3').
This argument is invalid if ``model`` is a feature extractor
built from a trained model. Defaults to None.
num_workers (int, optional): Number of workers to use for feature
extraction. Only used for PyTorch feature extractors. Defaults
to None.
pool_sort (bool): Use multiprocessing pools to perform final
sorting. Defaults to True.
progress (bool): Show a progress bar during feature calculation.
Defaults to True.
transform (Callable, optional): Custom transform to apply to
images. Applied before standardization. If the feature extractor
is a PyTorch model, the transform should be a torchvision
transform.
verbose (bool): Show verbose logging output. Defaults to True.
Examples
Calculate features using a feature extractor.
.. code-block:: python
import slideflow as sf
# Create a feature extractor
retccl = sf.build_feature_extractor('retccl', resize=True)
# Load a dataset
P = sf.load_project(...)
dataset = P.dataset(...)
# Calculate features
dts_ftrs = sf.DatasetFeatures(retccl, dataset)
Calculate features using a trained model (preferred).
.. code-block:: python
import slideflow as sf
# Create a feature extractor from the saved model.
extractor = sf.build_feature_extractor(
'/path/to/trained_model.zip',
layers=['postconv']
)
# Calculate features across the dataset
dts_ftrs = sf.DatasetFeatures(extractor, dataset)
Calculate features using a trained model (legacy).
.. code-block:: python
# This method is deprecated, and will be removed in a
# future release. Please use the method above instead.
dts_ftrs = sf.DatasetFeatures(
'/path/to/trained_model.zip',
dataset=dataset,
layers=['postconv']
)
Calculate features from a loaded model.
.. code-block:: python
import tensorflow as tf
import slideflow as sf
# Load a model
model = tf.keras.models.load_model('/path/to/model.h5')
# Calculate features
dts_ftrs = sf.DatasetFeatures(
model,
layers=['postconv'],
dataset
)
"""
self.activations = defaultdict(list) # type: Dict[str, Any]
self.predictions = defaultdict(list) # type: Dict[str, Any]
self.uncertainty = defaultdict(list) # type: Dict[str, Any]
self.locations = defaultdict(list) # type: Dict[str, Any]
self.num_features = 0
self.num_classes = 0
self.model = model
self.dataset = dataset
self.feature_generator = None
if dataset is not None:
self.tile_px = dataset.tile_px
self.manifest = dataset.manifest()
self.tfrecords = np.array(dataset.tfrecords())
else:
# Used when creating via DatasetFeatures.from_df(),
# otherwise dataset should not be None.
self.tile_px = None
self.manifest = dict()
self.tfrecords = []
self.slides = sorted([sf.util.path_to_name(t) for t in self.tfrecords])
if labels is not None and annotations is not None:
raise DeprecationWarning(
'Cannot supply both "labels" and "annotations" to sf.DatasetFeatures. '
'"annotations" is deprecated and has been replaced with "labels".'
)
elif annotations is not None:
warnings.warn(
'The "annotations" argument to sf.DatasetFeatures is deprecated.'
'Please use the argument "labels" instead.',
DeprecationWarning
)
self.labels = annotations
else:
self.labels = labels
if self.labels:
self.categories = list(set(self.labels.values()))
if self.activations:
for slide in self.slides:
try:
if self.activations[slide]:
used = (self.used_categories
+ [self.labels[slide]])
self.used_categories = list(set(used)) # type: List[Union[str, int, List[float]]]
self.used_categories.sort()
except KeyError:
raise KeyError(f"Slide {slide} not in labels.")
total = len(self.used_categories)
cat_list = ", ".join([str(c) for c in self.used_categories])
log.debug(f'Observed categories (total: {total}): {cat_list}')
else:
self.categories = []
self.used_categories = []
# Load from PKL (cache) if present
if cache and exists(cache):
self.load_cache(cache)
# Otherwise will need to generate new activations from a given model
elif model is not None:
self._generate_features(cache=cache, **kwargs)
# Now delete slides not included in our filtered TFRecord list
loaded_slides = list(self.activations.keys())
for loaded_slide in loaded_slides:
if loaded_slide not in self.slides:
log.debug(
f'Removing activations from slide {loaded_slide} '
'slide not in the filtered tfrecords list'
)
self.remove_slide(loaded_slide)
# Now screen for missing slides in activations
missing = []
for slide in self.slides:
if slide not in self.activations:
missing += [slide]
elif not len(self.activations[slide]):
missing += [slide]
num_loaded = len(self.slides)-len(missing)
log.debug(
f'Loaded activations from {num_loaded}/{len(self.slides)} '
f'slides ({len(missing)} missing)'
)
if missing:
log.warning(f'Activations missing for {len(missing)} slides')
# Record which categories have been included in the specified tfrecords
if self.categories and self.labels:
self.used_categories = list(set([
self.labels[slide]
for slide in self.slides
]))
self.used_categories.sort()
total = len(self.used_categories)
cat_list = ", ".join([str(c) for c in self.used_categories])
log.debug(f'Observed categories (total: {total}): {cat_list}')
# Show total number of features
if self.num_features is None:
self.num_features = self.activations[self.slides[0]].shape[-1]
log.debug(f'Number of activation features: {self.num_features}')
@classmethod
def from_df(cls, df: "pd.core.frame.DataFrame") -> "DatasetFeatures":
"""Load DataFrame of features, as exported by :meth:`DatasetFeatures.to_df()`
Args:
df (:class:`pandas.DataFrame`): DataFrame of features, as exported by
:meth:`DatasetFeatures.to_df()`
Returns:
:class:`DatasetFeatures`: DatasetFeatures object
Examples
Recreate DatasetFeatures after export to a DataFrame.
>>> df = features.to_df()
>>> new_features = DatasetFeatures.from_df(df)
"""
obj = cls(None, None) # type: ignore
obj.slides = df.slide.unique().tolist()
if 'activations' in df.columns:
obj.activations = {
s: np.stack(df.loc[df.slide==s].activations.values)
for s in obj.slides
}
obj.num_features = next(df.iterrows())[1].activations.shape[0]
if 'locations' in df.columns:
obj.locations = {
s: np.stack(df.loc[df.slide==s].locations.values)
for s in obj.slides
}
if 'uncertainty' in df.columns:
obj.uncertainty = {
s: np.stack(df.loc[df.slide==s].uncertainty.values)
for s in obj.slides
}
if 'predictions' in df.columns:
obj.predictions = {
s: np.stack(df.loc[df.slide==s].predictions.values)
for s in obj.slides
}
obj.num_classes = next(df.iterrows())[1].predictions.shape[0]
return obj
@classmethod
def from_bags(cls, bags: str) -> "DatasetFeatures":
"""Load a DatasetFeatures object from a directory of bags.
Args:
bags (str): Path to bags, as exported by :meth:`DatasetFeatures.to_torch()`
Returns:
:class:`DatasetFeatures`: DatasetFeatures object
"""
import torch
slides = [sf.util.path_to_name(b) for b in os.listdir(bags) if b.endswith('.pt')]
obj = cls(None, None)
obj.slides = slides
for slide in slides:
activations = torch.load(join(bags, f'{slide}.pt'))
obj.activations[slide] = activations.numpy()
obj.locations[slide] = tfrecord2idx.load_index(join(bags, f'{slide}.index'))
return obj
@classmethod
def concat(
cls,
args: Iterable["DatasetFeatures"],
) -> "DatasetFeatures":
"""Concatenate activations from multiple DatasetFeatures together.
For example, if ``df1`` is a DatasetFeatures object with 2048 features
and ``df2`` is a DatasetFeatures object with 1024 features,
then ``sf.DatasetFeatures.concat([df1, df2])`` would return an object
with 3072.
Vectors from DatasetFeatures objects are concatenated in the given order.
During concatenation, predictions and uncertainty are dropped.
If there are any tiles that do not have calculated features in both
dataframes, these will be dropped.
Args:
args (Iterable[:class:`DatasetFeatures`]): DatasetFeatures objects
to concatenate.
Returns:
:class:`DatasetFeatures`: DatasetFeatures object with concatenated
features.
Examples
Concatenate two DatasetFeatures objects.
>>> df1 = DatasetFeatures(model, dataset, layers='postconv')
>>> df2 = DatasetFeatures(model, dataset, layers='sepconv_3')
>>> df = DatasetFeatures.concat([df1, df2])
"""
assert len(args) > 1
dfs = []
for f, ftrs in enumerate(args):
log.debug(f"Creating dataframe {f} from features...")
dfs.append(ftrs.to_df())
if not all([len(df) == len(dfs[0]) for df in dfs]):
raise ValueError(
"Unable to concatenate DatasetFeatures of different lengths "
f"(got: {', '.join([str(len(_df)) for _df in dfs])})"
)
log.debug(f"Created {len(dfs)} dataframes")
for i in range(len(dfs)):
log.debug(f"Mapping tuples for df {i}")
dfs[i]['locations'] = dfs[i]['locations'].map(tuple)
for i in range(1, len(dfs)):
log.debug(f"Merging dataframe {i}")
dfs[0] = pd.merge(
dfs[0],
dfs[i],
how='inner',
left_on=['slide', 'locations', 'tfr_index'],
right_on=['slide', 'locations', 'tfr_index'],
suffixes=['_1', '_2']
)
log.debug("Dropping merged columns")
to_drop = [c for c in dfs[0].columns
if ('predictions' in c or 'uncertainty' in c)]
dfs[0].drop(columns=to_drop, inplace=True)
log.debug("Concatenating activations")
act1 = np.stack(dfs[0]['activations_1'].values)
act2 = np.stack(dfs[0]['activations_2'].values)
log.debug(f"Act 1 shape: {act1.shape}")
log.debug(f"Act 2 shape: {act2.shape}")
concatenated = np.concatenate((act1, act2), axis=1)
as_list = [_c for _c in concatenated]
dfs[0]['activations'] = as_list
log.debug("Dropping old columns")
dfs[0].drop(columns=['activations_1', 'activations_2'], inplace=True)
log.debug("Sorting by TFRecord index")
dfs[0].sort_values('tfr_index', inplace=True)
log.debug("Creating DatasetFeatures object")
return DatasetFeatures.from_df(dfs[0])
@property
def uq(self) -> bool:
if self.feature_generator is None:
return None
else:
return self.feature_generator.uq
@property
def normalizer(self):
if self.feature_generator is None:
return None
else:
return self.feature_generator.normalizer
def _generate_features(
self,
cache: Optional[str] = None,
progress: bool = True,
verbose: bool = True,
pool_sort: bool = True,
pb: Optional[Progress] = None,
**kwargs
) -> None:
"""Calculates activations from a given model, saving to self.activations
Args:
model (str): Path to Tensorflow model from which to calculate final
layer activations.
layers (str, optional): Layers from which to generate activations.
Defaults to 'postconv'.
include_preds (bool, optional): Include logit predictions.
Defaults to True.
include_uncertainty (bool, optional): Include uncertainty
estimation if UQ enabled. Defaults to True.
batch_size (int, optional): Batch size to use during activations
calculations. Defaults to 32.
progress (bool): Show a progress bar during feature calculation.
Defaults to True.
verbose (bool): Show verbose logging output. Defaults to True.
pool_sort (bool): Use multiprocessing pools to perform final
sorting. Defaults to True.
cache (str, optional): File in which to store PKL cache.
"""
fg = self.feature_generator = _FeatureGenerator(
self.model,
self.dataset,
**kwargs
)
self.num_features = fg.num_features
self.num_classes = fg.num_classes
# Calculate final layer activations for each tfrecord
fla_start_time = time.time()
activations, predictions, locations, uncertainty = fg.generate(
progress=progress, pb=pb, verbose=verbose
)
self.activations = {s: np.stack(v) for s, v in activations.items()}
self.predictions = {s: np.stack(v) for s, v in predictions.items()}
self.locations = {s: np.stack(v) for s, v in locations.items()}
self.uncertainty = {s: np.stack(v) for s, v in uncertainty.items()}
# Sort using TFRecord location information,
# to ensure dictionary indices reflect TFRecord indices
if fg.tfrecords_have_loc:
slides_to_sort = [
s for s in self.slides
if (self.activations[s].size
or not self.predictions[s].size
or not self.locations[s].size
or not self.uncertainty[s].size)
]
if pool_sort and len(slides_to_sort) > 1:
pool = mp.Pool(sf.util.num_cpu())
imap_iterable = pool.imap(
self.dataset.get_tfrecord_locations, slides_to_sort
)
else:
pool = None
imap_iterable = map(
self.dataset.get_tfrecord_locations, slides_to_sort
)
if progress and not pb:
iterable = track(
imap_iterable,
transient=False,
total=len(slides_to_sort),
description="Sorting...")
else:
iterable = imap_iterable
for i, true_locs in enumerate(iterable):
slide = slides_to_sort[i]
# Get the order of locations stored in TFRecords,
# and the corresponding indices for sorting
cur_locs = self.locations[slide]
idx = [true_locs.index(tuple(cur_locs[i])) for i in range(cur_locs.shape[0])]
# Make sure that the TFRecord indices are continuous, otherwise
# our sorted indices will be inaccurate
assert max(idx)+1 == len(idx)
# Final sorting
sorted_idx = np.argsort(idx)
if slide in self.activations:
self.activations[slide] = self.activations[slide][sorted_idx]
if slide in self.predictions:
self.predictions[slide] = self.predictions[slide][sorted_idx]
if slide in self.uncertainty:
self.uncertainty[slide] = self.uncertainty[slide][sorted_idx]
self.locations[slide] = self.locations[slide][sorted_idx]
if pool is not None:
pool.close()
fla_calc_time = time.time()
log.debug(f'Calculation time: {fla_calc_time-fla_start_time:.0f} sec')
log.debug(f'Number of activation features: {self.num_features}')
if cache:
self.save_cache(cache)
def activations_by_category(
self,
idx: int
) -> Dict[Union[str, int, List[float]], np.ndarray]:
"""For each outcome category, calculates activations of a given
feature across all tiles in the category. Requires annotations to
have been provided.
Args:
idx (int): Index of activations layer to return, stratified by
outcome category.
Returns:
dict: Dict mapping categories to feature activations for all
tiles in the category.
"""
if not self.categories:
raise errors.FeaturesError(
'Unable to calculate by category; annotations not provided.'
)
def act_by_cat(c):
return np.concatenate([
self.activations[pt][:, idx]
for pt in self.slides
if self.labels[pt] == c
])
return {c: act_by_cat(c) for c in self.used_categories}
def box_plots(self, features: List[int], outdir: str) -> None:
"""Generates plots comparing node activations at slide- and tile-level.
Args:
features (list(int)): List of feature indices for which to
generate box plots.
outdir (str): Path to directory in which to save box plots.
"""
import matplotlib.pyplot as plt
import seaborn as sns
if not isinstance(features, list):
raise ValueError("'features' must be a list of int.")
if not self.categories:
log.warning('Unable to generate box plots; no annotations loaded.')
return
if not os.path.exists(outdir):
os.makedirs(outdir)
_, _, category_stats = self.stats()
log.info('Generating box plots...')
for f in features:
# Display tile-level box plots & stats
plt.clf()
boxplot_data = list(self.activations_by_category(f).values())
snsbox = sns.boxplot(data=boxplot_data)
title = f'{f} (tile-level)'
snsbox.set_title(title)
snsbox.set(xlabel='Category', ylabel='Activation')
plt.xticks(plt.xticks()[0], self.used_categories)
boxplot_filename = join(outdir, f'boxplot_{title}.png')
plt.gcf().canvas.start_event_loop(sys.float_info.min)
plt.savefig(boxplot_filename, bbox_inches='tight')
# Print slide_level box plots & stats
plt.clf()
snsbox = sns.boxplot(data=[c[:, f] for c in category_stats])
title = f'{f} (slide-level)'
snsbox.set_title(title)
snsbox.set(xlabel='Category', ylabel='Average tile activation')
plt.xticks(plt.xticks()[0], self.used_categories)
boxplot_filename = join(outdir, f'boxplot_{title}.png')
plt.gcf().canvas.start_event_loop(sys.float_info.min)
plt.savefig(boxplot_filename, bbox_inches='tight')
def dump_config(self):
"""Return a dictionary of the feature extraction configuration."""
if self.normalizer:
norm_dict = dict(
method=self.normalizer.method,
fit=self.normalizer.get_fit(as_list=True),
)
else:
norm_dict = None
config = dict(
extractor=self.feature_generator.generator.dump_config(),
normalizer=norm_dict,
num_features=self.num_features,
tile_px=self.dataset.tile_px,
tile_um=self.dataset.tile_um
)
return config
def export_to_torch(self, *args, **kwargs):
"""Deprecated function; please use `.to_torch()`"""
warnings.warn(
"Deprecation warning: DatasetFeatures.export_to_torch() will"
" be removed in a future version. Use .to_torch() instead.",
DeprecationWarning
)
self.to_torch(*args, **kwargs)
def save_cache(self, path: str):
"""Cache calculated activations to file.
Args:
path (str): Path to pkl.
"""
with open(path, 'wb') as pt_pkl_file:
pickle.dump(
[self.activations,
self.predictions,
self.uncertainty,
self.locations],
pt_pkl_file
)
log.info(f'Data cached to [green]{path}')
def to_csv(
self,
filename: str,
level: str = 'tile',
method: str = 'mean',
slides: Optional[List[str]] = None
):
"""Exports calculated activations to csv.
Args:
filename (str): Path to CSV file for export.
level (str): 'tile' or 'slide'. Indicates whether tile or
slide-level activations are saved. Defaults to 'tile'.
method (str): Method of summarizing slide-level results. Either
'mean' or 'median'. Defaults to 'mean'.
slides (list(str)): Slides to export. If None, exports all slides.
Defaults to None.
"""
if level not in ('tile', 'slide'):
raise errors.FeaturesError(f"Export error: unknown level {level}")
meth_fn = {'mean': np.mean, 'median': np.median}
slides = self.slides if not slides else slides
with open(filename, 'w') as outfile:
csvwriter = csv.writer(outfile)
logit_header = [f'Class_{log}' for log in range(self.num_classes)]
feature_header = [f'Feature_{f}' for f in range(self.num_features)]
header = ['Slide'] + logit_header + feature_header
csvwriter.writerow(header)
for slide in track(slides):
if level == 'tile':
for i, tile_act in enumerate(self.activations[slide]):
if self.num_classes and self.predictions[slide] != []:
csvwriter.writerow(
[slide]
+ self.predictions[slide][i].tolist()
+ tile_act.tolist()
)
else:
csvwriter.writerow([slide] + tile_act.tolist())
else:
act = meth_fn[method](
self.activations[slide],
axis=0
).tolist()
if self.num_classes and self.predictions[slide] != []:
logit = meth_fn[method](
self.predictions[slide],
axis=0
).tolist()
csvwriter.writerow([slide] + logit + act)
else:
csvwriter.writerow([slide] + act)
log.debug(f'Activations saved to [green]{filename}')
def to_torch(
self,
outdir: str,
slides: Optional[List[str]] = None,
verbose: bool = True
) -> None:
"""Export activations in torch format to .pt files in the directory.
Used for training MIL models.
Args:
outdir (str): Path to directory in which to save .pt files.
verbose (bool): Verbose logging output. Defaults to True.
"""
import torch
if not exists(outdir):
os.makedirs(outdir)
slides = self.slides if not slides else slides
for slide in (slides if not verbose else track(slides)):
if not len(self.activations[slide]):
log.info(f'Skipping empty slide [green]{slide}')
continue
slide_activations = torch.from_numpy(
self.activations[slide].astype(np.float32)
)
torch.save(slide_activations, join(outdir, f'{slide}.pt'))
tfrecord2idx.save_index(
self.locations[slide],
join(outdir, f'{slide}.index')
)
# Log the feature extraction configuration
config = self.dump_config()
if exists(join(outdir, 'bags_config.json')):
old_config = sf.util.load_json(join(outdir, 'bags_config.json'))
if old_config != config:
log.warning(
"Feature extraction configuration does not match the "
"configuration used to generate the existing bags at "
f"{outdir}. Current configuration will not be saved."
)
else:
sf.util.write_json(config, join(outdir, 'bags_config.json'))
log_fn = log.info if verbose else log.debug
log_fn(f'Activations exported in Torch format to {outdir}')
def to_df(
self
) -> pd.core.frame.DataFrame:
"""Export activations, predictions, uncertainty, and locations to
a pandas DataFrame.
Returns:
pd.core.frame.DataFrame: Dataframe with columns 'activations',
'predictions', 'uncertainty', and 'locations'.
"""
index = [s for s in self.slides
for _ in range(len(self.locations[s]))]
df_dict = dict()
df_dict.update({
'locations': pd.Series([
self.locations[s][i]
for s in self.slides
for i in range(len(self.locations[s]))], index=index)
})
df_dict.update({
'tfr_index': pd.Series([
i
for s in self.slides
for i in range(len(self.locations[s]))], index=index)
})
if self.activations:
df_dict.update({
'activations': pd.Series([
self.activations[s][i]
for s in self.slides
for i in range(len(self.activations[s]))], index=index)
})
if self.predictions:
df_dict.update({
'predictions': pd.Series([
self.predictions[s][i]
for s in self.slides
for i in range(len(self.predictions[s]))], index=index)
})
if self.uncertainty:
df_dict.update({
'uncertainty': pd.Series([
self.uncertainty[s][i]
for s in self.slides
for i in range(len(self.uncertainty[s]))], index=index)
})
df = pd.DataFrame(df_dict)
df['slide'] = df.index
return df
def load_cache(self, path: str):
"""Load cached activations from PKL.
Args:
path (str): Path to pkl cache.
"""
log.info(f'Loading from cache [green]{path}...')
with open(path, 'rb') as pt_pkl_file:
loaded_pkl = pickle.load(pt_pkl_file)
self.activations = loaded_pkl[0]
self.predictions = loaded_pkl[1]
self.uncertainty = loaded_pkl[2]
self.locations = loaded_pkl[3]
if self.activations:
self.num_features = self.activations[self.slides[0]].shape[-1]
if self.predictions:
self.num_classes = self.predictions[self.slides[0]].shape[-1]
def stats(
self,
outdir: Optional[str] = None,
method: str = 'mean',
threshold: float = 0.5
) -> Tuple[Dict[int, Dict[str, float]],
Dict[int, Dict[str, float]],
List[np.ndarray]]:
"""Calculates activation averages across categories, as well as
tile-level and patient-level statistics, using ANOVA, exporting to
CSV if desired.
Args:
outdir (str, optional): Path to directory in which CSV file will
be saved. Defaults to None.
method (str, optional): Indicates method of aggregating tile-level
data into slide-level data. Either 'mean' (default) or
'threshold'. If mean, slide-level feature data is calculated by
averaging feature activations across all tiles. If threshold,
slide-level feature data is calculated by counting the number
of tiles with feature activations > threshold and dividing by
the total number of tiles. Defaults to 'mean'.
threshold (float, optional): Threshold if using 'threshold' method.
Returns:
A tuple containing
dict: Dict mapping slides to dict of slide-level features;
dict: Dict mapping features to tile-level statistics ('p', 'f');
dict: Dict mapping features to slide-level statistics ('p', 'f');
"""
if not self.categories:
raise errors.FeaturesError('No annotations loaded')
if method not in ('mean', 'threshold'):
raise errors.FeaturesError(f"Stats method {method} unknown")
if not self.labels:
raise errors.FeaturesError("No annotations provided, unable"
"to calculate feature stats.")
log.info('Calculating activation averages & stats across features...')
tile_stats = {}
pt_stats = {}
category_stats = []
activation_stats = {}
for slide in self.slides:
if method == 'mean':
# Mean of each feature across tiles
summarized = np.mean(self.activations[slide], axis=0)
elif method == 'threshold':
# For each feature, count number of tiles with value above
# threshold, divided by number of tiles
act_sum = np.sum((self.activations[slide] > threshold), axis=0)
summarized = act_sum / self.activations[slide].shape[-1]
activation_stats[slide] = summarized
for c in self.used_categories:
category_stats += [np.array([
activation_stats[slide]
for slide in self.slides
if self.labels[slide] == c
])]
for f in range(self.num_features):
# Tile-level ANOVA
stats_vals = list(self.activations_by_category(f).values())
with warnings.catch_warnings():
if hasattr(stats, "F_onewayConstantInputWarning"):
warnings.simplefilter(
"ignore",
category=stats.F_onewayConstantInputWarning)
elif hasattr(stats, "ConstantInputWarning"):
warnings.simplefilter(
"ignore",
category=stats.ConstantInputWarning)
fvalue, pvalue = stats.f_oneway(*stats_vals)
if not isnan(fvalue) and not isnan(pvalue):
tile_stats.update({f: {'f': fvalue,
'p': pvalue}})
else:
tile_stats.update({f: {'f': -1,
'p': 1}})
# Patient-level ANOVA
fvalue, pvalue = stats.f_oneway(*[c[:, f] for c in category_stats])
if not isnan(fvalue) and not isnan(pvalue):
pt_stats.update({f: {'f': fvalue,
'p': pvalue}})
else:
pt_stats.update({f: {'f': -1,
'p': 1}})
try:
pt_sorted_ft = sorted(
range(self.num_features),
key=lambda f: pt_stats[f]['p']
)
except Exception:
log.warning('No stats calculated; unable to sort features.')
for f in range(self.num_features):
try:
log.debug(f"Tile-level P-value ({f}): {tile_stats[f]['p']}")
log.debug(f"Patient-level P-value: ({f}): {pt_stats[f]['p']}")
except Exception:
log.warning(f'No stats calculated for feature {f}')
# Export results
if outdir:
if not exists(outdir):
os.makedirs(outdir)
filename = join(outdir, 'slide_level_summary.csv')
log.info(f'Writing results to [green]{filename}[/]...')
with open(filename, 'w') as outfile:
csv_writer = csv.writer(outfile)
header = (['slide', 'category']
+ [f'Feature_{n}' for n in pt_sorted_ft])
csv_writer.writerow(header)
for slide in self.slides:
category = self.labels[slide]
row = ([slide, category]
+ list(activation_stats[slide][pt_sorted_ft]))
csv_writer.writerow(row)
if tile_stats:
csv_writer.writerow(
['Tile statistic', 'ANOVA P-value']
+ [tile_stats[n]['p'] for n in pt_sorted_ft]
)
csv_writer.writerow(
['Tile statistic', 'ANOVA F-value']
+ [tile_stats[n]['f'] for n in pt_sorted_ft]
)
if pt_stats:
csv_writer.writerow(
['Slide statistic', 'ANOVA P-value']
+ [pt_stats[n]['p'] for n in pt_sorted_ft]
)
csv_writer.writerow(
['Slide statistic', 'ANOVA F-value']
+ [pt_stats[n]['f'] for n in pt_sorted_ft]
)
return tile_stats, pt_stats, category_stats
def softmax_mean(self) -> Dict[str, np.ndarray]:
"""Calculates the mean prediction vector (post-softmax) across
all tiles in each slide.
Returns:
dict: This is a dictionary mapping slides to the mean logits
array for all tiles in each slide.
"""
return {s: np.mean(v, axis=0) for s, v in self.predictions.items()}
def softmax_percent(
self,
prediction_filter: Optional[List[int]] = None
) -> Dict[str, np.ndarray]:
"""Returns dictionary mapping slides to a vector of length num_classes
with the percent of tiles in each slide predicted to be each outcome.
Args:
prediction_filter: (optional) List of int. If provided, will
restrict predictions to only these categories, with final
prediction being based based on highest logit among these
categories.
Returns:
dict: This is a dictionary mapping slides to an array of
percentages for each logit, of length num_classes
"""
if prediction_filter:
assert isinstance(prediction_filter, list) and all([
isinstance(i, int)
for i in prediction_filter
])
assert max(prediction_filter) <= self.num_classes
else:
prediction_filter = list(range(self.num_classes))
slide_percentages = {}
for slide in self.predictions:
# Find the index of the highest prediction for each tile, only for
# logits within prediction_filter
tile_pred = np.argmax(
self.predictions[slide][:, prediction_filter],
axis=1
)
slide_perc = np.array([
np.count_nonzero(tile_pred == logit) / len(tile_pred)
for logit in range(self.num_classes)
])
slide_percentages.update({slide: slide_perc})
return slide_percentages
def softmax_predict(
self,
prediction_filter: Optional[List[int]] = None
) -> Dict[str, int]:
"""Returns slide-level predictions, assuming the model is predicting a
categorical outcome, by generating a prediction for each individual
tile, and making a slide-level prediction by finding the most
frequently predicted outcome among its constituent tiles.
Args:
prediction_filter: (optional) List of int. If provided, will
restrict predictions to only these categories, with final
prediction based based on highest logit among these categories.
Returns:
dict: Dictionary mapping slide names to slide-level predictions.
"""
if prediction_filter:
assert isinstance(prediction_filter, list)
assert all([isinstance(i, int) for i in prediction_filter])
assert max(prediction_filter) <= self.num_classes
else:
prediction_filter = list(range(self.num_classes))
slide_predictions = {}
for slide in self.predictions:
# Find the index of the highest prediction for each tile, only for
# logits within prediction_filter
tile_pred = np.argmax(
self.predictions[slide][:, prediction_filter],
axis=1
)
slide_perc = np.array([
np.count_nonzero(tile_pred == logit) / len(tile_pred)
for logit in range(self.num_classes)
])
slide_predictions.update({slide: int(np.argmax(slide_perc))})
return slide_predictions
def map_activations(self, **kwargs) -> "sf.SlideMap":
"""Map activations with UMAP.
Keyword args:
...
Returns:
sf.SlideMap
"""
return sf.SlideMap.from_features(self, **kwargs)
def map_predictions(
self,
x: int = 0,
y: int = 0,
**kwargs
) -> "sf.SlideMap":
"""Map tile predictions onto x/y coordinate space.
Args:
x (int, optional): Outcome category id for which predictions will
be mapped to the X-axis. Defaults to 0.
y (int, optional): Outcome category id for which predictions will
be mapped to the Y-axis. Defaults to 0.
Keyword args:
cache (str, optional): Path to parquet file to cache coordinates.
Defaults to None (caching disabled).
Returns:
sf.SlideMap
"""
all_x, all_y, all_slides, all_tfr_idx = [], [], [], []
for slide in self.slides:
all_x.append(self.predictions[slide].values[:, x])
all_y.append(self.predictions[slide].values[:, y])
all_slides.append([slide for _ in range(self.predictions[slide].shape[0])])
all_tfr_idx.append(np.arange(self.predictions[slide].shape[0]))
all_x = np.concatenate(all_x)
all_y = np.concatenate(all_y)
all_slides = np.concatenate(all_slides)
all_tfr_idx = np.concatenate(all_tfr_idx)
return sf.SlideMap.from_xy(
x=all_x,
y=all_y,
slides=all_slides,
tfr_index=all_tfr_idx,
**kwargs
)
def merge(self, df: "DatasetFeatures") -> None:
'''Merges with another DatasetFeatures.
Args:
df (slideflow.DatasetFeatures): TargetDatasetFeatures
to merge with.
Returns:
None
'''
self.activations.update(df.activations)
self.predictions.update(df.predictions)
self.uncertainty.update(df.uncertainty)
self.locations.update(df.locations)
self.tfrecords = np.concatenate([self.tfrecords, df.tfrecords])
self.slides = list(self.activations.keys())
def remove_slide(self, slide: str) -> None:
"""Removes slide from calculated features."""
if slide in self.activations:
del self.activations[slide]
if slide in self.predictions:
del self.predictions[slide]
if slide in self.uncertainty:
del self.uncertainty[slide]
if slide in self.locations:
del self.locations[slide]
self.tfrecords = np.array([
t for t in self.tfrecords
if sf.util.path_to_name(t) != slide
])
if slide in self.slides:
self.slides.remove(slide)
def save_example_tiles(
self,
features: List[int],
outdir: str,
slides: Optional[List[str]] = None,
tiles_per_feature: int = 100
) -> None:
"""For a set of activation features, saves image tiles named according
to their corresponding activations.
Duplicate image tiles will be saved for each feature, organized into
subfolders named according to feature.
Args:
features (list(int)): Features to evaluate.
outdir (str): Path to folder in which to save examples tiles.
slides (list, optional): List of slide names. If provided, will
only include tiles from these slides. Defaults to None.
tiles_per_feature (int, optional): Number of tiles to include as
examples for each feature. Defaults to 100. Will evenly sample
this many tiles across the activation gradient.
"""
if not isinstance(features, list):
raise ValueError("'features' must be a list of int.")
if not slides:
slides = self.slides
for f in features:
if not exists(join(outdir, str(f))):
os.makedirs(join(outdir, str(f)))
gradient_list = []
for slide in slides:
for i, val in enumerate(self.activations[slide][:, f]):
gradient_list += [{
'val': val,
'slide': slide,
'index': i
}]
gradient = np.array(sorted(gradient_list, key=lambda k: k['val']))
sample_idx = np.linspace(
0,
gradient.shape[0]-1,
num=tiles_per_feature,
dtype=int
)
for i, g in track(enumerate(gradient[sample_idx]),
total=tiles_per_feature,
description=f"Feature {f}"):
for tfr in self.tfrecords:
if sf.util.path_to_name(tfr) == g['slide']:
tfr_dir = tfr
if not tfr_dir:
log.warning("TFRecord location not found for "
f"slide {g['slide']}")
slide, image = sf.io.get_tfrecord_by_index(tfr_dir, g['index'])
tile_filename = (f"{i}-tfrecord{g['slide']}-{g['index']}"
+ f"-{g['val']:.2f}.jpg")
image_string = open(join(outdir, str(f), tile_filename), 'wb')
image_string.write(image.numpy())
image_string.close()
# --- Deprecated functions ----------------------------------------------------
def logits_mean(self):
warnings.warn(
"DatasetFeatures.logits_mean() is deprecated. Please use "
"DatasetFeatures.softmax_mean()", DeprecationWarning
)
return self.softmax_mean()
def logits_percent(self, *args, **kwargs):
warnings.warn(
"DatasetFeatures.logits_percent() is deprecated. Please use "
"DatasetFeatures.softmax_percent()", DeprecationWarning
)
return self.softmax_percent(*args, **kwargs)
def logits_predict(self, *args, **kwargs):
warnings.warn(
"DatasetFeatures.logits_predict() is deprecated. Please use "
"DatasetFeatures.softmax_predict()", DeprecationWarning
)
return self.softmax_predict(*args, **kwargs)
# -----------------------------------------------------------------------------
class _FeatureGenerator:
"""Provides common API for feature generator interfaces."""
def __init__(
self,
model: Union[str, "BaseFeatureExtractor", "tf.keras.models.Model", "torch.nn.Module"],
dataset: "sf.Dataset",
*,
layers: Union[str, List[str]] = 'postconv',
include_preds: Optional[bool] = None,
include_uncertainty: bool = True,
batch_size: int = 32,
device: Optional[str] = None,
num_workers: Optional[int] = None,
augment: Optional[Union[bool, str]] = None,
transform: Optional[Callable] = None,
**kwargs
) -> None:
"""Initializes FeatureGenerator.
Args:
model (str, BaseFeatureExtractor, tf.keras.models.Model, torch.nn.Module):
Model to use for feature extraction. If str, must be a path to
a saved model.
dataset (sf.Dataset): Dataset to use for feature extraction.
Keyword Args:
augment (bool, str, optional): Whether to use data augmentation
during feature extraction. If True, will use default
augmentation. If str, will use augmentation specified by the
string. Defaults to None.
batch_size (int, optional): Batch size to use for feature
extraction. Defaults to 32.
device (str, optional): Device to use for feature extraction.
Only used for PyTorch feature extractors. Defaults to None.
include_preds (bool, optional): Whether to include model
predictions. If None, will be set to True if
model has a num_classes attribute. Defaults to None.
include_uncertainty (bool, optional): Whether to include model
uncertainty in the output. Only used if the feature generator
is a UQ-enabled model. Defaults to True.
layers (str, list(str)): Layers to extract features from. May be
the name of a single layer (str) or a list of layers (list).
Only used if model is a str. Defaults to 'postconv'.
normalizer ((str or :class:`slideflow.norm.StainNormalizer`), optional):
Stain normalization strategy to use on image tiles prior to
feature extraction. This argument is invalid if ``model`` is a
feature extractor built from a trained model, as stain
normalization will be specified by the model configuration.
Defaults to None.
normalizer_source (str, optional): Stain normalization preset
or path to a source image. Valid presets include 'v1', 'v2',
and 'v3'. If None, will use the default present ('v3').
This argument is invalid if ``model`` is a feature extractor
built from a trained model. Defaults to None.
num_workers (int, optional): Number of workers to use for feature
extraction. Only used for PyTorch feature extractors. Defaults
to None.
transform (Callable, optional): Custom transform to apply to
images. Applied before standardization. If the feature extractor
is a PyTorch model, the transform should be a torchvision
transform.
"""
self.model = model
self.dataset = dataset
self.layers = sf.util.as_list(layers)
self.batch_size = batch_size
self.simclr_args = None
self.num_workers = num_workers
self.augment = augment
self.transform = transform
# Check if location information is stored in TFRecords
self.tfrecords_have_loc = self.dataset.tfrecords_have_locations()
if not self.tfrecords_have_loc:
log.warning(
"Some TFRecords do not have tile location information; "
"dataset iteration speed may be affected."
)
if self.is_extractor() and include_preds is None:
include_preds = self.model.num_classes > 0 # type: ignore
elif include_preds is None:
include_preds = True
self.include_preds = include_preds
self.include_uncertainty = include_uncertainty
# Determine UQ and stain normalization.
# If the `model` is a feature extractor, stain normalization
# will be determined via keyword arguments by self._prepare_generator()
self._determine_uq_and_normalizer()
self.generator = self._prepare_generator(**kwargs)
self.num_features = self.generator.num_features
self.num_classes = 0 if not include_preds else self.generator.num_classes
if self.is_torch() and hasattr(self.model, 'device'):
from slideflow.model import torch_utils
self.device = self.model.device or torch_utils.get_device(device)
elif self.is_torch():
from slideflow.model import torch_utils
self.device = torch_utils.get_device(device)
else:
self.device = None
self._prepare_dataset_kwargs()
# Move the normalizer to the appropriate device, if this is
# a pytorch GPU normalizer.
if self.has_torch_gpu_normalizer():
log.debug("Moving normalizer to device: {}".format(self.device))
self.normalizer.device = self.device
def _calculate_feature_batch(self, batch_img):
"""Calculate features from a batch of images."""
# If a PyTorch generator, wrap in inference_mode() and perform on CUDA
if self.is_torch():
import torch
with torch.inference_mode():
batch_img = batch_img.to(self.device)
if self.has_torch_gpu_normalizer():
batch_img = self.normalizer.preprocess(
batch_img.to(self.normalizer.device),
standardize=self.standardize
).to(self.device)
return self.generator(batch_img)
else:
if self.has_torch_gpu_normalizer():
import torch
import tensorflow as tf
batch_img = batch_img.numpy()
batch_img = torch.from_numpy(batch_img)
batch_img = self.normalizer.transform(
batch_img.to(self.normalizer.device)
)
batch_img = batch_img.cpu().numpy()
batch_img = tf.convert_to_tensor(batch_img)
if self.standardize:
batch_img = tf.image.per_image_standardization(batch_img)
return self.generator(batch_img)
def _process_out(self, model_out, batch_slides, batch_loc):
model_out = sf.util.as_list(model_out)
# Process data if the output is Tensorflow (SimCLR or Tensorflow model)
if self.is_tf():
slides = [
bs.decode('utf-8')
for bs in batch_slides.numpy()
]
model_out = [
m.numpy() if not isinstance(m, (list, tuple)) else m
for m in model_out
]
if batch_loc[0] is not None:
loc = np.stack([
batch_loc[0].numpy(),
batch_loc[1].numpy()
], axis=1)
else:
loc = None
# Process data if the output is PyTorch
elif self.is_torch():
slides = batch_slides
model_out = [
m.cpu().numpy() if not isinstance(m, list) else m
for m in model_out
]
if batch_loc[0] is not None:
loc = np.stack([batch_loc[0], batch_loc[1]], axis=1)
else:
loc = None
# Final processing.
# Order of return is features, predictions, uncertainty.
if self.uq and self.include_uncertainty:
uncertainty = model_out[-1]
model_out = model_out[:-1]
else:
uncertainty = None
if self.include_preds:
predictions = model_out[-1]
features = model_out[:-1]
else:
predictions = None
features = model_out
# Concatenate features if we have features from >1 layer
if isinstance(features, list):
features = np.concatenate(features, axis=1)
return features, predictions, uncertainty, slides, loc
def _prepare_dataset_kwargs(self):
"""Prepare keyword arguments for Dataset.tensorflow() or .torch()."""
dts_kw = {
'infinite': False,
'batch_size': self.batch_size,
'augment': self.augment,
'transform': self.transform,
'incl_slidenames': True,
'incl_loc': True,
}
# If this is a Feature Extractor, update the dataset kwargs
# with any preprocessing instructions specified by the extractor
if self.is_extractor():
dts_kw.update(self.model.preprocess_kwargs)
# Establish standardization.
self.standardize = ('standardize' not in dts_kw or dts_kw['standardize'])
# Check if normalization is happening on GPU with PyTorch.
# If so, we will handle normalization and standardization
# in the feature generation loop.
if self.has_torch_gpu_normalizer():
log.debug("Using GPU for stain normalization")
dts_kw['standardize'] = False
else:
# Otherwise, let the dataset handle normalization/standardization.
dts_kw['normalizer'] = self.normalizer
# This is not used by SimCLR feature extractors.
self.dts_kw = dts_kw
def _determine_uq_and_normalizer(self):
"""Determines whether the model uses UQ and its stain normalizer."""
# Load configuration if model is path to a saved model
if isinstance(self.model, BaseFeatureExtractor):
self.uq = self.model.num_uncertainty > 0
# If the feature extractor has a normalizer, use it.
# This will be overridden by keyword arguments if the
# feature extractor is not an instance of slideflow.model.Features.
self.normalizer = self.model.normalizer
elif isinstance(self.model, str):
model_config = sf.util.get_model_config(self.model)
hp = sf.ModelParams.from_dict(model_config['hp'])
self.uq = hp.uq
self.normalizer = hp.get_normalizer()
if self.normalizer:
log.debug(f'Using realtime {self.normalizer.method} normalization')
if 'norm_fit' in model_config:
self.normalizer.set_fit(**model_config['norm_fit'])
else:
self.normalizer = None
self.uq = False
def _norm_from_kwargs(self, kwargs):
"""Parse the stain normalizer from keyword arguments."""
if 'normalizer' in kwargs and kwargs['normalizer'] is not None:
norm = kwargs['normalizer']
del kwargs['normalizer']
if 'normalizer_source' in kwargs:
norm_src = kwargs['normalizer_source']
del kwargs['normalizer_source']
else:
norm_src = None
if isinstance(norm, str):
normalizer = sf.norm.autoselect(
norm,
source=norm_src,
backend='tensorflow' if self.is_tf() else 'torch'
)
else:
normalizer = norm
log.debug(f"Normalizing with {normalizer.method}")
return normalizer, kwargs
if 'normalizer' in kwargs:
del kwargs['normalizer']
if 'normalizer_source' in kwargs:
del kwargs['normalizer_source']
return None, kwargs
def _prepare_generator(self, **kwargs) -> Callable:
"""Prepare the feature generator."""
# Generator is a Feature Extractor
if self.is_extractor():
# Handle the case where the extractor is built from a trained model
if self.is_tf():
from slideflow.model.tensorflow import Features as TFFeatures
is_tf_model_extractor = isinstance(self.model, TFFeatures)
is_torch_model_extractor = False
elif self.is_torch():
from slideflow.model.torch import Features as TorchFeatures
is_torch_model_extractor = isinstance(self.model, TorchFeatures)
is_tf_model_extractor = False
else:
is_tf_model_extractor = False
is_torch_model_extractor = False
if (is_tf_model_extractor or is_torch_model_extractor) and 'normalizer' in kwargs:
raise ValueError(
"Cannot specify a normalizer when using a feature extractor "
"created from a trained model. Stain normalization is auto-detected "
"from the model configuration."
)
elif (is_tf_model_extractor or is_torch_model_extractor) and kwargs:
raise ValueError(
f"Invalid keyword arguments: {', '.join(list(kwargs.keys()))}"
)
elif (is_tf_model_extractor or is_torch_model_extractor):
# Stain normalization has already been determined
# from the model configuration.
return self.model
# For all other feature extractors, stain normalization
# is determined from keyword arguments.
self.normalizer, kwargs = self._norm_from_kwargs(kwargs)
if kwargs:
raise ValueError(
f"Invalid keyword arguments: {', '.join(list(kwargs.keys()))}"
)
return self.model
# Generator is a path to a trained model, and we're using UQ
elif self.is_model_path() and (self.uq and self.include_uncertainty):
if self.include_preds is False:
raise ValueError(
"include_preds must be True if include_uncertainty is True"
)
return sf.model.UncertaintyInterface(
self.model,
layers=self.layers,
**kwargs
)
# Generator is a path to a trained Slideflow model
elif self.is_model_path():
return sf.model.Features(
self.model,
layers=self.layers,
include_preds=self.include_preds,
**kwargs
)
# Generator is a loaded Tensorflow model
elif self.is_tf():
return sf.model.Features.from_model(
self.model,
layers=self.layers,
include_preds=self.include_preds,
**kwargs
)
# Generator is a loaded torch.nn.Module
elif self.is_torch():
return sf.model.Features.from_model(
self.model.to(self.device),
tile_px=self.tile_px,
layers=self.layers,
include_preds=self.include_preds,
**kwargs
)
# Unrecognized feature extractor
else:
raise ValueError(f'Unrecognized feature extractor {self.model}')
def is_model_path(self):
return isinstance(self.model, str) and (self.is_tf() or self.is_torch())
def is_extractor(self):
return isinstance(self.model, BaseFeatureExtractor)
def is_torch(self):
if self.is_extractor():
return self.model.is_torch()
else:
return sf.model.is_torch_model(self.model)
def is_tf(self):
if self.is_extractor():
return self.model.is_tensorflow()
else:
return sf.model.is_tensorflow_model(self.model)
def has_torch_gpu_normalizer(self):
return (
isinstance(self.normalizer, sf.norm.StainNormalizer)
and self.normalizer.__class__.__name__ == 'TorchStainNormalizer'
and self.normalizer.device != 'cpu'
)
def build_dataset(self):
"""Build a dataloader."""
# Generator is a Tensorflow model.
if self.is_tf():
log.debug(
"Setting up Tensorflow dataset iterator (num_parallel_reads="
f"None, deterministic={not self.tfrecords_have_loc})"
)
# Disable parallel reads if we're using tfrecords without location
# information, as we would need to read and receive data in order.
if not self.tfrecords_have_loc:
par_kw = dict(num_parallel_reads=None)
else:
par_kw = dict()
return self.dataset.tensorflow(
None,
deterministic=(not self.tfrecords_have_loc),
**par_kw,
**self.dts_kw # type: ignore
)
# Generator is a PyTorch model.
elif self.is_torch():
if self.num_workers is None:
n_workers = (4 if self.tfrecords_have_loc else 1)
else:
n_workers = self.num_workers
log.debug(
"Setting up PyTorch dataset iterator (num_workers="
f"{n_workers}, chunk_size=8)"
)
return self.dataset.torch(
None,
num_workers=n_workers,
chunk_size=8,
**self.dts_kw # type: ignore
)
# Unrecognized feature generator.
else:
raise ValueError(f"Unrecognized model type: {type(self.model)}")
def generate(
self,
*,
verbose: bool = True,
progress: bool = True,
pb: Optional[Progress] = None,
):
# Get the dataloader for iterating through tfrecords
dataset = self.build_dataset()
# Rename tfrecord_array to tfrecords
log_fn = log.info if verbose else log.debug
log_fn(f'Calculating activations for {len(self.dataset.tfrecords())} '
'tfrecords')
log_fn(f'Generating from [green]{self.model}')
# Interleave tfrecord datasets
estimated_tiles = self.dataset.num_tiles
activations = defaultdict(list) # type: Dict[str, Any]
predictions = defaultdict(list) # type: Dict[str, Any]
uncertainty = defaultdict(list) # type: Dict[str, Any]
locations = defaultdict(list) # type: Dict[str, Any]
# Worker to process activations/predictions, for more efficient throughput
q = queue.Queue() # type: queue.Queue
def batch_worker():
while True:
model_out, batch_slides, batch_loc = q.get()
if model_out is None:
return
features, preds, unc, slides, loc = self._process_out(
model_out, batch_slides, batch_loc
)
for d, slide in enumerate(slides):
if self.layers:
activations[slide].append(features[d])
if self.include_preds and preds is not None:
predictions[slide].append(preds[d])
if self.uq and self.include_uncertainty:
uncertainty[slide].append(unc[d])
if loc is not None:
locations[slide].append(loc[d])
batch_proc_thread = threading.Thread(target=batch_worker, daemon=True)
batch_proc_thread.start()
if progress and not pb:
pb = Progress(*Progress.get_default_columns(),
ImgBatchSpeedColumn(),
transient=sf.getLoggingLevel()>20)
task = pb.add_task("Generating...", total=estimated_tiles)
pb.start()
elif pb:
task = 0
progress = False
else:
pb = None
with sf.util.cleanup_progress((pb if progress else None)):
for batch_img, _, batch_slides, batch_loc_x, batch_loc_y in dataset:
model_output = self._calculate_feature_batch(batch_img)
q.put((model_output, batch_slides, (batch_loc_x, batch_loc_y)))
if pb:
pb.advance(task, self.batch_size)
q.put((None, None, None))
batch_proc_thread.join()
if hasattr(dataset, 'close'):
dataset.close()
return activations, predictions, locations, uncertainty
# -----------------------------------------------------------------------------
def _export_bags(
model: Union[Callable, Dict],
dataset: "sf.Dataset",
slides: List[str],
slide_batch_size: int,
pb: Any,
outdir: str,
slide_task: int = 0,
**dts_kwargs
) -> None:
"""Export bags for a given feature extractor."""
for slide_batch in sf.util.batch(slides, slide_batch_size):
try:
_dataset = dataset.remove_filter(filters='slide')
except errors.DatasetFilterError:
_dataset = dataset
_dataset = _dataset.filter(filters={'slide': slide_batch})
if not len(_dataset.tfrecords()):
continue
df = sf.DatasetFeatures(model, _dataset, pb=pb, **dts_kwargs)
df.to_torch(outdir, verbose=False)
pb.advance(slide_task, len(slide_batch))
def _distributed_export(
device: int,
model_cfg: Dict,
dataset: "sf.Dataset",
slides: List[List[str]],
slide_batch_size: int,
pb: Any,
outdir: str,
slide_task: int = 0,
dts_kwargs: Any = None,
mixed_precision: Optional[bool] = None,
channels_last: Optional[bool] = None
) -> None:
"""Distributed export across multiple GPUs."""
model = sf.model.extractors.build_extractor_from_cfg(model_cfg, device=f'cuda:{device}')
if mixed_precision is not None:
model.mixed_precision = mixed_precision
if channels_last is not None:
model.channels_last = channels_last
return _export_bags(
model,
dataset,
list(slides[device]),
slide_batch_size,
pb,
outdir,
slide_task,
**(dts_kwargs or {})
)