Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
5217200
cleanup unused files
ArcaneEmergence Oct 14, 2025
642b983
feature: Use AutoNormal guide instead of AutoMultivariateNormal guide…
ArcaneEmergence Oct 14, 2025
a3d72f9
refactor: renaming variables of posterior bfdr thresholding
ArcaneEmergence Oct 14, 2025
a3878f0
figures: small improvements on figures: 1. shorter titles, 2. make sa…
ArcaneEmergence Oct 14, 2025
2dee937
refactor: cleanup unused files
ArcaneEmergence Oct 14, 2025
5f25647
update cluster params
ArcaneEmergence Oct 14, 2025
b227af7
feature: return mean_over_cell posterior parameters
ArcaneEmergence Oct 14, 2025
d1bb1ad
add estimate_sim_params.ipynb
ArcaneEmergence Oct 14, 2025
b009fe5
feature: new version of synthetic benchmark snakemake pipeline. 1. Op…
ArcaneEmergence Oct 14, 2025
1c53378
refactor: Remove duplicate line
ArcaneEmergence Oct 15, 2025
ea58ad4
feature: if y_true is None, create dummy zero values
ArcaneEmergence Oct 15, 2025
e69d361
feature: save and load models
ArcaneEmergence Oct 15, 2025
bcc96dd
feature: save and load models
ArcaneEmergence Oct 15, 2025
31f0a4c
Merge remote-tracking branch 'origin/experiment/synth_bench' into exp…
ArcaneEmergence Oct 15, 2025
e610cac
feature: parallelize BEAMT
ArcaneEmergence Oct 22, 2025
8f579ec
refactor: rename mixer to model
ArcaneEmergence Oct 22, 2025
3d20156
feature: sample variance based on mean, instead of sampling overdispe…
ArcaneEmergence Oct 24, 2025
aa7ae48
fix jax, jaxlib and numpyro versions
ArcaneEmergence Oct 24, 2025
ff646cd
scenario config: Update to version with variance sampled in relations…
ArcaneEmergence Oct 24, 2025
36a39a8
feature: add flags to sbatch script and snakemake file
ArcaneEmergence Oct 24, 2025
7b0472d
feature: resample so that sampled N_binder / N_total ~ p_binding_ratio
ArcaneEmergence Dec 1, 2025
43a951b
bugfix: Use noise mean for binder outliers instead of binder mean. Ad…
ArcaneEmergence Dec 2, 2025
211ce03
feature: ensure real outlier ratio roughly matches the specified ratio
ArcaneEmergence Dec 2, 2025
30a4181
feature: resample binding assignment to get close to target binding r…
ArcaneEmergence Feb 20, 2026
eb04360
feature: Add MCC
ArcaneEmergence Feb 20, 2026
2a66d7c
feature: Multiple enhancements
ArcaneEmergence Feb 20, 2026
0f6c8a7
update .gitignore
ArcaneEmergence Feb 20, 2026
1e1878b
Merge branch 'experiment/synth_bench' of https://github.com/SchubertL…
ArcaneEmergence Feb 20, 2026
915d50b
feature: remove kmeans outlier threshold, instead remove outlier from…
ArcaneEmergence Mar 16, 2026
b1bb15a
feature: float_or_none for argparse
ArcaneEmergence Mar 16, 2026
a2950de
experiment: Gemuend 2025 CMV data
ArcaneEmergence Mar 16, 2026
61854be
feature: save BEAMT results to csv
ArcaneEmergence Mar 18, 2026
9fa73ef
Update Figure design
ArcaneEmergence Mar 24, 2026
155e8d5
feature: clonotype median aggregation model
ArcaneEmergence Apr 8, 2026
3eae588
feature: memory-efficient `_predict_posterior_class_dist` through lax…
ArcaneEmergence Jun 18, 2026
3596302
feature: add to utils.py `mean_ci_t_interval`, `aggregate_csv`, `get_…
ArcaneEmergence Jun 18, 2026
a7b6fca
cleanup: remove hyperparameter tuning
ArcaneEmergence Jun 18, 2026
f42668b
cleanup: synthetic_benchmark unused files
ArcaneEmergence Jun 18, 2026
30b1f27
feature: new way of running synthetic benchmark by using apptainer
ArcaneEmergence Jun 18, 2026
269ab9f
feature: update slurm config
ArcaneEmergence Jun 18, 2026
331c37e
Figure: Add notebooks for figure plotting and add environment.yaml fi…
ArcaneEmergence Jun 18, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,19 @@ experiments/*/optuna_study
experiments/*/optuna
experiments/*/saved_models
experiments/*/simulation
experiments/*/logs
experiments/Ioanna_data/
experiments/Ioanna's experiments with Dextrademixer
*.db
*pkl

experiments/*/benchmarks
# Editors
.vscode/
.idea/
*.png
*.h5mu
*.csv
*.xlsx

# Vagrant
.vagrant/
Expand Down Expand Up @@ -138,3 +145,4 @@ venv.bak/
.mypy_cache/
.dmypy.json
dmypy.json
experiments/Ioanna_data/mudata_fromseurat-Copy1.h5mu
782 changes: 422 additions & 360 deletions dextrademixer/model/Dextrademixer.py

Large diffs are not rendered by default.

124 changes: 96 additions & 28 deletions dextrademixer/utils/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,42 @@ def generate_nb_val(mu, alpha, size):
return stats.poisson.rvs(g)


def sample_var_from_mean(mean: Union[float, np.ndarray],
a: float = 2.0221541172111164, b: float = 1.6969075027280063,
resid_std: float = 0.31049623532404225, rng: Union[int, np.random.RandomState] = 42
) -> Union[float, np.ndarray]:
"""
Sample a realistic variance given a mean using the fitted power-law model:
log(var) = a + b*log(mean) + Normal(0, resid_std^2)

Args:
mean : float or np.ndarray
Mean(s) at which to sample the variance. Must be > 0; broadcasting allowed.
a : float, default 2.0221541172111164
Proportionality constant (exp(intercept) from log–log OLS).
b : float, default 1.6969075027280063
Scaling exponent (slope from log–log OLS).
resid_std : float, default 0.31049623532404225
Residual standard deviation on the *log-variance* scale (σ from OLS residuals).
rng : int | np.random.RandomState, default 42
Source of randomness. If int, used as the seed. If None, uses SciPy/Numpy default RNG.
Returns:
float or np.ndarray
A sample of variance values with the same broadcasted shape as `mean`.
"""
Comment on lines +44 to +62

if isinstance(rng, int):
rng = np.random.RandomState(seed=rng)
if isinstance(mean, np.ndarray):
size = mean.shape
else:
size = None
log_var = np.log(a) + b*np.log(mean) + stats.norm(0, resid_std).rvs(size=size, random_state=rng)
var = np.exp(log_var)

return var


def t_cell_simulation(n_clones=3,
mean_binder_range=None,
shape_binder_range=None,
Expand Down Expand Up @@ -410,7 +446,8 @@ def simulate_pmhc_data_from_distribution(self,
use_clonotype_cov: bool = False,
simulate_neg_control: bool = False,
plot_data: bool = False,
rng_key: int = 42
rng_key: int = 42,
rep: int = 0,
) -> Union[Tuple[MuData, Any], MuData]:
"""
Given distribution parameters generate binding data for one pMHC. If certain parameters are not specified,
Expand Down Expand Up @@ -461,24 +498,47 @@ def simulate_pmhc_data_from_distribution(self,
if mean_neg_ctrl is None:
mean_neg_ctrl = np.exp(stats.truncnorm(-1.0539178917389445, 1.8375518345106903, loc=1.018115879390079, scale=0.4175162931163312).rvs(random_state=rng))
if concentration_neg_ctrl is None:
overdisp_neg_ctrl = stats.gamma(a=4.186062616134899, scale=1.2384303396204106).rvs(random_state=rng) + 1
var_neg_ctrl = mean_neg_ctrl * overdisp_neg_ctrl
var_neg_ctrl = sample_var_from_mean(mean_neg_ctrl, rng=rng)
concentration_neg_ctrl = convert_to_invdispersion(mean_neg_ctrl, var_neg_ctrl)
if mean_non_binder is None:
mean_non_binder = np.exp(stats.truncnorm(-1.4325807532116341, 1.9485510504360735, loc=2.0461540382126118, scale=0.6019089551720753).rvs(random_state=rng))
if concentration_non_binder is None:
overdisp_non_binder = stats.gamma(a=0.802396044662406, scale=6.554415080004114).rvs(random_state=rng) + 1
var_non_binder = mean_non_binder * overdisp_non_binder
var_non_binder = sample_var_from_mean(mean_non_binder, rng=rng)
concentration_non_binder = convert_to_invdispersion(mean_non_binder, var_non_binder)
if mean_inc is None:
mean_inc = stats.uniform(50, 450).rvs(random_state=rng) # between [50, 450+50]
mean_pos = mean_inc * mean_non_binder
if var_inc is None:
var_inc = stats.uniform(100, 400).rvs(random_state=rng) # between [100, 400+100]
assert var_inc > 1, "`var_inc` must be larger than 1"
concentration_pos = convert_to_invdispersion(mean_pos, mean_pos * var_inc)
var_pos = sample_var_from_mean(mean_pos, rng=rng)
else:
var_pos = var_inc * mean_non_binder
concentration_pos = convert_to_invdispersion(mean_pos, var_pos)
Comment on lines +512 to +515

# Sample binder assignments and cells per clone until empirical binding ratio is close to target
max_trials = 20
best_err = 10000
for _ in range(max_trials):
total_le = total_cells - nof_clones
raw_cells_per_clone = stats.boltzmann.rvs(*cells_per_clonotype, size=nof_clones, random_state=rng)
cells_per_clone_p = raw_cells_per_clone / raw_cells_per_clone.sum()
cells_per_clone_trial = (rng.multinomial(total_le, cells_per_clone_p) + np.ones(nof_clones)).astype("int32")

# Sample multiple binder assignments and pick the one that gives empirical binding ratio closest to target
binder_assignment_trial = rng.binomial(1, binding_ratio, size=(10000, nof_clones))
empirical_binding_ratio = ((cells_per_clone_trial * binder_assignment_trial).sum(1) / total_cells)
# mean of error from empirical cell and clone level binder ratio
err = ((np.abs(empirical_binding_ratio - binding_ratio) +
np.abs(binder_assignment_trial.mean(1) - binding_ratio))
/ 2)
Comment on lines +526 to +532

if err.min() < best_err:
best_idx = err.argmin()
binder_assignment = binder_assignment_trial[best_idx]
cells_per_clone = cells_per_clone_trial

if err.min() < binding_ratio * 0.05:
break

binder_assignment = rng.binomial(1, binding_ratio, size=nof_clones)
K = None
cc_assignment = None

Expand All @@ -494,10 +554,6 @@ def simulate_pmhc_data_from_distribution(self,

# generate cell per clonotype following a discrete exponentially decreasing distribution normalized to
# specified total cell count
total_le = total_cells - nof_clones
raw_cells_per_clone = np.array([stats.boltzmann.rvs(*cells_per_clonotype,random_state=rng) for _ in range(nof_clones)])
cells_per_clone_p = raw_cells_per_clone/raw_cells_per_clone.sum()
cells_per_clone = (rng.multinomial(total_le, cells_per_clone_p) + np.ones(nof_clones)).astype("int32")

d = {"x": [], "binder": [], "clone": [], "fold_increase": [], "outlier":[]}
if simulate_neg_control:
Expand All @@ -523,21 +579,6 @@ def simulate_pmhc_data_from_distribution(self,

x = DextramerSimulator.generate_nb_val(mean, concentration, size=n_cells, rng_key=key)

if p_binding_outlier > 0 and is_binder:
outlier = stats.binom.rvs(1, p_binding_outlier, size=n_cells, random_state=rng)
outlier_idx = np.where(outlier)

a = (0.001 - concentration_non_binder) / (concentration_non_binder / 3)
concentration = stats.truncnorm.rvs(a, np.inf, loc=concentration_non_binder,
scale=concentration_non_binder / 3, random_state=rng)

x = x.at[outlier_idx].set(
DextramerSimulator.generate_nb_val(mean, concentration, size=np.sum(outlier), rng_key=key)
)
d["outlier"].extend(outlier.tolist())
else:
d["outlier"].extend([0]*n_cells)

if simulate_neg_control:
key, subkey = jax.random.split(key)
x_neg = DextramerSimulator.generate_nb_val(mean_neg_ctrl, concentration_neg_ctrl, size=n_cells, rng_key=key)
Expand All @@ -548,6 +589,30 @@ def simulate_pmhc_data_from_distribution(self,
d["clone"].extend([i] * n_cells)
d["fold_increase"].extend([mean_inc] * n_cells)

if p_binding_outlier > 0:
outlier = np.zeros(total_cells, dtype=int)
binder_mask = np.array(d["binder"], dtype=bool)
n_binder = binder_mask.sum()

binder_outlier_trial = rng.binomial(1, p_binding_outlier, size=(10000, n_binder))

err = np.abs(p_binding_outlier - binder_outlier_trial.mean(1))
best_idx = err.argmin()
binder_outlier = binder_outlier_trial[best_idx]
outlier[binder_mask] = binder_outlier

a = (0.001 - concentration_non_binder) / (concentration_non_binder / 3)
concentration = stats.truncnorm.rvs(a, np.inf, loc=concentration_non_binder,
scale=concentration_non_binder / 3, random_state=rng)
x = np.array(d["x"])
x[outlier.astype(bool)] = (
DextramerSimulator.generate_nb_val(mean_non_binder, concentration, size=np.sum(outlier), rng_key=key)
)
d["x"] = x.tolist()
d["outlier"] = outlier.tolist()
else:
d["outlier"] = [0]*total_cells

mdat = DextramerSimulator.__generate_mdata(d, simulate_neg_control, K, cc_assignment)
# Best theoretical F1
precision, recall, thresholds = precision_recall_curve(mdat['airr'].obs['is_binder'], mdat['gex'].X[:, 0])
Expand All @@ -564,6 +629,7 @@ def simulate_pmhc_data_from_distribution(self,
'mean_inc': mean_inc,
'var_inc': var_inc,
'mean_pos': mean_pos,
'var_pos': var_pos,
'concentration_pos': concentration_pos,
'total_cells': total_cells,
'nof_clones': nof_clones,
Expand All @@ -573,6 +639,7 @@ def simulate_pmhc_data_from_distribution(self,
'rng_key': rng_key,
'best_f1': best_f1,
'best_threshold': best_threshold,
'rep': rep,
}
mdat['gex'].uns['sim_params'] = sim_params

Expand Down Expand Up @@ -766,6 +833,7 @@ def __generate_mdata(d, simulate_neg_control, cov, cc_assignment) -> MuData:
adata_tcr = ad.AnnData()
adata_tcr.obs["is_binder"] = d["binder"]
adata_tcr.obs["clone_id"] = d["clone"]
adata_tcr.obs["outlier"] = d["outlier"]

if cov is not None:
adata_tcr.obs["cc_aa_sim"] = cc_assignment[d["clone"]]
Expand Down
159 changes: 156 additions & 3 deletions dextrademixer/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import os
from collections import defaultdict
from typing import Any

Expand All @@ -7,11 +8,15 @@
import numpy as np
import optax
import pandas as pd
import scirpy as ir

from jax import pure_callback
from numpy import ndarray, dtype, bool_
from scipy.stats import ortho_group, random_correlation

import scirpy as ir
from scipy.stats import ortho_group, random_correlation, t
from sklearn.metrics import (roc_auc_score, average_precision_score, f1_score, precision_score, recall_score,
accuracy_score, matthews_corrcoef)
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm


def gower_centering(distance_matrix):
Expand Down Expand Up @@ -314,3 +319,151 @@ def str_to_bool(s):
setattr(args, key, str_to_bool(value))

return args


def float_or_none(value):
if value is None or value.lower() == 'none':
return None
try:
return float(value)
except ValueError:
raise ValueError(f"'{value}' is not a valid float or 'None'")


def get_slurm_cpu_count():
# Check for SLURM-provided variables
for var in ("SLURM_CPUS_PER_TASK", "SLURM_CPUS_ON_NODE", "SLURM_NTASKS", "SLURM_JOB_CPUS_PER_NODE"):
if var in os.environ:
value = os.environ[var]
# SLURM_JOB_CPUS_PER_NODE can be something like "4(x2)" meaning 2 nodes with 4 CPUs each
if "(" in value:
value = value.split("(")[0]
try:
return int(value)
except ValueError:
pass
# Fallback
try:
import multiprocessing
return multiprocessing.cpu_count()
except NotImplementedError:
return 1


def guess_worker_mem_limit_mb(nworkers: int):
# If SLURM ressources are present
if "SLURM_MEM_PER_NODE" in os.environ:
return int(int(os.environ["SLURM_MEM_PER_NODE"]) * 0.95 // nworkers)
if "SLURM_MEM_PER_CPU" in os.environ:
return int(int(os.environ["SLURM_MEM_PER_CPU"]) * 0.95)
return None # no good signal; skip limiting


def init_worker(worker_mem_limit_mb=None):
if worker_mem_limit_mb is None:
return
try:
import resource
limit_bytes = int(worker_mem_limit_mb) * 1024 * 1024
# Address space cap → allocations above this raise MemoryError
resource.setrlimit(resource.RLIMIT_AS, (limit_bytes, limit_bytes))
except Exception:
# If we can't set it, just proceed; kernel OOM may still occur.
pass


def calculate_metrics(y_true: np.ndarray, p_pred: np.ndarray, assignment: np.ndarray, full_metrics: bool = True) -> dict:
"""
Calculates performance metrics based on true labels, predicted probabilities, and binary assignments.
Args:
y_true (np.ndarray): True binary labels (0 or 1).
p_pred (np.ndarray): Predicted probabilities for the positive class.
assignment (np.ndarray): Binary predictions based on a threshold applied to p_pred.
full_metrics (bool): If True, calculates additionally AUROC, accuracy and MCC. Default is True.
Returns:
dict: A dictionary containing calculated metrics.
"""
results_dict = {'aps': average_precision_score(y_true, p_pred), 'f1': f1_score(y_true, assignment),
'precision': precision_score(y_true, assignment), 'recall': recall_score(y_true, assignment), }

if full_metrics:
results_dict.update({'auroc': roc_auc_score(y_true, p_pred), 'accuracy': accuracy_score(y_true, assignment), 'mcc': matthews_corrcoef(y_true, assignment)})

tp = np.sum(assignment.astype(bool) & y_true.astype(bool))
fp = np.sum(assignment.astype(bool) & ~y_true.astype(bool))
tn = np.sum(~assignment.astype(bool) & ~y_true.astype(bool))
fn = np.sum(~assignment.astype(bool) & y_true.astype(bool))

if (tp + fp) == 0:
fdr = 0.0
else:
fdr = fp / (tp + fp)

results_dict['fdr'] = fdr
results_dict['tp'] = tp
results_dict['fp'] = fp
results_dict['tn'] = tn
results_dict['fn'] = fn

return results_dict


def mean_ci_t_interval(x, confidence=0.95):
x = x.dropna()
n = len(x)
mean = x.mean()

alpha = 1 - confidence
q = 1 - alpha / 2 # for 95% CI: 1 - 0.05/2 = 0.975

tcrit = t.ppf(q, df=n - 1)
se = x.std(ddof=1) / np.sqrt(n)
ci = tcrit * se

ci_low = mean - ci
ci_high = mean + ci

return f"{mean:.3f} [{ci_low:.3f}, {ci_high:.3f}]"


def aggregate_csv(experiment_path='.', output_path='agg_results.csv', rerun=False, paths=None, fps=None) -> pd.DataFrame:
"""
Aggregates CSV files from single experiment outputs into a single DataFrame and saves it as a CSV file using multiprocessing.
Args:
experiment_path (str): The base directory where the CSV files are located.
agg_fp (str): The file path for the aggregated CSV file to be saved.
rerun (bool): If True, forces re-aggregation even if the aggregated file already exists. Default is False.
paths (list of str): A list of subdirectories within experiment_path to search for CSV files. If None, it defaults to ['csv'].
fps (list of str): Alternative instead of using directories, use list of file paths to aggregate. If provided, paths will be ignored.
Returns:
df (pd.DataFrame): The aggregated DataFrame containing data from all CSV files.
"""
def read_csv(fp):
return pd.read_csv(fp, index_col=0)

if os.path.exists(output_path) and not rerun:
df = pd.read_csv(output_path, index_col=0)
else:
paths = paths if paths is not None else ['csv']
dfs = []
if fps is None:
fps = [os.path.join(experiment_path, path, f) for path in paths for f in os.listdir(os.path.join(experiment_path, path)) if f.endswith('.csv') and 'intermediate.csv' not in f]

with ThreadPoolExecutor() as ex:
df = list(tqdm(ex.map(read_csv, fps), total=len(fps)))
dfs.extend(df)

df = pd.concat(dfs, ignore_index=True)
df.to_csv(output_path)

return df


def get_cpu_model():
try:
with open("/proc/cpuinfo") as f:
for line in f:
if "model name" in line:
return line.split(":", 1)[1].strip()
except:
return "Unknown"
Loading