diff --git a/preprocessing/BCI-IV-2A/data_process.py b/preprocessing/BCI-IV-2A/data_process.py index 20a51ec..3fdca83 100644 --- a/preprocessing/BCI-IV-2A/data_process.py +++ b/preprocessing/BCI-IV-2A/data_process.py @@ -8,8 +8,8 @@ data_root = sys.argv[1] print(f"Data root: {data_root}") -raw_data_path = os.path.join(data_root,'BCI-4-2A/raw_data') -processed_data_path = os.path.join(data_root,'BCI-4-2A/processed_data') +raw_data_path = os.path.join(data_root,'BCI-IV-2A/raw_data') +processed_data_path = os.path.join(data_root,'BCI-IV-2A/processed_data') os.makedirs(processed_data_path, exist_ok=True) # Sampling rate and time split range diff --git a/preprocessing/SHU/cross_json_process.py b/preprocessing/SHU/cross_json_process.py index 1ba9372..c2d63a9 100644 --- a/preprocessing/SHU/cross_json_process.py +++ b/preprocessing/SHU/cross_json_process.py @@ -1,169 +1,168 @@ -import json -import os -import pickle -import numpy as np -import random -from collections import defaultdict -import sys - -data_root = sys.argv[1] -print(f"Data root: {data_root}") -processed_data_path = os.path.join(data_root,'SHU/processed_data') -data_split_path = './preprocessing/SHU/cross_subject_json' -os.makedirs(data_split_path, exist_ok=True) -save_train_path = os.path.join(data_split_path, 'train.json') -save_val_path = os.path.join(data_split_path, 'val.json') -save_test_path = os.path.join(data_split_path, 'test.json') - -# path1 = './Preprocessing/SHU/processed_data/' -# output_dir = './Preprocessing/SHU/cross_subject_json/' -# os.makedirs(output_dir, exist_ok=True) - -def save_to_json(data, filename): - with open(filename, 'w', encoding='utf-8') as f: - json.dump(data, f, ensure_ascii=False, indent=4) - print(f"File has been saved to {filename}") - - -def calculate_dataset_stats(data_list): - max_value = -float('inf') - min_value = float('inf') - channel_means = 0 - channel_stds = 0 - i = 0 - - for file_data in data_list: - file_path = file_data['file'] - if not os.path.exists(file_path): - print(f"File does not exist: {file_path}") - continue - - with open(file_path, 'rb') as f: - data = pickle.load(f) - X = data['X'] - - current_max = np.max(X) - current_min = np.min(X) - if current_max > max_value: - max_value = current_max - if current_min < min_value: - min_value = current_min - - channel_means += np.mean(X, axis=-1) - channel_stds += np.std(X, axis=-1) - i += 1 - - mean_values = channel_means / i - std_values = channel_stds / i - - return max_value, min_value, mean_values, std_values - - -def split_train_val(all_train_data, val_ratio=0.2): - subject_label_dict = defaultdict(lambda: defaultdict(list)) - - for item in all_train_data: - subject_id = item["subject_id"] - label = item["label"] - subject_label_dict[subject_id][label].append(item) - - train_data = [] - val_data = [] - for subject_id, label_dict in subject_label_dict.items(): - for label, items in label_dict.items(): - random.shuffle(items) - split_idx = int(len(items) * (1 - val_ratio)) - - train_data.extend(items[:split_idx]) # The first 80% of the data is the training set. - val_data.extend(items[split_idx:]) # The last 20% of the data is the validation set. - - return train_data, val_data - - -dataset_info_template = { - "sampling_rate": 250, - "ch_names": ["Fp1", "Fp2", "Fz", "F3", "F4", "F7", "F8", "FC1", - "FC2", "FC5", "FC6", "Cz", "C3", "C4", "T3", "T4", - "A1", "A2", "CP1", "CP2", "CP5", "CP6", "Pz", "P3", - "P4", "T5", "T6", "PO3", "PO4", "Oz", "O1", "O2"], - "min": None, - "max": None, - "mean": None, - "std": None -} - - -# Split by subject -all_train_data = [] -all_val_data = [] -all_test_data = [] - -for subject_id in range(1, 26): - subject_path = str(subject_id) + '/' - data_folder = os.path.join(processed_data_path, subject_path) - - if not os.path.exists(data_folder): - print(f"Folder does not exist: {data_folder}") - continue - - subject_files = [] - for file_name in os.listdir(data_folder): - if file_name.endswith(".pkl"): - parts = file_name.split('_') - if len(parts) == 3: - subject = int(parts[0]) - session = int(parts[1]) - trial = int(parts[2].split('.')[0]) - file_path = os.path.join(data_folder, file_name) - - try: - with open(file_path, 'rb') as f: - data = pickle.load(f) - - file_data = { - "subject_id": subject - 1, - "subject_name": f"{subject:03d}", - "file": file_path, - "label": data['Y'].tolist() - } - subject_files.append(file_data) - except Exception as e: - print(f"Error loading file {file_path}: {str(e)}") - - if 1 <= subject_id <= 22: - train, val = split_train_val(subject_files) - all_train_data.extend(train) - all_val_data.extend(val) - elif 23 <= subject_id <= 25: - all_test_data.extend(subject_files) - -print(f"train_set: {len(all_train_data)}, val_set: {len(all_val_data)}, test_set: {len(all_test_data)}") - -# Compute normalization parameters -train_max, train_min, train_mean, train_std = calculate_dataset_stats(all_train_data) - -dataset_info = dataset_info_template.copy() -dataset_info.update({ - "min": train_max, - "max": train_min, - "mean": train_mean.tolist(), - "std": train_std.tolist() -}) - -final_train_data = { - "dataset_info": dataset_info, - "subject_data": all_train_data -} -final_val_data = { - "dataset_info": dataset_info, - "subject_data": all_val_data -} -final_test_data = { - "dataset_info": dataset_info, - "subject_data": all_test_data -} - -save_to_json(final_train_data, save_train_path) -save_to_json(final_val_data, save_val_path) -save_to_json(final_test_data, save_test_path) -print("Cross-subject splitting completed") +import json +import os +import pickle +import numpy as np +import random +from collections import defaultdict +import sys + +data_root = sys.argv[1] +print(f"Data root: {data_root}") +processed_data_path = os.path.join(data_root, 'SHU/processed_data') +data_split_path = './preprocessing/SHU/cross_subject_json' +os.makedirs(data_split_path, exist_ok=True) +save_train_path = os.path.join(data_split_path, 'train.json') +save_val_path = os.path.join(data_split_path, 'val.json') +save_test_path = os.path.join(data_split_path, 'test.json') + + +def save_to_json(data, filename): + with open(filename, 'w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=4) + print(f"File has been saved to {filename}") + + +stats_map = {} + + +def compute_stats_from_map(data_list): + max_value = -np.inf + min_value = np.inf + channel_means = np.zeros(32) + channel_stds = np.zeros(32) + count = 0 + + for file_data in data_list: + stats = stats_map.get(file_data['file']) + if stats is None: + continue + channel_means += stats['mean'] + channel_stds += stats['std'] + max_value = max(max_value, stats['max']) + min_value = min(min_value, stats['min']) + count += 1 + + if count == 0: + return max_value, min_value, channel_means, channel_stds + + mean_values = channel_means / count + std_values = channel_stds / count + + return max_value, min_value, mean_values, std_values + + +def split_train_val(all_train_data, val_ratio=0.2): + subject_label_dict = defaultdict(lambda: defaultdict(list)) + + for item in all_train_data: + subject_id = item["subject_id"] + label = item["label"] + subject_label_dict[subject_id][label].append(item) + + train_data = [] + val_data = [] + for subject_id, label_dict in subject_label_dict.items(): + for label, items in label_dict.items(): + random.shuffle(items) + split_idx = int(len(items) * (1 - val_ratio)) + + train_data.extend(items[:split_idx]) + val_data.extend(items[split_idx:]) + + return train_data, val_data + + +dataset_info_template = { + "sampling_rate": 250, + "ch_names": ["Fp1", "Fp2", "Fz", "F3", "F4", "F7", "F8", "FC1", + "FC2", "FC5", "FC6", "Cz", "C3", "C4", "T3", "T4", + "A1", "A2", "CP1", "CP2", "CP5", "CP6", "Pz", "P3", + "P4", "T5", "T6", "PO3", "PO4", "Oz", "O1", "O2"], + "min": None, + "max": None, + "mean": None, + "std": None +} + +# Split by subject +all_train_data = [] +all_val_data = [] +all_test_data = [] + +for subject_id in range(1, 26): + subject_path = str(subject_id) + '/' + data_folder = os.path.join(processed_data_path, subject_path) + + if not os.path.exists(data_folder): + print(f"Folder does not exist: {data_folder}") + continue + + subject_files = [] + for file_name in os.listdir(data_folder): + if file_name.endswith(".pkl"): + parts = file_name.split('_') + if len(parts) == 3: + subject = int(parts[0]) + session = int(parts[1]) + trial = int(parts[2].split('.')[0]) + file_path = os.path.join(data_folder, file_name) + + try: + with open(file_path, 'rb') as f: + data = pickle.load(f) + + X = data['X'] + stats_map[file_path] = { + 'mean': np.mean(X, axis=-1), + 'std': np.std(X, axis=-1), + 'min': np.min(X), + 'max': np.max(X) + } + + file_data = { + "subject_id": subject - 1, + "subject_name": f"{subject:03d}", + "file": file_path, + "label": data['Y'].tolist() + } + subject_files.append(file_data) + except Exception as e: + print(f"Error loading file {file_path}: {str(e)}") + + if 1 <= subject_id <= 22: + train, val = split_train_val(subject_files) + all_train_data.extend(train) + all_val_data.extend(val) + elif 23 <= subject_id <= 25: + all_test_data.extend(subject_files) + +print(f"train_set: {len(all_train_data)}, val_set: {len(all_val_data)}, test_set: {len(all_test_data)}") + +# Compute normalization parameters +train_max, train_min, train_mean, train_std = compute_stats_from_map(all_train_data) + +dataset_info = dataset_info_template.copy() +dataset_info.update({ + "min": train_min, + "max": train_max, + "mean": train_mean.tolist(), + "std": train_std.tolist() +}) + +final_train_data = { + "dataset_info": dataset_info, + "subject_data": all_train_data +} +final_val_data = { + "dataset_info": dataset_info, + "subject_data": all_val_data +} +final_test_data = { + "dataset_info": dataset_info, + "subject_data": all_test_data +} + +save_to_json(final_train_data, save_train_path) +save_to_json(final_val_data, save_val_path) +save_to_json(final_test_data, save_test_path) +print("Cross-subject splitting completed") diff --git a/preprocessing/SHU/data_process.py b/preprocessing/SHU/data_process.py index a7cdaac..6e0a446 100644 --- a/preprocessing/SHU/data_process.py +++ b/preprocessing/SHU/data_process.py @@ -1,71 +1,70 @@ -import scipy.io as scio -import numpy as np -from scipy import signal -import pickle -import os -import sys - - -data_root = sys.argv[1] -print(f"Data root: {data_root}") -raw_data_path = os.path.join(data_root,'SHU/raw_data') -processed_data_path = os.path.join(data_root,'SHU/processed_data') -os.makedirs(processed_data_path, exist_ok=True) - -# file_path = './Preprocessing/SHU/raw_data/' -# save_path = './Preprocessing/SHU/processed_data/' - -# Define a bandpass filter (0.1Hz - 75Hz) -def bandpass_filter(data, lowcut=0.1, highcut=75.0, fs=250, order=4): - nyquist = 0.5 * fs - low = lowcut / nyquist - high = highcut / nyquist - b, a = signal.butter(order, [low, high], btype='band') - return signal.filtfilt(b, a, data, axis=-1) - - -# Define a notch filter (50Hz) -def notch_filter(data, freq=50.0, fs=250, Q=30.0): - nyquist = 0.5 * fs - w0 = freq / nyquist - b, a = signal.iirnotch(w0, Q) - return signal.filtfilt(b, a, data, axis=-1) - - -# Resampling -def resample_data(data, old_rate=250, new_rate=256): - number_of_samples = int(data.shape[-1] * new_rate / old_rate) - return signal.resample(data, number_of_samples, axis=-1) - - -# Save to .pkl files -def save_event_to_pkl(X, Y, save_dir, people, session): - event_count = 1 - for i in range(X.shape[0]): - event_data = X[i, :, :] - label = Y[i] - file_name = os.path.join(save_dir, str(people) + '_' + str(session) + '_' + f'{event_count}.pkl') - with open(file_name, 'wb') as f: - pickle.dump({'X': event_data, 'Y': label}, f) - event_count += 1 - -for i in range(25): - subject = i + 1 - for j in range(5): - session = j + 1 - eeg = 'sub-0' + f"{subject:02d}" + '_ses-' + f"{session:02d}" + '_task_motorimagery_eeg.mat' - data = scio.loadmat(os.path.join(raw_data_path, eeg)) - data1 = data['data'] - label = np.squeeze(data['labels']) - 1 - print(data1.shape) - print(label.shape) - for trial in range(len(label)): - eeg_trial = data1[trial, :, :] - eeg_trial = bandpass_filter(eeg_trial, lowcut=0.1, highcut=75.0, fs=250) - eeg_trial = notch_filter(eeg_trial, freq=50.0, fs=250) - save_dir = processed_data_path + str(subject) + '/' - if not os.path.exists(save_dir): - os.makedirs(save_dir) - file_name = os.path.join(save_dir, str(subject) + '_' + str(session) + '_' + str(trial+1) + '.pkl') - with open(file_name, 'wb') as f: - pickle.dump({'X': eeg_trial, 'Y': label[trial]}, f) +import scipy.io as scio +import numpy as np +from scipy import signal +import pickle +import os +import sys + + +data_root = sys.argv[1] +print(f"Data root: {data_root}") +raw_data_path = os.path.join(data_root,'SHU/raw_data') +processed_data_path = os.path.join(data_root,'SHU/processed_data') +os.makedirs(processed_data_path, exist_ok=True) + +# file_path = './Preprocessing/SHU/raw_data/' +# save_path = './Preprocessing/SHU/processed_data/' + +# Define a bandpass filter (0.1Hz - 75Hz) +def bandpass_filter(data, lowcut=0.1, highcut=75.0, fs=250, order=4): + nyquist = 0.5 * fs + low = lowcut / nyquist + high = highcut / nyquist + b, a = signal.butter(order, [low, high], btype='band') + return signal.filtfilt(b, a, data, axis=-1) + + +# Define a notch filter (50Hz) +def notch_filter(data, freq=50.0, fs=250, Q=30.0): + nyquist = 0.5 * fs + w0 = freq / nyquist + b, a = signal.iirnotch(w0, Q) + return signal.filtfilt(b, a, data, axis=-1) + + +# Resampling +def resample_data(data, old_rate=250, new_rate=256): + number_of_samples = int(data.shape[-1] * new_rate / old_rate) + return signal.resample(data, number_of_samples, axis=-1) + + +# Save to .pkl files +def save_event_to_pkl(X, Y, save_dir, people, session): + event_count = 1 + for i in range(X.shape[0]): + event_data = X[i, :, :] + label = Y[i] + file_name = os.path.join(save_dir, str(people) + '_' + str(session) + '_' + f'{event_count}.pkl') + with open(file_name, 'wb') as f: + pickle.dump({'X': event_data, 'Y': label}, f) + event_count += 1 + +for i in range(25): + subject = i + 1 + for j in range(5): + session = j + 1 + eeg = 'sub-0' + f"{subject:02d}" + '_ses-' + f"{session:02d}" + '_task_motorimagery_eeg.mat' + data = scio.loadmat(os.path.join(raw_data_path, eeg)) + data1 = data['data'] + label = np.squeeze(data['labels']) - 1 + print(data1.shape) + print(label.shape) + for trial in range(len(label)): + eeg_trial = data1[trial, :, :] + eeg_trial = bandpass_filter(eeg_trial, lowcut=0.1, highcut=75.0, fs=250) + eeg_trial = notch_filter(eeg_trial, freq=50.0, fs=250) + save_dir = os.path.join(processed_data_path, str(subject)) + os.makedirs(save_dir, exist_ok=True) + file_name = os.path.join(save_dir, str(subject) + '_' + str(session) + '_' + str(trial+1) + '.pkl') + with open(file_name, 'wb') as f: + pickle.dump({'X': eeg_trial, 'Y': label[trial]}, f) diff --git a/preprocessing/SHU/multi_json_process.py b/preprocessing/SHU/multi_json_process.py index 18617e6..b700d6b 100644 --- a/preprocessing/SHU/multi_json_process.py +++ b/preprocessing/SHU/multi_json_process.py @@ -1,149 +1,147 @@ -import json -import os -import pickle -import numpy as np -from collections import defaultdict -import sys - - -data_root = sys.argv[1] -print(f"Data root: {data_root}") -processed_data_path = os.path.join(data_root,'SHU/processed_data') -data_split_path = './preprocessing/SHU/multi_subject_json' -os.makedirs(data_split_path, exist_ok=True) -save_train_path = os.path.join(data_split_path, 'train.json') -save_val_path = os.path.join(data_split_path, 'val.json') -save_test_path = os.path.join(data_split_path, 'test.json') - -# path1 = './Preprocessing/SHU/processed_data/' -# output_dir = './Preprocessing/SHU/multi_subject_json/' -# os.makedirs(output_dir, exist_ok=True) - -def save_to_json(data, filename): - with open(filename, 'w', encoding='utf-8') as f: - json.dump(data, f, ensure_ascii=False, indent=4) - print(f"File has been saved to {filename}") - - -def calculate_dataset_stats(data_list): - max_value = -float('inf') - min_value = float('inf') - channel_means = 0 - channel_stds = 0 - i = 0 - - for file_data in data_list: - file_path = file_data['file'] - if not os.path.exists(file_path): - print(f"File does not exist: {file_path}") - continue - - with open(file_path, 'rb') as f: - data = pickle.load(f) - X = data['X'] - - current_max = np.max(X) - current_min = np.min(X) - if current_max > max_value: - max_value = current_max - if current_min < min_value: - min_value = current_min - - channel_means += np.mean(X, axis=-1) - channel_stds += np.std(X, axis=-1) - i += 1 - - mean_values = channel_means / i - std_values = channel_stds / i - - return max_value, min_value, mean_values, std_values - - -dataset_info_template = { - "sampling_rate": 250, - "ch_names": ["Fp1", "Fp2", "Fz", "F3", "F4", "F7", "F8", "FC1", - "FC2", "FC5", "FC6", "Cz", "C3", "C4", "T3", "T4", - "A1", "A2", "CP1", "CP2", "CP5", "CP6", "Pz", "P3", - "P4", "T5", "T6", "PO3", "PO4", "Oz", "O1", "O2"], - "min": None, - "max": None, - "mean": None, - "std": None -} - - -# Split by session -train_data = [] -val_data = [] -test_data = [] - -for i in range(1,26): - subject_id = i - subject_path = str(subject_id) + '/' - data_folder = os.path.join(processed_data_path, subject_path) - - if not os.path.exists(data_folder): - print(f"Folder does not exist: {data_folder}") - continue - - for file_name in os.listdir(data_folder): - if file_name.endswith(".pkl"): - parts = file_name.split('_') - if len(parts) == 3: - subject = int(parts[0]) - session = int(parts[1]) - trial = int(parts[2].split('.')[0]) - file_path = os.path.join(data_folder, file_name) - - try: - with open(file_path, 'rb') as f: - data = pickle.load(f) - - label = data['Y'].tolist() - file_data = { - "subject_id": subject - 1, - "subject_name": f"{subject:03d}", - "file": file_path, - "label": label - } - - if session in [1, 2, 3]: - train_data.append(file_data) - elif session == 4: - val_data.append(file_data) - elif session == 5: - test_data.append(file_data) - - except Exception as e: - print(f"Error loading file {file_path}: {str(e)}") - -print(f"train_set: {len(train_data)}, val_set: {len(val_data)}, test_set: {len(test_data)}") - -# Compute normalization parameters -train_max, train_min, train_mean, train_std = calculate_dataset_stats(train_data) - -dataset_info = dataset_info_template.copy() -dataset_info.update({ - "min": train_max, - "max": train_min, - "mean": train_mean.tolist(), - "std": train_std.tolist() -}) - -final_train_data = { - "dataset_info": dataset_info, - "subject_data": train_data -} -final_val_data = { - "dataset_info": dataset_info, - "subject_data": val_data -} -final_test_data = { - "dataset_info": dataset_info, - "subject_data": test_data -} - -save_to_json(final_train_data, save_train_path) -save_to_json(final_val_data, save_val_path) -save_to_json(final_test_data, save_test_path) -print("Multi-subject splitting completed") +import json +import os +import pickle +import numpy as np +import sys + + +data_root = sys.argv[1] +print(f"Data root: {data_root}") +processed_data_path = os.path.join(data_root, 'SHU/processed_data') +data_split_path = './preprocessing/SHU/multi_subject_json' +os.makedirs(data_split_path, exist_ok=True) +save_train_path = os.path.join(data_split_path, 'train.json') +save_val_path = os.path.join(data_split_path, 'val.json') +save_test_path = os.path.join(data_split_path, 'test.json') + + +def save_to_json(data, filename): + with open(filename, 'w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=4) + print(f"File has been saved to {filename}") + + +stats_map = {} + + +def compute_stats_from_map(data_list): + max_value = -np.inf + min_value = np.inf + channel_means = np.zeros(32) + channel_stds = np.zeros(32) + count = 0 + + for file_data in data_list: + stats = stats_map.get(file_data['file']) + if stats is None: + continue + channel_means += stats['mean'] + channel_stds += stats['std'] + max_value = max(max_value, stats['max']) + min_value = min(min_value, stats['min']) + count += 1 + + if count == 0: + return max_value, min_value, channel_means, channel_stds + + mean_values = channel_means / count + std_values = channel_stds / count + + return max_value, min_value, mean_values, std_values + + +dataset_info_template = { + "sampling_rate": 250, + "ch_names": ["Fp1", "Fp2", "Fz", "F3", "F4", "F7", "F8", "FC1", + "FC2", "FC5", "FC6", "Cz", "C3", "C4", "T3", "T4", + "A1", "A2", "CP1", "CP2", "CP5", "CP6", "Pz", "P3", + "P4", "T5", "T6", "PO3", "PO4", "Oz", "O1", "O2"], + "min": None, + "max": None, + "mean": None, + "std": None +} + +# Split by session +train_data = [] +val_data = [] +test_data = [] + +for i in range(1, 26): + subject_id = i + subject_path = str(subject_id) + '/' + data_folder = os.path.join(processed_data_path, subject_path) + + if not os.path.exists(data_folder): + print(f"Folder does not exist: {data_folder}") + continue + + for file_name in os.listdir(data_folder): + if file_name.endswith(".pkl"): + parts = file_name.split('_') + if len(parts) == 3: + subject = int(parts[0]) + session = int(parts[1]) + trial = int(parts[2].split('.')[0]) + file_path = os.path.join(data_folder, file_name) + + try: + with open(file_path, 'rb') as f: + data = pickle.load(f) + + X = data['X'] + stats_map[file_path] = { + 'mean': np.mean(X, axis=-1), + 'std': np.std(X, axis=-1), + 'min': np.min(X), + 'max': np.max(X) + } + + label = data['Y'].tolist() + file_data = { + "subject_id": subject - 1, + "subject_name": f"{subject:03d}", + "file": file_path, + "label": label + } + + if session in [1, 2, 3]: + train_data.append(file_data) + elif session == 4: + val_data.append(file_data) + elif session == 5: + test_data.append(file_data) + + except Exception as e: + print(f"Error loading file {file_path}: {str(e)}") + +print(f"train_set: {len(train_data)}, val_set: {len(val_data)}, test_set: {len(test_data)}") + +# Compute normalization parameters +train_max, train_min, train_mean, train_std = compute_stats_from_map(train_data) + +dataset_info = dataset_info_template.copy() +dataset_info.update({ + "min": train_min, + "max": train_max, + "mean": train_mean.tolist(), + "std": train_std.tolist() +}) + +final_train_data = { + "dataset_info": dataset_info, + "subject_data": train_data +} +final_val_data = { + "dataset_info": dataset_info, + "subject_data": val_data +} +final_test_data = { + "dataset_info": dataset_info, + "subject_data": test_data +} + +save_to_json(final_train_data, save_train_path) +save_to_json(final_val_data, save_val_path) +save_to_json(final_test_data, save_test_path) +print("Multi-subject splitting completed") diff --git a/preprocessing/Siena/cross_json_process.py b/preprocessing/Siena/cross_json_process.py index 81a8f58..99a79d1 100644 --- a/preprocessing/Siena/cross_json_process.py +++ b/preprocessing/Siena/cross_json_process.py @@ -1,150 +1,152 @@ -import json -import os -import pickle -import numpy as np -from natsort import natsorted -from collections import defaultdict -import random -import sys - - -data_root = sys.argv[1] -print(f"Data root: {data_root}") -processed_data_path = os.path.join(data_root,'Siena/processed_data') -data_split_path = './preprocessing/Siena/cross_subject_json' -os.makedirs(data_split_path, exist_ok=True) -save_train_path = os.path.join(data_split_path, 'train.json') -save_val_path = os.path.join(data_split_path, 'val.json') -save_test_path = os.path.join(data_split_path, 'test.json') - -# data_folder = "./Preprocessing/Siena/raw_data" -# os.makedirs('./Preprocessing/Siena/cross_subject_json', exist_ok=True) -# save_folder_train = './Preprocessing/Siena/cross_subject_json/train.json' -# save_folder_val = './Preprocessing/Siena/cross_subject_json/val.json' -# save_folder_test = './Preprocessing/Siena/cross_subject_json/test.json' - -sampling_rate = 512 -ch_names = ['Fp1', 'F3', 'C3', 'P3', 'O1', 'F7', 'T3', 'T5', 'Fc1', 'Fc5', 'Cp1', 'Cp5', 'F9', 'Fz', 'Cz', 'Pz', 'Fp2', - 'F4', 'C4', 'P4', 'O2', 'F8', 'T4', 'T6', 'Fc2', 'Fc6', 'Cp2', 'Cp6', 'F10'] -num_channels = len(ch_names) -random.seed(42) - - -def load_subject_metadata(subject_folder): - subject_data = [] - folder_name = os.path.basename(subject_folder) - subject_num = int(folder_name[2:]) - subject_name = f"PN{subject_num:02d}" - - for file in natsorted(f for f in os.listdir(subject_folder) if f.endswith('.pkl')): - file_path = os.path.join(subject_folder, file) - try: - with open(file_path, 'rb') as f: - eeg_data = pickle.load(f) - subject_data.append({ - "subject_id": subject_num, - "subject_name": subject_name, - "file": file_path, - "label": eeg_data['Y'] - }) - except Exception as e: - print(f"Error loading {file_path}: {str(e)}") - return subject_data - - -def compute_normalization_params(data_list): - """Calculate normalization parameters.""" - total_mean = np.zeros(num_channels) - total_std = np.zeros(num_channels) - max_val, min_val = -np.inf, np.inf - count = 0 - - for data in data_list: - with open(data["file"], 'rb') as f: - eeg = pickle.load(f)['X'] - - max_val = max(max_val, eeg.max()) - min_val = min(min_val, eeg.min()) - total_mean += eeg.mean(axis=1) - total_std += eeg.std(axis=1) - count += 1 - - return (total_mean / count).tolist(), (total_std / count).tolist(), max_val, min_val - - -def split_subject_data(subject_data, val_ratio=0.2): - """Split the data of a single subject into a validation set with class-balanced partitioning.""" - label_to_data = defaultdict(list) - for data in subject_data: - label_to_data[data["label"]].append(data) - - train_data, val_data = [], [] - for label, data_list in label_to_data.items(): - random.shuffle(data_list) - split_idx = int(len(data_list) * (1 - val_ratio)) - train_data.extend(data_list[:split_idx]) - val_data.extend(data_list[split_idx:]) - - return train_data, val_data - - -def save_dataset(data_list, save_path, norm_params=None): - if norm_params is None: - print("Computing normalization parameters...") - mean, std, max_val, min_val = compute_normalization_params(data_list) - else: - mean, std, max_val, min_val = norm_params - - dataset = { - "subject_data": data_list, # 只包含元数据 - "dataset_info": { - "sampling_rate": sampling_rate, - "ch_names": ch_names, - "min": min_val, - "max": max_val, - "mean": mean, - "std": std - } - } - - os.makedirs(os.path.dirname(save_path), exist_ok=True) - with open(save_path, 'w') as f: - json.dump(dataset, f, indent=2) - print(f"Saved to {save_path}") - - -def main(): - subject_folders = natsorted( - os.path.join(processed_data_path, f) - for f in os.listdir(processed_data_path) - if f.startswith("PN") and os.path.isdir(os.path.join(processed_data_path, f)) - ) - - train_subjects = [s for s in subject_folders if int(os.path.basename(s)[2:]) <= 14] - test_subjects = [s for s in subject_folders if int(os.path.basename(s)[2:]) > 14] # The last two subjects are PN16 and PN17. - - all_train_data, all_val_data = [], [] - for subject in train_subjects: - subject_data = load_subject_metadata(subject) - train_data, val_data = split_subject_data(subject_data) - all_train_data.extend(train_data) - all_val_data.extend(val_data) - - all_test_data = [] - for subject in test_subjects: - all_test_data.extend(load_subject_metadata(subject)) - - print(f"\nData counts:") - print(f"Train: {len(all_train_data)}, Val: {len(all_val_data)}, Test: {len(all_test_data)}") - - print("\nComputing normalization...") - norm_params = compute_normalization_params(all_train_data) - - print("\nSaving datasets...") - save_dataset(all_train_data, save_train_path, norm_params) - save_dataset(all_val_data, save_val_path, norm_params) - save_dataset(all_test_data, save_test_path, norm_params) - - -if __name__ == "__main__": - main() +import json +import os +import pickle +import numpy as np +from natsort import natsorted +from collections import defaultdict +import random +import sys + + +data_root = sys.argv[1] +print(f"Data root: {data_root}") +processed_data_path = os.path.join(data_root, 'Siena/processed_data') +data_split_path = './preprocessing/Siena/cross_subject_json' +os.makedirs(data_split_path, exist_ok=True) +save_train_path = os.path.join(data_split_path, 'train.json') +save_val_path = os.path.join(data_split_path, 'val.json') +save_test_path = os.path.join(data_split_path, 'test.json') + +sampling_rate = 512 +ch_names = ['Fp1', 'F3', 'C3', 'P3', 'O1', 'F7', 'T3', 'T5', 'Fc1', 'Fc5', 'Cp1', 'Cp5', 'F9', 'Fz', 'Cz', 'Pz', 'Fp2', + 'F4', 'C4', 'P4', 'O2', 'F8', 'T4', 'T6', 'Fc2', 'Fc6', 'Cp2', 'Cp6', 'F10'] +num_channels = len(ch_names) +random.seed(42) + +stats_map = {} + + +def load_subject_metadata(subject_folder): + subject_data = [] + folder_name = os.path.basename(subject_folder) + subject_num = int(folder_name[2:]) + subject_name = f"PN{subject_num:02d}" + + for file in natsorted(f for f in os.listdir(subject_folder) if f.endswith('.pkl')): + file_path = os.path.join(subject_folder, file) + try: + with open(file_path, 'rb') as f: + eeg_data = pickle.load(f) + X = eeg_data['X'] + stats_map[file_path] = { + "mean": X.mean(axis=1), + "std": X.std(axis=1), + "min": X.min(), + "max": X.max() + } + subject_data.append({ + "subject_id": subject_num, + "subject_name": subject_name, + "file": file_path, + "label": eeg_data['Y'] + }) + except Exception as e: + print(f"Error loading {file_path}: {str(e)}") + return subject_data + + +def compute_stats_from_map(data_list): + total_mean = np.zeros(num_channels) + total_std = np.zeros(num_channels) + max_val = -np.inf + min_val = np.inf + count = 0 + + for data in data_list: + stats = stats_map.get(data["file"]) + if stats is None: + continue + total_mean += stats["mean"] + total_std += stats["std"] + max_val = max(max_val, stats["max"]) + min_val = min(min_val, stats["min"]) + count += 1 + + if count == 0: + return total_mean.tolist(), total_std.tolist(), max_val, min_val + + return (total_mean / count).tolist(), (total_std / count).tolist(), max_val, min_val + + +def split_subject_data(subject_data, val_ratio=0.2): + """Split the data of a single subject into a validation set with class-balanced partitioning.""" + label_to_data = defaultdict(list) + for data in subject_data: + label_to_data[data["label"]].append(data) + + train_data, val_data = [], [] + for label, data_list in label_to_data.items(): + random.shuffle(data_list) + split_idx = int(len(data_list) * (1 - val_ratio)) + train_data.extend(data_list[:split_idx]) + val_data.extend(data_list[split_idx:]) + + return train_data, val_data + + +def save_dataset(data_list, save_path, norm_params): + mean, std, max_val, min_val = norm_params + + dataset = { + "subject_data": data_list, + "dataset_info": { + "sampling_rate": sampling_rate, + "ch_names": ch_names, + "min": min_val, + "max": max_val, + "mean": mean, + "std": std + } + } + + os.makedirs(os.path.dirname(save_path), exist_ok=True) + with open(save_path, 'w') as f: + json.dump(dataset, f, indent=2) + print(f"Saved to {save_path}") + + +def main(): + subject_folders = natsorted( + os.path.join(processed_data_path, f) + for f in os.listdir(processed_data_path) + if f.startswith("PN") and os.path.isdir(os.path.join(processed_data_path, f)) + ) + + train_subjects = [s for s in subject_folders if int(os.path.basename(s)[2:]) <= 14] + test_subjects = [s for s in subject_folders if int(os.path.basename(s)[2:]) > 14] + + all_train_data, all_val_data = [], [] + for subject in train_subjects: + subject_data = load_subject_metadata(subject) + train_data, val_data = split_subject_data(subject_data) + all_train_data.extend(train_data) + all_val_data.extend(val_data) + + all_test_data = [] + for subject in test_subjects: + all_test_data.extend(load_subject_metadata(subject)) + + print(f"\nData counts:") + print(f"Train: {len(all_train_data)}, Val: {len(all_val_data)}, Test: {len(all_test_data)}") + + print("\nComputing normalization...") + norm_params = compute_stats_from_map(all_train_data) + + print("\nSaving datasets...") + save_dataset(all_train_data, save_train_path, norm_params) + save_dataset(all_val_data, save_val_path, norm_params) + save_dataset(all_test_data, save_test_path, norm_params) + + +if __name__ == "__main__": + main() diff --git a/preprocessing/Siena/data_process.py b/preprocessing/Siena/data_process.py index 3338356..751f11f 100644 --- a/preprocessing/Siena/data_process.py +++ b/preprocessing/Siena/data_process.py @@ -1,205 +1,208 @@ -import os -import mne -import numpy as np -import pickle -import argparse -import re -import sys - -data_root = sys.argv[1] -print(f"Data root: {data_root}") -raw_data_path = os.path.join(data_root,'Siena/raw_data') -processed_data_path = os.path.join(data_root,'Siena/processed_data') -os.makedirs(processed_data_path, exist_ok=True) - -# base_dir = "./Preprocessing/Siena/raw_data" -# output_dir = "./Preprocessing/Siena/processed_data/" -# os.makedirs(output_dir, exist_ok=True) - -STANDARD_CHANNELS = [ - "EEG Fp1", "EEG F3", "EEG C3", "EEG P3", "EEG O1", "EEG F7", "EEG T3", "EEG T5", "EEG Fc1", "EEG Fc5", - "EEG Cp1", "EEG Cp5", "EEG F9", "EEG Fz", "EEG Cz", "EEG Pz", "EEG Fp2", "EEG F4", "EEG C4", "EEG P4", - "EEG O2", "EEG F8", "EEG T4", "EEG T6", "EEG Fc2", "EEG Fc6", "EEG Cp2", "EEG Cp6", "EEG F10" -] # All subjects' EDF files contain the 29 specified channels. - -# {'file': 'PN00-3.edf', 'reg_start': '18.15.44', 'start_time': '18.28.29', 'end_time': '19.29.29'}, # Incorrect seizure time record -seizure_records = [ - # PN00 - {'file': 'PN00-1.edf', 'reg_start': '19.39.33', 'start_time': '19.58.36', 'end_time': '19.59.46'}, - {'file': 'PN00-2.edf', 'reg_start': '02.18.17', 'start_time': '02.38.37', 'end_time': '02.39.31'}, - {'file': 'PN00-4.edf', 'reg_start': '20.51.43', 'start_time': '21.08.29', 'end_time': '21.09.43'}, - {'file': 'PN00-5.edf', 'reg_start': '22.22.04', 'start_time': '22.37.08', 'end_time': '22.38.15'}, - # PN01 - {'file': 'PN01-1.edf', 'reg_start': '19.00.44', 'start_time': '21.51.02', 'end_time': '21.51.56'}, - {'file': 'PN01-1.edf', 'reg_start': '19.00.44', 'start_time': '07.53.17', 'end_time': '07.54.31'}, - # PN03 - {'file': 'PN03-1.edf', 'reg_start': '22.44.37', 'start_time': '09.29.10', 'end_time': '09.31.01'}, - {'file': 'PN03-2.edf', 'reg_start': '21.31.04', 'start_time': '07.13.05', 'end_time': '07.15.18'}, - # PN05 - {'file': 'PN05-2.edf', 'reg_start': '06.46.02', 'start_time': '08.45.25', 'end_time': '08.46.00'}, - {'file': 'PN05-3.edf', 'reg_start': '06.01.23', 'start_time': '07.55.19', 'end_time': '07.55.49'}, - {'file': 'PN05-4.edf', 'reg_start': '06.38.35', 'start_time': '07.38.43', 'end_time': '07.39.22'}, - # PN06 - {'file': 'PNO6-1.edf', 'reg_start': '04.21.22', 'start_time': '05.54.25', 'end_time': '05.55.29'}, - {'file': 'PNO6-2.edf', 'reg_start': '21.11.29', 'start_time': '23.39.09', 'end_time': '23.40.18'}, - {'file': 'PN06-3.edf', 'reg_start': '06.25.51', 'start_time': '08.10.26', 'end_time': '08.11.08'}, - {'file': 'PNO6-4.edf', 'reg_start': '11.16.09', 'start_time': '12.55.08', 'end_time': '12.56.11'}, - {'file': 'PN06-5.edf', 'reg_start': '13.24.41', 'start_time': '14.44.24', 'end_time': '14.45.08'}, - # PN07 - {'file': 'PN07-1.edf', 'reg_start': '23.18.10', 'start_time': '05.25.49', 'end_time': '05.26.51'}, - # PN09 - {'file': 'PN09-1.edf', 'reg_start': '14.08.54', 'start_time': '16.09.43', 'end_time': '16.11.03'}, - {'file': 'PN09-2.edf', 'reg_start': '15.02.09', 'start_time': '17.00.56', 'end_time': '17.01.55'}, - {'file': 'PN09-3.edf', 'reg_start': '14.20.23', 'start_time': '16.20.44', 'end_time': '16.21.48'}, - # PN10 - {'file': 'PN10-1.edf', 'reg_start': '05.40.05', 'start_time': '07.45.50', 'end_time': '07.46.59'}, - {'file': 'PN10-2.edf', 'reg_start': '09.30.15', 'start_time': '11.40.13', 'end_time': '11.41.04'}, - {'file': 'PN10-3.edf', 'reg_start': '13.33.18', 'start_time': '15.43.53', 'end_time': '15.45.02'}, - {'file': 'PN10-4.5.6.edf', 'reg_start': '12.11.21', 'start_time': '12.49.50', 'end_time': '12.49.55'}, - {'file': 'PN10-4.5.6.edf', 'reg_start': '12.11.21', 'start_time': '14.00.25', 'end_time': '14.00.44'}, - {'file': 'PN10-4.5.6.edf', 'reg_start': '12.11.21', 'start_time': '15.18.26', 'end_time': '15.19.23'}, - {'file': 'PN10-7.8.9.edf', 'reg_start': '16.49.25', 'start_time': '17.35.13', 'end_time': '17.36.01'}, - {'file': 'PN10-7.8.9.edf', 'reg_start': '16.49.25', 'start_time': '18.20.24', 'end_time': '18.20.42'}, - {'file': 'PN10-7.8.9.edf', 'reg_start': '16.49.25', 'start_time': '20.24.48', 'end_time': '20.25.03'}, - {'file': 'PN10-10.edf', 'reg_start': '08.45.22', 'start_time': '10.58.19', 'end_time': '10.58.33'}, - # PN11 - {'file': 'PN11-1.edf', 'reg_start': '11.31.25', 'start_time': '13.37.19', 'end_time': '13.38.14'}, - # PN12 - {'file': 'PN12-1.2.edf', 'reg_start': '15.51.31', 'start_time': '16.13.23', 'end_time': '16.14.26'}, - {'file': 'PN12-1.2.edf', 'reg_start': '15.51.31', 'start_time': '18.31.01', 'end_time': '18.32.09'}, - {'file': 'PN12-3.edf', 'reg_start': '08.42.35', 'start_time': '08.55.27', 'end_time': '08.57.03'}, - {'file': 'PN12-4.edf', 'reg_start': '15.59.19', 'start_time': '18.42.51', 'end_time': '18.43.54'}, - # PN13 - {'file': 'PN13-1.edf', 'reg_start': '08.24.28', 'start_time': '10.22.10', 'end_time': '10.22.58'}, - {'file': 'PN13-2.edf', 'reg_start': '06.55.02', 'start_time': '08.55.51', 'end_time': '08.56.56'}, - {'file': 'PN13-3.edf', 'reg_start': '12.00.01', 'start_time': '14.05.54', 'end_time': '14.08.25'}, - # PN14 - {'file': 'PN14-1.edf', 'reg_start': '11.44.58', 'start_time': '13.46.00', 'end_time': '13.46.27'}, - {'file': 'PN14-2.edf', 'reg_start': '15.50.13', 'start_time': '17.54.52', 'end_time': '17.55.04'}, - {'file': 'PN14-3.edf', 'reg_start': '16.17.45', 'start_time': '21.10.05', 'end_time': '21.10.46'}, - {'file': 'PN14-4.edf', 'reg_start': '14.18.30', 'start_time': '15.49.33', 'end_time': '15.50.56'}, - # PN16 - {'file': 'PN16-1.edf', 'reg_start': '20.45.21', 'start_time': '22.45.05', 'end_time': '22.47.08'}, - {'file': 'PN16-2.edf', 'reg_start': '00.53.55', 'start_time': '03.16.49', 'end_time': '03.18.36'}, - # PN17 - {'file': 'PN17-1.edf', 'reg_start': '20.14.28', 'start_time': '22.34.48', 'end_time': '22.35.58'}, - {'file': 'PN17-2.edf', 'reg_start': '13.52.18', 'start_time': '16.01.09', 'end_time': '16.02.32'} -] - - -def time_to_samples(time_str, start_time_str, sampling_rate): - # Supports time formats: HH:MM:SS and HH.MM.SS - time_str = time_str.replace(':', '.') - start_time_str = start_time_str.replace(':', '.') - - h, m, s = map(int, time_str.split('.')) - start_h, start_m, start_s = map(int, start_time_str.split('.')) - - # Handle cross-day scenarios (e.g., registration time 19.00.44, seizure time 07.53.17 the next day) - if h < start_h: - h += 24 # Assume the cross-day interval does not exceed 24 hours. - - delta = (h - start_h) * 3600 + (m - start_m) * 60 + (s - start_s) - return delta * sampling_rate - - -def process_edf(edf_path, seizure_records, processed_data_path, sampling_rate=512): - try: - raw = mne.io.read_raw_edf(edf_path, preload=True) - raw.filter(l_freq=0.1, h_freq=75.0) - raw.notch_filter(freqs=50) - - # Add sampling rate verification - if raw.info['sfreq'] != float(sampling_rate): - return False, f"Sampling rate mismatch in {os.path.basename(edf_path)}: " \ - f"expected {sampling_rate} Hz, got {raw.info['sfreq']} Hz" - - # Process channel-related issues - existing_channels = raw.ch_names - existing_map = {ch.lower(): ch for ch in existing_channels} - selected_channels = [] - reordered_channels = [] - missing_channels = [] - for ch in STANDARD_CHANNELS: - ch_lower = ch.lower() - if ch_lower in existing_map: - real_name = existing_map[ch_lower] - selected_channels.append(real_name) - reordered_channels.append(real_name) - else: - missing_channels.append(ch) - print(selected_channels) - print(reordered_channels) - if missing_channels: - print( - f"Missing channels in {raw.filenames[0] if hasattr(raw, 'filenames') else 'EDF'}: {missing_channels}") - raw.pick_channels(selected_channels) - raw.reorder_channels(reordered_channels) - - # Extract seizure event records in the current file - data, _ = raw[:, :] - current_file = os.path.basename(edf_path) - file_seizures = [sz for sz in seizure_records if sz["file"] == current_file] - seizure_samples = [] - for sz in file_seizures: - start = time_to_samples(sz["start_time"], sz["reg_start"], sampling_rate) - end = time_to_samples(sz["end_time"], sz["reg_start"], sampling_rate) - seizure_samples.append((start, end)) - print(seizure_samples) - - # Create all data segments - segments = [] - segment_length = 10 * sampling_rate - # 1. Standard split segmentation - for i in range(0, data.shape[1], segment_length): - if i + segment_length <= data.shape[1]: - seg = data[:, i:i + segment_length] - label = 0 - for sz_start, sz_end in seizure_samples: - if (i < sz_start < i + segment_length) or (i < sz_end < i + segment_length): - label = 1 - break - segments.append((seg, label)) - # 2. Seizure-enhanced segmentation - for sz_start, sz_end in seizure_samples: - start = max(0, sz_start - sampling_rate) - end = min(data.shape[1], sz_end + sampling_rate) - - for i in range(start, end, 5 * sampling_rate): - seg = data[:, i:i + segment_length] - segments.append((seg, 1)) - - # Save all data segments - base_name = os.path.splitext(current_file)[0] - for idx, (seg, label) in enumerate(segments): - output_data = {"X": seg, "Y": label} - output_file = os.path.join(processed_data_path, f"{base_name}_{idx}.pkl") - with open(output_file, "wb") as f: - pickle.dump(output_data, f) - - return True, f"Processed {current_file}, generated {len(segments)} segments (with 0.1-75Hz bandpass + 50Hz notch)" - - except Exception as e: - return False, f"Error processing {os.path.basename(edf_path)}: {str(e)}" - - -def main(patient_id): - edf_files = sorted([ - f for f in os.listdir(os.path.join(raw_data_path, patient_id)) - if f.endswith('.edf') and f.startswith(patient_id) - ]) - - for edf_file in edf_files: - edf_path = os.path.join(raw_data_path, patient_id, edf_file) - print(f"Processing {edf_file}...") - success, message = process_edf(edf_path, seizure_records, processed_data_path, sampling_rate = 512) - print(message) - print(f"\nAll processing completed. Results saved to {processed_data_path}") - - -if __name__ == "__main__": - patient_ids = ["PN00", "PN01", "PN03", "PN05", "PN06", "PN07", "PN09", - "PN10", "PN11", "PN12", "PN13", "PN14", "PN16", "PN17"] - - for pid in patient_ids: - main(pid) +import os +import mne +import numpy as np +import pickle +import argparse +import re +import sys + +data_root = sys.argv[1] +print(f"Data root: {data_root}") +raw_data_path = os.path.join(data_root,'Siena/raw_data') +processed_data_path = os.path.join(data_root,'Siena/processed_data') +os.makedirs(processed_data_path, exist_ok=True) + +# base_dir = "./Preprocessing/Siena/raw_data" +# output_dir = "./Preprocessing/Siena/processed_data/" +# os.makedirs(output_dir, exist_ok=True) + +STANDARD_CHANNELS = [ + "EEG Fp1", "EEG F3", "EEG C3", "EEG P3", "EEG O1", "EEG F7", "EEG T3", "EEG T5", "EEG Fc1", "EEG Fc5", + "EEG Cp1", "EEG Cp5", "EEG F9", "EEG Fz", "EEG Cz", "EEG Pz", "EEG Fp2", "EEG F4", "EEG C4", "EEG P4", + "EEG O2", "EEG F8", "EEG T4", "EEG T6", "EEG Fc2", "EEG Fc6", "EEG Cp2", "EEG Cp6", "EEG F10" +] # All subjects' EDF files contain the 29 specified channels. + +# {'file': 'PN00-3.edf', 'reg_start': '18.15.44', 'start_time': '18.28.29', 'end_time': '19.29.29'}, # Incorrect seizure time record +seizure_records = [ + # PN00 + {'file': 'PN00-1.edf', 'reg_start': '19.39.33', 'start_time': '19.58.36', 'end_time': '19.59.46'}, + {'file': 'PN00-2.edf', 'reg_start': '02.18.17', 'start_time': '02.38.37', 'end_time': '02.39.31'}, + {'file': 'PN00-4.edf', 'reg_start': '20.51.43', 'start_time': '21.08.29', 'end_time': '21.09.43'}, + {'file': 'PN00-5.edf', 'reg_start': '22.22.04', 'start_time': '22.37.08', 'end_time': '22.38.15'}, + # PN01 + {'file': 'PN01-1.edf', 'reg_start': '19.00.44', 'start_time': '21.51.02', 'end_time': '21.51.56'}, + {'file': 'PN01-1.edf', 'reg_start': '19.00.44', 'start_time': '07.53.17', 'end_time': '07.54.31'}, + # PN03 + {'file': 'PN03-1.edf', 'reg_start': '22.44.37', 'start_time': '09.29.10', 'end_time': '09.31.01'}, + {'file': 'PN03-2.edf', 'reg_start': '21.31.04', 'start_time': '07.13.05', 'end_time': '07.15.18'}, + # PN05 + {'file': 'PN05-2.edf', 'reg_start': '06.46.02', 'start_time': '08.45.25', 'end_time': '08.46.00'}, + {'file': 'PN05-3.edf', 'reg_start': '06.01.23', 'start_time': '07.55.19', 'end_time': '07.55.49'}, + {'file': 'PN05-4.edf', 'reg_start': '06.38.35', 'start_time': '07.38.43', 'end_time': '07.39.22'}, + # PN06 + {'file': 'PNO6-1.edf', 'reg_start': '04.21.22', 'start_time': '05.54.25', 'end_time': '05.55.29'}, + {'file': 'PNO6-2.edf', 'reg_start': '21.11.29', 'start_time': '23.39.09', 'end_time': '23.40.18'}, + {'file': 'PN06-3.edf', 'reg_start': '06.25.51', 'start_time': '08.10.26', 'end_time': '08.11.08'}, + {'file': 'PNO6-4.edf', 'reg_start': '11.16.09', 'start_time': '12.55.08', 'end_time': '12.56.11'}, + {'file': 'PN06-5.edf', 'reg_start': '13.24.41', 'start_time': '14.44.24', 'end_time': '14.45.08'}, + # PN07 + {'file': 'PN07-1.edf', 'reg_start': '23.18.10', 'start_time': '05.25.49', 'end_time': '05.26.51'}, + # PN09 + {'file': 'PN09-1.edf', 'reg_start': '14.08.54', 'start_time': '16.09.43', 'end_time': '16.11.03'}, + {'file': 'PN09-2.edf', 'reg_start': '15.02.09', 'start_time': '17.00.56', 'end_time': '17.01.55'}, + {'file': 'PN09-3.edf', 'reg_start': '14.20.23', 'start_time': '16.20.44', 'end_time': '16.21.48'}, + # PN10 + {'file': 'PN10-1.edf', 'reg_start': '05.40.05', 'start_time': '07.45.50', 'end_time': '07.46.59'}, + {'file': 'PN10-2.edf', 'reg_start': '09.30.15', 'start_time': '11.40.13', 'end_time': '11.41.04'}, + {'file': 'PN10-3.edf', 'reg_start': '13.33.18', 'start_time': '15.43.53', 'end_time': '15.45.02'}, + {'file': 'PN10-4.5.6.edf', 'reg_start': '12.11.21', 'start_time': '12.49.50', 'end_time': '12.49.55'}, + {'file': 'PN10-4.5.6.edf', 'reg_start': '12.11.21', 'start_time': '14.00.25', 'end_time': '14.00.44'}, + {'file': 'PN10-4.5.6.edf', 'reg_start': '12.11.21', 'start_time': '15.18.26', 'end_time': '15.19.23'}, + {'file': 'PN10-7.8.9.edf', 'reg_start': '16.49.25', 'start_time': '17.35.13', 'end_time': '17.36.01'}, + {'file': 'PN10-7.8.9.edf', 'reg_start': '16.49.25', 'start_time': '18.20.24', 'end_time': '18.20.42'}, + {'file': 'PN10-7.8.9.edf', 'reg_start': '16.49.25', 'start_time': '20.24.48', 'end_time': '20.25.03'}, + {'file': 'PN10-10.edf', 'reg_start': '08.45.22', 'start_time': '10.58.19', 'end_time': '10.58.33'}, + # PN11 + {'file': 'PN11-1.edf', 'reg_start': '11.31.25', 'start_time': '13.37.19', 'end_time': '13.38.14'}, + # PN12 + {'file': 'PN12-1.2.edf', 'reg_start': '15.51.31', 'start_time': '16.13.23', 'end_time': '16.14.26'}, + {'file': 'PN12-1.2.edf', 'reg_start': '15.51.31', 'start_time': '18.31.01', 'end_time': '18.32.09'}, + {'file': 'PN12-3.edf', 'reg_start': '08.42.35', 'start_time': '08.55.27', 'end_time': '08.57.03'}, + {'file': 'PN12-4.edf', 'reg_start': '15.59.19', 'start_time': '18.42.51', 'end_time': '18.43.54'}, + # PN13 + {'file': 'PN13-1.edf', 'reg_start': '08.24.28', 'start_time': '10.22.10', 'end_time': '10.22.58'}, + {'file': 'PN13-2.edf', 'reg_start': '06.55.02', 'start_time': '08.55.51', 'end_time': '08.56.56'}, + {'file': 'PN13-3.edf', 'reg_start': '12.00.01', 'start_time': '14.05.54', 'end_time': '14.08.25'}, + # PN14 + {'file': 'PN14-1.edf', 'reg_start': '11.44.58', 'start_time': '13.46.00', 'end_time': '13.46.27'}, + {'file': 'PN14-2.edf', 'reg_start': '15.50.13', 'start_time': '17.54.52', 'end_time': '17.55.04'}, + {'file': 'PN14-3.edf', 'reg_start': '16.17.45', 'start_time': '21.10.05', 'end_time': '21.10.46'}, + {'file': 'PN14-4.edf', 'reg_start': '14.18.30', 'start_time': '15.49.33', 'end_time': '15.50.56'}, + # PN16 + {'file': 'PN16-1.edf', 'reg_start': '20.45.21', 'start_time': '22.45.05', 'end_time': '22.47.08'}, + {'file': 'PN16-2.edf', 'reg_start': '00.53.55', 'start_time': '03.16.49', 'end_time': '03.18.36'}, + # PN17 + {'file': 'PN17-1.edf', 'reg_start': '20.14.28', 'start_time': '22.34.48', 'end_time': '22.35.58'}, + {'file': 'PN17-2.edf', 'reg_start': '13.52.18', 'start_time': '16.01.09', 'end_time': '16.02.32'} +] + + +def time_to_samples(time_str, start_time_str, sampling_rate): + # Supports time formats: HH:MM:SS and HH.MM.SS + time_str = time_str.replace(':', '.') + start_time_str = start_time_str.replace(':', '.') + + h, m, s = map(int, time_str.split('.')) + start_h, start_m, start_s = map(int, start_time_str.split('.')) + + # Handle cross-day scenarios (e.g., registration time 19.00.44, seizure time 07.53.17 the next day) + if h < start_h: + h += 24 # Assume the cross-day interval does not exceed 24 hours. + + delta = (h - start_h) * 3600 + (m - start_m) * 60 + (s - start_s) + return delta * sampling_rate + + +def process_edf(edf_path, seizure_records, processed_data_path, sampling_rate=512): + try: + raw = mne.io.read_raw_edf(edf_path, preload=True) + raw.filter(l_freq=0.1, h_freq=75.0) + raw.notch_filter(freqs=50) + + # Add sampling rate verification + if raw.info['sfreq'] != float(sampling_rate): + return False, f"Sampling rate mismatch in {os.path.basename(edf_path)}: " \ + f"expected {sampling_rate} Hz, got {raw.info['sfreq']} Hz" + + # Process channel-related issues + existing_channels = raw.ch_names + existing_map = {ch.lower(): ch for ch in existing_channels} + selected_channels = [] + reordered_channels = [] + missing_channels = [] + for ch in STANDARD_CHANNELS: + ch_lower = ch.lower() + if ch_lower in existing_map: + real_name = existing_map[ch_lower] + selected_channels.append(real_name) + reordered_channels.append(real_name) + else: + missing_channels.append(ch) + print(selected_channels) + print(reordered_channels) + if missing_channels: + print( + f"Missing channels in {raw.filenames[0] if hasattr(raw, 'filenames') else 'EDF'}: {missing_channels}") + raw.pick_channels(selected_channels) + raw.reorder_channels(reordered_channels) + + # Extract seizure event records in the current file + data, _ = raw[:, :] + current_file = os.path.basename(edf_path) + file_seizures = [sz for sz in seizure_records if sz["file"] == current_file] + seizure_samples = [] + for sz in file_seizures: + start = time_to_samples(sz["start_time"], sz["reg_start"], sampling_rate) + end = time_to_samples(sz["end_time"], sz["reg_start"], sampling_rate) + seizure_samples.append((start, end)) + print(seizure_samples) + + # Create all data segments + segments = [] + segment_length = 10 * sampling_rate + # 1. Standard split segmentation + for i in range(0, data.shape[1], segment_length): + if i + segment_length <= data.shape[1]: + seg = data[:, i:i + segment_length] + label = 0 + for sz_start, sz_end in seizure_samples: + if (i < sz_start < i + segment_length) or (i < sz_end < i + segment_length): + label = 1 + break + segments.append((seg, label)) + # 2. Seizure-enhanced segmentation + for sz_start, sz_end in seizure_samples: + start = max(0, sz_start - sampling_rate) + end = min(data.shape[1], sz_end + sampling_rate) + + for i in range(start, end, 5 * sampling_rate): + seg = data[:, i:i + segment_length] + segments.append((seg, 1)) + + # Save all data segments + base_name = os.path.splitext(current_file)[0] + patient_id = base_name.split('-')[0] + patient_dir = os.path.join(processed_data_path, patient_id) + os.makedirs(patient_dir, exist_ok=True) + for idx, (seg, label) in enumerate(segments): + output_data = {"X": seg, "Y": label} + output_file = os.path.join(patient_dir, f"{base_name}_{idx}.pkl") + with open(output_file, "wb") as f: + pickle.dump(output_data, f) + + return True, f"Processed {current_file}, generated {len(segments)} segments (with 0.1-75Hz bandpass + 50Hz notch)" + + except Exception as e: + return False, f"Error processing {os.path.basename(edf_path)}: {str(e)}" + + +def main(patient_id): + edf_files = sorted([ + f for f in os.listdir(os.path.join(raw_data_path, patient_id)) + if f.endswith('.edf') and f.startswith(patient_id) + ]) + + for edf_file in edf_files: + edf_path = os.path.join(raw_data_path, patient_id, edf_file) + print(f"Processing {edf_file}...") + success, message = process_edf(edf_path, seizure_records, processed_data_path, sampling_rate = 512) + print(message) + print(f"\nAll processing completed. Results saved to {processed_data_path}") + + +if __name__ == "__main__": + patient_ids = ["PN00", "PN01", "PN03", "PN05", "PN06", "PN07", "PN09", + "PN10", "PN11", "PN12", "PN13", "PN14", "PN16", "PN17"] + + for pid in patient_ids: + main(pid) diff --git a/preprocessing/TUEV/cross_json_process.py b/preprocessing/TUEV/cross_json_process.py index c10d0e0..fb847ee 100644 --- a/preprocessing/TUEV/cross_json_process.py +++ b/preprocessing/TUEV/cross_json_process.py @@ -1,14 +1,13 @@ import json import os -import random import pickle import numpy as np from natsort import natsorted import sys -data_root = sys.argv[1] +data_root = sys.argv[1] print(f"Data root: {data_root}") -processed_data_path = os.path.join(data_root,'TUEV/processed_data/') +processed_data_path = os.path.join(data_root, 'TUEV/processed_data/') data_split_path = './preprocessing/TUEV/cross_subject_json' os.makedirs(data_split_path, exist_ok=True) train_folder = os.path.join(processed_data_path, "train") @@ -18,35 +17,36 @@ save_val_path = os.path.join(data_split_path, 'val.json') save_test_path = os.path.join(data_split_path, 'test.json') -# base_folder = ".Preprocessing/TUEV/processed_data/final_data" -# train_folder = os.path.join(base_folder, "train") -# val_folder = os.path.join(base_folder, "eval") -# eval_folder = os.path.join(base_folder, "test") -# cross_subject_json_folder = ".Preprocessing/TUEV/cross_subject_json" -# os.makedirs(os.path.dirname(cross_subject_json_folder), exist_ok=True) - -# save_folder_test = os.path.join(cross_subject_split_folder, 'test.json') -# save_folder_train = os.path.join(cross_subject_split_folder, 'train.json') -# save_folder_val = os.path.join(cross_subject_split_folder, 'val.json') - - sampling_rate = 250 ch_names = ['FP1', 'FP2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2', 'F7', 'F8', 'T3', 'T4', 'T5', 'T6', 'A1', 'A2', 'FZ', 'CZ', 'PZ', 'T1', 'T2'] num_channels = len(ch_names) + total_mean = np.zeros(num_channels) total_std = np.zeros(num_channels) num_all = 0 -max_value = -1 -min_value = 1e6 +max_value = -np.inf +min_value = np.inf + -def is_23_channels(pkl_file): +def get_subject_name(file_path): + base = os.path.basename(file_path) + if "_" in base: + return base.split("_")[0] + return base[:8] + + +def load_eeg(file_path): try: - eeg_data = pickle.load(open(pkl_file, "rb")) + with open(file_path, "rb") as f: + eeg_data = pickle.load(f) eeg = eeg_data['X'] - return eeg.shape[0] == 23 + if eeg.shape[0] != num_channels: + return None + return eeg_data except Exception as e: - print(f"Error loading file {pkl_file}: {e}") - return False + print(f"Error loading file {file_path}: {e}") + return None + train_files = natsorted([os.path.join(train_folder, f) for f in os.listdir(train_folder) if f.endswith('.pkl')]) tuples_list_train = [] @@ -54,39 +54,36 @@ def is_23_channels(pkl_file): subject_id_map = {} for file in train_files: - if not is_23_channels(file): + eeg_data = load_eeg(file) + if eeg_data is None: continue - try: - eeg_data = pickle.load(open(file, "rb")) - label = eeg_data['Y'] - eeg = eeg_data['X'] - subject_folder = os.path.basename(file)[:8] - if subject_folder not in subject_id_map: - subject_id_map[subject_folder] = subject_id_counter - subject_id_counter += 1 - # Calculate normalization parameters. - for j in range(num_channels): - total_mean[j] += eeg[j].mean() - total_std[j] += eeg[j].std() - num_all += 1 - per_max_value = max(eeg.reshape(-1)) - per_min_value = min(eeg.reshape(-1)) - if per_max_value > max_value: - max_value = per_max_value - if per_min_value < min_value: - min_value = per_min_value - data = { - "subject_id": subject_id_map[subject_folder], - "subject_name": subject_folder, - "file": file, - "label": label - } - tuples_list_train.append(data) - except Exception as e: - print(f"Error loading file {file}: {e}") + label = eeg_data['Y'] + eeg = eeg_data['X'] + subject_name = get_subject_name(file) + if subject_name not in subject_id_map: + subject_id_map[subject_name] = subject_id_counter + subject_id_counter += 1 + + total_mean += eeg.mean(axis=1) + total_std += eeg.std(axis=1) + num_all += 1 + max_value = max(max_value, eeg.max()) + min_value = min(min_value, eeg.min()) + + data = { + "subject_id": subject_id_map[subject_name], + "subject_name": subject_name, + "file": file, + "label": label + } + tuples_list_train.append(data) -data_mean = (total_mean / num_all).tolist() -data_std = (total_std / num_all).tolist() +if num_all == 0: + data_mean = total_mean.tolist() + data_std = total_std.tolist() +else: + data_mean = (total_mean / num_all).tolist() + data_std = (total_std / num_all).tolist() train_dataset = { "subject_data": tuples_list_train, @@ -99,33 +96,28 @@ def is_23_channels(pkl_file): "std": data_std } } -formatted_json_train = json.dumps(train_dataset, indent=2) with open(save_train_path, 'w') as f: - f.write(formatted_json_train) + json.dump(train_dataset, f, indent=2) val_files = natsorted([os.path.join(val_folder, f) for f in os.listdir(val_folder) if f.endswith('.pkl')]) tuples_list_val = [] for file in val_files: - if not is_23_channels(file): + eeg_data = load_eeg(file) + if eeg_data is None: continue - try: - eeg_data = pickle.load(open(file, "rb")) - label = eeg_data['Y'] - eeg = eeg_data['X'] - subject_folder = os.path.basename(file)[:8] - if subject_folder not in subject_id_map: - subject_id_map[subject_folder] = subject_id_counter - subject_id_counter += 1 - data = { - "subject_id": subject_id_map[subject_folder], - "subject_name": subject_folder, - "file": file, - "label": label - } - tuples_list_val.append(data) - except Exception as e: - print(f"Error loading file {file}: {e}") + label = eeg_data['Y'] + subject_name = get_subject_name(file) + if subject_name not in subject_id_map: + subject_id_map[subject_name] = subject_id_counter + subject_id_counter += 1 + data = { + "subject_id": subject_id_map[subject_name], + "subject_name": subject_name, + "file": file, + "label": label + } + tuples_list_val.append(data) val_dataset = { "subject_data": tuples_list_val, @@ -138,29 +130,27 @@ def is_23_channels(pkl_file): "std": data_std } } -formatted_json_val = json.dumps(val_dataset, indent=2) with open(save_val_path, 'w') as f: - f.write(formatted_json_val) + json.dump(val_dataset, f, indent=2) eval_files = natsorted([os.path.join(eval_folder, f) for f in os.listdir(eval_folder) if f.endswith('.pkl')]) tuples_list_test = [] error_list = [] for file in eval_files: - if not is_23_channels(file): + eeg_data = load_eeg(file) + if eeg_data is None: continue try: - data_name = os.path.basename(file).split('_')[1][:3] - if data_name not in subject_id_map: - subject_id_map[data_name] = subject_id_counter - subject_id_counter += 1 - subject_id = subject_id_map[data_name] - eeg_data = pickle.load(open(file, "rb")) label = eeg_data['Y'] - eeg = eeg_data['X'] + subject_name = get_subject_name(file) + if subject_name not in subject_id_map: + subject_id_map[subject_name] = subject_id_counter + subject_id_counter += 1 + subject_id = subject_id_map[subject_name] data = { "subject_id": subject_id, - "subject_name": data_name, + "subject_name": subject_name, "file": file, "label": label } @@ -180,8 +170,7 @@ def is_23_channels(pkl_file): "std": data_std } } -formatted_json_test = json.dumps(test_dataset, indent=2) with open(save_test_path, 'w') as f: - f.write(formatted_json_test) + json.dump(test_dataset, f, indent=2) print("error list: ", error_list) diff --git a/preprocessing/TUEV/data_process.py b/preprocessing/TUEV/data_process.py index b815cc4..2959e5b 100644 --- a/preprocessing/TUEV/data_process.py +++ b/preprocessing/TUEV/data_process.py @@ -15,13 +15,12 @@ raw_data_path = os.path.join(data_root,'TUEV/raw_data/v2.0.1') processed_data_path = os.path.join(data_root,'TUEV/processed_data') os.makedirs(processed_data_path, exist_ok=True) -train_dir = os.path.join(processed_data_path, "train_dir") -eval_dir = os.path.join(processed_data_path, "eval_dir") -test_dir = os.path.join(processed_data_path, "test_dir") -if not os.path.exists(train_dir): - os.makedirs(train_dir) -if not os.path.exists(test_dir): - os.makedirs(test_dir) +train_dir = os.path.join(processed_data_path, "train") +eval_dir = os.path.join(processed_data_path, "eval") +test_dir = os.path.join(processed_data_path, "test") +os.makedirs(train_dir, exist_ok=True) +os.makedirs(eval_dir, exist_ok=True) +os.makedirs(test_dir, exist_ok=True) # final_data = os.path.join(processed_data_path, "final_data") diff --git a/util/eegdatasets.py b/util/eegdatasets.py index 028c239..2e9b5e1 100644 --- a/util/eegdatasets.py +++ b/util/eegdatasets.py @@ -41,7 +41,8 @@ def __init__(self, dataset, train=True, subject_mod='single', subject_id=1, samp test_root = f"{dataset_info['root']['multi']}/test.json" dataset_root = train_root if self.train else test_root - all_json_data = json.load(open(dataset_root, "r")) + with open(dataset_root, "r") as dataset_file: + all_json_data = json.load(dataset_file) data_info = all_json_data['dataset_info'] self.default_rate = data_info['sampling_rate'] self.ch_names = data_info['ch_names'] @@ -89,7 +90,8 @@ def _process_test_data(self, files): eeg_paths = [item["EEG"] for item in items] eeg_features = [] for eeg_path in eeg_paths: - sample = pickle.load(open(eeg_path, "rb"))['X'] + with open(eeg_path, "rb") as eeg_file: + sample = pickle.load(eeg_file)['X'] if self.sampling_rate != self.default_rate: sample = self.resample_data(sample) sample = self.normalize(sample) @@ -158,7 +160,8 @@ def __getitem__(self, index): eeg_sources = self.files[index] if self.train: datapath = eeg_sources['EEG'] - x = pickle.load(open(datapath, "rb"))['X'] + with open(datapath, "rb") as eeg_file: + x = pickle.load(eeg_file)['X'] if self.sampling_rate != self.default_rate: x = self.resample_data(x) x = self.normalize(x)