Skip to content
179 changes: 179 additions & 0 deletions aaanalysis/plotting/_plot_prediction_hist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""
This is a script for the frontend of the (internal) plot_prediction_hist function — a
class-separated histogram of a 0-100 model prediction score.

This symbol is **not** re-exported in ``aaanalysis/__init__.py`` yet; it is reached via the
submodule import ``from aaanalysis.plotting._plot_prediction_hist import plot_prediction_hist``.
"""
from typing import Optional, List, Dict, Union, Tuple
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator

from aaanalysis import utils as ut
from ._plot_rank import _resolve_group_colors

# TODO(#305): re-export plot_prediction_hist in __init__ (CONFIRM-FIRST, maintainer review)


# I Helper Functions
# (all shared helpers live in ``_plot_rank``; nothing local needed)


# II Main Functions
def plot_prediction_hist(df_pred: pd.DataFrame,
col_score: str = "score",
col_group: str = "group",
group_order: Optional[List[str]] = None,
dict_color: Optional[Dict[str, str]] = None,
binwidth: Union[int, float] = 5,
binrange: Tuple[Union[int, float], Union[int, float]] = (0, 100),
stacked: bool = True,
kde: bool = False,
ax: Optional[Axes] = None,
figsize: Tuple[Union[int, float], Union[int, float]] = (7, 5),
xlabel: str = "Prediction score [%]",
ylabel: str = "Count",
fontsize_labels: Optional[Union[int, float]] = None,
legend: bool = True,
) -> Tuple[Figure, Axes]:
"""
Plot a class-separated histogram of a per-sample prediction score (internal).

Shows how a model's prediction score (e.g. a substrate-probability in ``[0, 100]`` %)
distributes across known classes such as substrate / non-substrate / hold-out, so the
class separation of a deployed predictor can be read off at a glance.

A score column whose maximum is ``<= 1`` is treated as a ``[0, 1]`` probability and
rescaled to the ``[0, 100]`` percent axis (a notice is printed when ``verbose`` is on).
Pass an already-percent score (max ``> 1``) to skip the rescale.

This function is **internal**: it is not part of the public ``aaanalysis`` namespace and
may change without a deprecation cycle.

Parameters
----------
df_pred : pd.DataFrame, shape (n_samples, n_info)
One row per sample; must contain ``col_score`` (the prediction score) and
``col_group`` (the class label used to split the histogram).
col_score : str, default="score"
Column with the per-sample prediction score (numeric; ``[0, 1]`` or ``[0, 100]``).
col_group : str, default="group"
Column with the per-sample class label used to color / separate the bars.
group_order : list of str, optional
Order in which classes are colored / stacked. Defaults to first-appearance order.
dict_color : dict, optional
Mapping ``group -> color`` (overrides the canonical defaults). Canonical class names
(``substrate``, ``non-substrate``, ``hold-out``) default to the locked sample palette.
binwidth : int or float, default=5
Width of the histogram bins (in score units).
binrange : tuple, default=(0, 100)
``(low, high)`` range over which bins are computed; also used as the x-axis limits.
stacked : bool, default=True
If ``True``, class histograms are stacked (``multiple="stack"``); else overlaid
(``multiple="layer"``).
kde : bool, default=False
If ``True``, overlay a kernel-density estimate per class.
ax : matplotlib.axes.Axes, optional
Axes to draw on. If ``None``, a new figure and axes are created.
figsize : tuple, default=(7, 5)
Figure size when ``ax`` is ``None``.
xlabel, ylabel : str
Axis labels.
fontsize_labels : int or float, optional
Font size for the axis labels (matplotlib default if ``None``).
legend : bool, default=True
Whether to draw the class legend.

Returns
-------
fig : matplotlib.figure.Figure
The figure.
ax : matplotlib.axes.Axes
The axes with the class-separated histogram.

See Also
--------
* :func:`aaanalysis.plot_rank` for the companion ranked-candidates view.

Examples
--------
.. code-block:: python

import pandas as pd
from aaanalysis.plotting._plot_prediction_hist import plot_prediction_hist

df_pred = pd.DataFrame({"score": [95, 80, 60, 40, 20, 5],
"group": ["substrate", "substrate", "hold-out",
"non-substrate", "non-substrate", "non-substrate"]})
fig, ax = plot_prediction_hist(df_pred, col_score="score", col_group="group")

