From 6a30c7121b6004b18343555eed2798347528afb2 Mon Sep 17 00:00:00 2001 From: "benjamin.schubert" Date: Wed, 29 Jan 2025 17:02:22 +0100 Subject: [PATCH 01/13] - ICON implementation (unfinished) --- dextrademixer/model/ICON.py | 115 +++-------------- dextrademixer/model/__init__.py | 1 + dextrademixer/test/TestThresholdAssignment.py | 116 +++++++++++++++++- dextrademixer/utils/utils.py | 33 +++++ 4 files changed, 165 insertions(+), 100 deletions(-) diff --git a/dextrademixer/model/ICON.py b/dextrademixer/model/ICON.py index 95f4bf4..3376a5c 100644 --- a/dextrademixer/model/ICON.py +++ b/dextrademixer/model/ICON.py @@ -14,94 +14,15 @@ from dextrademixer.utils import calculate_pmhc_clonal_purity -def threshold_assign_pmhc(mdata: md.MuData, - threshold: float, - threshold_type: str = "absolute", - pmhc_keys: Union[str, List[str]] = None, - total_normalization: bool = False, - target_sum: float = None, - z_score_normalization: bool = False, - gex_key: str = "gex", - neg_ctrl_key: str = None, - ir_key: str = "airr", - ir_clone_key: str = None, - inplace=False): - """ - Assigns dextramer-specificities based on specified threshold. - Depending on additional information provided different assignment strategies are applied. - - - Args: - mdata: A Mudata containing only dextramer counts and clonotype information - threshold: A UMI count, or relative threshold to determine dextramer-specificity - threshold_type: A string specifiying whether the threshold is absolut or relative. if relative than X in gex_key - will be normalized by the column means - pmhc_keys (Optional): A string or list of strings indicating the pMHC columns in `gex_key` modality`s `X` which - should be deconvolved. If None is given, the full X is used - total_normalization: boolean whether or not normalization of each cell by total counts over all pMHCs - (including negative control) should be applied, so that every cell has the same total count - after normalization. - target_sum: If None, after normalization, each observation (cell) has a total count equal to the median of total - counts for observations (cells) before normalization. - z_score_normalization: z-score normalize within pMHC across cells - gex_key: the MuData transcriptome module key - neg_ctrl_key: (Optional) a string specifying the negative control column in `gex_key` modality`s `X` - ir_key: the MuData AIRR module key - ir_clone_key: (Optional) a string specifying the field in `obs` of `ir_key` that holds clonotype ids - inplace: boolean indicating whether assignment should be stored in mdata on `gex_key` `obsm` - kwargs: dictionary of additional information pasted to the Model object (used for custom model prior) - - Returns: An array of pMHC assignments per cell, or modifies the mdata object adding an obsm matrix at `gex_key` - """ - - gex = mdata.mod[gex_key] - air = mdata.mod[ir_key] - N = gex.shape[0] - - if pmhc_keys is None: - pmhc_keys = gex.var.index - - if neg_ctrl_key is not None and neg_ctrl_key not in pmhc_keys: - pmhc_keys.append(neg_ctrl_key) - - gex_filtered = gex[:, pmhc_keys] - - if total_normalization: - x = sc.pp.normalize_total(gex_filtered, target_sum=target_sum, inplace=False)['X'] - else: - x = gex_filtered.X.toarray() - - if z_score_normalization: - x = (x - jnp.nanmean(x, axis=0, keepdims=True)) / jnp.nanstd(x, axis=0, keepdims=True) - - neg_ctrl_key_idx = gex.var.index.tolist().index(neg_ctrl_key) if neg_ctrl_key else None - x_neg = x[:, neg_ctrl_key_idx].reshape((N,)) if neg_ctrl_key else None - x = jnp.delete(x, neg_ctrl_key_idx, axis=1) if neg_ctrl_key else x - c = air.obs[ir_clone_key].to_numpy().astype("int32") if ir_clone_key is not None else None - - if threshold_type == "relative" and neg_ctrl_key is None: - x = x / (x.sum(axis=1, keepdims=True)+1e-8) - elif threshold_type == "relative" and neg_ctrl_key is not None: - x = x / (jnp.nanmax(x_neg)+1e-8) - elif threshold_type == "absolut" and neg_ctrl_key is not None: - x = x - jnp.nanmax(x_neg) - - assignment = ((x == jnp.max(x, axis=1, keepdims=True)) & (x >= threshold)).astype("uint8") - - if inplace: - mdata.mod[gex_key].obsm["pMHC_assignment"] = assignment - else: - return assignment - - def icon_assign_pmhc(mdata: md.MuData, - neg_ctrl_key: str, - ir_clone_key: str, - threshold: float = 0, - pmhc_keys: Union[str, List[str]] = None, - gex_key: str = "gex", - ir_key: str = "airr", - inplace=False): + ir_clone_key: str, + neg_ctrl_key: str = None, + threshold: float = 0, + bg_noise: float = None, + pmhc_keys: Union[str, List[str]] = None, + gex_key: str = "gex", + ir_key: str = "airr", + inplace=False): """ implements the ICON assignment procedure @@ -124,23 +45,23 @@ def icon_assign_pmhc(mdata: md.MuData, """ gex = mdata.mod[gex_key] air = mdata.mod[ir_key] - N = gex.shape[0] if pmhc_keys is None: - pmhc_keys = gex.var.index + pmhc_keys = gex.var_names - neg_ctrl_key_idx = gex.var.index(neg_ctrl_key) + if bg_noise is None and neg_ctrl_key is None: + bg_noise = 10 - X = gex[:, pmhc_keys].X.toarray() - x_neg = gex.X[:, neg_ctrl_key_idx].max() if neg_ctrl_key else None - c = air.obs[ir_clone_key].to_numpy().astype("int32") if ir_clone_key is not None else None + X = jnp.array(gex[:, pmhc_keys].X.toarray()) + x_neg = gex.X[:, neg_ctrl_key].max() if bg_noise is None else bg_noise + c = air.obs[ir_clone_key].to_numpy().astype("int32") # Subtract background noise E = X - x_neg E = E.at[E < 0].set(0) # calc pMHC ratio per cell - C = E / (E.sum(axis=1, keepdims=True)+1) + C = E / (E.sum(axis=1, keepdims=True) + 1) # raw assignment with UMI > 0 rA = (E > 0).astype("int32") @@ -148,9 +69,9 @@ def icon_assign_pmhc(mdata: md.MuData, # calc clonotype purity R = calculate_pmhc_clonal_purity(rA, c) - S = jnp.log(E + 0.01) * C ** 2 * R + S = jnp.log(E + 0.01) * (C ** 2) * R S = jnp.nan_to_num(S) - S = S.at[S < 0].set(0) + S = S.at[S < 1].set(0) # pMHC-wise log-ratio normalization per cell colSum = S.sum(axis=1, keepdims=True) @@ -158,7 +79,7 @@ def icon_assign_pmhc(mdata: md.MuData, S = S / colSum # cell-wise z-score normalization - S = (S - jnp.nanmean(S)) / jnp.nanstd(S) + S = (S - jnp.nanmean(S, axis=0, keepdims=True)) / jnp.nanstd(S, axis=0, keepdims=True) S = jnp.nan_to_num(S, nan=jnp.nanmin(S)) assignment = (S > threshold).astype("uint8") diff --git a/dextrademixer/model/__init__.py b/dextrademixer/model/__init__.py index 3cf6620..ea5330e 100644 --- a/dextrademixer/model/__init__.py +++ b/dextrademixer/model/__init__.py @@ -1,3 +1,4 @@ from dextrademixer.model.ApMHCDeconvolution import ApMHCDeconvolution from dextrademixer.model.Dextrademixer import DextraDemixer from dextrademixer.model.BEAMT import BEAMT +from dextrademixer.model.ICON import icon_assign_pmhc diff --git a/dextrademixer/test/TestThresholdAssignment.py b/dextrademixer/test/TestThresholdAssignment.py index 052b47e..834483b 100644 --- a/dextrademixer/test/TestThresholdAssignment.py +++ b/dextrademixer/test/TestThresholdAssignment.py @@ -1,8 +1,118 @@ import unittest +import numpy as np +import jax.numpy as jnp + +from dextrademixer.model import threshold_assign_pmhc +from dextrademixer.utils import calculate_pmhc_clonal_purity, DextramerSimulator + + +class TestThresholdAssignment(unittest.TestCase): + def test_clonal_purity(self): + assignment = np.array([[1,0,0], + [1,0,0], + [1,1,0]]) + clonotypes = np.array([0,0,1]) + + purity = calculate_pmhc_clonal_purity(assignment, clonotypes) + self.assertTrue(np.allclose(purity, np.array([[1,0,0], + [1,0,0], + [0.5,0.5,0]]))) + + def test_threshold_based_assignment(self): + sim = DextramerSimulator() + mdat = sim.simulate_pmhc_data_from_distribution(total_cells=10, + nof_clones=3, + binding_ratio=0.5, + simulate_neg_control=True, + use_clonotype_cov=False, + binding_fold_increase_range=[100], + variance_fold_increase_range=[1.2], + plot_data=False) + assignment = threshold_assign_pmhc(mdat, 10) + print(mdat.mod["airr"].obs.clone_id) + print(mdat.mod["gex"].X) + print(assignment) + + def test_threshold_based_assignment_inplace(self): + sim = DextramerSimulator() + mdat = sim.simulate_pmhc_data_from_distribution(total_cells=10, + nof_clones=3, + binding_ratio=0.5, + simulate_neg_control=True, + use_clonotype_cov=False, + binding_fold_increase_range=[100], + variance_fold_increase_range=[1.2], + plot_data=False) + assignment = threshold_assign_pmhc(mdat, 10, neg_ctrl_key="neg_control", inplace=True) + print(mdat.mod["gex"].X) + print(mdat.mod["gex"].obsm["pMHC_assignment"]) + + + def test_threshold_based_assignment_total_normalization(self): + sim = DextramerSimulator() + mdat = sim.simulate_pmhc_data_from_distribution(total_cells=10, + nof_clones=3, + binding_ratio=0.5, + simulate_neg_control=True, + use_clonotype_cov=False, + binding_fold_increase_range=[100], + variance_fold_increase_range=[1.2], + plot_data=False) + assignment = threshold_assign_pmhc(mdat, 5, neg_ctrl_key="neg_control", + total_normalization=True, + target_sum=10e6) + print(mdat.mod["gex"].X) + print(assignment) + + def test_threshold_based_assignment_z_score(self): + sim = DextramerSimulator() + mdat = sim.simulate_pmhc_data_from_distribution(total_cells=10, + nof_clones=3, + binding_ratio=0.5, + simulate_neg_control=True, + use_clonotype_cov=False, + binding_fold_increase_range=[100], + variance_fold_increase_range=[1.2], + plot_data=False) + assignment = threshold_assign_pmhc(mdat, 0, neg_ctrl_key="neg_control", + z_score_normalization=True) + print(mdat.mod["gex"].X) + print(assignment) + + def test_threshold_based_assignment_z_score_and_total(self): + sim = DextramerSimulator() + mdat = sim.simulate_pmhc_data_from_distribution(total_cells=10, + nof_clones=3, + binding_ratio=0.5, + simulate_neg_control=True, + use_clonotype_cov=False, + binding_fold_increase_range=[100], + variance_fold_increase_range=[1.2], + plot_data=False) + assignment = threshold_assign_pmhc(mdat, 0, neg_ctrl_key="neg_control", + total_normalization=True, + z_score_normalization=True) + print(mdat.mod["gex"].X) + print(assignment) + + def test_threshold_based_assignment_relative_threshold(self): + sim = DextramerSimulator() + mdat = sim.simulate_pmhc_data_from_distribution(total_cells=10, + nof_clones=3, + binding_ratio=0.5, + simulate_neg_control=True, + use_clonotype_cov=False, + binding_fold_increase_range=[100], + variance_fold_increase_range=[1.2], + plot_data=False) + assignment = threshold_assign_pmhc(mdat, + threshold=0.5, + z_score_normalization=True, + threshold_type="relative") + print(mdat.mod["gex"].X) + print(assignment) + -class MyTestCase(unittest.TestCase): - def test_something(self): - self.assertEqual(True, False) # add assertion here if __name__ == '__main__': unittest.main() diff --git a/dextrademixer/utils/utils.py b/dextrademixer/utils/utils.py index e3f8c02..ef4c1d0 100644 --- a/dextrademixer/utils/utils.py +++ b/dextrademixer/utils/utils.py @@ -14,6 +14,39 @@ import scirpy as ir +def calculate_pmhc_clonal_purity(assignment, + clonotypes): + """ + Calculates the pMHC purity of each clonotype, i.e., the fraction of cells of a clonotype being assigned to + a specific pMHC (accounting for multiple assignments). + + Args: + assignment: pMHC assignments of each cell + clonotypes: clonotype assignment of each cell + + Returns matrix cell x pMHC with clonotype purity for each pMHC + """ + unique_clonotypes = jnp.unique(clonotypes) + + def compute_clonotype_purity(c): + # Create a mask for the cells belonging to the current clonotype + mask = clonotypes == c + mask = mask.astype(jnp.float32) + + # Sum assignments for the current clonotype avoiding a where statement + Tki = (assignment * mask[:, None]).sum(axis=0, keepdims=True) + + # Normalize the assignment for the clonotype + purity = jnp.nan_to_num(Tki / Tki.sum()) + + # Reapply the mask to distribute purity values back to the original matrix + return purity * mask[:, None] + + # Compute purity for each clonotype and sum the results + purity_matrix = jax.vmap(compute_clonotype_purity)(unique_clonotypes) + return purity_matrix.sum(axis=0) + + def gower_centering(distance_matrix): """ Applies Gower's 1966 centering method to the distance matrix to obtain a covariance matrix. From 2e4b156c3e2e17dc3f3e2b8255ed007c41a87bec Mon Sep 17 00:00:00 2001 From: Yang Date: Thu, 13 Feb 2025 20:48:07 +0100 Subject: [PATCH 02/13] init ITRAP --- dextrademixer/model/ITRAP.py | 316 +++++++++++++++++++++++++++ dextrademixer/test/TestITRAPModel.py | 35 +++ 2 files changed, 351 insertions(+) create mode 100644 dextrademixer/model/ITRAP.py create mode 100644 dextrademixer/test/TestITRAPModel.py diff --git a/dextrademixer/model/ITRAP.py b/dextrademixer/model/ITRAP.py new file mode 100644 index 0000000..2e94356 --- /dev/null +++ b/dextrademixer/model/ITRAP.py @@ -0,0 +1,316 @@ +from __future__ import annotations + +import os.path +from typing import TYPE_CHECKING, Tuple + +import mudata as md +import pandas as pd + +import numpy as np +import jax.lax +import jax +from scipy import stats + +from dextrademixer.model import ApMHCDeconvolution + +if TYPE_CHECKING: + from jax._src.typing import Array + + +class ITRAP(ApMHCDeconvolution): + """ + This class implements the ITRAP algorithm introduced by Povlsen et al. (2023). + First each clonotype with more than 10 cells is assigned an expected target if the highest UMI count is + significantly higher than the second most abundant pMHC using Wilcoxon p < 0.05. + Each cell's specificity is then assigned to the most abundant pMHC based on UMI count. + Using this expected target per clonotype, ITRAP calculates ideal UMI thresholds using a grid-search by optimizing + the accuracy (if the epitope with highest UMI count of a cell matches the expected target) while preserving the + ratio of retained cells using a weighted average between both objectives. + The optimal thresholds are then used to filter cells. Further filtering steps may be + included if the respective data, e.g., donor HLA, is available. + """ + __name = "ITRAP" + __version = "0.0.1" + + def __init__(self, umi_cols=None, umi_count_TRA=None, umi_count_TRB=None, filters=None): + """ + Args: + umi_cols: List of columns containing UMI counts for pMHCs (default set to ['neg_control', 'pmhc1']) + umi_count_TRA: List of columns containing UMI counts for TRA (default: None) + umi_count_TRB: List of columns containing UMI counts for TRB (default: None) + filters: List of filters to apply, options=['opt_thr', 'hashing_singlets', 'matching_HLA', 'complete_TCRs', + 'specificity_multiplets', 'is_cell', 'viable_cells'] (default: ['opt_thr']) + """ + super().__init__() + self.opt_thr = None + self.umi_cols_mhc = umi_cols + self.umi_cols_TRA = umi_count_TRA + self.umi_cols_TRB = umi_count_TRB + self.filters = filters if filters is not None else ['opt_thr'] + self.data = None + self.ir_clone_key = None + self.specificity_to_idx = None + self.idx_to_specificity = None + + def preprocess_model_data(self, mdata: md.MuData, pmhc_key: str, gex_key: str = "gex", neg_ctrl_key: str = None, + ir_key: str = "airr", ir_clone_key: str = None, ir_cov_key: str = None, **kwargs): + if ir_clone_key is None: + raise ValueError(f"{self.__name} requires a clonotype definition. Please specify a `ir_clone_key`.") + + gex = mdata.mod[gex_key] + N = gex.shape[0] + + x = gex[:, pmhc_key].X.toarray().reshape((N,)) + x_neg = gex[:, neg_ctrl_key].X.toarray().reshape((N,)) + + self._check_parameters(x, x_neg, None, None) + self.ir_clone_key = ir_clone_key + + if self.umi_cols_mhc is None: + if neg_ctrl_key is None: + raise ValueError("No negative control specified and no umi_cols_mhc. Please provide a `neg_ctrl_key` " + "or set umi_cols_mhc during initialization.") + self.umi_cols_mhc = [neg_ctrl_key, pmhc_key] + self.specificity_to_idx = {s: i for i, s in enumerate(self.umi_cols_mhc)} + self.idx_to_specificity = {i: s for i, s in enumerate(self.umi_cols_mhc)} + + data = mdata['airr'].obs.copy() + for col in self.umi_cols_mhc: + data[col] = mdata['gex'][:, col].X.toarray().reshape(-1) + + def calc_delta(x): + """ Calculate UMI ratio of two most abundant pMHCs, 0.25 is a small constant to avoid division by zero""" + if len(x) == 1: + return x[-1] / 0.25 + elif len(x) == 0: + return 0 + else: + return (x.nlargest(2).iloc()[0]) / (x.nlargest(2).iloc()[1] + 0.25) + + # Calculate UMI count and delta for pMHCs, TRA and TRB. Nomenclature follows original implementation + # umi_count_X = max(UMI count of X) + # delta_umi_X = ratio between highest and second highest UMI counts + data['umi_count_mhc'] = data[self.umi_cols_mhc].max(1) + data['delta_umi_mhc'] = data[self.umi_cols_mhc].apply(calc_delta, axis=1) + data['umi_count_mhc_rel'] = data['umi_count_mhc'] / data['umi_count_mhc'].quantile(0.9, interpolation='lower') + if self.umi_cols_TRA is not None: + data['umi_count_TRA'] = data[self.umi_cols_TRA].max(1) + data['delta_umi_TRA'] = data[self.umi_cols_TRA].apply(calc_delta) + if self.umi_cols_TRB is not None: + data['umi_count_TRB'] = data[self.umi_cols_TRA].max(1) + data['delta_umi_TRB'] = data[self.umi_cols_TRB].apply(calc_delta) + + self.data = data + + def fit(self): + """ + Fit the ITRAP model to the data. Calculate the ideal UMI thresholds for filtering + """ + if self.data is None: + raise Exception("Model is not initialized. Please call `preprocess_model_data` first.") + + # Calculate ideal thresholds + self.opt_thr = self._calculate_ideal_umi_thresholds(self.data) + + def predict_posterior_class(self, threshold: float = None, target_fdr: float = None) -> Tuple[Array, Array]: + """ + Returns the binder assignments based on the most abundant UMI count for each cell. + To filter out noise, different filters are applied to the data. + ITRAP does not return a posterior probability, so the assignment is returned as pseudo value. + Threshold and target_fdr are ignored in this implementation. + + Args: + threshold: (Optional) ignored + target_fdr: (Optional) ignored + Returns: + A tuple (p, assignment) of arrays with p being the pseudo value for compatibility of binding and assignment + the class assignment decision + """ + if self.opt_thr is None: + raise RuntimeError("Model has not been fit yet. Please call first `fit`.") + + # Assign cells to most abundant pMHC based on UMI count, then set assignment to 0 if it fails filters + filters = self._generate_filters(self.data) + self.data['assignment'] = self.data[self.umi_cols_mhc].idxmax(1).values + self.data['assignment'] = self.data['assignment'].map(self.specificity_to_idx) + self.data['assignment_before_filtering'] = self.data['assignment'].copy() + self.data.loc[~filters, 'assignment'] = 0 + + return self.data['assignment'].values.astype(int), self.data['assignment'].values.astype(float) + + def _generate_filters(self, data): + filters = pd.Series([True] * len(data), index=data.index) + + # Filter 1: UMI thresholds + if 'opt_thr' in self.filters: + for k, thr in self.opt_thr.items(): + if k in data.columns: + filters &= data[k] >= thr + # filters &= eval(' & '.join([f'(data["{k}"] >= {abs(v)})' for k, v in self.opt_thr.items() if k in data.columns])) + + # TODO Other filters are not implemented yet + # Filter 2: Hashing singlets + if 'hashing_singlets' in self.filters: + raise NotImplementedError("Hashing singlets filter is not implemented yet.") + + # Filter 3: Matching HLA + if 'matching_HLA' in self.filters: + raise NotImplementedError("Matching HLA filter is not implemented yet.") + + # Filter 4: Complete TCRs + if 'complete_TCRs' in self.filters: + raise NotImplementedError("Complete TCRs filter is not implemented yet.") + + # Filter 5: Specificity multiplets + if 'specificity_multiplets' in self.filters: + raise NotImplementedError("Specificity multiplets filter is not implemented yet.") + + # Filter 6: Is cell (Cellranger) + if 'is_cell' in self.filters: + raise NotImplementedError("Is cell filter is not implemented yet.") + + # Filter 7: Viable cells (GEX) + if 'viable_cells' in self.filters: + raise NotImplementedError("Viable cells filter is not implemented yet.") + + return filters + + def _calculate_expected_target(self, data): + # Select two most abundant pMHC based on UMI count + most_abundant_epitope = data[self.umi_cols_mhc].sum(0).nlargest(2).index + + w, p = stats.wilcoxon(data[most_abundant_epitope[0]].fillna(0) - data[most_abundant_epitope[1]].fillna(0), + alternative='greater') + + if p <= 0.05: + return True, most_abundant_epitope[0] + else: + return False, most_abundant_epitope[0] + + def _calculate_ideal_umi_thresholds(self, data): + # TODO What if tie in specificity? Should we just take the first one? + data['cell_specificity'] = data[self.umi_cols_mhc].idxmax(1).values + + # Calculate expected target for each clonotype + ct_pep = data.groupby(self.ir_clone_key).filter(lambda x: len(x) >= 10) + ct_pep = ct_pep.groupby(self.ir_clone_key).apply(self._calculate_expected_target).to_frame() + ct_pep[['significant', 'expected_target']] = ct_pep[0].apply(pd.Series) + ct_pep = ct_pep[ct_pep['significant']].drop(columns=0) + + # Add expected target of each clonotype to full data and filter out non-significant clonotype targets + data['ct_pep'] = data[self.ir_clone_key].map(ct_pep['expected_target']) + cells_with_ct_pep = data[data['ct_pep'].notna()].copy() + cells_with_ct_pep['pep_match'] = cells_with_ct_pep['cell_specificity'] == cells_with_ct_pep['ct_pep'] + + # Grid search for optimal UMI threshold, hparams extracted from ITRAP code + if self.umi_cols_TRA is None: + umi_count_TRA_l = [None] + delta_umi_TRA_l = [None] + else: + umi_count_TRA_l = np.arange(0, data['umi_count_TRA'].quantile(0.4, interpolation='higher')) + delta_umi_TRA_l = np.arange(0, 4) + if self.umi_cols_TRB is None: + umi_count_TRB_l = [None] + delta_umi_TRB_l = [None] + else: + umi_count_TRB_l = np.arange(0, data['umi_count_TRB'].quantile(0.4, interpolation='higher')) + delta_umi_TRB_l = np.arange(0, 4) + # TODO Why start from 1 in original implementation? + umi_count_mhc_l = np.arange(1, data['umi_count_mhc'].quantile(0.5, interpolation='higher')) + delta_umi_mhc_l = [0, 1, 2] # hparam from itrap Snakefile + umi_relat_mhc_l = [0] # seems unused in original implementation + + table = pd.DataFrame(columns=['accuracy', 'ratio_retained_gems', 'umi_count_mhc', 'umi_relat_mhc_l', + 'delta_umi_mhc', 'umi_count_TRA', 'delta_umi_TRA', 'umi_count_TRB', + 'delta_umi_TRB',]) + + n_total_gems = len(cells_with_ct_pep) + + i = -1 + for uca in umi_count_TRA_l: + for dua in delta_umi_TRA_l: + for ucb in umi_count_TRB_l: + for dub in delta_umi_TRB_l: + for ucm in umi_count_mhc_l: + for urm in umi_relat_mhc_l: + for dum in delta_umi_mhc_l: + i += 1 + filter_bool = ((cells_with_ct_pep['umi_count_mhc'] >= ucm) & + (cells_with_ct_pep['delta_umi_mhc'] >= dum) & + (cells_with_ct_pep['umi_count_mhc_rel'] >= urm)) + + if self.umi_cols_TRA is not None: + filter_bool &= (cells_with_ct_pep['umi_count_TRA'] >= uca) & ( + cells_with_ct_pep['delta_umi_TRA'] >= dua) + if self.umi_cols_TRB is not None: + filter_bool &= (cells_with_ct_pep['umi_count_TRB'] >= ucb) & ( + cells_with_ct_pep['delta_umi_TRB'] >= dub) + + flt = cells_with_ct_pep[filter_bool].copy() + + n_gems = len(flt) + n_mat = flt['pep_match'].sum() + + g_ratio = round(n_gems / n_total_gems, 3) + acc = round(n_mat / n_gems, 3) + + table.loc[i] = (acc, g_ratio, ucm, urm, dum, uca, dua, ucb, dub,) + + table['mix_mean'] = (table['accuracy'] * 2 + table['ratio_retained_gems']) / 3 + optimal_thresholds = (table.sort_values(by=['mix_mean', 'accuracy', 'ratio_retained_gems', 'umi_count_mhc'], + ascending=[True, True, True, False])) + opt_thr = optimal_thresholds.iloc()[-1][['umi_count_mhc', 'delta_umi_mhc', 'umi_count_TRA', + 'delta_umi_TRA', 'umi_count_TRB', 'delta_umi_TRB']] + + return opt_thr + + +if __name__ == "__main__": + from dextrademixer.utils import DextramerSimulator + import muon as mu + + + if os.path.exists('../../data/test.h5mu'): + mdata = mu.read('../../data/test.h5mu') + else: + + sim = DextramerSimulator() + mdata = sim.simulate_pmhc_data_from_distribution(total_cells=10000, + nof_clones=150, + p_binding_outlier=0.05, + binding_ratio=0.1, + binding_fold_increase_range=[5], + variance_fold_increase_range=[1.2], + simulate_neg_control=True, + use_clonotype_cov=True, + plot_data=False, + rng_key=42) + mdata.write('../../mdata/test.h5mu') + + itrap = ITRAP(filters=['opt_thr']) + itrap.preprocess_model_data(mdata, "pmhc1", neg_ctrl_key="neg_control", ir_key="airr", ir_clone_key='clone_id') + itrap.fit() + p, assignment = itrap.predict_posterior_class() + print(assignment) + + +from dextrademixer.utils.simulation import DextramerSimulator + +sim = DextramerSimulator() +mdata = sim.simulate_pmhc_data_from_distribution(total_cells=500, nof_clones=10, binding_ratio=0.05, + simulate_neg_control=True, rng_key=42 + ) + + + +binder = mdata.mod["airr"].obs["is_binder"].to_numpy() + +itrap = ITRAP() +itrap.preprocess_model_data(mdata, "pmhc1", neg_ctrl_key="neg_control", ir_clone_key="clone_id") +itrap.fit() +p, assignment = itrap.predict_posterior_class(target_fdr=0.05) +print(assignment) +print(binder) +N = len(binder) +accuracy = (binder == assignment).sum() / N +print("Accuracy", accuracy) diff --git a/dextrademixer/test/TestITRAPModel.py b/dextrademixer/test/TestITRAPModel.py new file mode 100644 index 0000000..ee83710 --- /dev/null +++ b/dextrademixer/test/TestITRAPModel.py @@ -0,0 +1,35 @@ +import unittest + +import numpy as np +import muon as mu +import pandas as pd + +from dextrademixer.model.ITRAP import ITRAP +from dextrademixer.utils.simulation import DextramerSimulator + + +class MyTestCase(unittest.TestCase): + + def setUp(self): + sim = DextramerSimulator() + self.mdata = sim.simulate_pmhc_data_from_distribution(total_cells=500, nof_clones=10, binding_ratio=0.05, + simulate_neg_control=True, rng_key=42 + ) + + self.binder = self.mdata.mod["airr"].obs["is_binder"].to_numpy() + + def test_ITRAP(self): + itrap = ITRAP() + itrap.preprocess_model_data(self.mdata, "pmhc1", neg_ctrl_key="neg_control", ir_clone_key="clone_id") + itrap.fit() + p, assignment = itrap.predict_posterior_class(target_fdr=0.05) + print(assignment) + print(self.binder) + N = len(self.binder) + accuracy = (self.binder == assignment).sum() / N + print("Accuracy", accuracy) + + + +if __name__ == '__main__': + unittest.main() From 7de0dc70d9a8547e5b163972069e7f2378e72d11 Mon Sep 17 00:00:00 2001 From: Yang Date: Thu, 13 Feb 2025 20:55:24 +0100 Subject: [PATCH 03/13] refactor --- dextrademixer/model/ITRAP.py | 64 ++++++++-------------------- dextrademixer/test/TestITRAPModel.py | 1 - 2 files changed, 18 insertions(+), 47 deletions(-) diff --git a/dextrademixer/model/ITRAP.py b/dextrademixer/model/ITRAP.py index 2e94356..1cdc693 100644 --- a/dextrademixer/model/ITRAP.py +++ b/dextrademixer/model/ITRAP.py @@ -148,7 +148,7 @@ def _generate_filters(self, data): filters &= data[k] >= thr # filters &= eval(' & '.join([f'(data["{k}"] >= {abs(v)})' for k, v in self.opt_thr.items() if k in data.columns])) - # TODO Other filters are not implemented yet + # TODO Other filters are not implemented yet, only makes sense once we have the respective data # Filter 2: Hashing singlets if 'hashing_singlets' in self.filters: raise NotImplementedError("Hashing singlets filter is not implemented yet.") @@ -188,7 +188,7 @@ def _calculate_expected_target(self, data): return False, most_abundant_epitope[0] def _calculate_ideal_umi_thresholds(self, data): - # TODO What if tie in specificity? Should we just take the first one? + # In case of a tie, in default params negative control is the first column and hence chosen as most abundant data['cell_specificity'] = data[self.umi_cols_mhc].idxmax(1).values # Calculate expected target for each clonotype @@ -215,7 +215,7 @@ def _calculate_ideal_umi_thresholds(self, data): else: umi_count_TRB_l = np.arange(0, data['umi_count_TRB'].quantile(0.4, interpolation='higher')) delta_umi_TRB_l = np.arange(0, 4) - # TODO Why start from 1 in original implementation? + umi_count_mhc_l = np.arange(1, data['umi_count_mhc'].quantile(0.5, interpolation='higher')) delta_umi_mhc_l = [0, 1, 2] # hparam from itrap Snakefile umi_relat_mhc_l = [0] # seems unused in original implementation @@ -266,51 +266,23 @@ def _calculate_ideal_umi_thresholds(self, data): if __name__ == "__main__": - from dextrademixer.utils import DextramerSimulator - import muon as mu - - - if os.path.exists('../../data/test.h5mu'): - mdata = mu.read('../../data/test.h5mu') - else: - - sim = DextramerSimulator() - mdata = sim.simulate_pmhc_data_from_distribution(total_cells=10000, - nof_clones=150, - p_binding_outlier=0.05, - binding_ratio=0.1, - binding_fold_increase_range=[5], - variance_fold_increase_range=[1.2], - simulate_neg_control=True, - use_clonotype_cov=True, - plot_data=False, - rng_key=42) - mdata.write('../../mdata/test.h5mu') - - itrap = ITRAP(filters=['opt_thr']) - itrap.preprocess_model_data(mdata, "pmhc1", neg_ctrl_key="neg_control", ir_key="airr", ir_clone_key='clone_id') - itrap.fit() - p, assignment = itrap.predict_posterior_class() - print(assignment) - + from dextrademixer.utils.simulation import DextramerSimulator -from dextrademixer.utils.simulation import DextramerSimulator + sim = DextramerSimulator() + mdata = sim.simulate_pmhc_data_from_distribution(total_cells=500, nof_clones=10, binding_ratio=0.05, + simulate_neg_control=True, rng_key=42 + ) -sim = DextramerSimulator() -mdata = sim.simulate_pmhc_data_from_distribution(total_cells=500, nof_clones=10, binding_ratio=0.05, - simulate_neg_control=True, rng_key=42 - ) + binder = mdata.mod["airr"].obs["is_binder"].to_numpy() -binder = mdata.mod["airr"].obs["is_binder"].to_numpy() - -itrap = ITRAP() -itrap.preprocess_model_data(mdata, "pmhc1", neg_ctrl_key="neg_control", ir_clone_key="clone_id") -itrap.fit() -p, assignment = itrap.predict_posterior_class(target_fdr=0.05) -print(assignment) -print(binder) -N = len(binder) -accuracy = (binder == assignment).sum() / N -print("Accuracy", accuracy) + itrap = ITRAP() + itrap.preprocess_model_data(mdata, "pmhc1", neg_ctrl_key="neg_control", ir_clone_key="clone_id") + itrap.fit() + p, assignment = itrap.predict_posterior_class(target_fdr=0.05) + print(assignment) + print(binder) + N = len(binder) + accuracy = (binder == assignment).sum() / N + print("Accuracy", accuracy) diff --git a/dextrademixer/test/TestITRAPModel.py b/dextrademixer/test/TestITRAPModel.py index ee83710..6f54463 100644 --- a/dextrademixer/test/TestITRAPModel.py +++ b/dextrademixer/test/TestITRAPModel.py @@ -30,6 +30,5 @@ def test_ITRAP(self): print("Accuracy", accuracy) - if __name__ == '__main__': unittest.main() From 9a0795ba066423de5071094cb67b69fa332ec0b5 Mon Sep 17 00:00:00 2001 From: Yang Date: Fri, 14 Feb 2025 10:12:46 +0100 Subject: [PATCH 04/13] refactor: remove debugging block --- dextrademixer/model/ITRAP.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/dextrademixer/model/ITRAP.py b/dextrademixer/model/ITRAP.py index 1cdc693..0a74164 100644 --- a/dextrademixer/model/ITRAP.py +++ b/dextrademixer/model/ITRAP.py @@ -146,7 +146,6 @@ def _generate_filters(self, data): for k, thr in self.opt_thr.items(): if k in data.columns: filters &= data[k] >= thr - # filters &= eval(' & '.join([f'(data["{k}"] >= {abs(v)})' for k, v in self.opt_thr.items() if k in data.columns])) # TODO Other filters are not implemented yet, only makes sense once we have the respective data # Filter 2: Hashing singlets @@ -263,26 +262,3 @@ def _calculate_ideal_umi_thresholds(self, data): 'delta_umi_TRA', 'umi_count_TRB', 'delta_umi_TRB']] return opt_thr - - -if __name__ == "__main__": - from dextrademixer.utils.simulation import DextramerSimulator - - sim = DextramerSimulator() - mdata = sim.simulate_pmhc_data_from_distribution(total_cells=500, nof_clones=10, binding_ratio=0.05, - simulate_neg_control=True, rng_key=42 - ) - - - - binder = mdata.mod["airr"].obs["is_binder"].to_numpy() - - itrap = ITRAP() - itrap.preprocess_model_data(mdata, "pmhc1", neg_ctrl_key="neg_control", ir_clone_key="clone_id") - itrap.fit() - p, assignment = itrap.predict_posterior_class(target_fdr=0.05) - print(assignment) - print(binder) - N = len(binder) - accuracy = (binder == assignment).sum() / N - print("Accuracy", accuracy) From 8de2726adf5952924e4f649782d088ef1d2edb3e Mon Sep 17 00:00:00 2001 From: "benjamin.schubert" Date: Wed, 30 Jul 2025 17:20:27 +0200 Subject: [PATCH 05/13] - fixed hard coded gex_key and ir_key in process_model_data --- dextrademixer/model/ITRAP.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dextrademixer/model/ITRAP.py b/dextrademixer/model/ITRAP.py index 0a74164..5e096b5 100644 --- a/dextrademixer/model/ITRAP.py +++ b/dextrademixer/model/ITRAP.py @@ -74,9 +74,9 @@ def preprocess_model_data(self, mdata: md.MuData, pmhc_key: str, gex_key: str = self.specificity_to_idx = {s: i for i, s in enumerate(self.umi_cols_mhc)} self.idx_to_specificity = {i: s for i, s in enumerate(self.umi_cols_mhc)} - data = mdata['airr'].obs.copy() + data = mdata[ir_key].obs.copy() for col in self.umi_cols_mhc: - data[col] = mdata['gex'][:, col].X.toarray().reshape(-1) + data[col] = mdata[gex_key][:, col].X.toarray().reshape(-1) def calc_delta(x): """ Calculate UMI ratio of two most abundant pMHCs, 0.25 is a small constant to avoid division by zero""" From 078816e012a453a117369bb35ddf93e1652a8dd5 Mon Sep 17 00:00:00 2001 From: irene-bonapa Date: Mon, 8 Sep 2025 18:01:10 +0200 Subject: [PATCH 06/13] revised ICON implementation --- dextrademixer/model/ICON.py | 66 ++++++++++++++----------------------- 1 file changed, 25 insertions(+), 41 deletions(-) diff --git a/dextrademixer/model/ICON.py b/dextrademixer/model/ICON.py index 3376a5c..a085067 100644 --- a/dextrademixer/model/ICON.py +++ b/dextrademixer/model/ICON.py @@ -1,17 +1,8 @@ -import warnings from typing import List, Union - +import numpy as np import pandas as pd import mudata as md -import scanpy as sc -import scipy.stats -from scipy.stats import zscore - -import jax -import jax.lax -import jax.numpy as jnp -from dextrademixer.utils import calculate_pmhc_clonal_purity def icon_assign_pmhc(mdata: md.MuData, @@ -32,7 +23,7 @@ def icon_assign_pmhc(mdata: md.MuData, threshold_type: A string specifiying whether the threshold is absolut or relative. if relative than X in gex_key will be normalized by the column means pmhc_keys (Optional): A string or list of strings indicating the pMHC columns in `gex_key` modality`s `X` which should be - deconvolved. If None is given, the full X is used + deconvolved. If None is given, the full X is used, excluding the negative control if specified. gex_key: the MuData transcriptome module key neg_ctrl_key: (Optional) a string specifying the negative control column in `gex_key` modality`s `X` ir_key: the MuData AIRR module key @@ -47,44 +38,37 @@ def icon_assign_pmhc(mdata: md.MuData, air = mdata.mod[ir_key] if pmhc_keys is None: - pmhc_keys = gex.var_names + pmhc_keys = gex.var_names[gex.var_names != neg_ctrl_key] - if bg_noise is None and neg_ctrl_key is None: - bg_noise = 10 + if bg_noise is None: + bg_noise = gex[:, neg_ctrl_key].X.max() if neg_ctrl_key is not None else 10 - X = jnp.array(gex[:, pmhc_keys].X.toarray()) - x_neg = gex.X[:, neg_ctrl_key].max() if bg_noise is None else bg_noise + X = gex[:, pmhc_keys].X.toarray() c = air.obs[ir_clone_key].to_numpy().astype("int32") - # Subtract background noise - E = X - x_neg - E = E.at[E < 0].set(0) + # substract background + E = np.maximum(0, X - bg_noise) + ge0 = E.sum(axis=1) > 0 # 0 mask # calc pMHC ratio per cell - C = E / (E.sum(axis=1, keepdims=True) + 1) - - # raw assignment with UMI > 0 - rA = (E > 0).astype("int32") - - # calc clonotype purity - R = calculate_pmhc_clonal_purity(rA, c) - - S = jnp.log(E + 0.01) * (C ** 2) * R - S = jnp.nan_to_num(S) - S = S.at[S < 1].set(0) - - # pMHC-wise log-ratio normalization per cell - colSum = S.sum(axis=1, keepdims=True) - colSum = colSum.at[colSum <= 0].set(1) - S = S / colSum - - # cell-wise z-score normalization - S = (S - jnp.nanmean(S, axis=0, keepdims=True)) / jnp.nanstd(S, axis=0, keepdims=True) - S = jnp.nan_to_num(S, nan=jnp.nanmin(S)) + C = E.copy() + C[ge0] = E[ge0] / E[ge0].sum(axis=1, keepdims=True) + + # clone purity + R = pd.DataFrame(E > 0).groupby(c).sum() + R = R.div(R.sum(axis=1), axis=0).fillna(0).loc[c].values + + # Dextramer signal correction (rows that summed 0 remain as 0) + S = np.log(E+0.01) * R * C**2 + + # Per cell normalization: pMHC-wise log-ratio normalization + S[ge0] = S[ge0] / S[ge0].sum(axis=1, keepdims=True) + + # Dextramer normalization: cell-wise z-score normalization + S = (S - S.mean(axis=0, keepdims=True)) / S.std(axis=0, keepdims=True) assignment = (S > threshold).astype("uint8") - if inplace: - mdata.mod[gex_key].obsm["pMHC_assignment"] = assignment + mdata.mod[gex_key].obsm["icon_pMHC_assignment"] = assignment else: return assignment From 451d48a1abe3de8d7c953ff6854164aeec3fffb2 Mon Sep 17 00:00:00 2001 From: irene-bonapa Date: Tue, 9 Sep 2025 13:49:47 +0200 Subject: [PATCH 07/13] fix NAs --- dextrademixer/model/ICON.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/dextrademixer/model/ICON.py b/dextrademixer/model/ICON.py index a085067..7b1db0d 100644 --- a/dextrademixer/model/ICON.py +++ b/dextrademixer/model/ICON.py @@ -48,11 +48,11 @@ def icon_assign_pmhc(mdata: md.MuData, # substract background E = np.maximum(0, X - bg_noise) - ge0 = E.sum(axis=1) > 0 # 0 mask # calc pMHC ratio per cell - C = E.copy() - C[ge0] = E[ge0] / E[ge0].sum(axis=1, keepdims=True) + cellnorm = E.sum(axis=1, keepdims=True) + cellnorm[cellnorm == 0] = 1 + C = E / cellnorm # clone purity R = pd.DataFrame(E > 0).groupby(c).sum() @@ -60,13 +60,17 @@ def icon_assign_pmhc(mdata: md.MuData, # Dextramer signal correction (rows that summed 0 remain as 0) S = np.log(E+0.01) * R * C**2 + S[S<1] = 0 # Per cell normalization: pMHC-wise log-ratio normalization - S[ge0] = S[ge0] / S[ge0].sum(axis=1, keepdims=True) + cellnorm = S.sum(axis=1, keepdims=True) + cellnorm[cellnorm == 0] = 1 + S = S / cellnorm # Dextramer normalization: cell-wise z-score normalization S = (S - S.mean(axis=0, keepdims=True)) / S.std(axis=0, keepdims=True) - + S[np.isnan(S)] = np.nanmin(S) # set NA's to smalles observed value + assignment = (S > threshold).astype("uint8") if inplace: mdata.mod[gex_key].obsm["icon_pMHC_assignment"] = assignment From 77256905f6e38d22d3510cd90224ffdef2f09975 Mon Sep 17 00:00:00 2001 From: irene-bonapa Date: Tue, 9 Sep 2025 15:43:33 +0200 Subject: [PATCH 08/13] changed std for reproducibility with R --- dextrademixer/model/ICON.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dextrademixer/model/ICON.py b/dextrademixer/model/ICON.py index 7b1db0d..212bc72 100644 --- a/dextrademixer/model/ICON.py +++ b/dextrademixer/model/ICON.py @@ -61,6 +61,7 @@ def icon_assign_pmhc(mdata: md.MuData, # Dextramer signal correction (rows that summed 0 remain as 0) S = np.log(E+0.01) * R * C**2 S[S<1] = 0 + S_raw = S.copy() # Per cell normalization: pMHC-wise log-ratio normalization cellnorm = S.sum(axis=1, keepdims=True) @@ -68,9 +69,9 @@ def icon_assign_pmhc(mdata: md.MuData, S = S / cellnorm # Dextramer normalization: cell-wise z-score normalization - S = (S - S.mean(axis=0, keepdims=True)) / S.std(axis=0, keepdims=True) + S = (S - S.mean(axis=0, keepdims=True)) / S.std(axis=0, ddof=1, keepdims=True) S[np.isnan(S)] = np.nanmin(S) # set NA's to smalles observed value - + assignment = (S > threshold).astype("uint8") if inplace: mdata.mod[gex_key].obsm["icon_pMHC_assignment"] = assignment From 94742a9e705101479041f07a939c42903aab5904 Mon Sep 17 00:00:00 2001 From: irene-bonapa Date: Mon, 23 Mar 2026 14:01:43 +0100 Subject: [PATCH 09/13] ICON option as in original implementation, allow adata --- dextrademixer/model/ICON.py | 92 ++++++++++++++++++++++++------------- 1 file changed, 59 insertions(+), 33 deletions(-) diff --git a/dextrademixer/model/ICON.py b/dextrademixer/model/ICON.py index 212bc72..6eb45be 100644 --- a/dextrademixer/model/ICON.py +++ b/dextrademixer/model/ICON.py @@ -2,66 +2,89 @@ import numpy as np import pandas as pd import mudata as md +import anndata as ad - -def icon_assign_pmhc(mdata: md.MuData, +def icon_assign_pmhc(adata: Union[md.MuData, ad.AnnData], ir_clone_key: str, neg_ctrl_key: str = None, threshold: float = 0, bg_noise: float = None, + bg_noise_quantile: float = 0.975, pmhc_keys: Union[str, List[str]] = None, - gex_key: str = "gex", - ir_key: str = "airr", - inplace=False): + dex_key: str = "dex", + inplace=False, + faithful: bool = False, + ): """ implements the ICON assignment procedure + requires clonal information and dextramer counts, and optionally a negative control column to estimate background noise. Args: - mdata: A Mudata containing only dextramer counts and clonotype information - threshold: A UMI count, or relative threshold to determine dextramer-specificity - threshold_type: A string specifiying whether the threshold is absolut or relative. if relative than X in gex_key - will be normalized by the column means - pmhc_keys (Optional): A string or list of strings indicating the pMHC columns in `gex_key` modality`s `X` which should be - deconvolved. If None is given, the full X is used, excluding the negative control if specified. - gex_key: the MuData transcriptome module key - neg_ctrl_key: (Optional) a string specifying the negative control column in `gex_key` modality`s `X` - ir_key: the MuData AIRR module key - ir_clone_key: (Optional) a string specifying the field in `obs` of `ir_key` that holds clonotype ids - inplace: boolean indicating whether assignment should be stored in mdata on `gex_key` `obsm` - kwargs: dictionary of additional information pasted to the Model object (used for custom model prior) - + adata: A MuData object containing only dextramer counts and clonotype information, + or an AnnData object containing the dextramer counts and clonotype information in the specified obsm and obs keys. + threshold: A relative threshold to determine dextramer-specificity + bg_noise: (Optional) A value to substract from dextramer counts to account for background noise. + If None is given, the bg_noise_quantile of the negative control column is used if specified, otherwise 10. + pmhc_keys (Optional): A string or list of strings indicating the pMHC columns in `dex_key` modality which should be + deconvolved. If None is given, the full matrix is used, excluding the negative control if specified. + dex_key: the dextramer signal MuData module key, or the obsm key if adata is an AnnData object + neg_ctrl_key: (Optional) a string specifying the negative control column in the `dex_key` matrix. + ir_clone_key: A string specifying the field in `obs` that holds clonotype ids. + If in the immune receptor modality of a mudata object, should be `ir_key:clone_key`. + inplace: boolean indicating whether assignment should be stored in `obsm` + faithful: boolean indicating whether to use the original ICON procedure (True) or a debuged version based on the paper description - Returns: An array of pMHC assignments per cell, or modifies the mdata object adding an obsm matrix at `gex_key` + Returns: An array of pMHC assignments per cell, or modifies the adata object adding an obsm matrix at `dex_key` """ - gex = mdata.mod[gex_key] - air = mdata.mod[ir_key] + # check if clone key contains NA values + if adata.obs[ir_clone_key].isna().sum() > 0: + raise ValueError(f"NA values found in clone key {ir_clone_key} of adata.obs. ICON works only for cells with TCR information. Please filter the object.") + c = adata.obs[ir_clone_key].to_numpy().astype("int32") + + # get dextramer counts + if isinstance(adata, md.MuData): + is_mudata = True + dex = adata.mod[dex_key] + dex = pd.DataFrame(dex.X.toarray(), index=dex.obs_names, columns=dex.var_names) + elif isinstance(adata, ad.AnnData): + is_mudata = False + dex = adata.obsm[dex_key] if pmhc_keys is None: - pmhc_keys = gex.var_names[gex.var_names != neg_ctrl_key] + X = dex.loc[:,dex.columns != neg_ctrl_key].values + # get background noise if bg_noise is None: - bg_noise = gex[:, neg_ctrl_key].X.max() if neg_ctrl_key is not None else 10 - - X = gex[:, pmhc_keys].X.toarray() - c = air.obs[ir_clone_key].to_numpy().astype("int32") + bg_noise = np.quantile(dex.loc[:, neg_ctrl_key], q=bg_noise_quantile) if neg_ctrl_key is not None else 10 # substract background E = np.maximum(0, X - bg_noise) # calc pMHC ratio per cell - cellnorm = E.sum(axis=1, keepdims=True) - cellnorm[cellnorm == 0] = 1 - C = E / cellnorm + if faithful: + # +1 in the denominator can have large effects + C = E / (E.sum(axis=1, keepdims=True) + 1) + else: + cellnorm = E.sum(axis=1, keepdims=True) + cellnorm[cellnorm == 0] = 1 # 0/1 instead of 0/0 for cells with no dextramer signal + C = E / cellnorm # clone purity - R = pd.DataFrame(E > 0).groupby(c).sum() - R = R.div(R.sum(axis=1), axis=0).fillna(0).loc[c].values + clonal_counts = pd.DataFrame(E > 0).groupby(c).sum() + total = clonal_counts.sum(axis=1) + R = clonal_counts.div(total, axis=0).fillna(0) + + if faithful: + non_zero = (clonal_counts != 0).astype(int) + pure = non_zero.sum(axis=1) == 1 + R[pure] = non_zero[pure].div(total[pure], axis=0).fillna(0) + + R = R.loc[c].values # Dextramer signal correction (rows that summed 0 remain as 0) S = np.log(E+0.01) * R * C**2 S[S<1] = 0 - S_raw = S.copy() # Per cell normalization: pMHC-wise log-ratio normalization cellnorm = S.sum(axis=1, keepdims=True) @@ -74,6 +97,9 @@ def icon_assign_pmhc(mdata: md.MuData, assignment = (S > threshold).astype("uint8") if inplace: - mdata.mod[gex_key].obsm["icon_pMHC_assignment"] = assignment + if is_mudata: + adata.mod[dex_key].obsm["icon_pMHC_assignment"] = assignment + else: + adata.obsm["icon_pMHC_assignment"] = assignment else: return assignment From 9558a263e97595fd89229ed3c514dae86c273c79 Mon Sep 17 00:00:00 2001 From: irene-bonapa Date: Mon, 30 Mar 2026 14:18:27 +0200 Subject: [PATCH 10/13] at least 2 columns for TCR chain optimization + TCRA/TCRB typo --- dextrademixer/model/ITRAP.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/dextrademixer/model/ITRAP.py b/dextrademixer/model/ITRAP.py index 5e096b5..ee5cf6c 100644 --- a/dextrademixer/model/ITRAP.py +++ b/dextrademixer/model/ITRAP.py @@ -94,11 +94,13 @@ def calc_delta(x): data['delta_umi_mhc'] = data[self.umi_cols_mhc].apply(calc_delta, axis=1) data['umi_count_mhc_rel'] = data['umi_count_mhc'] / data['umi_count_mhc'].quantile(0.9, interpolation='lower') if self.umi_cols_TRA is not None: - data['umi_count_TRA'] = data[self.umi_cols_TRA].max(1) - data['delta_umi_TRA'] = data[self.umi_cols_TRA].apply(calc_delta) + if data[[self.umi_cols_TRA]].shape[1] > 1: + data['umi_count_TRA'] = data[self.umi_cols_TRA].max(1) + data['delta_umi_TRA'] = data[self.umi_cols_TRA].apply(calc_delta) if self.umi_cols_TRB is not None: - data['umi_count_TRB'] = data[self.umi_cols_TRA].max(1) - data['delta_umi_TRB'] = data[self.umi_cols_TRB].apply(calc_delta) + if data[[self.umi_cols_TRB]].shape[1] > 1: + data['umi_count_TRB'] = data[self.umi_cols_TRB].max(1) + data['delta_umi_TRB'] = data[self.umi_cols_TRB].apply(calc_delta) self.data = data From 8b16a91284be911aa4eac6945cc60fe75365b512 Mon Sep 17 00:00:00 2001 From: irene-bonapa Date: Mon, 30 Mar 2026 14:32:07 +0200 Subject: [PATCH 11/13] Revert "at least 2 columns for TCR chain optimization + TCRA/TCRB typo" This reverts commit 9558a263e97595fd89229ed3c514dae86c273c79. --- dextrademixer/model/ITRAP.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/dextrademixer/model/ITRAP.py b/dextrademixer/model/ITRAP.py index ee5cf6c..5e096b5 100644 --- a/dextrademixer/model/ITRAP.py +++ b/dextrademixer/model/ITRAP.py @@ -94,13 +94,11 @@ def calc_delta(x): data['delta_umi_mhc'] = data[self.umi_cols_mhc].apply(calc_delta, axis=1) data['umi_count_mhc_rel'] = data['umi_count_mhc'] / data['umi_count_mhc'].quantile(0.9, interpolation='lower') if self.umi_cols_TRA is not None: - if data[[self.umi_cols_TRA]].shape[1] > 1: - data['umi_count_TRA'] = data[self.umi_cols_TRA].max(1) - data['delta_umi_TRA'] = data[self.umi_cols_TRA].apply(calc_delta) + data['umi_count_TRA'] = data[self.umi_cols_TRA].max(1) + data['delta_umi_TRA'] = data[self.umi_cols_TRA].apply(calc_delta) if self.umi_cols_TRB is not None: - if data[[self.umi_cols_TRB]].shape[1] > 1: - data['umi_count_TRB'] = data[self.umi_cols_TRB].max(1) - data['delta_umi_TRB'] = data[self.umi_cols_TRB].apply(calc_delta) + data['umi_count_TRB'] = data[self.umi_cols_TRA].max(1) + data['delta_umi_TRB'] = data[self.umi_cols_TRB].apply(calc_delta) self.data = data From 9a8e654c2f486f607b1e3ddb7ccb03436679bdc6 Mon Sep 17 00:00:00 2001 From: irene-bonapa Date: Mon, 30 Mar 2026 14:32:28 +0200 Subject: [PATCH 12/13] Revert "Merge branch 'feature/itrap' into feature/threshold_assignment" This reverts commit 792334c592a08467f00e247f3cd5c431f49411a3, reversing changes made to 94742a9e705101479041f07a939c42903aab5904. --- dextrademixer/model/ITRAP.py | 264 --------------------------- dextrademixer/test/TestITRAPModel.py | 34 ---- 2 files changed, 298 deletions(-) delete mode 100644 dextrademixer/model/ITRAP.py delete mode 100644 dextrademixer/test/TestITRAPModel.py diff --git a/dextrademixer/model/ITRAP.py b/dextrademixer/model/ITRAP.py deleted file mode 100644 index 5e096b5..0000000 --- a/dextrademixer/model/ITRAP.py +++ /dev/null @@ -1,264 +0,0 @@ -from __future__ import annotations - -import os.path -from typing import TYPE_CHECKING, Tuple - -import mudata as md -import pandas as pd - -import numpy as np -import jax.lax -import jax -from scipy import stats - -from dextrademixer.model import ApMHCDeconvolution - -if TYPE_CHECKING: - from jax._src.typing import Array - - -class ITRAP(ApMHCDeconvolution): - """ - This class implements the ITRAP algorithm introduced by Povlsen et al. (2023). - First each clonotype with more than 10 cells is assigned an expected target if the highest UMI count is - significantly higher than the second most abundant pMHC using Wilcoxon p < 0.05. - Each cell's specificity is then assigned to the most abundant pMHC based on UMI count. - Using this expected target per clonotype, ITRAP calculates ideal UMI thresholds using a grid-search by optimizing - the accuracy (if the epitope with highest UMI count of a cell matches the expected target) while preserving the - ratio of retained cells using a weighted average between both objectives. - The optimal thresholds are then used to filter cells. Further filtering steps may be - included if the respective data, e.g., donor HLA, is available. - """ - __name = "ITRAP" - __version = "0.0.1" - - def __init__(self, umi_cols=None, umi_count_TRA=None, umi_count_TRB=None, filters=None): - """ - Args: - umi_cols: List of columns containing UMI counts for pMHCs (default set to ['neg_control', 'pmhc1']) - umi_count_TRA: List of columns containing UMI counts for TRA (default: None) - umi_count_TRB: List of columns containing UMI counts for TRB (default: None) - filters: List of filters to apply, options=['opt_thr', 'hashing_singlets', 'matching_HLA', 'complete_TCRs', - 'specificity_multiplets', 'is_cell', 'viable_cells'] (default: ['opt_thr']) - """ - super().__init__() - self.opt_thr = None - self.umi_cols_mhc = umi_cols - self.umi_cols_TRA = umi_count_TRA - self.umi_cols_TRB = umi_count_TRB - self.filters = filters if filters is not None else ['opt_thr'] - self.data = None - self.ir_clone_key = None - self.specificity_to_idx = None - self.idx_to_specificity = None - - def preprocess_model_data(self, mdata: md.MuData, pmhc_key: str, gex_key: str = "gex", neg_ctrl_key: str = None, - ir_key: str = "airr", ir_clone_key: str = None, ir_cov_key: str = None, **kwargs): - if ir_clone_key is None: - raise ValueError(f"{self.__name} requires a clonotype definition. Please specify a `ir_clone_key`.") - - gex = mdata.mod[gex_key] - N = gex.shape[0] - - x = gex[:, pmhc_key].X.toarray().reshape((N,)) - x_neg = gex[:, neg_ctrl_key].X.toarray().reshape((N,)) - - self._check_parameters(x, x_neg, None, None) - self.ir_clone_key = ir_clone_key - - if self.umi_cols_mhc is None: - if neg_ctrl_key is None: - raise ValueError("No negative control specified and no umi_cols_mhc. Please provide a `neg_ctrl_key` " - "or set umi_cols_mhc during initialization.") - self.umi_cols_mhc = [neg_ctrl_key, pmhc_key] - self.specificity_to_idx = {s: i for i, s in enumerate(self.umi_cols_mhc)} - self.idx_to_specificity = {i: s for i, s in enumerate(self.umi_cols_mhc)} - - data = mdata[ir_key].obs.copy() - for col in self.umi_cols_mhc: - data[col] = mdata[gex_key][:, col].X.toarray().reshape(-1) - - def calc_delta(x): - """ Calculate UMI ratio of two most abundant pMHCs, 0.25 is a small constant to avoid division by zero""" - if len(x) == 1: - return x[-1] / 0.25 - elif len(x) == 0: - return 0 - else: - return (x.nlargest(2).iloc()[0]) / (x.nlargest(2).iloc()[1] + 0.25) - - # Calculate UMI count and delta for pMHCs, TRA and TRB. Nomenclature follows original implementation - # umi_count_X = max(UMI count of X) - # delta_umi_X = ratio between highest and second highest UMI counts - data['umi_count_mhc'] = data[self.umi_cols_mhc].max(1) - data['delta_umi_mhc'] = data[self.umi_cols_mhc].apply(calc_delta, axis=1) - data['umi_count_mhc_rel'] = data['umi_count_mhc'] / data['umi_count_mhc'].quantile(0.9, interpolation='lower') - if self.umi_cols_TRA is not None: - data['umi_count_TRA'] = data[self.umi_cols_TRA].max(1) - data['delta_umi_TRA'] = data[self.umi_cols_TRA].apply(calc_delta) - if self.umi_cols_TRB is not None: - data['umi_count_TRB'] = data[self.umi_cols_TRA].max(1) - data['delta_umi_TRB'] = data[self.umi_cols_TRB].apply(calc_delta) - - self.data = data - - def fit(self): - """ - Fit the ITRAP model to the data. Calculate the ideal UMI thresholds for filtering - """ - if self.data is None: - raise Exception("Model is not initialized. Please call `preprocess_model_data` first.") - - # Calculate ideal thresholds - self.opt_thr = self._calculate_ideal_umi_thresholds(self.data) - - def predict_posterior_class(self, threshold: float = None, target_fdr: float = None) -> Tuple[Array, Array]: - """ - Returns the binder assignments based on the most abundant UMI count for each cell. - To filter out noise, different filters are applied to the data. - ITRAP does not return a posterior probability, so the assignment is returned as pseudo value. - Threshold and target_fdr are ignored in this implementation. - - Args: - threshold: (Optional) ignored - target_fdr: (Optional) ignored - Returns: - A tuple (p, assignment) of arrays with p being the pseudo value for compatibility of binding and assignment - the class assignment decision - """ - if self.opt_thr is None: - raise RuntimeError("Model has not been fit yet. Please call first `fit`.") - - # Assign cells to most abundant pMHC based on UMI count, then set assignment to 0 if it fails filters - filters = self._generate_filters(self.data) - self.data['assignment'] = self.data[self.umi_cols_mhc].idxmax(1).values - self.data['assignment'] = self.data['assignment'].map(self.specificity_to_idx) - self.data['assignment_before_filtering'] = self.data['assignment'].copy() - self.data.loc[~filters, 'assignment'] = 0 - - return self.data['assignment'].values.astype(int), self.data['assignment'].values.astype(float) - - def _generate_filters(self, data): - filters = pd.Series([True] * len(data), index=data.index) - - # Filter 1: UMI thresholds - if 'opt_thr' in self.filters: - for k, thr in self.opt_thr.items(): - if k in data.columns: - filters &= data[k] >= thr - - # TODO Other filters are not implemented yet, only makes sense once we have the respective data - # Filter 2: Hashing singlets - if 'hashing_singlets' in self.filters: - raise NotImplementedError("Hashing singlets filter is not implemented yet.") - - # Filter 3: Matching HLA - if 'matching_HLA' in self.filters: - raise NotImplementedError("Matching HLA filter is not implemented yet.") - - # Filter 4: Complete TCRs - if 'complete_TCRs' in self.filters: - raise NotImplementedError("Complete TCRs filter is not implemented yet.") - - # Filter 5: Specificity multiplets - if 'specificity_multiplets' in self.filters: - raise NotImplementedError("Specificity multiplets filter is not implemented yet.") - - # Filter 6: Is cell (Cellranger) - if 'is_cell' in self.filters: - raise NotImplementedError("Is cell filter is not implemented yet.") - - # Filter 7: Viable cells (GEX) - if 'viable_cells' in self.filters: - raise NotImplementedError("Viable cells filter is not implemented yet.") - - return filters - - def _calculate_expected_target(self, data): - # Select two most abundant pMHC based on UMI count - most_abundant_epitope = data[self.umi_cols_mhc].sum(0).nlargest(2).index - - w, p = stats.wilcoxon(data[most_abundant_epitope[0]].fillna(0) - data[most_abundant_epitope[1]].fillna(0), - alternative='greater') - - if p <= 0.05: - return True, most_abundant_epitope[0] - else: - return False, most_abundant_epitope[0] - - def _calculate_ideal_umi_thresholds(self, data): - # In case of a tie, in default params negative control is the first column and hence chosen as most abundant - data['cell_specificity'] = data[self.umi_cols_mhc].idxmax(1).values - - # Calculate expected target for each clonotype - ct_pep = data.groupby(self.ir_clone_key).filter(lambda x: len(x) >= 10) - ct_pep = ct_pep.groupby(self.ir_clone_key).apply(self._calculate_expected_target).to_frame() - ct_pep[['significant', 'expected_target']] = ct_pep[0].apply(pd.Series) - ct_pep = ct_pep[ct_pep['significant']].drop(columns=0) - - # Add expected target of each clonotype to full data and filter out non-significant clonotype targets - data['ct_pep'] = data[self.ir_clone_key].map(ct_pep['expected_target']) - cells_with_ct_pep = data[data['ct_pep'].notna()].copy() - cells_with_ct_pep['pep_match'] = cells_with_ct_pep['cell_specificity'] == cells_with_ct_pep['ct_pep'] - - # Grid search for optimal UMI threshold, hparams extracted from ITRAP code - if self.umi_cols_TRA is None: - umi_count_TRA_l = [None] - delta_umi_TRA_l = [None] - else: - umi_count_TRA_l = np.arange(0, data['umi_count_TRA'].quantile(0.4, interpolation='higher')) - delta_umi_TRA_l = np.arange(0, 4) - if self.umi_cols_TRB is None: - umi_count_TRB_l = [None] - delta_umi_TRB_l = [None] - else: - umi_count_TRB_l = np.arange(0, data['umi_count_TRB'].quantile(0.4, interpolation='higher')) - delta_umi_TRB_l = np.arange(0, 4) - - umi_count_mhc_l = np.arange(1, data['umi_count_mhc'].quantile(0.5, interpolation='higher')) - delta_umi_mhc_l = [0, 1, 2] # hparam from itrap Snakefile - umi_relat_mhc_l = [0] # seems unused in original implementation - - table = pd.DataFrame(columns=['accuracy', 'ratio_retained_gems', 'umi_count_mhc', 'umi_relat_mhc_l', - 'delta_umi_mhc', 'umi_count_TRA', 'delta_umi_TRA', 'umi_count_TRB', - 'delta_umi_TRB',]) - - n_total_gems = len(cells_with_ct_pep) - - i = -1 - for uca in umi_count_TRA_l: - for dua in delta_umi_TRA_l: - for ucb in umi_count_TRB_l: - for dub in delta_umi_TRB_l: - for ucm in umi_count_mhc_l: - for urm in umi_relat_mhc_l: - for dum in delta_umi_mhc_l: - i += 1 - filter_bool = ((cells_with_ct_pep['umi_count_mhc'] >= ucm) & - (cells_with_ct_pep['delta_umi_mhc'] >= dum) & - (cells_with_ct_pep['umi_count_mhc_rel'] >= urm)) - - if self.umi_cols_TRA is not None: - filter_bool &= (cells_with_ct_pep['umi_count_TRA'] >= uca) & ( - cells_with_ct_pep['delta_umi_TRA'] >= dua) - if self.umi_cols_TRB is not None: - filter_bool &= (cells_with_ct_pep['umi_count_TRB'] >= ucb) & ( - cells_with_ct_pep['delta_umi_TRB'] >= dub) - - flt = cells_with_ct_pep[filter_bool].copy() - - n_gems = len(flt) - n_mat = flt['pep_match'].sum() - - g_ratio = round(n_gems / n_total_gems, 3) - acc = round(n_mat / n_gems, 3) - - table.loc[i] = (acc, g_ratio, ucm, urm, dum, uca, dua, ucb, dub,) - - table['mix_mean'] = (table['accuracy'] * 2 + table['ratio_retained_gems']) / 3 - optimal_thresholds = (table.sort_values(by=['mix_mean', 'accuracy', 'ratio_retained_gems', 'umi_count_mhc'], - ascending=[True, True, True, False])) - opt_thr = optimal_thresholds.iloc()[-1][['umi_count_mhc', 'delta_umi_mhc', 'umi_count_TRA', - 'delta_umi_TRA', 'umi_count_TRB', 'delta_umi_TRB']] - - return opt_thr diff --git a/dextrademixer/test/TestITRAPModel.py b/dextrademixer/test/TestITRAPModel.py deleted file mode 100644 index 6f54463..0000000 --- a/dextrademixer/test/TestITRAPModel.py +++ /dev/null @@ -1,34 +0,0 @@ -import unittest - -import numpy as np -import muon as mu -import pandas as pd - -from dextrademixer.model.ITRAP import ITRAP -from dextrademixer.utils.simulation import DextramerSimulator - - -class MyTestCase(unittest.TestCase): - - def setUp(self): - sim = DextramerSimulator() - self.mdata = sim.simulate_pmhc_data_from_distribution(total_cells=500, nof_clones=10, binding_ratio=0.05, - simulate_neg_control=True, rng_key=42 - ) - - self.binder = self.mdata.mod["airr"].obs["is_binder"].to_numpy() - - def test_ITRAP(self): - itrap = ITRAP() - itrap.preprocess_model_data(self.mdata, "pmhc1", neg_ctrl_key="neg_control", ir_clone_key="clone_id") - itrap.fit() - p, assignment = itrap.predict_posterior_class(target_fdr=0.05) - print(assignment) - print(self.binder) - N = len(self.binder) - accuracy = (self.binder == assignment).sum() / N - print("Accuracy", accuracy) - - -if __name__ == '__main__': - unittest.main() From 1a56480fa6b75097eb6cfe39300962bb8d646362 Mon Sep 17 00:00:00 2001 From: irene-bonapa Date: Tue, 31 Mar 2026 12:02:30 +0200 Subject: [PATCH 13/13] remove clonal_purity function - not used --- dextrademixer/test/TestThresholdAssignment.py | 13 +------ dextrademixer/utils/utils.py | 34 ------------------- 2 files changed, 1 insertion(+), 46 deletions(-) diff --git a/dextrademixer/test/TestThresholdAssignment.py b/dextrademixer/test/TestThresholdAssignment.py index 834483b..8ce430b 100644 --- a/dextrademixer/test/TestThresholdAssignment.py +++ b/dextrademixer/test/TestThresholdAssignment.py @@ -3,21 +3,10 @@ import jax.numpy as jnp from dextrademixer.model import threshold_assign_pmhc -from dextrademixer.utils import calculate_pmhc_clonal_purity, DextramerSimulator +from dextrademixer.utils import DextramerSimulator class TestThresholdAssignment(unittest.TestCase): - def test_clonal_purity(self): - assignment = np.array([[1,0,0], - [1,0,0], - [1,1,0]]) - clonotypes = np.array([0,0,1]) - - purity = calculate_pmhc_clonal_purity(assignment, clonotypes) - self.assertTrue(np.allclose(purity, np.array([[1,0,0], - [1,0,0], - [0.5,0.5,0]]))) - def test_threshold_based_assignment(self): sim = DextramerSimulator() mdat = sim.simulate_pmhc_data_from_distribution(total_cells=10, diff --git a/dextrademixer/utils/utils.py b/dextrademixer/utils/utils.py index ef4c1d0..7d5bb5e 100644 --- a/dextrademixer/utils/utils.py +++ b/dextrademixer/utils/utils.py @@ -13,40 +13,6 @@ import scirpy as ir - -def calculate_pmhc_clonal_purity(assignment, - clonotypes): - """ - Calculates the pMHC purity of each clonotype, i.e., the fraction of cells of a clonotype being assigned to - a specific pMHC (accounting for multiple assignments). - - Args: - assignment: pMHC assignments of each cell - clonotypes: clonotype assignment of each cell - - Returns matrix cell x pMHC with clonotype purity for each pMHC - """ - unique_clonotypes = jnp.unique(clonotypes) - - def compute_clonotype_purity(c): - # Create a mask for the cells belonging to the current clonotype - mask = clonotypes == c - mask = mask.astype(jnp.float32) - - # Sum assignments for the current clonotype avoiding a where statement - Tki = (assignment * mask[:, None]).sum(axis=0, keepdims=True) - - # Normalize the assignment for the clonotype - purity = jnp.nan_to_num(Tki / Tki.sum()) - - # Reapply the mask to distribute purity values back to the original matrix - return purity * mask[:, None] - - # Compute purity for each clonotype and sum the results - purity_matrix = jax.vmap(compute_clonotype_purity)(unique_clonotypes) - return purity_matrix.sum(axis=0) - - def gower_centering(distance_matrix): """ Applies Gower's 1966 centering method to the distance matrix to obtain a covariance matrix.