diff --git a/dextrademixer/model/ICON.py b/dextrademixer/model/ICON.py index 95f4bf4..6eb45be 100644 --- a/dextrademixer/model/ICON.py +++ b/dextrademixer/model/ICON.py @@ -1,169 +1,105 @@ -import warnings from typing import List, Union - +import numpy as np import pandas as pd import mudata as md -import scanpy as sc -import scipy.stats -from scipy.stats import zscore - -import jax -import jax.lax -import jax.numpy as jnp - -from dextrademixer.utils import calculate_pmhc_clonal_purity - - -def threshold_assign_pmhc(mdata: md.MuData, - threshold: float, - threshold_type: str = "absolute", - pmhc_keys: Union[str, List[str]] = None, - total_normalization: bool = False, - target_sum: float = None, - z_score_normalization: bool = False, - gex_key: str = "gex", - neg_ctrl_key: str = None, - ir_key: str = "airr", - ir_clone_key: str = None, - inplace=False): - """ - Assigns dextramer-specificities based on specified threshold. - Depending on additional information provided different assignment strategies are applied. - - - Args: - mdata: A Mudata containing only dextramer counts and clonotype information - threshold: A UMI count, or relative threshold to determine dextramer-specificity - threshold_type: A string specifiying whether the threshold is absolut or relative. if relative than X in gex_key - will be normalized by the column means - pmhc_keys (Optional): A string or list of strings indicating the pMHC columns in `gex_key` modality`s `X` which - should be deconvolved. If None is given, the full X is used - total_normalization: boolean whether or not normalization of each cell by total counts over all pMHCs - (including negative control) should be applied, so that every cell has the same total count - after normalization. - target_sum: If None, after normalization, each observation (cell) has a total count equal to the median of total - counts for observations (cells) before normalization. - z_score_normalization: z-score normalize within pMHC across cells - gex_key: the MuData transcriptome module key - neg_ctrl_key: (Optional) a string specifying the negative control column in `gex_key` modality`s `X` - ir_key: the MuData AIRR module key - ir_clone_key: (Optional) a string specifying the field in `obs` of `ir_key` that holds clonotype ids - inplace: boolean indicating whether assignment should be stored in mdata on `gex_key` `obsm` - kwargs: dictionary of additional information pasted to the Model object (used for custom model prior) - - Returns: An array of pMHC assignments per cell, or modifies the mdata object adding an obsm matrix at `gex_key` - """ - - gex = mdata.mod[gex_key] - air = mdata.mod[ir_key] - N = gex.shape[0] - - if pmhc_keys is None: - pmhc_keys = gex.var.index - - if neg_ctrl_key is not None and neg_ctrl_key not in pmhc_keys: - pmhc_keys.append(neg_ctrl_key) - - gex_filtered = gex[:, pmhc_keys] - - if total_normalization: - x = sc.pp.normalize_total(gex_filtered, target_sum=target_sum, inplace=False)['X'] - else: - x = gex_filtered.X.toarray() - - if z_score_normalization: - x = (x - jnp.nanmean(x, axis=0, keepdims=True)) / jnp.nanstd(x, axis=0, keepdims=True) - - neg_ctrl_key_idx = gex.var.index.tolist().index(neg_ctrl_key) if neg_ctrl_key else None - x_neg = x[:, neg_ctrl_key_idx].reshape((N,)) if neg_ctrl_key else None - x = jnp.delete(x, neg_ctrl_key_idx, axis=1) if neg_ctrl_key else x - c = air.obs[ir_clone_key].to_numpy().astype("int32") if ir_clone_key is not None else None - - if threshold_type == "relative" and neg_ctrl_key is None: - x = x / (x.sum(axis=1, keepdims=True)+1e-8) - elif threshold_type == "relative" and neg_ctrl_key is not None: - x = x / (jnp.nanmax(x_neg)+1e-8) - elif threshold_type == "absolut" and neg_ctrl_key is not None: - x = x - jnp.nanmax(x_neg) - - assignment = ((x == jnp.max(x, axis=1, keepdims=True)) & (x >= threshold)).astype("uint8") - - if inplace: - mdata.mod[gex_key].obsm["pMHC_assignment"] = assignment - else: - return assignment - - -def icon_assign_pmhc(mdata: md.MuData, - neg_ctrl_key: str, - ir_clone_key: str, - threshold: float = 0, - pmhc_keys: Union[str, List[str]] = None, - gex_key: str = "gex", - ir_key: str = "airr", - inplace=False): +import anndata as ad + + +def icon_assign_pmhc(adata: Union[md.MuData, ad.AnnData], + ir_clone_key: str, + neg_ctrl_key: str = None, + threshold: float = 0, + bg_noise: float = None, + bg_noise_quantile: float = 0.975, + pmhc_keys: Union[str, List[str]] = None, + dex_key: str = "dex", + inplace=False, + faithful: bool = False, + ): """ implements the ICON assignment procedure + requires clonal information and dextramer counts, and optionally a negative control column to estimate background noise. Args: - mdata: A Mudata containing only dextramer counts and clonotype information - threshold: A UMI count, or relative threshold to determine dextramer-specificity - threshold_type: A string specifiying whether the threshold is absolut or relative. if relative than X in gex_key - will be normalized by the column means - pmhc_keys (Optional): A string or list of strings indicating the pMHC columns in `gex_key` modality`s `X` which should be - deconvolved. If None is given, the full X is used - gex_key: the MuData transcriptome module key - neg_ctrl_key: (Optional) a string specifying the negative control column in `gex_key` modality`s `X` - ir_key: the MuData AIRR module key - ir_clone_key: (Optional) a string specifying the field in `obs` of `ir_key` that holds clonotype ids - inplace: boolean indicating whether assignment should be stored in mdata on `gex_key` `obsm` - kwargs: dictionary of additional information pasted to the Model object (used for custom model prior) - - - Returns: An array of pMHC assignments per cell, or modifies the mdata object adding an obsm matrix at `gex_key` + adata: A MuData object containing only dextramer counts and clonotype information, + or an AnnData object containing the dextramer counts and clonotype information in the specified obsm and obs keys. + threshold: A relative threshold to determine dextramer-specificity + bg_noise: (Optional) A value to substract from dextramer counts to account for background noise. + If None is given, the bg_noise_quantile of the negative control column is used if specified, otherwise 10. + pmhc_keys (Optional): A string or list of strings indicating the pMHC columns in `dex_key` modality which should be + deconvolved. If None is given, the full matrix is used, excluding the negative control if specified. + dex_key: the dextramer signal MuData module key, or the obsm key if adata is an AnnData object + neg_ctrl_key: (Optional) a string specifying the negative control column in the `dex_key` matrix. + ir_clone_key: A string specifying the field in `obs` that holds clonotype ids. + If in the immune receptor modality of a mudata object, should be `ir_key:clone_key`. + inplace: boolean indicating whether assignment should be stored in `obsm` + faithful: boolean indicating whether to use the original ICON procedure (True) or a debuged version based on the paper description + + Returns: An array of pMHC assignments per cell, or modifies the adata object adding an obsm matrix at `dex_key` """ - gex = mdata.mod[gex_key] - air = mdata.mod[ir_key] - N = gex.shape[0] + # check if clone key contains NA values + if adata.obs[ir_clone_key].isna().sum() > 0: + raise ValueError(f"NA values found in clone key {ir_clone_key} of adata.obs. ICON works only for cells with TCR information. Please filter the object.") + c = adata.obs[ir_clone_key].to_numpy().astype("int32") + + # get dextramer counts + if isinstance(adata, md.MuData): + is_mudata = True + dex = adata.mod[dex_key] + dex = pd.DataFrame(dex.X.toarray(), index=dex.obs_names, columns=dex.var_names) + elif isinstance(adata, ad.AnnData): + is_mudata = False + dex = adata.obsm[dex_key] if pmhc_keys is None: - pmhc_keys = gex.var.index - - neg_ctrl_key_idx = gex.var.index(neg_ctrl_key) + X = dex.loc[:,dex.columns != neg_ctrl_key].values - X = gex[:, pmhc_keys].X.toarray() - x_neg = gex.X[:, neg_ctrl_key_idx].max() if neg_ctrl_key else None - c = air.obs[ir_clone_key].to_numpy().astype("int32") if ir_clone_key is not None else None + # get background noise + if bg_noise is None: + bg_noise = np.quantile(dex.loc[:, neg_ctrl_key], q=bg_noise_quantile) if neg_ctrl_key is not None else 10 - # Subtract background noise - E = X - x_neg - E = E.at[E < 0].set(0) + # substract background + E = np.maximum(0, X - bg_noise) # calc pMHC ratio per cell - C = E / (E.sum(axis=1, keepdims=True)+1) - - # raw assignment with UMI > 0 - rA = (E > 0).astype("int32") - - # calc clonotype purity - R = calculate_pmhc_clonal_purity(rA, c) - - S = jnp.log(E + 0.01) * C ** 2 * R - S = jnp.nan_to_num(S) - S = S.at[S < 0].set(0) - - # pMHC-wise log-ratio normalization per cell - colSum = S.sum(axis=1, keepdims=True) - colSum = colSum.at[colSum <= 0].set(1) - S = S / colSum - - # cell-wise z-score normalization - S = (S - jnp.nanmean(S)) / jnp.nanstd(S) - S = jnp.nan_to_num(S, nan=jnp.nanmin(S)) + if faithful: + # +1 in the denominator can have large effects + C = E / (E.sum(axis=1, keepdims=True) + 1) + else: + cellnorm = E.sum(axis=1, keepdims=True) + cellnorm[cellnorm == 0] = 1 # 0/1 instead of 0/0 for cells with no dextramer signal + C = E / cellnorm + + # clone purity + clonal_counts = pd.DataFrame(E > 0).groupby(c).sum() + total = clonal_counts.sum(axis=1) + R = clonal_counts.div(total, axis=0).fillna(0) + + if faithful: + non_zero = (clonal_counts != 0).astype(int) + pure = non_zero.sum(axis=1) == 1 + R[pure] = non_zero[pure].div(total[pure], axis=0).fillna(0) + + R = R.loc[c].values + + # Dextramer signal correction (rows that summed 0 remain as 0) + S = np.log(E+0.01) * R * C**2 + S[S<1] = 0 + + # Per cell normalization: pMHC-wise log-ratio normalization + cellnorm = S.sum(axis=1, keepdims=True) + cellnorm[cellnorm == 0] = 1 + S = S / cellnorm + + # Dextramer normalization: cell-wise z-score normalization + S = (S - S.mean(axis=0, keepdims=True)) / S.std(axis=0, ddof=1, keepdims=True) + S[np.isnan(S)] = np.nanmin(S) # set NA's to smalles observed value assignment = (S > threshold).astype("uint8") - if inplace: - mdata.mod[gex_key].obsm["pMHC_assignment"] = assignment + if is_mudata: + adata.mod[dex_key].obsm["icon_pMHC_assignment"] = assignment + else: + adata.obsm["icon_pMHC_assignment"] = assignment else: return assignment diff --git a/dextrademixer/model/__init__.py b/dextrademixer/model/__init__.py index 3cf6620..ea5330e 100644 --- a/dextrademixer/model/__init__.py +++ b/dextrademixer/model/__init__.py @@ -1,3 +1,4 @@ from dextrademixer.model.ApMHCDeconvolution import ApMHCDeconvolution from dextrademixer.model.Dextrademixer import DextraDemixer from dextrademixer.model.BEAMT import BEAMT +from dextrademixer.model.ICON import icon_assign_pmhc diff --git a/dextrademixer/test/TestThresholdAssignment.py b/dextrademixer/test/TestThresholdAssignment.py index 052b47e..8ce430b 100644 --- a/dextrademixer/test/TestThresholdAssignment.py +++ b/dextrademixer/test/TestThresholdAssignment.py @@ -1,8 +1,107 @@ import unittest +import numpy as np +import jax.numpy as jnp + +from dextrademixer.model import threshold_assign_pmhc +from dextrademixer.utils import DextramerSimulator + + +class TestThresholdAssignment(unittest.TestCase): + def test_threshold_based_assignment(self): + sim = DextramerSimulator() + mdat = sim.simulate_pmhc_data_from_distribution(total_cells=10, + nof_clones=3, + binding_ratio=0.5, + simulate_neg_control=True, + use_clonotype_cov=False, + binding_fold_increase_range=[100], + variance_fold_increase_range=[1.2], + plot_data=False) + assignment = threshold_assign_pmhc(mdat, 10) + print(mdat.mod["airr"].obs.clone_id) + print(mdat.mod["gex"].X) + print(assignment) + + def test_threshold_based_assignment_inplace(self): + sim = DextramerSimulator() + mdat = sim.simulate_pmhc_data_from_distribution(total_cells=10, + nof_clones=3, + binding_ratio=0.5, + simulate_neg_control=True, + use_clonotype_cov=False, + binding_fold_increase_range=[100], + variance_fold_increase_range=[1.2], + plot_data=False) + assignment = threshold_assign_pmhc(mdat, 10, neg_ctrl_key="neg_control", inplace=True) + print(mdat.mod["gex"].X) + print(mdat.mod["gex"].obsm["pMHC_assignment"]) + + + def test_threshold_based_assignment_total_normalization(self): + sim = DextramerSimulator() + mdat = sim.simulate_pmhc_data_from_distribution(total_cells=10, + nof_clones=3, + binding_ratio=0.5, + simulate_neg_control=True, + use_clonotype_cov=False, + binding_fold_increase_range=[100], + variance_fold_increase_range=[1.2], + plot_data=False) + assignment = threshold_assign_pmhc(mdat, 5, neg_ctrl_key="neg_control", + total_normalization=True, + target_sum=10e6) + print(mdat.mod["gex"].X) + print(assignment) + + def test_threshold_based_assignment_z_score(self): + sim = DextramerSimulator() + mdat = sim.simulate_pmhc_data_from_distribution(total_cells=10, + nof_clones=3, + binding_ratio=0.5, + simulate_neg_control=True, + use_clonotype_cov=False, + binding_fold_increase_range=[100], + variance_fold_increase_range=[1.2], + plot_data=False) + assignment = threshold_assign_pmhc(mdat, 0, neg_ctrl_key="neg_control", + z_score_normalization=True) + print(mdat.mod["gex"].X) + print(assignment) + + def test_threshold_based_assignment_z_score_and_total(self): + sim = DextramerSimulator() + mdat = sim.simulate_pmhc_data_from_distribution(total_cells=10, + nof_clones=3, + binding_ratio=0.5, + simulate_neg_control=True, + use_clonotype_cov=False, + binding_fold_increase_range=[100], + variance_fold_increase_range=[1.2], + plot_data=False) + assignment = threshold_assign_pmhc(mdat, 0, neg_ctrl_key="neg_control", + total_normalization=True, + z_score_normalization=True) + print(mdat.mod["gex"].X) + print(assignment) + + def test_threshold_based_assignment_relative_threshold(self): + sim = DextramerSimulator() + mdat = sim.simulate_pmhc_data_from_distribution(total_cells=10, + nof_clones=3, + binding_ratio=0.5, + simulate_neg_control=True, + use_clonotype_cov=False, + binding_fold_increase_range=[100], + variance_fold_increase_range=[1.2], + plot_data=False) + assignment = threshold_assign_pmhc(mdat, + threshold=0.5, + z_score_normalization=True, + threshold_type="relative") + print(mdat.mod["gex"].X) + print(assignment) + -class MyTestCase(unittest.TestCase): - def test_something(self): - self.assertEqual(True, False) # add assertion here if __name__ == '__main__': unittest.main() diff --git a/dextrademixer/utils/utils.py b/dextrademixer/utils/utils.py index e3f8c02..7d5bb5e 100644 --- a/dextrademixer/utils/utils.py +++ b/dextrademixer/utils/utils.py @@ -13,7 +13,6 @@ import scirpy as ir - def gower_centering(distance_matrix): """ Applies Gower's 1966 centering method to the distance matrix to obtain a covariance matrix.