diff --git a/aaanalysis/__init__.py b/aaanalysis/__init__.py index 4fc1edee..17989828 100644 --- a/aaanalysis/__init__.py +++ b/aaanalysis/__init__.py @@ -1,4 +1,4 @@ -from .data_handling import (load_dataset, load_scales, load_features, +from .data_handling import (load_dataset, load_scales, load_features, get_labels, read_fasta, to_fasta, SequencePreprocessor, EmbeddingPreprocessor, @@ -14,6 +14,8 @@ comp_per_protein_ap, comp_detection_metrics, comp_bootstrap_ci, comp_smooth_scores) from .config import options +from ._constants import (COLOR_SAMPLES_POS, COLOR_SAMPLES_NEG, + COLOR_SAMPLES_UNL, COLOR_SAMPLES_REL_NEG) from importlib.metadata import version as _version, PackageNotFoundError @@ -28,6 +30,7 @@ "load_dataset", "load_scales", "load_features", + "get_labels", "read_fasta", "to_fasta", "SequencePreprocessor", @@ -72,6 +75,10 @@ "comp_detection_metrics", "comp_bootstrap_ci", "comp_smooth_scores", + "COLOR_SAMPLES_POS", + "COLOR_SAMPLES_NEG", + "COLOR_SAMPLES_UNL", + "COLOR_SAMPLES_REL_NEG", "options" ] diff --git a/aaanalysis/_constants.py b/aaanalysis/_constants.py index 2c22e606..2ce81416 100644 --- a/aaanalysis/_constants.py +++ b/aaanalysis/_constants.py @@ -478,6 +478,15 @@ def _folder_path(super_folder, folder_name): COLOR_NEG = "#ad4570" # (173,69,112) COLOR_REL_NEG = "#ad9745" # (173, 151, 69) +# Public, named aliases for the canonical sample-group colors (positive / negative / +# unlabeled / reliable-negative). They mirror the ``DICT_COLOR["SAMPLES_*"]`` entries +# exactly, so users can reference a named constant (``aa.COLOR_SAMPLES_POS``) instead +# of indexing ``plot_get_cdict("DICT_COLOR")`` by string key. +COLOR_SAMPLES_POS = COLOR_POS +COLOR_SAMPLES_NEG = COLOR_NEG +COLOR_SAMPLES_UNL = COLOR_UNL +COLOR_SAMPLES_REL_NEG = COLOR_REL_NEG + DICT_COLOR = {"SHAP_POS": COLOR_SHAP_POS, "SHAP_NEG": COLOR_SHAP_NEG, "FEAT_POS": COLOR_FEAT_POS, diff --git a/aaanalysis/data_handling/__init__.py b/aaanalysis/data_handling/__init__.py index 1fad9397..7ea8c663 100644 --- a/aaanalysis/data_handling/__init__.py +++ b/aaanalysis/data_handling/__init__.py @@ -1,8 +1,8 @@ """ Data loading and sequence/embedding preprocessing — the package's data entry point. -Public objects: load_dataset, load_scales, load_features, read_fasta, to_fasta, -SequencePreprocessor, EmbeddingPreprocessor, combine_dict_nums. +Public objects: load_dataset, load_scales, load_features, get_labels, read_fasta, +to_fasta, SequencePreprocessor, EmbeddingPreprocessor, combine_dict_nums. Produces the core data objects the rest of the pipeline consumes: ``load_dataset`` yields ``df_seq``, ``load_scales`` yields ``df_scales`` (fed to ``feature_engineering.AAclust`` / ``CPP``), ``load_features`` yields a reference @@ -17,6 +17,7 @@ from ._load_dataset import load_dataset from ._load_scales import load_scales from ._load_features import load_features +from ._get_labels import get_labels from ._read_fasta import read_fasta from ._to_fasta import to_fasta from ._seq_preproc import SequencePreprocessor @@ -27,6 +28,7 @@ "load_dataset", "load_scales", "load_features", + "get_labels", "read_fasta", "to_fasta", "SequencePreprocessor", diff --git a/aaanalysis/data_handling/_get_labels.py b/aaanalysis/data_handling/_get_labels.py new file mode 100644 index 00000000..58f29e21 --- /dev/null +++ b/aaanalysis/data_handling/_get_labels.py @@ -0,0 +1,70 @@ +""" +This is a script for the frontend of the get_labels function, deriving a binary +label vector from a sequence DataFrame's label column. +""" +from typing import Any +import numpy as np +import pandas as pd + +import aaanalysis.utils as ut + + +# I Helper Functions +def check_match_df_positive_label(df=None, col_label=None, positive_label=None) -> None: + """Check that the positive label value is present in the label column.""" + present = set(df[col_label].tolist()) + if positive_label not in present: + raise ValueError(f"'positive_label' ({positive_label}) is not among the values of " + f"column '{col_label}' ({sorted(present, key=str)}).") + + +# II Main Functions +def get_labels(df: pd.DataFrame, + positive_label: Any = 1, + col_label: str = "label", + ) -> np.ndarray: + """ + Derive a binary ``int`` label vector from a column of a sequence DataFrame. + + Maps the value flagged as positive (``positive_label``) onto ``1`` and every other + value onto ``0``, the binary encoding consumed across the package (e.g. by + :meth:`CPP.run`, :class:`TreeModel`, and the ``labels`` argument of most tools). + This is the single-call form of the recurring ``(df[col] == x).astype(int).to_numpy()`` + expression. + + .. versionadded:: 1.1.0 + + Parameters + ---------- + df : pd.DataFrame, shape (n_samples, n_seq_info) + Sequence DataFrame (``df_seq``) containing the label column ``col_label``. + positive_label : int or str, default=1 + Value in ``col_label`` marking the positive class. All rows equal to it become + ``1``; all remaining rows become ``0``. Must be present in ``col_label``. + col_label : str, default='label' + Name of the column holding the (multi-value or already binary) labels. + + Returns + ------- + labels : array-like, shape (n_samples,) + Binary ``int`` label vector (``1`` = positive, ``0`` = otherwise), row-aligned + to ``df``. + + Notes + ----- + * The result equals ``(df[col_label] == positive_label).astype(int).to_numpy()``. + * Pass the resulting vector directly as the ``labels`` argument of CPP, TreeModel, + or other tools. For Positive-Unlabeled mining keep the package ``1`` (positive) / + ``2`` (unlabeled) markers instead and pass ``X_pos`` / ``X_unlabeled`` to :meth:`dPULearn.fit`. + + Examples + -------- + .. include:: examples/get_labels.rst + """ + # Check input + ut.check_str(name="col_label", val=col_label, accept_none=False) + ut.check_df(name="df", df=df, cols_required=col_label) + check_match_df_positive_label(df=df, col_label=col_label, positive_label=positive_label) + # Derive binary int label vector + labels = (df[col_label] == positive_label).astype(int).to_numpy() + return labels diff --git a/aaanalysis/pu_learning/_dpulearn.py b/aaanalysis/pu_learning/_dpulearn.py index 23b5f729..a2332d58 100644 --- a/aaanalysis/pu_learning/_dpulearn.py +++ b/aaanalysis/pu_learning/_dpulearn.py @@ -133,6 +133,15 @@ def check_match_X_X_neg(X=None, X_neg=None) -> None: raise ValueError(f"'n_features' does not match between 'X' (n={n_features}) and 'X_neg' (n={n_features_neg})") +def check_match_X_pos_X_unlabeled(X_pos=None, X_unlabeled=None) -> None: + """Check that positive and unlabeled feature matrices share the same feature dimension.""" + n_features_pos = X_pos.shape[1] + n_features_unl = X_unlabeled.shape[1] + if n_features_pos != n_features_unl: + raise ValueError(f"'n_features' does not match between 'X_pos' (n={n_features_pos}) and " + f"'X_unlabeled' (n={n_features_unl})") + + # II Main Functions class dPULearn(Wrapper): """ @@ -210,11 +219,14 @@ def __init__(self, # Output parameters (will be set during model fitting) self.labels_ = None self.df_pu_ = None + self.mask_neg_ = None # Main method def fit(self, - X: ut.ArrayLike2D, - labels: ut.ArrayLike1D, + X: Optional[ut.ArrayLike2D] = None, + labels: Optional[ut.ArrayLike1D] = None, + X_pos: Optional[ut.ArrayLike2D] = None, + X_unlabeled: Optional[ut.ArrayLike2D] = None, label_pos: int = 1, label_unl: int = 2, label_neg: Optional[int] = None, @@ -239,15 +251,30 @@ def fit(self, .. versionadded:: 0.1.0 + There are two input modes (provide exactly one): pass ``X`` + ``labels`` (a single feature + matrix with per-sample markers), or — for the common positives-vs-unlabeled setup — pass the + two matrices ``X_pos`` and ``X_unlabeled`` separately, which are stacked internally with the + package markers. Either way, after fitting :attr:`dPULearn.mask_neg_` is the boolean mask of + reliable negatives (over ``X_unlabeled`` in the split mode, over ``X`` otherwise). + Parameters ---------- - X : array-like, shape (n_samples, n_features) - Feature matrix. `Rows` typically correspond to proteins and `columns` to features. - labels : array-like, shape (n_samples,) + X : array-like, shape (n_samples, n_features), optional + Feature matrix. `Rows` typically correspond to proteins and `columns` to features. Provide + ``X`` + ``labels``, or ``X_pos`` + ``X_unlabeled`` (exactly one of the two modes). + labels : array-like, shape (n_samples,), optional Dataset labels of samples in ``X``. Must contain the positive marker (``label_pos``) and the unlabeled marker (``label_unl``); pre-labeled negatives (``label_neg``) are optional. By default positives are ``1`` and unlabeled are ``2``; set ``label_unl=0`` to pass the standard ``{0, 1}`` encoding directly (``0`` = unlabeled, ``1`` = positive). + X_pos : array-like, shape (n_pos, n_features), optional + Feature matrix of the positive samples (split-input mode). Provided together with + ``X_unlabeled`` instead of ``X`` + ``labels``; the two are stacked and marked internally + (positives ``label_pos``, unlabeled ``label_unl``), so no manual label vector is needed. + X_unlabeled : array-like, shape (n_unl, n_features), optional + Feature matrix of the unlabeled candidate pool (split-input mode). Must have the same number + of features as ``X_pos``. After fitting, :attr:`dPULearn.mask_neg_` is a boolean mask over its + rows marking the identified reliable negatives. label_pos : int, default=1 Value marking positive samples in ``labels``. Must be present. label_unl : int, default=2 @@ -322,6 +349,24 @@ def fit(self, -------- .. include:: examples/dpul_fit.rst """ + # Resolve the input mode: (X, labels) or the positives/unlabeled split. In the split + # mode, stack X_pos over X_unlabeled and build the label vector internally with the + # package markers, so the caller does not hand-roll the vstack + 1/2 vector + slice. + split_mode = X_pos is not None or X_unlabeled is not None + n_pos = None + if split_mode: + if X is not None or labels is not None: + raise ValueError("Pass either 'X'/'labels' or 'X_pos'/'X_unlabeled', not both.") + if X_pos is None or X_unlabeled is None: + raise ValueError("'X_pos' and 'X_unlabeled' must both be given for the split-input mode.") + X_pos = ut.check_X(X=X_pos, X_name="X_pos", min_n_samples=1) + X_unlabeled = ut.check_X(X=X_unlabeled, X_name="X_unlabeled", min_n_samples=1) + check_match_X_pos_X_unlabeled(X_pos=X_pos, X_unlabeled=X_unlabeled) + n_pos = X_pos.shape[0] + X = np.vstack([X_pos, X_unlabeled]) + labels = np.array([label_pos] * n_pos + [label_unl] * X_unlabeled.shape[0]) + elif X is None or labels is None: + raise ValueError("'X' and 'labels' are required (or pass 'X_pos' + 'X_unlabeled').") # Check input X = ut.check_X(X=X) check_match_labels_markers(label_pos=label_pos, label_unl=label_unl, label_neg=label_neg) @@ -353,9 +398,11 @@ def fit(self, # Identify most far away negatives in PCA compressed feature space else: new_labels, df_pu = get_neg_via_pca(**args, n_components=n_components, **self._model_kwargs) - # Set new labels + # Set new labels + the reliable-negative mask. In the split-input mode the mask is over + # the rows of X_unlabeled (True = mined reliable negative); otherwise over all rows of X. self.labels_ = np.asarray(new_labels) self.df_pu_ = df_pu + self.mask_neg_ = self.labels_[n_pos:] == 0 if n_pos is not None else self.labels_ == 0 return self @staticmethod diff --git a/docs/_cheatsheet/content.py b/docs/_cheatsheet/content.py index 801adc7f..3f0d58f3 100644 --- a/docs/_cheatsheet/content.py +++ b/docs/_cheatsheet/content.py @@ -188,6 +188,7 @@ ("Load benchmark sequences", "load_dataset(name) → df_seq", None), ("Load AAontology scales", "load_scales() → df_scales", None), ("Load precomputed features", "load_features(name) → df_feat", None), + ("Binary labels from df column", "get_labels(df, positive_label) → labels", None, "v1.1"), ("Read / write FASTA", "read_fasta(file) → df_seq", None), ("Cluster redundant homologs", "filter_seq(df_seq) → df_clust [pro]", None), ]}, @@ -221,6 +222,7 @@ {"name": "Modeling & Explainability", "tag": "PU · classify · SHAP", "rows": [ ("Train with positives + unlabeled data", "dPULearn().fit(X, labels) [Wrapper]", None), + ("Mine reliable negatives (mask)", "dPULearn().fit(X_pos=, X_unlabeled=).mask_neg_ → mask", None, "v1.1"), ("Train + RFE + MC importance", "TreeModel().fit(X, labels) [Wrapper]", None), ("Per-feature / sample SHAP impact", "ShapModel().fit(X, labels) [pro]", None), ]}, diff --git a/docs/source/index/release_notes.rst b/docs/source/index/release_notes.rst index c715505d..1c1bc782 100644 --- a/docs/source/index/release_notes.rst +++ b/docs/source/index/release_notes.rst @@ -35,6 +35,11 @@ Added per-residue PTM and functional-site annotations and encodes them into tensors (``fetch_uniprot``, ``ingest``, ``register_feature``, ``encode``, ``build_scales``, ``build_cat``, ``to_df_seq``). +- **combine_dict_nums**: Concatenates per-residue tensors (embedding / structure / + annotation) along the feature axis into one combined ``CPP.run_num`` input. +- **get_labels**: Derives a binary ``int`` label vector from a sequence DataFrame's + label column (``positive_label`` mapped to ``1``, everything else to ``0``) — the + single-call form of the recurring ``(df[col] == x).astype(int).to_numpy()`` expression. - :func:`~aaanalysis.combine_dict_nums`: Concatenates per-residue tensors (embedding / structure / annotation) along the feature axis into one combined :meth:`~aaanalysis.CPP.run_num` input. @@ -132,6 +137,16 @@ Added switches the pre-computed prediction per P1 (feature map + structure restyle) with no kernel, keeping the column-residue linking (warned past 40 sites, hard-capped at 200). +**PU Learning** + +- **dPULearn.fit — positives/unlabeled split input**: for the common positive / unlabeled + setup, ``fit`` now accepts ``X_pos`` and ``X_unlabeled`` separately (an alternative to + ``X`` + ``labels``) instead of stacking them by hand and building a ``1`` / ``2`` label + vector. After fitting, the new ``dPULearn.mask_neg_`` attribute holds the **boolean mask + of reliable negatives** — over the rows of ``X_unlabeled`` in the split mode, over ``X`` + otherwise (equal to the manual ``labels_[len(X_pos):] == 0`` result exactly). ``fit`` still + returns ``self`` and the existing ``fit(X, labels=...)`` path is unchanged. + **Sequence Analysis** - :class:`~aaanalysis.AAWindowSampler`: Samples fixed-length sequence windows for PU-learning and @@ -191,6 +206,10 @@ Added - :func:`~aaanalysis.plot_rank`: Standalone per-protein max-score-vs-rank scatter with group coloring and optional threshold lines (pairs with the new ``aa.metrics`` functions). +- **COLOR_SAMPLES_POS / COLOR_SAMPLES_NEG / COLOR_SAMPLES_UNL / COLOR_SAMPLES_REL_NEG**: + Public, named constants for the canonical sample-group colors (positive / negative / + unlabeled / reliable-negative). They equal the ``plot_get_cdict("DICT_COLOR")["SAMPLES_*"]`` + values exactly, so a named constant replaces indexing the color dict by string key. **Golden Pipelines** diff --git a/examples/data_handling/get_labels.ipynb b/examples/data_handling/get_labels.ipynb new file mode 100644 index 00000000..ef84686a --- /dev/null +++ b/examples/data_handling/get_labels.ipynb @@ -0,0 +1,297 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "30e77e51", + "metadata": {}, + "source": [ + "The ``get_labels`` function derives a binary ``int`` label vector from a column of a sequence DataFrame (``df_seq``). It is the single-call form of the recurring ``(df[col] == positive_label).astype(int).to_numpy()`` expression: the value flagged as positive becomes ``1`` and every other value becomes ``0``." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "14210edc", + "metadata": { + "execution": { + "iopub.execute_input": "2026-06-30T23:29:34.142992Z", + "iopub.status.busy": "2026-06-30T23:29:34.142637Z", + "iopub.status.idle": "2026-06-30T23:29:35.599732Z", + "shell.execute_reply": "2026-06-30T23:29:35.599479Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DataFrame shape: (10, 8)\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
 entrysequencelabeltmd_starttmd_stopjmd_ntmdjmd_c
