From 2e4b156c3e2e17dc3f3e2b8255ed007c41a87bec Mon Sep 17 00:00:00 2001 From: Yang Date: Thu, 13 Feb 2025 20:48:07 +0100 Subject: [PATCH 1/8] init ITRAP --- dextrademixer/model/ITRAP.py | 316 +++++++++++++++++++++++++++ dextrademixer/test/TestITRAPModel.py | 35 +++ 2 files changed, 351 insertions(+) create mode 100644 dextrademixer/model/ITRAP.py create mode 100644 dextrademixer/test/TestITRAPModel.py diff --git a/dextrademixer/model/ITRAP.py b/dextrademixer/model/ITRAP.py new file mode 100644 index 0000000..2e94356 --- /dev/null +++ b/dextrademixer/model/ITRAP.py @@ -0,0 +1,316 @@ +from __future__ import annotations + +import os.path +from typing import TYPE_CHECKING, Tuple + +import mudata as md +import pandas as pd + +import numpy as np +import jax.lax +import jax +from scipy import stats + +from dextrademixer.model import ApMHCDeconvolution + +if TYPE_CHECKING: + from jax._src.typing import Array + + +class ITRAP(ApMHCDeconvolution): + """ + This class implements the ITRAP algorithm introduced by Povlsen et al. (2023). + First each clonotype with more than 10 cells is assigned an expected target if the highest UMI count is + significantly higher than the second most abundant pMHC using Wilcoxon p < 0.05. + Each cell's specificity is then assigned to the most abundant pMHC based on UMI count. + Using this expected target per clonotype, ITRAP calculates ideal UMI thresholds using a grid-search by optimizing + the accuracy (if the epitope with highest UMI count of a cell matches the expected target) while preserving the + ratio of retained cells using a weighted average between both objectives. + The optimal thresholds are then used to filter cells. Further filtering steps may be + included if the respective data, e.g., donor HLA, is available. + """ + __name = "ITRAP" + __version = "0.0.1" + + def __init__(self, umi_cols=None, umi_count_TRA=None, umi_count_TRB=None, filters=None): + """ + Args: + umi_cols: List of columns containing UMI counts for pMHCs (default set to ['neg_control', 'pmhc1']) + umi_count_TRA: List of columns containing UMI counts for TRA (default: None) + umi_count_TRB: List of columns containing UMI counts for TRB (default: None) + filters: List of filters to apply, options=['opt_thr', 'hashing_singlets', 'matching_HLA', 'complete_TCRs', + 'specificity_multiplets', 'is_cell', 'viable_cells'] (default: ['opt_thr']) + """ + super().__init__() + self.opt_thr = None + self.umi_cols_mhc = umi_cols + self.umi_cols_TRA = umi_count_TRA + self.umi_cols_TRB = umi_count_TRB + self.filters = filters if filters is not None else ['opt_thr'] + self.data = None + self.ir_clone_key = None + self.specificity_to_idx = None + self.idx_to_specificity = None + + def preprocess_model_data(self, mdata: md.MuData, pmhc_key: str, gex_key: str = "gex", neg_ctrl_key: str = None, + ir_key: str = "airr", ir_clone_key: str = None, ir_cov_key: str = None, **kwargs): + if ir_clone_key is None: + raise ValueError(f"{self.__name} requires a clonotype definition. Please specify a `ir_clone_key`.") + + gex = mdata.mod[gex_key] + N = gex.shape[0] + + x = gex[:, pmhc_key].X.toarray().reshape((N,)) + x_neg = gex[:, neg_ctrl_key].X.toarray().reshape((N,)) + + self._check_parameters(x, x_neg, None, None) + self.ir_clone_key = ir_clone_key + + if self.umi_cols_mhc is None: + if neg_ctrl_key is None: + raise ValueError("No negative control specified and no umi_cols_mhc. Please provide a `neg_ctrl_key` " + "or set umi_cols_mhc during initialization.") + self.umi_cols_mhc = [neg_ctrl_key, pmhc_key] + self.specificity_to_idx = {s: i for i, s in enumerate(self.umi_cols_mhc)} + self.idx_to_specificity = {i: s for i, s in enumerate(self.umi_cols_mhc)} + + data = mdata['airr'].obs.copy() + for col in self.umi_cols_mhc: + data[col] = mdata['gex'][:, col].X.toarray().reshape(-1) + + def calc_delta(x): + """ Calculate UMI ratio of two most abundant pMHCs, 0.25 is a small constant to avoid division by zero""" + if len(x) == 1: + return x[-1] / 0.25 + elif len(x) == 0: + return 0 + else: + return (x.nlargest(2).iloc()[0]) / (x.nlargest(2).iloc()[1] + 0.25) + + # Calculate UMI count and delta for pMHCs, TRA and TRB. Nomenclature follows original implementation + # umi_count_X = max(UMI count of X) + # delta_umi_X = ratio between highest and second highest UMI counts + data['umi_count_mhc'] = data[self.umi_cols_mhc].max(1) + data['delta_umi_mhc'] = data[self.umi_cols_mhc].apply(calc_delta, axis=1) + data['umi_count_mhc_rel'] = data['umi_count_mhc'] / data['umi_count_mhc'].quantile(0.9, interpolation='lower') + if self.umi_cols_TRA is not None: + data['umi_count_TRA'] = data[self.umi_cols_TRA].max(1) + data['delta_umi_TRA'] = data[self.umi_cols_TRA].apply(calc_delta) + if self.umi_cols_TRB is not None: + data['umi_count_TRB'] = data[self.umi_cols_TRA].max(1) + data['delta_umi_TRB'] = data[self.umi_cols_TRB].apply(calc_delta) + + self.data = data + + def fit(self): + """ + Fit the ITRAP model to the data. Calculate the ideal UMI thresholds for filtering + """ + if self.data is None: + raise Exception("Model is not initialized. Please call `preprocess_model_data` first.") + + # Calculate ideal thresholds + self.opt_thr = self._calculate_ideal_umi_thresholds(self.data) + + def predict_posterior_class(self, threshold: float = None, target_fdr: float = None) -> Tuple[Array, Array]: + """ + Returns the binder assignments based on the most abundant UMI count for each cell. + To filter out noise, different filters are applied to the data. + ITRAP does not return a posterior probability, so the assignment is returned as pseudo value. + Threshold and target_fdr are ignored in this implementation. + + Args: + threshold: (Optional) ignored + target_fdr: (Optional) ignored + Returns: + A tuple (p, assignment) of arrays with p being the pseudo value for compatibility of binding and assignment + the class assignment decision + """ + if self.opt_thr is None: + raise RuntimeError("Model has not been fit yet. Please call first `fit`.") + + # Assign cells to most abundant pMHC based on UMI count, then set assignment to 0 if it fails filters + filters = self._generate_filters(self.data) + self.data['assignment'] = self.data[self.umi_cols_mhc].idxmax(1).values + self.data['assignment'] = self.data['assignment'].map(self.specificity_to_idx) + self.data['assignment_before_filtering'] = self.data['assignment'].copy() + self.data.loc[~filters, 'assignment'] = 0 + + return self.data['assignment'].values.astype(int), self.data['assignment'].values.astype(float) + + def _generate_filters(self, data): + filters = pd.Series([True] * len(data), index=data.index) + + # Filter 1: UMI thresholds + if 'opt_thr' in self.filters: + for k, thr in self.opt_thr.items(): + if k in data.columns: + filters &= data[k] >= thr + # filters &= eval(' & '.join([f'(data["{k}"] >= {abs(v)})' for k, v in self.opt_thr.items() if k in data.columns])) + + # TODO Other filters are not implemented yet + # Filter 2: Hashing singlets + if 'hashing_singlets' in self.filters: + raise NotImplementedError("Hashing singlets filter is not implemented yet.") + + # Filter 3: Matching HLA + if 'matching_HLA' in self.filters: + raise NotImplementedError("Matching HLA filter is not implemented yet.") + + # Filter 4: Complete TCRs + if 'complete_TCRs' in self.filters: + raise NotImplementedError("Complete TCRs filter is not implemented yet.") + + # Filter 5: Specificity multiplets + if 'specificity_multiplets' in self.filters: + raise NotImplementedError("Specificity multiplets filter is not implemented yet.") + + # Filter 6: Is cell (Cellranger) + if 'is_cell' in self.filters: + raise NotImplementedError("Is cell filter is not implemented yet.") + + # Filter 7: Viable cells (GEX) + if 'viable_cells' in self.filters: + raise NotImplementedError("Viable cells filter is not implemented yet.") + + return filters + + def _calculate_expected_target(self, data): + # Select two most abundant pMHC based on UMI count + most_abundant_epitope = data[self.umi_cols_mhc].sum(0).nlargest(2).index + + w, p = stats.wilcoxon(data[most_abundant_epitope[0]].fillna(0) - data[most_abundant_epitope[1]].fillna(0), + alternative='greater') + + if p <= 0.05: + return True, most_abundant_epitope[0] + else: + return False, most_abundant_epitope[0] + + def _calculate_ideal_umi_thresholds(self, data): + # TODO What if tie in specificity? Should we just take the first one? + data['cell_specificity'] = data[self.umi_cols_mhc].idxmax(1).values + + # Calculate expected target for each clonotype + ct_pep = data.groupby(self.ir_clone_key).filter(lambda x: len(x) >= 10) + ct_pep = ct_pep.groupby(self.ir_clone_key).apply(self._calculate_expected_target).to_frame() + ct_pep[['significant', 'expected_target']] = ct_pep[0].apply(pd.Series) + ct_pep = ct_pep[ct_pep['significant']].drop(columns=0) + + # Add expected target of each clonotype to full data and filter out non-significant clonotype targets + data['ct_pep'] = data[self.ir_clone_key].map(ct_pep['expected_target']) + cells_with_ct_pep = data[data['ct_pep'].notna()].copy() + cells_with_ct_pep['pep_match'] = cells_with_ct_pep['cell_specificity'] == cells_with_ct_pep['ct_pep'] + + # Grid search for optimal UMI threshold, hparams extracted from ITRAP code + if self.umi_cols_TRA is None: + umi_count_TRA_l = [None] + delta_umi_TRA_l = [None] + else: + umi_count_TRA_l = np.arange(0, data['umi_count_TRA'].quantile(0.4, interpolation='higher')) + delta_umi_TRA_l = np.arange(0, 4) + if self.umi_cols_TRB is None: + umi_count_TRB_l = [None] + delta_umi_TRB_l = [None] + else: + umi_count_TRB_l = np.arange(0, data['umi_count_TRB'].quantile(0.4, interpolation='higher')) + delta_umi_TRB_l = np.arange(0, 4) + # TODO Why start from 1 in original implementation? + umi_count_mhc_l = np.arange(1, data['umi_count_mhc'].quantile(0.5, interpolation='higher')) + delta_umi_mhc_l = [0, 1, 2] # hparam from itrap Snakefile + umi_relat_mhc_l = [0] # seems unused in original implementation + + table = pd.DataFrame(columns=['accuracy', 'ratio_retained_gems', 'umi_count_mhc', 'umi_relat_mhc_l', + 'delta_umi_mhc', 'umi_count_TRA', 'delta_umi_TRA', 'umi_count_TRB', + 'delta_umi_TRB',]) + + n_total_gems = len(cells_with_ct_pep) + + i = -1 + for uca in umi_count_TRA_l: + for dua in delta_umi_TRA_l: + for ucb in umi_count_TRB_l: + for dub in delta_umi_TRB_l: + for ucm in umi_count_mhc_l: + for urm in umi_relat_mhc_l: + for dum in delta_umi_mhc_l: + i += 1 + filter_bool = ((cells_with_ct_pep['umi_count_mhc'] >= ucm) & + (cells_with_ct_pep['delta_umi_mhc'] >= dum) & + (cells_with_ct_pep['umi_count_mhc_rel'] >= urm)) + + if self.umi_cols_TRA is not None: + filter_bool &= (cells_with_ct_pep['umi_count_TRA'] >= uca) & ( + cells_with_ct_pep['delta_umi_TRA'] >= dua) + if self.umi_cols_TRB is not None: + filter_bool &= (cells_with_ct_pep['umi_count_TRB'] >= ucb) & ( + cells_with_ct_pep['delta_umi_TRB'] >= dub) + + flt = cells_with_ct_pep[filter_bool].copy() + + n_gems = len(flt) + n_mat = flt['pep_match'].sum() + + g_ratio = round(n_gems / n_total_gems, 3) + acc = round(n_mat / n_gems, 3) + + table.loc[i] = (acc, g_ratio, ucm, urm, dum, uca, dua, ucb, dub,) + + table['mix_mean'] = (table['accuracy'] * 2 + table['ratio_retained_gems']) / 3 + optimal_thresholds = (table.sort_values(by=['mix_mean', 'accuracy', 'ratio_retained_gems', 'umi_count_mhc'], + ascending=[True, True, True, False])) + opt_thr = optimal_thresholds.iloc()[-1][['umi_count_mhc', 'delta_umi_mhc', 'umi_count_TRA', + 'delta_umi_TRA', 'umi_count_TRB', 'delta_umi_TRB']] + + return opt_thr + + +if __name__ == "__main__": + from dextrademixer.utils import DextramerSimulator + import muon as mu + + + if os.path.exists('../../data/test.h5mu'): + mdata = mu.read('../../data/test.h5mu') + else: + + sim = DextramerSimulator() + mdata = sim.simulate_pmhc_data_from_distribution(total_cells=10000, + nof_clones=150, + p_binding_outlier=0.05, + binding_ratio=0.1, + binding_fold_increase_range=[5], + variance_fold_increase_range=[1.2], + simulate_neg_control=True, + use_clonotype_cov=True, + plot_data=False, + rng_key=42) + mdata.write('../../mdata/test.h5mu') + + itrap = ITRAP(filters=['opt_thr']) + itrap.preprocess_model_data(mdata, "pmhc1", neg_ctrl_key="neg_control", ir_key="airr", ir_clone_key='clone_id') + itrap.fit() + p, assignment = itrap.predict_posterior_class() + print(assignment) + + +from dextrademixer.utils.simulation import DextramerSimulator + +sim = DextramerSimulator() +mdata = sim.simulate_pmhc_data_from_distribution(total_cells=500, nof_clones=10, binding_ratio=0.05, + simulate_neg_control=True, rng_key=42 + ) + + + +binder = mdata.mod["airr"].obs["is_binder"].to_numpy() + +itrap = ITRAP() +itrap.preprocess_model_data(mdata, "pmhc1", neg_ctrl_key="neg_control", ir_clone_key="clone_id") +itrap.fit() +p, assignment = itrap.predict_posterior_class(target_fdr=0.05) +print(assignment) +print(binder) +N = len(binder) +accuracy = (binder == assignment).sum() / N +print("Accuracy", accuracy) diff --git a/dextrademixer/test/TestITRAPModel.py b/dextrademixer/test/TestITRAPModel.py new file mode 100644 index 0000000..ee83710 --- /dev/null +++ b/dextrademixer/test/TestITRAPModel.py @@ -0,0 +1,35 @@ +import unittest + +import numpy as np +import muon as mu +import pandas as pd + +from dextrademixer.model.ITRAP import ITRAP +from dextrademixer.utils.simulation import DextramerSimulator + + +class MyTestCase(unittest.TestCase): + + def setUp(self): + sim = DextramerSimulator() + self.mdata = sim.simulate_pmhc_data_from_distribution(total_cells=500, nof_clones=10, binding_ratio=0.05, + simulate_neg_control=True, rng_key=42 + ) + + self.binder = self.mdata.mod["airr"].obs["is_binder"].to_numpy() + + def test_ITRAP(self): + itrap = ITRAP() + itrap.preprocess_model_data(self.mdata, "pmhc1", neg_ctrl_key="neg_control", ir_clone_key="clone_id") + itrap.fit() + p, assignment = itrap.predict_posterior_class(target_fdr=0.05) + print(assignment) + print(self.binder) + N = len(self.binder) + accuracy = (self.binder == assignment).sum() / N + print("Accuracy", accuracy) + + + +if __name__ == '__main__': + unittest.main() From 7de0dc70d9a8547e5b163972069e7f2378e72d11 Mon Sep 17 00:00:00 2001 From: Yang Date: Thu, 13 Feb 2025 20:55:24 +0100 Subject: [PATCH 2/8] refactor --- dextrademixer/model/ITRAP.py | 64 ++++++++-------------------- dextrademixer/test/TestITRAPModel.py | 1 - 2 files changed, 18 insertions(+), 47 deletions(-) diff --git a/dextrademixer/model/ITRAP.py b/dextrademixer/model/ITRAP.py index 2e94356..1cdc693 100644 --- a/dextrademixer/model/ITRAP.py +++ b/dextrademixer/model/ITRAP.py @@ -148,7 +148,7 @@ def _generate_filters(self, data): filters &= data[k] >= thr # filters &= eval(' & '.join([f'(data["{k}"] >= {abs(v)})' for k, v in self.opt_thr.items() if k in data.columns])) - # TODO Other filters are not implemented yet + # TODO Other filters are not implemented yet, only makes sense once we have the respective data # Filter 2: Hashing singlets if 'hashing_singlets' in self.filters: raise NotImplementedError("Hashing singlets filter is not implemented yet.") @@ -188,7 +188,7 @@ def _calculate_expected_target(self, data): return False, most_abundant_epitope[0] def _calculate_ideal_umi_thresholds(self, data): - # TODO What if tie in specificity? Should we just take the first one? + # In case of a tie, in default params negative control is the first column and hence chosen as most abundant data['cell_specificity'] = data[self.umi_cols_mhc].idxmax(1).values # Calculate expected target for each clonotype @@ -215,7 +215,7 @@ def _calculate_ideal_umi_thresholds(self, data): else: umi_count_TRB_l = np.arange(0, data['umi_count_TRB'].quantile(0.4, interpolation='higher')) delta_umi_TRB_l = np.arange(0, 4) - # TODO Why start from 1 in original implementation? + umi_count_mhc_l = np.arange(1, data['umi_count_mhc'].quantile(0.5, interpolation='higher')) delta_umi_mhc_l = [0, 1, 2] # hparam from itrap Snakefile umi_relat_mhc_l = [0] # seems unused in original implementation @@ -266,51 +266,23 @@ def _calculate_ideal_umi_thresholds(self, data): if __name__ == "__main__": - from dextrademixer.utils import DextramerSimulator - import muon as mu - - - if os.path.exists('../../data/test.h5mu'): - mdata = mu.read('../../data/test.h5mu') - else: - - sim = DextramerSimulator() - mdata = sim.simulate_pmhc_data_from_distribution(total_cells=10000, - nof_clones=150, - p_binding_outlier=0.05, - binding_ratio=0.1, - binding_fold_increase_range=[5], - variance_fold_increase_range=[1.2], - simulate_neg_control=True, - use_clonotype_cov=True, - plot_data=False, - rng_key=42) - mdata.write('../../mdata/test.h5mu') - - itrap = ITRAP(filters=['opt_thr']) - itrap.preprocess_model_data(mdata, "pmhc1", neg_ctrl_key="neg_control", ir_key="airr", ir_clone_key='clone_id') - itrap.fit() - p, assignment = itrap.predict_posterior_class() - print(assignment) - + from dextrademixer.utils.simulation import DextramerSimulator -from dextrademixer.utils.simulation import DextramerSimulator + sim = DextramerSimulator() + mdata = sim.simulate_pmhc_data_from_distribution(total_cells=500, nof_clones=10, binding_ratio=0.05, + simulate_neg_control=True, rng_key=42 + ) -sim = DextramerSimulator() -mdata = sim.simulate_pmhc_data_from_distribution(total_cells=500, nof_clones=10, binding_ratio=0.05, - simulate_neg_control=True, rng_key=42 - ) + binder = mdata.mod["airr"].obs["is_binder"].to_numpy() -binder = mdata.mod["airr"].obs["is_binder"].to_numpy() - -itrap = ITRAP() -itrap.preprocess_model_data(mdata, "pmhc1", neg_ctrl_key="neg_control", ir_clone_key="clone_id") -itrap.fit() -p, assignment = itrap.predict_posterior_class(target_fdr=0.05) -print(assignment) -print(binder) -N = len(binder) -accuracy = (binder == assignment).sum() / N -print("Accuracy", accuracy) + itrap = ITRAP() + itrap.preprocess_model_data(mdata, "pmhc1", neg_ctrl_key="neg_control", ir_clone_key="clone_id") + itrap.fit() + p, assignment = itrap.predict_posterior_class(target_fdr=0.05) + print(assignment) + print(binder) + N = len(binder) + accuracy = (binder == assignment).sum() / N + print("Accuracy", accuracy) diff --git a/dextrademixer/test/TestITRAPModel.py b/dextrademixer/test/TestITRAPModel.py index ee83710..6f54463 100644 --- a/dextrademixer/test/TestITRAPModel.py +++ b/dextrademixer/test/TestITRAPModel.py @@ -30,6 +30,5 @@ def test_ITRAP(self): print("Accuracy", accuracy) - if __name__ == '__main__': unittest.main() From 9a0795ba066423de5071094cb67b69fa332ec0b5 Mon Sep 17 00:00:00 2001 From: Yang Date: Fri, 14 Feb 2025 10:12:46 +0100 Subject: [PATCH 3/8] refactor: remove debugging block --- dextrademixer/model/ITRAP.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/dextrademixer/model/ITRAP.py b/dextrademixer/model/ITRAP.py index 1cdc693..0a74164 100644 --- a/dextrademixer/model/ITRAP.py +++ b/dextrademixer/model/ITRAP.py @@ -146,7 +146,6 @@ def _generate_filters(self, data): for k, thr in self.opt_thr.items(): if k in data.columns: filters &= data[k] >= thr - # filters &= eval(' & '.join([f'(data["{k}"] >= {abs(v)})' for k, v in self.opt_thr.items() if k in data.columns])) # TODO Other filters are not implemented yet, only makes sense once we have the respective data # Filter 2: Hashing singlets @@ -263,26 +262,3 @@ def _calculate_ideal_umi_thresholds(self, data): 'delta_umi_TRA', 'umi_count_TRB', 'delta_umi_TRB']] return opt_thr - - -if __name__ == "__main__": - from dextrademixer.utils.simulation import DextramerSimulator - - sim = DextramerSimulator() - mdata = sim.simulate_pmhc_data_from_distribution(total_cells=500, nof_clones=10, binding_ratio=0.05, - simulate_neg_control=True, rng_key=42 - ) - - - - binder = mdata.mod["airr"].obs["is_binder"].to_numpy() - - itrap = ITRAP() - itrap.preprocess_model_data(mdata, "pmhc1", neg_ctrl_key="neg_control", ir_clone_key="clone_id") - itrap.fit() - p, assignment = itrap.predict_posterior_class(target_fdr=0.05) - print(assignment) - print(binder) - N = len(binder) - accuracy = (binder == assignment).sum() / N - print("Accuracy", accuracy) From 8de2726adf5952924e4f649782d088ef1d2edb3e Mon Sep 17 00:00:00 2001 From: "benjamin.schubert" Date: Wed, 30 Jul 2025 17:20:27 +0200 Subject: [PATCH 4/8] - fixed hard coded gex_key and ir_key in process_model_data --- dextrademixer/model/ITRAP.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dextrademixer/model/ITRAP.py b/dextrademixer/model/ITRAP.py index 0a74164..5e096b5 100644 --- a/dextrademixer/model/ITRAP.py +++ b/dextrademixer/model/ITRAP.py @@ -74,9 +74,9 @@ def preprocess_model_data(self, mdata: md.MuData, pmhc_key: str, gex_key: str = self.specificity_to_idx = {s: i for i, s in enumerate(self.umi_cols_mhc)} self.idx_to_specificity = {i: s for i, s in enumerate(self.umi_cols_mhc)} - data = mdata['airr'].obs.copy() + data = mdata[ir_key].obs.copy() for col in self.umi_cols_mhc: - data[col] = mdata['gex'][:, col].X.toarray().reshape(-1) + data[col] = mdata[gex_key][:, col].X.toarray().reshape(-1) def calc_delta(x): """ Calculate UMI ratio of two most abundant pMHCs, 0.25 is a small constant to avoid division by zero""" From 6460587974837f0f7d8c8367a93a0b005a4987b0 Mon Sep 17 00:00:00 2001 From: irene-bonapa Date: Mon, 30 Mar 2026 14:49:42 +0200 Subject: [PATCH 5/8] at least 2 columns for TCR chain optimization + TCRA/TCRB typo --- dextrademixer/model/ITRAP.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/dextrademixer/model/ITRAP.py b/dextrademixer/model/ITRAP.py index 5e096b5..ee5cf6c 100644 --- a/dextrademixer/model/ITRAP.py +++ b/dextrademixer/model/ITRAP.py @@ -94,11 +94,13 @@ def calc_delta(x): data['delta_umi_mhc'] = data[self.umi_cols_mhc].apply(calc_delta, axis=1) data['umi_count_mhc_rel'] = data['umi_count_mhc'] / data['umi_count_mhc'].quantile(0.9, interpolation='lower') if self.umi_cols_TRA is not None: - data['umi_count_TRA'] = data[self.umi_cols_TRA].max(1) - data['delta_umi_TRA'] = data[self.umi_cols_TRA].apply(calc_delta) + if data[[self.umi_cols_TRA]].shape[1] > 1: + data['umi_count_TRA'] = data[self.umi_cols_TRA].max(1) + data['delta_umi_TRA'] = data[self.umi_cols_TRA].apply(calc_delta) if self.umi_cols_TRB is not None: - data['umi_count_TRB'] = data[self.umi_cols_TRA].max(1) - data['delta_umi_TRB'] = data[self.umi_cols_TRB].apply(calc_delta) + if data[[self.umi_cols_TRB]].shape[1] > 1: + data['umi_count_TRB'] = data[self.umi_cols_TRB].max(1) + data['delta_umi_TRB'] = data[self.umi_cols_TRB].apply(calc_delta) self.data = data From 79c06d7227dcd944dba5efa232a60484bbea9d7c Mon Sep 17 00:00:00 2001 From: irene-bonapa Date: Mon, 30 Mar 2026 18:33:09 +0200 Subject: [PATCH 6/8] clean-up pipeline --- dextrademixer/model/ITRAP.py | 156 ++++++++++++++++++++--------------- 1 file changed, 88 insertions(+), 68 deletions(-) diff --git a/dextrademixer/model/ITRAP.py b/dextrademixer/model/ITRAP.py index ee5cf6c..2ee8a10 100644 --- a/dextrademixer/model/ITRAP.py +++ b/dextrademixer/model/ITRAP.py @@ -1,23 +1,22 @@ from __future__ import annotations -import os.path -from typing import TYPE_CHECKING, Tuple +from typing import List, Union -import mudata as md -import pandas as pd - -import numpy as np -import jax.lax -import jax from scipy import stats +import numpy as np +import pandas as pd -from dextrademixer.model import ApMHCDeconvolution - -if TYPE_CHECKING: - from jax._src.typing import Array +import anndata as ad +import mudata as md +# Ignore small sample warning from scipy when calculating expected target for clonotypes with few cells +import warnings +warnings.filterwarnings( + "ignore", + message=".*sample arguments is too small.*" +) -class ITRAP(ApMHCDeconvolution): +class ITRAP: """ This class implements the ITRAP algorithm introduced by Povlsen et al. (2023). First each clonotype with more than 10 cells is assigned an expected target if the highest UMI count is @@ -32,52 +31,44 @@ class ITRAP(ApMHCDeconvolution): __name = "ITRAP" __version = "0.0.1" - def __init__(self, umi_cols=None, umi_count_TRA=None, umi_count_TRB=None, filters=None): + def __init__(self, filters=None): """ Args: - umi_cols: List of columns containing UMI counts for pMHCs (default set to ['neg_control', 'pmhc1']) - umi_count_TRA: List of columns containing UMI counts for TRA (default: None) - umi_count_TRB: List of columns containing UMI counts for TRB (default: None) filters: List of filters to apply, options=['opt_thr', 'hashing_singlets', 'matching_HLA', 'complete_TCRs', 'specificity_multiplets', 'is_cell', 'viable_cells'] (default: ['opt_thr']) """ super().__init__() self.opt_thr = None - self.umi_cols_mhc = umi_cols - self.umi_cols_TRA = umi_count_TRA - self.umi_cols_TRB = umi_count_TRB self.filters = filters if filters is not None else ['opt_thr'] self.data = None self.ir_clone_key = None self.specificity_to_idx = None self.idx_to_specificity = None - def preprocess_model_data(self, mdata: md.MuData, pmhc_key: str, gex_key: str = "gex", neg_ctrl_key: str = None, - ir_key: str = "airr", ir_clone_key: str = None, ir_cov_key: str = None, **kwargs): - if ir_clone_key is None: - raise ValueError(f"{self.__name} requires a clonotype definition. Please specify a `ir_clone_key`.") - - gex = mdata.mod[gex_key] - N = gex.shape[0] - - x = gex[:, pmhc_key].X.toarray().reshape((N,)) - x_neg = gex[:, neg_ctrl_key].X.toarray().reshape((N,)) - - self._check_parameters(x, x_neg, None, None) - self.ir_clone_key = ir_clone_key - - if self.umi_cols_mhc is None: - if neg_ctrl_key is None: - raise ValueError("No negative control specified and no umi_cols_mhc. Please provide a `neg_ctrl_key` " - "or set umi_cols_mhc during initialization.") - self.umi_cols_mhc = [neg_ctrl_key, pmhc_key] - self.specificity_to_idx = {s: i for i, s in enumerate(self.umi_cols_mhc)} - self.idx_to_specificity = {i: s for i, s in enumerate(self.umi_cols_mhc)} - - data = mdata[ir_key].obs.copy() - for col in self.umi_cols_mhc: - data[col] = mdata[gex_key][:, col].X.toarray().reshape(-1) - + def preprocess_model_data( + self, + adata: Union[md.MuData, ad.AnnData], + pmhc_keys: Union[str, List[str]] = None, + neg_ctrl_key: str = None, + ir_clone_key: str = 'clone_id', + dex_key: str = "dex", + ir_key: str = "airr", + umi_cols_TRA: list=None, umi_cols_TRB: list=None, + **kwargs + ): + """ + Args: + adata: A MuData object containing only dextramer counts and clonotype information, + or an AnnData object containing the dextramer counts and clonotype information in the specified obsm and obs keys. + pmhc_keys (Optional): A string or list of strings indicating the pMHC columns in `dex_key` modality which should be deconvolved. + If None is given, the full dextramer matrix is used, excluding the negative control. + neg_ctrl_key: A string specifying the negative control column in the `dex_key` matrix. + ir_clone_key: A string specifying the field in `obs` that holds clonotype ids. If adata is a MuData object, this will be prefixed with `{ir_key}:` + dex_key: the dextramer signal MuData module key, or the obsm key if adata is an AnnData object + ir_key: the MuData module key where the immune receptor data is stored, only relevant if adata is a MuData object. + umi_cols_TRA: list of strings specifying the columns in `obs` that hold the UMI counts for TRA, if available. If adata is a MuData object, these will be prefixed with `{ir_key}:` + umi_cols_TRB: list of strings specifying the columns in `obs` that hold the UMI counts for TRB, if available. If adata is a MuData object, these will be prefixed with `{ir_key}:` + """ def calc_delta(x): """ Calculate UMI ratio of two most abundant pMHCs, 0.25 is a small constant to avoid division by zero""" if len(x) == 1: @@ -86,6 +77,39 @@ def calc_delta(x): return 0 else: return (x.nlargest(2).iloc()[0]) / (x.nlargest(2).iloc()[1] + 0.25) + + # Check inputs + if ir_clone_key is None: + raise ValueError(f"{self.__name} requires a clonotype definition. Please specify a `ir_clone_key`.") + if neg_ctrl_key is None: + raise ValueError("No negative control specified. Please provide a `neg_ctrl_key` ") + + # Adjust data access for mudata and anndata + if isinstance(adata, md.MuData): + dex = adata.mod[dex_key] + dex = pd.DataFrame(dex.X.toarray(), index=dex.obs_names, columns=dex.var_names) + ir_clone_key = f'{ir_key}:{ir_clone_key}' + umi_cols_TRA = [f'{ir_key}:{col}' for col in umi_cols_TRA] if umi_cols_TRA is not None else None + umi_cols_TRB = [f'{ir_key}:{col}' for col in umi_cols_TRB] if umi_cols_TRB is not None else None + adata.pull_obs() # make sure adata.obs is updated with prefixed columns from ir module + elif isinstance(adata, ad.AnnData): + dex = adata.obsm[dex_key] + + if pmhc_keys is None: + pmhc_keys = dex.columns[dex.columns != neg_ctrl_key].tolist() + + self.umi_cols_TRA = umi_cols_TRA + self.umi_cols_TRB = umi_cols_TRB + self.umi_cols_mhc = [neg_ctrl_key] + pmhc_keys if type(pmhc_keys) == list else [neg_ctrl_key, pmhc_keys] + + # get dextramer counts + data = dex.loc[:, self.umi_cols_mhc] + self.specificity_to_idx = {s: i for i, s in enumerate(self.umi_cols_mhc)} + self.idx_to_specificity = {i: s for i, s in enumerate(self.umi_cols_mhc)} + + # Get clonotype information + self.ir_clone_key = ir_clone_key + data[self.ir_clone_key] = adata.obs[self.ir_clone_key].values # Calculate UMI count and delta for pMHCs, TRA and TRB. Nomenclature follows original implementation # umi_count_X = max(UMI count of X) @@ -93,15 +117,12 @@ def calc_delta(x): data['umi_count_mhc'] = data[self.umi_cols_mhc].max(1) data['delta_umi_mhc'] = data[self.umi_cols_mhc].apply(calc_delta, axis=1) data['umi_count_mhc_rel'] = data['umi_count_mhc'] / data['umi_count_mhc'].quantile(0.9, interpolation='lower') - if self.umi_cols_TRA is not None: - if data[[self.umi_cols_TRA]].shape[1] > 1: - data['umi_count_TRA'] = data[self.umi_cols_TRA].max(1) - data['delta_umi_TRA'] = data[self.umi_cols_TRA].apply(calc_delta) - if self.umi_cols_TRB is not None: - if data[[self.umi_cols_TRB]].shape[1] > 1: - data['umi_count_TRB'] = data[self.umi_cols_TRB].max(1) - data['delta_umi_TRB'] = data[self.umi_cols_TRB].apply(calc_delta) - + if umi_cols_TRA is not None: + data['umi_count_TRA'] = adata.obs[umi_cols_TRA].max(1) if len(umi_cols_TRA) > 1 else adata.obs[umi_cols_TRA].values + data['delta_umi_TRA'] = adata.obs[umi_cols_TRA].apply(calc_delta, axis=1) + if umi_cols_TRB is not None: + data['umi_count_TRB'] = adata.obs[umi_cols_TRB].max(1) if len(umi_cols_TRB) > 1 else adata.obs[umi_cols_TRB].values + data['delta_umi_TRB'] = adata.obs[umi_cols_TRB].apply(calc_delta, axis=1) self.data = data def fit(self): @@ -114,22 +135,17 @@ def fit(self): # Calculate ideal thresholds self.opt_thr = self._calculate_ideal_umi_thresholds(self.data) - def predict_posterior_class(self, threshold: float = None, target_fdr: float = None) -> Tuple[Array, Array]: + def assign_pmhc(self, adata=None) -> np.array: """ Returns the binder assignments based on the most abundant UMI count for each cell. To filter out noise, different filters are applied to the data. - ITRAP does not return a posterior probability, so the assignment is returned as pseudo value. - Threshold and target_fdr are ignored in this implementation. - - Args: - threshold: (Optional) ignored - target_fdr: (Optional) ignored Returns: - A tuple (p, assignment) of arrays with p being the pseudo value for compatibility of binding and assignment - the class assignment decision + An assignment array with the class assignment decision. + If adata is not none, the assignment will be added to adata.obsm['itrap_pMHC_assignment']. """ if self.opt_thr is None: - raise RuntimeError("Model has not been fit yet. Please call first `fit`.") + print("Model has not been fit yet. Finding optimal thresholds...") + self.fit() # Assign cells to most abundant pMHC based on UMI count, then set assignment to 0 if it fails filters filters = self._generate_filters(self.data) @@ -137,8 +153,12 @@ def predict_posterior_class(self, threshold: float = None, target_fdr: float = N self.data['assignment'] = self.data['assignment'].map(self.specificity_to_idx) self.data['assignment_before_filtering'] = self.data['assignment'].copy() self.data.loc[~filters, 'assignment'] = 0 + assignments = pd.Series(self.data['assignment'].values.astype(int)).map(self.idx_to_specificity).values - return self.data['assignment'].values.astype(int), self.data['assignment'].values.astype(float) + if adata is not None: + adata.obs['itrap_pMHC_assignment'] = assignments + adata.obsm['itrap_pMHC_assignment'] = pd.get_dummies(assignments).astype(int).set_index(adata.obs_names) + return assignments def _generate_filters(self, data): filters = pd.Series([True] * len(data), index=data.index) @@ -193,8 +213,8 @@ def _calculate_ideal_umi_thresholds(self, data): data['cell_specificity'] = data[self.umi_cols_mhc].idxmax(1).values # Calculate expected target for each clonotype - ct_pep = data.groupby(self.ir_clone_key).filter(lambda x: len(x) >= 10) - ct_pep = ct_pep.groupby(self.ir_clone_key).apply(self._calculate_expected_target).to_frame() + ct_pep = data.groupby(self.ir_clone_key, observed=True).filter(lambda x: len(x) >= 10) + ct_pep = ct_pep.groupby(self.ir_clone_key, observed=True).apply(self._calculate_expected_target).to_frame() ct_pep[['significant', 'expected_target']] = ct_pep[0].apply(pd.Series) ct_pep = ct_pep[ct_pep['significant']].drop(columns=0) From 628411ed60f4bf4e6e7ee1ea26d2407a6fba2b51 Mon Sep 17 00:00:00 2001 From: irene-bonapa Date: Tue, 31 Mar 2026 11:48:29 +0200 Subject: [PATCH 7/8] add additional filters --- dextrademixer/model/ITRAP.py | 60 +++++++++++++++++++++++++----------- 1 file changed, 42 insertions(+), 18 deletions(-) diff --git a/dextrademixer/model/ITRAP.py b/dextrademixer/model/ITRAP.py index 2ee8a10..602e1a6 100644 --- a/dextrademixer/model/ITRAP.py +++ b/dextrademixer/model/ITRAP.py @@ -35,7 +35,7 @@ def __init__(self, filters=None): """ Args: filters: List of filters to apply, options=['opt_thr', 'hashing_singlets', 'matching_HLA', 'complete_TCRs', - 'specificity_multiplets', 'is_cell', 'viable_cells'] (default: ['opt_thr']) + 'specificity_multiplets', 'is_cell'] (default: ['opt_thr']) """ super().__init__() self.opt_thr = None @@ -54,6 +54,9 @@ def preprocess_model_data( dex_key: str = "dex", ir_key: str = "airr", umi_cols_TRA: list=None, umi_cols_TRB: list=None, + is_cell_key: str = 'is_cell', + chain_pairing_key: str = 'chain_pairing', + hashing_classification_key: str = 'HTO_classification', **kwargs ): """ @@ -68,6 +71,9 @@ def preprocess_model_data( ir_key: the MuData module key where the immune receptor data is stored, only relevant if adata is a MuData object. umi_cols_TRA: list of strings specifying the columns in `obs` that hold the UMI counts for TRA, if available. If adata is a MuData object, these will be prefixed with `{ir_key}:` umi_cols_TRB: list of strings specifying the columns in `obs` that hold the UMI counts for TRB, if available. If adata is a MuData object, these will be prefixed with `{ir_key}:` + is_cell_key: string specifying the column in `obs` that indicates whether a barcode is classified as a cell, only relevant if 'is_cell' filter is applied. + chain_pairing_key: string specifying the column in `obs` that indicates whether a cell has complete TCR chain pairing, only relevant if 'complete_TCRs' filter is applied. + hashing_classification_key: string specifying the column in `obs` that indicates the hashing classification of a cell, only relevant if 'hashing_singlets' filter is applied. """ def calc_delta(x): """ Calculate UMI ratio of two most abundant pMHCs, 0.25 is a small constant to avoid division by zero""" @@ -88,10 +94,12 @@ def calc_delta(x): if isinstance(adata, md.MuData): dex = adata.mod[dex_key] dex = pd.DataFrame(dex.X.toarray(), index=dex.obs_names, columns=dex.var_names) - ir_clone_key = f'{ir_key}:{ir_clone_key}' - umi_cols_TRA = [f'{ir_key}:{col}' for col in umi_cols_TRA] if umi_cols_TRA is not None else None - umi_cols_TRB = [f'{ir_key}:{col}' for col in umi_cols_TRB] if umi_cols_TRB is not None else None + ir_clone_key = f'{ir_key}:{ir_clone_key}' if not ir_clone_key in adata.obs.columns else ir_clone_key + chain_pairing_key = f'{ir_key}:{chain_pairing_key}' if not chain_pairing_key in adata.obs.columns else chain_pairing_key + umi_cols_TRA = [f'{ir_key}:{col}' if not col in adata.obs.columns else col for col in umi_cols_TRA] if umi_cols_TRA is not None else None + umi_cols_TRB = [f'{ir_key}:{col}' if not col in adata.obs.columns else col for col in umi_cols_TRB] if umi_cols_TRB is not None else None adata.pull_obs() # make sure adata.obs is updated with prefixed columns from ir module + elif isinstance(adata, ad.AnnData): dex = adata.obsm[dex_key] @@ -107,9 +115,16 @@ def calc_delta(x): self.specificity_to_idx = {s: i for i, s in enumerate(self.umi_cols_mhc)} self.idx_to_specificity = {i: s for i, s in enumerate(self.umi_cols_mhc)} - # Get clonotype information + # Get clonotype information and filters self.ir_clone_key = ir_clone_key - data[self.ir_clone_key] = adata.obs[self.ir_clone_key].values + self.is_cell_key = is_cell_key if 'is_cell' in self.filters else None + self.chain_pairing_key = chain_pairing_key if 'complete_TCRs' in self.filters else None + self.hashing_classification_key = hashing_classification_key if 'hashing_singlets' in self.filters else None + for col in [self.ir_clone_key, self.is_cell_key, self.chain_pairing_key, self.hashing_classification_key]: + if col is not None: + if not col in adata.obs.columns: + raise ValueError(f"Filter {col} specified but column not found in adata.obs.") + data[col] = adata.obs[col].values # Calculate UMI count and delta for pMHCs, TRA and TRB. Nomenclature follows original implementation # umi_count_X = max(UMI count of X) @@ -135,10 +150,20 @@ def fit(self): # Calculate ideal thresholds self.opt_thr = self._calculate_ideal_umi_thresholds(self.data) - def assign_pmhc(self, adata=None) -> np.array: + def assign_pmhc( + self, adata=None, + is_cell_keep_values: List=[True], + chain_pairing_keep_values: List=['single pair', 'extra VDJ', 'extra VJ'], + hashing_classification_keep_values: List=['singlet', 'Singlet'], + ) -> np.array: """ Returns the binder assignments based on the most abundant UMI count for each cell. To filter out noise, different filters are applied to the data. + Args: + adata: If provided, the pMHC assignment will be added to adata.obs['itrap_pMHC_assignment'] and adata.obsm['itrap_pMHC_assignment']. + is_cell_keep_values: List of values in `is_cell_key` column that indicate a barcode is classified as a cell, only relevant if 'is_cell' filter is applied. + chain_pairing_keep_values: List of values in `chain_pairing_key` column that indicate a cell has complete TCR, only relevant if 'complete_TCRs' filter is applied. + hashing_classification_keep_values: List of values in `hashing_classification_key` column that indicate a cell is a singlet, only relevant if 'hashing_singlets' filter is applied. Returns: An assignment array with the class assignment decision. If adata is not none, the assignment will be added to adata.obsm['itrap_pMHC_assignment']. @@ -148,10 +173,10 @@ def assign_pmhc(self, adata=None) -> np.array: self.fit() # Assign cells to most abundant pMHC based on UMI count, then set assignment to 0 if it fails filters - filters = self._generate_filters(self.data) self.data['assignment'] = self.data[self.umi_cols_mhc].idxmax(1).values self.data['assignment'] = self.data['assignment'].map(self.specificity_to_idx) self.data['assignment_before_filtering'] = self.data['assignment'].copy() + filters = self._generate_filters(self.data, is_cell_keep_values, chain_pairing_keep_values, hashing_classification_keep_values) self.data.loc[~filters, 'assignment'] = 0 assignments = pd.Series(self.data['assignment'].values.astype(int)).map(self.idx_to_specificity).values @@ -160,7 +185,9 @@ def assign_pmhc(self, adata=None) -> np.array: adata.obsm['itrap_pMHC_assignment'] = pd.get_dummies(assignments).astype(int).set_index(adata.obs_names) return assignments - def _generate_filters(self, data): + def _generate_filters( + self, data, is_cell_keep_values, chain_pairing_keep_values, hashing_classification_keep_values, + ): filters = pd.Series([True] * len(data), index=data.index) # Filter 1: UMI thresholds @@ -172,7 +199,7 @@ def _generate_filters(self, data): # TODO Other filters are not implemented yet, only makes sense once we have the respective data # Filter 2: Hashing singlets if 'hashing_singlets' in self.filters: - raise NotImplementedError("Hashing singlets filter is not implemented yet.") + filters &= data[self.hashing_classification_key].isin(hashing_classification_keep_values).values # Filter 3: Matching HLA if 'matching_HLA' in self.filters: @@ -180,19 +207,16 @@ def _generate_filters(self, data): # Filter 4: Complete TCRs if 'complete_TCRs' in self.filters: - raise NotImplementedError("Complete TCRs filter is not implemented yet.") + filters &= data[self.chain_pairing_key].isin(chain_pairing_keep_values).values # Filter 5: Specificity multiplets if 'specificity_multiplets' in self.filters: - raise NotImplementedError("Specificity multiplets filter is not implemented yet.") + multiplets = data.groupby([self.ir_clone_key, 'assignment_before_filtering'], observed=True).size() > 1 + filters &= data.set_index([self.ir_clone_key, 'assignment_before_filtering']).index.map(multiplets).values - # Filter 6: Is cell (Cellranger) + # Filter 6: Is cell (GEX/cellranger/TCR) - user defined if 'is_cell' in self.filters: - raise NotImplementedError("Is cell filter is not implemented yet.") - - # Filter 7: Viable cells (GEX) - if 'viable_cells' in self.filters: - raise NotImplementedError("Viable cells filter is not implemented yet.") + filters &= data[self.is_cell_key].isin(is_cell_keep_values).values return filters From 988dbc5232ef8b446502ef96d30809d9ddd6c04d Mon Sep 17 00:00:00 2001 From: irene-bonapa Date: Tue, 31 Mar 2026 11:54:34 +0200 Subject: [PATCH 8/8] update test --- dextrademixer/test/TestITRAPModel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dextrademixer/test/TestITRAPModel.py b/dextrademixer/test/TestITRAPModel.py index 6f54463..1c57e60 100644 --- a/dextrademixer/test/TestITRAPModel.py +++ b/dextrademixer/test/TestITRAPModel.py @@ -22,7 +22,7 @@ def test_ITRAP(self): itrap = ITRAP() itrap.preprocess_model_data(self.mdata, "pmhc1", neg_ctrl_key="neg_control", ir_clone_key="clone_id") itrap.fit() - p, assignment = itrap.predict_posterior_class(target_fdr=0.05) + assignment = itrap.assign_pmhc() print(assignment) print(self.binder) N = len(self.binder)