diff --git a/libemg/_datasets/emg_epn100.py b/libemg/_datasets/emg_epn100.py index 8ed80431..1cbaddae 100644 --- a/libemg/_datasets/emg_epn100.py +++ b/libemg/_datasets/emg_epn100.py @@ -129,14 +129,26 @@ def process_block(block, rep_offset: int, max_reps: int): classe = GESTURE_MAP[gesture] emg = np.asarray(entry.emg, dtype=np.float32) - point_begins = np.asarray(entry.pointGestureBegins, dtype=np.int64) + + if not hasattr(entry, "groundTruth"): + pb, pe = -1, -1 + + else: + gt = np.asarray(entry.groundTruth, dtype=np.uint8) + idx = np.where(gt != 0)[0] + if idx.size == 0: + pb, pe = -1, -1 + else: + pb = int(idx[0]) + pe = int(idx[-1]) + 1 rep_grp = reps_grp.create_group(f"rep_{rep_id:03d}") rep_grp.create_dataset("emg", data=emg) rep_grp.create_dataset("gesture", data=classe) rep_grp.create_dataset("subject", data=subject_id) rep_grp.create_dataset("rep", data=rep_id) - rep_grp.create_dataset("point_begins", data=point_begins) + rep_grp.create_dataset("pb", data=pb) + rep_grp.create_dataset("pe", data=pe) reps_written += 1 @@ -178,7 +190,7 @@ def process_dataset(root_in: str, root_out: str): class EMGEPN100(Dataset): def __init__(self, dataset_folder: str='DATASET_85'): Dataset.__init__(self, - sampling={'myo': 200, 'gForce': 500}, + sampling={'myo': 200, 'gForce': 1000}, num_channels={'myo': 8, 'gForce': 8}, recording_device=['myo', 'gForce'], num_subjects=85, @@ -186,14 +198,14 @@ def __init__(self, dataset_folder: str='DATASET_85'): num_reps="30 Reps x 12 Gestures x 43 Users (Train group), 15 Reps x 12 Gestures x 42 Users (Test group) --> Cross User Split", description="Multi-hardware EMG dataset for 12 different hand gesture categories using the myo armband and the G-force armband.", citation="https://doi.org/10.3390/s22249613") - self.resolution_bit = {'myo': 8, 'gForce': 12} + self.resolution_bit = {'myo': 8, 'gForce': 8} self.dataset_folder = dataset_folder self.url = "https://laboratorio-ia.epn.edu.ec/es/recursos/dataset/emg-imu-epn-100" def _get_odh(self, processed_root, subjects, segment, relabel_seg, channel_last): - splits = {"training", "testing"} + splits = ["training", "testing"] odhs = [] for split in splits: @@ -204,9 +216,11 @@ def _get_odh(self, processed_root, subjects, odh.subjects = [] odh.classes = [] odh.reps = [] + odh.base_class = [] odh.devices = [] odh.sampling_rates = [] odh.extra_attributes = ['subjects', 'classes', 'reps', + 'base_class', 'devices', 'sampling_rates'] for user_file in user_files: @@ -231,35 +245,60 @@ def _get_odh(self, processed_root, subjects, rep_id = int(rep_grp["rep"][()]) _emg = rep_grp["emg"][:].astype(np.float32, copy=False) # [T, CH] - if not channel_last: - _emg = np.transpose(_emg, (1, 0)) # [CH, T] - + if segment and gst != 0: - point_begins = rep_grp["point_begins"][()] - emg = _emg[point_begins:] + pb = int(rep_grp["pb"][()]) + pe = int(rep_grp["pe"][()]) + if pb < 0 or pe < 0: + pb, pe = None, None + emg = _emg[pb:pe] else: emg = _emg + + if emg.shape[0] == 0: + continue + + if not channel_last: + emg = np.transpose(emg, (1, 0)) # [CH, T] # ---- Preparing ODH ---- odh.data.append(emg) - odh.classes.append(np.ones((len(emg), 1)) * gst) - odh.subjects.append(np.ones((len(emg), 1)) * subject) - odh.reps.append(np.ones((len(emg), 1)) * rep_id) - odh.devices.append(np.ones((len(emg), 1)) * device) - odh.sampling_rates.append(np.ones((len(emg), 1)) * fs) - - if segment and gst != 0 and relabel_seg is not None: + odh.classes.append(np.ones((len(emg), 1), dtype=np.int64) * gst) + odh.subjects.append(np.ones((len(emg), 1), dtype=np.int64) * subject) + odh.reps.append(np.ones((len(emg), 1), dtype=np.int64) * rep_id) + odh.base_class.append(np.ones((len(emg), 1), dtype=np.int64) * gst) + odh.devices.append(np.ones((len(emg), 1), dtype=np.int64) * device) + odh.sampling_rates.append(np.ones((len(emg), 1), dtype=np.int64) * fs) + + if segment and gst != 0 and relabel_seg is not None \ + and pb is not None and pe is not None: assert type(relabel_seg) is int - gst = relabel_seg - - emg = _emg[:point_begins] + emg = _emg[:pb] + if emg.shape[0] == 0: + continue + if not channel_last: + emg = np.transpose(emg, (1, 0)) + odh.data.append(emg) + odh.classes.append(np.ones((len(emg), 1), dtype=np.int64) * relabel_seg) + odh.subjects.append(np.ones((len(emg), 1), dtype=np.int64) * subject) + odh.reps.append(np.ones((len(emg), 1), dtype=np.int64) * rep_id) + odh.base_class.append(np.ones((len(emg), 1), dtype=np.int64) * gst) + odh.devices.append(np.ones((len(emg), 1), dtype=np.int64) * device) + odh.sampling_rates.append(np.ones((len(emg), 1), dtype=np.int64) * fs) + + emg = _emg[pe:] + if emg.shape[0] == 0: + continue + if not channel_last: + emg = np.transpose(emg, (1, 0)) odh.data.append(emg) - odh.classes.append(np.ones((len(emg), 1)) * gst) - odh.subjects.append(np.ones((len(emg), 1)) * subject) - odh.reps.append(np.ones((len(emg), 1)) * rep_id) - odh.devices.append(np.ones((len(emg), 1)) * device) - odh.sampling_rates.append(np.ones((len(emg), 1)) * fs) + odh.classes.append(np.ones((len(emg), 1), dtype=np.int64) * relabel_seg) + odh.subjects.append(np.ones((len(emg), 1), dtype=np.int64) * subject) + odh.reps.append(np.ones((len(emg), 1), dtype=np.int64) * rep_id) + odh.base_class.append(np.ones((len(emg), 1), dtype=np.int64) * gst) + odh.devices.append(np.ones((len(emg), 1), dtype=np.int64) * device) + odh.sampling_rates.append(np.ones((len(emg), 1), dtype=np.int64) * fs) odhs.append(odh) @@ -298,15 +337,15 @@ def prepare_data(self, """ print('\nPlease cite: ' + self.citation+'\n') if (not self.check_exists(self.dataset_folder)) and \ - (not self.check_exists( self.dataset_folder + "PROCESSED")): + (not self.check_exists(self.dataset_folder + "_PROCESSED")): raise FileNotFoundError("Please download the EPN100+ dataset from: {} " "and place 'testing' and 'training' folders inside: " "'{}' folder.".format(self.url, self.dataset_folder)) - if (not self.check_exists( self.dataset_folder + "PROCESSED")): - process_dataset(self.dataset_folder, self.dataset_folder + "PROCESSED") + if (not self.check_exists(self.dataset_folder + "_PROCESSED")): + process_dataset(self.dataset_folder, self.dataset_folder + "_PROCESSED") - odh_tr, odh_te = self._get_odh(self.dataset_folder + "PROCESSED", + odh_tr, odh_te = self._get_odh(self.dataset_folder + "_PROCESSED", subjects, segment, relabel_seg, channel_last) return {'All': odh_tr + odh_te, 'Train': odh_tr, 'Test': odh_te} \ @@ -327,9 +366,4 @@ def get_device_ID(self, device_name: str): Device's ID """ - return DEVICE_MAP[device_name] - - - - - \ No newline at end of file + return DEVICE_MAP[device_name] \ No newline at end of file