Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions ifcb/data/products/class_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,21 @@ def _get_v3_file(self, bin_lid):
if path is not None:
return ClassScoresFile(path, bin_lid, version=3)
raise KeyError(bin_lid)
def _get_v4_file(self, bin_lid):
filename = '{}.csv'.format(bin_lid)
path = find_product_file(self.path, filename, exhaustive=self.exhaustive)
if path is not None:
return ClassScoresFile(path, bin_lid, version=4)
raise KeyError(bin_lid)
def __getitem__(self, bin_lid):
if self.version == 1:
return self._get_v1_file(bin_lid)
elif self.version == 2:
return self._get_v2_file(bin_lid)
elif self.version == 3:
return self._get_v3_file(bin_lid)
elif self.version == 4:
return self._get_v4_file(bin_lid)
else:
raise KeyError('unknown class scores version {}'.format(self.version))
Comment thread
joefutrelle marked this conversation as resolved.
def has_key(self, bin_lid):
Expand All @@ -71,6 +79,10 @@ def keys(self):
fn_regex = r'.*_class_v2\.h5'
elif self.version == 3:
fn_regex = r'.*_class.h5'
elif self.version == 4:
fn_regex = r'.*\.csv'
Comment thread
joefutrelle marked this conversation as resolved.
else:
raise KeyError('unknown class scores version {}'.format(self.version))
for p in list_product_files(self.path, fn_regex):
# parse the filename as a pid
bin_lid = Pid(os.path.basename(p)).bin_lid
Expand Down Expand Up @@ -108,12 +120,25 @@ def _class_scores_v3(self):
class_labels = [l.decode('ascii') for l in f['class_labels'][:]]
roi_numbers = f['roi_numbers'][:]
return self._cs2df(scores, class_labels, roi_numbers)
def _class_scores_v4(self):
# version 4 is a CSV file, the first column is "pid" which is a string that contains the ROI number
df = pd.read_csv(self.path, index_col='pid')
# the index is the pid, which contains the roi number
df.index = df.index.str.extract(r'_(\d+)$')[0]
df.index.name = 'roi_number'
# convert index to int
roi_numbers = df.index.astype(int)
class_labels = df.columns.tolist()
scores = df.values
return self._cs2df(scores, class_labels, roi_numbers)
def class_scores(self):
if self.version == 1:
return self._class_scores_v1()
elif self.version == 2:
return self._class_scores_v2()
elif self.version == 3:
return self._class_scores_v3()
elif self.version == 4:
return self._class_scores_v4()
else:
raise KeyError('unknown class scores version {}'.format(self.version))