Shortcuts

Source code for slideflow.dataset

"""Module for the ``Dataset`` class and its associated functions.

The ``Dataset`` class handles management of collections of patients,
clinical annotations, slides, extracted tiles, and assembly of images
into torch DataLoader and tensorflow Dataset objects. The high-level
overview of the structure of ``Dataset`` is as follows:


 ──────────── Information Methods ───────────────────────────────
   Annotations      Slides        Settings         TFRecords
  ┌──────────────┐ ┌─────────┐   ┌──────────────┐ ┌──────────────┐
  │Patient       │ │Paths to │   │Tile size (px)│ | *.tfrecords  |
  │Slide         │ │ slides  │   │Tile size (um)│ |  (generated) |
  │Label(s)      │ └─────────┘   └──────────────┘ └──────────────┘
  │ - Categorical│  .slides()     .tile_px         .tfrecords()
  │ - Continuous │  .rois()       .tile_um         .manifest()
  │ - Time Series│  .slide_paths()                 .num_tiles
  └──────────────┘  .thumbnails()                  .img_format
    .patients()
    .rois()
    .labels()
    .harmonize_labels()
    .is_float()


 ─────── Filtering and Splitting Methods ──────────────────────
  ┌────────────────────────────┐
  │                            │
  │ ┌─────────┐                │ .filter()
  │ │Filtered │                │ .remove_filter()
  │ │ Dataset │                │ .clear_filters()
  │ └─────────┘                │ .split()
  │               Full Dataset │
  └────────────────────────────┘


 ───────── Summary of Image Data Flow ──────────────────────────
  ┌──────┐
  │Slides├─────────────┐
  └──┬───┘             │
     │                 │
     ▼                 │
  ┌─────────┐          │
  │TFRecords├──────────┤
  └──┬──────┘          │
     │                 │
     ▼                 ▼
  ┌────────────────┐ ┌─────────────┐
  │torch DataLoader│ │Loose images │
  │ / tf Dataset   │ │ (.png, .jpg)│
  └────────────────┘ └─────────────┘

 ──────── Slide Processing Methods ─────────────────────────────
  ┌──────┐
  │Slides├───────────────┐
  └──┬───┘               │
     │.extract_tiles()   │.extract_tiles(
     ▼                   │    save_tiles=True
  ┌─────────┐            │  )
  │TFRecords├────────────┤
  └─────────┘            │ .extract_tiles
                         │  _from_tfrecords()

                       ┌─────────────┐
                       │Loose images │
                       │ (.png, .jpg)│
                       └─────────────┘


 ─────────────── TFRecords Operations ─────────────────────────
                      ┌─────────┐
   ┌────────────┬─────┤TFRecords├──────────┐
   │            │     └─────┬───┘          │
   │.tfrecord   │.tfrecord  │ .balance()   │.resize_tfrecords()
   │  _heatmap()│  _report()│ .clip()      │.split_tfrecords
   │            │           │ .torch()     │  _by_roi()
   │            │           │ .tensorflow()│
   ▼            ▼           ▼              ▼
  ┌───────┐ ┌───────┐ ┌────────────────┐┌─────────┐
  │Heatmap│ │PDF    │ │torch DataLoader││TFRecords│
  └───────┘ │ Report│ │ / tf Dataset   │└─────────┘
            └───────┘ └────────────────┘
"""

import copy
import csv
import multiprocessing as mp
import os
import shutil
import threading
import time
import types
import tempfile
import warnings
from contextlib import contextmanager
from collections import defaultdict
from datetime import datetime
from glob import glob
from multiprocessing.dummy import Pool as DPool
from os.path import basename, dirname, exists, isdir, join
from queue import Queue
from random import shuffle
from tabulate import tabulate  # type: ignore[import]
from pprint import pformat
from functools import partial
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple,
                    Union, Callable)
import numpy as np
import pandas as pd
import shapely.geometry as sg
from rich.progress import track, Progress
from tqdm import tqdm

import slideflow as sf
from slideflow import errors
from slideflow.slide import WSI, ExtractionReport, SlideReport
from slideflow.util import (log, Labels, _shortname, path_to_name,
                            tfrecord2idx, TileExtractionProgress, MultiprocessProgress)

if TYPE_CHECKING:
    import tensorflow as tf
    import cellpose
    from slideflow.model import BaseFeatureExtractor
    from slideflow.model import ModelParams
    from torch.utils.data import DataLoader
    from slideflow.norm import StainNormalizer

# -----------------------------------------------------------------------------


def _prepare_slide(
    path: str,
    report_dir: Optional[str],
    wsi_kwargs: Dict,
    qc: Optional[str],
    qc_kwargs: Dict,
) -> Optional["sf.WSI"]:

    try:
        slide = sf.WSI(path, **wsi_kwargs)
        if qc:
            slide.qc(method=qc, **qc_kwargs)
        return slide
    except errors.MissingROIError:
        log.debug(f'Missing ROI for slide {path}; skipping')
        return None
    except errors.IncompatibleBackendError:
        log.error('Slide {} has type {}, which is incompatible with the active '
                  'slide reading backend, {}. Consider using a different '
                  'backend, which can be set with the environmental variable '
                  'SF_SLIDE_BACKEND. See https://slideflow.dev/installation/#cucim-vs-libvips '
                  'for more information.'.format(
                    path,
                    sf.util.path_to_ext(path).upper(),
                    sf.slide_backend()
                  ))
    except errors.SlideLoadError as e:
        log.error(f'Error loading slide {path}: {e}. Skipping')
        return None
    except errors.QCError as e:
        log.error(e)
        return None
    except errors.TileCorruptionError:
        log.error(f'{path} corrupt; skipping')
        return None
    except (KeyboardInterrupt, SystemExit) as e:
        print('Exiting...')
        raise e
    except Exception as e:
        log.error(f'Error processing slide {path}: {e}. Skipping')
        return None


@contextmanager
def _handle_slide_errors(path: str):
    try:
        yield
    except errors.MissingROIError:
        log.info(f'Missing ROI for slide {path}; skipping')
    except errors.SlideLoadError as e:
        log.error(f'Error loading slide {path}: {e}. Skipping')
    except errors.QCError as e:
        log.error(e)
    except errors.TileCorruptionError:
        log.error(f'{path} corrupt; skipping')
    except (KeyboardInterrupt, SystemExit) as e:
        print('Exiting...')
        raise e


def _tile_extractor(
    path: str,
    tfrecord_dir: str,
    tiles_dir: str,
    reports: Dict,
    qc: str,
    wsi_kwargs: Dict,
    generator_kwargs: Dict,
    qc_kwargs: Dict,
    render_thumb: bool = True
) -> None:
    """Extract tiles. Internal function.

    Slide processing needs to be process-isolated when num_workers > 1 .

    Args:
        tfrecord_dir (str): Path to TFRecord directory.
        tiles_dir (str): Path to tiles directory (loose format).
        reports (dict): Multiprocessing-enabled dict.
        qc (bool): Quality control method.
        wsi_kwargs (dict): Keyword arguments for sf.WSI.
        generator_kwargs (dict): Keyword arguments for WSI.extract_tiles()
        qc_kwargs(dict): Keyword arguments for quality control.
    """
    with _handle_slide_errors(path):
        log.debug(f'Extracting tiles for {path_to_name(path)}')
        slide = _prepare_slide(
            path,
            report_dir=tfrecord_dir,
            wsi_kwargs=wsi_kwargs,
            qc=qc,
            qc_kwargs=qc_kwargs)
        if slide is not None:
            report = slide.extract_tiles(
                tfrecord_dir=tfrecord_dir,
                tiles_dir=tiles_dir,
                **generator_kwargs
            )
            if render_thumb and isinstance(report, SlideReport):
                _ = report.thumb
            reports.update({path: report})


def _buffer_slide(path: str, dest: str) -> str:
    """Buffer a slide to a path."""
    buffered = join(dest, basename(path))
    shutil.copy(path, buffered)

    # If this is an MRXS file, copy the associated folder.
    if path.lower().endswith('mrxs'):
        folder_path = join(dirname(path), path_to_name(path))
        if exists(folder_path):
            shutil.copytree(folder_path, join(dest, path_to_name(path)))
        else:
            log.debug("Could not find associated MRXS folder for slide buffer")

    return buffered


def _debuffer_slide(path: str) -> None:
    """De-buffer a slide."""
    os.remove(path)
    # If this is an MRXS file, remove the associated folder.
    if path.lower().endswith('mrxs'):
        folder_path = join(dirname(path), path_to_name(path))
        if exists(folder_path):
            shutil.rmtree(folder_path)
        else:
            log.debug("Could not find MRXS folder for slide debuffer")


def _fill_queue(
    slide_list: Sequence[str],
    q: Queue,
    q_size: int,
    buffer: Optional[str] = None
) -> None:
    """Fill a queue with slide paths, using an optional buffer."""
    for path in slide_list:
        warned = False
        if buffer:
            while True:
                if q.qsize() < q_size:
                    try:
                        q.put(_buffer_slide(path, buffer))
                        break
                    except OSError:
                        if not warned:
                            slide = _shortname(path_to_name(path))
                            log.debug(f'OSError for {slide}: buffer full?')
                            log.debug(f'Queue size: {q.qsize()}')
                            warned = True
                        time.sleep(1)
                else:
                    time.sleep(1)
        else:
            q.put(path)
    q.put(None)
    q.join()


def _count_otsu_tiles(wsi):
    wsi.qc('otsu')
    return wsi.estimated_num_tiles


def _create_index(tfrecord, force=False):
    index_name = join(
        dirname(tfrecord),
        path_to_name(tfrecord)+'.index'
    )
    if not tfrecord2idx.find_index(tfrecord) or force:
        tfrecord2idx.create_index(tfrecord, index_name)


def _get_tile_df(
    slide_path: str,
    tile_px: int,
    tile_um: Union[int, str],
    rois: Optional[List[str]],
    stride_div: int,
    roi_method: str
) -> pd.DataFrame:
    try:
        wsi = sf.WSI(
        slide_path,
        tile_px,
        tile_um,
        rois=rois,
        stride_div=stride_div,
        roi_method=roi_method,
        verbose=False
    )
    except Exception as e:
        log.warning("Skipping slide {}, error raised: {}".format(
            path_to_name(slide_path), e
        ))
        return None
    _df = wsi.get_tile_dataframe()
    _df['slide'] = wsi.name
    return _df

# -----------------------------------------------------------------------------

def split_patients_preserved_site(
    patients_dict: Dict[str, Dict],
    n: int,
    balance: Optional[str] = None,
    method: str = 'auto'
) -> List[List[str]]:
    """Split a dictionary of patients into n groups, with site balancing.

    Splits are balanced according to key "balance", while preserving site.

    Args:
        patients_dict (dict): Nested dictionary mapping patient names to
            dict of outcomes: labels
        n (int): Number of splits to generate.
        balance (str): Annotation header to balance splits across.
        method (str): Solver method. 'auto', 'cplex', or 'bonmin'. If 'auto',
            will use CPLEX if availabe, otherwise will default to pyomo/bonmin.

    Returns:
        List of patient splits
    """
    patient_list = list(patients_dict.keys())
    shuffle(patient_list)

    def flatten(arr):
        """Flatten an array."""
        return [y for x in arr for y in x]

    # Get patient outcome labels
    if balance is not None:
        patient_outcome_labels = [
            patients_dict[p][balance] for p in patient_list
        ]
    else:
        patient_outcome_labels = [1 for _ in patient_list]
    # Get unique outcomes
    unique_labels = list(set(patient_outcome_labels))
    n_unique = len(set(unique_labels))
    # Delayed import in case CPLEX not installed
    import slideflow.io.preservedsite.crossfolds as cv

    site_list = [patients_dict[p]['site'] for p in patient_list]
    df = pd.DataFrame(
        list(zip(patient_list, patient_outcome_labels, site_list)),
        columns=['patient', 'outcome_label', 'site']
    )
    df = cv.generate(
        df, 'outcome_label', k=n, target_column='CV', method=method
    )
    log.info("[bold]Train/val split with Preserved-Site Cross-Val")
    log.info("[bold]Category\t" + "\t".join(
        [str(cat) for cat in range(n_unique)]
    ))
    for k in range(n):
        def num_labels_matching(o):
            match = df[(df.CV == str(k+1)) & (df.outcome_label == o)]
            return str(len(match))
        matching = [num_labels_matching(o) for o in unique_labels]
        log.info(f"K-fold-{k}\t" + "\t".join(matching))
    splits = [
        df.loc[df.CV == str(ni+1), "patient"].tolist()
        for ni in range(n)
    ]
    return splits


def split_patients_balanced(
    patients_dict: Dict[str, Dict],
    n: int,
    balance: str
) -> List[List[str]]:
    """Split a dictionary of patients into n groups, balancing by outcome.

    Splits are balanced according to key "balance".

    Args:
        patients_dict (dict): Nested ditionary mapping patient names to
            dict of outcomes: labels
        n (int): Number of splits to generate.
        balance (str): Annotation header to balance splits across.

    Returns:
        List of patient splits
    """
    patient_list = list(patients_dict.keys())
    shuffle(patient_list)

    def flatten(arr):
        """Flatten an array."""
        return [y for x in arr for y in x]

    # Get patient outcome labels
    patient_outcome_labels = [
        patients_dict[p][balance] for p in patient_list
    ]
    # Get unique outcomes
    unique_labels = list(set(patient_outcome_labels))
    n_unique = len(set(unique_labels))

    # Now, split patient_list according to outcomes
    pt_by_outcome = [
        [p for p in patient_list if patients_dict[p][balance] == uo]
        for uo in unique_labels
    ]
    # Then, for each sublist, split into n components
    pt_by_outcome_by_n = [
        list(sf.util.split_list(sub_l, n)) for sub_l in pt_by_outcome
    ]
    # Print splitting as a table
    log.info(
        "[bold]Category\t" + "\t".join([str(cat) for cat in range(n_unique)])
    )
    for k in range(n):
        matching = [str(len(clist[k])) for clist in pt_by_outcome_by_n]
        log.info(f"K-fold-{k}\t" + "\t".join(matching))
    # Join sublists
    splits = [
        flatten([
            item[ni] for item in pt_by_outcome_by_n
        ]) for ni in range(n)
    ]
    return splits


def split_patients(patients_dict: Dict[str, Dict], n: int) -> List[List[str]]:
    """Split a dictionary of patients into n groups.

    Args:
        patients_dict (dict): Nested ditionary mapping patient names to
            dict of outcomes: labels
        n (int): Number of splits to generate.

    Returns:
        List of patient splits
    """
    patient_list = list(patients_dict.keys())
    shuffle(patient_list)
    return list(sf.util.split_list(patient_list, n))

# -----------------------------------------------------------------------------


