Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions aaanalysis/explainable_ai_pro/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
"""
Pro explainable AI: CPP-SHAP feature explanations (``pro`` extra).

Public objects: ShapModel.
Gated behind the ``pro`` extra (needs ``shap``). Wraps a fitted model (typically
Public objects: ShapModel, shap_to_feat_imp.
Gated behind the ``pro`` extra (needs ``shap``). ``ShapModel`` wraps a fitted model (typically
``explainable_ai.TreeModel``) to compute SHAP values, which ``feature_engineering.CPPPlot``
renders as per-feature impact. Imported lazily from the top-level package and replaced
by an install-hint stub when ``shap`` is absent.
renders as per-feature impact; ``shap_to_feat_imp`` converts a per-sample SHAP vector into signed
feature impact / absolute importance. Imported lazily from the top-level package and replaced by an
install-hint stub when ``shap`` is absent.

See ``.claude/rules/pro-core-boundary.md`` for the pro/core boundary, ``CONTEXT.md``
for domain terms (explainability (CPP-SHAP) vocabulary).
"""
from ._shap_model import ShapModel
from ._shap_model_plot import shap_to_feat_imp

# NOTE: shap_to_feat_imp is intentionally NOT yet re-exported from the top-level
# aaanalysis namespace.
__all__ = [
"ShapModel",
"shap_to_feat_imp",
]
74 changes: 74 additions & 0 deletions aaanalysis/explainable_ai_pro/_shap_model_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
This is a script for the stand-alone ``shap_to_feat_imp`` helper.

It complements :class:`ShapModel` (which computes the SHAP values) with a helper that turns a
per-sample SHAP vector into a normalized signed feature impact / absolute importance.
"""
import numpy as np

import aaanalysis.utils as ut
from ._backend.shap_model.sm_add_feat_impact import _comp_sample_shap_feat_impact


# II Main Functions
def shap_to_feat_imp(shap_values: ut.ArrayLike1D,
impact: bool = True,
) -> np.ndarray:
"""
Convert a per-sample SHAP-value vector into normalized feature impact or importance
(**[pro]**, requires ``aaanalysis[pro]``).

For one sample (or the mean SHAP vector of a group of same-class samples), the SHAP
values are normalized so the sum of their absolute values equals 100%:

- **feature impact** (``impact=True``): signed, ``shap / sum(|shap|) * 100`` — keeps the
sign, so a feature that pushes the prediction up is positive and one that pushes it
down is negative.
- **feature importance** (``impact=False``): absolute, ``|shap| / sum(|shap|) * 100`` —
magnitude only, the per-sample analogue of the SHAP value-based feature importance.

This shares the normalization used internally by :meth:`ShapModel.add_feat_impact`
(re-using its per-sample backend) so the two never diverge.

.. versionadded:: 1.1.0

Parameters
----------
shap_values : array-like, shape (n_features,)
One-dimensional array of SHAP values for a single sample (or the mean SHAP vector of
a group of same-class samples). Computation is only meaningful within one class.
impact : bool, default=True
If ``True``, return the signed feature impact; if ``False``, the absolute feature importance.

Returns
-------
feat_imp : np.ndarray, shape (n_features,)
Normalized feature impact (signed) or importance (absolute), summing in absolute
value to 100.

See Also
--------
* :meth:`ShapModel.add_feat_impact` for attaching impact/importance columns to a feature DataFrame.

Examples
--------
>>> import numpy as np
>>> from aaanalysis.explainable_ai_pro import shap_to_feat_imp
>>> shap_vec = np.array([0.2, -0.1, 0.3, -0.4])
>>> impact = shap_to_feat_imp(shap_vec, impact=True)
>>> float(np.round(np.abs(impact).sum(), 6))
100.0
"""
shap_values = ut.check_array_like(name="shap_values", val=shap_values,
dtype="numeric", expected_dim=1)
ut.check_bool(name="impact", val=impact)
if np.nansum(np.abs(shap_values)) == 0:
raise ValueError("'shap_values' are all zero; feature impact/importance is undefined.")
if impact:
# Re-use the per-sample backend used by ShapModel.add_feat_impact (no divergence)
feat_imp = _comp_sample_shap_feat_impact(shap_values=shap_values.reshape(1, -1),
i=0, normalize=False)
return np.asarray(feat_imp, dtype=float)
abs_values = np.abs(shap_values)
feat_imp = abs_values / np.nansum(abs_values) * 100
return np.asarray(feat_imp, dtype=float)
101 changes: 101 additions & 0 deletions aaanalysis/feature_engineering/_cpp_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
This is a script for the frontend of the CPPPlot class.
"""
from typing import Optional, Dict, Union, List, Tuple, Type, Literal
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
Expand Down Expand Up @@ -184,6 +185,51 @@ def check_match_ax_seq_len(ax=None, jmd_n_len=10, jmd_c_len=10) -> None:
f"\n Sequence (len: {len(tmd_jmd_seq)}) retrieved from 'ax' is: '{tmd_jmd_seq}'")