.. note::
This symbol is internal; an example notebook and a public re-export are tracked as a
TODO (re-export under :mod:`aaanalysis` is CONFIRM-FIRST, pending maintainer review).
"""
# Check input
ut.check_str(name="col_score", val=col_score)
ut.check_str(name="col_group", val=col_group)
ut.check_df(name="df_pred", df=df_pred, cols_required=[col_score, col_group])
if len(df_pred) == 0:
raise ValueError("'df_pred' (0 rows) should contain at least one sample.")
if not pd.api.types.is_numeric_dtype(df_pred[col_score]):
raise ValueError(f"'{col_score}' column should be numeric (got dtype "
f"'{df_pred[col_score].dtype}').")
if not np.isfinite(df_pred[col_score].to_numpy(dtype=float)).any():
raise ValueError(f"'{col_score}' column has no finite values to plot.")
ut.check_list_like(name="group_order", val=group_order, accept_none=True)
ut.check_dict_color(name="dict_color", val=dict_color, accept_none=True)
ut.check_number_range(name="binwidth", val=binwidth, min_val=0, exclusive_limits=True, just_int=False)
ut.check_lim(name="binrange", val=binrange, accept_none=False)
ut.check_bool(name="stacked", val=stacked)
ut.check_bool(name="kde", val=kde)
ut.check_bool(name="legend", val=legend)
ut.check_ax(ax=ax, accept_none=True)
ut.check_figsize(figsize=figsize, accept_none=True)
ut.check_str(name="xlabel", val=xlabel, accept_none=True)
ut.check_str(name="ylabel", val=ylabel, accept_none=True)
ut.check_number_range(name="fontsize_labels", val=fontsize_labels, min_val=0,
accept_none=True, just_int=False)

# Resolve order + colors
if group_order is None:
group_order = list(dict.fromkeys(df_pred[col_group].tolist()))
else:
missing = set(df_pred[col_group]) - set(group_order)
if missing:
raise ValueError(f"'group_order' is missing groups present in 'df_pred': {missing}")
dict_group_color = _resolve_group_colors(group_order=group_order, dict_color=dict_color)

# Rescale a [0, 1] probability to a [0, 100] percent axis (explicit, verbose-noticed)
df = df_pred.copy()
scores = df[col_score].to_numpy(dtype=float)
if np.isfinite(scores).any() and np.nanmax(scores) <= 1:
df[col_score] = scores * 100
if ut.check_verbose(False):
ut.print_out(f"Note: '{col_score}' (max <= 1) looks like a [0, 1] probability; "
f"rescaled to a [0, 100] % axis.")

# Draw
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.figure
multiple = "stack" if stacked else "layer"
sns.histplot(data=df, x=col_score, hue=col_group, hue_order=group_order,
palette=dict_group_color, binwidth=binwidth, binrange=tuple(binrange),
multiple=multiple, kde=kde, legend=legend, ax=ax)
ax.set_xlim(binrange[0], binrange[1])
ax.set_xlabel(xlabel, fontsize=fontsize_labels)
ax.set_ylabel(ylabel, fontsize=fontsize_labels)
# Counts are integers: force integer y-ticks with matplotlib's own nice spacing
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
sns.despine(ax=ax)
# ``ax.figure`` is typed ``Figure | SubFigure | None`` by the matplotlib stubs,
# but a top-level Axes here always belongs to a real Figure.
return fig, ax # pyright: ignore[reportReturnType]
105 changes: 96 additions & 9 deletions aaanalysis/plotting/_plot_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,16 @@
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.patches import Patch

from aaanalysis import utils as ut
from ._plot_get_clist import plot_get_clist

# Default axis labels for the scatter (rank) mode; used to detect "left at default"
# so the additive ranked-candidates (bar) mode can substitute sensible labels.
_DEFAULT_XLABEL = "Protein rank"
_DEFAULT_YLABEL = "Max score per protein"

# I Helper Functions
# Canonical group -> color defaults (overridable via dict_color); leans on the
# locked sample palette so substrate/non-substrate read green/magenta out of the box.
Expand Down Expand Up @@ -44,13 +50,55 @@ def _resolve_group_colors(group_order: List[str], dict_color: Optional[Dict[str,
return out


def _plot_ranked_candidates(ax, df_rank, col_score, col_class, col_std, group_order,
dict_group_color, list_thresholds, xlabel, ylabel,
fontsize_labels):
"""Draw the ranked-candidates horizontal-bar variant (port of plot_pred3_top_hits):
named candidates as horizontal bars colored by class, optional per-item error bars,
and vertical threshold (cutoff) lines."""
# Sort by class (in group_order) then ascending score, keep the index as candidate names
order_map = {c: i for i, c in enumerate(group_order)}
df = df_rank.copy()
df["_name"] = df.index.astype(str)
df["_sorter"] = df[col_class].map(order_map)
df = df.sort_values(["_sorter", col_score], ascending=[False, True]).reset_index(drop=True)
y_pos = np.arange(len(df))
scores = df[col_score].to_numpy()
colors = [dict_group_color[c] for c in df[col_class]]
ax.barh(y_pos, scores, color=colors)
if col_std is not None:
ax.errorbar(x=scores, y=y_pos, xerr=df[col_std].to_numpy(),
fmt="none", ecolor="black", capsize=3, capthick=1, elinewidth=1)
ax.set_yticks(y_pos)
ax.set_yticklabels(df["_name"].tolist())
ax.tick_params(length=0, axis="y")
ax.set_ylim(-0.75, len(df))
ax.set_xlim(left=0)
for t in list_thresholds:
ax.axvline(t, color="grey", linestyle="--", linewidth=1)
# Substitute bar-appropriate axis labels only when the caller kept the scatter defaults
xl = "Prediction score" if xlabel == _DEFAULT_XLABEL else xlabel
yl = "" if ylabel == _DEFAULT_YLABEL else ylabel
ax.set_xlabel(xl, fontsize=fontsize_labels)
ax.set_ylabel(yl, fontsize=fontsize_labels)
present = set(df[col_class])
handles = [Patch(color=dict_group_color[g], label=str(g))
for g in group_order if g in present]
# Horizontal bars fill the axes; keep the class legend outside (right) so it never
# overlaps the top candidate bars.
ax.legend(handles=handles, loc="upper left", bbox_to_anchor=(1.01, 1.0), frameon=False)
sns.despine(ax=ax)


# II Main Functions
def plot_rank(df_rank: pd.DataFrame,
col_score: str = "score",
col_group: str = "group",
group_order: Optional[List[str]] = None,
dict_color: Optional[Dict[str, str]] = None,
threshold: Optional[Union[int, float, List[Union[int, float]]]] = None,
col_std: Optional[str] = None,
col_class: Optional[str] = None,
ax: Optional[Axes] = None,
figsize: Tuple[Union[int, float], Union[int, float]] = (7, 5),
marker_size: Union[int, float] = 25,
Expand All @@ -65,6 +113,12 @@ def plot_rank(df_rank: pd.DataFrame,
ranked by their maximum score and colored by membership in groups such as substrate /
hold-out / non-substrate, with optional threshold lines for the deployment caller.

Passing ``col_class`` switches to an additive **ranked-candidates** variant: named
candidates (the DataFrame index) are drawn as horizontal bars colored by class, with
optional per-item error bars (``col_std``) and vertical cutoff lines (``threshold``).
This reproduces the recurring "top-hits with agreement" figure. The default scatter
output is unchanged when ``col_class`` is ``None``.

.. versionadded:: 1.1.0

Parameters
Expand All @@ -82,7 +136,15 @@ def plot_rank(df_rank: pd.DataFrame,
Mapping ``group -> color`` (overrides the canonical defaults). Canonical group names
(``substrate``, ``non-substrate``, ``hold-out``) default to the locked sample palette.
threshold : int, float, or list, optional
One or more y-values drawn as horizontal threshold lines.
One or more cutoff values drawn as threshold lines (horizontal in scatter mode,
vertical in ranked-candidates mode).
col_std : str, optional
Column with a per-item standard deviation. When given (only valid together with
``col_class``), symmetric horizontal error bars are drawn on the candidate bars.
col_class : str, optional
Column with the per-item class label. When given, the plot switches to the
ranked-candidates horizontal-bar variant (bars colored by class, candidate names
taken from the DataFrame index); ``col_group`` is then unused.
ax : matplotlib.axes.Axes, optional
Axes to draw on. If ``None``, a new figure and axes are created.
figsize : tuple, default=(7, 5)
Expand Down Expand Up @@ -114,9 +176,24 @@ def plot_rank(df_rank: pd.DataFrame,
# Check input
ut.check_str(name="col_score", val=col_score)
ut.check_str(name="col_group", val=col_group)
ut.check_df(name="df_rank", df=df_rank, cols_required=[col_score, col_group])
ut.check_str(name="col_std", val=col_std, accept_none=True)
ut.check_str(name="col_class", val=col_class, accept_none=True)
if col_std is not None and col_class is None:
raise ValueError("'col_std' (error bars) requires 'col_class' (ranked-candidates mode).")
# Column that carries the class/group labels (col_group for scatter, col_class for bars)
col_groups = col_group if col_class is None else col_class
cols_required = [col_score, col_groups]
if col_std is not None:
cols_required.append(col_std)
ut.check_df(name="df_rank", df=df_rank, cols_required=cols_required)
if len(df_rank) == 0:
raise ValueError("'df_rank' (0 rows) should contain at least one protein.")
if not pd.api.types.is_numeric_dtype(df_rank[col_score]):
raise ValueError(f"'{col_score}' column should be numeric (got dtype "
f"'{df_rank[col_score].dtype}').")
if col_std is not None and not pd.api.types.is_numeric_dtype(df_rank[col_std]):
raise ValueError(f"'{col_std}' column should be numeric (got dtype "
f"'{df_rank[col_std].dtype}').")
ut.check_list_like(name="group_order", val=group_order, accept_none=True)
ut.check_dict_color(name="dict_color", val=dict_color, accept_none=True)
ut.check_ax(ax=ax, accept_none=True)
Expand All @@ -134,22 +211,32 @@ def plot_rank(df_rank: pd.DataFrame,

# Resolve order + colors
if group_order is None:
group_order = list(dict.fromkeys(df_rank[col_group].tolist()))
group_order = list(dict.fromkeys(df_rank[col_groups].tolist()))
else:
missing = set(df_rank[col_group]) - set(group_order)
missing = set(df_rank[col_groups]) - set(group_order)
if missing:
raise ValueError(f"'group_order' is missing groups present in 'df_rank': {missing}")
dict_group_color = _resolve_group_colors(group_order=group_order, dict_color=dict_color)

# Build the ranking (descending score -> rank 1..N on the x-axis)
df = df_rank.sort_values(col_score, ascending=False).reset_index(drop=True)
df["_rank"] = np.arange(1, len(df) + 1)

# Draw
# Create / reuse axes
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.figure

if col_class is not None:
# Additive ranked-candidates (horizontal-bar) mode
_plot_ranked_candidates(ax=ax, df_rank=df_rank, col_score=col_score, col_class=col_class,
col_std=col_std, group_order=group_order,
dict_group_color=dict_group_color, list_thresholds=list_thresholds,
xlabel=xlabel, ylabel=ylabel, fontsize_labels=fontsize_labels)
# ``ax.figure`` is typed ``Figure | SubFigure | None`` by the matplotlib stubs,
# but a top-level Axes here always belongs to a real Figure.
return fig, ax # pyright: ignore[reportReturnType]

# Default per-protein rank scatter (descending score -> rank 1..N on the x-axis)
df = df_rank.sort_values(col_score, ascending=False).reset_index(drop=True)
df["_rank"] = np.arange(1, len(df) + 1)
for g in group_order:
sub = df[df[col_group] == g]
if len(sub) == 0:
Expand Down
8 changes: 8 additions & 0 deletions docs/source/index/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,14 @@ Added

**Plotting**

- **plot_rank**: Standalone per-protein max-score-vs-rank scatter with group coloring and
optional threshold lines (pairs with the new ``aa.metrics`` functions). Additively extended
with a ranked-candidates variant (``col_class`` / ``col_std``): named candidates as
horizontal bars colored by class, with optional per-item error bars and vertical cutoff
lines. The default scatter output is unchanged.
- **plot_prediction_hist** (internal): Class-separated histogram of a 0-100 model prediction
score (substrate / non-substrate / unknown separation), with ``[0, 1]`` probabilities
auto-rescaled to a percent axis. Currently internal; a public re-export is pending review.
- :func:`~aaanalysis.plot_rank`: Standalone per-protein max-score-vs-rank scatter with group coloring and
optional threshold lines (pairs with the new ``aa.metrics`` functions).

Expand Down
Loading
Loading