Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 86 additions & 150 deletions dextrademixer/model/ICON.py
Original file line number Diff line number Diff line change
@@ -1,169 +1,105 @@
import warnings
from typing import List, Union

import numpy as np
import pandas as pd
import mudata as md
import scanpy as sc
import scipy.stats
from scipy.stats import zscore

import jax
import jax.lax
import jax.numpy as jnp

from dextrademixer.utils import calculate_pmhc_clonal_purity


def threshold_assign_pmhc(mdata: md.MuData,
threshold: float,
threshold_type: str = "absolute",
pmhc_keys: Union[str, List[str]] = None,
total_normalization: bool = False,
target_sum: float = None,
z_score_normalization: bool = False,
gex_key: str = "gex",
neg_ctrl_key: str = None,
ir_key: str = "airr",
ir_clone_key: str = None,
inplace=False):
"""
Assigns dextramer-specificities based on specified threshold.
Depending on additional information provided different assignment strategies are applied.


Args:
mdata: A Mudata containing only dextramer counts and clonotype information
threshold: A UMI count, or relative threshold to determine dextramer-specificity
threshold_type: A string specifiying whether the threshold is absolut or relative. if relative than X in gex_key
will be normalized by the column means
pmhc_keys (Optional): A string or list of strings indicating the pMHC columns in `gex_key` modality`s `X` which
should be deconvolved. If None is given, the full X is used
total_normalization: boolean whether or not normalization of each cell by total counts over all pMHCs
(including negative control) should be applied, so that every cell has the same total count
after normalization.
target_sum: If None, after normalization, each observation (cell) has a total count equal to the median of total
counts for observations (cells) before normalization.
z_score_normalization: z-score normalize within pMHC across cells
gex_key: the MuData transcriptome module key
neg_ctrl_key: (Optional) a string specifying the negative control column in `gex_key` modality`s `X`
ir_key: the MuData AIRR module key
ir_clone_key: (Optional) a string specifying the field in `obs` of `ir_key` that holds clonotype ids
inplace: boolean indicating whether assignment should be stored in mdata on `gex_key` `obsm`
kwargs: dictionary of additional information pasted to the Model object (used for custom model prior)

Returns: An array of pMHC assignments per cell, or modifies the mdata object adding an obsm matrix at `gex_key`
"""

gex = mdata.mod[gex_key]
air = mdata.mod[ir_key]
N = gex.shape[0]

if pmhc_keys is None:
pmhc_keys = gex.var.index

if neg_ctrl_key is not None and neg_ctrl_key not in pmhc_keys:
pmhc_keys.append(neg_ctrl_key)

gex_filtered = gex[:, pmhc_keys]

if total_normalization:
x = sc.pp.normalize_total(gex_filtered, target_sum=target_sum, inplace=False)['X']
else:
x = gex_filtered.X.toarray()

if z_score_normalization:
x = (x - jnp.nanmean(x, axis=0, keepdims=True)) / jnp.nanstd(x, axis=0, keepdims=True)

neg_ctrl_key_idx = gex.var.index.tolist().index(neg_ctrl_key) if neg_ctrl_key else None
x_neg = x[:, neg_ctrl_key_idx].reshape((N,)) if neg_ctrl_key else None
x = jnp.delete(x, neg_ctrl_key_idx, axis=1) if neg_ctrl_key else x
c = air.obs[ir_clone_key].to_numpy().astype("int32") if ir_clone_key is not None else None

if threshold_type == "relative" and neg_ctrl_key is None:
x = x / (x.sum(axis=1, keepdims=True)+1e-8)
elif threshold_type == "relative" and neg_ctrl_key is not None:
x = x / (jnp.nanmax(x_neg)+1e-8)
elif threshold_type == "absolut" and neg_ctrl_key is not None:
x = x - jnp.nanmax(x_neg)

assignment = ((x == jnp.max(x, axis=1, keepdims=True)) & (x >= threshold)).astype("uint8")

