Shortcuts

Source code for slideflow.model

'''Submodule that includes tools for intermediate layer activations.

Supports both PyTorch and Tensorflow backends, importing either model.tensorflow
or model.pytorch based on the environmental variable SF_BACKEND.
'''

import warnings
from typing import Any, Dict, List

import slideflow as sf
from slideflow import errors
from .base import BaseFeatureExtractor
from .features import DatasetFeatures
from .extractors import (
    list_extractors, list_torch_extractors, list_tensorflow_extractors,
    is_extractor, is_torch_extractor, is_tensorflow_extractor,
    build_feature_extractor, build_torch_feature_extractor,
    build_tensorflow_feature_extractor, rebuild_extractor
)

# --- Backend-specific imports ------------------------------------------------

if sf.backend() == 'tensorflow':
    from slideflow.model.tensorflow import (SurvivalTrainer, Features, load, # noqa F401
                                            RegressionTrainer, ModelParams,
                                            Trainer, UncertaintyInterface)
elif sf.backend() == 'torch':
    from slideflow.model.torch import (SurvivalTrainer, Features, load, # noqa F401
                                       RegressionTrainer, ModelParams,
                                       Trainer, UncertaintyInterface)
else:
    raise errors.UnrecognizedBackendError

# -----------------------------------------------------------------------------


[docs]def is_tensorflow_tensor(arg: Any) -> bool: """Checks if the given object is a Tensorflow Tensor.""" if sf.util.tf_available: import tensorflow as tf return isinstance(arg, tf.Tensor) else: return False
[docs]def is_torch_tensor(arg: Any) -> bool: """Checks if the given object is a Tensorflow Tensor.""" if sf.util.torch_available: import torch return isinstance(arg, torch.Tensor) else: return False
[docs]def is_tensorflow_model(arg: Any) -> bool: """Checks if the object is a Tensorflow Model or path to Tensorflow model.""" if isinstance(arg, str): return sf.util.is_tensorflow_model_path(arg) elif sf.util.tf_available: import tensorflow as tf return isinstance(arg, tf.keras.models.Model) else: return False
[docs]def is_torch_model(arg: Any) -> bool: """Checks if the object is a PyTorch Module or path to PyTorch model.""" if isinstance(arg, str): return sf.util.is_torch_model_path(arg) elif sf.util.torch_available: import torch return isinstance(arg, torch.nn.Module) else: return False
def trainer_from_hp(*args, **kwargs): warnings.warn( "sf.model.trainer_from_hp() is deprecated. Please use " "sf.model.build_trainer().", DeprecationWarning ) return build_trainer(*args, **kwargs)
[docs]def build_trainer( hp: "ModelParams", outdir: str, labels: Dict[str, Any], **kwargs ) -> Trainer: """From the given :class:`slideflow.ModelParams` object, returns the appropriate instance of :class:`slideflow.model.Trainer`. Args: hp (:class:`slideflow.ModelParams`): ModelParams object. outdir (str): Path for event logs and checkpoints. labels (dict): Dict mapping slide names to outcome labels (int or float format). Keyword Args: slide_input (dict): Dict mapping slide names to additional slide-level input, concatenated after post-conv. name (str, optional): Optional name describing the model, used for model saving. Defaults to 'Trainer'. feature_sizes (list, optional): List of sizes of input features. Required if providing additional input features as input to the model. feature_names (list, optional): List of names for input features. Used when permuting feature importance. outcome_names (list, optional): Name of each outcome. Defaults to "Outcome {X}" for each outcome. mixed_precision (bool, optional): Use FP16 mixed precision (rather than FP32). Defaults to True. allow_tf32 (bool): Allow internal use of Tensorfloat-32 format. Defaults to False. config (dict, optional): Training configuration dictionary, used for logging. Defaults to None. use_neptune (bool, optional): Use Neptune API logging. Defaults to False neptune_api (str, optional): Neptune API token, used for logging. Defaults to None. neptune_workspace (str, optional): Neptune workspace. Defaults to None. load_method (str): Either 'full' or 'weights'. Method to use when loading a Tensorflow model. If 'full', loads the model with ``tf.keras.models.load_model()``. If 'weights', will read the ``params.json`` configuration file, build the model architecture, and then load weights from the given model with ``Model.load_weights()``. Loading with 'full' may improve compatibility across Slideflow versions. Loading with 'weights' may improve compatibility across hardware & environments. custom_objects (dict, Optional): Dictionary mapping names (strings) to custom classes or functions. Defaults to None. num_workers (int): Number of dataloader workers. Only used for PyTorch. Defaults to 4. """ if hp.model_type() == 'classification': return Trainer(hp, outdir, labels, **kwargs) if hp.model_type() == 'regression': return RegressionTrainer(hp, outdir, labels, **kwargs) if hp.model_type() == 'survival': return SurvivalTrainer(hp, outdir, labels, **kwargs) else: raise ValueError(f"Unknown model type: {hp.model_type()}")
[docs]def read_hp_sweep( filename: str, models: List[str] = None ) -> Dict[str, "ModelParams"]: """Organizes a list of hyperparameters ojects and associated models names. Args: filename (str): Path to hyperparameter sweep JSON file. models (list(str)): List of model names. Defaults to None. If not supplied, returns all valid models from batch file. Returns: List of (Hyperparameter, model_name) for each HP combination """ if models is not None and not isinstance(models, list): raise ValueError("If supplying models, must be list(str) " "with model names.") if isinstance(models, list) and not list(set(models)) == models: raise ValueError("Duplicate model names provided.") hp_list = sf.util.load_json(filename) # First, ensure all indicated models are in the batch train file if models: valid_models = [] for hp_dict in hp_list: model_name = list(hp_dict.keys())[0] if ((not models) or (isinstance(models, str) and model_name == models) or model_name in models): valid_models += [model_name] missing = [m for m in models if m not in valid_models] if missing: raise ValueError(f"Unable to find models {', '.join(missing)}") else: valid_models = [list(hp_dict.keys())[0] for hp_dict in hp_list] # Read the batch train file and generate HyperParameter objects # from the given configurations loaded = {} for hp_dict in hp_list: name = list(hp_dict.keys())[0] if name in valid_models: loaded.update({ name: ModelParams.from_dict(hp_dict[name]) }) return loaded # type: ignore