def resolve_sample_kws(sample=None, df_seq=None, df_parts=None,
jmd_n_seq=None, tmd_seq=None, jmd_c_seq=None,
resolve_seq=True):
"""Resolve the ``sample=`` shortcut for the SHAP plot methods.

When ``sample`` is given, the feature-impact column is resolved to ``feat_impact_<entry>``
and (for sequence-aware plots) the TMD-JMD parts are read from ``df_parts`` via
:meth:`SequenceFeature.get_seq_kws`, so the per-sample ``col_imp`` string-templating and the
explicit ``get_seq_kws`` plumbing are no longer needed at the call site. Explicitly passed
sequence parts are never overridden.

The impact column is keyed by the protein **entry name** (as written by
:meth:`ShapModel.add_feat_impact`). A ``sample`` given as a row position (int) is therefore
mapped to its entry name via ``df_parts``' index; ``ranking`` has no ``df_parts`` to map a
position, so it accepts the entry name (str) only.
"""
is_str = isinstance(sample, str)
is_int = isinstance(sample, (int, np.integer)) and not isinstance(sample, bool)
if not (is_str or is_int):
raise ValueError(f"'sample' ({sample}) should be an entry name (str) or a row position (int).")
if resolve_seq:
if df_seq is None or df_parts is None:
raise ValueError("'df_seq' and 'df_parts' are required when 'sample' is given for "
"sequence-level plots, to resolve the TMD-JMD parts.")
# Lazy import to avoid a circular import at module load (same subpackage).
from aaanalysis.feature_engineering._sequence_feature import SequenceFeature
seq_kws = SequenceFeature().get_seq_kws(df_seq=df_seq, df_parts=df_parts, sample=sample)
# get_seq_kws validated 'sample'; map an int position to its entry name for 'col_imp'.
name = sample if is_str else df_parts.index[int(sample)]
if jmd_n_seq is None:
jmd_n_seq = seq_kws[f"{ut.COL_JMD_N}_seq"]
if tmd_seq is None:
tmd_seq = seq_kws[f"{ut.COL_TMD}_seq"]
if jmd_c_seq is None:
jmd_c_seq = seq_kws[f"{ut.COL_JMD_C}_seq"]
else:
if not is_str:
raise ValueError(f"'sample' ({sample}) must be an entry name (str) for 'ranking'; a row "
f"position is only supported for 'profile' / 'feature_map' (which have "
f"'df_parts' to map it).")
name = sample
col_imp = f"{ut.COL_FEAT_IMPACT}_{name}"
return col_imp, jmd_n_seq, tmd_seq, jmd_c_seq


