From 7740fbce96e10b11aa9293adc04d51d5df2c00c6 Mon Sep 17 00:00:00 2001 From: Brett Hannigan Date: Wed, 16 Jul 2025 15:13:00 -0400 Subject: [PATCH] Fix pseudo-count difference --- enrich2/selection.py | 54 +++++++++++++++++--------------------------- 1 file changed, 21 insertions(+), 33 deletions(-) diff --git a/enrich2/selection.py b/enrich2/selection.py index 405cdbb..4cbbbfc 100644 --- a/enrich2/selection.py +++ b/enrich2/selection.py @@ -530,23 +530,19 @@ def calc_ratios(self, label): ) shared_counts = shared_counts.values + 0.5 elif self.logr_method == "complete": + counts_df = self.store.select( + "/main/{}/counts".format(label), "columns in ['c_0', c_last]" + ) shared_counts = ( - self.store.select( - "/main/{}/counts".format(label), "columns in ['c_0', c_last]" - ) - .sum(axis="index") - .values - + 0.5 + counts_df.sum(axis="index").values + 0.5 * len(counts_df) ) elif self.logr_method == "full": + counts_df = self.store.select( + "/main/{}/counts_unfiltered".format(label), + "columns in ['c_0', c_last]", + ) shared_counts = ( - self.store.select( - "/main/{}/counts_unfiltered".format(label), - "columns in ['c_0', c_last]", - ) - .sum(axis="index", skipna=True) - .values - + 0.5 + counts_df.sum(axis="index", skipna=True).values + 0.5 * len(counts_df) ) else: raise ValueError( @@ -612,20 +608,16 @@ def calc_log_ratios(self, label): ) ratios = ratios - np.log(wt_counts.values + 0.5) elif self.logr_method == "complete": + counts_df = self.store.select("/main/{}/counts".format(label), "columns=c_n") ratios = ratios - np.log( - self.store.select("/main/{}/counts".format(label), "columns=c_n") - .sum(axis="index") - .values - + 0.5 + counts_df.sum(axis="index").values + 0.5 * len(counts_df) ) elif self.logr_method == "full": + counts_df = self.store.select( + "/main/{}/counts_unfiltered".format(label), "columns=c_n" + ) ratios = ratios - np.log( - self.store.select( - "/main/{}/counts_unfiltered".format(label), "columns=c_n" - ) - .sum(axis="index", skipna=True) - .values - + 0.5 + counts_df.sum(axis="index", skipna=True).values + 0.5 * len(counts_df) ) else: raise ValueError( @@ -682,20 +674,16 @@ def calc_weights(self, label): ) variances = variances + 1.0 / (wt_counts.values + 0.5) elif self.logr_method == "complete": + counts_df = self.store.select("/main/{}/counts".format(label), "columns=c_n") variances = variances + 1.0 / ( - self.store.select("/main/{}/counts".format(label), "columns=c_n") - .sum(axis="index") - .values - + 0.5 + counts_df.sum(axis="index").values + 0.5 * len(counts_df) ) elif self.logr_method == "full": + counts_df = self.store.select( + "/main/{}/counts_unfiltered".format(label), "columns=c_n" + ) variances = variances + 1.0 / ( - self.store.select( - "/main/{}/counts_unfiltered".format(label), "columns=c_n" - ) - .sum(axis="index", skipna=True) - .values - + 0.5 + counts_df.sum(axis="index", skipna=True).values + 0.5 * len(counts_df) ) else: raise ValueError(