Skip to content
4 changes: 3 additions & 1 deletion aaanalysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from .explainable_ai import TreeModel
from .protein_design import AAMut, AAMutPlot, SeqMut, SeqMutPlot, SeqOpt, SeqOptPlot
from .plotting import (plot_get_clist, plot_get_cmap, plot_get_cdict,
plot_settings, plot_legend, plot_gcfs, plot_rank)
plot_settings, plot_legend, plot_gcfs, plot_rank,
plot_eval_heatmap)
from .metrics import (comp_auc_adjusted, comp_bic_score, comp_kld,
comp_per_protein_ap, comp_detection_metrics,
comp_bootstrap_ci, comp_smooth_scores)
Expand Down Expand Up @@ -65,6 +66,7 @@
"plot_legend",
"plot_gcfs",
"plot_rank",
"plot_eval_heatmap",
"comp_auc_adjusted",
"comp_bic_score",
"comp_kld",
Expand Down
4 changes: 3 additions & 1 deletion aaanalysis/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Plotting utilities: shared styling, colors, and legends for every ``*Plot`` class.

Public objects: plot_get_clist, plot_get_cmap, plot_get_cdict, plot_settings,
plot_legend, plot_gcfs, plot_rank.
plot_legend, plot_gcfs, plot_rank, plot_eval_heatmap.
A cross-cutting subpackage (not a pipeline stage): ``plot_settings`` sets the house
rcParams, the ``plot_get_*`` helpers supply the color list / map / dict, and
``plot_legend`` / ``plot_gcfs`` / ``plot_rank`` are reused by the paired plot classes
Expand All @@ -18,6 +18,7 @@
from ._plot_gcfs import plot_gcfs
from ._plot_legend import plot_legend
from ._plot_rank import plot_rank
from ._plot_eval_heatmap import plot_eval_heatmap


__all__ = [
Expand All @@ -28,4 +29,5 @@
"plot_legend",
"plot_gcfs",
"plot_rank",
"plot_eval_heatmap",
]
123 changes: 123 additions & 0 deletions aaanalysis/plotting/_plot_eval_heatmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
This is a script for the frontend of the plot_eval_heatmap function — a house-preset
annotated evaluation heatmap for a static score grid.
"""
from typing import Optional, Union, Tuple
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.axes import Axes

from aaanalysis import utils as ut


# I Helper Functions


# II Main Functions
def plot_eval_heatmap(df_eval: pd.DataFrame,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
vmin: Union[int, float] = 50,
vmax: Union[int, float] = 100,
cbar_label: Optional[str] = "Balanced accuracy [%]",
xtick_rotation: Union[int, float] = 0,
ytick_rotation: Union[int, float] = 0,
square: bool = True,
figsize: Optional[Tuple[Union[int, float], Union[int, float]]] = None,
ax: Optional[Axes] = None,
) -> Axes:
"""
Plot a house-preset annotated evaluation heatmap from a static score grid.

Renders ``df_eval`` (rows × columns of evaluation scores, e.g. balanced accuracy in
percent) as a ``viridis`` heatmap with integer annotations, fixed ``[vmin, vmax]``
color limits, and a labeled colorbar — collapsing the hand-built seaborn block that is
otherwise copied for every sweep result into a single call. Left/bottom ticks are
removed for the package look; tick labels are horizontal by default and can be rotated
(e.g. for long column labels) via ``xtick_rotation`` / ``ytick_rotation``.

.. versionadded:: 1.1.0

Parameters
----------
df_eval : pd.DataFrame, shape (n_rows, n_cols)
Numeric score grid. Each cell holds the score for one row × column configuration
(the ``index`` becomes the y-axis levels, the ``columns`` the x-axis levels).
xlabel : str, optional
Label for the x-axis. If ``None``, seaborn's default (the ``columns`` name, if any)
is kept.
ylabel : str, optional
Label for the y-axis. If ``None``, seaborn's default (the ``index`` name, if any)
is kept.
vmin : int or float, default=50
Lower bound of the color scale (and colorbar).
vmax : int or float, default=100
Upper bound of the color scale (and colorbar). Must be greater than ``vmin``.
cbar_label : str, optional
Label drawn next to the colorbar. If ``None``, no colorbar label is set.
xtick_rotation : int or float, default=0
Rotation (degrees) of the x-axis tick labels. Non-zero values are right-aligned to
keep long column labels legible without overlap.
ytick_rotation : int or float, default=0
Rotation (degrees) of the y-axis tick labels.
square : bool, default=True
If ``True``, force each heatmap cell to be equal in width and height (a square grid),
the convention for an evaluation map. Set ``False`` to let the cells stretch to fill
the axes/figure.
figsize : tuple, optional
Figure ``(width, height)`` used when ``ax`` is ``None``. If ``None``, matplotlib's
default figure size is used.
ax : matplotlib.axes.Axes, optional
Axes to draw on. If ``None``, a new figure and axes are created.