# II Main Functions
class CPPPlot:
"""
Expand Down Expand Up @@ -573,6 +619,7 @@ def ranking(self,
xlim_dif: Union[Tuple[Union[int, float], Union[int, float]], None] = (-17.5, 17.5),
xlim_rank: Optional[Tuple[Union[int, float], Union[int, float]]] = (0, 4),
rank_info_xy: Optional[Tuple[Optional[Union[int, float]], Optional[Union[int, float]]]] = None,
sample: Optional[str] = None,
) -> Tuple[Figure, Axes]:
"""
Plot CPP/-SHAP feature ranking based on feature importance or sample-specific feature impact.
Expand Down Expand Up @@ -645,6 +692,11 @@ def ranking(self,
- When ``shap_plot=False``: Displays sum of feature importance.
- When ``shap_plot=True``: Show the sum of the absolute feature impact and the SHAP legend.

sample : str, optional
Convenience shortcut for sample-level CPP-SHAP ranking. When given (a protein entry name),
``col_imp`` is resolved to ``feat_impact_<sample>`` and ``shap_plot`` is set to ``True``
automatically, removing the manual ``col_imp=f"feat_impact_<name>"`` string-templating.

Returns
-------
fig : Figure
Expand All @@ -670,6 +722,10 @@ def ranking(self,
--------
.. include:: examples/cpp_plot_ranking.rst
"""
# Resolve sample-level SHAP shortcut: 'sample' sets col_imp='feat_impact_<sample>'
if sample is not None:
col_imp, _, _, _ = resolve_sample_kws(sample=sample, resolve_seq=False)
shap_plot = True
# Check input
ut.check_bool(name="shap_plot", val=shap_plot)
check_col_dif(col_dif=col_dif, shap_plot=shap_plot)
Expand Down Expand Up @@ -764,6 +820,9 @@ def profile(self,
ytick_size: Optional[Union[int, float]] = None,
ytick_width: Optional[Union[int, float]] = None,
ytick_length: Union[int, float] = 5.0,
sample: Optional[Union[str, int]] = None,
df_seq: Optional[pd.DataFrame] = None,
df_parts: Optional[pd.DataFrame] = None,
) -> Tuple[Figure, Axes]:
"""
Plot CPP/-SHAP profile showing feature importance/impact per residue position.
Expand Down Expand Up @@ -862,6 +921,18 @@ def profile(self,
Width of the y-ticks (>0).
ytick_length : int or float, default=5.0
Length of the y-ticks (>0).
sample : str or int, optional
Convenience shortcut for a sample-level CPP-SHAP profile. When given (an entry name or row
position), ``col_imp`` is resolved to ``feat_impact_<sample>``, the TMD-JMD sequence parts
(``jmd_n_seq``, ``tmd_seq``, ``jmd_c_seq``) are read from ``df_parts`` via
:meth:`SequenceFeature.get_seq_kws`, and ``shap_plot`` is set to ``True`` automatically.
Requires ``df_seq`` and ``df_parts``. Explicitly passed sequence parts are not overridden.
df_seq : pd.DataFrame, optional
DataFrame containing an ``entry`` column with unique protein identifiers; required when
``sample`` is given, to cross-check the resolved TMD-JMD parts.
df_parts : pd.DataFrame, optional
Sequence parts DataFrame (indexed by ``entry``) as produced by
:meth:`SequenceFeature.get_df_parts`; required when ``sample`` is given.

Returns
-------
Expand All @@ -884,6 +955,13 @@ def profile(self,
--------
.. include:: examples/cpp_plot_profile.rst
"""
# Resolve sample-level SHAP shortcut: 'sample' sets col_imp='feat_impact_<sample>'
# and the TMD-JMD parts from df_parts (via SequenceFeature.get_seq_kws).
if sample is not None:
col_imp, jmd_n_seq, tmd_seq, jmd_c_seq = resolve_sample_kws(
sample=sample, df_seq=df_seq, df_parts=df_parts,
jmd_n_seq=jmd_n_seq, tmd_seq=tmd_seq, jmd_c_seq=jmd_c_seq)
shap_plot = True
# Check primary input
ut.check_bool(name="shap_plot", val=shap_plot)
if col_imp is None:
Expand Down Expand Up @@ -1261,6 +1339,9 @@ def feature_map(self,
xtick_size: Union[int, float] = 11.0,
xtick_width: Union[int, float] = 2.0,
xtick_length: Union[int, float] = 5.0,
sample: Optional[Union[str, int]] = None,
df_seq: Optional[pd.DataFrame] = None,
df_parts: Optional[pd.DataFrame] = None,
) -> Tuple[Figure, Axes]:
"""
Plot Comparative Physicochemical Profiling (CPP) feature map showing feature value mean
Expand Down Expand Up @@ -1401,6 +1482,19 @@ def feature_map(self,
Width of the x-ticks (>0).
xtick_length : int or float, default=5.0
Length of the x-ticks (>0).
sample : str or int, optional
Convenience shortcut for a sample-level CPP-SHAP feature map. When given (an entry name or
row position), ``col_imp`` is resolved to ``feat_impact_<entry>`` (an int position is
mapped to its entry name via ``df_parts``), the TMD-JMD sequence
parts (``jmd_n_seq``, ``tmd_seq``, ``jmd_c_seq``) are read from ``df_parts`` via
:meth:`SequenceFeature.get_seq_kws`, and ``shap_plot`` is set to ``True`` automatically.
Requires ``df_seq`` and ``df_parts``. Explicitly passed sequence parts are not overridden.
df_seq : pd.DataFrame, optional
DataFrame containing an ``entry`` column with unique protein identifiers; required when
``sample`` is given, to cross-check the resolved TMD-JMD parts.
df_parts : pd.DataFrame, optional
Sequence parts DataFrame (indexed by ``entry``) as produced by
:meth:`SequenceFeature.get_df_parts`; required when ``sample`` is given.

Returns
-------
Expand Down Expand Up @@ -1429,6 +1523,13 @@ def feature_map(self,
--------
.. include:: examples/cpp_plot_feature_map.rst
"""
# Resolve sample-level SHAP shortcut: 'sample' sets col_imp='feat_impact_<sample>'
# and the TMD-JMD parts from df_parts (via SequenceFeature.get_seq_kws).
if sample is not None:
col_imp, jmd_n_seq, tmd_seq, jmd_c_seq = resolve_sample_kws(
sample=sample, df_seq=df_seq, df_parts=df_parts,
jmd_n_seq=jmd_n_seq, tmd_seq=tmd_seq, jmd_c_seq=jmd_c_seq)
shap_plot = True
# Check primary input
ut.check_bool(name="shap_plot", val=shap_plot)
ut.check_str_options(name="col_cat", val=col_cat,
Expand Down
Loading
Loading