Source code for slideflow.io.torch.data_utils
"""Data utilities for Torch datasets."""
import pandas as pd
import numpy as np
from slideflow import errors
from slideflow.util import tfrecord2idx, to_onehot
from slideflow.io.io_utils import detect_tfrecord_format
from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterable,
Optional, Tuple, Union)
from .augment import compose_augmentations
from .img_utils import decode_image
if TYPE_CHECKING:
from slideflow.norm import StainNormalizer
# -------------------------------------------------------------------------
FEATURE_DESCRIPTION = {
'image_raw': 'byte',
'slide': 'byte',
'loc_x': 'int',
'loc_y': 'int'
}
# -------------------------------------------------------------------------
def process_labels(
labels: Optional[Dict[str, Any]] = None,
onehot: bool = False
) -> Tuple[Optional[Union[Dict[str, Any], pd.DataFrame]],
Optional[np.ndarray],
Optional[np.ndarray],
int]:
"""Analyze labels to determine unique labels, label probabilities, and
number of outcomes.
Args:
labels (dict): Dict mapping slide names to labels.
onehot (bool, optional): Onehot encode outcomes. Defaults to False.
Returns:
labels (dict): Dict mapping slide names to labels.
unique_labels (np.ndarray): Unique labels.
label_prob (np.ndarray): Label probabilities.
num_outcomes (int): Number of outcomes.
"""
# Weakly supervised labels from slides.
if labels is not None and not isinstance(labels, (str, pd.DataFrame)):
if onehot:
_all_labels_raw = np.array(list(labels.values()))
_unique_raw = np.unique(_all_labels_raw)
max_label = np.max(_unique_raw)
labels = {
k: to_onehot(v, max_label+1) # type: ignore
for k, v in labels.items()
}
num_outcomes = 1
else:
first_label = list(labels.values())[0]
if not isinstance(first_label, list):
num_outcomes = 1
else:
num_outcomes = len(first_label)
_all_labels = np.array(list(labels.values()))
unique_labels = np.unique(_all_labels, axis=0)
_lbls = np.array([
np.sum(_all_labels == i)
for i in unique_labels
])
label_prob = _lbls / len(_all_labels)
# Strongly supervised tile labels from a dataframe.
elif isinstance(labels, (pd.DataFrame, str)):
if isinstance(labels, str):
df = pd.read_parquet(labels)
else:
df = labels
if 'label' not in df.columns:
raise ValueError('Could not find column "label" in the '
f'tile labels dataframe at {labels}.')
labels = df
unique_labels = None
label_prob = None
num_outcomes = 1
else:
unique_labels = None
label_prob = None # type: ignore
num_outcomes = 1
return labels, unique_labels, label_prob, num_outcomes
# -------------------------------------------------------------------------
def load_index(tfr):
if isinstance(tfr, bytes):
tfr = tfr.decode('utf-8')
try:
index = tfrecord2idx.load_index(tfr)
except OSError:
raise errors.TFRecordsError(
f"Could not find index path for TFRecord {tfr}"
)
return index
[docs]def read_and_return_record(
record: bytes,
parser: Callable,
assign_slide: Optional[str] = None
) -> Dict:
"""Process raw TFRecord bytes into a format that can be written with
``tf.io.TFRecordWriter``.
Args:
record (bytes): Raw TFRecord bytes (unparsed)
parser (Callable): TFRecord parser, as returned by
:func:`sf.io.get_tfrecord_parser()`
assign_slide (str, optional): Slide name to override the record with.
Defaults to None.
Returns:
Dictionary mapping record key to a tuple containing (bytes, dtype).
"""
parsed = parser(record)
if assign_slide:
parsed['slide'] = assign_slide
parsed['slide'] = parsed['slide'].encode('utf-8')
return {k: (v, FEATURE_DESCRIPTION[k]) for k, v in parsed.items()}
[docs]def serialized_record(
slide: bytes,
image_raw: bytes,
loc_x: int = 0,
loc_y: int = 0
):
"""Returns a serialized example for TFRecord storage, ready to be written
by a TFRecordWriter."""
example = {
'image_raw': (image_raw, FEATURE_DESCRIPTION['image_raw']),
'slide': (slide, FEATURE_DESCRIPTION['slide']),
'loc_x': (loc_x, FEATURE_DESCRIPTION['loc_x']),
'loc_y': (loc_y, FEATURE_DESCRIPTION['loc_y']),
}
return example
[docs]def get_tfrecord_parser(
tfrecord_path: str,
features_to_return: Iterable[str] = None,
decode_images: bool = True,
standardize: bool = False,
normalizer: Optional["StainNormalizer"] = None,
augment: bool = False,
**kwargs
) -> Callable:
"""Gets tfrecord parser using dareblopy reader. Torch implementation;
different than sf.io.tensorflow
Args:
tfrecord_path (str): Path to tfrecord to parse.
features_to_return (list or dict, optional): Designates format for how
features should be returned from parser. If a list of feature names
is provided, the parsing function will return tfrecord features as
a list in the order provided. If a dictionary of labels (keys)
mapping to feature names (values) is provided, features will be
returned from the parser as a dictionary matching the same format.
If None, will return all features as a list.
decode_images (bool, optional): Decode raw image strings into image
arrays. Defaults to True.
standardize (bool, optional): Standardize images into the range (0,1).
Defaults to False.
normalizer (:class:`slideflow.norm.StainNormalizer`): Stain normalizer
to use on images. Defaults to None.
augment (str or bool): Image augmentations to perform. Augmentations include:
* ``'x'``: Random horizontal flip
* ``'y'``: Random vertical flip
* ``'r'``: Random 90-degree rotation
* ``'j'``: Random JPEG compression (50% chance to compress with quality between 50-100)
* ``'b'``: Random Gaussian blur (10% chance to blur with sigma between 0.5-2.0)
Combine letters to define augmentations, such as ``'xyrjn'``.
A value of True will use ``'xyrjb'``.
Note: this function does not support stain augmentation.
Returns:
A tuple containing
func: Parsing function
dict: Detected feature description for the tfrecord
"""
features, img_type = detect_tfrecord_format(tfrecord_path)
if features is None or img_type is None:
raise errors.TFRecordsError(f"Unable to read TFRecord {tfrecord_path}")
if features_to_return is None:
features_to_return = {k: k for k in features}
elif not all(f in features for f in features_to_return):
detected = ",".join(features)
_ftrs = list(features_to_return.keys()) # type: ignore
raise errors.TFRecordsError(
f'Not all features {",".join(_ftrs)} '
f'were found in the tfrecord {detected}'
)
# Build the transformations / augmentations.
transform = compose_augmentations(
augment=augment,
standardize=standardize,
normalizer=normalizer,
whc=True
)
parser = TFRecordParser(
features_to_return,
decode_images,
img_type,
transform
)
return parser
# -------------------------------------------------------------------------
class TFRecordParser:
def __init__(self, features_to_return, decode_images, img_type, transform=None):
self.features_to_return = features_to_return
self.decode_images = decode_images
self.img_type = img_type
self.transform = transform
def __call__(self, record):
"""Each item in args is an array with one item, as the dareblopy reader
returns items in batches and we have set our batch_size = 1 for
interleaving.
"""
features = {}
if ('slide' in self.features_to_return):
slide = bytes(record['slide']).decode('utf-8')
features['slide'] = slide
if ('image_raw' in self.features_to_return):
img = bytes(record['image_raw'])
if self.decode_images:
features['image_raw'] = decode_image(
img,
img_type=self.img_type,
transform=self.transform
)
else:
features['image_raw'] = img
if ('loc_x' in self.features_to_return):
features['loc_x'] = record['loc_x'][0]
if ('loc_y' in self.features_to_return):
features['loc_y'] = record['loc_y'][0]
if type(self.features_to_return) == dict:
return {
label: features[f]
for label, f in self.features_to_return.items()
}
else:
return [features[f] for f in self.features_to_return]