if inplace:
mdata.mod[gex_key].obsm["pMHC_assignment"] = assignment
else:
return assignment


def icon_assign_pmhc(mdata: md.MuData,
neg_ctrl_key: str,
ir_clone_key: str,
threshold: float = 0,
pmhc_keys: Union[str, List[str]] = None,
gex_key: str = "gex",
ir_key: str = "airr",
inplace=False):
import anndata as ad


def icon_assign_pmhc(adata: Union[md.MuData, ad.AnnData],
ir_clone_key: str,
neg_ctrl_key: str = None,
threshold: float = 0,
bg_noise: float = None,
bg_noise_quantile: float = 0.975,
pmhc_keys: Union[str, List[str]] = None,
dex_key: str = "dex",
inplace=False,
faithful: bool = False,
):
"""
implements the ICON assignment procedure
requires clonal information and dextramer counts, and optionally a negative control column to estimate background noise.

Args:
mdata: A Mudata containing only dextramer counts and clonotype information
threshold: A UMI count, or relative threshold to determine dextramer-specificity
threshold_type: A string specifiying whether the threshold is absolut or relative. if relative than X in gex_key
will be normalized by the column means
pmhc_keys (Optional): A string or list of strings indicating the pMHC columns in `gex_key` modality`s `X` which should be
deconvolved. If None is given, the full X is used
gex_key: the MuData transcriptome module key
neg_ctrl_key: (Optional) a string specifying the negative control column in `gex_key` modality`s `X`
ir_key: the MuData AIRR module key
ir_clone_key: (Optional) a string specifying the field in `obs` of `ir_key` that holds clonotype ids
inplace: boolean indicating whether assignment should be stored in mdata on `gex_key` `obsm`
kwargs: dictionary of additional information pasted to the Model object (used for custom model prior)


Returns: An array of pMHC assignments per cell, or modifies the mdata object adding an obsm matrix at `gex_key`
adata: A MuData object containing only dextramer counts and clonotype information,
or an AnnData object containing the dextramer counts and clonotype information in the specified obsm and obs keys.
threshold: A relative threshold to determine dextramer-specificity
bg_noise: (Optional) A value to substract from dextramer counts to account for background noise.
If None is given, the bg_noise_quantile of the negative control column is used if specified, otherwise 10.
pmhc_keys (Optional): A string or list of strings indicating the pMHC columns in `dex_key` modality which should be
deconvolved. If None is given, the full matrix is used, excluding the negative control if specified.
dex_key: the dextramer signal MuData module key, or the obsm key if adata is an AnnData object
neg_ctrl_key: (Optional) a string specifying the negative control column in the `dex_key` matrix.
ir_clone_key: A string specifying the field in `obs` that holds clonotype ids.
If in the immune receptor modality of a mudata object, should be `ir_key:clone_key`.
inplace: boolean indicating whether assignment should be stored in `obsm`
faithful: boolean indicating whether to use the original ICON procedure (True) or a debuged version based on the paper description

Returns: An array of pMHC assignments per cell, or modifies the adata object adding an obsm matrix at `dex_key`
"""
gex = mdata.mod[gex_key]
air = mdata.mod[ir_key]
N = gex.shape[0]
# check if clone key contains NA values
if adata.obs[ir_clone_key].isna().sum() > 0:
raise ValueError(f"NA values found in clone key {ir_clone_key} of adata.obs. ICON works only for cells with TCR information. Please filter the object.")
c = adata.obs[ir_clone_key].to_numpy().astype("int32")

# get dextramer counts
if isinstance(adata, md.MuData):
is_mudata = True
dex = adata.mod[dex_key]
dex = pd.DataFrame(dex.X.toarray(), index=dex.obs_names, columns=dex.var_names)
elif isinstance(adata, ad.AnnData):
is_mudata = False
dex = adata.obsm[dex_key]