Returns
-------
ax : matplotlib.axes.Axes
The axes with the annotated heatmap.

See Also
--------
* :func:`aaanalysis.pipe.plot_eval` is the adaptive sibling: it inspects a
``find_features`` sweep table and lays out one or more panels automatically.
``plot_eval_heatmap`` is the simple **static** entry point — hand it a ready-made
score grid and it applies the house preset, with no axis selection or multi-panel
logic.

Examples
--------
.. include:: examples/plot_eval_heatmap.rst
"""
# Check input
ut.check_df(name="df_eval", df=df_eval, accept_none=False)
if df_eval.empty:
raise ValueError(f"'df_eval' (shape {df_eval.shape}) should contain at least one "
f"row and one column.")
non_numeric = df_eval.select_dtypes(exclude="number").columns.tolist()
if non_numeric:
raise ValueError(f"'df_eval' should be all-numeric; non-numeric columns: {non_numeric}.")
ut.check_str(name="xlabel", val=xlabel, accept_none=True)
ut.check_str(name="ylabel", val=ylabel, accept_none=True)
ut.check_vmin_vmax(vmin=vmin, vmax=vmax)
ut.check_str(name="cbar_label", val=cbar_label, accept_none=True)
ut.check_number_val(name="xtick_rotation", val=xtick_rotation, just_int=False)
ut.check_number_val(name="ytick_rotation", val=ytick_rotation, just_int=False)
ut.check_bool(name="square", val=square)
ut.check_figsize(figsize=figsize, accept_none=True)
ut.check_ax(ax=ax, accept_none=True)

# Draw
if ax is None:
_, ax = plt.subplots(figsize=figsize)
cbar_kws = dict(label=cbar_label) if cbar_label is not None else None
sns.heatmap(df_eval, ax=ax, vmin=vmin, vmax=vmax, cmap="viridis", annot=True,
fmt=".0f", linewidths=0.1, square=square, cbar_kws=cbar_kws)
ax.tick_params(left=False, bottom=False)
x_ha = "right" if xtick_rotation % 360 != 0 else "center"
ax.set_yticklabels(ax.get_yticklabels(), rotation=ytick_rotation)
ax.set_xticklabels(ax.get_xticklabels(), rotation=xtick_rotation, ha=x_ha)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
return ax
1 change: 1 addition & 0 deletions docs/_cheatsheet/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@
("BIC score · KL divergence", "comp_bic_score(X, labels) · comp_kld", None),
("Per-protein / detection (v1.1)", "comp_per_protein_ap · comp_detection_metrics", None),
("Plot style, fonts & standalone legend", "plot_settings(font_scale) · plot_legend(ax)", None),
("Evaluation score grid heatmap (v1.1)", "plot_eval_heatmap(df_eval)", None),
]},
{"name": "Protein Design", "tag": "mutations · design",
"under_construction": True,
Expand Down
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ Utility Functions
comp_smooth_scores
display_df
options
plot_eval_heatmap
plot_gcfs
plot_get_cdict
plot_get_clist
Expand Down
4 changes: 4 additions & 0 deletions docs/source/index/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ Added

- :func:`~aaanalysis.plot_rank`: Standalone per-protein max-score-vs-rank scatter with group coloring and
optional threshold lines (pairs with the new ``aa.metrics`` functions).
- **plot_eval_heatmap**: House-preset annotated evaluation heatmap (``viridis``, fixed
``[vmin, vmax]`` color limits, integer annotations, labeled colorbar) for a static score
grid. Collapses the hand-built seaborn block previously copied for every sweep result
into one call; the simple static sibling of the adaptive ``aap.plot_eval``.

**Golden Pipelines**

Expand Down
127 changes: 127 additions & 0 deletions examples/plotting/plot_eval_heatmap.ipynb

Large diffs are not rendered by default.

207 changes: 207 additions & 0 deletions tests/unit/plotting_tests/test_plot_eval_heatmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
"""This is a script to test the plot_eval_heatmap() house-preset evaluation heatmap (#310)."""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import pytest

import aaanalysis as aa
from aaanalysis.plotting import plot_eval_heatmap

aa.options["verbose"] = False


# Helper functions
def _df(n_rows=2, n_cols=3, seed=0):
rng = np.random.default_rng(seed)
vals = rng.uniform(55, 95, size=(n_rows, n_cols))
return pd.DataFrame(vals,
index=[f"row{i}" for i in range(n_rows)],
columns=[f"col{j}" for j in range(n_cols)])


@pytest.fixture(autouse=True)
def _close_figs():
yield
plt.close("all")


class TestPlotEvalHeatmap:
"""Normal cases for plot_eval_heatmap."""

def test_returns_ax(self):
ax = plot_eval_heatmap(df_eval=_df())
assert isinstance(ax, plt.Axes)

def test_is_public_export(self):
assert "plot_eval_heatmap" in aa.__all__
assert aa.plot_eval_heatmap is plot_eval_heatmap

def test_draws_on_passed_ax(self):
fig0, ax0 = plt.subplots()
ax = plot_eval_heatmap(df_eval=_df(), ax=ax0)
assert ax is ax0

def test_vmin_vmax_respected(self):
ax = plot_eval_heatmap(df_eval=_df(), vmin=40, vmax=90)
assert ax.collections[0].get_clim() == (40.0, 90.0)

def test_default_vmin_vmax(self):
ax = plot_eval_heatmap(df_eval=_df())
assert ax.collections[0].get_clim() == (50.0, 100.0)

def test_xlabel_ylabel_set(self):
ax = plot_eval_heatmap(df_eval=_df(), xlabel="Scales", ylabel="Parts")
assert ax.get_xlabel() == "Scales" and ax.get_ylabel() == "Parts"

def test_labels_default_none_keeps_seaborn(self):
# With unnamed axes and xlabel/ylabel=None, no custom label is forced.
ax = plot_eval_heatmap(df_eval=_df())
assert ax.get_xlabel() == "" and ax.get_ylabel() == ""

def test_cbar_label_default(self):
ax = plot_eval_heatmap(df_eval=_df())
# The colorbar lives on a sibling axes of the same figure.
cbar_axes = [a for a in ax.figure.axes if a is not ax]
assert any(a.get_ylabel() == "Balanced accuracy [%]" for a in cbar_axes)

def test_cbar_label_custom(self):
ax = plot_eval_heatmap(df_eval=_df(), cbar_label="F1 [%]")
cbar_axes = [a for a in ax.figure.axes if a is not ax]
assert any(a.get_ylabel() == "F1 [%]" for a in cbar_axes)

def test_cbar_label_none_no_colorbar_label(self):
ax = plot_eval_heatmap(df_eval=_df(), cbar_label=None)
cbar_axes = [a for a in ax.figure.axes if a is not ax]
assert all(a.get_ylabel() == "" for a in cbar_axes)

def test_annotation_count_matches_cells(self):
df = _df(n_rows=2, n_cols=3)
ax = plot_eval_heatmap(df_eval=df)
assert len(ax.texts) == df.size

def test_ticklabels_horizontal(self):
ax = plot_eval_heatmap(df_eval=_df())
assert all(t.get_rotation() == 0 for t in ax.get_xticklabels())
assert all(t.get_rotation() == 0 for t in ax.get_yticklabels())

def test_single_cell(self):
ax = plot_eval_heatmap(df_eval=pd.DataFrame([[88.0]]))
assert isinstance(ax, plt.Axes) and len(ax.texts) == 1

def test_xtick_rotation(self):
ax = plot_eval_heatmap(df_eval=_df(), xtick_rotation=45)
assert all(t.get_rotation() == 45 for t in ax.get_xticklabels())
assert all(t.get_ha() == "right" for t in ax.get_xticklabels())

def test_ytick_rotation(self):
ax = plot_eval_heatmap(df_eval=_df(), ytick_rotation=90)
assert all(t.get_rotation() == 90 for t in ax.get_yticklabels())

def test_figsize_respected(self):
ax = plot_eval_heatmap(df_eval=_df(), figsize=(8, 3))
assert tuple(ax.figure.get_size_inches()) == (8.0, 3.0)

def test_nan_cells_render_gracefully(self):
# A sweep table can carry NaN for a config that failed; NaN cells stay blank
# (not annotated) and the call must not raise.
df = pd.DataFrame([[88.0, np.nan], [70.0, 62.0]],
index=["a", "b"], columns=["x", "y"])
ax = plot_eval_heatmap(df_eval=df)
assert isinstance(ax, plt.Axes)
assert len(ax.texts) == 3 # only the three non-NaN cells are annotated

def test_integer_grid_accepted(self):
ax = plot_eval_heatmap(df_eval=pd.DataFrame([[88, 70], [60, 90]]))
assert isinstance(ax, plt.Axes) and len(ax.texts) == 4


class TestPlotEvalHeatmapEquivalence:
"""KPI #310: equivalent to the hand-built seaborn block it consolidates."""

def test_matches_raw_seaborn_block(self):
# The exact block duplicated in gamma-secretase notebook cells 12/25.
df = _df()
fig_raw, ax_raw = plt.subplots()
sns.heatmap(df, ax=ax_raw, vmin=50, vmax=100, cmap="viridis", annot=True,
fmt=".0f", linewidth=0.1, cbar_kws=dict(label="Balanced accuracy [%]"))
ax_raw.tick_params(left=False, bottom=False)
ax_new = plot_eval_heatmap(df_eval=df)
# Same heatmap data, color limits, colormap, and annotations.
raw_mesh, new_mesh = ax_raw.collections[0], ax_new.collections[0]
assert np.allclose(raw_mesh.get_array(), new_mesh.get_array())
assert raw_mesh.get_clim() == new_mesh.get_clim()
assert raw_mesh.get_cmap().name == new_mesh.get_cmap().name == "viridis"
assert ([t.get_text() for t in ax_raw.texts]
== [t.get_text() for t in ax_new.texts])


class TestPlotEvalHeatmapErrors:
"""Negative cases — bad input raises ValueError."""

def test_not_a_dataframe(self):
with pytest.raises(ValueError):
plot_eval_heatmap(df_eval="not a frame")

def test_none_df(self):
with pytest.raises(ValueError):
plot_eval_heatmap(df_eval=None)

def test_empty_df(self):
with pytest.raises(ValueError):
plot_eval_heatmap(df_eval=pd.DataFrame())

def test_non_numeric_df(self):
with pytest.raises(ValueError):
plot_eval_heatmap(df_eval=pd.DataFrame({"a": ["x", "y"], "b": ["u", "v"]}))

def test_vmin_not_below_vmax(self):
with pytest.raises(ValueError):
plot_eval_heatmap(df_eval=_df(), vmin=100, vmax=50)

def test_bad_xlabel_type(self):
with pytest.raises(ValueError):
plot_eval_heatmap(df_eval=_df(), xlabel=123)

def test_bad_ylabel_type(self):
with pytest.raises(ValueError):
plot_eval_heatmap(df_eval=_df(), ylabel=123)

def test_bad_cbar_label_type(self):
with pytest.raises(ValueError):
plot_eval_heatmap(df_eval=_df(), cbar_label=123)

def test_bad_ytick_rotation_type(self):
with pytest.raises(ValueError):
plot_eval_heatmap(df_eval=_df(), ytick_rotation="sideways")

def test_bad_ax_type(self):
with pytest.raises(ValueError):
plot_eval_heatmap(df_eval=_df(), ax="not an ax")

def test_bad_xtick_rotation_type(self):
with pytest.raises(ValueError):
plot_eval_heatmap(df_eval=_df(), xtick_rotation="sideways")

def test_bad_figsize_type(self):
with pytest.raises(ValueError):
plot_eval_heatmap(df_eval=_df(), figsize="wide")


class TestPlotEvalHeatmapSquare:
"""`square` — evaluation-map cells are equal width and height by default."""

def test_square_default_equal_aspect(self):
ax = plot_eval_heatmap(df_eval=_df())
# square=True makes seaborn force a box (data) aspect of 1.0 -> equal cells.
assert ax.get_aspect() in (1.0, "equal")

def test_square_false_not_forced(self):
ax = plot_eval_heatmap(df_eval=_df(), square=False)
assert ax.get_aspect() not in (1.0, "equal")

def test_square_bad_type(self):
with pytest.raises(ValueError):
plot_eval_heatmap(df_eval=_df(), square="yes")
Loading