From 5256d8cbef3efcfd6ce425109b4cbb3ebe66ed99 Mon Sep 17 00:00:00 2001 From: Joe Futrelle Date: Tue, 15 Jul 2025 08:57:55 -0400 Subject: [PATCH] support class scores CSVs, called "v4" --- ifcb/data/products/class_scores.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/ifcb/data/products/class_scores.py b/ifcb/data/products/class_scores.py index 55229ca..ccf77b6 100644 --- a/ifcb/data/products/class_scores.py +++ b/ifcb/data/products/class_scores.py @@ -49,6 +49,12 @@ def _get_v3_file(self, bin_lid): if path is not None: return ClassScoresFile(path, bin_lid, version=3) raise KeyError(bin_lid) + def _get_v4_file(self, bin_lid): + filename = '{}.csv'.format(bin_lid) + path = find_product_file(self.path, filename, exhaustive=self.exhaustive) + if path is not None: + return ClassScoresFile(path, bin_lid, version=4) + raise KeyError(bin_lid) def __getitem__(self, bin_lid): if self.version == 1: return self._get_v1_file(bin_lid) @@ -56,6 +62,8 @@ def __getitem__(self, bin_lid): return self._get_v2_file(bin_lid) elif self.version == 3: return self._get_v3_file(bin_lid) + elif self.version == 4: + return self._get_v4_file(bin_lid) else: raise KeyError('unknown class scores version {}'.format(self.version)) def has_key(self, bin_lid): @@ -71,6 +79,10 @@ def keys(self): fn_regex = r'.*_class_v2\.h5' elif self.version == 3: fn_regex = r'.*_class.h5' + elif self.version == 4: + fn_regex = r'.*\.csv' + else: + raise KeyError('unknown class scores version {}'.format(self.version)) for p in list_product_files(self.path, fn_regex): # parse the filename as a pid bin_lid = Pid(os.path.basename(p)).bin_lid @@ -108,6 +120,17 @@ def _class_scores_v3(self): class_labels = [l.decode('ascii') for l in f['class_labels'][:]] roi_numbers = f['roi_numbers'][:] return self._cs2df(scores, class_labels, roi_numbers) + def _class_scores_v4(self): + # version 4 is a CSV file, the first column is "pid" which is a string that contains the ROI number + df = pd.read_csv(self.path, index_col='pid') + # the index is the pid, which contains the roi number + df.index = df.index.str.extract(r'_(\d+)$')[0] + df.index.name = 'roi_number' + # convert index to int + roi_numbers = df.index.astype(int) + class_labels = df.columns.tolist() + scores = df.values + return self._cs2df(scores, class_labels, roi_numbers) def class_scores(self): if self.version == 1: return self._class_scores_v1() @@ -115,5 +138,7 @@ def class_scores(self): return self._class_scores_v2() elif self.version == 3: return self._class_scores_v3() + elif self.version == 4: + return self._class_scores_v4() else: raise KeyError('unknown class scores version {}'.format(self.version))