if pmhc_keys is None:
pmhc_keys = gex.var.index

neg_ctrl_key_idx = gex.var.index(neg_ctrl_key)
X = dex.loc[:,dex.columns != neg_ctrl_key].values

X = gex[:, pmhc_keys].X.toarray()
x_neg = gex.X[:, neg_ctrl_key_idx].max() if neg_ctrl_key else None
c = air.obs[ir_clone_key].to_numpy().astype("int32") if ir_clone_key is not None else None
# get background noise
if bg_noise is None:
bg_noise = np.quantile(dex.loc[:, neg_ctrl_key], q=bg_noise_quantile) if neg_ctrl_key is not None else 10

# Subtract background noise
E = X - x_neg
E = E.at[E < 0].set(0)
# substract background
E = np.maximum(0, X - bg_noise)

# calc pMHC ratio per cell
C = E / (E.sum(axis=1, keepdims=True)+1)

# raw assignment with UMI > 0
rA = (E > 0).astype("int32")

# calc clonotype purity
R = calculate_pmhc_clonal_purity(rA, c)

S = jnp.log(E + 0.01) * C ** 2 * R
S = jnp.nan_to_num(S)
S = S.at[S < 0].set(0)

# pMHC-wise log-ratio normalization per cell
colSum = S.sum(axis=1, keepdims=True)
colSum = colSum.at[colSum <= 0].set(1)
S = S / colSum

# cell-wise z-score normalization
S = (S - jnp.nanmean(S)) / jnp.nanstd(S)
S = jnp.nan_to_num(S, nan=jnp.nanmin(S))
if faithful:
# +1 in the denominator can have large effects
C = E / (E.sum(axis=1, keepdims=True) + 1)
else:
cellnorm = E.sum(axis=1, keepdims=True)
cellnorm[cellnorm == 0] = 1 # 0/1 instead of 0/0 for cells with no dextramer signal
C = E / cellnorm

# clone purity
clonal_counts = pd.DataFrame(E > 0).groupby(c).sum()
total = clonal_counts.sum(axis=1)
R = clonal_counts.div(total, axis=0).fillna(0)

if faithful:
non_zero = (clonal_counts != 0).astype(int)
pure = non_zero.sum(axis=1) == 1
R[pure] = non_zero[pure].div(total[pure], axis=0).fillna(0)

R = R.loc[c].values

# Dextramer signal correction (rows that summed 0 remain as 0)
S = np.log(E+0.01) * R * C**2
S[S<1] = 0

# Per cell normalization: pMHC-wise log-ratio normalization
cellnorm = S.sum(axis=1, keepdims=True)
cellnorm[cellnorm == 0] = 1
S = S / cellnorm

# Dextramer normalization: cell-wise z-score normalization
S = (S - S.mean(axis=0, keepdims=True)) / S.std(axis=0, ddof=1, keepdims=True)
S[np.isnan(S)] = np.nanmin(S) # set NA's to smalles observed value

assignment = (S > threshold).astype("uint8")

if inplace:
mdata.mod[gex_key].obsm["pMHC_assignment"] = assignment
if is_mudata:
adata.mod[dex_key].obsm["icon_pMHC_assignment"] = assignment
else:
adata.obsm["icon_pMHC_assignment"] = assignment
else:
return assignment
1 change: 1 addition & 0 deletions dextrademixer/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dextrademixer.model.ApMHCDeconvolution import ApMHCDeconvolution
from dextrademixer.model.Dextrademixer import DextraDemixer
from dextrademixer.model.BEAMT import BEAMT
from dextrademixer.model.ICON import icon_assign_pmhc
105 changes: 102 additions & 3 deletions dextrademixer/test/TestThresholdAssignment.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,107 @@
import unittest
import numpy as np
import jax.numpy as jnp

from dextrademixer.model import threshold_assign_pmhc
from dextrademixer.utils import DextramerSimulator


