Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .claude/rules/notebooks.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,14 @@ Rules:

## General

- Cover every public parameter of the demonstrated method (grouped sensibly).
- **Cover every public parameter of the demonstrated method (grouped sensibly).** This is a
hard rule, not a nicety: pass every public param by name in a code cell, related params
grouped into a "further parameters" cell. A minimal "call it with one arg" notebook is
rejected. **Enforced in CI** by `tests/unit/api_tests/test_notebook_param_coverage.py`
(name-based, with a committed `notebook_param_coverage_baseline.txt` backlog ratchet): a
new notebook — or any prediction notebook — must have **zero** undemonstrated params, and
the baseline may only shrink, never grow. When you add a param to a public signature, add
it to the notebook in the same change.
- Set `aa.options["verbose"] = False` at the top to keep outputs clean.
- Commit notebooks **with executed outputs** (tables + images), and re-run
before pushing — a stale/unexecuted notebook is a recurring miss.
Expand Down
22 changes: 15 additions & 7 deletions aaanalysis/prediction/_aa_pred_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,14 @@ def eval(df_eval: pd.DataFrame,
ylabel: str = "Score",
) -> Tuple[Figure, Axes]:
"""
Grouped bar plot of the model x metric evaluation table.
Grouped bar plot comparing methods across metrics (hue = model).

Bars are grouped by metric and colored by model; cross-validation bars carry error bars
from ``score_std`` and held-out bars (if present) are drawn hatched.
Each metric is a group on the x-axis and each model is a colored bar (the hue), so the
different **methods** are compared side by side. Cross-validation bars carry ``score_std``
error bars and held-out bars are hatched; pass ``baseline`` for a chance line. This is the
plot for **method** comparison — to compare **CPP parameter combinations** (parameter
ranges) instead, use the feature-optimization protocol :func:`aaanalysis.pipe.find_features`
and its evaluation-grid heatmap :func:`aaanalysis.pipe.plot_eval`.

.. versionadded:: 1.1.0

Expand All @@ -88,7 +92,7 @@ def eval(df_eval: pd.DataFrame,
figsize : tuple, default=(7, 5)
Figure size when ``ax`` is ``None``.
dict_color : dict, optional
Mapping ``model -> color``. Defaults to the house categorical palette.
Mapping ``model -> color`` (the bar hue). Defaults to the house categorical palette.
baseline : int or float, optional
If given, a horizontal reference line is drawn at this score (e.g. ``0.5`` for chance).
ylabel : str, default="Score"
Expand All @@ -101,6 +105,11 @@ def eval(df_eval: pd.DataFrame,
ax : matplotlib.axes.Axes
The axes with the grouped bar plot.

See Also
--------
* :func:`aaanalysis.pipe.find_features` and :func:`aaanalysis.pipe.plot_eval` for
comparing CPP parameter combinations (a heatmap over the parameter grid).

Examples
--------
.. include:: examples/aapred_plot_eval.rst
Expand All @@ -114,15 +123,14 @@ def eval(df_eval: pd.DataFrame,
if baseline is not None:
ut.check_number_val(name="baseline", val=baseline, just_int=False)
ut.check_str(name="ylabel", val=ylabel, accept_none=True)
# Resolve layout
# Grouped bar plot: metrics on the x-axis, one hued bar per model
metrics = list(dict.fromkeys(df_eval[ut.COL_METRIC].tolist()))
models = list(dict.fromkeys(df_eval[ut.COL_MODEL].tolist()))
principles = list(dict.fromkeys(df_eval[ut.COL_PRINCIPLE].tolist()))
fig, ax = _new_ax(ax=ax, figsize=figsize)
clist = ut.plot_get_clist_(n_colors=max(len(models), 2))
dict_color = dict(dict_color) if dict_color is not None else {}
dict_model_color = {m: dict_color.get(m, clist[i % len(clist)]) for i, m in enumerate(models)}
# Draw grouped bars
fig, ax = _new_ax(ax=ax, figsize=figsize)
n_groups = len(models) * len(principles)
width = 0.8 / max(n_groups, 1)
x = np.arange(len(metrics))
Expand Down
367 changes: 324 additions & 43 deletions examples/prediction/aapred.ipynb

Large diffs are not rendered by default.

98 changes: 86 additions & 12 deletions examples/prediction/aapred_fit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
"id": "9a3b2221",
"metadata": {
"execution": {
"iopub.execute_input": "2026-07-02T12:04:35.878829Z",
"iopub.status.busy": "2026-07-02T12:04:35.878741Z",
"iopub.status.idle": "2026-07-02T12:04:37.401413Z",
"shell.execute_reply": "2026-07-02T12:04:37.401110Z"
"iopub.execute_input": "2026-07-04T13:09:13.955770Z",
"iopub.status.busy": "2026-07-04T13:09:13.955617Z",
"iopub.status.idle": "2026-07-04T13:09:15.361031Z",
"shell.execute_reply": "2026-07-04T13:09:15.360738Z"
}
},
"outputs": [],
Expand Down Expand Up @@ -50,10 +50,10 @@
"id": "f7672511",
"metadata": {
"execution": {
"iopub.execute_input": "2026-07-02T12:04:37.402828Z",
"iopub.status.busy": "2026-07-02T12:04:37.402743Z",
"iopub.status.idle": "2026-07-02T12:04:37.444450Z",
"shell.execute_reply": "2026-07-02T12:04:37.444216Z"
"iopub.execute_input": "2026-07-04T13:09:15.362317Z",
"iopub.status.busy": "2026-07-04T13:09:15.362254Z",
"iopub.status.idle": "2026-07-04T13:09:15.402484Z",
"shell.execute_reply": "2026-07-04T13:09:15.402274Z"
}
},
"outputs": [
Expand Down Expand Up @@ -85,10 +85,10 @@
"id": "0c93d6ef",
"metadata": {
"execution": {
"iopub.execute_input": "2026-07-02T12:04:37.445548Z",
"iopub.status.busy": "2026-07-02T12:04:37.445481Z",
"iopub.status.idle": "2026-07-02T12:04:37.448901Z",
"shell.execute_reply": "2026-07-02T12:04:37.448697Z"
"iopub.execute_input": "2026-07-04T13:09:15.403635Z",
"iopub.status.busy": "2026-07-04T13:09:15.403573Z",
"iopub.status.idle": "2026-07-04T13:09:15.406814Z",
"shell.execute_reply": "2026-07-04T13:09:15.406619Z"
}
},
"outputs": [
Expand Down Expand Up @@ -123,6 +123,80 @@
"source": [
"The positive class whose probability :meth:`AAPred.predict_proba` returns is set by ``label_pos`` (default=1)."
]
},
{
"cell_type": "markdown",
"id": "f1md0001",
"metadata": {},
"source": [
"**Further parameters.** ``label_pos`` sets which class is treated as positive. Hyperparameters can be tuned per model by ``GridSearchCV``: enable ``optimize_hyperparams``, optionally passing an explicit ``param_grids`` (a single dict is applied to every model) and the number of stratified folds ``n_cv``:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f1cd0001",
"metadata": {
"execution": {
"iopub.execute_input": "2026-07-04T13:09:15.407865Z",
"iopub.status.busy": "2026-07-04T13:09:15.407807Z",
"iopub.status.idle": "2026-07-04T13:09:15.426862Z",
"shell.execute_reply": "2026-07-04T13:09:15.426665Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of fitted models: 1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/stephanbreimann/Programming/1Packages/aaanalysis/.venv/lib/python3.13/site-packages/sklearn/svm/_base.py:239: FutureWarning: The `probability` parameter was deprecated in 1.9 and will be removed in version 1.11. Use `CalibratedClassifierCV(SVC(), ensemble=False)` instead of `SVC(probability=True)`\n",
" warnings.warn(\n",
"/Users/stephanbreimann/Programming/1Packages/aaanalysis/.venv/lib/python3.13/site-packages/sklearn/svm/_base.py:239: FutureWarning: The `probability` parameter was deprecated in 1.9 and will be removed in version 1.11. Use `CalibratedClassifierCV(SVC(), ensemble=False)` instead of `SVC(probability=True)`\n",
" warnings.warn(\n",
"/Users/stephanbreimann/Programming/1Packages/aaanalysis/.venv/lib/python3.13/site-packages/sklearn/svm/_base.py:239: FutureWarning: The `probability` parameter was deprecated in 1.9 and will be removed in version 1.11. Use `CalibratedClassifierCV(SVC(), ensemble=False)` instead of `SVC(probability=True)`\n",
" warnings.warn(\n",
"/Users/stephanbreimann/Programming/1Packages/aaanalysis/.venv/lib/python3.13/site-packages/sklearn/svm/_base.py:239: FutureWarning: The `probability` parameter was deprecated in 1.9 and will be removed in version 1.11. Use `CalibratedClassifierCV(SVC(), ensemble=False)` instead of `SVC(probability=True)`\n",
" warnings.warn(\n",
"/Users/stephanbreimann/Programming/1Packages/aaanalysis/.venv/lib/python3.13/site-packages/sklearn/svm/_base.py:239: FutureWarning: The `probability` parameter was deprecated in 1.9 and will be removed in version 1.11. Use `CalibratedClassifierCV(SVC(), ensemble=False)` instead of `SVC(probability=True)`\n",
" warnings.warn(\n",
"/Users/stephanbreimann/Programming/1Packages/aaanalysis/.venv/lib/python3.13/site-packages/sklearn/svm/_base.py:239: FutureWarning: The `probability` parameter was deprecated in 1.9 and will be removed in version 1.11. Use `CalibratedClassifierCV(SVC(), ensemble=False)` instead of `SVC(probability=True)`\n",
" warnings.warn(\n",
"/Users/stephanbreimann/Programming/1Packages/aaanalysis/.venv/lib/python3.13/site-packages/sklearn/svm/_base.py:239: FutureWarning: The `probability` parameter was deprecated in 1.9 and will be removed in version 1.11. Use `CalibratedClassifierCV(SVC(), ensemble=False)` instead of `SVC(probability=True)`\n",
" warnings.warn(\n",
"/Users/stephanbreimann/Programming/1Packages/aaanalysis/.venv/lib/python3.13/site-packages/sklearn/svm/_base.py:239: FutureWarning: The `probability` parameter was deprecated in 1.9 and will be removed in version 1.11. Use `CalibratedClassifierCV(SVC(), ensemble=False)` instead of `SVC(probability=True)`\n",
" warnings.warn(\n",
"/Users/stephanbreimann/Programming/1Packages/aaanalysis/.venv/lib/python3.13/site-packages/sklearn/svm/_base.py:239: FutureWarning: The `probability` parameter was deprecated in 1.9 and will be removed in version 1.11. Use `CalibratedClassifierCV(SVC(), ensemble=False)` instead of `SVC(probability=True)`\n",
" warnings.warn(\n",
"/Users/stephanbreimann/Programming/1Packages/aaanalysis/.venv/lib/python3.13/site-packages/sklearn/svm/_base.py:239: FutureWarning: The `probability` parameter was deprecated in 1.9 and will be removed in version 1.11. Use `CalibratedClassifierCV(SVC(), ensemble=False)` instead of `SVC(probability=True)`\n",
" warnings.warn(\n",
"/Users/stephanbreimann/Programming/1Packages/aaanalysis/.venv/lib/python3.13/site-packages/sklearn/svm/_base.py:239: FutureWarning: The `probability` parameter was deprecated in 1.9 and will be removed in version 1.11. Use `CalibratedClassifierCV(SVC(), ensemble=False)` instead of `SVC(probability=True)`\n",
" warnings.warn(\n",
"/Users/stephanbreimann/Programming/1Packages/aaanalysis/.venv/lib/python3.13/site-packages/sklearn/svm/_base.py:239: FutureWarning: The `probability` parameter was deprecated in 1.9 and will be removed in version 1.11. Use `CalibratedClassifierCV(SVC(), ensemble=False)` instead of `SVC(probability=True)`\n",
" warnings.warn(\n",
"/Users/stephanbreimann/Programming/1Packages/aaanalysis/.venv/lib/python3.13/site-packages/sklearn/svm/_base.py:239: FutureWarning: The `probability` parameter was deprecated in 1.9 and will be removed in version 1.11. Use `CalibratedClassifierCV(SVC(), ensemble=False)` instead of `SVC(probability=True)`\n",
" warnings.warn(\n",
"/Users/stephanbreimann/Programming/1Packages/aaanalysis/.venv/lib/python3.13/site-packages/sklearn/svm/_base.py:239: FutureWarning: The `probability` parameter was deprecated in 1.9 and will be removed in version 1.11. Use `CalibratedClassifierCV(SVC(), ensemble=False)` instead of `SVC(probability=True)`\n",
" warnings.warn(\n",
"/Users/stephanbreimann/Programming/1Packages/aaanalysis/.venv/lib/python3.13/site-packages/sklearn/svm/_base.py:239: FutureWarning: The `probability` parameter was deprecated in 1.9 and will be removed in version 1.11. Use `CalibratedClassifierCV(SVC(), ensemble=False)` instead of `SVC(probability=True)`\n",
" warnings.warn(\n",
"/Users/stephanbreimann/Programming/1Packages/aaanalysis/.venv/lib/python3.13/site-packages/sklearn/svm/_base.py:239: FutureWarning: The `probability` parameter was deprecated in 1.9 and will be removed in version 1.11. Use `CalibratedClassifierCV(SVC(), ensemble=False)` instead of `SVC(probability=True)`\n",
" warnings.warn(\n"
]
}
],
"source": [
"aapred = aa.AAPred(models=[\"svm\"], random_state=42)\n",
"aapred.fit(X, labels, label_pos=1, optimize_hyperparams=True,\n",
" param_grids={\"C\": [0.1, 1.0, 10.0]}, n_cv=5)\n",
"print(\"Number of fitted models:\", len(aapred.list_models_))"
]
}
],
"metadata": {
Expand Down
57 changes: 49 additions & 8 deletions examples/prediction/aapred_plot_clustermap.ipynb

Large diffs are not rendered by default.

114 changes: 81 additions & 33 deletions examples/prediction/aapred_plot_comparison.ipynb

Large diffs are not rendered by default.

56 changes: 48 additions & 8 deletions examples/prediction/aapred_plot_cutoff.ipynb

Large diffs are not rendered by default.

55 changes: 47 additions & 8 deletions examples/prediction/aapred_plot_domain.ipynb

Large diffs are not rendered by default.

142 changes: 134 additions & 8 deletions examples/prediction/aapred_plot_eval.ipynb

Large diffs are not rendered by default.

57 changes: 49 additions & 8 deletions examples/prediction/aapred_plot_hist.ipynb

Large diffs are not rendered by default.

Loading
Loading