import shutil
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as plticker
import numpy as np
from skmisc.loess import loess
from scipy import stats
from tqdm import tqdm
from statistics import mean
from os.path import join, exists
import slideflow as sf
from slideflow.util import log
from . import utils, threshold
from . import hp as biscuit_hp
from .errors import MatchError, ModelNotFoundError, ThresholdError
# -----------------------------------------------------------------------------
ALL_EXP = {
'AA': 'full',
'U': 800,
'T': 700,
'S': 600,
'R': 500,
'A': 400,
'L': 350,
'M': 300,
'N': 250,
'D': 200,
'O': 176,
'P': 150,
'Q': 126,
'G': 100,
'V': 90,
'W': 80,
'X': 70,
'Y': 60,
'Z': 50,
'ZA': 40,
'ZB': 30,
'ZC': 20,
'ZD': 10
}
# -----------------------------------------------------------------------------
[docs]class Experiment:
def __init__(
self,
train_project,
eval_projects=None,
outcome='cohort',
outcome1='LUAD',
outcome2='LUSC',
outdir='results'
):
"""Supervises uncertainty thresholding experiments."""
if eval_projects is None:
eval_projects = []
if isinstance(train_project, str):
self.train_project = sf.Project(train_project)
elif isinstance(train_project, sf.Project):
self.train_project = train_project
else:
raise ValueError(f"Unrecognized value for train_project: {train_project}")
self.eval_projects = []
for ep in eval_projects:
if isinstance(ep, str):
self.eval_projects += [sf.Project(ep)]
elif isinstance(ep, sf.Project):
self.eval_projects += [ep]
else:
raise ValueError(f"Unrecognized value for eval_project: {eval_projects}")
self.outcome = outcome
self.outcome1 = outcome1
self.outcome2 = outcome2
self.outdir = outdir
def add(self, path, label, out1, out2, order='f', order_col='order', gan=0):
"""Adds a sample size experiment to the given project annotations file.
Args:
path (str): Path to project annotations file.
label (str): Experimental label.
out1 (int): Number of lung adenocarcinomas (LUAD) to include in the
experiment.
out2 (int): Number of lung squamous cell carcinomas (LUSC) to include
in the experiment.
outcome (str, optional): Annotation header which indicates the outcome
of interest. Defaults to 'cohort'.
order (str, optional): 'f' (forward) or 'r' (reverse). Indicates which
direction to follow when sequentially adding slides.
Defaults to 'f'.
order_col (str, optional): Annotation header column to use when
sequentially adding slides. Defaults to 'order'.
gan (int, optional): Number of GAN slides to include in experiment.
Defaults to 0.
Returns:
None
"""
assert isinstance(out1, int)
assert isinstance(out2, int)
assert isinstance(gan, (int, float)) and 0 <= gan < 1
assert order in ('f', 'r')
ann = pd.read_csv(path, dtype=str)
print(f"Setting up exp. {label} with order {order} (sort by {order_col})")
ann[order_col] = pd.to_numeric(ann[order_col])
ann.sort_values(
['gan', self.outcome, order_col],
ascending=[True, True, (order != 'r')],
inplace=True
)
gan_out1 = round(gan * out1)
gan_out2 = round(gan * out2)
out1_indices = np.where((ann['site'].to_numpy() != 'GAN')
& (ann[self.outcome] == self.outcome1))[0]
out2_indices = np.where((ann['site'].to_numpy() != 'GAN')
& (ann[self.outcome] == self.outcome2))[0]
gan_out1_indices = np.where((ann['site'].to_numpy() == 'GAN')
& (ann[self.outcome] == self.outcome1))[0]
gan_out2_indices = np.where((ann['site'].to_numpy() == 'GAN')
& (ann[self.outcome] == self.outcome2))[0]
assert out1 <= out1_indices.shape[0]
assert out2 <= out2_indices.shape[0]
assert gan_out1 <= gan_out1_indices.shape[0]
assert gan_out2 <= gan_out2_indices.shape[0]
include = np.array(['exclude' for _ in range(len(ann))])
include[out1_indices[:out1]] = 'include'
include[out2_indices[:out2]] = 'include'
include[gan_out1_indices[:gan_out1]] = 'include'
include[gan_out2_indices[:gan_out2]] = 'include'
ann[f'include_{label}'] = include
ann.to_csv(path, index=False)
@staticmethod
def config(name_pattern, subset, ratio, **kwargs):
"""Configures a set of experiments.
Args:
name_pattern (str): String pattern for experiment naming.
subset (list(str)): List of experiment ID/labels.
ratio (float): Float 0-1. n_out1 / n_out2 (or n_out2 / n_out1)
"""
if not isinstance(ratio, (int, float)) and ratio >= 1:
raise ValueError("Invalid ratio; must be float >= 1")
config = {}
for exp in ALL_EXP:
if exp not in subset:
continue
if exp == 'AA' and ratio != 1:
raise ValueError("Cannot create full dataset exp. with ratio != 1")
exp_name = name_pattern.format(exp)
if ratio != 1:
n1 = round(ALL_EXP[exp] / (1 + (1/ratio)))
n2 = ALL_EXP[exp] - n1
config.update({
exp_name: {'out1': n1, 'out2': n2, **kwargs},
exp_name+'i': {'out1': n2, 'out2': n1, **kwargs}
})
else:
if ALL_EXP[exp] == 'full':
n_out1 = 467
n_out2 = 474
else:
n_out1 = n_out2 = int(ALL_EXP[exp] / 2)
config.update({
exp_name: {'out1': n_out1, 'out2': n_out2, **kwargs},
})
return config
def display(self, df, eval_dfs, hue='uq', palette='tab10', relplot_uq_compare=True,
boxplot_uq_compare=True, ttest_uq_groups=['all', 'include'],
prefix=''):
"""Creates plots from assmebled results, exports results to CSV.
Args:
df (pandas.DataFrame): Cross-validation results metrics, as generated
by results()
eval_dfs (dict(pandas.DataFrame)): Dict of external eval dataset names
(keys) mapped to pandas DataFrame of result metrics (values).
hue (str, optional): Comparison to show with different hue on plots.
Defaults to 'uq'.
palette (str, optional): Seaborn color palette. Defaults to 'tab10'.
relplot_uq_compare (bool, optional): For the Relplot display, ensure
non-UQ and UQ results are generated from the same models/preds.
boxplot_uq_compare (bool, optional): For the boxplot display, ensure
non-UQ and UQ results are generated from the same models/preds.
ttest_uq_groups (list(str)): UQ groups to compare via t-test. Defaults
to ['all', 'include'].
prefix (str, optional): Prefix to use when saving figures.
Defaults to empty string.
Returns:
None
"""
if not len(df):
log.error("No results to display")
return
# Filter out UQ results if n_slides < 100
df = df.loc[~ ((df['n_slides'] < 100)
& (df['uq'].isin(['include', 'exclude'])))]
# --- Paired t-tests ---------------------------------------------------
if ttest_uq_groups and len(ttest_uq_groups) != 2:
raise ValueError("Length of ttest_uq_groups must be exactly 2")
ttest_df = df.loc[df['uq'].isin(ttest_uq_groups)].copy()
ttest_df = ttest_df.sort_values(['id', 'fold'])
def perform_paired_testing(level):
print(f"Paired t-tests ({level}-level):")
for n in sorted(ttest_df['n_slides'].unique()):
exp_df = ttest_df[ttest_df['n_slides'] == n]
try:
ttest_result = stats.ttest_rel(
exp_df.loc[exp_df['uq'] == ttest_uq_groups[0]][f'{level}_auc'],
exp_df.loc[exp_df['uq'] == ttest_uq_groups[1]][f'{level}_auc'],
alternative='less')
print(n, '\t', 'p =', ttest_result.pvalue)
except ValueError:
print(n, '\t', 'p = (error)')
perform_paired_testing('patient')
perform_paired_testing('slide')
# --- Cross-validation plots -------------------------------------------
if len(df):
# AUC (relplot)
if relplot_uq_compare:
rel_df = df.loc[df['uq'] != 'none']
else:
rel_df = df
sns.relplot(
x='n_slides',
y='slide_auc',
data=rel_df,
hue=hue,
marker='o',
kind='line',
palette=palette
)
plt.title('Cross-val AUC')
ax = plt.gca()
ax.set_ylim([0.5, 1])
ax.grid(visible=True, which='both', axis='both', color='white')
ax.set_facecolor('#EAEAF2')
ax.xaxis.set_minor_locator(plticker.MultipleLocator(100))
plt.subplots_adjust(top=0.9)
plt.savefig(join(self.outdir, f'{prefix}relplot.svg'))
f, axes = plt.subplots(1, 3)
f.set_size_inches(18, 6)
# AUC boxplot
if boxplot_uq_compare:
box_df = df.loc[df['uq'] != 'none']
else:
box_df = df
sns.boxplot(
x='n_slides',
y='slide_auc',
hue=hue,
data=box_df,
ax=axes[0],
palette=palette
)
axes[0].title.set_text('Cross-val AUC')
axes[0].set_ylabel('')
axes[0].tick_params(labelrotation=90)
# AUC scatter - LOESS & standard error
df = df.sort_values(by=['n_slides'])
x = df['n_slides'].to_numpy().astype(np.float32)
y = df['slide_auc'].to_numpy()
lo = loess(x, y)
try:
lo.fit()
pred = lo.predict(x, stderror=True)
conf = pred.confidence()
z = pred.values
ll = conf.lower
ul = conf.upper
axes[1].plot(x, y, '+', ms=6)
axes[1].plot(x, z)
axes[1].fill_between(x, ll, ul, alpha=.33)
except ValueError:
pass
axes[1].xaxis.set_minor_locator(plticker.MultipleLocator(20))
axes[1].spines['bottom'].set_linewidth(0.5)
axes[1].spines['bottom'].set_color('black')
axes[1].tick_params(axis='x', colors='black')
axes[1].grid(visible=True, which='both', axis='both', color='white')
axes[1].set_facecolor('#EAEAF2')
axes[1].set_xscale('log')
axes[1].title.set_text('Cross-val AUC')
# % slides included
sns.lineplot(
x='n_slides',
y='patient_uq_perc',
data=df,
marker='o',
ax=axes[2],
zorder=3
)
axes[2].set_ylabel('')
axes[2].title.set_text('% Patients Included with UQ (cross-val)')
axes[2].xaxis.set_minor_locator(plticker.MultipleLocator(100))
axes[2].tick_params(labelrotation=90)
axes[2].grid(visible=True, which='both', axis='both', color='white', zorder=0)
axes[2].set_facecolor('#EAEAF2')
axes[2].set_xlim(100)
axes[2].scatter(x=df.groupby('n_slides', as_index=False).median().n_slides.values, y=df.groupby('n_slides').median().patient_uq_perc.values, marker='x', zorder=5)
plt.subplots_adjust(bottom=0.2)
plt.savefig(join(self.outdir, f'{prefix}crossval.svg'))
# --- Evaluation plots ----------------------------------------------------
if eval_dfs:
for eval_name, eval_df in eval_dfs.items():
if not len(eval_df):
continue
has_uq = len(eval_df.loc[eval_df['uq'].isin(['include', 'exclude'])])
# Prepare figure
sns.set(rc={"xtick.bottom": True, "ytick.left": True})
f, axes = plt.subplots(1, (4 if has_uq else 3))
f.suptitle(f'{eval_name} Evaluation Dataset')
f.set_size_inches(16, 4)
# AUC
if not len(eval_df):
continue
eval_df = eval_df.loc[~ ((eval_df['n_slides'] < 100)
& (eval_df['uq'].isin(['include', 'exclude'])))]
sns.lineplot(
x='n_slides',
y='patient_auc',
hue=hue,
data=eval_df,
marker="o",
ax=axes[0]
)
sns.scatterplot(
x='n_slides',
y='slide_auc',
hue=hue,
data=eval_df,
marker="x",
ax=axes[0]
)
axes[0].get_legend().remove()
axes[0].title.set_text('AUC')
# Accuracy
sns.lineplot(
x='n_slides',
y='patient_acc',
hue=hue,
data=eval_df,
marker="o",
ax=axes[1]
)
sns.scatterplot(
x='n_slides',
y='slide_acc',
hue=hue,
data=eval_df,
marker="x",
ax=axes[1]
)
axes[1].get_legend().remove()
axes[1].title.set_text('Accuracy')
# Youden's index
sns.lineplot(
x='n_slides',
y='patient_youden',
hue=hue,
data=eval_df,
marker="o",
ax=axes[2]
)
sns.scatterplot(
x='n_slides',
y='slide_youden',
hue=hue,
data=eval_df,
marker="x",
ax=axes[2]
)
axes[2].title.set_text("Youden's J")
axes[2].get_legend().remove()
# % slides included
if has_uq:
sns.lineplot(
x='n_slides',
y='patient_incl',
data=eval_df.loc[eval_df['uq'] == 'include'],
marker='o'
)
sns.scatterplot(
x='n_slides',
y='slide_incl',
data=eval_df.loc[eval_df['uq'] == 'include'],
marker='x'
)
axes[3].title.set_text('% Included')
for ax in axes:
ax.set_ylabel('')
ax.xaxis.set_major_locator(plticker.MultipleLocator(base=100))
ax.tick_params(labelrotation=90)
plt.subplots_adjust(top=0.8)
plt.subplots_adjust(bottom=0.2)
plt.savefig(join(self.outdir, f'{prefix}eval.svg'))
def plot_uq_calibration(self, label, tile_uq, slide_uq, slide_pred, epoch=1):
"""Plots a graph of predictions vs. uncertainty.
Args:
label (str): Experiment label.
kfold (int): Validation k-fold.
tile_uq (float): Tile-level uncertainty threshold.
slide_uq (float): Slide-level uncertainty threshold.
slide_pred (float): Slide-level prediction threshold.
Returns:
None
"""
val_dfs = [
pd.read_csv(
join(
utils.find_model(self.train_project, label, kfold=k, outcome=self.outcome),
f'tile_predictions_val_epoch{epoch}.csv'),
dtype={'slide': str})
for k in range(1, 4)
]
for v in range(len(val_dfs)):
utils.rename_cols(val_dfs[v], outcome=self.outcome)
_df = val_dfs[0]
_df = pd.concat([_df, val_dfs[1]], axis=0, join='outer', ignore_index=True)
_df = pd.concat([_df, val_dfs[2]], axis=0, join='outer', ignore_index=True)
# Plot tile-level uncertainty
patients = self.train_project.dataset().patients()
_df, _ = threshold.process_tile_predictions(_df, patients=patients)
threshold.plot_uncertainty(
_df,
kind='tile',
threshold=tile_uq,
title=f'CV UQ Calibration: {label}'
)
# Plot slide-level uncertainty
_df = _df[_df['uncertainty'] < tile_uq]
_s_df, _ = threshold.process_group_predictions(
_df,
pred_thresh=slide_pred,
level='slide'
)
threshold.plot_uncertainty(
_s_df,
kind='slide',
threshold=slide_uq,
title=f'CV UQ Calibration: {label}'
)
def results(self, exp_to_run, uq=True, eval=True, plot=False):
"""Assembles results from experiments, applies UQ thresholding,
and returns pandas dataframes with metrics.
Args:
exp_to_run (list): List of experiment IDs to search for results.
uq (bool, optional): Apply UQ thresholds. Defaults to True.
eval (bool, optional): Calculate results of external evaluation models.
Defaults to True.
plot (bool, optional): Show plots. Defaults to False.
Returns:
pandas.DataFrame: Cross-val results,
pandas.DataFrame: Dxternal eval results
"""
# === Initialize projects & prepare experiments ===========================
P = self.train_project
eval_Ps = self.eval_projects
df = pd.DataFrame()
eval_dfs = {val_P.name: pd.DataFrame() for val_P in eval_Ps}
prediction_thresholds = {}
slide_uq_thresholds = {}
tile_uq_thresholds = {}
pred_uq_thresholds = {}
# === Show results from designated epoch ==================================
for exp in exp_to_run:
try:
models = utils.find_cv(P, f'EXP_{exp}', outcome=self.outcome)
except MatchError:
log.debug(f"Unable to find cross-val results for {exp}; skipping")
continue
for i, m in enumerate(models):
try:
results = utils.get_model_results(m, outcome=self.outcome, epoch=1)
except FileNotFoundError:
print(f"Unable to open cross-val results for {exp}; skipping")
continue
m_slides = sf.util.get_slides_from_model_manifest(m, dataset=None)
df = pd.concat([df, pd.DataFrame([{
'id': exp,
'n_slides': len(m_slides),
'fold': i+1,
'uq': 'none',
'patient_auc': results['pt_auc'],
'patient_ap': results['pt_ap'],
'slide_auc': results['slide_auc'],
'slide_ap': results['slide_ap'],
'tile_auc': results['tile_auc'],
'tile_ap': results['tile_ap'],
}])], axis=0, join='outer', ignore_index=True)
# === Add UQ Crossval results (non-thresholded) ===========================
for exp in exp_to_run:
try:
skip = False
models = utils.find_cv(P, f'EXP_{exp}_UQ', outcome=self.outcome)
except MatchError:
continue
all_pred_thresh = []
for i, m in enumerate(models):
try:
results = utils.get_model_results(m, outcome=self.outcome, epoch=1)
all_pred_thresh += [results['opt_thresh']]
df = pd.concat([df, pd.DataFrame([{
'id': exp,
'n_slides': len(sf.util.get_slides_from_model_manifest(m, dataset=None)),
'fold': i+1,
'uq': 'all',
'patient_auc': results['pt_auc'],
'patient_ap': results['pt_ap'],
'slide_auc': results['slide_auc'],
'slide_ap': results['slide_ap'],
'tile_auc': results['tile_auc'],
'tile_ap': results['tile_ap'],
}])], axis=0, join='outer', ignore_index=True)
except FileNotFoundError:
log.debug(f"Skipping UQ crossval (non-thresholded) results for {exp}; not found")
skip = True
break
if not skip:
prediction_thresholds[exp] = mean(all_pred_thresh)
# === Get & Apply Nested UQ Threshold =====================================
if uq:
pb = tqdm(exp_to_run)
for exp in pb:
# Skip UQ for experiments with n_slides < 100
if exp in ('V', 'W', 'X', 'Y', 'Z', 'ZA', 'ZB', 'ZC', 'ZD'):
continue
pb.set_description(f"Calculating thresholds (exp {exp})...")
try:
_df, thresh = self.thresholds_from_nested_cv(
f'EXP_{exp}_UQ', id=exp
)
df = pd.concat([df, _df], axis=0, join='outer', ignore_index=True)
except (MatchError, FileNotFoundError, ModelNotFoundError) as e:
log.debug(str(e))
log.debug(f"Skipping UQ crossval results for {exp}; not found")
continue
except ThresholdError as e:
log.debug(str(e))
log.debug(f'Skipping UQ crossval results for {exp}; could not find thresholds in cross-validation')
continue
tile_uq_thresholds[exp] = thresh['tile_uq']
slide_uq_thresholds[exp] = thresh['slide_uq']
pred_uq_thresholds[exp] = thresh['slide_pred']
# Show CV uncertainty calibration
if plot and exp == 'AA':
print("Plotting UQ calibration for cross-validation (exp. AA)")
self.plot_uq_calibration(
label=f'EXP_{exp}_UQ',
**thresh
)
plt.show()
# === Show external validation results ====================================
if eval:
# --- Step 7A: Show non-UQ external validation results ----------------
for val_P in eval_Ps:
name = val_P.name
pb = tqdm(exp_to_run, ncols=80)
for exp in pb:
pb.set_description(f'Working on {name} eval (EXP {exp})...')
# Read and prepare model results
try:
eval_dir = utils.find_eval(val_P, f'EXP_{exp}_FULL', outcome=self.outcome)
results = utils.get_eval_results(eval_dir, outcome=self.outcome)
except (FileNotFoundError, MatchError):
log.debug(f"Skipping eval for exp {exp}; eval not found")
continue
if not utils.model_exists(P, f'EXP_{exp}_FULL', outcome=self.outcome, epoch=1):
log.debug(f'Skipping eval for exp {exp}; trained model not found')
continue
if exp not in prediction_thresholds:
log.warn(f"No predictions threshold for experiment {exp}; using slide-level pred threshold of 0.5")
pred_thresh = 0.5
else:
pred_thresh = prediction_thresholds[exp]
# Patient-level and slide-level predictions & metrics
patient_yt, patient_yp = utils.read_group_predictions(
join(
eval_dir,
f'patient_predictions_{self.outcome}_eval.csv'
)
)
patient_metrics = utils.prediction_metrics(
patient_yt,
patient_yp,
threshold=pred_thresh
)
patient_metrics = {
f'patient_{m}': patient_metrics[m]
for m in patient_metrics
}
slide_yt, slide_yp = utils.read_group_predictions(
join(
eval_dir,
f'patient_predictions_{self.outcome}_eval.csv'
)
)
slide_metrics = utils.prediction_metrics(
slide_yt,
slide_yp,
threshold=pred_thresh
)
slide_metrics = {
f'slide_{m}': slide_metrics[m]
for m in slide_metrics
}
model = utils.find_model(P, f'EXP_{exp}_FULL', outcome=self.outcome, epoch=1)
n_slides = len(sf.util.get_slides_from_model_manifest(model, dataset=None))
eval_dfs[name] = pd.concat([eval_dfs[name], pd.DataFrame([{
'id': exp,
'n_slides': n_slides,
'uq': 'none',
'incl': 1,
'patient_auc': results['pt_auc'],
'patient_ap': results['pt_ap'],
'slide_auc': results['slide_auc'],
'slide_ap': results['slide_ap'],
**patient_metrics,
**slide_metrics,
}])], axis=0, join='outer', ignore_index=True)
# --- [end patient-level predictions] -------------------------
if exp not in prediction_thresholds:
log.debug(f"Unable to calculate eval UQ performance; no prediction thresholds found for exp {exp}")
continue
# --- Step 7B: Show UQ external validation results ------------
if uq:
if exp in tile_uq_thresholds:
for keep in ('high_confidence', 'low_confidence'):
tile_pred_df = pd.read_csv(
join(
eval_dir,
'tile_predictions_eval.csv'
), dtype={'slide': str}
)
new_cols = {
f'{self.outcome}_y_pred1': 'y_pred',
f'{self.outcome}_y_true0': 'y_true',
f'{self.outcome}_uncertainty1': 'uncertainty'
}
tile_pred_df.rename(columns=new_cols, inplace=True)
thresh_tile = tile_uq_thresholds[exp]
thresh_slide = slide_uq_thresholds[exp]
val_patients = val_P.dataset(verification=None).patients()
def get_metrics_by_level(level):
return threshold.apply(
tile_pred_df,
tile_uq=thresh_tile,
slide_uq=thresh_slide,
tile_pred=0.5,
slide_pred=pred_uq_thresholds[exp],
plot=(plot and level == 'slide' and keep == 'high_confidence' and exp == 'AA'),
title=f'{name}: Exp. {exp} Uncertainty',
keep=keep, # Keeps only LOW or HIGH-confidence slide predictions
patients=val_patients,
level=level
)
s_results, _ = get_metrics_by_level('slide')
p_results, _ = get_metrics_by_level('patient')
if (plot and keep == 'high_confidence' and exp == 'AA'):
plt.savefig(join(self.outdir, f'{name}_uncertainty_v_preds.svg'))
full_model = utils.find_model(P, f'EXP_{exp}_FULL', outcome=self.outcome, epoch=1)
n_slides = len(sf.util.get_slides_from_model_manifest(full_model, dataset=None))
eval_dfs[name] = pd.concat([eval_dfs[name], pd.DataFrame([{
'id': exp,
'n_slides': n_slides,
'uq': ('include' if keep == 'high_confidence' else 'exclude'),
'slide_incl': s_results['percent_incl'],
'slide_auc': s_results['auc'],
'slide_acc': s_results['acc'],
'slide_sens': s_results['sensitivity'],
'slide_spec': s_results['specificity'],
'slide_youden': s_results['sensitivity'] + s_results['specificity'] - 1,
'patient_incl': p_results['percent_incl'],
'patient_auc': p_results['auc'],
'patient_acc': p_results['acc'],
'patient_sens': p_results['sensitivity'],
'patient_spec': p_results['specificity'],
'patient_youden': p_results['sensitivity'] + p_results['specificity'] - 1,
}])], axis=0, join='outer', ignore_index=True)
for eval_name in eval_dfs:
eval_dfs[eval_name].to_csv(
join(self.outdir, f'{eval_name}_results.csv'),
index=False
)
else:
eval_dfs = None
df.to_csv(join(self.outdir, 'crossval_results.csv'), index=False)
return df, eval_dfs
def run(self, exp_to_run, steps=None, hp='nature2022'):
"""Trains the designated experiments.
Args:
exp_to_run (dict): Dict containing experiment configuration,
as provided by config().
steps (list(int)): Steps to run. Defaults to all steps, 1-6.
hp (slideflow.ModelParams, optional): Hyperparameters object.
Defaults to hyperparameters used for publication.
Returns:
None
"""
# === Initialize projects & prepare experiments ===========================
print(sf.util.bold("Initializing experiments..."))
P = self.train_project
eval_Ps = self.eval_projects
exp_annotations = join(P.root, 'experiments.csv')
if P.annotations != exp_annotations:
if not exists(exp_annotations):
shutil.copy(P.annotations, exp_annotations)
P.annotations = exp_annotations
exp_to_add = [
e for e in exp_to_run
if f'include_{e}' not in pd.read_csv(exp_annotations).columns.tolist()
]
for exp in exp_to_add:
self.add(exp_annotations, label=exp, **exp_to_run[exp])
full_epoch_exp = [e for e in exp_to_run if e in ('AA', 'A', 'D', 'G')]
if hp == 'nature2022':
exp_hp = biscuit_hp.nature2022()
else:
exp_hp = hp
# Configure steps to run
if steps is None:
steps = range(7)
# === Step 1: Initialize full-epochs experiments ==========================
if 1 in steps:
print(sf.util.bold("[Step 1] Running full-epoch experiments..."))
exp_hp.epochs = [1, 3, 5, 10]
for exp in full_epoch_exp:
val_k = [
k for k in range(1, 4)
if not utils.model_exists(P, f'EXP_{exp}', outcome=self.outcome, kfold=k)
]
if not len(val_k):
print(f'Skipping Step 1 for experiment {exp}; already done.')
continue
elif val_k != list(range(1, 4)):
print(f'[Step 1] Some k-folds done; running {val_k} for {exp}')
self.train(
hp=exp_hp,
label=f'EXP_{exp}',
filters={f'include_{exp}': ['include']},
splits=f'splits_{exp}.json',
val_k=val_k,
val_strategy='k-fold',
save_model=False
)
# === Step 2: Run the rest of the experiments at the designated epoch =====
if 2 in steps:
print(sf.util.bold("[Step 2] Running experiments at target epoch..."))
exp_hp.epochs = [1]
for exp in exp_to_run:
if exp in full_epoch_exp:
continue # Already done in Step 2
val_k = [
k for k in range(1, 4)
if not utils.model_exists(P, f'EXP_{exp}', outcome=self.outcome, kfold=k)
]
if not len(val_k):
print(f'Skipping Step 2 for experiment {exp}; already done.')
continue
elif val_k != list(range(1, 4)):
print(f'[Step 2] Some k-folds done; running {val_k} for {exp}')
self.train(
hp=exp_hp,
label=f'EXP_{exp}',
filters={f'include_{exp}': ['include']},
save_predictions=True,
splits=f'splits_{exp}.json',
val_k=val_k,
val_strategy='k-fold',
save_model=False
)
# === Step 3: Run experiments with UQ & save predictions ==================
if 3 in steps:
print(sf.util.bold("[Step 3] Running experiments with UQ..."))
exp_hp.epochs = [1]
exp_hp.uq = True
for exp in exp_to_run:
val_k = [
k for k in range(1, 4)
if not utils.model_exists(P, f'EXP_{exp}_UQ', outcome=self.outcome, kfold=k)
]
if not len(val_k):
print(f'Skipping Step 3 for experiment {exp}; already done.')
continue
elif val_k != list(range(1, 4)):
print(f'[Step 3] Some k-folds done; running {val_k} for {exp}')
self.train(
hp=exp_hp,
label=f'EXP_{exp}_UQ',
filters={f'include_{exp}': ['include']},
save_predictions=True,
splits=f'splits_{exp}.json',
val_k=val_k,
val_strategy='k-fold',
save_model=False
)
# === Step 4: Run nested UQ cross-validation ==============================
if 4 in steps:
print(sf.util.bold("[Step 4] Running nested UQ experiments..."))
exp_hp.epochs = [1]
exp_hp.uq = True
for exp in exp_to_run:
total_slides = exp_to_run[exp]['out2'] + exp_to_run[exp]['out1']
if total_slides >= 50:
self.train_nested_cv(
hp=exp_hp,
label=f'EXP_{exp}_UQ',
val_strategy='k-fold'
)
else:
print(f"[Step 4] Skipping UQ for {exp}, need >=50 slides")
# === Step 5: Train models across full datasets ===========================
if 5 in steps:
print(sf.util.bold("[Step 5] Training across full datasets..."))
exp_hp.epochs = [1]
exp_hp.uq = True
for exp in exp_to_run:
if utils.model_exists(P, f'EXP_{exp}_FULL', outcome=self.outcome):
print(f'Skipping Step 5 for experiment {exp}; already done.')
else:
stop_batch = utils.find_cv_early_stop(P, f'EXP_{exp}', outcome=self.outcome, k=3)
print(f"Using detected early stop batch {stop_batch}")
self.train(
hp=exp_hp,
label=f'EXP_{exp}_FULL',
filters={f'include_{exp}': ['include']},
save_model=True,
val_strategy='none',
steps_per_epoch_override=stop_batch
)
# === Step 6: External validation ========================================
if 6 in steps:
for val_P in eval_Ps:
print(sf.util.bold(f"[Step 6] Running eval ({val_P.name})..."))
for exp in exp_to_run:
full_model = utils.find_model(P, f'EXP_{exp}_FULL', outcome=self.outcome, epoch=1)
if utils.eval_exists(val_P, f'EXP_{exp}_FULL', outcome=self.outcome, epoch=1):
print(f'Skipping eval for experiment {exp}; already done.')
else:
filters = {self.outcome: [self.outcome1, self.outcome2]}
val_P.evaluate(
full_model,
self.outcome,
filters=filters,
save_predictions=True,
)
def thresholds_from_nested_cv(self, label, outer_k=3, inner_k=5, id=None,
threshold_params=None, epoch=1,
tile_filename='tile_predictions_val_epoch1.csv',
y_true=None, y_pred=None, uncertainty=None):
"""Detects tile- and slide-level UQ thresholds and slide-level prediction
thresholds from nested cross-validation."""
if id is None:
id = label
patients = self.train_project.dataset(verification=None).patients()
if threshold_params is None:
threshold_params = {
'tile_pred': 'detect',
'slide_pred': 'detect',
'plot': False,
'patients': patients
}
all_tile_uq_thresh = []
all_slide_uq_thresh = []
all_slide_pred_thresh = []
df = pd.DataFrame()
for k in range(1, outer_k+1):
try:
dfs = utils.df_from_cv(
self.train_project,
f'{label}-k{k}',
outcome=self.outcome,
k=inner_k,
y_true=y_true,
y_pred=y_pred,
uncertainty=uncertainty)
except ModelNotFoundError:
log.warn(f"Could not find {label} k-fold {k}; skipping")
continue
val_path = join(
utils.find_model(self.train_project, f'{label}', kfold=k, outcome=self.outcome),
tile_filename
)
if not exists(val_path):
log.warn(f"Could not find {label} k-fold {k}; skipping")
continue
tile_uq = threshold.from_cv(
dfs,
tile_uq='detect',
slide_uq=None,
**threshold_params
)['tile_uq']
thresholds = threshold.from_cv(
dfs,
tile_uq=tile_uq,
slide_uq='detect',
**threshold_params
)
all_tile_uq_thresh += [tile_uq]
all_slide_uq_thresh += [thresholds['slide_uq']]
all_slide_pred_thresh += [thresholds['slide_pred']]
if sf.util.path_to_ext(val_path).lower() == 'csv':
tile_pred_df = pd.read_csv(val_path, dtype={'slide': str})
elif sf.util.path_to_ext(val_path).lower() in ('parquet', 'gzip'):
tile_pred_df = pd.read_parquet(val_path)
else:
raise OSError(f"Unrecognized prediction filetype {val_path}")
utils.rename_cols(tile_pred_df, self.outcome, y_true=y_true, y_pred=y_pred, uncertainty=uncertainty)
def uq_auc_by_level(level):
results, _ = threshold.apply(
tile_pred_df,
plot=False,
patients=patients,
level=level,
**thresholds
)
return results['auc'], results['percent_incl']
pt_auc, pt_perc = uq_auc_by_level('patient')
slide_auc, slide_perc = uq_auc_by_level('slide')
model = utils.find_model(
self.train_project,
f'{label}',
kfold=k,
epoch=1,
outcome=self.outcome
)
m_slides = sf.util.get_slides_from_model_manifest(model, dataset=None)
df = pd.concat([df, pd.DataFrame([{
'id': id,
'n_slides': len(m_slides),
'fold': k,
'uq': 'include',
'patient_auc': pt_auc,
'patient_uq_perc': pt_perc,
'slide_auc': slide_auc,
'slide_uq_perc': slide_perc
}])], axis=0, join='outer', ignore_index=True)
thresholds = {
'tile_uq': None if not all_tile_uq_thresh else mean(all_tile_uq_thresh),
'slide_uq': None if not all_slide_uq_thresh else mean(all_slide_uq_thresh),
'slide_pred': None if not all_slide_pred_thresh else mean(all_slide_pred_thresh),
}
return df, thresholds
def train(self, hp, label, filters=None, save_predictions='csv',
validate_on_batch=32, validation_steps=32, **kwargs):
r"""Train outer cross-validation models.
Args:
hp (:class:`slideflow.ModelParams`): Hyperparameters object.
label (str): Experimental label.
filters (dict, optional): Dataset filters to use for
selecting slides. See :meth:`slideflow.Dataset.filter` for
more information. Defaults to None.
save_predictions (bool, optional): Save validation predictions to
model folder. Defaults to 'csv'.
Keyword args:
validate_on_batch (int): Frequency of validation checks during
training, in steps. Defaults to 32.
validation_steps (int): Number of validation steps to perform
during each mid-training evaluation check. Defaults to 32.
**kwargs: All remaining keyword arguments are passed to
:meth:`slideflow.Project.train`.
Returns:
None
"""
self.train_project.train(
self.outcome,
exp_label=label,
filters=filters,
params=hp,
save_predictions=save_predictions,
validate_on_batch=validate_on_batch,
validation_steps=validation_steps,
**kwargs
)
def train_nested_cv(self, hp, label, outer_k=3, inner_k=5, **kwargs):
r"""Train models using nested cross-validation (outer_k=3, inner_k=5),
skipping already-generated models.
Args:
hp (slideflow.ModelParams): Hyperparameters object.
label (str): Experimental label.
Keyword args:
outer_k (int): Number of outer cross-folds. Defaults to 3.
inner_k (int): Number of inner cross-folds. Defaults to 5.
**kwargs: All remaining keyword arguments are passed to
:meth:`slideflow.Project.train`.
Returns:
None
"""
k_models = utils.find_cv(self.train_project, label, k=outer_k, outcome=self.outcome)
for ki, k_model in enumerate(k_models):
inner_k_to_run = [
k for k in range(1, inner_k+1)
if not utils.model_exists(self.train_project, f'{label}-k{ki+1}', outcome=self.outcome, kfold=k)
]
if not len(inner_k_to_run):
print(f'Skipping nested cross-val (inner k{ki+1} for experiment '
f'{label}; already done.')
else:
if inner_k_to_run != list(range(1, inner_k+1)):
print(f'Only running k-folds {inner_k_to_run} for nested '
f'cross-val k{ki+1} in experiment {label}; '
'some k-folds already done.')
train_slides = sf.util.get_slides_from_model_manifest(
k_model, dataset='training'
)
self.train(
hp=hp,
label=f"{label}-k{ki+1}",
filters={'slide': train_slides},
val_k_fold=inner_k,
val_k=inner_k_to_run,
save_predictions=True,
save_model=False,
**kwargs
)