class TestThresholdAssignment(unittest.TestCase):
def test_threshold_based_assignment(self):
sim = DextramerSimulator()
mdat = sim.simulate_pmhc_data_from_distribution(total_cells=10,
nof_clones=3,
binding_ratio=0.5,
simulate_neg_control=True,
use_clonotype_cov=False,
binding_fold_increase_range=[100],
variance_fold_increase_range=[1.2],
plot_data=False)
assignment = threshold_assign_pmhc(mdat, 10)
print(mdat.mod["airr"].obs.clone_id)
print(mdat.mod["gex"].X)
print(assignment)

def test_threshold_based_assignment_inplace(self):
sim = DextramerSimulator()
mdat = sim.simulate_pmhc_data_from_distribution(total_cells=10,
nof_clones=3,
binding_ratio=0.5,
simulate_neg_control=True,
use_clonotype_cov=False,
binding_fold_increase_range=[100],
variance_fold_increase_range=[1.2],
plot_data=False)
assignment = threshold_assign_pmhc(mdat, 10, neg_ctrl_key="neg_control", inplace=True)
print(mdat.mod["gex"].X)
print(mdat.mod["gex"].obsm["pMHC_assignment"])


def test_threshold_based_assignment_total_normalization(self):
sim = DextramerSimulator()
mdat = sim.simulate_pmhc_data_from_distribution(total_cells=10,
nof_clones=3,
binding_ratio=0.5,
simulate_neg_control=True,
use_clonotype_cov=False,
binding_fold_increase_range=[100],
variance_fold_increase_range=[1.2],
plot_data=False)
assignment = threshold_assign_pmhc(mdat, 5, neg_ctrl_key="neg_control",
total_normalization=True,
target_sum=10e6)
print(mdat.mod["gex"].X)
print(assignment)

def test_threshold_based_assignment_z_score(self):
sim = DextramerSimulator()
mdat = sim.simulate_pmhc_data_from_distribution(total_cells=10,
nof_clones=3,
binding_ratio=0.5,
simulate_neg_control=True,
use_clonotype_cov=False,
binding_fold_increase_range=[100],
variance_fold_increase_range=[1.2],
plot_data=False)
assignment = threshold_assign_pmhc(mdat, 0, neg_ctrl_key="neg_control",
z_score_normalization=True)
print(mdat.mod["gex"].X)
print(assignment)

def test_threshold_based_assignment_z_score_and_total(self):
sim = DextramerSimulator()
mdat = sim.simulate_pmhc_data_from_distribution(total_cells=10,
nof_clones=3,
binding_ratio=0.5,
simulate_neg_control=True,
use_clonotype_cov=False,
binding_fold_increase_range=[100],
variance_fold_increase_range=[1.2],
plot_data=False)
assignment = threshold_assign_pmhc(mdat, 0, neg_ctrl_key="neg_control",
total_normalization=True,
z_score_normalization=True)
print(mdat.mod["gex"].X)
print(assignment)

def test_threshold_based_assignment_relative_threshold(self):
sim = DextramerSimulator()
mdat = sim.simulate_pmhc_data_from_distribution(total_cells=10,
nof_clones=3,
binding_ratio=0.5,
simulate_neg_control=True,
use_clonotype_cov=False,
binding_fold_increase_range=[100],
variance_fold_increase_range=[1.2],
plot_data=False)
assignment = threshold_assign_pmhc(mdat,
threshold=0.5,
z_score_normalization=True,
threshold_type="relative")
print(mdat.mod["gex"].X)
print(assignment)


class MyTestCase(unittest.TestCase):
def test_something(self):
self.assertEqual(True, False) # add assertion here

if __name__ == '__main__':
unittest.main()
1 change: 0 additions & 1 deletion dextrademixer/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import scirpy as ir


def gower_centering(distance_matrix):
"""
Applies Gower's 1966 centering method to the distance matrix to obtain a covariance matrix.
Expand Down