import warnings
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn import metrics
from sklearn.exceptions import UndefinedMetricWarning
from slideflow.util import log
from . import errors, utils
[docs]def plot_uncertainty(df, kind, threshold=None, title=None):
"""Plots figure of tile or slide-level predictions vs. uncertainty.
Args:
df (pandas.DataFrame): Processed dataframe containing columns
'uncertainty', 'correct', 'y_pred'.
kind (str): Kind of plot. If 'tile', subsample to only 1000 points.
Included in title.
threshold (float, optional): Uncertainty threshold.
Defaults to None.
title (str, optional): Title for plots. Defaults to None.
Returns:
None
"""
try:
from skmisc.loess import loess
except ImportError:
raise ImportError(
"Uncertainty plots with loess estimation require scikit-misc, "
"which is not installed."
)
# Subsample tile-level predictions
if kind == 'tile':
df = df.sample(n=1000)
f, axes = plt.subplots(1, 3)
f.set_size_inches(15, 5)
palette = sns.color_palette("Set2")
tf_pal = {True: palette[0], False: palette[1]}
# Left figure - KDE -------------------------------------------------------
kde = sns.kdeplot(
x='uncertainty',
hue='correct',
data=df,
fill=True,
palette=tf_pal,
ax=axes[0]
)
kde.set(xlabel='Uncertainty')
axes[0].title.set_text(f'Uncertainty density ({kind}-level)')
# Middle figure - Scatter -------------------------------------------------
# - Above threshold
if threshold is not None:
axes[1].axhline(y=threshold, color='r', linestyle='--')
at_df = df.loc[(df['uncertainty'] >= threshold)]
c_a_df = at_df.loc[at_df['correct']]
ic_a_df = at_df.loc[~at_df['correct']]
axes[1].scatter(
x=c_a_df['y_pred'],
y=c_a_df['uncertainty'],
marker='o',
s=10,
color='gray'
)
axes[1].scatter(
x=ic_a_df['y_pred'],
y=ic_a_df['uncertainty'],
marker='x',
color='#FC6D77'
)
# - Below threshold
if threshold is not None:
bt_df = df.loc[(df['uncertainty'] < threshold)]
else:
bt_df = df
c_df = bt_df.loc[bt_df['correct']]
ic_df = bt_df.loc[~bt_df['correct']]
axes[1].scatter(
x=c_df['y_pred'],
y=c_df['uncertainty'],
marker='o',
s=10
)
axes[1].scatter(
x=ic_df['y_pred'],
y=ic_df['uncertainty'],
marker='x',
color='red'
)
if title is not None:
axes[1].title.set_text(title)
# Right figure - probability calibration ----------------------------------
l_df = df[['uncertainty', 'correct']].sort_values(by=['uncertainty'])
x = l_df['uncertainty'].to_numpy()
y = l_df['correct'].astype(float).to_numpy()
ol = loess(x, y)
ol.fit()
pred = ol.predict(x, stderror=True)
conf = pred.confidence()
z = pred.values
ll = conf.lower
ul = conf.upper
axes[2].plot(x, y, '+', ms=6)
axes[2].plot(x, z)
axes[2].fill_between(x, ll, ul, alpha=.2)
axes[2].tick_params(labelrotation=90)
axes[2].set_ylim(-0.1, 1.1)
if threshold is not None:
axes[2].axvline(x=threshold, color='r', linestyle='--')
# - Figure style
for ax in (axes[1], axes[2]):
ax.spines['bottom'].set_linewidth(0.5)
ax.spines['bottom'].set_color('black')
ax.tick_params(axis='x', colors='black')
ax.grid(visible=True, which='both', axis='both', color='white')
ax.set_facecolor('#EAEAF2')
[docs]def process_tile_predictions(df, pred_thresh=0.5, patients=None):
'''Load and process tile-level predictions from CSV.
Args:
df (pandas.DataFrame): Unprocessed DataFrame from reading tile-level
predictions.
pred_thresh (float or str, optional): Tile-level prediction threshold.
If 'detect', will auto-detect via Youden's J. Defaults to 0.5.
patients (dict, optional): Dict mapping slides to patients, used for
patient-level thresholding. Defaults to None.
Returns:
pandas.DataFrame, tile prediction threshold
'''
# Tile-level AUC
if np.isnan(df['y_pred'].to_numpy()).sum():
raise errors.PredsContainNaNError
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UndefinedMetricWarning)
fpr, tpr, thresh = metrics.roc_curve(
df['y_true'].to_numpy(),
df['y_pred'].to_numpy()
)
tile_auc = metrics.auc(fpr, tpr)
try:
max_j = max(zip(tpr, fpr), key=lambda x: x[0]-x[1])
opt_pred = thresh[list(zip(tpr, fpr)).index(max_j)]
except ValueError:
log.debug("Unable to calculate tile prediction threshold; using 0.5")
opt_pred = 0.5
if pred_thresh == 'detect':
log.debug(f"Auto-detected tile prediction threshold: {opt_pred:.4f}")
pred_thresh = opt_pred
else:
log.debug(f"Using tile prediction threshold: {pred_thresh:.4f}")
if patients is not None:
df['patient'] = df['slide'].map(patients)
else:
log.warn('Patients not provided; assuming 1:1 slide:patient mapping')
log.debug(f'Tile AUC: {tile_auc:.4f}')
# Calculate tile-level prediction accuracy
df['error'] = abs(df['y_true'] - df['y_pred'])
df['correct'] = (
((df['y_pred'] < pred_thresh) & (df['y_true'] == 0))
| ((df['y_pred'] >= pred_thresh) & (df['y_true'] == 1))
)
df['incorrect'] = (~df['correct']).astype(int)
df['y_pred_bin'] = (df['y_pred'] >= pred_thresh).astype(int)
return df, pred_thresh
[docs]def process_group_predictions(df, pred_thresh, level):
'''From a given dataframe of tile-level predictions, calculate group-level
predictions and uncertainty.'''
if any(c not in df.columns for c in ('y_true', 'y_pred', 'uncertainty')):
raise ValueError('Missing columns. Expected y_true, y_pred, uncertainty.'
f'Got: {df.columns}')
# Calculate group-level predictions
log.debug(f'Calculating {level}-level means from {len(df)} predictions')
levels = [l for l in pd.unique(df[level]) if l is not np.nan]
reduced_df = df[[level, 'y_pred', 'y_true', 'uncertainty']]
grouped = reduced_df.groupby(level, as_index=False).mean()
yp = np.array([
grouped.loc[grouped[level] == lev]['y_pred'].to_numpy()[0]
for lev in levels
])
yt = np.array([
grouped.loc[grouped[level] == lev]['y_true'].to_numpy()[0]
for lev in levels
], dtype=np.uint8)
u = np.array([
grouped.loc[grouped[level] == lev]['uncertainty'].to_numpy()[0]
for lev in levels
])
if not len(yt):
raise errors.ROCFailedError("Unable to generate ROC; preds are empty.")
# Slide-level AUC
log.debug(f'Calculating {level}-level ROC')
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UndefinedMetricWarning)
l_fpr, l_tpr, l_thresh = metrics.roc_curve(yt, yp)
log.debug('Calculating AUC')
level_auc = metrics.auc(l_fpr, l_tpr)
log.debug('Calculating optimal threshold')
if pred_thresh == 'detect':
try:
max_j = max(zip(l_tpr, l_fpr), key=lambda x: x[0]-x[1])
pred_thresh = l_thresh[list(zip(l_tpr, l_fpr)).index(max_j)]
except ValueError:
raise errors.ROCFailedError(f"Unable to generate {level}-level ROC")
log.debug(f"Using detected prediction threshold: {pred_thresh:.4f}")
else:
log.debug(f"Using {level} prediction threshold: {pred_thresh:.4f}")
log.debug(f'{level} AUC: {level_auc:.4f}')
correct = (((yp < pred_thresh) & (yt == 0))
| ((yp >= pred_thresh) & (yt == 1)))
incorrect = pd.Series(
((yp < pred_thresh) & (yt == 1))
| ((yp >= pred_thresh) & (yt == 0))
).astype(int)
l_df = pd.DataFrame({
level: pd.Series(levels),
'error': pd.Series(abs(yt - yp)),
'uncertainty': pd.Series(u),
'correct': correct,
'incorrect': incorrect,
'y_true': pd.Series(yt),
'y_pred': pd.Series(yp),
'y_pred_bin': pd.Series(yp >= pred_thresh).astype(int)
})
return l_df, pred_thresh
[docs]def apply(df, tile_uq, slide_uq, tile_pred=0.5,
slide_pred=0.5, plot=False, keep='high_confidence',
title=None, patients=None, level='slide'):
'''Apply pre-calculcated tile- and group-level uncertainty thresholds.
Args:
df (pandas.DataFrame): Must contain columns 'y_true', 'y_pred',
and 'uncertainty'.
tile_uq (float): Tile-level uncertainty threshold.
slide_uq (float): Slide-level uncertainty threshold.
tile_pred (float, optional): Tile-level prediction threshold.
Defaults to 0.5.
slide_pred (float, optional): Slide-level prediction threshold.
Defaults to 0.5.
plot (bool, optional): Plot slide-level uncertainty. Defaults to False.
keep (str, optional): Either 'high_confidence' or 'low_confidence'.
Cohort to keep after thresholding. Defaults to 'high_confidence'.
title (str, optional): Title for uncertainty plot. Defaults to None.
patients (dict, optional): Dictionary mapping slides to patients. Adds
a 'patient' column in the tile prediction dataframe, enabling
patient-level thresholding. Defaults to None.
level (str, optional): Either 'slide' or 'patient'. Level at which to
apply threshold. If 'patient', requires patient dict be supplied.
Defaults to 'slide'.
Returns:
Dictionary of results, with keys auc, percent_incl, accuracy,
sensitivity, and specificity
DataFrame of thresholded group-level predictions
'''
assert keep in ('high_confidence', 'low_confidence')
assert not (level == 'patient' and patients is None)
log.debug(f"Applying tile UQ threshold of {tile_uq:.5f}")
if patients:
df['patient'] = df['slide'].map(patients)
log.debug(f"Number of {level}s before tile UQ filter: {pd.unique(df[level]).shape[0]}")
log.debug(f"Number of tiles before tile-level filter: {len(df)}")
df, _ = process_tile_predictions(
df,
pred_thresh=tile_pred,
patients=patients
)
num_pre_filter = pd.unique(df[level]).shape[0]
if tile_uq:
df = df[df['uncertainty'] < tile_uq]
log.debug(f"Number of {level}s after tile-level filter: {pd.unique(df[level]).shape[0]}")
log.debug(f"Number of tiles after tile-level filter: {len(df)}")
# Build group-level predictions
try:
s_df, _ = process_group_predictions(
df,
pred_thresh=slide_pred,
level=level
)
except errors.ROCFailedError:
log.error("Unable to process slide predictions")
empty_results = {k: None for k in ['auc',
'percent_incl',
'acc',
'sensitivity',
'specificity']}
return empty_results, None
if plot:
plot_uncertainty(s_df, threshold=slide_uq, kind=level, title=title)
# Apply slide-level thresholds
if slide_uq:
log.debug(f"Using {level} uncertainty threshold of {slide_uq:.5f}")
if keep == 'high_confidence':
s_df = s_df.loc[s_df['uncertainty'] < slide_uq]
elif keep == 'low_confidence':
s_df = s_df.loc[s_df['uncertainty'] >= slide_uq]
else:
raise Exception(f"Unknown keep option {keep}")
# Show post-filtering group-level predictions and AUC
auc = utils.auc(s_df['y_true'].to_numpy(), s_df['y_pred'].to_numpy())
num_post_filter = len(s_df)
percent_incl = num_post_filter / num_pre_filter
log.debug(f"Percent {level} included: {percent_incl*100:.2f}%")
# Calculate post-thresholded sensitivity/specificity
y_true = s_df['y_true'].to_numpy().astype(bool)
y_pred = s_df['y_pred'].to_numpy() > slide_pred
tp = np.logical_and(y_true, y_pred).sum()
fp = np.logical_and(np.logical_not(y_true), y_pred).sum()
tn = np.logical_and(np.logical_not(y_true), np.logical_not(y_pred)).sum()
fn = np.logical_and(y_true, np.logical_not(y_pred)).sum()
acc = (tp + tn) / (tp + tn + fp + fn)
sensitivity = tp / (tp + fn)
specificity = tn / (tn + fp)
log.debug(f"Accuracy: {acc:.4f}")
log.debug(f"Sensitivity: {sensitivity:.4f}")
log.debug(f"Specificity: {specificity:.4f}")
results = {
'auc': auc,
'percent_incl': percent_incl,
'acc': acc,
'sensitivity': sensitivity,
'specificity': specificity
}
return results, s_df
[docs]def detect(df, tile_uq='detect', slide_uq='detect', tile_pred='detect',
slide_pred='detect', plot=False, patients=None):
'''Detect optimal tile- and slide-level uncertainty thresholds.
Args:
df (pandas.DataFrame): Tile-level predictions. Must contain columns
'y_true', 'y_pred', and 'uncertainty'.
tile_uq (str or float): Either 'detect' or float. If 'detect',
will detect tile-level uncertainty threshold. If float, will use
the specified tile-level uncertainty threshold.
slide_uq (str or float): Either 'detect' or float. If 'detect',
will detect slide-level uncertainty threshold. If float, will use
the specified slide-level uncertainty threshold.
tile_pred (str or float): Either 'detect' or float. If 'detect',
will detect tile-level prediction threshold. If float, will use the
specified tile-level prediction threshold.
slide_pred (str or float): Either 'detect' or float. If 'detect'
will detect slide-level prediction threshold. If float, will use
the specified slide-level prediction threshold.
plot (bool, optional): Plot slide-level uncertainty. Defaults to False.
patients (dict, optional): Dict mapping slides to patients. Required
for patient-level thresholding.
Returns:
Dictionary with tile- and slide-level UQ and prediction threhsolds,
with keys: 'tile_uq', 'tile_pred', 'slide_uq', 'slide_pred'
Float: Slide-level AUROC
'''
log.debug("Detecting thresholds...")
empty_thresh = {k: None
for k in ['tile_uq', 'slide_uq', 'tile_pred', 'slide_pred']}
try:
df, detected_tile_pred = process_tile_predictions(
df,
pred_thresh=tile_pred,
patients=patients
)
except errors.PredsContainNaNError:
log.error("Tile-level predictions contain NaNs; unable to process.")
return empty_thresh, None
if tile_pred == 'detect':
tile_pred = detected_tile_pred
# Tile-level ROC and Youden's J
if isinstance(tile_uq, (float, np.float16, np.float32, np.float64)):
df = df[df['uncertainty'] < tile_uq]
elif tile_uq != 'detect':
log.debug("Not performing tile-level uncertainty thresholding.")
tile_uq = None
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UndefinedMetricWarning)
t_fpr, t_tpr, t_thresh = metrics.roc_curve(
df['incorrect'].to_numpy(),
df['uncertainty'].to_numpy()
)
max_j = max(zip(t_tpr, t_fpr), key=lambda x: x[0] - x[1])
tile_uq = t_thresh[list(zip(t_tpr, t_fpr)).index(max_j)]
log.debug(f"Tile-level optimal UQ threshold: {tile_uq:.4f}")
df = df[df['uncertainty'] < tile_uq]
slides = list(set(df['slide']))
log.debug(f"Number of slides after filter: {len(slides)}")
log.debug(f"Number of tiles after filter: {len(df)}")
# Build slide-level predictions
try:
s_df, slide_pred = process_group_predictions(
df,
pred_thresh=slide_pred,
level='slide'
)
except errors.ROCFailedError:
log.error("Unable to process slide predictions")
return empty_thresh, None
# Slide-level thresholding
if slide_uq == 'detect':
if not s_df['incorrect'].to_numpy().sum():
log.debug("Unable to calculate slide UQ threshold; no incorrect predictions made")
slide_uq = None
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UndefinedMetricWarning)
s_fpr, s_tpr, s_thresh = metrics.roc_curve(
s_df['incorrect'],
s_df['uncertainty'].to_numpy()
)
max_j = max(zip(s_tpr, s_fpr), key=lambda x: x[0]-x[1])
slide_uq = s_thresh[list(zip(s_tpr, s_fpr)).index(max_j)]
log.debug(f"Slide-level optimal UQ threshold: {slide_uq:.4f}")
if plot:
plot_uncertainty(s_df, threshold=slide_uq, kind='slide')
s_df = s_df[s_df['uncertainty'] < slide_uq]
else:
log.debug("Not performing slide-level uncertainty thresholding.")
slide_uq = 0.5
if plot:
plot_uncertainty(s_df, threshold=slide_uq, kind='slide')
# Show post-filtering slide predictions and AUC
auc = utils.auc(s_df['y_true'].to_numpy(), s_df['y_pred'].to_numpy())
thresholds = {
'tile_uq': tile_uq,
'slide_uq': slide_uq,
'tile_pred': tile_pred,
'slide_pred': slide_pred
}
return thresholds, auc
[docs]def from_cv(dfs, **kwargs):
'''Finds the optimal tile and slide-level thresholds from a set of nested
cross-validation experiments.
Args:
dfs (list(DataFrame)): List of DataFrames with tile predictions,
containing headers 'y_true', 'y_pred', 'uncertainty', 'slide',
and 'patient'.
Keyword args:
tile_uq (str or float): Either 'detect' or float. If 'detect',
will detect tile-level uncertainty threshold. If float, will use
the specified tile-level uncertainty threshold.
slide_uq (str or float): Either 'detect' or float. If 'detect',
will detect slide-level uncertainty threshold. If float, will use
the specified slide-level uncertainty threshold.
tile_pred (str or float): Either 'detect' or float. If 'detect',
will detect tile-level prediction threshold. If float, will use the
specified tile-level prediction threshold.
slide_pred (str or float): Either 'detect' or float. If 'detect'
will detect slide-level prediction threshold. If float, will use
the specified slide-level prediction threshold.
plot (bool, optional): Plot slide-level uncertainty. Defaults to False.
patients (dict, optional): Dict mapping slides to patients. Required
for patient-level thresholding.
Returns:
Dictionary with tile- and slide-level UQ and prediction threhsolds,
with keys: 'tile_uq', 'tile_pred', 'slide_uq', 'slide_pred'
'''
required_cols = ('y_true', 'y_pred', 'uncertainty', 'slide', 'patient')
k_tile_thresh, k_slide_thresh = [], []
k_tile_pred_thresh, k_slide_pred_thresh = [], []
k_auc = []
skip_tile = ('tile_uq_thresh' in kwargs
and kwargs['tile_uq_thresh'] is None)
skip_slide = ('slide_uq_thresh' in kwargs
and kwargs['slide_uq_thresh'] is None)
for idx, df in enumerate(dfs):
log.debug(f"Detecting thresholds from fold {idx}")
if not all(col in df.columns for col in required_cols):
raise ValueError(
f"DataFrame missing columns, expected {required_cols}, got: "
f"{', '.join(df.columns.tolist())}"
)
thresholds, auc = detect(df, **kwargs)
if thresholds['tile_uq'] is None or thresholds['slide_uq'] is None:
log.debug(f"Skipping CV #{idx}, unable to detect threshold")
continue
k_tile_pred_thresh += [thresholds['slide_pred']]
k_slide_pred_thresh += [thresholds['tile_pred']]
k_auc += [auc]
if not skip_tile:
k_tile_thresh += [thresholds['tile_uq']]
if not skip_slide:
k_slide_thresh += [thresholds['slide_uq']]
if not skip_tile and not len(k_tile_thresh):
raise errors.ThresholdError('Unable to detect tile UQ threshold.')
if not skip_slide and not len(k_slide_thresh):
raise errors.ThresholdError('Unable to detect slide UQ threshold.')
k_slide_pred_thresh = np.mean(k_slide_pred_thresh)
k_tile_pred_thresh = np.mean(k_tile_pred_thresh)
if not skip_tile:
k_tile_thresh = np.min(k_tile_thresh)
if not skip_slide:
k_slide_thresh = np.max(k_slide_thresh)
return {
'tile_uq': k_tile_thresh,
'slide_uq': k_slide_thresh,
'tile_pred': k_tile_pred_thresh,
'slide_pred': k_slide_pred_thresh
}