[docs]class Dataset: """Supervises organization and processing of slides, tfrecords, and tiles. Datasets can be comprised of one or more sources, where a source is a combination of slides and any associated regions of interest (ROI) and extracted image tiles (stored as TFRecords or loose images). Datasets can be created in two ways: either by loading one dataset source, or by loading a dataset configuration that contains information about multiple dataset sources. For the first approach, the dataset source configuration is provided via keyword arguments (``tiles``, ``tfrecords``, ``slides``, and ``roi``). Each is a path to a directory containing the respective data. For the second approach, the first argument ``config`` is either a nested dictionary containing the configuration for multiple dataset sources, or a path to a JSON file with this information. The second argument is a list of dataset sources to load (keys from the ``config`` dictionary). With either approach, slide/patient-level annotations are provided through the ``annotations`` keyword argument, which can either be a path to a CSV file, or a pandas DataFrame, which must contain at minimum the column '`patient`'. """ def __init__( self, config: Optional[Union[str, Dict[str, Dict[str, str]]]] = None, sources: Optional[Union[str, List[str]]] = None, tile_px: Optional[int] = None, tile_um: Optional[Union[str, int]] = None, *, tfrecords: Optional[str] = None, tiles: Optional[str] = None, roi: Optional[str] = None, slides: Optional[str] = None, filters: Optional[Dict] = None, filter_blank: Optional[Union[List[str], str]] = None, annotations: Optional[Union[str, pd.DataFrame]] = None, min_tiles: int = 0, ) -> None: """Initialize a Dataset to organize processed images. Examples Load a dataset via keyword arguments. .. code-block:: python dataset = Dataset( tfrecords='../path', slides='../path', annotations='../file.csv' ) Load a dataset configuration file and specify dataset source(s). .. code-block:: python dataset = Dataset( config='../path/to/config.json', sources=['Lung_Adeno', 'Lung_Squam'], annotations='../file.csv ) Args: config (str, dict): Either a dictionary or a path to a JSON file. If a dictionary, keys should be dataset source names, and values should be dictionaries containing the keys 'tiles', 'tfrecords', 'roi', and/or 'slides', specifying directories for each dataset source. If `config` is a str, it should be a path to a JSON file containing a dictionary with the same formatting. If None, tiles, tfrecords, roi and/or slides should be manually provided via keyword arguments. Defaults to None. sources (List[str]): List of dataset sources to include from configuration. If not provided, will use all sources in the provided configuration. Defaults to None. tile_px (int): Tile size in pixels. tile_um (int or str): Tile size in microns (int) or magnification (str, e.g. "20x"). Keyword args: filters (dict, optional): Dataset filters to use for selecting slides. See :meth:`slideflow.Dataset.filter` for more information. Defaults to None. filter_blank (list(str) or str, optional): Skip slides that have blank values in these patient annotation columns. Defaults to None. min_tiles (int, optional): Only include slides with this many tiles at minimum. Defaults to 0. annotations (str or pd.DataFrame, optional): Path to annotations file or pandas DataFrame with slide-level annotations. Defaults to None. Raises: errors.SourceNotFoundError: If provided source does not exist in the dataset config. """ if isinstance(tile_um, str): sf.util.assert_is_mag(tile_um) tile_um = tile_um.lower() self.tile_px = tile_px self.tile_um = tile_um self._filters = filters if filters else {} if filter_blank is None: self._filter_blank = [] else: self._filter_blank = sf.util.as_list(filter_blank) self._min_tiles = min_tiles self._clip = {} # type: Dict[str, int] self.prob_weights = None # type: Optional[Dict] self._annotations = None # type: Optional[pd.DataFrame] self.annotations_file = None # type: Optional[str] if (any(arg is not None for arg in (tfrecords, tiles, roi, slides)) and (config is not None or sources is not None)): raise ValueError( "When initializing a Dataset object via keywords (tiles, " "tfrecords, slides, roi), the arguments 'config' and 'sources'" " are invalid." ) elif any(arg is not None for arg in (tfrecords, tiles, roi, slides)): config = dict(dataset=dict( tfrecords=tfrecords, tiles=tiles, roi=roi, slides=slides )) sources = ['dataset'] if isinstance(config, str): self._config = config loaded_config = sf.util.load_json(config) else: self._config = "<dict>" loaded_config = config # Read dataset sources from the configuration if sources is None: raise ValueError("Missing argument 'sources'") sources = sources if isinstance(sources, list) else [sources] try: self.sources = { k: v for k, v in loaded_config.items() if k in sources } self.sources_names = list(self.sources.keys()) except KeyError: sources_list = ', '.join(sources) raise errors.SourceNotFoundError(sources_list, config) missing_sources = [s for s in sources if s not in self.sources] if len(missing_sources): log.warn( "The following sources were not found in the dataset " f"configuration: {', '.join(missing_sources)}" ) # Create labels for each source based on tile size if (tile_px is not None) and (tile_um is not None): label = sf.util.tile_size_label(tile_px, tile_um) else: label = None for source in self.sources: self.sources[source]['label'] = label # Load annotations if annotations is not None: self.load_annotations(annotations) def __repr__(self) -> str: # noqa D105 _b = "Dataset(config={!r}, sources={!r}, tile_px={!r}, tile_um={!r})" return _b.format( self._config, self.sources_names, self.tile_px, self.tile_um ) @property def annotations(self) -> Optional[pd.DataFrame]: """Pandas DataFrame of all loaded clinical annotations.""" return self._annotations @property def num_tiles(self) -> int: """Number of tiles in tfrecords after filtering/clipping.""" tfrecords = self.tfrecords() m = self.manifest() if not all([tfr in m for tfr in tfrecords]): self.update_manifest() n_tiles = [ m[tfr]['total'] if 'clipped' not in m[tfr] else m[tfr]['clipped'] for tfr in tfrecords ] return sum(n_tiles) @property def filters(self) -> Dict: """Returns the active filters, if any.""" return self._filters @property def filter_blank(self) -> Union[str, List[str]]: """Returns the active filter_blank filter, if any.""" return self._filter_blank @property def min_tiles(self) -> int: """Returns the active min_tiles filter, if any (defaults to 0).""" return self._min_tiles @property def filtered_annotations(self) -> pd.DataFrame: """Pandas DataFrame of clinical annotations, after filtering.""" if self.annotations is not None: f_ann = self.annotations # Only return slides with annotation values specified in "filters" if self.filters: for filter_key in self.filters.keys(): if filter_key not in f_ann.columns: raise IndexError( f"Filter header {filter_key} not in annotations." ) filter_vals = sf.util.as_list(self.filters[filter_key]) f_ann = f_ann.loc[f_ann[filter_key].isin(filter_vals)] # Filter out slides that are blank in a given annotation # column ("filter_blank") if self.filter_blank and self.filter_blank != [None]: for fb in self.filter_blank: if fb not in f_ann.columns: raise errors.DatasetFilterError( f"Header {fb} not found in annotations." ) f_ann = f_ann.loc[f_ann[fb].notna()] f_ann = f_ann.loc[~f_ann[fb].isin(sf.util.EMPTY)] # Filter out slides that do not meet minimum number of tiles if self.min_tiles: manifest = self.manifest(key='name', filter=False) man_slides = [s for s in manifest if manifest[s]['total'] >= self.min_tiles] f_ann = f_ann.loc[f_ann.slide.isin(man_slides)] return f_ann else: return None @property def img_format(self) -> Optional[str]: """Format of images stored in TFRecords (jpg/png). Verifies all tfrecords share the same image format. Returns: str: Image format of tfrecords (PNG or JPG), or None if no tfrecords have been extracted. """ return self.verify_img_format(progress=False) def _tfrecords_set(self, source: str): if source not in self.sources: raise ValueError(f"Unrecognized dataset source {source}") config = self.sources[source] return 'tfrecords' in config and config['tfrecords'] def _tiles_set(self, source: str): if source not in self.sources: raise ValueError(f"Unrecognized dataset source {source}") config = self.sources[source] return 'tiles' in config and config['tiles'] def _roi_set(self, source: str): if source not in self.sources: raise ValueError(f"Unrecognized dataset source {source}") config = self.sources[source] return 'roi' in config and config['roi'] def _slides_set(self, source: str): if source not in self.sources: raise ValueError(f"Unrecognized dataset source {source}") config = self.sources[source] return 'slides' in config and config['slides'] def _assert_size_matches_hp(self, hp: Union[Dict, "ModelParams"]) -> None: """Check if dataset tile size (px/um) matches the given parameters.""" if isinstance(hp, dict): hp_px = hp['tile_px'] hp_um = hp['tile_um'] elif isinstance(hp, sf.ModelParams): hp_px = hp.tile_px hp_um = hp.tile_um else: raise ValueError(f"Unrecognized hyperparameter type {type(hp)}") if self.tile_px != hp_px or self.tile_um != hp_um: d_sz = f'({self.tile_px}px, tile_um={self.tile_um})' m_sz = f'({hp_px}px, tile_um={hp_um})' raise ValueError( f"Dataset tile size {d_sz} does not match model {m_sz}" ) def load_annotations(self, annotations: Union[str, pd.DataFrame]) -> None: """Load annotations. Args: annotations (Union[str, pd.DataFrame]): Either path to annotations in CSV format, or a pandas DataFrame. Raises: errors.AnnotationsError: If annotations are incorrectly formatted. """ if isinstance(annotations, str): if not exists(annotations): raise errors.AnnotationsError( f'Unable to find annotations file {annotations}' ) try: ann_df = pd.read_csv(annotations, dtype=str) ann_df.fillna('', inplace=True) self._annotations = ann_df self.annotations_file = annotations except pd.errors.EmptyDataError: log.error(f"Unable to load empty annotations {annotations}") elif isinstance(annotations, pd.core.frame.DataFrame): annotations.fillna('', inplace=True) self._annotations = annotations else: raise errors.AnnotationsError( 'Invalid annotations format; expected path or DataFrame' ) # Check annotations assert self.annotations is not None if len(self.annotations.columns) == 1: raise errors.AnnotationsError( "Only one annotations column detected (is it in CSV format?)" ) if len(self.annotations.columns) != len(set(self.annotations.columns)): raise errors.AnnotationsError( "Annotations file has duplicate headers; all must be unique" ) if 'patient' not in self.annotations.columns: raise errors.AnnotationsError( "Patient identifier 'patient' missing in annotations." ) if 'slide' not in self.annotations.columns: if isinstance(annotations, pd.DataFrame): raise errors.AnnotationsError( "If loading annotations from a pandas DataFrame," " must include column 'slide' containing slide names." ) log.info("Column 'slide' missing in annotations.") log.info("Attempting to associate patients with slides...") self.update_annotations_with_slidenames(annotations) self.load_annotations(annotations) # Check for duplicate slides ann = self.annotations.loc[self.annotations.slide.isin(self.slides())] if not ann.slide.is_unique: dup_slide_idx = ann.slide.duplicated() dup_slides = ann.loc[dup_slide_idx].slide.to_numpy().tolist() raise errors.DatasetError( f"Duplicate slides found in annotations: {dup_slides}." ) def balance( self, headers: Optional[Union[str, List[str]]] = None, strategy: Optional[str] = 'category', *, force: bool = False, ) -> "Dataset": """Return a dataset with mini-batch balancing configured. Mini-batch balancing can be configured at tile, slide, patient, or category levels. Balancing information is saved to the attribute ``prob_weights``, which is used by the interleaving dataloaders when sampling from tfrecords to create a batch. Tile level balancing will create prob_weights reflective of the number of tiles per slide, thus causing the batch sampling to mirror random sampling from the entire population of tiles (rather than randomly sampling from slides). Slide level balancing is the default behavior, where batches are assembled by randomly sampling from each slide/tfrecord with equal probability. This balancing behavior would be the same as no balancing. Patient level balancing is used to randomly sample from individual patients with equal probability. This is distinct from slide level balancing, as some patients may have multiple slides per patient. Category level balancing takes a list of annotation header(s) and generates prob_weights such that each category is sampled equally. This requires categorical outcomes. Args: headers (list of str, optional): List of annotation headers if balancing by category. Defaults to None. strategy (str, optional): 'tile', 'slide', 'patient' or 'category'. Create prob_weights used to balance dataset batches to evenly distribute slides, patients, or categories in a given batch. Tile-level balancing generates prob_weights reflective of the total number of tiles in a slide. Defaults to 'category.' force (bool, optional): If using category-level balancing, interpret all headers as categorical variables, even if the header appears to be a float. Returns: balanced :class:`slideflow.Dataset` object. """ ret = copy.deepcopy(self) manifest = ret.manifest() tfrecords = ret.tfrecords() slides = [path_to_name(tfr) for tfr in tfrecords] totals = { tfr: (manifest[tfr]['total'] if 'clipped' not in manifest[tfr] else manifest[tfr]['clipped']) for tfr in tfrecords } if not tfrecords: raise errors.DatasetBalanceError( "Unable to balance; no tfrecords found." ) if strategy == 'none' or strategy is None: return self if strategy == 'tile': ret.prob_weights = { tfr: totals[tfr] / sum(totals.values()) for tfr in tfrecords } if strategy == 'slide': ret.prob_weights = {tfr: 1/len(tfrecords) for tfr in tfrecords} if strategy == 'patient': pts = ret.patients() # Maps tfrecords to patients r_pts = {} # Maps patients to list of tfrecords for slide in pts: if slide not in slides: continue if pts[slide] not in r_pts: r_pts[pts[slide]] = [slide] else: r_pts[pts[slide]] += [slide] ret.prob_weights = { tfr: 1/(len(r_pts) * len(r_pts[pts[path_to_name(tfr)]])) for tfr in tfrecords } if strategy == 'category': if headers is None: raise ValueError('Category balancing requires headers.') # Ensure that header is not type 'float' headers = sf.util.as_list(headers) if any(ret.is_float(h) for h in headers) and not force: raise errors.DatasetBalanceError( f"Headers {','.join(headers)} appear to be `float`. " "Categorical outcomes required for balancing. " "To force balancing with these outcomes, pass " "`force=True` to Dataset.balance()" ) labels, _ = ret.labels(headers, use_float=False) cats = {} # type: Dict[str, Dict] cat_prob = {} tfr_cats = {} # type: Dict[str, str] for tfrecord in tfrecords: slide = path_to_name(tfrecord) balance_cat = sf.util.as_list(labels[slide]) balance_cat_str = '-'.join(map(str, balance_cat)) tfr_cats[tfrecord] = balance_cat_str tiles = totals[tfrecord] if balance_cat_str not in cats: cats.update({balance_cat_str: { 'num_slides': 1, 'num_tiles': tiles }}) else: cats[balance_cat_str]['num_slides'] += 1 cats[balance_cat_str]['num_tiles'] += tiles for category in cats: min_cat_slides = min([ cats[i]['num_slides'] for i in cats ]) slides_in_cat = cats[category]['num_slides'] cat_prob[category] = min_cat_slides / slides_in_cat total_prob = sum([cat_prob[tfr_cats[tfr]] for tfr in tfrecords]) ret.prob_weights = { tfr: cat_prob[tfr_cats[tfr]]/total_prob for tfr in tfrecords } return ret def build_index( self, force: bool = True, *, num_workers: Optional[int] = None ) -> None: """Build index files for TFRecords. Args: force (bool): Force re-build existing indices. Keyword args: num_workers (int, optional): Number of workers to use for building indices. Defaults to num_cpus, up to maximum of 16. Returns: None """ if num_workers is None: num_workers = min(sf.util.num_cpu(), 16) if force: index_to_update = self.tfrecords() # Remove existing indices for tfr in self.tfrecords(): index = tfrecord2idx.find_index(tfr) if index: os.remove(index) else: index_to_update = [] for tfr in self.tfrecords(): index = tfrecord2idx.find_index(tfr) if not index: index_to_update.append(tfr) elif (not tfrecord2idx.index_has_locations(index) and sf.io.tfrecord_has_locations(tfr)): os.remove(index) index_to_update.append(tfr) if not index_to_update: return if num_workers == 0: # Single thread. for tfr in track(index_to_update, description=f'Updating index files...', total=len(index_to_update), transient=True): _create_index(tfr, force=force) else: # Multiprocessing. index_fn = partial(_create_index, force=force) pool = mp.Pool( sf.util.num_cpu(), initializer=sf.util.set_ignore_sigint ) for _ in track(pool.imap_unordered(index_fn, index_to_update), description=f'Updating index files...', total=len(index_to_update), transient=True): pass pool.close() def cell_segmentation( self, diam_um: float, dest: str, *, model: Union["cellpose.models.Cellpose", str] = 'cyto2', window_size: int = 256, diam_mean: Optional[int] = None, qc: Optional[str] = None, qc_kwargs: Optional[dict] = None, buffer: Optional[str] = None, q_size: int = 2, force: bool = False, save_centroid: bool = True, save_flow: bool = False, **kwargs ) -> None: """Perform cell segmentation on slides, saving segmentation masks. Args: diam_um (int, optional): Cell segmentation diameter, in microns. dest (str): Destination in which to save cell segmentation masks. Keyword args: batch_size (int): Batch size for cell segmentation. Defaults to 8. cp_thresh (float): Cell probability threshold. All pixels with value above threshold kept for masks, decrease to find more and larger masks. Defaults to 0. diam_mean (int, optional): Cell diameter to detect, in pixels (without image resizing). If None, uses Cellpose defaults (17 for the 'nuclei' model, 30 for all others). downscale (float): Factor by which to downscale generated masks after calculation. Defaults to None (keep masks at original size). flow_threshold (float): Flow error threshold (all cells with errors below threshold are kept). Defaults to 0.4. gpus (int, list(int)): GPUs to use for cell segmentation. Defaults to 0 (first GPU). interp (bool): Interpolate during 2D dynamics. Defaults to True. qc (str): Slide-level quality control method to use before performing cell segmentation. Defaults to "Otsu". model (str, :class:`cellpose.models.Cellpose`): Cellpose model to use for cell segmentation. May be any valid cellpose model. Defaults to 'cyto2'. mpp (float): Microns-per-pixel at which cells should be segmented. Defaults to 0.5. num_workers (int, optional): Number of workers. Defaults to 2 * num_gpus. save_centroid (bool): Save mask centroids. Increases memory utilization slightly. Defaults to True. save_flow (bool): Save flow values for the whole-slide image. Increases memory utilization. Defaults to False. sources (List[str]): List of dataset sources to include from configuration file. tile (bool): Tiles image to decrease GPU/CPU memory usage. Defaults to True. verbose (bool): Verbose log output at the INFO level. Defaults to True. window_size (int): Window size at which to segment cells across a whole-slide image. Defaults to 256. Returns: None """ from slideflow.cellseg import segment_slide if qc_kwargs is None: qc_kwargs = {} slide_list = self.slide_paths() if not force: n_all = len(slide_list) slide_list = [ s for s in slide_list if not exists( join(dest, sf.util.path_to_name(s)+'-masks.zip') ) ] n_skipped = n_all - len(slide_list) if n_skipped: log.info("Skipping {} slides (masks already generated)".format( n_skipped )) if slide_list: log.info(f"Segmenting cells for {len(slide_list)} slides.") else: log.info("No slides found.") return if diam_mean is None: diam_mean = 30 if model != 'nuclei' else 17 tile_um = int(window_size * (diam_um / diam_mean)) pb = TileExtractionProgress() speed_task = pb.add_task( "Speed: ", progress_type="speed", total=None ) slide_task = pb.add_task( "Slides: ", progress_type="slide_progress", total=len(slide_list) ) q = Queue() # type: Queue if buffer: thread = threading.Thread( target=_fill_queue, args=(slide_list, q, q_size, buffer)) thread.start() pb.start() with sf.util.cleanup_progress(pb): while True: slide_path = q.get() if slide_path is None: q.task_done() break wsi = sf.WSI( slide_path, tile_px=window_size, tile_um=tile_um, verbose=False ) if qc is not None: wsi.qc(qc, **qc_kwargs) segment_task = pb.add_task( "Segmenting... ", progress_type="slide_progress", total=wsi.estimated_num_tiles ) # Perform segmentation and save segmentation = segment_slide( wsi, pb=pb, pb_tasks=[speed_task, segment_task], show_progress=False, model=model, diam_mean=diam_mean, save_flow=save_flow, **kwargs) mask_dest = dest if dest is not None else dirname(slide_path) segmentation.save( join(mask_dest, f'{wsi.name}-masks.zip'), flows=save_flow, centroids=save_centroid) pb.advance(slide_task) pb.remove_task(segment_task) if buffer: _debuffer_slide(slide_path) q.task_done() if buffer: thread.join() def check_duplicates( self, dataset: Optional["Dataset"] = None, px: int = 64, mse_thresh: int = 100 ) -> List[Tuple[str, str]]: """Check for duplicate slides by comparing slide thumbnails. Args: dataset (`slideflow.Dataset`, optional): Also check for duplicate slides between this dataset and the provided dataset. px (int): Pixel size at which to compare thumbnails. Defaults to 64. mse_thresh (int): MSE threshold below which an image pair is considered duplicate. Defaults to 100. Returns: List[str], optional: List of path pairs of potential duplicates. """ import cv2 thumbs = {} dups = [] def mse(A, B): """Calulate the mean squared error between two image matrices.""" err = np.sum((A.astype("float") - B.astype("float")) ** 2) err /= float(A.shape[0] * A.shape[1]) return err def img_from_path(path): """Read and resize an image.""" img = cv2.imdecode( np.fromfile(path, dtype=np.uint8), cv2.IMREAD_UNCHANGED) img = img[..., 0:3] return cv2.resize(img, dsize=(px, px), interpolation=cv2.INTER_CUBIC) with tempfile.TemporaryDirectory() as temp_dir: os.makedirs(join(temp_dir, 'this_dataset')) self.thumbnails(join(temp_dir, 'this_dataset')) if dataset: os.makedirs(join(temp_dir, 'other_dataset')) dataset.thumbnails(join(temp_dir, 'other_dataset')) for subdir in os.listdir(temp_dir): files = os.listdir(join(temp_dir, subdir)) for file in tqdm(files, desc="Scanning for duplicates..."): if dataset and subdir == 'other_dataset': wsi_path = dataset.find_slide(slide=path_to_name(file)) else: wsi_path = self.find_slide(slide=path_to_name(file)) assert wsi_path is not None img = img_from_path(join(temp_dir, subdir, file)) thumbs[wsi_path] = img # Check if this thumbnail has a duplicate for existing_img in thumbs: if wsi_path != existing_img: img2 = thumbs[existing_img] img_mse = mse(img, img2) if img_mse < mse_thresh: tqdm.write( 'Possible duplicates: ' '{} and {} (MSE: {})'.format( wsi_path, existing_img, mse(img, img2) ) ) dups += [(wsi_path, existing_img)] if not dups: log.info("No duplicates found.") else: log.info(f"{len(dups)} possible duplicates found.") return dups def clear_filters(self) -> "Dataset": """Return a dataset with all filters cleared. Returns: :class:`slideflow.Dataset` object. """ ret = copy.deepcopy(self) ret._filters = {} ret._filter_blank = [] ret._min_tiles = 0 return ret def clip( self, max_tiles: int = 0, strategy: Optional[str] = None, headers: Optional[List[str]] = None ) -> "Dataset": """Return a dataset with TFRecords clipped to a max number of tiles. Clip the number of tiles per tfrecord to a given maximum value and/or to the min number of tiles per patient or category. Args: max_tiles (int, optional): Clip the maximum number of tiles per tfrecord to this number. Defaults to 0 (do not perform tfrecord-level clipping). strategy (str, optional): 'slide', 'patient', or 'category'. Clip the maximum number of tiles to the minimum tiles seen across slides, patients, or categories. If 'category', headers must be provided. Defaults to None (do not perform group-level clipping). headers (list of str, optional): List of annotation headers to use if clipping by minimum category count (strategy='category'). Defaults to None. Returns: clipped :class:`slideflow.Dataset` object. """ if strategy == 'category' and not headers: raise errors.DatasetClipError( "headers must be provided if clip strategy is 'category'." ) if not max_tiles and strategy is None: return self.unclip() ret = copy.deepcopy(self) manifest = ret.manifest() tfrecords = ret.tfrecords() slides = [path_to_name(tfr) for tfr in tfrecords] totals = {tfr: manifest[tfr]['total'] for tfr in tfrecords} if not tfrecords: raise errors.DatasetClipError("No tfrecords found.") if strategy == 'slide': if max_tiles: clip = min(min(totals.values()), max_tiles) else: clip = min(totals.values()) ret._clip = { tfr: (clip if totals[tfr] > clip else totals[tfr]) for tfr in manifest } elif strategy == 'patient': patients = ret.patients() # Maps slide name to patient rev_patients = {} # Will map patients to list of slide names slide_totals = {path_to_name(tfr): t for tfr, t in totals.items()} for slide in patients: if slide not in slides: continue if patients[slide] not in rev_patients: rev_patients[patients[slide]] = [slide] else: rev_patients[patients[slide]] += [slide] tiles_per_patient = { pt: sum([slide_totals[slide] for slide in slide_list]) for pt, slide_list in rev_patients.items() } if max_tiles: clip = min(min(tiles_per_patient.values()), max_tiles) else: clip = min(tiles_per_patient.values()) ret._clip = { tfr: (clip if slide_totals[path_to_name(tfr)] > clip else totals[tfr]) for tfr in manifest } elif strategy == 'category': if headers is None: raise ValueError("Category clipping requires arg `headers`") labels, _ = ret.labels(headers, use_float=False) categories = {} cat_fraction = {} tfr_cats = {} for tfrecord in tfrecords: slide = path_to_name(tfrecord) balance_category = sf.util.as_list(labels[slide]) balance_cat_str = '-'.join(map(str, balance_category)) tfr_cats[tfrecord] = balance_cat_str tiles = totals[tfrecord] if balance_cat_str not in categories: categories[balance_cat_str] = tiles else: categories[balance_cat_str] += tiles for category in categories: min_cat_count = min([categories[i] for i in categories]) cat_fraction[category] = min_cat_count / categories[category] ret._clip = { tfr: int(totals[tfr] * cat_fraction[tfr_cats[tfr]]) for tfr in manifest } elif max_tiles: ret._clip = { tfr: (max_tiles if totals[tfr] > max_tiles else totals[tfr]) for tfr in manifest } return ret def convert_xml_rois(self): """Convert ImageScope XML ROI files to QuPath format CSV ROI files.""" n_converted = 0 xml_list = [] for source in self.sources: if self._roi_set(source): xml_list = glob(join(self.sources[source]['roi'], "*.xml")) if len(xml_list) == 0: raise errors.DatasetError( 'No XML files found. Check dataset configuration.' ) for xml in xml_list: try: sf.slide.utils.xml_to_csv(xml) except errors.ROIError as e: log.warning(f"Failed to convert XML roi {xml}: {e}") else: n_converted += 1 log.info(f'Converted {n_converted} XML ROIs -> CSV') def get_tile_dataframe( self, roi_method: str = 'auto', stride_div: int = 1, ) -> pd.DataFrame: """Generate a pandas dataframe with tile-level ROI labels. Returns: Pandas dataframe of all tiles, with the following columns: - ``loc_x``: X-coordinate of tile center - ``loc_y``: Y-coordinate of tile center - ``grid_x``: X grid index of the tile - ``grid_y``: Y grid index of the tile - ``roi_name``: Name of the ROI if tile is in an ROI, else None - ``roi_desc``: Description of the ROI if tile is in ROI, else None - ``label``: ROI label, if present. """ df = None with mp.Pool(4, initializer=sf.util.set_ignore_sigint) as pool: fn = partial( _get_tile_df, tile_px=self.tile_px, tile_um=self.tile_um, rois=self.rois(), stride_div=stride_div, roi_method=roi_method ) for _df in track(pool.imap_unordered(fn, self.slide_paths()), description=f'Building...', total=len(self.slide_paths()), transient=True): if df is None: df = _df else: df = pd.concat([df, _df], axis=0, join='outer') return df def get_unique_roi_labels(self, allow_empty: bool = False) -> List[str]: """Get a list of unique ROI labels for all slides in this dataset.""" # Get a list of unique labels. roi_unique_labels = [] for roi in self.rois(): _df = pd.read_csv(roi) if 'label' not in _df.columns: continue unique = [ l for l in _df.label.unique().tolist() if (l not in roi_unique_labels) ] roi_unique_labels += unique without_nan = sorted([ l for l in roi_unique_labels if (not isinstance(l, float) or not np.isnan(l)) ]) if allow_empty and (len(roi_unique_labels) > len(without_nan)): return without_nan + [None] else: return without_nan def extract_cells( self, masks_path: str, **kwargs ) -> Dict[str, SlideReport]: """Extract cell images from slides, with a tile at each cell centroid. Requires that cells have already been segmented with ``Dataset.cell_segmentation()``. Args: masks_path (str): Location of saved segmentation masks. Keyword Args: apply_masks (bool): Apply cell segmentation masks to the extracted tiles. Defaults to True. **kwargs: All other keyword arguments for :meth:`Dataset.extract_tiles()`. Returns: Dictionary mapping slide paths to each slide's SlideReport (:class:`slideflow.slide.report.SlideReport`) """ from slideflow.cellseg.seg_utils import ApplySegmentation # Add WSI segmentation as slide-level transformation. qc = [] if 'qc' not in kwargs else kwargs['qc'] if not isinstance(qc, list): qc = [qc] qc.append(ApplySegmentation(masks_path)) kwargs['qc'] = qc # Extract tiles from segmentation centroids. return self.extract_tiles( from_centroids=True, **kwargs ) def extract_tiles( self, *, save_tiles: bool = False, save_tfrecords: bool = True, source: Optional[str] = None, stride_div: int = 1, enable_downsample: bool = True, roi_method: str = 'auto', roi_filter_method: Union[str, float] = 'center', skip_extracted: bool = True, tma: bool = False, randomize_origin: bool = False, buffer: Optional[str] = None, q_size: int = 2, qc: Optional[Union[str, Callable, List[Callable]]] = None, report: bool = True, use_edge_tiles: bool = False, artifact_labels: Optional[Union[List[str], str]] = list(), mpp_override: Optional[float] = None, **kwargs: Any ) -> Dict[str, SlideReport]: r"""Extract tiles from a group of slides. Extracted tiles are saved either loose image or in TFRecord format. Extracted tiles are either saved in TFRecord format (``save_tfrecords=True``, default) or as loose \*.jpg / \*.png images (``save_tiles=True``). TFRecords or image tiles are saved in the the tfrecord and tile directories configured by :class:`slideflow.Dataset`. Keyword Args: save_tiles (bool, optional): Save tile images in loose format. Defaults to False. save_tfrecords (bool): Save compressed image data from extracted tiles into TFRecords in the corresponding TFRecord directory. Defaults to True. source (str, optional): Name of dataset source from which to select slides for extraction. Defaults to None. If not provided, will default to all sources in project. stride_div (int): Stride divisor for tile extraction. A stride of 1 will extract non-overlapping tiles. A stride_div of 2 will extract overlapping tiles, with a stride equal to 50% of the tile width. Defaults to 1. enable_downsample (bool): Enable downsampling for slides. This may result in corrupted image tiles if downsampled slide layers are corrupted or incomplete. Defaults to True. roi_method (str): Either 'inside', 'outside', 'auto', or 'ignore'. Determines how ROIs are used to extract tiles. If 'inside' or 'outside', will extract tiles in/out of an ROI, and skip the slide if an ROI is not available. If 'auto', will extract tiles inside an ROI if available, and across the whole-slide if no ROI is found. If 'ignore', will extract tiles across the whole-slide regardless of whether an ROI is available. Defaults to 'auto'. roi_filter_method (str or float): Method of filtering tiles with ROIs. Either 'center' or float (0-1). If 'center', tiles are filtered with ROIs based on the center of the tile. If float, tiles are filtered based on the proportion of the tile inside the ROI, and ``roi_filter_method`` is interpreted as a threshold. If the proportion of a tile inside the ROI is greater than this number, the tile is included. For example, if ``roi_filter_method=0.7``, a tile that is 80% inside of an ROI will be included, and a tile that is 50% inside of an ROI will be excluded. Defaults to 'center'. skip_extracted (bool): Skip slides that have already been extracted. Defaults to True. tma (bool): Reads slides as Tumor Micro-Arrays (TMAs). Deprecated argument; all slides are now read as standard WSIs. randomize_origin (bool): Randomize pixel starting position during extraction. Defaults to False. buffer (str, optional): Slides will be copied to this directory before extraction. Defaults to None. Using an SSD or ramdisk buffer vastly improves tile extraction speed. q_size (int): Size of queue when using a buffer. Defaults to 2. qc (str, optional): 'otsu', 'blur', 'both', or None. Perform blur detection quality control - discarding tiles with detected out-of-focus regions or artifact - and/or otsu's method. Increases tile extraction time. Defaults to None. report (bool): Save a PDF report of tile extraction. Defaults to True. normalizer (str, optional): Normalization strategy. Defaults to None. normalizer_source (str, optional): Stain normalization preset or path to a source image. Valid presets include 'v1', 'v2', and 'v3'. If None, will use the default present ('v3'). Defaults to None. whitespace_fraction (float, optional): Range 0-1. Discard tiles with this fraction of whitespace. If 1, will not perform whitespace filtering. Defaults to 1. whitespace_threshold (int, optional): Range 0-255. Defaults to 230. Threshold above which a pixel (RGB average) is whitespace. grayspace_fraction (float, optional): Range 0-1. Defaults to 0.6. Discard tiles with this fraction of grayspace. If 1, will not perform grayspace filtering. grayspace_threshold (float, optional): Range 0-1. Defaults to 0.05. Pixels in HSV format with saturation below this threshold are considered grayspace. img_format (str, optional): 'png' or 'jpg'. Defaults to 'jpg'. Image format to use in tfrecords. PNG (lossless) for fidelity, JPG (lossy) for efficiency. shuffle (bool, optional): Shuffle tiles prior to storage in tfrecords. Defaults to True. num_threads (int, optional): Number of worker processes for each tile extractor. When using cuCIM slide reading backend, defaults to the total number of available CPU cores, using the 'fork' multiprocessing method. With Libvips, this defaults to the total number of available CPU cores or 32, whichever is lower, using 'spawn' multiprocessing. qc_blur_radius (int, optional): Quality control blur radius for out-of-focus area detection. Used if qc=True. Defaults to 3. qc_blur_threshold (float, optional): Quality control blur threshold for detecting out-of-focus areas. Only used if qc=True. Defaults to 0.1 qc_filter_threshold (float, optional): Float between 0-1. Tiles with more than this proportion of blur will be discarded. Only used if qc=True. Defaults to 0.6. qc_mpp (float, optional): Microns-per-pixel indicating image magnification level at which quality control is performed. Defaults to mpp=4 (effective magnification 2.5 X) dry_run (bool, optional): Determine tiles that would be extracted, but do not export any images. Defaults to None. max_tiles (int, optional): Only extract this many tiles per slide. Defaults to None. use_edge_tiles (bool): Use edge tiles in extraction. Areas outside the slide will be padded white. Defaults to False. artifact_labels (list(str) or str, optional): List of ROI issue labels to treat as artifacts. Whenever this is not None, all the ROIs with referred label will be inverted with ROI.invert(). Defaults to an empty list. mpp_override (float, optional): Override the microns-per-pixel for each slide. If None, will auto-detect microns-per-pixel for all slides and raise an error if MPP is not found. Defaults to None. Returns: Dictionary mapping slide paths to each slide's SlideReport (:class:`slideflow.slide.report.SlideReport`) """ if tma: warnings.warn( "tma=True is deprecated and will be removed in a future " "version. Tumor micro-arrays are read as standard slides. " ) if not self.tile_px or not self.tile_um: raise errors.DatasetError( "Dataset tile_px and tile_um must be != 0 to extract tiles" ) if source: sources = sf.util.as_list(source) # type: List[str] else: sources = list(self.sources.keys()) all_reports = [] self.verify_annotations_slides() if isinstance(artifact_labels, str): artifact_labels = [artifact_labels] # Log the active slide reading backend col = 'green' if sf.slide_backend() == 'cucim' else 'cyan' log.info(f"Slide reading backend: [{col}]{sf.slide_backend()}[/]") # Set up kwargs for tile extraction generator and quality control qc_kwargs = {k[3:]: v for k, v in kwargs.items() if k[:3] == 'qc_'} kwargs = {k: v for k, v in kwargs.items() if k[:3] != 'qc_'} sf.slide.log_extraction_params(**kwargs) for source in sources: log.info(f'Working on dataset source [bold]{source}[/]...') if self._roi_set(source): roi_dir = self.sources[source]['roi'] else: roi_dir = None src_conf = self.sources[source] if 'dry_run' not in kwargs or not kwargs['dry_run']: if save_tfrecords and not self._tfrecords_set(source): log.error(f"tfrecords path not set for source {source}") continue elif save_tfrecords: tfrecord_dir = join( src_conf['tfrecords'], src_conf['label'] ) else: tfrecord_dir = None if save_tiles and not self._tiles_set(source): log.error(f"tiles path not set for source {source}") continue elif save_tiles: tiles_dir = join(src_conf['tiles'], src_conf['label']) else: tiles_dir = None if save_tfrecords and not exists(tfrecord_dir): os.makedirs(tfrecord_dir) if save_tiles and not exists(tiles_dir): os.makedirs(tiles_dir) else: save_tfrecords, save_tiles = False, False tfrecord_dir, tiles_dir = None, None # Prepare list of slides for extraction slide_list = self.slide_paths(source=source) # Check for interrupted or already-extracted tfrecords if skip_extracted and save_tfrecords: done = [ path_to_name(tfr) for tfr in self.tfrecords(source=source) ] _dir = tfrecord_dir if tfrecord_dir else tiles_dir unfinished = glob(join((_dir), '*.unfinished')) interrupted = [path_to_name(marker) for marker in unfinished] if len(interrupted): log.info(f'Re-extracting {len(interrupted)} interrupted:') for interrupted_slide in interrupted: log.info(interrupted_slide) if interrupted_slide in done: del done[done.index(interrupted_slide)] slide_list = [ s for s in slide_list if path_to_name(s) not in done ] if len(done): log.info(f'Skipping {len(done)} slides; already done.') _tail = f"(tile_px={self.tile_px}, tile_um={self.tile_um})" log.info(f'Extracting tiles from {len(slide_list)} slides {_tail}') # Use multithreading if specified, extracting tiles # from all slides in the filtered list if len(slide_list): q = Queue() # type: Queue # Forking incompatible with some libvips configurations ptype = 'spawn' if sf.slide_backend() == 'libvips' else 'fork' ctx = mp.get_context(ptype) manager = ctx.Manager() reports = manager.dict() kwargs['report'] = report # Use a single shared multiprocessing pool if 'num_threads' not in kwargs: num_threads = sf.util.num_cpu() if num_threads is None: num_threads = 8 if sf.slide_backend() == 'libvips': num_threads = min(num_threads, 32) else: num_threads = kwargs['num_threads'] if num_threads != 1: pool = kwargs['pool'] = ctx.Pool( num_threads, initializer=sf.util.set_ignore_sigint ) qc_kwargs['pool'] = pool else: pool = None ptype = None log.info(f'Using {num_threads} processes (pool={ptype})') # Set up the multiprocessing progress bar pb = TileExtractionProgress() pb.add_task( "Speed: ", progress_type="speed", total=None) slide_task = pb.add_task( f"Extracting ({source})...", progress_type="slide_progress", total=len(slide_list)) wsi_kwargs = { 'tile_px': self.tile_px, 'tile_um': self.tile_um, 'stride_div': stride_div, 'enable_downsample': enable_downsample, 'roi_dir': roi_dir, 'roi_method': roi_method, 'roi_filter_method': roi_filter_method, 'origin': 'random' if randomize_origin else (0, 0), 'pb': pb, 'use_edge_tiles': use_edge_tiles, 'artifact_labels': artifact_labels, 'mpp': mpp_override } extraction_kwargs = { 'tfrecord_dir': tfrecord_dir, 'tiles_dir': tiles_dir, 'reports': reports, 'qc': qc, 'generator_kwargs': kwargs, 'qc_kwargs': qc_kwargs, 'wsi_kwargs': wsi_kwargs, 'render_thumb': (buffer is not None) } pb.start() with sf.util.cleanup_progress(pb): if buffer: # Start the worker threads thread = threading.Thread( target=_fill_queue, args=(slide_list, q, q_size, buffer)) thread.start() # Grab slide path from queue and start extraction while True: path = q.get() if path is None: q.task_done() break _tile_extractor(path, **extraction_kwargs) pb.advance(slide_task) _debuffer_slide(path) q.task_done() thread.join() else: for slide in slide_list: with _handle_slide_errors(slide): wsi = _prepare_slide( slide, report_dir=tfrecord_dir, wsi_kwargs=wsi_kwargs, qc=qc, qc_kwargs=qc_kwargs) if wsi is not None: log.debug(f'Extracting tiles for {wsi.name}') wsi_report = wsi.extract_tiles( tfrecord_dir=tfrecord_dir, tiles_dir=tiles_dir, **kwargs ) reports.update({wsi.path: wsi_report}) del wsi pb.advance(slide_task) # Generate PDF report. if report: log.info('Generating PDF (this may take some time)...', ) rep_vals = list( reports.copy().values() ) # type: List[SlideReport] all_reports += rep_vals num_slides = len(slide_list) img_kwargs = defaultdict(lambda: None) # type: Dict img_kwargs.update(kwargs) img_kwargs = sf.slide.utils._update_kw_with_defaults(img_kwargs) report_meta = types.SimpleNamespace( tile_px=self.tile_px, tile_um=self.tile_um, qc=qc, total_slides=num_slides, slides_skipped=len([r for r in rep_vals if r is None]), roi_method=roi_method, stride=stride_div, gs_frac=img_kwargs['grayspace_fraction'], gs_thresh=img_kwargs['grayspace_threshold'], ws_frac=img_kwargs['whitespace_fraction'], ws_thresh=img_kwargs['whitespace_threshold'], normalizer=img_kwargs['normalizer'], img_format=img_kwargs['img_format'] ) pdf_report = ExtractionReport( [r for r in rep_vals if r is not None], meta=report_meta, pool=pool ) _time = datetime.now().strftime('%Y%m%d-%H%M%S') pdf_dir = tfrecord_dir if tfrecord_dir else '' pdf_report.save( join(pdf_dir, f'tile_extraction_report-{_time}.pdf') ) pdf_report.update_csv( join(pdf_dir, 'extraction_report.csv') ) warn_path = join(pdf_dir, f'warn_report-{_time}.txt') if pdf_report.warn_txt: with open(warn_path, 'w') as warn_f: warn_f.write(pdf_report.warn_txt) # Close the multiprocessing pool. if pool is not None: pool.close() # Update manifest & rebuild indices self.update_manifest(force_update=True) self.build_index(True) all_reports = [r for r in all_reports if r is not None] return {report.path: report for report in all_reports} def extract_tiles_from_tfrecords(self, dest: str) -> None: """Extract tiles from a set of TFRecords. Args: dest (str): Path to directory in which to save tile images. If None, uses dataset default. Defaults to None. """ for source in self.sources: to_extract_tfrecords = self.tfrecords(source=source) if dest: tiles_dir = dest elif self._tiles_set(source): tiles_dir = join(self.sources[source]['tiles'], self.sources[source]['label']) if not exists(tiles_dir): os.makedirs(tiles_dir) else: log.error(f"tiles directory not set for source {source}") continue for tfr in to_extract_tfrecords: sf.io.extract_tiles(tfr, tiles_dir) def filter(self, *args: Any, **kwargs: Any) -> "Dataset": """Return a filtered dataset. This method can either accept a single argument (``filters``) or any combination of keyword arguments (``filters``, ``filter_blank``, or ``min_tiles``). Keyword Args: filters (dict, optional): Dictionary used for filtering the dataset. Dictionary keys should be column headers in the patient annotations, and the values should be the variable states to be included in the dataset. For example, ``filters={'HPV_status': ['negative', 'positive']}`` would filter the dataset by the column ``HPV_status`` and only include slides with values of either ``'negative'`` or ``'positive'`` in this column. See `Filtering <https://slideflow.dev/datasets_and_val/#filtering>`_ for further discussion. Defaults to None. filter_blank (list(str) or str, optional): Skip slides that have blank values in these patient annotation columns. Defaults to None. min_tiles (int): Filter out tfrecords that have less than this minimum number of tiles. Defaults to 0. Returns: :class:`slideflow.Dataset`: Dataset with filter added. """ if len(args) == 1 and 'filters' not in kwargs: kwargs['filters'] = args[0] elif len(args): raise ValueError( "filter() accepts either one argument (filters), or any " "combination of keywords (filters, filter_blank, min_tiles)" ) for kwarg in kwargs: if kwarg not in ('filters', 'filter_blank', 'min_tiles'): raise ValueError(f'Unknown filtering argument {kwarg}') ret = copy.deepcopy(self) if 'filters' in kwargs and kwargs['filters'] is not None: if not isinstance(kwargs['filters'], dict): raise TypeError("'filters' must be a dict.") ret._filters.update(kwargs['filters']) if 'filter_blank' in kwargs and kwargs['filter_blank'] is not None: if not isinstance(kwargs['filter_blank'], list): kwargs['filter_blank'] = [kwargs['filter_blank']] ret._filter_blank += kwargs['filter_blank'] if 'min_tiles' in kwargs and kwargs['min_tiles'] is not None: if not isinstance(kwargs['min_tiles'], int): raise TypeError("'min_tiles' must be an int.") ret._min_tiles = kwargs['min_tiles'] return ret def filter_bags_by_roi( self, bags_path: str, dest: str, *, tile_df: Optional[pd.DataFrame] = None ) -> None: """Filter bags by tiles in an ROI.""" import torch #TODO: extend to tfrecords #TODO: accelerate with multiprocessing #TODO: save filtered indices #TODO: copy bags config if tile_df is None: tile_df = self.get_tile_dataframe() if not exists(dest): os.makedirs(dest) # Subset the dataframe to only include tiles with an ROI roi_df = tile_df.loc[tile_df.roi_name.notnull()] n_complete = 0 for slide in tqdm(roi_df.slide.unique()): if not exists(join(bags_path, slide+'.pt')): continue # Get the bag bag = torch.load(join(bags_path, slide+'.pt')) bag_index = np.load(join(bags_path, slide+'.index.npz'))['arr_0'] # Subset the ROI based on this slide slide_df = roi_df.loc[roi_df.slide == slide] # Get the common locations (in an ROI) bag_locs = {tuple(r) for r in bag_index} roi_locs = {tuple(r) for r in np.stack([slide_df.loc_x.values, slide_df.loc_y.values], axis=1)} common_locs = bag_locs.intersection(roi_locs) # Find indices in the bag that match the common locations (in an ROI) bag_i = [i for i, row in enumerate(bag_index) if tuple(row) in common_locs] if not len(bag_i): log.debug("No common locations found for {}".format(slide)) continue # Subset and save the bag bag = bag[bag_i] torch.save(bag, join(dest, slide+'.pt')) log.debug("Subset size ({}): {} -> {}".format(slide, len(bag_index), len(bag))) n_complete += 1 log.info("Bag filtering complete. {} bags filtered.".format(n_complete)) def find_rois(self, slide: str) -> Optional[str]: """Find an ROI path from a given slide. Args: slide (str): Slide name. Returns: str: Matching path to ROI, if found. If not found, returns None """ rois = self.rois() if not rois: return None for roi in rois: if path_to_name(roi) == slide: return roi return None def find_slide( self, *, slide: Optional[str] = None, patient: Optional[str] = None ) -> Optional[str]: """Find a slide path from a given slide or patient. Keyword args: slide (str): Find a tfrecord associated with this slide name. patient (str): Find a tfrecord associated with this patient. Returns: str: Matching path to slide, if found. If not found, returns None """ if slide is None and patient is None: raise ValueError("Must supply either slide or patient.") if slide is not None and patient is not None: raise ValueError("Must supply either slide or patient, not both.") if slide is not None: filtered = self.filter({'slide': slide}) if patient is not None: filtered = self.filter({'slide': patient}) matching = filtered.slide_paths() if not len(matching): return None else: return matching[0] def find_tfrecord( self, *, slide: Optional[str] = None, patient: Optional[str] = None ) -> Optional[str]: """Find a TFRecord path from a given slide or patient. Keyword args: slide (str): Find a tfrecord associated with this slide name. patient (str): Find a tfrecord associated with this patient. Returns: str: Matching path to tfrecord, if found. Otherwise, returns None """ if slide is None and patient is None: raise ValueError("Must supply either slide or patient.") if slide is not None and patient is not None: raise ValueError("Must supply either slide or patient, not both.") if slide is not None: filtered = self.filter({'slide': slide}) if patient is not None: filtered = self.filter({'slide': patient}) matching = filtered.tfrecords() if not len(matching): return None else: return matching[0] def generate_feature_bags( self, model: Union[str, "BaseFeatureExtractor"], outdir: str, *, force_regenerate: bool = False, batch_size: int = 32, slide_batch_size: int = 16, num_gpus: int = 0, **kwargs: Any ) -> None: """Generate bags of tile-level features for slides for use with MIL models. Args: model (str): Path to model from which to generate activations. May provide either this or "pt_files" outdir (str, optional): Save exported activations in .pt format. Keyword Args: layers (list): Which model layer(s) generate activations. If ``model`` is a saved model, this defaults to 'postconv'. Not used if ``model`` is pretrained feature extractor. Defaults to None. force_regenerate (bool): Forcibly regenerate activations for all slides even if .pt file exists. Defaults to False. batch_size (int): Batch size during feature calculation. Defaults to 32. slide_batch_size (int): Interleave feature calculation across this many slides. Higher values may improve performance but require more memory. Defaults to 16. num_gpus (int): Number of GPUs to use for feature extraction. Defaults to 0. **kwargs: Additional keyword arguments are passed to :class:`slideflow.DatasetFeatures`. """ if not sf.util.torch_available: raise RuntimeError( "Pytorch is required for generating feature bags. " "Please install Pytorch and try again." ) # Interpret model argument. if isinstance(model, str) and sf.model.is_extractor(model): # Model is a architecture name (for Imagenet pretrained model) log.info(f"Building feature extractor: [green]{model}[/]") layer_kw = dict(layers=kwargs['layers']) if 'layers' in kwargs else dict() model = sf.build_feature_extractor(model, **layer_kw) elif isinstance(model, str) and exists(model): # Model is a path to a trained slideflow model log.info(f"Using model: [green]{model}[/]") elif isinstance(model, str) and not exists(model): # Model is a string but not a path to a saved model raise ValueError( f"'{model}' is neither a path to a saved model nor the name " "of a valid feature extractor (use sf.model.list_extractors() " "for a list of all available feature extractors).") elif not isinstance(model, str): # Model is a feature extractor object from slideflow.model.base import BaseFeatureExtractor if not isinstance(model, BaseFeatureExtractor): raise ValueError( f"'{model}' is neither a path to a saved model nor the name " "of a valid feature extractor (use sf.model.list_extractors() " "for a list of all available feature extractors).") log.info(f"Using feature extractor: [green]{model.tag}[/]") # Create the pt_files directory if not exists(outdir): os.makedirs(outdir) # Detect already generated pt files done = [ path_to_name(f) for f in os.listdir(outdir) if sf.util.path_to_ext(join(outdir, f)) == 'pt' ] # Work from this dataset. dataset = self if not force_regenerate and len(done): all_slides = dataset.slides() slides_to_generate = [s for s in all_slides if s not in done] if len(slides_to_generate) != len(all_slides): to_skip = len(all_slides) - len(slides_to_generate) skip_p = f'{to_skip}/{len(all_slides)}' log.info(f"Skipping {skip_p} finished slides.") if not slides_to_generate: log.warn("No slides for which to generate features.") return outdir dataset = dataset.filter(filters={'slide': slides_to_generate}) filtered_slides_to_generate = dataset.slides() log.info(f'Working on {len(filtered_slides_to_generate)} slides') # Verify TFRecords are available n_tfrecords = len(dataset.tfrecords()) n_slides = len(dataset.slides()) if not n_tfrecords: log.warning("Unable to generate features; no TFRecords found.") return outdir elif n_tfrecords < n_slides: log.warning("{} tfrecords missing.".format(n_slides - n_tfrecords)) # Rebuild any missing index files. # Must be done before the progress bar is started. dataset.build_index(False) # Set up progress bar. pb = sf.util.FeatureExtractionProgress() pb.add_task( "Speed: ", progress_type="speed", total=self.num_tiles ) slide_task = pb.add_task( "Generating...", progress_type="slide_progress", total=n_slides ) pb.start() # Prepare keyword arguments. dts_kwargs = dict( include_preds=False, include_uncertainty=False, batch_size=batch_size, verbose=False, progress=False, **kwargs ) # Set up activations interface. # Calculate features one slide at a time to reduce memory consumption. with sf.util.cleanup_progress(pb): if not num_gpus > 1: sf.model.features._export_bags( model, dataset, slides=dataset.slides(), slide_batch_size=slide_batch_size, pb=pb, outdir=outdir, slide_task=slide_task, **dts_kwargs ) else: if not hasattr(model, 'dump_config'): raise ValueError( "Feature extraction with multiple GPUs is only " "supported for feature extractors with a dump_config() " "attribute. Please set num_gpus=1 or use a different " "feature extractor." ) import torch model_cfg = sf.model.extractors.extractor_to_config(model) # Mixed precision and channels_last config if hasattr(model, "mixed_precision"): mixed_precision = model.mixed_precision else: mixed_precision = None if hasattr(model, "channels_last"): channels_last = model.channels_last else: channels_last = None with MultiprocessProgress(pb) as mp_pb: torch.multiprocessing.spawn( sf.model.features._distributed_export, args=( model_cfg, dataset, [n.tolist() for n in np.array_split(dataset.slides(), num_gpus)], slide_batch_size, mp_pb.tracker, outdir, slide_task, dts_kwargs, mixed_precision, channels_last ), nprocs=num_gpus ) def generate_rois( self, model: str, *, overwrite: bool = False, dest: Optional[str] = None, **kwargs ): """Generate ROIs using a U-Net model. Args: model (str): Path to model (zip) or model configuration (json). Keyword args: overwrite (bool, optional): Overwrite existing ROIs. Defaults to False. dest (str, optional): Destination directory for generated ROIs. If not provided, uses the dataset's default ROI directory. sq_mm_threshold (float, optional): If not None, filter out ROIs with an area less than the given threshold (in square millimeters). Defaults to None. """ # Load the model configuration. segment = sf.slide.qc.StridedSegment(model) for slide in track(self.slide_paths(), description='Generating...'): # Set the destination directory source = self.get_slide_source(slide) if 'roi' not in self.sources[source] and dest is None: raise errors.DatasetError( "No ROI directory set. Please set an ROI directory in the " "dataset configuration, or provide a destination directory " "with the `dest` argument." ) if dest is None: dest = self.sources[source]['roi'] if not exists(dest): os.makedirs(dest) # Check if an ROI already exists. existing_rois = [path_to_name(f) for f in os.listdir(dest) if f.endswith('csv')] if path_to_name(slide) in existing_rois: if overwrite: log.info(f"Overwriting ROI for slide {path_to_name(slide)} at {dest}") else: log.info(f"ROI already exists for slide {path_to_name(slide)} at {dest}") continue # Load the slide and remove any existing auto-loaded ROIs. log.info("Working on {}...".format(slide)) try: wsi = sf.WSI(slide, 299, 512, verbose=False) wsi.rois = [] # Generate and apply ROIs. segment.generate_rois(wsi, apply=True, **kwargs) except Exception as e: log.error(f"Failed to generate ROIs for {slide}: {e}") continue # Export ROIs to CSV. wsi.export_rois(join(dest, wsi.name + '.csv')) def get_slide_source(self, slide: str) -> str: """Return the source of a given slide. Args: slide (str): Slide name. Returns: str: Source name. """ for source in self.sources: paths = self.slide_paths(source=source) names = [path_to_name(path) for path in paths] if slide in paths or slide in names: return source raise errors.DatasetError(f"Could not find slide '{slide}'") def get_tfrecord_locations(self, slide: str) -> List[Tuple[int, int]]: """Return a list of locations stored in an associated TFRecord. Args: slide (str): Slide name. Returns: List of tuples of (x, y) coordinates. """ tfr = self.find_tfrecord(slide=slide) if tfr is None: raise errors.TFRecordsError( f"Could not find associated TFRecord for slide '{slide}'" ) tfr_idx = sf.util.tfrecord2idx.find_index(tfr) if not tfr_idx: _create_index(tfr) elif tfr_idx.endswith('index'): log.info(f"Updating index for {tfr}...") os.remove(tfr_idx) _create_index(tfr) return sf.io.get_locations_from_tfrecord(tfr) def harmonize_labels( self, *args: "Dataset", header: Optional[str] = None ) -> Dict[str, int]: """Harmonize labels with another dataset. Returns categorical label assignments converted to int, harmonized with another dataset to ensure label consistency between datasets. Args: *args (:class:`slideflow.Dataset`): Any number of Datasets. header (str): Categorical annotation header. Returns: Dict mapping slide names to categories. """ if header is None: raise ValueError("Must supply kwarg 'header'") if not isinstance(header, str): raise ValueError('Harmonized labels require a single header.') _, my_unique = self.labels(header, use_float=False) other_uniques = [ np.array(dts.labels(header, use_float=False)[1]) for dts in args ] other_uniques = other_uniques + [np.array(my_unique)] uniques_list = np.concatenate(other_uniques).tolist() all_unique = sorted(list(set(uniques_list))) labels_to_int = dict(zip(all_unique, range(len(all_unique)))) return labels_to_int def is_float(self, header: str) -> bool: """Check if labels in the given header can all be converted to float. Args: header (str): Annotations column header. Returns: bool: If all values from header can be converted to float. """ if self.annotations is None: raise errors.DatasetError("Annotations not loaded.") filtered_labels = self.filtered_annotations[header] try: filtered_labels = [float(o) for o in filtered_labels] return True except ValueError: return False def kfold_split( self, k: int, *, labels: Optional[Union[Dict, str]] = None, preserved_site: bool = False, site_labels: Optional[Union[str, Dict[str, str]]] = 'site', splits: Optional[str] = None, read_only: bool = False, ) -> Tuple[Tuple["Dataset", "Dataset"], ...]: """Split the dataset into k cross-folds. Args: k (int): Number of cross-folds. Keyword args: labels (dict or str, optional): Either a dictionary mapping slides to labels, or an outcome label (``str``). Used for balancing outcome labels in training and validation cohorts. If None, will not balance k-fold splits by outcome labels. Defaults to None. preserved_site (bool): Split with site-preserved cross-validation. Defaults to False. site_labels (dict, optional): Dict mapping patients to site labels, or an outcome column with site labels. Only used for site preserved cross validation. Defaults to 'site'. splits (str, optional): Path to JSON file containing validation splits. Defaults to None. read_only (bool): Prevents writing validation splits to file. Defaults to False. """ if splits is None: temp_dir = tempfile.TemporaryDirectory() splits = join(temp_dir.name, '_splits.json') else: temp_dir = None crossval_splits = [] for k_fold_iter in range(k): split_kw = dict( labels=labels, val_strategy=('k-fold-preserved-site' if preserved_site else 'k-fold'), val_k_fold=k, k_fold_iter=k_fold_iter+1, site_labels=site_labels, splits=splits, read_only=read_only ) crossval_splits.append(self.split(**split_kw)) if temp_dir is not None: temp_dir.cleanup() return tuple(crossval_splits) def labels( self, headers: Union[str, List[str]], use_float: Union[bool, Dict, str] = False, assign: Optional[Dict[str, Dict[str, int]]] = None, format: str = 'index' ) -> Tuple[Labels, Union[Dict[str, Union[List[str], List[float]]], List[str], List[float]]]: """Return a dict of slide names mapped to patient id and label(s). Args: headers (list(str)) Annotation header(s) that specifies label. May be a list or string. use_float (bool, optional) Either bool, dict, or 'auto'. If true, convert data into float; if unable, raise TypeError. If false, interpret all data as categorical. If a dict(bool), look up each header to determine type. If 'auto', will try to convert all data into float. For each header in which this fails, will interpret as categorical. assign (dict, optional): Dictionary mapping label ids to label names. If not provided, will map ids to names by sorting alphabetically. format (str, optional): Either 'index' or 'name.' Indicates which format should be used for categorical outcomes when returning the label dictionary. If 'name', uses the string label name. If 'index', returns an int (index corresponding with the returned list of unique outcomes as str). Defaults to 'index'. Returns: A tuple containing **dict**: Dictionary mapping slides to outcome labels in numerical format (float for continuous outcomes, int of outcome label id for categorical outcomes). **list**: List of unique labels. For categorical outcomes, this will be a list of str; indices correspond with the outcome label id. """ if self.annotations is None: raise errors.DatasetError("Annotations not loaded.") if not len(self.filtered_annotations): raise errors.DatasetError( "Cannot generate labels: dataset is empty after filtering." ) results = {} # type: Dict headers = sf.util.as_list(headers) unique_labels = {} filtered_pts = self.filtered_annotations.patient filtered_slides = self.filtered_annotations.slide for header in headers: if assign and (len(headers) > 1 or header in assign): assigned_for_header = assign[header] elif assign is not None: raise errors.DatasetError( f"Unable to read outcome assignments for header {header}" f" (assign={assign})" ) else: assigned_for_header = None unique_labels_for_this_header = [] try: filtered_labels = self.filtered_annotations[header] except KeyError: raise errors.AnnotationsError(f"Missing column {header}.") # Determine whether values should be converted into float if isinstance(use_float, dict) and header not in use_float: raise ValueError( f"use_float is dict, but header {header} is missing." ) elif isinstance(use_float, dict): header_is_float = use_float[header] elif isinstance(use_float, bool): header_is_float = use_float elif use_float == 'auto': header_is_float = self.is_float(header) else: raise ValueError(f"Invalid use_float option {use_float}") # Ensure labels can be converted to desired type, # then assign values if header_is_float and not self.is_float(header): raise TypeError( f"Unable to convert all labels of {header} into 'float' " f"({','.join(filtered_labels)})." ) elif header_is_float: log.debug(f'Interpreting column "{header}" as continuous') filtered_labels = filtered_labels.astype(float) else: log.debug(f'Interpreting column "{header}" as categorical') unique_labels_for_this_header = list(set(filtered_labels)) unique_labels_for_this_header.sort() for i, ul in enumerate(unique_labels_for_this_header): n_matching_filtered = sum(f == ul for f in filtered_labels) if assigned_for_header and ul not in assigned_for_header: raise KeyError( f"assign was provided, but label {ul} missing" ) elif assigned_for_header: val_msg = assigned_for_header[ul] n_s = str(n_matching_filtered) log.debug( f"{header} {ul} assigned {val_msg} [{n_s} slides]" ) else: n_s = str(n_matching_filtered) log.debug( f"{header} {ul} assigned {i} [{n_s} slides]" ) def _process_cat_label(o): if assigned_for_header: return assigned_for_header[o] elif format == 'name': return o else: return unique_labels_for_this_header.index(o) # Check for multiple, different labels per patient and warn pt_assign = np.array(list(set(zip(filtered_pts, filtered_labels)))) unique_pt, counts = np.unique(pt_assign[:, 0], return_counts=True) for pt in unique_pt[np.argwhere(counts > 1)][:, 0]: dup_vals = pt_assign[pt_assign[:, 0] == pt][:, 1] dups = ", ".join([str(d) for d in dup_vals]) log.error( f'Multiple labels for patient "{pt}" (header {header}): ' f'{dups}' ) # Assemble results dictionary for slide, lbl in zip(filtered_slides, filtered_labels): if slide in sf.util.EMPTY: continue if not header_is_float: lbl = _process_cat_label(lbl) if slide in results: results[slide] = sf.util.as_list(results[slide]) results[slide] += [lbl] elif header_is_float: results[slide] = [lbl] else: results[slide] = lbl unique_labels[header] = unique_labels_for_this_header if len(headers) == 1: return results, unique_labels[headers[0]] else: return results, unique_labels def load_indices(self, verbose=False) -> Dict[str, np.ndarray]: """Return TFRecord indices.""" pool = DPool(8) tfrecords = self.tfrecords() indices = {} def load_index(tfr): tfr_name = path_to_name(tfr) index = tfrecord2idx.load_index(tfr) return tfr_name, index log.debug("Loading indices...") for tfr_name, index in pool.imap(load_index, tfrecords): indices[tfr_name] = index pool.close() return indices def manifest( self, key: str = 'path', filter: bool = True ) -> Dict[str, Dict[str, int]]: """Generate a manifest of all tfrecords. Args: key (str): Either 'path' (default) or 'name'. Determines key format in the manifest dictionary. filter (bool): Apply active filters to manifest. Returns: dict: Dict mapping key (path or slide name) to number of tiles. """ if key not in ('path', 'name'): raise ValueError("'key' must be in ['path, 'name']") all_manifest = {} for source in self.sources: if self.sources[source]['label'] is None: continue if not self._tfrecords_set(source): log.warning(f"tfrecords path not set for source {source}") continue tfrecord_dir = join( self.sources[source]['tfrecords'], self.sources[source]['label'] ) manifest_path = join(tfrecord_dir, "manifest.json") if not exists(manifest_path): log.debug(f"No manifest at {tfrecord_dir}; creating now") sf.io.update_manifest_at_dir(tfrecord_dir) if exists(manifest_path): relative_manifest = sf.util.load_json(manifest_path) else: relative_manifest = {} global_manifest = {} for record in relative_manifest: k = join(tfrecord_dir, record) global_manifest.update({k: relative_manifest[record]}) all_manifest.update(global_manifest) # Now filter out any tfrecords that would be excluded by filters if filter: filtered_tfrecords = self.tfrecords() manifest_tfrecords = list(all_manifest.keys()) for tfr in manifest_tfrecords: if tfr not in filtered_tfrecords: del all_manifest[tfr] # Log clipped tile totals if applicable for tfr in all_manifest: if tfr in self._clip: all_manifest[tfr]['clipped'] = min(self._clip[tfr], all_manifest[tfr]['total']) else: all_manifest[tfr]['clipped'] = all_manifest[tfr]['total'] if key == 'path': return all_manifest else: return {path_to_name(t): v for t, v in all_manifest.items()} def manifest_histogram( self, by: Optional[str] = None, binrange: Optional[Tuple[int, int]] = None ) -> None: """Plot histograms of tiles-per-slide. Example Create histograms of tiles-per-slide, stratified by site. .. code-block:: python import matplotlib.pyplot as plt dataset.manifest_histogram(by='site') plt.show() Args: by (str, optional): Stratify histograms by this annotation column header. Defaults to None. binrange (tuple(int, int)): Histogram bin ranges. If None, uses full range. Defaults to None. """ import seaborn as sns import matplotlib.pyplot as plt if by is not None: _, unique_vals = self.labels(by, format='name') val_counts = [ [ m['total'] for m in self.filter({by: val}).manifest().values() ] for val in unique_vals ] all_counts = [c for vc in val_counts for c in vc] else: unique_vals = [''] all_counts = [m['total'] for m in self.manifest().values()] val_counts = [all_counts] if binrange is None: max_count = (max(all_counts) // 20) * 20 binrange = (0, max_count) fig, axes = plt.subplots(len(unique_vals), 1, figsize=(3, len(unique_vals))) if not isinstance(axes, np.ndarray): axes = [axes] fig.set_tight_layout({"pad": .0}) for a, ax in enumerate(axes): sns.histplot(val_counts[a], bins=20, binrange=binrange, ax=ax) ax.yaxis.set_tick_params(labelleft=False) ax.set_ylabel(unique_vals[a], rotation='horizontal', ha='right') ax.set_xlim(binrange) if a != (len(axes) - 1): ax.xaxis.set_tick_params(labelbottom=False) ax.set(xlabel=None) ax.set(xlabel="Tiles per slide") def patients(self) -> Dict[str, str]: """Return a list of patient IDs from this dataset.""" if self.annotations is None: raise errors.DatasetError("Annotations not loaded.") result = {} # type: Dict[str, str] pairs = list(zip( self.filtered_annotations['slide'], self.filtered_annotations['patient'] )) for slide, patient in pairs: if slide in result and result[slide] != patient: raise errors.AnnotationsError( f'Slide "{slide}" assigned to multiple patients: ' f"({patient}, {result[slide]})" ) else: if slide not in sf.util.EMPTY: result[slide] = patient return result def pt_files(self, *args, **kwargs): """Deprecated function. Please use `Dataset.get_bags()`.""" warnings.warn( "pt_files() is deprecated. Please use Dataset.get_bags()", DeprecationWarning ) return self.get_bags(*args, **kwargs) def get_bags(self, path, warn_missing=True): """Return list of all \*.pt files with slide names in this dataset. May return more than one \*.pt file for each slide. Args: path (str, list(str)): Directory(ies) to search for \*.pt files. warn_missing (bool): Raise a warning if any slides in this dataset do not have a \*.pt file. """ slides = self.slides() if isinstance(path, str): path = [path] bags = [] for p in path: if not exists(p): raise ValueError(f"Path {p} does not exist.") bags_at_path = np.array([ join(p, f) for f in os.listdir(p) if f.endswith('.pt') and path_to_name(f) in slides ]) bags.append(bags_at_path) bags = np.concatenate(bags) unique_slides_with_bags = np.unique([path_to_name(b) for b in bags]) if (len(unique_slides_with_bags) != len(slides)) and warn_missing: log.warning(f"Bags missing for {len(slides) - len(unique_slides_with_bags)} slides.") return bags def read_tfrecord_by_location( self, slide: str, loc: Tuple[int, int], decode: Optional[bool] = None ) -> Any: """Read a record from a TFRecord, indexed by location. Finds the associated TFRecord for a slide, and returns the record inside which corresponds to a given tile location. Args: slide (str): Name of slide. Will search for the slide's associated TFRecord. loc ((int, int)): ``(x, y)`` tile location. Searches the TFRecord for the tile that corresponds to this location. decode (bool): Decode the associated record, returning Tensors. Defaults to True. Returns: Unprocessed raw TFRecord bytes if ``decode=False``, otherwise a tuple containing ``(slide, image)``, where ``image`` is a uint8 Tensor. """ tfr = self.find_tfrecord(slide=slide) if tfr is None: raise errors.TFRecordsError( f"Could not find associated TFRecord for slide '{slide}'" ) if decode is None: decode = True else: warnings.warn( "The 'decode' argument to `Dataset.read_tfrecord_by_location` " "is deprecated and will be removed in a future version. In the " "future, all records will be decoded." ) return sf.io.get_tfrecord_by_location(tfr, loc, decode=decode) def remove_filter(self, **kwargs: Any) -> "Dataset": """Remove a specific filter from the active filters. Keyword Args: filters (list of str): Filter keys. Will remove filters with these keys. filter_blank (list of str): Will remove these headers stored in filter_blank. Returns: :class:`slideflow.Dataset`: Dataset with filter removed. """ for kwarg in kwargs: if kwarg not in ('filters', 'filter_blank'): raise ValueError(f'Unknown filtering argument {kwarg}') ret = copy.deepcopy(self) if 'filters' in kwargs: if isinstance(kwargs['filters'], str): kwargs['filters'] = [kwargs['filters']] elif not isinstance(kwargs['filters'], list): raise TypeError("'filters' must be a list.") for f in kwargs['filters']: if f not in ret._filters: raise errors.DatasetFilterError( f"Filter {f} not found in dataset (active filters:" f"{','.join(list(ret._filters.keys()))})" ) else: del ret._filters[f] if 'filter_blank' in kwargs: kwargs['filter_blank'] = sf.util.as_list(kwargs['filter_blank']) for f in kwargs['filter_blank']: if f not in ret._filter_blank: raise errors.DatasetFilterError( f"Filter_blank {f} not found in dataset (active " f"filter_blank: {','.join(ret._filter_blank)})" ) elif isinstance(ret._filter_blank, dict): del ret._filter_blank[ret._filter_blank.index(f)] return ret def rebuild_index(self) -> None: """Rebuild index files for TFRecords. Equivalent to ``Dataset.build_index(force=True)``. Args: None Returns: None """ self.build_index(force=True) def resize_tfrecords(self, tile_px: int) -> None: """Resize images in a set of TFRecords to a given pixel size. Args: tile_px (int): Target pixel size for resizing TFRecord images. """ if not sf.util.tf_available: raise NotImplementedError( "Dataset.resize_tfrecords() requires Tensorflow, which is " "not installed.") log.info(f'Resizing TFRecord tiles to ({tile_px}, {tile_px})') tfrecords_list = self.tfrecords() log.info(f'Resizing {len(tfrecords_list)} tfrecords') for tfr in tfrecords_list: sf.io.tensorflow.transform_tfrecord( tfr, tfr+'.transformed', resize=tile_px ) def rois(self) -> List[str]: """Return a list of all ROIs.""" rois_list = [] for source in self.sources: if self._roi_set(source): rois_list += glob(join(self.sources[source]['roi'], "*.csv")) else: log.warning(f"roi path not set for source {source}") slides = self.slides() return [r for r in list(set(rois_list)) if path_to_name(r) in slides] def slide_manifest( self, roi_method: str = 'auto', stride_div: int = 1, tma: bool = False, source: Optional[str] = None, low_memory: bool = False ) -> Dict[str, int]: """Return a dictionary of slide names and estimated number of tiles. Uses Otsu thresholding for background filtering, and the ROI strategy. Args: roi_method (str): Either 'inside', 'outside', 'auto', or 'ignore'. Determines how ROIs are used to extract tiles. If 'inside' or 'outside', will extract tiles in/out of an ROI, and skip a slide if an ROI is not available. If 'auto', will extract tiles inside an ROI if available, and across the whole-slide if no ROI is found. If 'ignore', will extract tiles across the whole-slide regardless of whether an ROI is available. Defaults to 'auto'. stride_div (int): Stride divisor for tile extraction. A stride of 1 will extract non-overlapping tiles. A stride_div of 2 will extract overlapping tiles, with a stride equal to 50% of the tile width. Defaults to 1. tma (bool): Deprecated argument. Tumor micro-arrays are read as standard slides. Defaults to False. source (str, optional): Dataset source name. Defaults to None (using all sources). low_memory (bool): Operate in low-memory mode at the cost of worse performance. Returns: Dict[str, int]: Dictionary mapping slide names to number of estimated non-background tiles in the slide. """ if tma: warnings.warn( "tma=True is deprecated and will be removed in a future " "version. Tumor micro-arrays are read as standard slides. " ) if self.tile_px is None or self.tile_um is None: raise errors.DatasetError( "tile_px and tile_um must be set to calculate a slide manifest" ) paths = self.slide_paths(source=source) pb = Progress(transient=True) read_task = pb.add_task('Reading slides...', total=len(paths)) if not low_memory: otsu_task = pb.add_task("Otsu thresholding...", total=len(paths)) pb.start() pool = mp.Pool( sf.util.num_cpu(default=16), initializer=sf.util.set_ignore_sigint ) wsi_list = [] to_remove = [] counts = [] for path in paths: try: wsi = sf.WSI( path, self.tile_px, self.tile_um, rois=self.rois(), stride_div=stride_div, roi_method=roi_method, verbose=False) if low_memory: wsi.qc('otsu') counts += [wsi.estimated_num_tiles] else: wsi_list += [wsi] pb.advance(read_task) except errors.SlideLoadError as e: log.error(f"Error reading slide {path}: {e}") to_remove += [path] for path in to_remove: paths.remove(path) pb.update(read_task, total=len(paths)) pb.update(otsu_task, total=len(paths)) if not low_memory: for count in pool.imap(_count_otsu_tiles, wsi_list): counts += [count] pb.advance(otsu_task) pb.stop() pool.close() return {path: counts[p] for p, path in enumerate(paths)} def slide_paths( self, source: Optional[str] = None, apply_filters: bool = True ) -> List[str]: """Return a list of paths to slides. Either returns a list of paths to all slides, or slides only matching dataset filters. Args: source (str, optional): Dataset source name. Defaults to None (using all sources). filter (bool, optional): Return only slide paths meeting filter criteria. If False, return all slides. Defaults to True. Returns: list(str): List of slide paths. """ if source and source not in self.sources.keys(): raise errors.DatasetError(f"Dataset {source} not found.") # Get unfiltered paths if source: if not self._slides_set(source): log.warning(f"slides path not set for source {source}") return [] else: paths = sf.util.get_slide_paths(self.sources[source]['slides']) else: paths = [] for src in self.sources: if not self._slides_set(src): log.warning(f"slides path not set for source {src}") else: paths += sf.util.get_slide_paths( self.sources[src]['slides'] ) # Remove any duplicates from shared dataset paths paths = list(set(paths)) # Filter paths if apply_filters: filtered_slides = self.slides() filtered_paths = [ p for p in paths if path_to_name(p) in filtered_slides ] return filtered_paths else: return paths def slides(self) -> List[str]: """Return a list of slide names in this dataset.""" if self.annotations is None: raise errors.AnnotationsError( "No annotations loaded; is the annotations file empty?" ) if 'slide' not in self.annotations.columns: raise errors.AnnotationsError( f"{'slide'} not found in annotations file." ) ann = self.filtered_annotations ann = ann.loc[~ann.slide.isin(sf.util.EMPTY)] slides = ann.slide.unique().tolist() return slides def split( self, model_type: Optional[str] = None, labels: Optional[Union[Dict, str]] = None, val_strategy: str = 'fixed', splits: Optional[str] = None, val_fraction: Optional[float] = None, val_k_fold: Optional[int] = None, k_fold_iter: Optional[int] = None, site_labels: Optional[Union[str, Dict[str, str]]] = 'site', read_only: bool = False, from_wsi: bool = False, ) -> Tuple["Dataset", "Dataset"]: """Split this dataset into a training and validation dataset. If a validation split has already been prepared (e.g. K-fold iterations were already determined), the previously generated split will be used. Otherwise, create a new split and log the result in the TFRecord directory so future models may use the same split for consistency. Args: model_type (str): Either 'classification' or 'regression'. Defaults to 'classification' if ``labels`` is provided. labels (dict or str): Either a dictionary of slides: labels, or an outcome label (``str``). Used for balancing outcome labels in training and validation cohorts. Defaults to None. val_strategy (str): Either 'k-fold', 'k-fold-preserved-site', 'bootstrap', or 'fixed'. Defaults to 'fixed'. splits (str, optional): Path to JSON file containing validation splits. Defaults to None. outcome_key (str, optional): Key indicating outcome label in slide_labels_dict. Defaults to 'outcome_label'. val_fraction (float, optional): Proportion of data for validation. Not used if strategy is k-fold. Defaults to None. val_k_fold (int): K, required if using K-fold validation. Defaults to None. k_fold_iter (int, optional): Which K-fold iteration to generate starting at 1. Fequired if using K-fold validation. Defaults to None. site_labels (dict, optional): Dict mapping patients to site labels, or an outcome column with site labels. Only used for site preserved cross validation. Defaults to 'site'. read_only (bool): Prevents writing validation splits to file. Defaults to False. Returns: A tuple containing :class:`slideflow.Dataset`: Training dataset. :class:`slideflow.Dataset`: Validation dataset. """ if (not k_fold_iter and 'k-fold' in val_strategy): raise errors.DatasetSplitError( "If strategy is 'k-fold', must supply k_fold_iter " "(int starting at 1)" ) if (not val_k_fold and 'k-fold' in val_strategy): raise errors.DatasetSplitError( "If strategy is 'k-fold', must supply val_k_fold (K)" ) if val_strategy == 'k-fold-preserved-site' and not site_labels: raise errors.DatasetSplitError( "k-fold-preserved-site requires site_labels (dict of " "patients:sites, or name of annotation column header" ) if (val_strategy == 'k-fold-preserved-site' and isinstance(site_labels, str)): site_labels, _ = self.labels(site_labels, format='name') if val_strategy == 'k-fold-preserved-site' and site_labels is None: raise errors.DatasetSplitError( f"Must supply site_labels for strategy {val_strategy}" ) if val_strategy in ('bootstrap', 'fixed') and val_fraction is None: raise errors.DatasetSplitError( f"Must supply val_fraction for strategy {val_strategy}" ) if isinstance(labels, str): labels = self.labels(labels)[0] if labels is None and model_type is None: labels = self.patients() model_type = 'regression' elif model_type is None: model_type = 'classification' if model_type == 'categorical': warnings.warn( "model_type='categorical' is deprecated. Please use " "'classification' instead." ) model_type = 'classification' if model_type == 'linear': warnings.warn( "model_type='linear' is deprecated. Please use " "'regression' instead." ) model_type = 'regression' if model_type not in ('classification', 'regression'): raise ValueError( f"Invalid model_type {model_type}; must be either " "'classification' or 'regression'" ) # Prepare dataset patients = self.patients() splits_file = splits training_tfr = [] val_tfr = [] accepted_split = None slide_list = list(labels.keys()) # Assemble dict of patients linking to list of slides & outcome labels # dataset.labels() ensures no duplicate labels for a single patient tfr_dir_list = self.tfrecords() if not from_wsi else self.slide_paths() skip_tfr_verification = False if not len(tfr_dir_list) and not from_wsi: log.warning("No tfrecords found; splitting from annotations only.") tfr_dir_list = tfr_dir_list_names = self.slides() skip_tfr_verification = True elif not len(tfr_dir_list): log.warning("No slides found; splitting from annotations only.") tfr_dir_list = tfr_dir_list_names = self.slides() skip_tfr_verification = True else: tfr_dir_list_names = [ sf.util.path_to_name(tfr) for tfr in tfr_dir_list ] patients_dict = {} num_warned = 0 for slide in slide_list: patient = slide if not patients else patients[slide] # Skip slides not found in directory if slide not in tfr_dir_list_names: log.debug(f"Slide {slide} missing tfrecord, skipping") num_warned += 1 continue if patient not in patients_dict: patients_dict[patient] = { 'outcome_label': labels[slide], 'slides': [slide] } elif patients_dict[patient]['outcome_label'] != labels[slide]: ol = patients_dict[patient]['outcome_label'] ok = labels[slide] raise errors.DatasetSplitError( f"Multiple labels found for {patient} ({ol}, {ok})" ) else: patients_dict[patient]['slides'] += [slide] # Add site labels to the patients dict if doing # preserved-site cross-validation if val_strategy == 'k-fold-preserved-site': assert site_labels is not None site_slide_list = list(site_labels.keys()) for slide in site_slide_list: patient = slide if not patients else patients[slide] # Skip slides not found in directory if slide not in tfr_dir_list_names: continue if 'site' not in patients_dict[patient]: patients_dict[patient]['site'] = site_labels[slide] elif patients_dict[patient]['site'] != site_labels[slide]: ol = patients_dict[patient]['slide'] ok = site_labels[slide] _tail = f"{patient} ({ol}, {ok})" raise errors.DatasetSplitError( f"Multiple site labels found for {_tail}" ) if num_warned: log.warning(f"{num_warned} slides missing tfrecords, skipping") patients_list = list(patients_dict.keys()) sorted_patients = [p for p in patients_list] sorted_patients.sort() shuffle(patients_list) # Create and log a validation subset if val_strategy == 'none': log.info("val_strategy is None; skipping validation") train_slides = np.concatenate([ patients_dict[patient]['slides'] for patient in patients_dict.keys() ]).tolist() val_slides = [] elif val_strategy == 'bootstrap': assert val_fraction is not None num_val = int(val_fraction * len(patients_list)) log.info( f"Boostrap validation: selecting {num_val} " "patients at random for validation testing" ) val_patients = patients_list[0:num_val] train_patients = patients_list[num_val:] if not len(val_patients) or not len(train_patients): raise errors.InsufficientDataForSplitError val_slides = np.concatenate([ patients_dict[patient]['slides'] for patient in val_patients ]).tolist() train_slides = np.concatenate([ patients_dict[patient]['slides'] for patient in train_patients ]).tolist() else: # Try to load validation split if (not splits_file or not exists(splits_file)): loaded_splits = [] else: loaded_splits = sf.util.load_json(splits_file) for split_id, split in enumerate(loaded_splits): # First, see if strategy is the same if split['strategy'] != val_strategy: continue # If k-fold, check that k-fold length is the same if (val_strategy in ('k-fold', 'k-fold-preserved-site') and len(list(split['tfrecords'].keys())) != val_k_fold): continue # Then, check if patient lists are the same sp_pts = list(split['patients'].keys()) sp_pts.sort() if sp_pts == sorted_patients: # Finally, check if outcome variables are the same c1 = [patients_dict[p]['outcome_label'] for p in sp_pts] c2 = [split['patients'][p]['outcome_label']for p in sp_pts] if c1 == c2: log.info( f"Using {val_strategy} validation split detected" f" at [green]{splits_file}[/] (ID: {split_id})" ) accepted_split = split break # If no split found, create a new one if not accepted_split: if splits_file: log.info("No compatible train/val split found.") log.info(f"Logging new split at [green]{splits_file}") else: log.info("No training/validation splits file provided.") log.info("Unable to save or load validation splits.") new_split = { 'strategy': val_strategy, 'patients': patients_dict, 'tfrecords': {} } # type: Any if val_strategy == 'fixed': assert val_fraction is not None num_val = int(val_fraction * len(patients_list)) val_patients = patients_list[0:num_val] train_patients = patients_list[num_val:] if not len(val_patients) or not len(train_patients): raise errors.InsufficientDataForSplitError val_slides = np.concatenate([ patients_dict[patient]['slides'] for patient in val_patients ]).tolist() train_slides = np.concatenate([ patients_dict[patient]['slides'] for patient in train_patients ]).tolist() new_split['tfrecords']['validation'] = val_slides new_split['tfrecords']['training'] = train_slides elif val_strategy in ('k-fold', 'k-fold-preserved-site'): assert val_k_fold is not None if (val_strategy == 'k-fold-preserved-site'): k_fold_patients = split_patients_preserved_site( patients_dict, val_k_fold, balance=('outcome_label' if model_type == 'classification' else None) ) elif model_type == 'classification': k_fold_patients = split_patients_balanced( patients_dict, val_k_fold, balance='outcome_label' ) else: k_fold_patients = split_patients( patients_dict, val_k_fold ) # Verify at least one patient is in each k_fold group if (len(k_fold_patients) != val_k_fold or not min([len(pl) for pl in k_fold_patients])): raise errors.InsufficientDataForSplitError train_patients = [] for k in range(1, val_k_fold+1): new_split['tfrecords'][f'k-fold-{k}'] = np.concatenate( [patients_dict[patient]['slides'] for patient in k_fold_patients[k-1]] ).tolist() if k == k_fold_iter: val_patients = k_fold_patients[k-1] else: train_patients += k_fold_patients[k-1] val_slides = np.concatenate([ patients_dict[patient]['slides'] for patient in val_patients ]).tolist() train_slides = np.concatenate([ patients_dict[patient]['slides'] for patient in train_patients ]).tolist() else: raise errors.DatasetSplitError( f"Unknown validation strategy {val_strategy}." ) # Write the new split to log loaded_splits += [new_split] if not read_only and splits_file: sf.util.write_json(loaded_splits, splits_file) else: # Use existing split if val_strategy == 'fixed': val_slides = accepted_split['tfrecords']['validation'] train_slides = accepted_split['tfrecords']['training'] elif val_strategy in ('k-fold', 'k-fold-preserved-site'): assert val_k_fold is not None k_id = f'k-fold-{k_fold_iter}' val_slides = accepted_split['tfrecords'][k_id] train_slides = np.concatenate([ accepted_split['tfrecords'][f'k-fold-{ki}'] for ki in range(1, val_k_fold+1) if ki != k_fold_iter ]).tolist() else: raise errors.DatasetSplitError( f"Unknown val_strategy {val_strategy} requested." ) # Perform final integrity check to ensure no patients # are in both training and validation slides if patients: validation_pt = list(set([patients[s] for s in val_slides])) training_pt = list(set([patients[s] for s in train_slides])) else: validation_pt, training_pt = val_slides, train_slides if sum([pt in training_pt for pt in validation_pt]): raise errors.DatasetSplitError( "At least one patient is in both val and training sets." ) # Assemble list of tfrecords if val_strategy != 'none': val_tfr = [ tfr for tfr in tfr_dir_list if path_to_name(tfr) in val_slides or tfr in val_slides ] training_tfr = [ tfr for tfr in tfr_dir_list if path_to_name(tfr) in train_slides or tfr in train_slides ] if not len(val_tfr) == len(val_slides): raise errors.DatasetError( f"Number of validation tfrecords ({len(val_tfr)}) does " f"not match number of validation slides ({len(val_slides)}). " "This may happen if multiple tfrecords were found for a slide." ) if not len(training_tfr) == len(train_slides): raise errors.DatasetError( f"Number of training tfrecords ({len(training_tfr)}) does " f"not match number of training slides ({len(train_slides)}). " "This may happen if multiple tfrecords were found for a slide." ) training_dts = copy.deepcopy(self) training_dts = training_dts.filter(filters={'slide': train_slides}) val_dts = copy.deepcopy(self) val_dts = val_dts.filter(filters={'slide': val_slides}) if not skip_tfr_verification and not from_wsi: assert sorted(training_dts.tfrecords()) == sorted(training_tfr) assert sorted(val_dts.tfrecords()) == sorted(val_tfr) elif not skip_tfr_verification: assert sorted(training_dts.slide_paths()) == sorted(training_tfr) assert sorted(val_dts.slide_paths()) == sorted(val_tfr) return training_dts, val_dts def split_tfrecords_by_roi( self, destination: str, roi_filter_method: Union[str, float] = 'center' ) -> None: """Split dataset tfrecords into separate tfrecords according to ROI. Will generate two sets of tfrecords, with identical names: one with tiles inside the ROIs, one with tiles outside the ROIs. Will skip any tfrecords that are missing ROIs. Requires slides to be available. Args: destination (str): Destination path. roi_filter_method (str or float): Method of filtering tiles with ROIs. Either 'center' or float (0-1). If 'center', tiles are filtered with ROIs based on the center of the tile. If float, tiles are filtered based on the proportion of the tile inside the ROI, and ``roi_filter_method`` is interpreted as a threshold. If the proportion of a tile inside the ROI is greater than this number, the tile is included. For example, if ``roi_filter_method=0.7``, a tile that is 80% inside of an ROI will be included, and a tile that is 50% inside of an ROI will be excluded. Defaults to 'center'. Returns: None """ tfrecords = self.tfrecords() slides = {path_to_name(s): s for s in self.slide_paths()} rois = self.rois() manifest = self.manifest() if self.tile_px is None or self.tile_um is None: raise errors.DatasetError( "tile_px and tile_um must be non-zero to process TFRecords." ) for tfr in tfrecords: slidename = path_to_name(tfr) if slidename not in slides: continue try: slide = WSI( slides[slidename], self.tile_px, self.tile_um, rois=rois, roi_method='inside', roi_filter_method=roi_filter_method ) except errors.SlideLoadError as e: log.error(e) continue parser = sf.io.get_tfrecord_parser( tfr, decode_images=False, to_numpy=True ) if parser is None: log.error(f"Could not read TFRecord {tfr}; skipping") continue reader = sf.io.TFRecordDataset(tfr) if not exists(join(destination, 'inside')): os.makedirs(join(destination, 'inside')) if not exists(join(destination, 'outside')): os.makedirs(join(destination, 'outside')) in_path = join(destination, 'inside', f'{slidename}.tfrecords') out_path = join(destination, 'outside', f'{slidename}.tfrecords') inside_roi_writer = sf.io.TFRecordWriter(in_path) outside_roi_writer = sf.io.TFRecordWriter(out_path) for record in track(reader, total=manifest[tfr]['total']): parsed = parser(record) loc_x, loc_y = parsed['loc_x'], parsed['loc_y'] tile_in_roi = any([ roi.poly.contains(sg.Point(loc_x, loc_y)) for roi in slide.rois ]) # Convert from a Tensor -> Numpy array if hasattr(record, 'numpy'): record = record.numpy() if tile_in_roi: inside_roi_writer.write(record) else: outside_roi_writer.write(record) inside_roi_writer.close() outside_roi_writer.close() def summary(self) -> None: """Print a summary of this dataset.""" # Get ROI information. patients = self.patients() has_rois = defaultdict(bool) slides_with_roi = {} patients_with_roi = defaultdict(bool) for r in self.rois(): s = sf.util.path_to_name(r) with open(r, 'r') as f: has_rois[s] = len(f.read().split('\n')) > 2 for sp in self.slide_paths(): s = sf.util.path_to_name(sp) slides_with_roi[s] = has_rois[s] for s in self.slides(): p = patients[s] if s in slides_with_roi and slides_with_roi[s]: patients_with_roi[p] = True # Print summary. if self.annotations is not None: num_patients = len(self.annotations.patient.unique()) else: num_patients = 0 print("Overview:") table = [("Configuration file:", self._config), ("Tile size (px):", self.tile_px), ("Tile size (um):", self.tile_um), ("Slides:", len(self.slides())), ("Patients:", num_patients), ("Slides with ROIs:", len([s for s in slides_with_roi if slides_with_roi[s]])), ("Patients with ROIs:", len([p for p in patients_with_roi if patients_with_roi[p]]))] print(tabulate(table, tablefmt='fancy_outline')) print("\nFilters:") table = [("Filters:", pformat(self.filters)), ("Filter Blank:", pformat(self.filter_blank)), ("Min Tiles:", pformat(self.min_tiles))] print(tabulate(table, tablefmt='fancy_grid')) print("\nSources:") if not self.sources: print("<None>") else: for source in self.sources: print(f"\n{source}") d = self.sources[source] print(tabulate(zip(d.keys(), d.values()), tablefmt="fancy_outline")) print("\nNumber of tiles in TFRecords:", self.num_tiles) print("Annotation columns:") print("<NA>" if self.annotations is None else self.annotations.columns) def tensorflow( self, labels: Labels = None, batch_size: Optional[int] = None, from_wsi: bool = False, **kwargs: Any ) -> "tf.data.Dataset": """Return a Tensorflow Dataset object that interleaves tfrecords. The returned dataset yields a batch of (image, label) for each tile. Labels may be specified either via a dict mapping slide names to outcomes, or a parsing function which accept and image and slide name, returning a dict {'image_raw': image(tensor)} and label (int or float). Args: labels (dict or str, optional): Dict or function. If dict, must map slide names to outcome labels. If function, function must accept an image (tensor) and slide name (str), and return a dict {'image_raw': image (tensor)} and label (int or float). If not provided, all labels will be None. batch_size (int): Batch size. Keyword Args: augment (str or bool): Image augmentations to perform. Augmentations include: * ``'x'``: Random horizontal flip * ``'y'``: Random vertical flip * ``'r'``: Random 90-degree rotation * ``'j'``: Random JPEG compression (50% chance to compress with quality between 50-100) * ``'b'``: Random Gaussian blur (10% chance to blur with sigma between 0.5-2.0) * ``'n'``: Random :ref:`stain_augmentation` (requires stain normalizer) Combine letters to define augmentations, such as ``'xyrjn'``. A value of True will use ``'xyrjb'``. deterministic (bool, optional): When num_parallel_calls is specified, if this boolean is specified, it controls the order in which the transformation produces elements. If set to False, the transformation is allowed to yield elements out of order to trade determinism for performance. Defaults to False. drop_last (bool, optional): Drop the last non-full batch. Defaults to False. from_wsi (bool): Generate predictions from tiles dynamically extracted from whole-slide images, rather than TFRecords. Defaults to False (use TFRecords). incl_loc (str, optional): 'coord', 'grid', or None. Return (x,y) origin coordinates ('coord') for each tile center along with tile images, or the (x,y) grid coordinates for each tile. Defaults to 'coord'. incl_slidenames (bool, optional): Include slidenames as third returned variable. Defaults to False. infinite (bool, optional): Create an finite dataset. WARNING: If infinite is False && balancing is used, some tiles will be skipped. Defaults to True. img_size (int): Image width in pixels. normalizer (:class:`slideflow.norm.StainNormalizer`, optional): Normalizer to use on images. Defaults to None. num_parallel_reads (int, optional): Number of parallel reads for each TFRecordDataset. Defaults to 4. num_shards (int, optional): Shard the tfrecord datasets, used for multiprocessing datasets. Defaults to None. pool (multiprocessing.Pool): Shared multiprocessing pool. Useful if ``from_wsi=True``, for sharing a unified processing pool between dataloaders. Defaults to None. rois (list(str), optional): List of ROI paths. Only used if from_wsi=True. Defaults to None. roi_method (str, optional): Method for extracting ROIs. Only used if from_wsi=True. Defaults to 'auto'. shard_idx (int, optional): Index of the tfrecord shard to use. Defaults to None. standardize (bool, optional): Standardize images to (0,1). Defaults to True. tile_um (int, optional): Size of tiles to extract from WSI, in microns. Only used if from_wsi=True. Defaults to None. tfrecord_parser (Callable, optional): Custom parser for TFRecords. Defaults to None. transform (Callable, optional): Arbitrary transform function. Performs transformation after augmentations but before standardization. Defaults to None. **decode_kwargs (dict): Keyword arguments to pass to :func:`slideflow.io.tensorflow.decode_image`. Returns: tf.data.Dataset """ from slideflow.io.tensorflow import interleave if self.tile_px is None: raise errors.DatasetError("tile_px and tile_um must be non-zero" "to create dataloaders.") if self.prob_weights is not None and from_wsi: log.warning("Dataset balancing is disabled when `from_wsi=True`") if self._clip not in (None, {}) and from_wsi: log.warning("Dataset clipping is disabled when `from_wsi=True`") if from_wsi: tfrecords = self.slide_paths() kwargs['rois'] = self.rois() kwargs['tile_um'] = self.tile_um kwargs['from_wsi'] = True prob_weights = None clip = None else: tfrecords = self.tfrecords() prob_weights = self.prob_weights clip = self._clip if not tfrecords: raise errors.TFRecordsNotFoundError self.verify_img_format(progress=False) return interleave(paths=tfrecords, labels=labels, img_size=self.tile_px, batch_size=batch_size, prob_weights=prob_weights, clip=clip, **kwargs) def tfrecord_report( self, dest: str, normalizer: Optional["StainNormalizer"] = None ) -> None: """Create a PDF report of TFRecords. Reports include 10 example tiles per TFRecord. Report is saved in the target destination. Args: dest (str): Directory in which to save the PDF report. normalizer (`slideflow.norm.StainNormalizer`, optional): Normalizer to use on image tiles. Defaults to None. """ if normalizer is not None: log.info(f'Using realtime {normalizer.method} normalization') tfrecord_list = self.tfrecords() reports = [] log.info('Generating TFRecords report...') # Get images for report for tfr in track(tfrecord_list, description='Generating report...'): dataset = sf.io.TFRecordDataset(tfr) parser = sf.io.get_tfrecord_parser( tfr, ('image_raw',), to_numpy=True, decode_images=False ) if not parser: continue sample_tiles = [] for i, record in enumerate(dataset): if i > 9: break image_raw_data = parser(record)[0] if normalizer: image_raw_data = normalizer.jpeg_to_jpeg(image_raw_data) sample_tiles += [image_raw_data] reports += [SlideReport(sample_tiles, tfr, tile_px=self.tile_px, tile_um=self.tile_um, ignore_thumb_errors=True)] # Generate and save PDF log.info('Generating PDF (this may take some time)...') pdf_report = ExtractionReport(reports, title='TFRecord Report') timestring = datetime.now().strftime('%Y%m%d-%H%M%S') if exists(dest) and isdir(dest): filename = join(dest, f'tfrecord_report-{timestring}.pdf') elif sf.util.path_to_ext(dest) == 'pdf': filename = join(dest) else: raise ValueError(f"Could not find destination directory {dest}.") pdf_report.save(filename) log.info(f'TFRecord report saved to [green]{filename}') def tfrecord_heatmap( self, tfrecord: Union[str, List[str]], tile_dict: Dict[int, float], filename: str, **kwargs ) -> None: """Create a tfrecord-based WSI heatmap. Uses a dictionary of tile values for heatmap display, and saves to the specified directory. Args: tfrecord (str or list(str)): Path(s) to tfrecord(s). tile_dict (dict): Dictionary mapping tfrecord indices to a tile-level value for display in heatmap format filename (str): Destination filename for heatmap. """ slide_paths = { sf.util.path_to_name(sp): sp for sp in self.slide_paths() } if not self.tile_px or not self.tile_um: raise errors.DatasetError( "Dataset tile_px & tile_um must be set to create TFRecords." ) for tfr in sf.util.as_list(tfrecord): name = sf.util.path_to_name(tfr) if name not in slide_paths: raise errors.SlideNotFoundError(f'Unable to find slide {name}') sf.util.tfrecord_heatmap( tfrecord=tfr, slide=slide_paths[name], tile_px=self.tile_px, tile_um=self.tile_um, tile_dict=tile_dict, filename=filename, **kwargs ) def tfrecords(self, source: Optional[str] = None) -> List[str]: """Return a list of all tfrecords. Args: source (str, optional): Only return tfrecords from this dataset source. Defaults to None (return all tfrecords in dataset). Returns: List of tfrecords paths. """ if source and source not in self.sources.keys(): log.error(f"Dataset {source} not found.") return [] if source is None: sources_to_search = list(self.sources.keys()) # type: List[str] else: sources_to_search = [source] tfrecords_list = [] folders_to_search = [] for source in sources_to_search: if not self._tfrecords_set(source): log.warning(f"tfrecords path not set for source {source}") continue tfrecords = self.sources[source]['tfrecords'] label = self.sources[source]['label'] if label is None: continue tfrecord_path = join(tfrecords, label) if not exists(tfrecord_path): log.debug( f"TFRecords path not found: {tfrecord_path}" ) continue folders_to_search += [tfrecord_path] for folder in folders_to_search: tfrecords_list += glob(join(folder, "*.tfrecords")) tfrecords_list = list(set(tfrecords_list)) # Filter the list by filters if self.annotations is not None: slides = self.slides() filtered_tfrecords_list = [ tfrecord for tfrecord in tfrecords_list if path_to_name(tfrecord) in slides ] filtered = filtered_tfrecords_list else: log.warning("Error filtering TFRecords, are annotations empty?") filtered = tfrecords_list # Filter by min_tiles manifest = self.manifest(filter=False) if not all([f in manifest for f in filtered]): self.update_manifest() manifest = self.manifest(filter=False) if self.min_tiles: return [ f for f in filtered if f in manifest and manifest[f]['total'] >= self.min_tiles ] else: return [f for f in filtered if f in manifest and manifest[f]['total'] > 0] def tfrecords_by_subfolder(self, subfolder: str) -> List[str]: """Return a list of all tfrecords in a specific subfolder. Ignores any dataset filters. Args: subfolder (str): Path to subfolder to check for tfrecords. Returns: List of tfrecords paths. """ tfrecords_list = [] folders_to_search = [] for source in self.sources: if self.sources[source]['label'] is None: continue if not self._tfrecords_set(source): log.warning(f"tfrecords path not set for source {source}") continue base_dir = join( self.sources[source]['tfrecords'], self.sources[source]['label'] ) tfrecord_path = join(base_dir, subfolder) if not exists(tfrecord_path): raise errors.DatasetError( f"Unable to find subfolder [bold]{subfolder}[/] in " f"source [bold]{source}[/], tfrecord directory: " f"[green]{base_dir}" ) folders_to_search += [tfrecord_path] for folder in folders_to_search: tfrecords_list += glob(join(folder, "*.tfrecords")) return tfrecords_list def tfrecords_folders(self) -> List[str]: """Return folders containing tfrecords.""" folders = [] for source in self.sources: if self.sources[source]['label'] is None: continue if not self._tfrecords_set(source): log.warning(f"tfrecords path not set for source {source}") continue folders += [join( self.sources[source]['tfrecords'], self.sources[source]['label'] )] return folders def tfrecords_from_tiles(self, delete_tiles: bool = False) -> None: """Create tfrecord files from a collection of raw images. Images must be stored in the dataset source(s) tiles directory. Args: delete_tiles (bool): Remove tiles after storing in tfrecords. Returns: None """ if not self.tile_px or not self.tile_um: raise errors.DatasetError( "Dataset tile_px & tile_um must be set to create TFRecords." ) for source in self.sources: log.info(f'Working on dataset source {source}') config = self.sources[source] if not (self._tiles_set(source) and self._tfrecords_set(source)): log.error("tiles and/or tfrecords paths not set for " f"source {source}") continue tfrecord_dir = join(config['tfrecords'], config['label']) tiles_dir = join(config['tiles'], config['label']) if not exists(tiles_dir): log.warn(f'No tiles found for source [bold]{source}') continue sf.io.write_tfrecords_multi(tiles_dir, tfrecord_dir) self.update_manifest() if delete_tiles: shutil.rmtree(tiles_dir) def tfrecords_have_locations(self) -> bool: """Check if TFRecords have associated tile location information.""" for tfr in self.tfrecords(): try: tfr_has_loc = sf.io.tfrecord_has_locations(tfr) except errors.TFRecordsError: # Encountered when the TFRecord is empty. continue if not tfr_has_loc: log.info(f"{tfr}: Tile location information missing.") return False return True def thumbnails( self, outdir: str, size: int = 512, roi: bool = False, enable_downsample: bool = True ) -> None: """Generate square slide thumbnails with black borders of fixed size. Saves thumbnails to the specified directory. Args: size (int, optional): Width/height of thumbnail in pixels. Defaults to 512. dataset (:class:`slideflow.Dataset`, optional): Dataset from which to generate activations. If not supplied, will calculate activations for all tfrecords at the tile_px/tile_um matching the supplied model, optionally using provided filters and filter_blank. filters (dict, optional): Dataset filters to use for selecting slides. See :meth:`slideflow.Dataset.filter` for more information. Defaults to None. filter_blank (list(str) or str, optional): Skip slides that have blank values in these patient annotation columns. Defaults to None. roi (bool, optional): Include ROI in the thumbnail images. Defaults to False. enable_downsample (bool, optional): If True and a thumbnail is not embedded in the slide file, downsampling is permitted to accelerate thumbnail calculation. """ slide_list = self.slide_paths() rois = self.rois() log.info(f'Saving thumbnails to [green]{outdir}') for slide_path in tqdm(slide_list, desc="Generating thumbnails..."): log.debug(f'Working on [green]{path_to_name(slide_path)}[/]...') try: whole_slide = WSI(slide_path, tile_px=1000, tile_um=1000, stride_div=1, enable_downsample=enable_downsample, rois=rois, verbose=False) except errors.MissingROIError: log.info(f"Skipping {slide_path}; missing ROI") continue except Exception as e: log.error( f"Error generating thumbnail for {slide_path}: {e}" ) continue if roi: thumb = whole_slide.thumb(rois=True) else: thumb = whole_slide.square_thumb(size) thumb.save(join(outdir, f'{whole_slide.name}.png')) log.info('Thumbnail generation complete.') def train_val_split( self, *args: Any, **kwargs: Any ) -> Tuple["Dataset", "Dataset"]: """Deprecated function.""" # noqa: D401 warnings.warn( "Dataset.train_val_split() is deprecated and will be " "removed in a future version. Please use Dataset.split()", DeprecationWarning ) return self.split(*args, **kwargs) def transform_tfrecords(self, dest: str, **kwargs) -> None: """Transform TFRecords, saving to a target path. Tfrecords will be saved in the output directory nested by source name. Args: dest (str): Destination. """ if not exists(dest): os.makedirs(dest) total = len(self.tfrecords()) pb = tqdm(total=total) for source in self.sources: log.debug(f"Working on source {source}") tfr_dest = join(dest, source) if not exists(tfr_dest): os.makedirs(tfr_dest) for tfr in self.tfrecords(source=source): sf.io.tensorflow.transform_tfrecord( tfr, join(tfr_dest, basename(tfr)), **kwargs ) pb.update(1) log.info(f"Saved {total} transformed tfrecords to {dest}.") def torch( self, labels: Optional[Union[Dict[str, Any], str, pd.DataFrame]] = None, batch_size: Optional[int] = None, rebuild_index: bool = False, from_wsi: bool = False, **kwargs: Any ) -> "DataLoader": """Return a PyTorch DataLoader object that interleaves tfrecords. The returned dataloader yields a batch of (image, label) for each tile. Args: labels (dict, str, or pd.DataFrame, optional): If a dict is provided, expect a dict mapping slide names to outcome labels. If a str, will intepret as categorical annotation header. For regression tasks, or outcomes with manually assigned labels, pass the first result of dataset.labels(...). If None, returns slide instead of label. batch_size (int): Batch size. rebuild_index (bool): Re-build index files even if already present. Defaults to True. Keyword Args: augment (str or bool): Image augmentations to perform. Augmentations include: * ``'x'``: Random horizontal flip * ``'y'``: Random vertical flip * ``'r'``: Random 90-degree rotation * ``'j'``: Random JPEG compression (50% chance to compress with quality between 50-100) * ``'b'``: Random Gaussian blur (10% chance to blur with sigma between 0.5-2.0) * ``'n'``: Random :ref:`stain_augmentation` (requires stain normalizer) Combine letters to define augmentations, such as ``'xyrjn'``. A value of True will use ``'xyrjb'``. chunk_size (int, optional): Chunk size for image decoding. Defaults to 1. drop_last (bool, optional): Drop the last non-full batch. Defaults to False. from_wsi (bool): Generate predictions from tiles dynamically extracted from whole-slide images, rather than TFRecords. Defaults to False (use TFRecords). incl_loc (bool, optional): Include loc_x and loc_y (image tile center coordinates, in base / level=0 dimension) as additional returned variables. Defaults to False. incl_slidenames (bool, optional): Include slidenames as third returned variable. Defaults to False. infinite (bool, optional): Infinitely repeat data. Defaults to True. max_size (bool, optional): Unused argument present for legacy compatibility; will be removed. model_type (str, optional): Used to generate random labels (for StyleGAN2). Not required. Defaults to 'classification'. num_replicas (int, optional): Number of GPUs or unique instances which will have their own DataLoader. Used to interleave results among workers without duplications. Defaults to 1. num_workers (int, optional): Number of DataLoader workers. Defaults to 2. normalizer (:class:`slideflow.norm.StainNormalizer`, optional): Normalizer. Defaults to None. onehot (bool, optional): Onehot encode labels. Defaults to False. persistent_workers (bool, optional): Sets the DataLoader persistent_workers flag. Defaults toNone (4 if not using a SPAMS normalizer, 1 if using SPAMS). pin_memory (bool, optional): Pin memory to GPU. Defaults to True. pool (multiprocessing.Pool): Shared multiprocessing pool. Useful if from_wsi=True, for sharing a unified processing pool between dataloaders. Defaults to None. prefetch_factor (int, optional): Number of batches to prefetch in each SlideflowIterator. Defaults to 1. rank (int, optional): Worker ID to identify this worker. Used to interleave results. among workers without duplications. Defaults to 0 (first worker). rois (list(str), optional): List of ROI paths. Only used if from_wsi=True. Defaults to None. roi_method (str, optional): Method for extracting ROIs. Only used if from_wsi=True. Defaults to 'auto'. standardize (bool, optional): Standardize images to mean 0 and variance of 1. Defaults to True. tile_um (int, optional): Size of tiles to extract from WSI, in microns. Only used if from_wsi=True. Defaults to None. transform (Callable, optional): Arbitrary torchvision transform function. Performs transformation after augmentations but before standardization. Defaults to None. tfrecord_parser (Callable, optional): Custom parser for TFRecords. Defaults to None. """ from slideflow.io.torch import interleave_dataloader if isinstance(labels, str) and not exists(labels): labels = self.labels(labels)[0] if self.tile_px is None: raise errors.DatasetError("tile_px and tile_um must be non-zero" "to create dataloaders.") if self._clip not in (None, {}) and from_wsi: log.warning("Dataset clipping is disabled when `from_wsi=True`") if from_wsi: tfrecords = self.slide_paths() kwargs['rois'] = self.rois() kwargs['tile_um'] = self.tile_um kwargs['img_size'] = self.tile_px indices = None clip = None else: self.build_index(rebuild_index) tfrecords = self.tfrecords() if not tfrecords: raise errors.TFRecordsNotFoundError self.verify_img_format(progress=False) _idx_dict = self.load_indices() indices = [_idx_dict[path_to_name(tfr)] for tfr in tfrecords] clip = self._clip if self.prob_weights: prob_weights = [self.prob_weights[tfr] for tfr in tfrecords] else: prob_weights = None return interleave_dataloader(tfrecords=tfrecords, batch_size=batch_size, labels=labels, num_tiles=self.num_tiles, prob_weights=prob_weights, clip=clip, indices=indices, from_wsi=from_wsi, **kwargs) def unclip(self) -> "Dataset": """Return a dataset object with all clips removed. Returns: :class:`slideflow.Dataset`: Dataset with clips removed. """ ret = copy.deepcopy(self) ret._clip = {} return ret def update_manifest(self, force_update: bool = False) -> None: """Update tfrecord manifests. Args: forced_update (bool, optional): Force regeneration of the manifests from scratch. """ tfrecords_folders = self.tfrecords_folders() for tfr_folder in tfrecords_folders: sf.io.update_manifest_at_dir( directory=tfr_folder, force_update=force_update ) def update_annotations_with_slidenames( self, annotations_file: str ) -> None: """Automatically associated slide names and paths in the annotations. Attempts to automatically associate slide names from a directory with patients in a given annotations file, skipping any slide names that are already present in the annotations file. Args: annotations_file (str): Path to annotations file. """ header, _ = sf.util.read_annotations(annotations_file) slide_list = self.slide_paths(apply_filters=False) # First, load all patient names from the annotations file try: patient_index = header.index('patient') except ValueError: raise errors.AnnotationsError( f"Patient header {'patient'} not found in annotations." ) patients = [] pt_to_slide = {} with open(annotations_file) as csv_file: csv_reader = csv.reader(csv_file, delimiter=',') header = next(csv_reader) for row in csv_reader: patients.extend([row[patient_index]]) patients = list(set(patients)) log.debug(f"Number of patients in annotations: {len(patients)}") log.debug(f"Slides found: {len(slide_list)}") # Then, check for sets of slides that would match to the same patient; # due to ambiguity, these will be skipped. n_occur = {} for slide in slide_list: if _shortname(slide) not in n_occur: n_occur[_shortname(slide)] = 1 else: n_occur[_shortname(slide)] += 1 slides_to_skip = [s for s in slide_list if n_occur[_shortname(s)] > 1] # Next, search through the slides folder for all valid slide files for file in slide_list: slide = path_to_name(file) # First, skip this slide due to ambiguity if needed if slide in slides_to_skip: log.warning(f"Skipping slide {slide} due to ambiguity") # Then, make sure the shortname and long name # aren't both in the annotation file if ((slide != _shortname(slide)) and (slide in patients) and (_shortname(slide) in patients)): log.warning(f"Skipping slide {slide} due to ambiguity") # Check if either the slide name or the shortened version # are in the annotation file if any(x in patients for x in [slide, _shortname(slide)]): slide = slide if slide in patients else _shortname(slide) pt_to_slide.update({slide: slide}) # Now, write the assocations n_updated = 0 n_missing = 0 with open(annotations_file) as csv_file: csv_reader = csv.reader(csv_file, delimiter=',') header = next(csv_reader) with open('temp.csv', 'w') as csv_outfile: csv_writer = csv.writer(csv_outfile, delimiter=',') # Write to existing "slide" column in the annotations file, # otherwise create new column try: slide_index = header.index('slide') except ValueError: header.extend(['slide']) csv_writer.writerow(header) for row in csv_reader: patient = row[patient_index] if patient in pt_to_slide: row.extend([pt_to_slide[patient]]) n_updated += 1 else: row.extend([""]) n_missing += 1 csv_writer.writerow(row) else: csv_writer.writerow(header) for row in csv_reader: pt = row[patient_index] # Only write column if no slide is in the annotation if (pt in pt_to_slide) and (row[slide_index] == ''): row[slide_index] = pt_to_slide[pt] n_updated += 1 elif ((pt not in pt_to_slide) and (row[slide_index] == '')): n_missing += 1 csv_writer.writerow(row) if n_updated: log.info(f"Done; associated slides with {n_updated} annotations.") if n_missing: log.info(f"Slides not found for {n_missing} annotations.") elif n_missing: log.debug(f"Slides missing for {n_missing} annotations.") else: log.debug("Annotations up-to-date, no changes made.") # Finally, backup the old annotation file and overwrite # existing with the new data backup_file = f"{annotations_file}.backup" if exists(backup_file): os.remove(backup_file) assert isinstance(annotations_file, str) shutil.move(annotations_file, backup_file) shutil.move('temp.csv', annotations_file) def verify_annotations_slides(self) -> None: """Verify that annotations are correctly loaded.""" if self.annotations is None: log.warn("Annotations not loaded.") return # Verify no duplicate slide names are found ann = self.annotations.loc[self.annotations.slide.isin(self.slides())] if not ann.slide.is_unique: raise errors.AnnotationsError( "Duplicate slide names detected in the annotation file." ) # Verify that there are no tfrecords with the same name. # This is a problem because the tfrecord name is used to # identify the slide. tfrecords = self.tfrecords() if len(tfrecords): tfrecord_names = [sf.util.path_to_name(tfr) for tfr in tfrecords] if not len(set(tfrecord_names)) == len(tfrecord_names): duplicate_tfrs = [ tfr for tfr in tfrecords if tfrecord_names.count(sf.util.path_to_name(tfr)) > 1 ] raise errors.AnnotationsError( "Multiple TFRecords with the same names detected: {}".format( ', '.join(duplicate_tfrs) ) ) # Verify all slides in the annotation column are valid n_missing = len(self.annotations.loc[ (self.annotations.slide.isin(['', ' ']) | self.annotations.slide.isna()) ]) if n_missing == 1: log.warn("1 patient does not have a slide assigned.") if n_missing > 1: log.warn(f"{n_missing} patients do not have a slide assigned.") def verify_img_format(self, *, progress: bool = True) -> Optional[str]: """Verify that all tfrecords have the same image format (PNG/JPG). Returns: str: image format (png or jpeg) """ tfrecords = self.tfrecords() if len(tfrecords): with mp.Pool(sf.util.num_cpu(), initializer=sf.util.set_ignore_sigint) as pool: img_formats = [] mapped = pool.imap_unordered( sf.io.detect_tfrecord_format, tfrecords ) if progress: mapped = track( mapped, description="Verifying tfrecord formats...", transient=True ) for *_, fmt in mapped: if fmt is not None: img_formats += [fmt] if len(set(img_formats)) > 1: log_msg = "Mismatched TFRecord image formats:\n" for tfr, fmt in zip(tfrecords, img_formats): log_msg += f"{tfr}: {fmt}\n" log.error(log_msg) raise errors.MismatchedImageFormatsError( "Mismatched TFRecord image formats detected" ) if len(img_formats): return img_formats[0] else: return None else: return None def verify_slide_names(self, allow_errors: bool = False) -> bool: """Verify that slide names inside TFRecords match the file names. Args: allow_errors (bool): Do not raise an error if there is a mismatch. Defaults to False. Returns: bool: If all slide names inside TFRecords match the TFRecord file names. Raises: sf.errors.MismatchedSlideNamesError: If any slide names inside TFRecords do not match the TFRecord file names, and allow_errors=False. """ tfrecords = self.tfrecords() if len(tfrecords): pb = track( tfrecords, description="Verifying tfrecord slide names...", transient=True ) for tfr in pb: first_record = sf.io.get_tfrecord_by_index(tfr, 0) if first_record['slide'] == sf.util.path_to_name(tfr): continue elif allow_errors: return False else: raise errors.MismatchedSlideNamesError( "Mismatched slide name in TFRecord {}: expected slide " "name {} based on filename, but found {}. ".format( tfr, sf.util.path_to_name(tfr), first_record['slide'] ) ) return True