Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
b23ed2c
feat(prediction): AAPred + AAPredPlot — evaluate & deploy prediction …
breimanntools Jul 1, 2026
64b06d7
feat(prediction): sequence-level prediction (seq / domain / window) +…
breimanntools Jul 2, 2026
a493ef7
feat(prediction): shared model registry (get_cv_model_)
breimanntools Jul 2, 2026
6659b7d
feat(prediction): AAPred models= (names/instances) + GridSearchCV tuning
breimanntools Jul 2, 2026
a7476d5
Merge remote-tracking branch 'origin/master' into feat/prediction-class
breimanntools Jul 2, 2026
bf2cd2c
refactor(prediction): featurize_seq via feature_matrix(df_seq=, df_pa…
breimanntools Jul 2, 2026
d0cd9a7
docs(explainable_ai): frame TreeModel as global importance; point to …
breimanntools Jul 2, 2026
d2b544d
feat(prediction): AAPredPlot.comparison — grouped method x condition …
breimanntools Jul 2, 2026
663f223
feat(prediction): AAPredPlot.ranking — ranked-candidate bars (class c…
breimanntools Jul 2, 2026
a851c76
feat(prediction): AAPredPlot.clustermap + 3-type figure grouping
breimanntools Jul 2, 2026
4987cb3
fix(prediction): AAPredPlot new methods annotate -> Tuple[Figure, Axes]
breimanntools Jul 2, 2026
fb5be91
fix(prediction): clone configured estimators (voting/stacking/xgboost…
breimanntools Jul 2, 2026
a4307cc
simplify(prediction): trim model registry to the 4 established families
breimanntools Jul 2, 2026
63a7030
docs(prediction): re-execute example notebooks after the models= / cl…
breimanntools Jul 2, 2026
09892c1
Merge remote-tracking branch 'origin/master' into feat/prediction-class
breimanntools Jul 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aaanalysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .feature_engineering import AAclust, AAclustPlot, SequenceFeature, NumericalFeature, CPP, CPPGrid, CPPPlot
from .pu_learning import dPULearn, dPULearnPlot
from .explainable_ai import TreeModel
from .prediction import AAPred, AAPredPlot
from .protein_engineering import AAMut, AAMutPlot, SeqMut, SeqMutPlot, SeqOpt, SeqOptPlot
from .plotting import (plot_get_clist, plot_get_cmap, plot_get_cdict,
plot_settings, plot_legend, plot_gcfs, plot_rank)
Expand Down Expand Up @@ -60,6 +61,8 @@
"SeqOpt",
"SeqOptPlot",
"TreeModel",
"AAPred",
"AAPredPlot",
# "ShapModel" # SHAP
"plot_get_clist",
"plot_get_cmap",
Expand Down
21 changes: 21 additions & 0 deletions aaanalysis/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,14 @@ def _folder_path(super_folder, folder_name):
MODEL_SVM = "svm"
MODEL_RF = "rf"
MODEL_LOG_REG = "log_reg"
MODEL_EXTRA_TREES = "extra_trees"
LIST_CV_MODELS = [MODEL_SVM, MODEL_RF, MODEL_LOG_REG]
# Prediction-model registry names (AAPred). Resolved to sklearn estimators by
# ut.get_cv_model_. Kept deliberately small — the four standard families the package
# already uses (matching pipe.predict_samples' default set). Any other estimator
# (MLP, gradient boosting, xgboost, a voting/stacking ensemble, ...) is used by
# passing a configured sklearn estimator instance instead of a name.
LIST_PRED_MODELS = [MODEL_SVM, MODEL_RF, MODEL_EXTRA_TREES, MODEL_LOG_REG]
DICT_VALUE_TYPE = {COL_ABS_AUC: "mean",
COL_ABS_MEAN_DIF: "mean",
COL_MEAN_DIF: "mean",
Expand Down Expand Up @@ -441,6 +448,20 @@ def _folder_path(super_folder, folder_name):
COL_AVG_ABS_AUC_NEG, COL_AVG_KLD_NEG]
COLS_EVAL_DPULEARN = [COL_N_REL_NEG] + COLS_EVAL_DPULEARN_SIMILARITY + COLS_EVAL_DPULEARN_DISSIMILARITY

# AAPred (model evaluation / deployment)
COL_MODEL = "model" # model class short name (e.g. 'RandomForestClassifier')
COL_METRIC = "metric" # performance metric name (e.g. 'balanced_accuracy')
COL_PRINCIPLE = "principle" # evaluation principle: 'cv' (cross-validation) | 'holdout'
COL_SCORE_STD = "score_std" # std of the score (across CV folds; NaN for a single holdout estimate)
COL_GROUP = "group" # per-sample/per-protein group label used for coloring
COL_OFFSET = "offset" # AAPred.predict_domain — boundary shift applied to tmd_start/tmd_stop
COL_RESIDUE_POS = "position" # AAPred.predict_window — 1-based anchor position scored
STR_PRINCIPLE_CV = "cv"
STR_PRINCIPLE_HOLDOUT = "holdout"
LIST_PRINCIPLES = [STR_PRINCIPLE_CV, STR_PRINCIPLE_HOLDOUT]
LIST_METRICS_PRED = ["accuracy", "balanced_accuracy", "precision", "recall", "f1", "roc_auc"]
COLS_EVAL_PRED = [COL_MODEL, COL_METRIC, COL_PRINCIPLE, COL_SCORE, COL_SCORE_STD]

# Labels
LABEL_FEAT_VAL = "Feature value"
LABEL_HIST_COUNT = "Number of proteins"
Expand Down
56 changes: 56 additions & 0 deletions aaanalysis/_utils/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Shared classifier registry: map a model-name string to a configured estimator.

One source of truth for the ``name -> sklearn estimator`` mapping used across the
package (AAPred, ``predict_samples``, ``find_features``), so the name list does not
drift between call sites. The roster is kept deliberately small — the four standard
families the package already uses. Any other estimator (MLP, gradient boosting,
xgboost, a voting/stacking ensemble, a full ``Pipeline``, ...) is used by passing a
configured sklearn estimator instance instead of a name; only names route here. The
default (``svm``) reproduces the linear-SVM recipe used throughout the γ-secretase
analysis.
"""
from .. import _constants as const


# I Helper Functions
def _svm(random_state=None):
from sklearn.svm import SVC
return SVC(kernel="linear", probability=True, random_state=random_state)


def _rf(random_state=None):
from sklearn.ensemble import RandomForestClassifier
return RandomForestClassifier(random_state=random_state)


def _extra_trees(random_state=None):
from sklearn.ensemble import ExtraTreesClassifier
return ExtraTreesClassifier(random_state=random_state)


def _log_reg(random_state=None):
from sklearn.linear_model import LogisticRegression
return LogisticRegression(max_iter=1000, random_state=random_state)


_FACTORIES = {
const.MODEL_SVM: _svm,
const.MODEL_RF: _rf,
const.MODEL_EXTRA_TREES: _extra_trees,
const.MODEL_LOG_REG: _log_reg,
}


# II Main Functions
def get_cv_model_(name=None, random_state=None):
"""Return a fresh configured estimator for a registry ``name``.

``name`` is one of ``ut.LIST_PRED_MODELS``. ``random_state`` is injected where the
estimator supports it. Raises ``ValueError`` for an unknown name; pass a configured
sklearn estimator instance instead of a name to use any other model.
"""
if name not in _FACTORIES:
valid = ", ".join(list(_FACTORIES))
raise ValueError(f"'model' name '{name}' is not in the registry. Valid names: {valid}. "
f"A configured sklearn estimator instance may be passed instead.")
return _FACTORIES[name](random_state=random_state)
5 changes: 5 additions & 0 deletions aaanalysis/explainable_ai/_tree_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,11 @@ def predict_proba(self,
.. note::
:meth:`TreeModel.fit` must be called before using this method.

.. note::
``TreeModel`` is focused on global Monte-Carlo feature *importance*. For training and
deploying prediction models (selecting estimators, tuning, and scoring at the sequence,
domain, and window level), :class:`AAPred` is the recommended entry point.

.. versionadded:: 0.1.0

Parameters
Expand Down
20 changes: 20 additions & 0 deletions aaanalysis/prediction/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
Prediction: evaluate and deploy sequence-based prediction models.

Public objects: AAPred, AAPredPlot.
Downstream of feature engineering (``CPP`` / ``CPPGrid`` produce ``df_feat`` and the feature
matrix ``X``): ``AAPred`` evaluates one or more scikit-learn models across metrics by
cross-validation and an optional held-out set, and fits them for deployment
(``predict_proba`` / ``predict``); ``AAPredPlot`` visualizes the evaluation table and the
per-sample prediction scores. Complements ``explainable_ai.TreeModel`` (tree-ensemble
feature importance) — this subpackage owns the general evaluate-and-deploy path.

See ``.claude/rules/code-conventions.md`` for conventions and ``CONTEXT.md`` for domain terms.
"""
from ._aa_pred import AAPred
from ._aa_pred_plot import AAPredPlot

__all__ = [
"AAPred",
"AAPredPlot",
]
Loading
Loading