diff --git a/aaanalysis/plotting/_plot_prediction_hist.py b/aaanalysis/plotting/_plot_prediction_hist.py new file mode 100644 index 00000000..83b866f6 --- /dev/null +++ b/aaanalysis/plotting/_plot_prediction_hist.py @@ -0,0 +1,179 @@ +""" +This is a script for the frontend of the (internal) plot_prediction_hist function — a +class-separated histogram of a 0-100 model prediction score. + +This symbol is **not** re-exported in ``aaanalysis/__init__.py`` yet; it is reached via the +submodule import ``from aaanalysis.plotting._plot_prediction_hist import plot_prediction_hist``. +""" +from typing import Optional, List, Dict, Union, Tuple +import numpy as np +import pandas as pd +import seaborn as sns +from matplotlib import pyplot as plt +from matplotlib.axes import Axes +from matplotlib.figure import Figure +from matplotlib.ticker import MaxNLocator + +from aaanalysis import utils as ut +from ._plot_rank import _resolve_group_colors + +# TODO(#305): re-export plot_prediction_hist in __init__ (CONFIRM-FIRST, maintainer review) + + +# I Helper Functions +# (all shared helpers live in ``_plot_rank``; nothing local needed) + + +# II Main Functions +def plot_prediction_hist(df_pred: pd.DataFrame, + col_score: str = "score", + col_group: str = "group", + group_order: Optional[List[str]] = None, + dict_color: Optional[Dict[str, str]] = None, + binwidth: Union[int, float] = 5, + binrange: Tuple[Union[int, float], Union[int, float]] = (0, 100), + stacked: bool = True, + kde: bool = False, + ax: Optional[Axes] = None, + figsize: Tuple[Union[int, float], Union[int, float]] = (7, 5), + xlabel: str = "Prediction score [%]", + ylabel: str = "Count", + fontsize_labels: Optional[Union[int, float]] = None, + legend: bool = True, + ) -> Tuple[Figure, Axes]: + """ + Plot a class-separated histogram of a per-sample prediction score (internal). + + Shows how a model's prediction score (e.g. a substrate-probability in ``[0, 100]`` %) + distributes across known classes such as substrate / non-substrate / hold-out, so the + class separation of a deployed predictor can be read off at a glance. + + A score column whose maximum is ``<= 1`` is treated as a ``[0, 1]`` probability and + rescaled to the ``[0, 100]`` percent axis (a notice is printed when ``verbose`` is on). + Pass an already-percent score (max ``> 1``) to skip the rescale. + + This function is **internal**: it is not part of the public ``aaanalysis`` namespace and + may change without a deprecation cycle. + + Parameters + ---------- + df_pred : pd.DataFrame, shape (n_samples, n_info) + One row per sample; must contain ``col_score`` (the prediction score) and + ``col_group`` (the class label used to split the histogram). + col_score : str, default="score" + Column with the per-sample prediction score (numeric; ``[0, 1]`` or ``[0, 100]``). + col_group : str, default="group" + Column with the per-sample class label used to color / separate the bars. + group_order : list of str, optional + Order in which classes are colored / stacked. Defaults to first-appearance order. + dict_color : dict, optional + Mapping ``group -> color`` (overrides the canonical defaults). Canonical class names + (``substrate``, ``non-substrate``, ``hold-out``) default to the locked sample palette. + binwidth : int or float, default=5 + Width of the histogram bins (in score units). + binrange : tuple, default=(0, 100) + ``(low, high)`` range over which bins are computed; also used as the x-axis limits. + stacked : bool, default=True + If ``True``, class histograms are stacked (``multiple="stack"``); else overlaid + (``multiple="layer"``). + kde : bool, default=False + If ``True``, overlay a kernel-density estimate per class. + ax : matplotlib.axes.Axes, optional + Axes to draw on. If ``None``, a new figure and axes are created. + figsize : tuple, default=(7, 5) + Figure size when ``ax`` is ``None``. + xlabel, ylabel : str + Axis labels. + fontsize_labels : int or float, optional + Font size for the axis labels (matplotlib default if ``None``). + legend : bool, default=True + Whether to draw the class legend. + + Returns + ------- + fig : matplotlib.figure.Figure + The figure. + ax : matplotlib.axes.Axes + The axes with the class-separated histogram. + + See Also + -------- + * :func:`aaanalysis.plot_rank` for the companion ranked-candidates view. + + Examples + -------- + .. code-block:: python + + import pandas as pd + from aaanalysis.plotting._plot_prediction_hist import plot_prediction_hist + + df_pred = pd.DataFrame({"score": [95, 80, 60, 40, 20, 5], + "group": ["substrate", "substrate", "hold-out", + "non-substrate", "non-substrate", "non-substrate"]}) + fig, ax = plot_prediction_hist(df_pred, col_score="score", col_group="group") + + .. note:: + This symbol is internal; an example notebook and a public re-export are tracked as a + TODO (re-export under :mod:`aaanalysis` is CONFIRM-FIRST, pending maintainer review). + """ + # Check input + ut.check_str(name="col_score", val=col_score) + ut.check_str(name="col_group", val=col_group) + ut.check_df(name="df_pred", df=df_pred, cols_required=[col_score, col_group]) + if len(df_pred) == 0: + raise ValueError("'df_pred' (0 rows) should contain at least one sample.") + if not pd.api.types.is_numeric_dtype(df_pred[col_score]): + raise ValueError(f"'{col_score}' column should be numeric (got dtype " + f"'{df_pred[col_score].dtype}').") + if not np.isfinite(df_pred[col_score].to_numpy(dtype=float)).any(): + raise ValueError(f"'{col_score}' column has no finite values to plot.") + ut.check_list_like(name="group_order", val=group_order, accept_none=True) + ut.check_dict_color(name="dict_color", val=dict_color, accept_none=True) + ut.check_number_range(name="binwidth", val=binwidth, min_val=0, exclusive_limits=True, just_int=False) + ut.check_lim(name="binrange", val=binrange, accept_none=False) + ut.check_bool(name="stacked", val=stacked) + ut.check_bool(name="kde", val=kde) + ut.check_bool(name="legend", val=legend) + ut.check_ax(ax=ax, accept_none=True) + ut.check_figsize(figsize=figsize, accept_none=True) + ut.check_str(name="xlabel", val=xlabel, accept_none=True) + ut.check_str(name="ylabel", val=ylabel, accept_none=True) + ut.check_number_range(name="fontsize_labels", val=fontsize_labels, min_val=0, + accept_none=True, just_int=False) + + # Resolve order + colors + if group_order is None: + group_order = list(dict.fromkeys(df_pred[col_group].tolist())) + else: + missing = set(df_pred[col_group]) - set(group_order) + if missing: + raise ValueError(f"'group_order' is missing groups present in 'df_pred': {missing}") + dict_group_color = _resolve_group_colors(group_order=group_order, dict_color=dict_color) + + # Rescale a [0, 1] probability to a [0, 100] percent axis (explicit, verbose-noticed) + df = df_pred.copy() + scores = df[col_score].to_numpy(dtype=float) + if np.isfinite(scores).any() and np.nanmax(scores) <= 1: + df[col_score] = scores * 100 + if ut.check_verbose(False): + ut.print_out(f"Note: '{col_score}' (max <= 1) looks like a [0, 1] probability; " + f"rescaled to a [0, 100] % axis.") + + # Draw + if ax is None: + fig, ax = plt.subplots(figsize=figsize) + else: + fig = ax.figure + multiple = "stack" if stacked else "layer" + sns.histplot(data=df, x=col_score, hue=col_group, hue_order=group_order, + palette=dict_group_color, binwidth=binwidth, binrange=tuple(binrange), + multiple=multiple, kde=kde, legend=legend, ax=ax) + ax.set_xlim(binrange[0], binrange[1]) + ax.set_xlabel(xlabel, fontsize=fontsize_labels) + ax.set_ylabel(ylabel, fontsize=fontsize_labels) + # Counts are integers: force integer y-ticks with matplotlib's own nice spacing + ax.yaxis.set_major_locator(MaxNLocator(integer=True)) + sns.despine(ax=ax) + # ``ax.figure`` is typed ``Figure | SubFigure | None`` by the matplotlib stubs, + # but a top-level Axes here always belongs to a real Figure. + return fig, ax # pyright: ignore[reportReturnType] diff --git a/aaanalysis/plotting/_plot_rank.py b/aaanalysis/plotting/_plot_rank.py index 9d9a8712..bda286a9 100644 --- a/aaanalysis/plotting/_plot_rank.py +++ b/aaanalysis/plotting/_plot_rank.py @@ -8,10 +8,16 @@ from matplotlib import pyplot as plt from matplotlib.axes import Axes from matplotlib.figure import Figure +from matplotlib.patches import Patch from aaanalysis import utils as ut from ._plot_get_clist import plot_get_clist +# Default axis labels for the scatter (rank) mode; used to detect "left at default" +# so the additive ranked-candidates (bar) mode can substitute sensible labels. +_DEFAULT_XLABEL = "Protein rank" +_DEFAULT_YLABEL = "Max score per protein" + # I Helper Functions # Canonical group -> color defaults (overridable via dict_color); leans on the # locked sample palette so substrate/non-substrate read green/magenta out of the box. @@ -44,6 +50,46 @@ def _resolve_group_colors(group_order: List[str], dict_color: Optional[Dict[str, return out +def _plot_ranked_candidates(ax, df_rank, col_score, col_class, col_std, group_order, + dict_group_color, list_thresholds, xlabel, ylabel, + fontsize_labels): + """Draw the ranked-candidates horizontal-bar variant (port of plot_pred3_top_hits): + named candidates as horizontal bars colored by class, optional per-item error bars, + and vertical threshold (cutoff) lines.""" + # Sort by class (in group_order) then ascending score, keep the index as candidate names + order_map = {c: i for i, c in enumerate(group_order)} + df = df_rank.copy() + df["_name"] = df.index.astype(str) + df["_sorter"] = df[col_class].map(order_map) + df = df.sort_values(["_sorter", col_score], ascending=[False, True]).reset_index(drop=True) + y_pos = np.arange(len(df)) + scores = df[col_score].to_numpy() + colors = [dict_group_color[c] for c in df[col_class]] + ax.barh(y_pos, scores, color=colors) + if col_std is not None: + ax.errorbar(x=scores, y=y_pos, xerr=df[col_std].to_numpy(), + fmt="none", ecolor="black", capsize=3, capthick=1, elinewidth=1) + ax.set_yticks(y_pos) + ax.set_yticklabels(df["_name"].tolist()) + ax.tick_params(length=0, axis="y") + ax.set_ylim(-0.75, len(df)) + ax.set_xlim(left=0) + for t in list_thresholds: + ax.axvline(t, color="grey", linestyle="--", linewidth=1) + # Substitute bar-appropriate axis labels only when the caller kept the scatter defaults + xl = "Prediction score" if xlabel == _DEFAULT_XLABEL else xlabel + yl = "" if ylabel == _DEFAULT_YLABEL else ylabel + ax.set_xlabel(xl, fontsize=fontsize_labels) + ax.set_ylabel(yl, fontsize=fontsize_labels) + present = set(df[col_class]) + handles = [Patch(color=dict_group_color[g], label=str(g)) + for g in group_order if g in present] + # Horizontal bars fill the axes; keep the class legend outside (right) so it never + # overlaps the top candidate bars. + ax.legend(handles=handles, loc="upper left", bbox_to_anchor=(1.01, 1.0), frameon=False) + sns.despine(ax=ax) + + # II Main Functions def plot_rank(df_rank: pd.DataFrame, col_score: str = "score", @@ -51,6 +97,8 @@ def plot_rank(df_rank: pd.DataFrame, group_order: Optional[List[str]] = None, dict_color: Optional[Dict[str, str]] = None, threshold: Optional[Union[int, float, List[Union[int, float]]]] = None, + col_std: Optional[str] = None, + col_class: Optional[str] = None, ax: Optional[Axes] = None, figsize: Tuple[Union[int, float], Union[int, float]] = (7, 5), marker_size: Union[int, float] = 25, @@ -65,6 +113,12 @@ def plot_rank(df_rank: pd.DataFrame, ranked by their maximum score and colored by membership in groups such as substrate / hold-out / non-substrate, with optional threshold lines for the deployment caller. + Passing ``col_class`` switches to an additive **ranked-candidates** variant: named + candidates (the DataFrame index) are drawn as horizontal bars colored by class, with + optional per-item error bars (``col_std``) and vertical cutoff lines (``threshold``). + This reproduces the recurring "top-hits with agreement" figure. The default scatter + output is unchanged when ``col_class`` is ``None``. + .. versionadded:: 1.1.0 Parameters @@ -82,7 +136,15 @@ def plot_rank(df_rank: pd.DataFrame, Mapping ``group -> color`` (overrides the canonical defaults). Canonical group names (``substrate``, ``non-substrate``, ``hold-out``) default to the locked sample palette. threshold : int, float, or list, optional - One or more y-values drawn as horizontal threshold lines. + One or more cutoff values drawn as threshold lines (horizontal in scatter mode, + vertical in ranked-candidates mode). + col_std : str, optional + Column with a per-item standard deviation. When given (only valid together with + ``col_class``), symmetric horizontal error bars are drawn on the candidate bars. + col_class : str, optional + Column with the per-item class label. When given, the plot switches to the + ranked-candidates horizontal-bar variant (bars colored by class, candidate names + taken from the DataFrame index); ``col_group`` is then unused. ax : matplotlib.axes.Axes, optional Axes to draw on. If ``None``, a new figure and axes are created. figsize : tuple, default=(7, 5) @@ -114,9 +176,24 @@ def plot_rank(df_rank: pd.DataFrame, # Check input ut.check_str(name="col_score", val=col_score) ut.check_str(name="col_group", val=col_group) - ut.check_df(name="df_rank", df=df_rank, cols_required=[col_score, col_group]) + ut.check_str(name="col_std", val=col_std, accept_none=True) + ut.check_str(name="col_class", val=col_class, accept_none=True) + if col_std is not None and col_class is None: + raise ValueError("'col_std' (error bars) requires 'col_class' (ranked-candidates mode).") + # Column that carries the class/group labels (col_group for scatter, col_class for bars) + col_groups = col_group if col_class is None else col_class + cols_required = [col_score, col_groups] + if col_std is not None: + cols_required.append(col_std) + ut.check_df(name="df_rank", df=df_rank, cols_required=cols_required) if len(df_rank) == 0: raise ValueError("'df_rank' (0 rows) should contain at least one protein.") + if not pd.api.types.is_numeric_dtype(df_rank[col_score]): + raise ValueError(f"'{col_score}' column should be numeric (got dtype " + f"'{df_rank[col_score].dtype}').") + if col_std is not None and not pd.api.types.is_numeric_dtype(df_rank[col_std]): + raise ValueError(f"'{col_std}' column should be numeric (got dtype " + f"'{df_rank[col_std].dtype}').") ut.check_list_like(name="group_order", val=group_order, accept_none=True) ut.check_dict_color(name="dict_color", val=dict_color, accept_none=True) ut.check_ax(ax=ax, accept_none=True) @@ -134,22 +211,32 @@ def plot_rank(df_rank: pd.DataFrame, # Resolve order + colors if group_order is None: - group_order = list(dict.fromkeys(df_rank[col_group].tolist())) + group_order = list(dict.fromkeys(df_rank[col_groups].tolist())) else: - missing = set(df_rank[col_group]) - set(group_order) + missing = set(df_rank[col_groups]) - set(group_order) if missing: raise ValueError(f"'group_order' is missing groups present in 'df_rank': {missing}") dict_group_color = _resolve_group_colors(group_order=group_order, dict_color=dict_color) - # Build the ranking (descending score -> rank 1..N on the x-axis) - df = df_rank.sort_values(col_score, ascending=False).reset_index(drop=True) - df["_rank"] = np.arange(1, len(df) + 1) - - # Draw + # Create / reuse axes if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.figure + + if col_class is not None: + # Additive ranked-candidates (horizontal-bar) mode + _plot_ranked_candidates(ax=ax, df_rank=df_rank, col_score=col_score, col_class=col_class, + col_std=col_std, group_order=group_order, + dict_group_color=dict_group_color, list_thresholds=list_thresholds, + xlabel=xlabel, ylabel=ylabel, fontsize_labels=fontsize_labels) + # ``ax.figure`` is typed ``Figure | SubFigure | None`` by the matplotlib stubs, + # but a top-level Axes here always belongs to a real Figure. + return fig, ax # pyright: ignore[reportReturnType] + + # Default per-protein rank scatter (descending score -> rank 1..N on the x-axis) + df = df_rank.sort_values(col_score, ascending=False).reset_index(drop=True) + df["_rank"] = np.arange(1, len(df) + 1) for g in group_order: sub = df[df[col_group] == g] if len(sub) == 0: diff --git a/docs/source/index/release_notes.rst b/docs/source/index/release_notes.rst index 89fa7d74..3339edec 100644 --- a/docs/source/index/release_notes.rst +++ b/docs/source/index/release_notes.rst @@ -189,6 +189,14 @@ Added **Plotting** +- **plot_rank**: Standalone per-protein max-score-vs-rank scatter with group coloring and + optional threshold lines (pairs with the new ``aa.metrics`` functions). Additively extended + with a ranked-candidates variant (``col_class`` / ``col_std``): named candidates as + horizontal bars colored by class, with optional per-item error bars and vertical cutoff + lines. The default scatter output is unchanged. +- **plot_prediction_hist** (internal): Class-separated histogram of a 0-100 model prediction + score (substrate / non-substrate / unknown separation), with ``[0, 1]`` probabilities + auto-rescaled to a percent axis. Currently internal; a public re-export is pending review. - :func:`~aaanalysis.plot_rank`: Standalone per-protein max-score-vs-rank scatter with group coloring and optional threshold lines (pairs with the new ``aa.metrics`` functions). diff --git a/tests/unit/plotting_tests/test_plot_prediction_hist.py b/tests/unit/plotting_tests/test_plot_prediction_hist.py new file mode 100644 index 00000000..037b241f --- /dev/null +++ b/tests/unit/plotting_tests/test_plot_prediction_hist.py @@ -0,0 +1,229 @@ +"""This is a script to test the (internal) plot_prediction_hist() class-separated score histogram (#312).""" +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pytest + +import aaanalysis as aa +import aaanalysis.utils as ut +from aaanalysis.plotting._plot_prediction_hist import plot_prediction_hist + +aa.options["verbose"] = False + + +# Helper functions +def _df(n_sub=10, n_hold=5, n_non=15, seed=0, scale=100.0): + rng = np.random.default_rng(seed) + n = n_sub + n_hold + n_non + return pd.DataFrame({"score": rng.random(n) * scale, + "group": ["substrate"] * n_sub + ["hold-out"] * n_hold + + ["non-substrate"] * n_non}) + + +@pytest.fixture(autouse=True) +def _close_figs(): + yield + plt.close("all") + + +class TestPlotPredictionHist: + """Normal cases for plot_prediction_hist.""" + + def test_returns_fig_ax(self): + fig, ax = plot_prediction_hist(df_pred=_df()) + assert isinstance(fig, plt.Figure) and isinstance(ax, plt.Axes) + + def test_returns_axes(self): + _, ax = plot_prediction_hist(df_pred=_df()) + assert isinstance(ax, plt.Axes) + + def test_draws_bars(self): + fig, ax = plot_prediction_hist(df_pred=_df()) + assert len(ax.patches) > 0 + + def test_xlim_matches_binrange(self): + fig, ax = plot_prediction_hist(df_pred=_df(), binrange=(0, 100)) + assert ax.get_xlim() == (0.0, 100.0) + + def test_custom_binrange_applied(self): + fig, ax = plot_prediction_hist(df_pred=_df(), binrange=(0, 50)) + assert ax.get_xlim() == (0.0, 50.0) + + def test_auto_rescale_probability_to_percent(self): + # scores in [0, 1] should map onto the [0, 100] axis (bars beyond x=1) + fig, ax = plot_prediction_hist(df_pred=_df(scale=1.0)) + right_edges = [p.get_x() + p.get_width() for p in ax.patches if p.get_height() > 0] + assert max(right_edges) > 1.0 + + def test_percent_scores_not_rescaled(self): + # An already-percent score (max > 1) must be left untouched (no x100 blow-up). + df = pd.DataFrame({"score": [2.0, 40.0, 95.0], "group": ["a", "a", "a"]}) + fig, ax = plot_prediction_hist(df_pred=df, binrange=(0, 100)) + right_edges = [p.get_x() + p.get_width() for p in ax.patches if p.get_height() > 0] + assert max(right_edges) <= 100.0 # would be 9500 if wrongly rescaled + + def test_stacked_true_uses_stack(self): + fig, ax = plot_prediction_hist(df_pred=_df(), stacked=True) + assert len(ax.patches) > 0 + + def test_layer_mode_runs(self): + fig, ax = plot_prediction_hist(df_pred=_df(), stacked=False) + assert len(ax.patches) > 0 + + def test_kde_adds_lines(self): + fig, ax = plot_prediction_hist(df_pred=_df(), kde=True) + assert len(ax.get_lines()) > 0 + + def test_no_kde_no_lines(self): + fig, ax = plot_prediction_hist(df_pred=_df(), kde=False) + assert len(ax.get_lines()) == 0 + + def test_legend_drawn(self): + fig, ax = plot_prediction_hist(df_pred=_df(), legend=True) + assert ax.get_legend() is not None + + def test_legend_off(self): + fig, ax = plot_prediction_hist(df_pred=_df(), legend=False) + assert ax.get_legend() is None + + def test_custom_columns(self): + df = _df().rename(columns={"score": "s", "group": "g"}) + fig, ax = plot_prediction_hist(df_pred=df, col_score="s", col_group="g") + assert len(ax.patches) > 0 + + def test_draws_on_passed_ax(self): + fig0, ax0 = plt.subplots() + fig, ax = plot_prediction_hist(df_pred=_df(), ax=ax0) + assert ax is ax0 and fig is fig0 + + def test_figsize_applied(self): + fig, ax = plot_prediction_hist(df_pred=_df(), figsize=(10, 4)) + assert tuple(fig.get_size_inches()) == (10.0, 4.0) + + def test_labels_custom(self): + fig, ax = plot_prediction_hist(df_pred=_df(), xlabel="Score X", ylabel="N") + assert ax.get_xlabel() == "Score X" and ax.get_ylabel() == "N" + + def test_fontsize_labels_applied(self): + fig, ax = plot_prediction_hist(df_pred=_df(), fontsize_labels=15) + assert ax.xaxis.label.get_size() == 15 + + def test_canonical_substrate_color_is_green(self): + fig, ax = plot_prediction_hist(df_pred=_df(), + group_order=["substrate", "hold-out", "non-substrate"]) + green = matplotlib.colors.to_rgb(ut.COLOR_POS) + facecolors = {tuple(np.round(p.get_facecolor()[:3], 4)) for p in ax.patches} + assert tuple(np.round(green, 4)) in facecolors + + def test_custom_dict_color_applied(self): + fig, ax = plot_prediction_hist(df_pred=_df(n_sub=5, n_hold=0, n_non=5), + group_order=["substrate", "non-substrate"], + dict_color={"substrate": "#000000", "non-substrate": "#ffffff"}) + facecolors = {tuple(np.round(p.get_facecolor()[:3], 4)) for p in ax.patches} + assert (0.0, 0.0, 0.0) in facecolors + + +class TestPlotPredictionHistComplex: + """Negative cases and combinations.""" + + def test_missing_score_col_raises(self): + with pytest.raises(ValueError): + plot_prediction_hist(df_pred=_df().drop(columns=["score"])) + + def test_missing_group_col_raises(self): + with pytest.raises(ValueError): + plot_prediction_hist(df_pred=_df().drop(columns=["group"])) + + def test_empty_df_raises(self): + with pytest.raises(ValueError): + plot_prediction_hist(df_pred=_df(n_sub=0, n_hold=0, n_non=0)) + + def test_bad_score_name_raises(self): + with pytest.raises(ValueError): + plot_prediction_hist(df_pred=_df(), col_score="nope") + + def test_bad_group_name_raises(self): + with pytest.raises(ValueError): + plot_prediction_hist(df_pred=_df(), col_group="nope") + + def test_group_order_missing_group_raises(self): + with pytest.raises(ValueError): + plot_prediction_hist(df_pred=_df(), group_order=["substrate"]) + + def test_binwidth_zero_raises(self): + with pytest.raises(ValueError): + plot_prediction_hist(df_pred=_df(), binwidth=0) + + def test_binwidth_negative_raises(self): + with pytest.raises(ValueError): + plot_prediction_hist(df_pred=_df(), binwidth=-5) + + @pytest.mark.parametrize("bad", [(0,), (0, 50, 100), "rng"]) + def test_binrange_wrong_shape_raises(self, bad): + with pytest.raises(ValueError): + plot_prediction_hist(df_pred=_df(), binrange=bad) + + def test_binrange_low_ge_high_raises(self): + with pytest.raises(ValueError): + plot_prediction_hist(df_pred=_df(), binrange=(100, 0)) + + def test_stacked_wrong_type_raises(self): + with pytest.raises(ValueError): + plot_prediction_hist(df_pred=_df(), stacked="yes") + + def test_kde_wrong_type_raises(self): + with pytest.raises(ValueError): + plot_prediction_hist(df_pred=_df(), kde="yes") + + def test_legend_wrong_type_raises(self): + with pytest.raises(ValueError): + plot_prediction_hist(df_pred=_df(), legend="yes") + + def test_figsize_wrong_type_raises(self): + with pytest.raises(ValueError): + plot_prediction_hist(df_pred=_df(), figsize="big") + + def test_xlabel_wrong_type_raises(self): + with pytest.raises(ValueError): + plot_prediction_hist(df_pred=_df(), xlabel=123) + + def test_fontsize_labels_wrong_type_raises(self): + with pytest.raises(ValueError): + plot_prediction_hist(df_pred=_df(), fontsize_labels="big") + + def test_ax_wrong_type_raises(self): + with pytest.raises(ValueError): + plot_prediction_hist(df_pred=_df(), ax="not_an_axes") + + def test_dict_color_wrong_type_raises(self): + with pytest.raises(ValueError): + plot_prediction_hist(df_pred=_df(), dict_color="red") + + def test_group_order_wrong_type_raises(self): + with pytest.raises(ValueError): + plot_prediction_hist(df_pred=_df(), group_order="substrate") + + def test_single_group(self): + df = pd.DataFrame({"score": [10, 50, 90], "group": ["a", "a", "a"]}) + fig, ax = plot_prediction_hist(df_pred=df) + assert len(ax.patches) > 0 + + def test_non_numeric_score_raises(self): + # A non-numeric score column must fail loudly, not be silently coerced to NaN. + df = pd.DataFrame({"score": ["high", "low", "mid"], "group": ["a", "a", "a"]}) + with pytest.raises(ValueError): + plot_prediction_hist(df_pred=df) + + def test_all_nan_score_raises(self): + # An all-NaN score column must fail with a clear message, not a cryptic seaborn error. + df = pd.DataFrame({"score": [np.nan, np.nan], "group": ["a", "b"]}) + with pytest.raises(ValueError): + plot_prediction_hist(df_pred=df) + + def test_integer_yticks(self): + # Counts are integers; every y-tick within the data range must be a whole number. + fig, ax = plot_prediction_hist(df_pred=_df()) + _, top = ax.get_ylim() + assert all(float(t).is_integer() for t in ax.get_yticks() if 0 <= t <= top) diff --git a/tests/unit/plotting_tests/test_plot_rank.py b/tests/unit/plotting_tests/test_plot_rank.py index 4d51a359..f2f94b03 100644 --- a/tests/unit/plotting_tests/test_plot_rank.py +++ b/tests/unit/plotting_tests/test_plot_rank.py @@ -180,6 +180,12 @@ def test_bad_col_group_name_raises(self): with pytest.raises(ValueError): plot_rank(df_rank=_df(), col_group="nope") + def test_non_numeric_score_raises(self): + # A non-numeric score column must fail loudly (scatter mode), not draw garbage. + df = pd.DataFrame({"score": ["hi", "lo", "mid"], "group": ["a", "a", "b"]}) + with pytest.raises(ValueError): + plot_rank(df_rank=df) + # Wrong-TYPE negatives for the cosmetic params (clean ValueError, not a deep mpl crash) def test_figsize_wrong_type_raises(self): with pytest.raises(ValueError): @@ -212,3 +218,166 @@ def test_dict_color_wrong_type_raises(self): def test_group_order_wrong_type_raises(self): with pytest.raises(ValueError): plot_rank(df_rank=_df(), group_order="substrate") + + +# Helper for the additive ranked-candidates (bar) mode +def _df_cand(n_sub=3, n_non=3, seed=0): + rng = np.random.default_rng(seed) + n = n_sub + n_non + names = [f"GENE{i}" for i in range(n)] + return pd.DataFrame( + {"score": rng.uniform(0, 100, n), + "std": rng.uniform(1, 10, n), + "class": ["substrate"] * n_sub + ["non-substrate"] * n_non}, + index=names, + ) + + +class TestPlotRankCandidatesMode: + """The additive col_class / col_std ranked-candidates (horizontal-bar) variant.""" + + def test_col_class_switches_to_bars(self): + fig, ax = plot_rank(df_rank=_df_cand(), col_score="score", col_class="class") + # barh produces a BarContainer (no scatter collections in this mode) + assert len(ax.containers) >= 1 and len(ax.collections) == 0 + + def test_yticklabels_are_candidate_names(self): + df = _df_cand() + fig, ax = plot_rank(df_rank=df, col_score="score", col_class="class") + labels = {t.get_text() for t in ax.get_yticklabels()} + assert labels == set(df.index.astype(str)) + + def test_col_std_draws_error_bars(self): + df = _df_cand() + fig, ax = plot_rank(df_rank=df, col_score="score", col_class="class", col_std="std") + # errorbar adds an extra container beyond the BarContainer + assert len(ax.containers) >= 2 + + def test_threshold_draws_vertical_line(self): + fig, ax = plot_rank(df_rank=_df_cand(), col_score="score", col_class="class", + threshold=50) + dashed = [ln for ln in ax.get_lines() if ln.get_linestyle() == "--"] + assert len(dashed) == 1 + + def test_bar_colors_follow_class(self): + df = _df_cand(n_sub=2, n_non=2) + fig, ax = plot_rank(df_rank=df, col_score="score", col_class="class", + group_order=["substrate", "non-substrate"]) + green = matplotlib.colors.to_rgb(ut.COLOR_POS) + facecolors = {tuple(np.round(p.get_facecolor()[:3], 4)) for p in ax.patches} + assert tuple(np.round(green, 4)) in facecolors + + def test_default_labels_substituted_in_bar_mode(self): + # scatter defaults ("Protein rank") must not leak onto the score axis + fig, ax = plot_rank(df_rank=_df_cand(), col_score="score", col_class="class") + assert ax.get_xlabel() == "Prediction score" and ax.get_ylabel() == "" + + def test_explicit_labels_respected_in_bar_mode(self): + fig, ax = plot_rank(df_rank=_df_cand(), col_score="score", col_class="class", + xlabel="My score", ylabel="My genes") + assert ax.get_xlabel() == "My score" and ax.get_ylabel() == "My genes" + + def test_col_std_without_col_class_raises(self): + with pytest.raises(ValueError): + plot_rank(df_rank=_df_cand(), col_score="score", col_std="std") + + def test_missing_col_class_col_raises(self): + with pytest.raises(ValueError): + plot_rank(df_rank=_df_cand(), col_score="score", col_class="nope") + + def test_missing_col_std_col_raises(self): + with pytest.raises(ValueError): + plot_rank(df_rank=_df_cand(), col_score="score", col_class="class", col_std="nope") + + def test_col_class_wrong_type_raises(self): + with pytest.raises(ValueError): + plot_rank(df_rank=_df_cand(), col_score="score", col_class=123) + + def test_col_std_wrong_type_raises(self): + with pytest.raises(ValueError): + plot_rank(df_rank=_df_cand(), col_score="score", col_class="class", col_std=123) + + def test_non_numeric_score_bar_mode_raises(self): + df = pd.DataFrame({"score": ["hi", "lo"], "class": ["substrate", "non-substrate"]}, + index=["A", "B"]) + with pytest.raises(ValueError): + plot_rank(df_rank=df, col_score="score", col_class="class") + + def test_non_numeric_std_raises(self): + df = _df_cand() + df["std"] = df["std"].astype(str) + with pytest.raises(ValueError): + plot_rank(df_rank=df, col_score="score", col_class="class", col_std="std") + + def test_draws_on_passed_ax_bar_mode(self): + fig0, ax0 = plt.subplots() + fig, ax = plot_rank(df_rank=_df_cand(), col_score="score", col_class="class", ax=ax0) + assert ax is ax0 and fig is fig0 + + def test_bar_legend_placed_outside_axes(self): + # The class legend must sit outside the plot area (right) so it never overlaps bars. + fig, ax = plot_rank(df_rank=_df_cand(), col_score="score", col_class="class") + fig.canvas.draw() + leg = ax.get_legend() + assert leg is not None and not leg.get_frame_on() + leg_left = leg.get_window_extent().x0 + ax_right = ax.get_window_extent().x1 + assert leg_left >= ax_right # legend starts at/after the axes' right edge + + +class TestPlotRankDefaultRegression: + """Guard: the default scatter path stays byte-identical to the pre-``col_class`` output. + + The expected values below are FROZEN from the scatter implementation before the additive + ranked-candidates mode was added (``_df()`` seed=0, threshold=[0.5, 0.8]). They are NOT + recomputed from the current code, so any change to the scatter branch (sort order, ranking, + group->color mapping, or threshold drawing) makes these assertions fail. + """ + + # Golden snapshot of the first drawn group ("substrate") as (rank, score) pairs. + GOLDEN_SUBSTRATE_OFFSETS = [ + (3.0, 0.935072), (4.0, 0.912756), (8.0, 0.81327), (10.0, 0.729497), + (15.0, 0.636962), (17.0, 0.606636), (18.0, 0.543625), (23.0, 0.269787), + (26.0, 0.040974), (29.0, 0.016528), + ] + GOLDEN_LABELS = ["substrate", "hold-out", "non-substrate"] + GOLDEN_SIZES = [10, 5, 15] + + def test_scatter_offsets_frozen(self): + # Exact ranked (rank, score) coordinates of the substrate group must not drift. + fig, ax = plot_rank(df_rank=_df(), threshold=[0.5, 0.8]) + offs = [tuple(np.round(xy, 6)) for xy in ax.collections[0].get_offsets()] + assert offs == [tuple(np.round(g, 6)) for g in self.GOLDEN_SUBSTRATE_OFFSETS] + + def test_scatter_groups_labels_and_sizes_frozen(self): + fig, ax = plot_rank(df_rank=_df(), threshold=[0.5, 0.8]) + labels = [t.get_text() for t in ax.get_legend().get_texts()] + sizes = [len(c.get_offsets()) for c in ax.collections] + assert labels == self.GOLDEN_LABELS and sizes == self.GOLDEN_SIZES + + def test_scatter_group_colors_frozen(self): + fig, ax = plot_rank(df_rank=_df(), + group_order=["substrate", "hold-out", "non-substrate"]) + expected = [ut.COLOR_POS, ut.COLOR_REL_NEG, ut.COLOR_NEG] # green / brown / magenta + for coll, col in zip(ax.collections, expected): + assert np.allclose(coll.get_facecolor()[0][:3], matplotlib.colors.to_rgb(col)) + + def test_scatter_thresholds_are_horizontal_frozen(self): + # Scatter mode draws thresholds as HORIZONTAL dashed lines at the given y-values. + fig, ax = plot_rank(df_rank=_df(), threshold=[0.5, 0.8]) + lines = sorted((round(float(ln.get_ydata()[0]), 6), ln.get_linestyle()) + for ln in ax.get_lines()) + assert lines == [(0.5, "--"), (0.8, "--")] + # each threshold line is flat (constant y) -> truly horizontal + assert all(len(set(np.round(ln.get_ydata(), 6))) == 1 for ln in ax.get_lines()) + + def test_new_args_default_dont_touch_scatter(self): + # Passing the new args at their explicit defaults yields exactly the golden scatter. + fig, ax = plot_rank(df_rank=_df(), threshold=[0.5, 0.8], col_std=None, col_class=None) + offs = [tuple(np.round(xy, 6)) for xy in ax.collections[0].get_offsets()] + assert offs == [tuple(np.round(g, 6)) for g in self.GOLDEN_SUBSTRATE_OFFSETS] + assert len(ax.patches) == 0 # scatter path, not the bar path + + def test_default_axis_labels_unchanged(self): + fig, ax = plot_rank(df_rank=_df()) + assert ax.get_xlabel() == "Protein rank" and ax.get_ylabel() == "Max score per protein"