1P05067MLPGLALLLLAAWTA...GYENPTYKFFEQMQN1701723FAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHH
2P14925MAGRARSGLLLLLLG...EEEYSAPLPKPAPSS1868890KLSTEPGSGVSVVLITTLLVIPVLVLLAIVMFIRWKKSRAFGD
3P70180MRSLLLFTFSACVLL...RELREDSIRSHFSVA1477499PCKSSGGLEESAVTGIVVGALLGAGLLMAFYFFRKKYRITIER
4Q03157MGPTSPAARGQGRRW...HGYENPTYRFLEERP1585607APSGTGVSREALSGLLIMGAGGGSLIVLSLLLLRKKKPYGTIS
5Q06481MAATGTAAAAATGRL...GYENPTYKYLEQMQI1694716LREDFSLSSSALIGLLVIAVAIATVIVISLVMLRKRQYGTISH
6P12821MGAASGRRGPGLLLP...SHGPQFGSEVELRHS212571276GLDLDAQQARVGQWLLLFLGIALLVATLGLSQRLFSIRHR
7P36896MAESAGASSFFPLVV...KKTLSQLSVQEDVKI2127149EHPSMWGPVELVGIIAGPVFLLFLIIIIVFLVINYHQRVYHNR
8Q8NER5MTRALCSALRQALLL...KKTISQLCVKEDCKA2114136PNAPKLGPMELAIIITVPVCLLSIAAMLTVWACQGRQCSYRKK
9P37023MTLGSPRKGLLMLLM...LQKISNSPEKPKVIQ2119141PSEQPGTDGQLALILGPVLALLALVALGVLGLWHVRRRQEKQR
10O43184MAARPLPVSPARALL...YPHQVPRSTHTAYIK2707729DSGPIRQADNQGLTIGILVTILCLLAAGFVVYLKRKTLIRLLF
\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import aaanalysis as aa\n", + "aa.options[\"verbose\"] = False\n", + "\n", + "# A Positive-Unlabeled (PU) dataset: substrates (1) and unlabeled others (2).\n", + "df_seq = aa.load_dataset(name=\"DOM_GSEC_PU\", n=5)\n", + "aa.display_df(df=df_seq, n_rows=10, show_shape=True)" + ] + }, + { + "cell_type": "markdown", + "id": "e84a0e34", + "metadata": {}, + "source": [ + "By default ``positive_label=1``: substrates map to ``1`` and everything else to ``0``." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a5ccd25b", + "metadata": { + "execution": { + "iopub.execute_input": "2026-06-30T23:29:35.600825Z", + "iopub.status.busy": "2026-06-30T23:29:35.600758Z", + "iopub.status.idle": "2026-06-30T23:29:35.602749Z", + "shell.execute_reply": "2026-06-30T23:29:35.602535Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1 1 1 1 1 0 0 0 0 0]\n" + ] + } + ], + "source": [ + "labels = aa.get_labels(df=df_seq, positive_label=1)\n", + "print(labels)" + ] + }, + { + "cell_type": "markdown", + "id": "5db29b65", + "metadata": {}, + "source": [ + "Pick any value as the positive class via ``positive_label`` (e.g. treat the unlabeled ``2`` as positive), and select a different column with ``col_label``. The result equals the manual ``(df[col_label] == positive_label).astype(int).to_numpy()`` expression." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "d9a49747", + "metadata": { + "execution": { + "iopub.execute_input": "2026-06-30T23:29:35.603686Z", + "iopub.status.busy": "2026-06-30T23:29:35.603626Z", + "iopub.status.idle": "2026-06-30T23:29:35.605515Z", + "shell.execute_reply": "2026-06-30T23:29:35.605342Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0 0 0 0 0 1 1 1 1 1]\n" + ] + } + ], + "source": [ + "labels_unl = aa.get_labels(df=df_seq, positive_label=2, col_label=\"label\")\n", + "print(labels_unl)" + ] + }, + { + "cell_type": "markdown", + "id": "82fd2bd1", + "metadata": {}, + "source": [ + "Pass the resulting vector straight into the ``labels`` argument of tools such as :meth:`CPP.run` or :class:`TreeModel`." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/pu_learning/dpul_fit.ipynb b/examples/pu_learning/dpul_fit.ipynb index 3155b6d6..a76bb4f0 100644 --- a/examples/pu_learning/dpul_fit.ipynb +++ b/examples/pu_learning/dpul_fit.ipynb @@ -7,7 +7,12 @@ "collapsed": false }, "source": [ - "To demonstrate the ``dPULearn().fit()``method, we create a small example dataset containing positive (1) and unlabeled (2) data samples:" + "``dPULearn().fit()`` identifies reliable negatives (labeled ``0``) from unlabeled samples, based on their distance to the positives. There are **two ways to provide the input** (choose one):\n", + "\n", + "- **Option 1 — ``X`` + ``labels``**: a single feature matrix with a per-sample label vector (positives ``1``, unlabeled ``2``; pre-labeled negatives optional). The general, flexible form — the input encoding is configurable via ``label_pos`` / ``label_unl`` / ``label_neg``.\n", + "- **Option 2 — ``X_pos`` + ``X_unlabeled``**: the positives and the unlabeled pool passed as two separate matrices (the common positives-vs-unlabeled setup). They are stacked internally, so no manual label vector is needed, and the mined negatives are exposed as the ``mask_neg_`` attribute.\n", + "\n", + "Either way ``fit`` returns the fitted model. We start with **Option 1**, creating a small example dataset with positive (``1``) and unlabeled (``2``) samples:" ] }, { @@ -21,10 +26,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-13T11:04:31.979361Z", - "iopub.status.busy": "2026-06-13T11:04:31.979269Z", - "iopub.status.idle": "2026-06-13T11:04:33.219287Z", - "shell.execute_reply": "2026-06-13T11:04:33.218979Z" + "iopub.execute_input": "2026-07-03T17:44:19.857813Z", + "iopub.status.busy": "2026-07-03T17:44:19.857731Z", + "iopub.status.idle": "2026-07-03T17:44:21.698305Z", + "shell.execute_reply": "2026-07-03T17:44:21.697468Z" } }, "outputs": [], @@ -59,10 +64,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-13T11:04:33.220594Z", - "iopub.status.busy": "2026-06-13T11:04:33.220480Z", - "iopub.status.idle": "2026-06-13T11:04:33.244615Z", - "shell.execute_reply": "2026-06-13T11:04:33.244408Z" + "iopub.execute_input": "2026-07-03T17:44:21.701160Z", + "iopub.status.busy": "2026-07-03T17:44:21.700652Z", + "iopub.status.idle": "2026-07-03T17:44:21.821262Z", + "shell.execute_reply": "2026-07-03T17:44:21.820775Z" } }, "outputs": [ @@ -70,58 +75,58 @@ "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 selection_viaPC1 (100.0%)PC1 (100.0%)_abs_difselection_viaPC1 (100.0%)PC1 (100.0%)_abs_dif
1nan-0.4000000.0000001nan-0.4000000.000000
2nan-0.2000000.2000002nan-0.2000000.200000
3nan0.4000000.8000003nan0.4000000.800000
4PC10.8000001.2000004PC10.8000001.200000
\n" @@ -156,10 +161,10 @@ "id": "9f79293e", "metadata": { "execution": { - "iopub.execute_input": "2026-06-13T11:04:33.245783Z", - "iopub.status.busy": "2026-06-13T11:04:33.245704Z", - "iopub.status.idle": "2026-06-13T11:04:33.249988Z", - "shell.execute_reply": "2026-06-13T11:04:33.249784Z" + "iopub.execute_input": "2026-07-03T17:44:21.828319Z", + "iopub.status.busy": "2026-07-03T17:44:21.827465Z", + "iopub.status.idle": "2026-07-03T17:44:21.849251Z", + "shell.execute_reply": "2026-07-03T17:44:21.848839Z" } }, "outputs": [ @@ -167,58 +172,58 @@ "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 selection_viaPC1 (100.0%)PC1 (100.0%)_abs_difselection_viaPC1 (100.0%)PC1 (100.0%)_abs_dif
1nan-0.4000000.0000001nan-0.4000000.000000
2nan-0.2000000.2000002nan-0.2000000.200000
3nan0.4000000.8000003nan0.4000000.800000
4PC10.8000001.2000004PC10.8000001.200000
\n" @@ -254,10 +259,10 @@ "id": "c25bfee1", "metadata": { "execution": { - "iopub.execute_input": "2026-06-13T11:04:33.250939Z", - "iopub.status.busy": "2026-06-13T11:04:33.250878Z", - "iopub.status.idle": "2026-06-13T11:04:33.255134Z", - "shell.execute_reply": "2026-06-13T11:04:33.254925Z" + "iopub.execute_input": "2026-07-03T17:44:21.855718Z", + "iopub.status.busy": "2026-07-03T17:44:21.851787Z", + "iopub.status.idle": "2026-07-03T17:44:21.889166Z", + "shell.execute_reply": "2026-07-03T17:44:21.887236Z" } }, "outputs": [ @@ -275,70 +280,70 @@ "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 selection_viaPC1 (100.0%)PC1 (100.0%)_abs_difselection_viaPC1 (100.0%)PC1 (100.0%)_abs_dif
1nan-0.3086000.0000001nan-0.3086000.000000
2nan-0.1543000.1543002nan-0.1543000.154300
3nan0.3086000.6172003nan0.3086000.617200
4PC10.6172000.9258004PC10.6172000.925800
5PC10.6172000.9258005PC10.6172000.925800
6nan0.1543000.4629006nan0.1543000.462900
\n" @@ -366,8 +371,50 @@ { "cell_type": "markdown", "id": "4e027d56", - "source": "``n_neg`` is the **total** number of negatives wanted (pre-labeled plus newly identified). Alternatively, use ``n_unl_to_neg`` to control the number identified **directly from the unlabeled pool**, independent of any pre-labeled negatives (final negatives = pre-labeled + ``n_unl_to_neg``). Provide exactly one of the two.", - "metadata": {} + "metadata": {}, + "source": [ + "``n_neg`` is the **total** number of negatives wanted (pre-labeled plus newly identified). Alternatively, use ``n_unl_to_neg`` to control the number identified **directly from the unlabeled pool**, independent of any pre-labeled negatives (final negatives = pre-labeled + ``n_unl_to_neg``). Provide exactly one of the two." + ] + }, + { + "cell_type": "markdown", + "id": "split-md", + "metadata": {}, + "source": [ + "**Option 2 — ``X_pos`` + ``X_unlabeled`` (positives/unlabeled split).** For the common positives-vs-unlabeled setup, pass the two matrices to ``fit`` separately instead of stacking them and building a ``1``/``2`` label vector by hand. After fitting, ``dPULearn.mask_neg_`` is the boolean mask of reliable negatives over the rows of ``X_unlabeled``:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "split-code", + "metadata": { + "execution": { + "iopub.execute_input": "2026-07-03T17:44:21.892077Z", + "iopub.status.busy": "2026-07-03T17:44:21.891845Z", + "iopub.status.idle": "2026-07-03T17:44:21.901188Z", + "shell.execute_reply": "2026-07-03T17:44:21.900160Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2 reliable negatives mined from 4 unlabeled samples\n" + ] + } + ], + "source": [ + "X_pos = np.array([[0.2, 0.1], [0.25, 0.2]])\n", + "X_unlabeled = np.array([[0.2, 0.3], [0.5, 0.7], [0.6, 0.8], [0.1, 0.15]])\n", + "\n", + "dpul = aa.dPULearn()\n", + "dpul.fit(X_pos=X_pos, X_unlabeled=X_unlabeled, n_neg=2)\n", + "mask_neg = dpul.mask_neg_ # boolean mask over X_unlabeled (True = reliable negative)\n", + "X_neg = X_unlabeled[mask_neg] # the mined reliable negatives\n", + "print(f\"{X_neg.shape[0]} reliable negatives mined from {len(X_unlabeled)} unlabeled samples\")" + ] }, { "cell_type": "markdown", @@ -381,7 +428,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "id": "fb7ea317dcf0c81c", "metadata": { "ExecuteTime": { @@ -390,10 +437,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-13T11:04:33.256104Z", - "iopub.status.busy": "2026-06-13T11:04:33.256031Z", - "iopub.status.idle": "2026-06-13T11:04:33.274095Z", - "shell.execute_reply": "2026-06-13T11:04:33.273857Z" + "iopub.execute_input": "2026-07-03T17:44:21.909294Z", + "iopub.status.busy": "2026-07-03T17:44:21.908835Z", + "iopub.status.idle": "2026-07-03T17:44:21.940023Z", + "shell.execute_reply": "2026-07-03T17:44:21.938888Z" } }, "outputs": [ @@ -408,76 +455,76 @@ "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 entrysequencelabeltmd_starttmd_stopentrysequencelabeltmd_starttmd_stop
690P60852MAGGSATTWGYPVAL...LSQTWAQKLWESNRQ2602624690P60852MAGGSATTWGYPVAL...LSQTWAQKLWESNRQ2602624
691P20239MARWQRKASVSSPCG...FICYLYKKRTIRFNH2684703691P20239MARWQRKASVSSPCG...FICYLYKKRTIRFNH2684703
692P21754MELSYRLFICLLLWG...TRRCRTASHPVSASE2387409692P21754MELSYRLFICLLLWG...TRRCRTASHPVSASE2387409
693Q12836MWLLRCVLLCVSLSL...LAVKKQKSCPDQMCQ2506528693Q12836MWLLRCVLLCVSLSL...LAVKKQKSCPDQMCQ2506528
694Q8TCW7MEQIWLLLLLTIRVL...PTSLVLNGIRNPVFD2374396694Q8TCW7MEQIWLLLLLTIRVL...PTSLVLNGIRNPVFD2374396
\n" @@ -509,7 +556,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "id": "772eb0fd613630d6", "metadata": { "ExecuteTime": { @@ -518,10 +565,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-13T11:04:33.275142Z", - "iopub.status.busy": "2026-06-13T11:04:33.275034Z", - "iopub.status.idle": "2026-06-13T11:04:33.509151Z", - "shell.execute_reply": "2026-06-13T11:04:33.508920Z" + "iopub.execute_input": "2026-07-03T17:44:21.942383Z", + "iopub.status.busy": "2026-07-03T17:44:21.942175Z", + "iopub.status.idle": "2026-07-03T17:44:22.055488Z", + "shell.execute_reply": "2026-07-03T17:44:22.054329Z" } }, "outputs": [ @@ -539,208 +586,200 @@ "DataFrame shape: (63, 15)\n" ] }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/stephanbreimann/Programming/1Packages/aaanalysis-wt-dpulearn-157/aaanalysis/feature_engineering/_backend/cpp_run.py:143: UserWarning: CPP is using the Python kernel fallback — the compiled Cython extension is not available in this install. Output is bit-exact with the Cython path but ~2x slower. Reinstall via `pip install --force-reinstall aaanalysis` to fetch a prebuilt wheel.\n", - " warnings.warn(\n" - ] - }, { "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 selection_viaPC1 (56.2%)PC2 (7.4%)PC3 (2.9%)PC4 (2.8%)selection_viaPC1 (56.2%)PC2 (7.4%)PC3 (2.9%)PC4 (2.8%)
81PC30.0336000.0073000.098200-0.00780081PC30.0336000.0073000.098200-0.007800
82PC70.033400-0.0411000.033500-0.00520082PC70.033400-0.0411000.033500-0.005200
84PC10.021000-0.0478000.075200-0.00540084PC10.021000-0.0478000.075200-0.005400
90PC40.039000-0.032000-0.0013000.11090090PC40.039000-0.032000-0.0013000.110900
95PC20.032000-0.0821000.025800-0.03770095PC20.032000-0.0821000.025800-0.037700
109PC10.026100-0.0585000.075700-0.020900109PC10.026100-0.0585000.075700-0.020900
149PC10.026500-0.0380000.0191000.045500149PC10.026500-0.0380000.0191000.045500
158PC10.023500-0.0607000.0540000.000900158PC10.023500-0.0607000.0540000.000900
161PC10.0259000.0314000.0449000.055400161PC10.0259000.0314000.0449000.055400
169PC10.026500-0.0099000.012500-0.016700169PC10.026500-0.0099000.012500-0.016700
170PC10.026100-0.0353000.0583000.025800170PC10.026100-0.0353000.0583000.025800
187PC10.0261000.0188000.0506000.038600187PC10.0261000.0188000.0506000.038600
192PC60.040100-0.0022000.004300-0.053600192PC60.040100-0.0022000.004300-0.053600
193PC10.024700-0.0569000.051300-0.035600193PC10.024700-0.0569000.051300-0.035600
195PC50.0299000.0065000.0358000.050200195PC50.0299000.0065000.0358000.050200
200PC10.021200-0.0562000.0057000.072600200PC10.021200-0.0562000.0057000.072600
204PC10.025500-0.0071000.062900-0.052500204PC10.025500-0.0071000.062900-0.052500
223PC10.018800-0.0436000.048500-0.072700223PC10.018800-0.0436000.048500-0.072700
254PC10.021500-0.0129000.0715000.038500254PC10.021500-0.0129000.0715000.038500
264PC40.0405000.023100-0.0247000.113800264PC40.0405000.023100-0.0247000.113800
\n" @@ -785,7 +824,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "id": "6e9aa4b93419060c", "metadata": { "ExecuteTime": { @@ -794,10 +833,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-13T11:04:33.510285Z", - "iopub.status.busy": "2026-06-13T11:04:33.510223Z", - "iopub.status.idle": "2026-06-13T11:04:33.557421Z", - "shell.execute_reply": "2026-06-13T11:04:33.557158Z" + "iopub.execute_input": "2026-07-03T17:44:22.058397Z", + "iopub.status.busy": "2026-07-03T17:44:22.058162Z", + "iopub.status.idle": "2026-07-03T17:44:22.185037Z", + "shell.execute_reply": "2026-07-03T17:44:22.184451Z" } }, "outputs": [], @@ -817,7 +856,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "id": "a3f2a24092217134", "metadata": { "ExecuteTime": { @@ -826,10 +865,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-13T11:04:33.558645Z", - "iopub.status.busy": "2026-06-13T11:04:33.558583Z", - "iopub.status.idle": "2026-06-13T11:04:33.563060Z", - "shell.execute_reply": "2026-06-13T11:04:33.562866Z" + "iopub.execute_input": "2026-07-03T17:44:22.188324Z", + "iopub.status.busy": "2026-07-03T17:44:22.188099Z", + "iopub.status.idle": "2026-07-03T17:44:22.200066Z", + "shell.execute_reply": "2026-07-03T17:44:22.199231Z" } }, "outputs": [ @@ -844,94 +883,94 @@ "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 selection_viaeuclidean_difeuclidean_abs_difselection_viaeuclidean_difeuclidean_abs_dif
84euclidean3.4807003.48070084euclidean3.4807003.480700
505euclidean3.2327003.232700505euclidean3.2327003.232700
509euclidean3.3363003.336300509euclidean3.3363003.336300
526euclidean3.3897003.389700526euclidean3.3897003.389700
533euclidean3.3639003.363900533euclidean3.3639003.363900
542euclidean3.0750003.075000542euclidean3.0750003.075000
546euclidean3.1625003.162500546euclidean3.1625003.162500
548euclidean3.1119003.111900548euclidean3.1119003.111900
552euclidean3.2886003.288600552euclidean3.2886003.288600
553euclidean3.6208003.620800553euclidean3.6208003.620800
\n" @@ -961,7 +1000,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "id": "77d2558171e91dee", "metadata": { "ExecuteTime": { @@ -970,10 +1009,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-13T11:04:33.564068Z", - "iopub.status.busy": "2026-06-13T11:04:33.563993Z", - "iopub.status.idle": "2026-06-13T11:04:33.577083Z", - "shell.execute_reply": "2026-06-13T11:04:33.576867Z" + "iopub.execute_input": "2026-07-03T17:44:22.202407Z", + "iopub.status.busy": "2026-07-03T17:44:22.202162Z", + "iopub.status.idle": "2026-07-03T17:44:22.238387Z", + "shell.execute_reply": "2026-07-03T17:44:22.236293Z" } }, "outputs": [ @@ -988,483 +1027,483 @@ "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 selection_viaPC1 (56.2%)PC2 (7.4%)PC3 (2.9%)selection_viaPC1 (56.2%)PC2 (7.4%)PC3 (2.9%)
497PC10.022500-0.0512000.013400497PC10.022500-0.0512000.013400
615PC10.026100-0.0533000.099300615PC10.026100-0.0533000.099300
406PC10.025400-0.0308000.027200406PC10.025400-0.0308000.027200
446PC10.026200-0.0137000.054500446PC10.026200-0.0137000.054500
455PC10.026600-0.0521000.089500455PC10.026600-0.0521000.089500
468PC10.025600-0.0688000.011800468PC10.025600-0.0688000.011800
471PC10.025000-0.0055000.083500471PC10.025000-0.0055000.083500
668PC10.023200-0.0169000.076500668PC10.023200-0.0169000.076500
605PC10.025800-0.0545000.006700605PC10.025800-0.0545000.006700
505PC10.023100-0.0484000.033900505PC10.023100-0.0484000.033900
509PC10.022800-0.0563000.086300509PC10.022800-0.0563000.086300
604PC10.024800-0.0788000.027600604PC10.024800-0.0788000.027600
526PC10.022500-0.0557000.038200526PC10.022500-0.0557000.038200
600PC10.019600-0.0442000.092700600PC10.019600-0.0442000.092700
534PC10.026100-0.032300-0.019400534PC10.026100-0.032300-0.019400
542PC10.026400-0.0391000.049200542PC10.026400-0.0391000.049200
545PC10.0262000.0072000.039200545PC10.0262000.0072000.039200
548PC10.025200-0.0569000.039300548PC10.025200-0.0569000.039300
552PC10.026000-0.072800-0.031000552PC10.026000-0.072800-0.031000
553PC10.020500-0.0775000.079700553PC10.020500-0.0775000.079700
336PC10.026200-0.0207000.032200336PC10.026200-0.0207000.032200
329PC10.025800-0.0149000.043100329PC10.025800-0.0149000.043100
624PC10.0265000.0345000.046800624PC10.0265000.0345000.046800
308PC10.025400-0.0306000.033100308PC10.025400-0.0306000.033100
84PC10.021000-0.0478000.07520084PC10.021000-0.0478000.075200
649PC10.022800-0.0324000.108000649PC10.022800-0.0324000.108000
637PC10.022600-0.0578000.044500637PC10.022600-0.0578000.044500
109PC10.026100-0.0585000.075700109PC10.026100-0.0585000.075700
149PC10.026500-0.0380000.019100149PC10.026500-0.0380000.019100
158PC10.023500-0.0607000.054000158PC10.023500-0.0607000.054000
161PC10.0259000.0314000.044900161PC10.0259000.0314000.044900
169PC10.026500-0.0099000.012500169PC10.026500-0.0099000.012500
569PC10.022100-0.0436000.065400569PC10.022100-0.0436000.065400
170PC10.026100-0.0353000.058300170PC10.026100-0.0353000.058300
635PC10.0254000.0406000.054600635PC10.0254000.0406000.054600
193PC10.024700-0.0569000.051300193PC10.024700-0.0569000.051300
634PC10.026000-0.0422000.007900634PC10.026000-0.0422000.007900
200PC10.021200-0.0562000.005700200PC10.021200-0.0562000.005700
204PC10.025500-0.0071000.062900204PC10.025500-0.0071000.062900
223PC10.018800-0.0436000.048500223PC10.018800-0.0436000.048500
254PC10.021500-0.0129000.071500254PC10.021500-0.0129000.071500
628PC10.025600-0.0272000.051300628PC10.025600-0.0272000.051300
300PC10.024900-0.0135000.052900300PC10.024900-0.0135000.052900
187PC10.0261000.0188000.050600187PC10.0261000.0188000.050600
585PC10.022400-0.0222000.087800585PC10.022400-0.0222000.087800
658PC20.035200-0.081100-0.040700658PC20.035200-0.081100-0.040700
683PC20.028700-0.1032000.011900683PC20.028700-0.1032000.011900
533PC20.039300-0.094200-0.045800533PC20.039300-0.094200-0.045800
337PC20.035100-0.102700-0.021700337PC20.035100-0.102700-0.021700
322PC20.041300-0.096300-0.075700322PC20.041300-0.096300-0.075700
95PC20.032000-0.0821000.02580095PC20.032000-0.0821000.025800
524PC30.0316000.0284000.106200524PC30.0316000.0284000.106200
632PC30.0301000.0225000.090800632PC30.0301000.0225000.090800
81PC30.0336000.0073000.09820081PC30.0336000.0073000.098200
264PC40.0405000.023100-0.024700264PC40.0405000.023100-0.024700
90PC40.039000-0.032000-0.00130090PC40.039000-0.032000-0.001300
591PC40.031300-0.0040000.032100591PC40.031300-0.0040000.032100
195PC50.0299000.0065000.035800195PC50.0299000.0065000.035800
641PC50.0435000.0065000.015200641PC50.0435000.0065000.015200
501PC60.042100-0.018500-0.050200501PC60.042100-0.018500-0.050200
192PC60.040100-0.0022000.004300192PC60.040100-0.0022000.004300
82PC70.033400-0.0411000.03350082PC70.033400-0.0411000.033500
666PC70.0352000.075600-0.011600666PC70.0352000.075600-0.011600
1nan0.0524000.039300-0.0663001nan0.0524000.039300-0.066300
\n" @@ -1499,9 +1538,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.14.0" + "version": "3.13.11" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/tests/unit/data_handling_tests/test_get_labels.py b/tests/unit/data_handling_tests/test_get_labels.py new file mode 100644 index 00000000..0d63e291 --- /dev/null +++ b/tests/unit/data_handling_tests/test_get_labels.py @@ -0,0 +1,114 @@ +""" +This script tests the top-level get_labels() function (issue #308). + +get_labels is the single-call form of the recurring +``(df[col] == positive_label).astype(int).to_numpy()`` expression that appears in 4+ places +of the gamma-secretase use case. It maps the positive value onto 1 and everything else onto 0. +""" +import numpy as np +import pandas as pd +import pytest + +import aaanalysis as aa + + +# Helper functions +def _manual(df, positive_label, col="label"): + return (df[col] == positive_label).astype(int).to_numpy() + + +# Normal Cases Test Class +class TestGetLabels: + """Test get_labels() for each parameter individually.""" + + def test_returns_int_numpy_array(self): + df = pd.DataFrame({"entry": ["a", "b", "c"], "label": [1, 2, 1]}) + labels = aa.get_labels(df=df, positive_label=1) + assert isinstance(labels, np.ndarray) + assert labels.dtype.kind == "i" + assert labels.shape == (3,) + + def test_positive_label_default(self): + df = pd.DataFrame({"label": [1, 0, 1, 0]}) + labels = aa.get_labels(df=df) + assert np.array_equal(labels, np.array([1, 0, 1, 0])) + + def test_df_parameter(self): + df = pd.DataFrame({"label": [2, 2, 1]}) + labels = aa.get_labels(df=df, positive_label=1) + assert np.array_equal(labels, np.array([0, 0, 1])) + + def test_col_label_parameter(self): + df = pd.DataFrame({"y": [1, 2, 1, 2]}) + labels = aa.get_labels(df=df, positive_label=2, col_label="y") + assert np.array_equal(labels, np.array([0, 1, 0, 1])) + + +# Golden equivalence to the manual expression (KPI: >= 2 encodings) +class TestGetLabelsEquivalence: + """Result equals the manual expression on multiple label encodings (KPI #308).""" + + def test_pu_encoding_1_2(self): + # PU encoding: 1 = positive, 2 = unlabeled + df = pd.DataFrame({"label": [1, 2, 1, 2, 2, 1]}) + assert np.array_equal(aa.get_labels(df=df, positive_label=1), + _manual(df, 1)) + + def test_binary_encoding_0_1(self): + # Standard {0, 1} encoding + df = pd.DataFrame({"label": [0, 1, 1, 0]}) + assert np.array_equal(aa.get_labels(df=df, positive_label=1), + _manual(df, 1)) + + def test_multiclass_encoding(self): + # Multi-class: pick one class as positive + df = pd.DataFrame({"label": [0, 1, 2, 0, 1, 2]}) + for pos in (0, 1, 2): + assert np.array_equal(aa.get_labels(df=df, positive_label=pos), + _manual(df, pos)) + + def test_string_labels(self): + df = pd.DataFrame({"label": ["sub", "non", "sub", "unl"]}) + assert np.array_equal(aa.get_labels(df=df, positive_label="sub"), + _manual(df, "sub")) + + def test_single_class_column_maps_all_ones(self): + # Pure mapping: unlike dPULearn.fit, get_labels does not require >1 distinct value, + # so an all-positive column maps to all ones rather than raising. + df = pd.DataFrame({"label": [1, 1, 1]}) + assert np.array_equal(aa.get_labels(df=df, positive_label=1), + np.array([1, 1, 1])) + + def test_nan_maps_to_zero(self): + # NaN never equals positive_label, so it becomes 0. + df = pd.DataFrame({"label": [1.0, np.nan, 1.0]}) + assert np.array_equal(aa.get_labels(df=df, positive_label=1.0), + np.array([1, 0, 1])) + + +# Negative Cases Test Class +class TestGetLabelsNegative: + """Invalid inputs must raise informative ValueErrors.""" + + def test_df_none(self): + with pytest.raises(ValueError): + aa.get_labels(df=None, positive_label=1) + + def test_df_not_dataframe(self): + with pytest.raises(ValueError): + aa.get_labels(df=[1, 2, 3], positive_label=1) + + def test_missing_label_column(self): + df = pd.DataFrame({"entry": ["a", "b"], "y": [1, 0]}) + with pytest.raises(ValueError): + aa.get_labels(df=df, positive_label=1) + + def test_custom_col_missing(self): + df = pd.DataFrame({"label": [1, 0]}) + with pytest.raises(ValueError): + aa.get_labels(df=df, positive_label=1, col_label="missing") + + def test_positive_label_absent(self): + df = pd.DataFrame({"label": [1, 2, 1]}) + with pytest.raises(ValueError): + aa.get_labels(df=df, positive_label=9) diff --git a/tests/unit/dpulearn_tests/test_dpulearn_mine_negatives.py b/tests/unit/dpulearn_tests/test_dpulearn_mine_negatives.py new file mode 100644 index 00000000..e0b6b9f9 --- /dev/null +++ b/tests/unit/dpulearn_tests/test_dpulearn_mine_negatives.py @@ -0,0 +1,191 @@ +""" +Tests the dPULearn.fit positives/unlabeled split-input mode + mask_neg_ (issue #308). + +For the common positive/unlabeled setup, ``fit`` accepts ``X_pos`` and ``X_unlabeled`` +separately (instead of ``X`` + a hand-built 1/2 label vector), stacks them internally, and +sets ``mask_neg_`` — the boolean mask of identified reliable negatives over the rows of +``X_unlabeled``. The key contract is that ``mask_neg_`` equals the manual +``labels_[len(X_pos):] == 0`` result exactly, and that the existing ``fit(X, labels=...)`` +path stays byte-identical (no algorithm change). +""" +import numpy as np +import pytest + +import aaanalysis as aa + + +# Helper functions +def _make_data(n_pos=20, n_unl=50, n_features=8, seed=0): + rng = np.random.default_rng(seed) + X_pos = rng.normal(0.0, 1.0, size=(n_pos, n_features)) + X_unl = rng.normal(0.6, 1.0, size=(n_unl, n_features)) + return X_pos, X_unl + + +def _manual_mask(X_pos, X_unl, random_state=42, **fit_kwargs): + """Reproduce the notebook cell 18/24 manual stacking path.""" + X_pool = np.vstack([X_pos, X_unl]) + y_pool = np.array([1] * len(X_pos) + [2] * len(X_unl)) + dpul = aa.dPULearn(random_state=random_state, verbose=False) + dpul.fit(X=X_pool, labels=y_pool, **fit_kwargs) + return np.asarray(dpul.labels_)[len(X_pos):] == 0, dpul + + +# Normal Cases Test Class +class TestFitSplitInput: + """fit(X_pos=, X_unlabeled=) sets mask_neg_ over the unlabeled rows, per parameter.""" + + def test_mask_neg_is_boolean_over_unlabeled(self): + X_pos, X_unl = _make_data() + dpul = aa.dPULearn(random_state=42, verbose=False) + mask = dpul.fit(X_pos=X_pos, X_unlabeled=X_unl, n_neg=10).mask_neg_ + assert isinstance(mask, np.ndarray) + assert mask.dtype == bool + assert mask.shape == (X_unl.shape[0],) + assert mask.sum() == 10 + + def test_X_pos_parameter(self): + X_pos, X_unl = _make_data(n_pos=30) + dpul = aa.dPULearn(random_state=42, verbose=False) + mask = dpul.fit(X_pos=X_pos, X_unlabeled=X_unl, n_neg=5).mask_neg_ + assert mask.shape[0] == X_unl.shape[0] + + def test_X_unlabeled_parameter(self): + X_pos, X_unl = _make_data(n_unl=70) + dpul = aa.dPULearn(random_state=42, verbose=False) + mask = dpul.fit(X_pos=X_pos, X_unlabeled=X_unl, n_neg=12).mask_neg_ + assert mask.shape[0] == 70 + assert mask.sum() == 12 + + def test_n_neg_parameter(self): + X_pos, X_unl = _make_data() + for n in (1, 5, 25): + dpul = aa.dPULearn(random_state=42, verbose=False) + mask = dpul.fit(X_pos=X_pos, X_unlabeled=X_unl, n_neg=n).mask_neg_ + assert mask.sum() == n + + def test_n_unl_to_neg_equivalent_to_n_neg(self): + # With no pre-labeled negatives the two count params are equivalent. + X_pos, X_unl = _make_data() + m1 = aa.dPULearn(random_state=42, verbose=False).fit(X_pos=X_pos, X_unlabeled=X_unl, n_neg=8).mask_neg_ + m2 = aa.dPULearn(random_state=42, verbose=False).fit(X_pos=X_pos, X_unlabeled=X_unl, n_unl_to_neg=8).mask_neg_ + assert np.array_equal(m1, m2) + + def test_metric_parameter(self): + X_pos, X_unl = _make_data() + for metric in ("euclidean", "manhattan", "cosine"): + dpul = aa.dPULearn(random_state=42, verbose=False) + mask = dpul.fit(X_pos=X_pos, X_unlabeled=X_unl, n_neg=10, metric=metric).mask_neg_ + assert mask.sum() == 10 + + def test_n_components_parameter(self): + X_pos, X_unl = _make_data() + for n_components in (2, 3, 0.5): + dpul = aa.dPULearn(random_state=42, verbose=False) + mask = dpul.fit(X_pos=X_pos, X_unlabeled=X_unl, n_neg=10, n_components=n_components).mask_neg_ + assert mask.sum() == 10 + + def test_fit_returns_self_in_split_mode(self): + X_pos, X_unl = _make_data() + dpul = aa.dPULearn(random_state=42, verbose=False) + out = dpul.fit(X_pos=X_pos, X_unlabeled=X_unl, n_neg=10) + assert out is dpul # sklearn contract preserved + + def test_instance_attributes_set(self): + X_pos, X_unl = _make_data() + dpul = aa.dPULearn(random_state=42, verbose=False) + dpul.fit(X_pos=X_pos, X_unlabeled=X_unl, n_neg=10) + assert dpul.labels_ is not None + assert dpul.labels_.shape[0] == X_pos.shape[0] + X_unl.shape[0] + assert dpul.df_pu_ is not None + + +# Regression / golden equivalence +class TestSplitMaskEquivalence: + """mask_neg_ must equal the manual stacking path exactly (KPI #308).""" + + @pytest.mark.parametrize("seed", [0, 1, 7]) + def test_mask_equals_manual_pca(self, seed): + X_pos, X_unl = _make_data(seed=seed) + manual_mask, dpul_m = _manual_mask(X_pos, X_unl, n_unl_to_neg=10) + dpul = aa.dPULearn(random_state=42, verbose=False) + mask = dpul.fit(X_pos=X_pos, X_unlabeled=X_unl, n_neg=10).mask_neg_ + assert np.array_equal(mask, manual_mask) + assert np.array_equal(np.asarray(dpul.labels_), np.asarray(dpul_m.labels_)) + + def test_mask_equals_manual_metric(self): + X_pos, X_unl = _make_data(seed=3) + manual_mask, _ = _manual_mask(X_pos, X_unl, n_unl_to_neg=8, metric="cosine") + dpul = aa.dPULearn(random_state=42, verbose=False) + mask = dpul.fit(X_pos=X_pos, X_unlabeled=X_unl, n_neg=8, metric="cosine").mask_neg_ + assert np.array_equal(mask, manual_mask) + + def test_mask_equals_manual_few_positives(self): + # n_pos < 3: the stacked matrix carries the >=3 floor, so the split path accepts a + # small positive set exactly as the manual path does. + X_pos, X_unl = _make_data(n_pos=1, seed=5) + manual_mask, _ = _manual_mask(X_pos, X_unl, n_unl_to_neg=6) + dpul = aa.dPULearn(random_state=42, verbose=False) + mask = dpul.fit(X_pos=X_pos, X_unlabeled=X_unl, n_neg=6).mask_neg_ + assert np.array_equal(mask, manual_mask) + + def test_manual_mode_mask_neg_is_labels_zero(self): + # In the (X, labels) mode, mask_neg_ is over all rows and equals labels_ == 0. + X_pos, X_unl = _make_data(seed=2) + _, dpul = _manual_mask(X_pos, X_unl, n_unl_to_neg=9) + assert np.array_equal(dpul.mask_neg_, np.asarray(dpul.labels_) == 0) + assert dpul.mask_neg_.shape[0] == X_pos.shape[0] + X_unl.shape[0] + + +# Negative Cases Test Class +class TestSplitInputNegative: + """Invalid inputs must raise informative ValueErrors.""" + + def test_feature_mismatch(self): + X_pos, _ = _make_data(n_features=8) + _, X_unl = _make_data(n_features=6) + with pytest.raises(ValueError): + aa.dPULearn(random_state=42, verbose=False).fit(X_pos=X_pos, X_unlabeled=X_unl, n_neg=5) + + def test_n_neg_below_one(self): + X_pos, X_unl = _make_data() + with pytest.raises(ValueError): + aa.dPULearn(random_state=42, verbose=False).fit(X_pos=X_pos, X_unlabeled=X_unl, n_neg=0) + + def test_too_many_negatives_requested(self): + X_pos, X_unl = _make_data(n_unl=10) + with pytest.raises(ValueError): + aa.dPULearn(random_state=42, verbose=False).fit(X_pos=X_pos, X_unlabeled=X_unl, n_neg=999) + + def test_X_unlabeled_missing(self): + X_pos, _ = _make_data() + with pytest.raises(ValueError): + aa.dPULearn(random_state=42, verbose=False).fit(X_pos=X_pos, n_neg=5) + + def test_both_input_modes_rejected(self): + X_pos, X_unl = _make_data() + X = np.vstack([X_pos, X_unl]) + y = np.array([1] * len(X_pos) + [2] * len(X_unl)) + with pytest.raises(ValueError): + aa.dPULearn(random_state=42, verbose=False).fit(X=X, labels=y, X_pos=X_pos, + X_unlabeled=X_unl, n_neg=5) + + def test_no_input_given(self): + with pytest.raises(ValueError): + aa.dPULearn(random_state=42, verbose=False).fit(n_neg=5) + + +# Existing-fit byte-identical regression +class TestFitUnchanged: + """The pre-existing fit(X, labels=...) path stays byte-identical (#308 no-change).""" + + def test_fit_pca_unchanged(self): + X_pos, X_unl = _make_data(seed=11) + X_pool = np.vstack([X_pos, X_unl]) + y_pool = np.array([1] * len(X_pos) + [2] * len(X_unl)) + dpul = aa.dPULearn(random_state=42, verbose=False) + dpul.fit(X=X_pool, labels=y_pool, n_unl_to_neg=10) + labels = np.asarray(dpul.labels_) + assert (labels[:len(X_pos)] == 1).all() + assert (labels == 0).sum() == 10 + assert set(np.unique(labels)).issubset({0, 1, 2}) diff --git a/tests/unit/plotting_tests/test_color_samples_constants.py b/tests/unit/plotting_tests/test_color_samples_constants.py new file mode 100644 index 00000000..52826a1b --- /dev/null +++ b/tests/unit/plotting_tests/test_color_samples_constants.py @@ -0,0 +1,36 @@ +""" +This script tests the named sample-color constants exposed at top level (issue #308). + +COLOR_SAMPLES_POS / NEG / UNL / REL_NEG are public, named aliases for the canonical sample +colors. They must equal today's ``plot_get_cdict("DICT_COLOR")["SAMPLES_*"]`` values exactly, +so users can reference a named constant instead of indexing the color dict by string key. +""" +import pytest + +import aaanalysis as aa + + +# Golden equivalence test +class TestColorSamplesConstants: + """Named constants must equal the plot_get_cdict values (golden KPI #308).""" + + def test_constants_exist_at_top_level(self): + for name in ("COLOR_SAMPLES_POS", "COLOR_SAMPLES_NEG", + "COLOR_SAMPLES_UNL", "COLOR_SAMPLES_REL_NEG"): + assert hasattr(aa, name) + assert name in aa.__all__ + + @pytest.mark.parametrize("const_name,dict_key", [ + ("COLOR_SAMPLES_POS", "SAMPLES_POS"), + ("COLOR_SAMPLES_NEG", "SAMPLES_NEG"), + ("COLOR_SAMPLES_UNL", "SAMPLES_UNL"), + ("COLOR_SAMPLES_REL_NEG", "SAMPLES_REL_NEG"), + ]) + def test_constant_equals_cdict_value(self, const_name, dict_key): + dict_color = aa.plot_get_cdict(name="DICT_COLOR") + assert getattr(aa, const_name) == dict_color[dict_key] + + def test_constants_are_strings(self): + for name in ("COLOR_SAMPLES_POS", "COLOR_SAMPLES_NEG", + "COLOR_SAMPLES_UNL", "COLOR_SAMPLES_REL_NEG"): + assert isinstance(getattr(aa, name), str)