Source code for slideflow.model.extractors._factory
"""Factory for building feature extractors."""
import importlib
import slideflow as sf
from os.path import join, exists
from typing import Optional, Tuple, Dict, Any, TYPE_CHECKING
from slideflow import errors
from slideflow.model import BaseFeatureExtractor
from ._registry import (is_tensorflow_extractor, is_torch_extractor,
_tf_extractors, _torch_extractors, _extras_extractors)
from ._factory_tensorflow import build_tensorflow_feature_extractor
from ._factory_torch import build_torch_feature_extractor
if TYPE_CHECKING:
from slideflow.norm import StainNormalizer
# -----------------------------------------------------------------------------
[docs]def build_feature_extractor(
name: str,
backend: Optional[str] = None,
**kwargs
) -> BaseFeatureExtractor:
"""Build a feature extractor.
The returned feature extractor is a callable object, which returns
features (often layer activations) for either a batch of images or a
:class:`slideflow.WSI` object.
If generating features for a batch of images, images are expected to be in
(B, W, H, C) format and non-standardized (scaled 0-255) with dtype uint8.
The feature extractors perform all needed preprocessing on the fly.
If generating features for a slide, the slide is expected to be a
:class:`slideflow.WSI` object. The feature extractor will generate features
for each tile in the slide, returning a numpy array of shape (W, H, F),
where F is the number of features.
Args:
name (str): Name of the feature extractor to build. Available
feature extractors are listed with
:func:`slideflow.model.list_extractors()`.
Keyword arguments:
tile_px (int): Tile size (input image size), in pixels.
**kwargs (Any): All remaining keyword arguments are passed
to the feature extractor factory function, and may be different
for each extractor.
Returns:
A callable object which accepts a batch of images (B, W, H, C) of dtype
uint8 and returns a batch of features (dtype float32).
Examples
Create an extractor that calculates post-convolutional layer activations
from an imagenet-pretrained Resnet50 model.
.. code-block:: python
import slideflow as sf
extractor = sf.build_feature_extractor(
'resnet50_imagenet'
)
Create an extractor that calculates 'conv4_block4_2_relu' activations
from an imagenet-pretrained Resnet50 model.
.. code-block:: python
extractor = sf.build_feature_extractor(
'resnet50_imagenet',
layers='conv4_block4_2_relu
)
Create a pretrained "CTransPath" extractor.
.. code-block:: python
extractor = sf.build_feature_extractor('ctranspath')
Use an extractor to calculate layer activations for an entire dataset.
.. code-block:: python
import slideflow as sf
# Load a project and dataset
P = sf.load_project(...)
dataset = P.dataset(...)
# Create a feature extractor
resnet = sf.build_feature_extractor(
'resnet50_imagenet'
)
# Calculate features for the entire dataset
features = sf.DatasetFeatures(
resnet,
dataset=dataset
)
Generate a map of features across a slide.
.. code-block:: python
import slideflow as sf
# Load a slide
wsi = sf.WSI(...)
# Create a feature extractor
retccl = sf.build_feature_extractor(
'retccl',
resize=True
)
# Create a feature map, a 2D array of shape
# (W, H, F), where F is the number of features.
features = retccl(wsi)
"""
# Build feature extractor according to manually specified backend
if backend is not None and backend not in ('tensorflow', 'torch'):
raise ValueError(f"Invalid backend: {backend}")
# Build a feature extractor from a finetuned model
if sf.util.is_tensorflow_model_path(name):
model_config = sf.util.get_model_config(name)
if model_config['hp']['uq']:
from slideflow.model.tensorflow import UncertaintyInterface
return UncertaintyInterface(name, **kwargs)
else:
from slideflow.model.tensorflow import Features
return Features(name, **kwargs)
elif sf.util.is_torch_model_path(name):
model_config = sf.util.get_model_config(name)
if model_config['hp']['uq']:
from slideflow.model.torch import UncertaintyInterface
return UncertaintyInterface(name, **kwargs)
else:
from slideflow.model.torch import Features # noqa: F401
return Features(name, **kwargs)
# Build feature extractor with a specific backend
if backend == 'tensorflow':
if not is_tensorflow_extractor(name):
raise errors.InvalidFeatureExtractor(
f"Feature extractor {name} not available in Tensorflow backend")
return build_tensorflow_feature_extractor(name, **kwargs)
elif backend == 'torch':
if not is_torch_extractor(name):
raise errors.InvalidFeatureExtractor(
f"Feature extractor {name} not available in PyTorch backend")
return build_torch_feature_extractor(name, **kwargs)
# Auto-build feature extractor according to available backends
if is_tensorflow_extractor(name) and is_torch_extractor(name):
sf.log.info(
f"Feature extractor {name} available in both Tensorflow and "
f"PyTorch backends; using active backend {sf.backend()}")
if sf.backend() == 'tensorflow':
return build_tensorflow_feature_extractor(name, **kwargs)
else:
return build_torch_feature_extractor(name, **kwargs)
if is_tensorflow_extractor(name):
return build_tensorflow_feature_extractor(name, **kwargs)
elif is_torch_extractor(name):
return build_torch_feature_extractor(name, **kwargs)
elif name in _extras_extractors:
raise errors.InvalidFeatureExtractor(
"{} requires the package {}, please install with 'pip install {}'".format(
name, _extras_extractors[name], _extras_extractors[name]
))
else:
raise errors.InvalidFeatureExtractor(f"Unrecognized feature extractor: {name}")
[docs]def rebuild_extractor(
bags_or_model: str,
allow_errors: bool = False,
native_normalizer: bool = True
) -> Tuple[Optional["BaseFeatureExtractor"], Optional["StainNormalizer"]]:
"""Recreate the extractor used to generate features stored in bags.
Args:
bags_or_model (str): Either a path to directory containing feature bags,
or a path to a trained MIL model. If a path to a trained MIL model,
the extractor used to generate features will be recreated.
allow_errors (bool): If True, return None if the extractor
cannot be rebuilt. If False, raise an error. Defaults to False.
native_normalizer (bool, optional): Whether to use PyTorch/Tensorflow-native
stain normalization, if applicable. If False, will use the OpenCV/Numpy
implementations. Defaults to True.
Returns:
Optional[BaseFeatureExtractor]: Extractor function, or None if ``allow_errors`` is
True and the extractor cannot be rebuilt.
Optional[StainNormalizer]: Stain normalizer used when generating
feature bags, or None if no stain normalization was used.
"""
# Load bags configuration
is_bag_config = bags_or_model.endswith('bags_config.json')
is_bag_dir = exists(join(bags_or_model, 'bags_config.json'))
is_model_dir = exists(join(bags_or_model, 'mil_params.json'))
if not (is_bag_dir or is_model_dir or is_bag_config):
if allow_errors:
return None, None
else:
raise ValueError(
'Could not find bags or MIL model configuration at '
f'{bags_or_model}.'
)
if is_bag_config:
bags_config = sf.util.load_json(bags_or_model)
elif is_model_dir:
mil_config = sf.util.load_json(join(bags_or_model, 'mil_params.json'))
if 'bags_extractor' not in mil_config:
if allow_errors:
return None, None
else:
raise ValueError(
'Could not rebuild extractor from configuration at '
f'{bags_or_model}; missing "bags_extractor" key in '
'mil_params.json.'
)
bags_config = mil_config['bags_extractor']
else:
bags_config = sf.util.load_json(join(bags_or_model, 'bags_config.json'))
if ('extractor' not in bags_config
or any(n not in bags_config['extractor'] for n in ['class', 'kwargs'])):
if allow_errors:
return None, None
else:
raise ValueError(
'Could not rebuild extractor from configuration at '
f'{bags_or_model}; missing "extractor" class or kwargs.'
)
# Rebuild extractor
extractor_name = bags_config['extractor']['class'].split('.')
extractor_class = extractor_name[-1]
extractor_kwargs = bags_config['extractor']['kwargs']
try:
module = importlib.import_module('.'.join(extractor_name[:-1]))
extractor = getattr(module, extractor_class)(**extractor_kwargs)
except Exception:
submodule_name = extractor_name[-2]
if submodule_name in _extras_extractors:
raise errors.InvalidFeatureExtractor(
"{} requires the package {}, please install with 'pip install {}'".format(
submodule_name,
_extras_extractors[submodule_name],
_extras_extractors[submodule_name]
))
if allow_errors:
return None
else:
raise ValueError(
f'Could not rebuild extractor from configuration at {bags_or_model}.'
)
# Rebuild stain normalizer
if bags_config['normalizer'] is not None:
normalizer = sf.norm.autoselect(
bags_config['normalizer']['method'],
backend=(extractor.backend if native_normalizer else 'opencv')
)
normalizer.set_fit(**bags_config['normalizer']['fit'])
else:
normalizer = None
if (hasattr(extractor, 'normalizer')
and extractor.normalizer is not None
and normalizer is not None):
sf.log.warning(
'Extractor already has a stain normalizer. Overwriting with '
'normalizer from bags configuration.'
)
extractor.normalizer = normalizer
elif hasattr(extractor, 'normalizer') and extractor.normalizer is not None:
normalizer = extractor.normalizer
return extractor, normalizer
# -----------------------------------------------------------------------------
def extractor_to_config(extractor: BaseFeatureExtractor) -> Dict[str, Any]:
"""Return a dictionary of configuration parameters for the extractor.
These configuration parameters can be used to reconstruct the
feature extractor, using ``build_extractor_from_cfg()``.
Args:
extractor (BaseFeatureExtractor): Feature extractor.
Returns:
Dict[str, Any]: Configuration dictionary.
"""
cfg = extractor.dump_config()
if extractor.backend == 'torch':
cfg['mixed_precision'] = extractor.mixed_precision
cfg['channels_last'] = extractor.channels_last
return cfg
def build_extractor_from_cfg(
cfg: Dict[str, Any],
**kwargs: Any
) -> BaseFeatureExtractor:
"""Rebuild an extractor from a configuration dictionary.
Args:
cfg (Dict[str, Any]): Configuration dictionary.
**kwargs (Any): All remaining keyword arguments are passed
to the feature extractor factory function, and may be different
for each extractor.
Returns:
BaseFeatureExtractor: The rebuilt feature extractor.
"""
extractor_name = cfg['class'].split('.')
extractor_class = extractor_name[-1]
extractor_kwargs = cfg['kwargs']
module = importlib.import_module('.'.join(extractor_name[:-1]))
extractor = getattr(module, extractor_class)(**extractor_kwargs, **kwargs)
for k, v in cfg.items():
if k not in ['class', 'kwargs']:
setattr(extractor, k, v)
return extractor