From 0b0a94c7e915de597eb74b5dc7ed0c0377337615 Mon Sep 17 00:00:00 2001 From: seqasim Date: Sat, 15 Nov 2025 13:53:04 -0500 Subject: [PATCH] updated all docstrings --- LFPAnalysis/analysis_utils.py | 245 ++++-- LFPAnalysis/iowa_utils.py | 43 +- LFPAnalysis/lfp_preprocess_utils.py | 632 +++++++-------- LFPAnalysis/nlx_utils.py | 159 +++- LFPAnalysis/oscillation_utils.py | 1155 +++++++++++++++++---------- LFPAnalysis/statistics_utils.py | 258 +++--- LFPAnalysis/sync_utils.py | 169 ++-- 7 files changed, 1605 insertions(+), 1056 deletions(-) diff --git a/LFPAnalysis/analysis_utils.py b/LFPAnalysis/analysis_utils.py index 6f36378..db3075d 100644 --- a/LFPAnalysis/analysis_utils.py +++ b/LFPAnalysis/analysis_utils.py @@ -14,10 +14,22 @@ # There are some things that MNE is not that good at, or simply does not do. Let's write our own code for these. -def select_rois_picks(elec_data, chan_name, manual_col='collapsed_manual'): +def select_rois_picks(elec_data: pd.DataFrame, chan_name: str, manual_col: str = 'collapsed_manual'): + """Select ROI for specific channel. - """ - Grab specific roi for the channel you are looking at + Parameters + ---------- + elec_data : pd.DataFrame + Electrode data DataFrame. + chan_name : str + Channel name. + manual_col : str, optional + Manual column name. Default is 'collapsed_manual'. + + Returns + ------- + str + ROI label. """ # Load the YBA ROI labels, custom assigned by Salman: @@ -105,9 +117,20 @@ def select_rois_picks(elec_data, chan_name, manual_col='collapsed_manual'): return roi -def select_picks_rois(elec_data, roi=None): - """ - Grab specific electrodes that you care about +def select_picks_rois(elec_data: pd.DataFrame, roi=None): + """Select electrodes for specific ROI. + + Parameters + ---------- + elec_data : pd.DataFrame + Electrode data DataFrame. + roi : str or list, optional + ROI name or list of ROI names. + + Returns + ------- + list + List of electrode labels. """ # Site specific processing: @@ -142,16 +165,27 @@ def select_picks_rois(elec_data, roi=None): return picks -def lfp_sta(ev_times, signal, sr, pre, post): - ''' - Compute the STA for a vector of stimuli. - - Input: - spikes - raw spike times used to compute STA, should be in s - signal - signal for averaging. can be filtered or unfiltered. - bound - bound of the STA in ms, +- this number +def lfp_sta(ev_times: np.ndarray, signal: np.ndarray, sr: float, pre: float, post: float): + """Compute spike-triggered average for LFP signal. + + Parameters + ---------- + ev_times : np.ndarray + Event times in seconds. + signal : np.ndarray + Signal for averaging. + sr : float + Sampling rate. + pre : float + Pre-event window in seconds. + post : float + Post-event window in seconds. - ''' + Returns + ------- + tuple + Tuple containing (sta, ste). + """ num_evs = len(ev_times) ev_in_samples = (ev_times * sr).astype(int) @@ -175,10 +209,28 @@ def lfp_sta(ev_times, signal, sr, pre, post): return sta, ste -def plot_TFR(data, freqs, pre_win, post_win, sr, title): - """ - - pre_win should be in seconds +def plot_TFR(data: np.ndarray, freqs: np.ndarray, pre_win: float, post_win: float, sr: float, title: str): + """Plot time-frequency representation. + + Parameters + ---------- + data : np.ndarray + TFR data array. + freqs : np.ndarray + Frequency array. + pre_win : float + Pre-window in seconds. + post_win : float + Post-window in seconds. + sr : float + Sampling rate. + title : str + Plot title. + + Returns + ------- + matplotlib.figure.Figure + Figure object. """ f, tfr = plt.subplots(1, 1, figsize=[7, 4], dpi=300) @@ -199,16 +251,28 @@ def plot_TFR(data, freqs, pre_win, post_win, sr, title): return f -def detect_fast_burst_evs(mne_data, - baseline_data, - burst_frequency = (70, 200), - smooth_win_s=0.02, - sd_upper_cutoff=6, - sd_lower_cutoff=1): - """ +def detect_fast_burst_evs(mne_data, baseline_data, burst_frequency: tuple = (70, 200), smooth_win_s: float = 0.02, sd_upper_cutoff: float = 6, sd_lower_cutoff: float = 1): + """Detect fast burst events in HFA band. - HFA band: 70-200 Hz - Ripple range: 80-120 + Parameters + ---------- + mne_data + MNE epochs object. + baseline_data + Baseline MNE epochs object. + burst_frequency : tuple, optional + Frequency range for burst detection. Default is (70, 200). + smooth_win_s : float, optional + Smoothing window in seconds. Default is 0.02. + sd_upper_cutoff : float, optional + Upper SD cutoff. Default is 6. + sd_lower_cutoff : float, optional + Lower SD cutoff. Default is 1. + + Returns + ------- + dict + Dictionary of burst events per channel. """ @@ -534,39 +598,35 @@ def detect_fast_burst_evs(mne_data, # # then reject any electrode with a low ripple count (< 20 ripples detected per electrode per task) or high rejection rate (greater than 30% rejection rate) # return allts, ripple_categories, ripple_psds -def FOOOF_continuous(signal): - """ - TODO +def FOOOF_continuous(signal: np.ndarray): + """Compute FOOOF on continuous signal. + + Parameters + ---------- + signal : np.ndarray + Continuous signal. """ pass -def FOOOF_compute_epochs(epochs, tmin=0, tmax=1.5, **kwargs): - """ - - This function is meant to enable easy computation of FOOOF. - +def FOOOF_compute_epochs(epochs, tmin: float = 0, tmax: float = 1.5, **kwargs): + """Compute FOOOF on epoched data. + Parameters ---------- - epochs : mne Epochs object - mne object - - tmin : time to start (s) - float - - tmax : time to end (s) - float - - band_dict : definitions of the bands of interest - dict - - kwargs : input arguments to the FOOOFGroup function, including: 'min_peak_height', 'peak_threshold', 'max_n_peaks' - dict - + epochs + MNE Epochs object. + tmin : float, optional + Start time in seconds. Default is 0. + tmax : float, optional + End time in seconds. Default is 1.5. + **kwargs + Additional FOOOFGroup arguments. + Returns ------- - mne_data_reref : mne object - mne object with re-referenced data + tuple + Tuple containing (FOOOFGroup_res, pd.DataFrame). """ # bands = fooof.bands.Bands(band_dict) @@ -825,39 +885,40 @@ def FOOOF_compute_epochs(epochs, tmin=0, tmax=1.5, **kwargs): # We put all of our basic FOOOF usage into a slightly clunky function that is meant to be used for running the regression # over multiple channels in parallel using joblib/Dask/multiprocessing.Pool: -def compute_FOOOF_parallel(chan_name, MNE_object, subj_id, elec_df, event_name, ev_dict, band_dict, conditions, - do_plot=False, save_path='/sc/arion/projects/guLab/Salman/EphysAnalyses', - do_save=False, **kwargs): - """ - Compute FOOOF for a single channel across all trials and for each condition of interest. - Meant to be used in parallel, hence a little clunky. - +def compute_FOOOF_parallel(chan_name: str, MNE_object, subj_id: str, elec_df: pd.DataFrame, event_name: str, ev_dict: dict, band_dict: dict, conditions: list, do_plot: bool = False, save_path: str = '/sc/arion/projects/guLab/Salman/EphysAnalyses', do_save: bool = False, **kwargs): + """Compute FOOOF for single channel in parallel. + Parameters - ---------- + ---------- chan_name : str - Name of the channel to compute FOOOF for - MNE_object : mne.Epochs - MNE object containing the data + Channel name. + MNE_object + MNE Epochs object. subj_id : str - Subject ID + Subject ID. elec_df : pd.DataFrame - DataFrame containing the electrode information - event : str - Event to compute FOOOF for + Electrode DataFrame. + event_name : str + Event name. ev_dict : dict - Dictionary containing the start and end times for each event + Event time dictionary. band_dict : dict - Dictionary containing the frequency bands to compute FOOOF for + Frequency band dictionary. conditions : list - List of conditions to compute FOOOF for - do_plot : bool - Whether to plot the FOOOF results - save_path : str - Path to save the FOOOF results - do_save : bool - Whether to save the FOOOF results - **kwargs : dict - Additional arguments to pass to FOOOF_compute_epochs + List of conditions. + do_plot : bool, optional + Whether to plot. Default is False. + save_path : str, optional + Save path. Default is '/sc/arion/projects/guLab/Salman/EphysAnalyses'. + do_save : bool, optional + Whether to save. Default is False. + **kwargs + Additional FOOOF arguments. + + Returns + ------- + pd.DataFrame or None + Results DataFrame if not saving. """ # First, compute FOOOF across all trials @@ -917,19 +978,29 @@ def compute_FOOOF_parallel(chan_name, MNE_object, subj_id, elec_df, event_name, return chan_df -def sliding_FOOOF(signal): - """ - Implement time-resolved FOOOF: - https://github.com/lucwilson/SPRiNT now has a python implementation we can borrow from! +def sliding_FOOOF(signal: np.ndarray): + """Compute time-resolved FOOOF. - + Parameters + ---------- + signal : np.ndarray + Signal array. """ pass -def hctsa_signal_features(signal): - """ - Implement https://github.com/DynamicsAndNeuralSystems/catch22 +def hctsa_signal_features(signal: np.ndarray): + """Extract catch22 signal features. + + Parameters + ---------- + signal : np.ndarray + Signal array. + + Returns + ------- + pd.DataFrame + DataFrame with signal features. """ signal_features = pycatch22.catch22_all(signal) diff --git a/LFPAnalysis/iowa_utils.py b/LFPAnalysis/iowa_utils.py index d4f21a2..9e8bb0a 100644 --- a/LFPAnalysis/iowa_utils.py +++ b/LFPAnalysis/iowa_utils.py @@ -3,9 +3,18 @@ from itertools import chain from LFPAnalysis import lfp_preprocess_utils -def extract_names_connect_table(connect_table_path): - """ - Utility function for extracting channel types from Iowa connection table +def extract_names_connect_table(connect_table_path: str): + """Extract channel types from Iowa connection table. + + Parameters + ---------- + connect_table_path : str + Path to the connection table CSV file. + + Returns + ------- + tuple + Tuple containing (eeg_names, resp_names, ekg_names, seeg_names, drop_names). """ connect_table = pd.read_csv(connect_table_path) @@ -95,10 +104,18 @@ def extract_names_connect_table(connect_table_path): return eeg_names, resp_names, ekg_names, seeg_names, drop_names -def extract_names_elec_table(elec_table_path): - """ - In some instances we just have Kiril's electrode table. In this case, we need a different extractor for - the data +def extract_names_elec_table(elec_table_path: str): + """Extract channel names from electrode table. + + Parameters + ---------- + elec_table_path : str + Path to the electrode table file. + + Returns + ------- + list + List of sEEG channel names. """ elec_data = lfp_preprocess_utils.load_elec(elec_table_path, site='UI') @@ -134,9 +151,15 @@ def extract_names_elec_table(elec_table_path): # return mapping_name -def rename_mne_channels(mne_data, location_table_path): +def rename_mne_channels(mne_data, location_table_path: str): + """Rename MNE channels based on location table. + + Parameters + ---------- + mne_data + MNE data object. + location_table_path : str + Path to the location table CSV file. """ - """ - location_table = pd.read_csv(location_table_path) \ No newline at end of file diff --git a/LFPAnalysis/lfp_preprocess_utils.py b/LFPAnalysis/lfp_preprocess_utils.py index 7de930e..f060110 100644 --- a/LFPAnalysis/lfp_preprocess_utils.py +++ b/LFPAnalysis/lfp_preprocess_utils.py @@ -105,40 +105,32 @@ def baseline_avg_TFR(data, baseline, mode='zscore'): return baseline_corrected -def baseline_trialwise_TFR(data=None, baseline_mne=None, mode='zscore', include_epoch_in_baseline=True, - ev_axis=0, elec_axis=1, freq_axis=2, time_axis=3): - - """ - This function zscores the task data and the baseline data together. Then, it subtracts the mean of the z-scored - baseline from the task data. - - TODO: Make this flexible in case the number of events (baseline_mne) ! = (data) +def baseline_trialwise_TFR(data=None, baseline_mne=None, mode: str = 'zscore', include_epoch_in_baseline: bool = True, ev_axis: int = 0, elec_axis: int = 1, freq_axis: int = 2, time_axis: int = 3): + """Baseline correct trialwise TFR data. Parameters ---------- - data : np.ndarray, shape (n_trials, n_channels, n_freqs, n_times) - The original time-frequency data. - baseline_mne : mne.epochs.Epochs or np.ndarray, shape (n_trials, n_channels, n_freqs, n_times) - The baseline data. If `trialwise` is True, this should contain baseline data for each trial. + data : np.ndarray, optional + Time-frequency data array. + baseline_mne + Baseline MNE epochs or array. mode : str, optional - The type of baseline correction to apply. Valid options are 'mean', 'ratio', 'logratio', 'percent', 'zscore', and 'zlogratio'. Default is 'zscore'. - trialwise : bool, optional - Whether to baseline each trial separately. Default is True. - baseline_only : bool, optional - Whether to only use the baseline data for correction. Default is False. But depends on 'trialwise'. + Baseline correction mode. Default is 'zscore'. + include_epoch_in_baseline : bool, optional + Whether to include epoch in baseline. Default is True. ev_axis : int, optional - The axis corresponding to the event dimension. Default is 0. + Event axis. Default is 0. elec_axis : int, optional - The axis corresponding to the electrode dimension. Default is 1. + Electrode axis. Default is 1. freq_axis : int, optional - The axis corresponding to the frequency dimension. Default is 2. + Frequency axis. Default is 2. time_axis : int, optional - The axis corresponding to the time dimension. Default is 3. - + Time axis. Default is 3. + Returns ------- - baseline_corrected : np.ndarray, shape (n_trials, n_channels, n_freqs, n_times) - The baseline-corrected time-frequency data. + np.ndarray + Baseline-corrected data. """ # The reason I want baseline_mne to be an mne input was to specify these axes in a foolproof way for when @@ -355,38 +347,32 @@ def baseline_trialwise_TFR(data=None, baseline_mne=None, mode='zscore', include_ -def baseline_TFR_permute(data=None, baseline_mne=None, mode='zscore', num_samples=1000, - ev_axis=0, elec_axis=1, freq_axis=2, time_axis=3): - - """ - This function samples from all the baseline periods N times with replacement - and computes the mean and std for normalization of task-related activity. +def baseline_TFR_permute(data=None, baseline_mne=None, mode: str = 'zscore', num_samples: int = 1000, ev_axis: int = 0, elec_axis: int = 1, freq_axis: int = 2, time_axis: int = 3): + """Baseline correct TFR using permutation sampling. Parameters ---------- - data : np.ndarray, shape (n_trials, n_channels, n_freqs, n_times) - The original time-frequency data. - baseline_mne : mne.epochs.Epochs or np.ndarray, shape (n_trials, n_channels, n_freqs, n_times) - The baseline data. If `trialwise` is True, this should contain baseline data for each trial. + data : np.ndarray, optional + Time-frequency data array. + baseline_mne + Baseline MNE epochs or array. mode : str, optional - The type of baseline correction to apply. Valid options are 'mean', 'ratio', 'logratio', 'percent', 'zscore', and 'zlogratio'. Default is 'zscore'. - trialwise : bool, optional - Whether to baseline each trial separately. Default is True. - baseline_only : bool, optional - Whether to only use the baseline data for correction. Default is False. But depends on 'trialwise'. + Baseline correction mode. Default is 'zscore'. + num_samples : int, optional + Number of permutation samples. Default is 1000. ev_axis : int, optional - The axis corresponding to the event dimension. Default is 0. + Event axis. Default is 0. elec_axis : int, optional - The axis corresponding to the electrode dimension. Default is 1. + Electrode axis. Default is 1. freq_axis : int, optional - The axis corresponding to the frequency dimension. Default is 2. + Frequency axis. Default is 2. time_axis : int, optional - The axis corresponding to the time dimension. Default is 3. - + Time axis. Default is 3. + Returns ------- - baseline_corrected : np.ndarray, shape (n_trials, n_channels, n_freqs, n_times) - The baseline-corrected time-frequency data. + np.ndarray + Baseline-corrected data. """ @@ -455,49 +441,28 @@ def baseline_TFR_permute(data=None, baseline_mne=None, mode='zscore', num_sample -def wm_ref(mne_data=None, elec_path=None, bad_channels=None, unmatched_seeg=None, site='MSSM', average=False): - """ - Define a custom reference using the white matter electrodes. Originated here: https://doi.org/10.1016/j.neuroimage.2015.02.031 - - (as in https://www.science.org/doi/10.1126/sciadv.abf4198) +def wm_ref(mne_data=None, elec_path=None, bad_channels=None, unmatched_seeg=None, site: str = 'MSSM', average: bool = False): + """Create white matter reference. - Identify all white matter electrodes (based on the electrode names), and make sure they are not bad electrodes (based on the bad channels list). - - 1. iterate through each electrode, compute distance to all white matter electrodes - 2. find 3 closest wm electrodes, compute amplitude (rms) - 3. lowest amplitude electrode = wm reference - - Make sure it's the same hemisphere. - - TODO: implement average reference option, whereby the mean activity across all white matter electrodes is used as a reference [separate per hemi]... - see: https://www.sciencedirect.com/science/article/pii/S1053811922005559#bib0349 - - TODO: this is SLOW; any vectorization to speed it up or parallelization? - Parameters ---------- - mne_data : mne object - non-referenced data stored in an MNE object - elec_data : pandas df - dataframe containing the electrode localization information - bad_channels : list - bad channels - unmatched_seeg : list - list of channels that were not in the edf file - site : str - hospital where the recording took place - average : bool - should we construct an average white matter reference instead of a default? - + mne_data + MNE data object. + elec_path : str, optional + Path to electrode file. + bad_channels : list, optional + List of bad channels. + unmatched_seeg : list, optional + List of unmatched sEEG channels. + site : str, optional + Site name. Default is 'MSSM'. + average : bool, optional + Whether to use average reference. Default is False. + Returns ------- - anode_list : list - list of channels to subtract from - cathode_list : list - list of channels to subtract - drop_wm_channels : list - list of white matter channels which were not used for reference and now serve no purpose - + tuple + Tuple containing (anode_list, cathode_list, drop_wm_channels, oob_channels). """ elec_data = load_elec(elec_path, site=site) @@ -686,31 +651,26 @@ def wm_ref(mne_data=None, elec_path=None, bad_channels=None, unmatched_seeg=None return anode_list, cathode_list, drop_wm_channels -def laplacian_ref(mne_data, elec_path, bad_channels, unmatched_seeg=None, site=None): - """ - Return the cathode list and anode list for mne to use for laplacian referencing. - - In this case, the cathode is the average of the surrounding electrodes. If an edge electrode, it's just bipolar. - +def laplacian_ref(mne_data, elec_path: str, bad_channels: list, unmatched_seeg=None, site=None): + """Create laplacian reference. + Parameters ---------- - mne_data : MNE Raw object - MNE Raw object containing the EEG data + mne_data + MNE Raw object. elec_path : str - Path to the electrode localization file - bad_channels : list - List of bad channels - unmatched_seeg : list - List of channels that were not in the edf file - site : str - Hospital where the recording took place - + Path to electrode localization file. + bad_channels : list + List of bad channels. + unmatched_seeg : list, optional + List of unmatched sEEG channels. + site : str, optional + Site name. + Returns ------- - anode_list : list - List of channels to subtract from - cathode_list : list - List of channels to subtract + tuple + Tuple containing (anode_list, cathode_list). """ # TODO: for someone clever. Note that you have to bypass the mne reference script because that specific a single reference for each electrode. @@ -761,27 +721,24 @@ def laplacian_ref(mne_data, elec_path, bad_channels, unmatched_seeg=None, site=N # return anode_list, cathode_list -def bipolar_ref(elec_path, bad_channels, unmatched_seeg=None, site='MSSM'): - """ - Return the cathode list and anode list for mne to use for bipolar referencing. - +def bipolar_ref(elec_path: str, bad_channels: list, unmatched_seeg=None, site: str = 'MSSM'): + """Create bipolar reference. + Parameters ---------- - elec_data : pandas df - dataframe containing the electrode localization information - bad_channels : list - bad channels - unmatched_seeg : list - list of channels that were not in the edf file - site : str - hospital where the recording took place - + elec_path : str + Path to electrode file. + bad_channels : list + List of bad channels. + unmatched_seeg : list, optional + List of unmatched sEEG channels. + site : str, optional + Site name. Default is 'MSSM'. + Returns ------- - anode_list : list - list of channels to subtract from - cathode_list : list - list of channels to subtract + tuple + Tuple containing (anode_list, cathode_list, drop_wm_channels, oob_channels). """ elec_data = load_elec(elec_path, site=site) @@ -922,29 +879,22 @@ def sort_strings(strings): return anode_list, cathode_list, drop_wm_channels, oob_channels -def match_elec_names(mne_names, loc_names, method='levenshtein'): - """ - The electrode names read out of the edf file do not always match those - in the pdf (used for localization). This could be error on the side of the tech who input the labels, - or on the side of MNE reading the labels in. Usually there's a mixup between lowercase 'l' and capital 'I', or between 'R' and 'P'... - - This function matches the MNE channel names to those used in the localization. - +def match_elec_names(mne_names: list, loc_names, method: str = 'levenshtein'): + """Match MNE channel names to localization names. + Parameters ---------- mne_names : list - list of electrode names in the recording data (mne) - loc_names : list - list of electrode names in the pdf, used for the localization - + List of MNE channel names. + loc_names + Localization names. + method : str, optional + Matching method. Default is 'levenshtein'. + Returns ------- - new_mne_names : list - revised mne names merged across sources - unmatched_names : list - names that do not match (mostly scalp EEG and misc) - unmatched_seeg : list - sEEG channels that do not match (should be rare) + tuple + Tuple containing (new_mne_names, unmatched_names, unmatched_seeg). """ # strip spaces from mne_names and put in lower case mne_names = [x.replace(" ", "").lower() for x in mne_names] @@ -1053,28 +1003,20 @@ def match_elec_names(mne_names, loc_names, method='levenshtein'): return new_mne_names, unmatched_names, unmatched_seeg -def detect_bad_elecs(mne_data, sEEG_mapping_dict): - """ - Find outlier channels using a combination of kurtosis, variance, and standard deviation. Also use the elec_data to find channels out of the brain - - https://www-sciencedirect-com.eresources.mssm.edu/science/article/pii/S016502701930278X - https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7472198/ - https://www.biorxiv.org/content/10.1101/2021.05.14.444176v2.full.pdf - +def detect_bad_elecs(mne_data, sEEG_mapping_dict: dict): + """Detect bad electrodes using statistical measures. - Plot these channels for manual verification. - Parameters ---------- - mne_data : mne object - mne data to check for bad channels - sEEG_mapping_dict : dict - dict of sEEG channels - + mne_data + MNE data object. + sEEG_mapping_dict : dict + Dictionary of sEEG channels. + Returns ------- - bad_channels : list - list of bad channels + list + List of bad channels. """ # Get the data @@ -1092,10 +1034,20 @@ def detect_bad_elecs(mne_data, sEEG_mapping_dict): # return bad_channels -def detect_misc_artifacts(mne_data, peak_thresh=6): - """ - This function detects artifacts (sharp transients) in the LFP signal automatically. +def detect_misc_artifacts(mne_data, peak_thresh: float = 6): + """Detect miscellaneous artifacts in LFP signal. + + Parameters + ---------- + mne_data + MNE data object. + peak_thresh : float, optional + Peak threshold. Default is 6. + Returns + ------- + dict + Dictionary of artifact times per channel. """ # 1. take the gradient of the signal: gradient_signal = np.gradient(mne_data.copy()._data, axis=-1) @@ -1115,37 +1067,24 @@ def detect_misc_artifacts(mne_data, peak_thresh=6): return artifact_sec_dict -def detect_IEDs(mne_data, peak_thresh=5, closeness_thresh=0.25, width_thresh=0.2): - """ - This function detects IEDs in the LFP signal automatically. Alternative to manual marking of each ied. - - From: https://academic.oup.com/brain/article/142/11/3502/5566384 - - Method 1: - 1. Bandpass filter in the [25-80] Hz band. - 2. Rectify. - 3. Find filtered envelope > 3. - 4. Eliminate events with peaks with unfiltered envelope < 3. - 5. Eliminate close IEDs (peaks within 250 ms). - 6. Eliminate IEDs that are not present on at least 4 electrodes. - (https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6821283/) - +def detect_IEDs(mne_data, peak_thresh: float = 5, closeness_thresh: float = 0.25, width_thresh: float = 0.2): + """Detect interictal epileptiform discharges. + Parameters ---------- - mne_data : mne object - mne data to check for bad channels - peak_thresh : float - the peak threshold in amplitude - closeness_thresh : float - the closeness threshold in time - width_thresh : float - the width threshold for IEDs - + mne_data + MNE data object. + peak_thresh : float, optional + Peak threshold. Default is 5. + closeness_thresh : float, optional + Closeness threshold in seconds. Default is 0.25. + width_thresh : float, optional + Width threshold in seconds. Default is 0.2. + Returns ------- - IED_samps_dict : dict - dict with every IED index - + dict + Dictionary of IED times per channel. """ # What type of data is this? Continuous or epoched? @@ -1455,17 +1394,20 @@ def detect_IEDs(mne_data, peak_thresh=5, closeness_thresh=0.25, width_thresh=0.2 # Below are code that condense the Jupyter notebooks for pre-processing into individual functions. -def load_elec(elec_path=None, site='MSSM'): - """ - Load the electrode data from a CSV or Excel file, correct for small idiosyncracies, and return as a pandas dataframe. - +def load_elec(elec_path=None, site: str = 'MSSM'): + """Load electrode data from file. + Parameters ---------- - elec_path (str): Path to the electrode data file. The file should be in CSV or Excel format. - + elec_path : str, optional + Path to electrode file. + site : str, optional + Site name. Default is 'MSSM'. + Returns - ---------- - pandas.DataFrame: A dataframe containing the electrode data. The dataframe has columns for the electrode label, the x, y, and z coordinates in MNI space, and any other metadata associated with the electrodes. + ------- + pd.DataFrame + Electrode data DataFrame. """ # Load electrode data (should already be manually localized!) @@ -1575,31 +1517,22 @@ def load_elec(elec_path=None, site='MSSM'): return elec_data -def make_mne_scalp(load_path=None, overwrite=True, return_data=False): - """ - Make a mne object from the scalp data file, and save out the sync. - Following this step, you can indicate bad electrodes manually. - - This function requires users to input the file format of the raw data. - - Optionally, users can input the names of special channel types as these might be communicated manually rather than hardcoded into the raw data. - - (On that note, a better idea would be for someone to go back and edit the original data to include informative names...) +def make_mne_scalp(load_path=None, overwrite: bool = True, return_data: bool = False): + """Create MNE object from scalp data. Parameters ---------- - load_path : str - path to the neural data - format : str - how was this data collected? options: ['edf', 'nlx] - overwrite: bool - whether to overwrite existing data for this person if it exists - return_data: bool - whether to actually return the data or just save it in the directory + load_path : str, optional + Path to neural data. + overwrite : bool, optional + Whether to overwrite existing data. Default is True. + return_data : bool, optional + Whether to return data. Default is False. + Returns ------- - mne_data : mne object - mne object + mne object or None + MNE object if return_data is True. """ edf_file = glob(f'{load_path}/*.edf')[0] @@ -1631,53 +1564,50 @@ def is_scalp_eeg_channel(name): return mne_data if return_data else mne_data.save(f'{load_path}/scalp_raw.fif', overwrite=overwrite) -def make_mne(load_path=None, elec_path=None, format='edf', site='MSSM', resample_sr = 500, overwrite=True, return_data=False, -include_micros=False, eeg_names=None, resp_names=None, ekg_names=None, sync_name=None, sync_type='photodiode', seeg_names=None, drop_names=None, -seeg_only=True, check_bad=False): - """ - Make a mne object from the data and electrode files, and save out the sync. - Following this step, you can indicate bad electrodes manually. - - This function requires users to input the file format of the raw data, and the location the data was recorded for site-specific steps. - - Optionally, users can input the names of special channel types as these might be communicated manually rather than hardcoded into the raw data. - - (On that note, a better idea would be for someone to go back and edit the original data to include informative names...) +def make_mne(load_path=None, elec_path=None, format: str = 'edf', site: str = 'MSSM', resample_sr: int = 500, overwrite: bool = True, return_data: bool = False, include_micros: bool = False, eeg_names=None, resp_names=None, ekg_names=None, sync_name=None, sync_type: str = 'photodiode', seeg_names=None, drop_names=None, seeg_only: bool = True, check_bad: bool = False): + """Create MNE object from data and electrode files. Parameters ---------- - load_path : str - path to the neural data - elec_data : pandas df - dataframe with all the electrode localization information - format : str - how was this data collected? options: ['edf', 'nlx] - site: str - where was the data collected? options: ['UI', 'MSSM']. - TODO: add site specificity for UC Davis - overwrite: bool - whether to overwrite existing data for this person if it exists - return_data: bool - whether to actually return the data or just save it in the directory - include_micros : bool - whether to include the microwire LFP in the LFP data object or not - eeg_names : list - list of channel names that pertain to scalp EEG in case the hardcoded options don't work - resp_names : list - list of channel names that pertain to respiration in case the hardcoded options don't work - ekg_names : list - list of channel names that pertain to the EKG in case the hardcoded options don't work - sync_name : str - provide the sync name in case the hardcoded options don't work - drop_names: str - provide the drop names in case you know certain channels that should be thrown out asap - seeg_only: bool (default=True) - indicate whether you want non seeg channels included - + load_path : str, optional + Path to neural data. + elec_path : str, optional + Path to electrode file. + format : str, optional + Data format. Default is 'edf'. + site : str, optional + Site name. Default is 'MSSM'. + resample_sr : int, optional + Resampling rate. Default is 500. + overwrite : bool, optional + Whether to overwrite. Default is True. + return_data : bool, optional + Whether to return data. Default is False. + include_micros : bool, optional + Whether to include microwires. Default is False. + eeg_names : list, optional + List of EEG channel names. + resp_names : list, optional + List of respiration channel names. + ekg_names : list, optional + List of EKG channel names. + sync_name : str, optional + Sync channel name. + sync_type : str, optional + Sync type. Default is 'photodiode'. + seeg_names : list, optional + List of sEEG channel names. + drop_names : list, optional + List of channels to drop. + seeg_only : bool, optional + Whether to include only sEEG. Default is True. + check_bad : bool, optional + Whether to check for bad channels. Default is False. + Returns ------- - mne_data : mne object - mne object + mne object or None + MNE object if return_data is True. """ if not sync_name: @@ -1938,25 +1868,24 @@ def make_mne(load_path=None, elec_path=None, format='edf', site='MSSM', resample return mne_data -def ref_mne(mne_data=None, elec_path=None, method='wm', site='MSSM'): - """ - Following this step, you can indicate IEDs manually. - +def ref_mne(mne_data=None, elec_path=None, method: str = 'wm', site: str = 'MSSM'): + """Re-reference MNE data. + Parameters ---------- - mne_data : mne object - mne object - elec_data : pandas df - dataframe with all the electrode localization information - method : str - how should we reference the data ['wm', 'bipolar'] - site : str - where was this data collected? Options: ['MSSM', 'UI', 'Davis'] - + mne_data + MNE data object. + elec_path : str, optional + Path to electrode file. + method : str, optional + Reference method. Default is 'wm'. + site : str, optional + Site name. Default is 'MSSM'. + Returns ------- - mne_data_reref : mne object - mne object with re-referenced data + mne object + Re-referenced MNE object. """ elec_data = load_elec(elec_path, site=site) @@ -1998,17 +1927,22 @@ def ref_mne(mne_data=None, elec_path=None, method='wm', site='MSSM'): return mne_data_reref -def _bin_channelwise_times_into_behav_evs(channel_dict_seconds, ev_starts, ev_ends): - """ - feed in a dictionary of format {['channel_name']: [time1,...n]} - timepoints should be in seconds - every key corresponds to a channel in your mne object - - returns a dataframe of these timepoints binned relative to your behavioral epoch of interest - useful for detecting artifacts and IEDs in the signal prior to epoching and carrying over those - detections to the epoched data +def _bin_channelwise_times_into_behav_evs(channel_dict_seconds: dict, ev_starts: list, ev_ends: list): + """Bin channelwise times into behavioral events. - ev_starts and ev_ends should be the start and end of each epoch in seconds + Parameters + ---------- + channel_dict_seconds : dict + Dictionary of channel times in seconds. + ev_starts : list + Event start times. + ev_ends : list + Event end times. + + Returns + ------- + pd.DataFrame + DataFrame with binned timestamps. """ allts = {f'{x}': np.nan for x in channel_dict_seconds.keys()} for key in channel_dict_seconds.keys(): @@ -2038,55 +1972,40 @@ def _bin_channelwise_times_into_behav_evs(channel_dict_seconds, ev_starts, ev_en event_metadata.where(pd.notna(event_metadata), None) return event_metadata -def make_epochs(load_path=None, slope=None, offset=None, behav_name=None, behav_times=None, -ev_start_s=0, ev_end_s=1.5, buf_s=1, downsamp_factor=None, IED_args=None, baseline=None, detrend=None): - - # elec_path=None, - """ - - TODO: allow for a dict of pre and post times so they can vary across evs +def make_epochs(load_path=None, slope=None, offset=None, behav_name=None, behav_times=None, ev_start_s: float = 0, ev_end_s: float = 1.5, buf_s: float = 1, downsamp_factor=None, IED_args=None, baseline=None, detrend=None): + """Create epochs from continuous data. - behav_times: dict with format {'event_name': np.array([times])} - baseline_times: dict with format {'event_name': np.array([times])} - IED_args: dict with format {'peak_thresh':5, 'closeness_thresh':0.5, 'width_thresh':0.2} - - elec_data : pandas df - dataframe with all the electrode localization information - Parameters ---------- - load_path : str - path to the re-referenced neural data - slope : float - slope used for syncing behavioral and neural data - offset : float - offset used for syncing behavioral and neural data - behav_name : str - what event are we epoching to? - behav_times : dict - format - baseline_times : dict - format - ev_start_s: - - ev_end_s: - - method : str - how should we reference the data ['wm', 'bipolar'] - site : str - where was this data collected? Options: ['MSSM', 'UI', 'Davis'] - - buf_s : float - time to add as buffer in epochs - downsamp_factor : float - factor by which to downsample the data - IED_args: dict - format {'peak_thresh':5, 'closeness_thresh':0.5, 'width_thresh':0.2} - + load_path : str, optional + Path to re-referenced data. + slope : float, optional + Sync slope. + offset : float, optional + Sync offset. + behav_name : str, optional + Behavioral event name. + behav_times : dict, optional + Behavioral times dictionary. + ev_start_s : float, optional + Event start time. Default is 0. + ev_end_s : float, optional + Event end time. Default is 1.5. + buf_s : float, optional + Buffer time. Default is 1. + downsamp_factor : float, optional + Downsampling factor. + IED_args : dict, optional + IED detection arguments. + baseline : dict, optional + Baseline times dictionary. + detrend : int, optional + Detrend order. + Returns ------- - ev_epochs : mne object - mne Epoch object with re-referenced data + mne.Epochs + Epochs object. """ # Load the data @@ -2263,10 +2182,22 @@ def make_epochs(load_path=None, slope=None, offset=None, behav_name=None, behav_ # return revised_annot -def rename_elec_df_reref(reref_labels, elec_path, site='MSSM'): - - """ - Sometimes we want to filter and relabel our electrode dataframe based on the renamed channels from the re-referenced data +def rename_elec_df_reref(reref_labels: list, elec_path: str, site: str = 'MSSM'): + """Rename electrode DataFrame after re-referencing. + + Parameters + ---------- + reref_labels : list + List of re-referenced labels. + elec_path : str + Path to electrode file. + site : str, optional + Site name. Default is 'MSSM'. + + Returns + ------- + pd.DataFrame + Renamed electrode DataFrame. """ elec_data = load_elec(elec_path, site=site) @@ -2318,37 +2249,36 @@ def rename_elec_df_reref(reref_labels, elec_path, site='MSSM'): return elec_df # -def compute_and_baseline_tfr(baseline_event, task_events, freqs, n_cycles, load_path, save_path, - IED_artifact_thresh=True, uncaptured_z_thresh=True, output='save', tfr_method='morlet'): +def compute_and_baseline_tfr(baseline_event: dict, task_events: dict, freqs: np.ndarray, n_cycles: float, load_path: str, save_path: str, IED_artifact_thresh: bool = True, uncaptured_z_thresh: bool = True, output: str = 'save', tfr_method: str = 'morlet'): + """Compute and baseline TFR for events. - """ - This function computes the TFRs for the baseline and task events of interest, and baselines the task events of interest - Parameters ---------- baseline_event : dict - Dictionary with the key being the name of the baseline event, and the value being a list of the start and end time of the baseline event + Baseline event dictionary. task_events : dict - Dictionary with the key being the name of the task event, and the value being a list of the start and end time of the task event - tfr_method : str - The method to compute the TFR. Options: ['morlet', 'multitaper'] - freqs : array - The frequencies of interest for the TFR + Task events dictionary. + freqs : np.ndarray + Frequency array. n_cycles : float - The number of cycles for the Morlet wavelet + Number of cycles. load_path : str - The path to the directory where the epochs are stored + Path to epochs. save_path : str - The path to the directory where the TFRs will be saved - IED_artifact_thresh : bool - If True, will remove 100 ms before and after IEDs and artifacts from the TFRs - uncaptured_z_thresh : bool - If True, will iteratively remove absurd z-scores from the TFRs - output : str - If 'save', will save the TFRs to the save_path - If 'return', will return the TFRs - If 'both', will save and return the TFRs + Path to save TFRs. + IED_artifact_thresh : bool, optional + Whether to remove IEDs/artifacts. Default is True. + uncaptured_z_thresh : bool, optional + Whether to remove extreme z-scores. Default is True. + output : str, optional + Output mode. Default is 'save'. + tfr_method : str, optional + TFR method. Default is 'morlet'. + Returns + ------- + mne.time_frequency.EpochsTFR or None + TFR object if output is 'return' or 'both'. """ diff --git a/LFPAnalysis/nlx_utils.py b/LFPAnalysis/nlx_utils.py index a8786e4..a03eaa5 100644 --- a/LFPAnalysis/nlx_utils.py +++ b/LFPAnalysis/nlx_utils.py @@ -51,6 +51,18 @@ def read_header(fid): + """Read raw header data from file object. + + Parameters + ---------- + fid + File object. + + Returns + ------- + bytes + Raw header data. + """ # Read the raw header data (16 kb) from the file object fid. Restores the position in the file object after reading. pos = fid.tell() fid.seek(0) @@ -61,6 +73,18 @@ def read_header(fid): def parse_header(raw_hdr): + """Parse header string into dictionary. + + Parameters + ---------- + raw_hdr : bytes + Raw header bytes. + + Returns + ------- + dict + Parsed header dictionary. + """ # Parse the header string into a dictionary of name value pairs hdr = dict() @@ -101,7 +125,25 @@ def parse_header(raw_hdr): return hdr -def read_records(fid, record_dtype, record_skip=0, count=None): +def read_records(fid, record_dtype, record_skip: int = 0, count=None): + """Read records from file object. + + Parameters + ---------- + fid + File object. + record_dtype + NumPy dtype for records. + record_skip : int, optional + Number of records to skip. Default is 0. + count : int, optional + Number of records to read. Default is None (all). + + Returns + ------- + np.ndarray + Array of records. + """ # Read count records (default all) from the file object fid skipping the first record_skip records. Restores the # position of the file object after reading. if count is None: @@ -116,7 +158,21 @@ def read_records(fid, record_dtype, record_skip=0, count=None): return rec -def estimate_record_count(file_path, record_dtype): +def estimate_record_count(file_path: str, record_dtype): + """Estimate number of records from file size. + + Parameters + ---------- + file_path : str + Path to file. + record_dtype + NumPy dtype for records. + + Returns + ------- + float + Estimated number of records. + """ # Estimate the number of records from the file size file_size = os.path.getsize(file_path) file_size -= HEADER_LENGTH @@ -127,7 +183,19 @@ def estimate_record_count(file_path, record_dtype): return file_size / record_dtype.itemsize -def parse_neuralynx_time_string(time_string): +def parse_neuralynx_time_string(time_string: str): + """Parse datetime from Neuralynx time string. + + Parameters + ---------- + time_string : str + Time string from Neuralynx header. + + Returns + ------- + datetime.datetime or None + Parsed datetime object or None if parsing fails. + """ # Parse a datetime object from the idiosyncratic time string in Neuralynx file headers try: tmp_date = [int(x) for x in time_string.split()[4].split('/')] @@ -143,6 +211,18 @@ def parse_neuralynx_time_string(time_string): def check_ncs_records(records): + """Check that all records are similar. + + Parameters + ---------- + records : np.ndarray + Array of NCS records. + + Returns + ------- + bool + True if all records are similar, False otherwise. + """ # Check that all the records in the array are "similar" (have the same sampling frequency etc. dt = np.diff(records['TimeStamp']) dt = np.abs(dt - dt[0]) @@ -162,7 +242,25 @@ def check_ncs_records(records): return True -def load_ncs(file_path, load_time=True, rescale_data=True, signal_scaling=VOLT_SCALING): +def load_ncs(file_path: str, load_time: bool = True, rescale_data: bool = True, signal_scaling=VOLT_SCALING): + """Load Neuralynx .ncs continuous acquisition file. + + Parameters + ---------- + file_path : str + Path to .ncs file. + load_time : bool, optional + Whether to load time points. Default is True. + rescale_data : bool, optional + Whether to rescale data. Default is True. + signal_scaling + Signal scaling tuple. Default is VOLT_SCALING. + + Returns + ------- + dict + Dictionary containing file data and metadata. + """ # Load the given file as a Neuralynx .ncs continuous acquisition file and extract the contents file_path = os.path.abspath(file_path) with open(file_path, 'rb') as fid: @@ -204,7 +302,19 @@ def load_ncs(file_path, load_time=True, rescale_data=True, signal_scaling=VOLT_S return ncs -def load_nev(file_path): +def load_nev(file_path: str): + """Load Neuralynx .nev event file. + + Parameters + ---------- + file_path : str + Path to .nev file. + + Returns + ------- + dict + Dictionary containing event data and metadata. + """ # Load the given file as a Neuralynx .nev event file and extract the contents file_path = os.path.abspath(file_path) with open(file_path, 'rb') as fid: @@ -228,9 +338,30 @@ def load_nev(file_path): return nev -def parse_subject_nlx_data(ncs_files, eeg_names=None, resp_names=None, ekg_names=None, seeg_names=None, drop_names=None, include_micros=False): - """ - Iterate through a list of ncs files and extract the relevant data: signal, sr, channel type and channel name +def parse_subject_nlx_data(ncs_files, eeg_names=None, resp_names=None, ekg_names=None, seeg_names=None, drop_names=None, include_micros: bool = False): + """Parse subject NLX data from NCS files. + + Parameters + ---------- + ncs_files : list + List of NCS file paths. + eeg_names : list, optional + List of EEG channel names. + resp_names : list, optional + List of respiratory channel names. + ekg_names : list, optional + List of EKG channel names. + seeg_names : list, optional + List of sEEG channel names. + drop_names : list, optional + List of channel names to drop. + include_micros : bool, optional + Whether to include microwire data. Default is False. + + Returns + ------- + tuple + Tuple containing (signals, srs, ch_name, ch_type). """ signals = [] @@ -292,11 +423,13 @@ def parse_subject_nlx_data(ncs_files, eeg_names=None, resp_names=None, ekg_names return signals, srs, ch_name, ch_type -def merge_multiple_ncs_files(ncs_files): - """ - TODO - - Merge multiple ncs files. Usually done if recording was paused for whatever reason. +def merge_multiple_ncs_files(ncs_files): + """Merge multiple NCS files. + + Parameters + ---------- + ncs_files : list + List of NCS file paths to merge. """ merged_ncs_dict = {} diff --git a/LFPAnalysis/oscillation_utils.py b/LFPAnalysis/oscillation_utils.py index 1d08964..e73b217 100644 --- a/LFPAnalysis/oscillation_utils.py +++ b/LFPAnalysis/oscillation_utils.py @@ -11,6 +11,8 @@ from mne.filter import next_fast_len from IPython.display import clear_output from joblib import delayed, Parallel +import os +from typing import Union, Tuple, List, Optional, Dict, Any, Generator import scipy.special import warnings @@ -21,23 +23,39 @@ # Helper functions -def find_nearest_value(array, value): - """Find nearest value and index of float in array - Parameters: - array : Array of values [1d array] - value : Value of interest [float] - Returns: - array[idx] : Nearest value [1d float] - idx : Nearest index [1d float] +def find_nearest_value(array: np.ndarray, value: float) -> Tuple[float, int]: + """Find nearest value and index in array. + + Parameters + ---------- + array : np.ndarray + Array of values. + value : float + Value of interest. + + Returns + ------- + tuple + Tuple containing (nearest_value, index). """ array = np.asarray(array) idx = (np.abs(array - value)).argmin() return array[idx], idx -def getTimeFromFTmat(fname, var_name='data'): - """ - Get original timing from FieldTrip structure - Solution based on https://github.com/mne-tools/mne-python/issues/2476 +def getTimeFromFTmat(fname: str, var_name: str = 'data') -> np.ndarray: + """Get original timing from FieldTrip structure. + + Parameters + ---------- + fname : str + Path to MATLAB file. + var_name : str, optional + Variable name. Default is 'data'. + + Returns + ------- + np.ndarray + Time array. """ # load Matlab/Fieldtrip data mat = sio.loadmat(fname, squeeze_me=True, struct_as_record=False) @@ -54,6 +72,13 @@ def getTimeFromFTmat(fname, var_name='data'): return time def get_project_root() -> Path: + """Get project root path. + + Returns + ------- + Path + Project root path. + """ return Path(__file__) # def swap_time_blocks(data, random_state=None): @@ -94,22 +119,26 @@ def get_project_root() -> Path: # return np.concatenate(surr, axis=-1) -def make_surrogate_data(data, method='swap_epochs', n_shuffles=1000, rng_seed=42, return_generator=False): - """Create surrogate data for a null hypothesis of connectivity. +def make_surrogate_data(data: mne.Epochs, method: str = 'swap_epochs', n_shuffles: int = 1000, rng_seed: int = 42, return_generator: bool = False) -> Union[List[mne.Epochs], Generator[mne.Epochs, None, None]]: + """Create surrogate data for connectivity null hypothesis. Parameters ---------- data : mne.Epochs - The data to be shuffled. - method : str - The method to use for shuffling. Options are 'swap_time_blocks' and - 'swap_epochs'. - n_shuffles : int - The number of shuffles to perform. - rng_seed : int - The random seed to use for the shuffling. - return_generator : bool - If True, returns a generator object. If False, returns a list of surrogates. + MNE Epochs object. + method : str, optional + Shuffling method. Default is 'swap_epochs'. + n_shuffles : int, optional + Number of shuffles. Default is 1000. + rng_seed : int, optional + Random seed. Default is 42. + return_generator : bool, optional + Whether to return generator. Default is False. + + Returns + ------- + list or generator + Surrogate data. """ if method =='swap_time_blocks': surrogate = _shuffle_within_epochs(data, n_shuffles, rng_seed) @@ -119,9 +148,23 @@ def make_surrogate_data(data, method='swap_epochs', n_shuffles=1000, rng_seed=42 surrogate = [shuffle for shuffle in surrogate] return surrogate -def _shuffle_epochs(data, n_shuffles, rng_seed): - """Shuffle epochs in data. - This function shuffles the order of the epochs (dim 0) separately for each channel (dim 1)""" +def _shuffle_epochs(data: mne.Epochs, n_shuffles: int, rng_seed: int) -> Generator[mne.Epochs, None, None]: + """Shuffle epochs in data. + + Parameters + ---------- + data : mne.Epochs + MNE Epochs object. + n_shuffles : int + Number of shuffles. + rng_seed : int + Random seed. + + Yields + ------ + mne.Epochs + Shuffled epochs. + """ data_arr = data.get_data(copy=True) rng = np.random.default_rng(rng_seed) for _ in range(n_shuffles): @@ -134,10 +177,22 @@ def _shuffle_epochs(data, n_shuffles, rng_seed): new_epochs.set_annotations(data.annotations) yield new_epochs -def _shuffle_within_epochs(data, n_shuffles, rng_seed): - """Shuffle within epochs in data. - This function cuts the timeseries at a random time point. Then, both time - blocks are swapped. +def _shuffle_within_epochs(data: mne.Epochs, n_shuffles: int, rng_seed: int) -> Generator[mne.Epochs, None, None]: + """Shuffle within epochs by swapping time blocks. + + Parameters + ---------- + data : mne.Epochs + MNE Epochs object. + n_shuffles : int + Number of shuffles. + rng_seed : int + Random seed. + + Yields + ------ + mne.Epochs + Shuffled epochs. """ data_arr = data.get_data(copy=True) rng = np.random.default_rng(rng_seed) @@ -153,18 +208,43 @@ def _shuffle_within_epochs(data, n_shuffles, rng_seed): new_epochs.set_annotations(data.annotations) yield new_epochs -def _swap_time_blocks(data, cut_at): - """Swap time blocks in data at a given cutpoint.""" +def _swap_time_blocks(data: np.ndarray, cut_at: int) -> np.ndarray: + """Swap time blocks at cutpoint. + + Parameters + ---------- + data : np.ndarray + Data array. + cut_at : int + Cut point index. + + Returns + ------- + np.ndarray + Swapped data. + """ surr = np.array_split(data, [cut_at], axis=-1) surr.reverse() return np.concatenate(surr, axis=-1) -def make_seed_target_df(elec_df, epochs, source_roi, target_roi): - - """ - Create arrays of indices for mapping electrodes for connectivity analyses. Use the Epoch - ch_names list itself to find the index of the electrode within the mne object +def make_seed_target_df(elec_df: pd.DataFrame, epochs: mne.Epochs, source_roi: str, target_roi: str) -> pd.DataFrame: + """Create seed-target DataFrame for connectivity. + Parameters + ---------- + elec_df : pd.DataFrame + Electrode DataFrame. + epochs : mne.Epochs + MNE Epochs object. + source_roi : str + Source ROI name. + target_roi : str + Target ROI name. + + Returns + ------- + pd.DataFrame + Seed-target DataFrame. """ seed_target_df = pd.DataFrame(columns=['seed', 'target'], index=['l', 'r']) @@ -202,12 +282,18 @@ def make_seed_target_df(elec_df, epochs, source_roi, target_roi): Gaussian copula mutual information estimation """ -def ctransform(x): - """Copula transformation (empirical CDF) - - cx = ctransform(x) returns the empirical CDF value along the first - axis of x. Data is ranked and scaled within [0 1] (open interval). - +def ctransform(x: np.ndarray) -> np.ndarray: + """Copula transformation (empirical CDF). + + Parameters + ---------- + x : np.ndarray + Input data. + + Returns + ------- + np.ndarray + Empirical CDF values. """ xi = np.argsort(np.atleast_2d(x)) @@ -216,26 +302,38 @@ def ctransform(x): return cx -def copnorm(x): - """Copula normalization +def copnorm(x: np.ndarray) -> np.ndarray: + """Copula normalization. - cx = copnorm(x) returns standard normal samples with the same empirical - CDF value as the input. Operates along the last axis. - + Parameters + ---------- + x : np.ndarray + Input data. + + Returns + ------- + np.ndarray + Standard normal samples. """ #cx = sp.stats.norm.ppf(ctransform(x)) cx = sp.special.ndtri(ctransform(x)) return cx -def ent_g(x, biascorrect=True): - """Entropy of a Gaussian variable in bits - - H = ent_g(x) returns the entropy of a (possibly - multidimensional) Gaussian variable x with bias correction. - Columns of x correspond to samples, rows to dimensions/variables. - (Samples last axis) - +def ent_g(x: np.ndarray, biascorrect: bool = True) -> float: + """Compute entropy of Gaussian variable. + + Parameters + ---------- + x : np.ndarray + Input data. + biascorrect : bool, optional + Whether to apply bias correction. Default is True. + + Returns + ------- + float + Entropy in bits. """ x = np.atleast_2d(x) if x.ndim > 2: @@ -262,19 +360,24 @@ def ent_g(x, biascorrect=True): return HX / ln2 -def mi_gg(x, y, biascorrect=True, demeaned=False): - """Mutual information (MI) between two Gaussian variables in bits - - I = mi_gg(x,y) returns the MI between two (possibly multidimensional) - Gassian variables, x and y, with bias correction. - If x and/or y are multivariate columns must correspond to samples, rows - to dimensions/variables. (Samples last axis) - - biascorrect : true / false option (default true) which specifies whether - bias correction should be applied to the esimtated MI. - demeaned : false / true option (default false) which specifies whether th - input data already has zero mean (true if it has been copula-normalized) - +def mi_gg(x: np.ndarray, y: np.ndarray, biascorrect: bool = True, demeaned: bool = False) -> float: + """Compute mutual information between Gaussian variables. + + Parameters + ---------- + x : np.ndarray + First variable. + y : np.ndarray + Second variable. + biascorrect : bool, optional + Whether to apply bias correction. Default is True. + demeaned : bool, optional + Whether data is already demeaned. Default is False. + + Returns + ------- + float + Mutual information in bits. """ x = np.atleast_2d(x) @@ -321,15 +424,20 @@ def mi_gg(x, y, biascorrect=True, demeaned=False): return I -def gcmi_cc(x,y): - """Gaussian-Copula Mutual Information between two continuous variables. - - I = gcmi_cc(x,y) returns the MI between two (possibly multidimensional) - continuous variables, x and y, estimated via a Gaussian copula. - If x and/or y are multivariate columns must correspond to samples, rows - to dimensions/variables. (Samples first axis) - This provides a lower bound to the true MI value. - +def gcmi_cc(x: np.ndarray, y: np.ndarray) -> float: + """Compute Gaussian-copula mutual information. + + Parameters + ---------- + x : np.ndarray + First variable. + y : np.ndarray + Second variable. + + Returns + ------- + float + Mutual information in bits. """ x = np.atleast_2d(x) @@ -361,24 +469,26 @@ def gcmi_cc(x,y): return I -def mi_model_gd(x, y, Ym, biascorrect=True, demeaned=False): - """Mutual information (MI) between a Gaussian and a discrete variable in bits - based on ANOVA style model comparison. - - I = mi_model_gd(x,y,Ym) returns the MI between the (possibly multidimensional) - Gaussian variable x and the discrete variable y. - For 1D x this is a lower bound to the mutual information. - Columns of x correspond to samples, rows to dimensions/variables. - (Samples last axis) - y should contain integer values in the range [0 Ym-1] (inclusive). - - biascorrect : true / false option (default true) which specifies whether - bias correction should be applied to the esimtated MI. - demeaned : false / true option (default false) which specifies whether the - input data already has zero mean (true if it has been copula-normalized) - - See also: mi_mixture_gd - +def mi_model_gd(x: np.ndarray, y: np.ndarray, Ym: int, biascorrect: bool = True, demeaned: bool = False) -> float: + """Compute MI between Gaussian and discrete variable. + + Parameters + ---------- + x : np.ndarray + Gaussian variable. + y : np.ndarray + Discrete variable. + Ym : int + Number of discrete values. + biascorrect : bool, optional + Whether to apply bias correction. Default is True. + demeaned : bool, optional + Whether data is demeaned. Default is False. + + Returns + ------- + float + Mutual information in bits. """ x = np.atleast_2d(x) @@ -442,19 +552,22 @@ def mi_model_gd(x, y, Ym, biascorrect=True, demeaned=False): return I -def gcmi_model_cd(x,y,Ym): - """Gaussian-Copula Mutual Information between a continuous and a discrete variable - based on ANOVA style model comparison. - - I = gcmi_model_cd(x,y,Ym) returns the MI between the (possibly multidimensional) - continuous variable x and the discrete variable y. - For 1D x this is a lower bound to the mutual information. - Columns of x correspond to samples, rows to dimensions/variables. - (Samples last axis) - y should contain integer values in the range [0 Ym-1] (inclusive). - - See also: gcmi_mixture_cd - +def gcmi_model_cd(x: np.ndarray, y: np.ndarray, Ym: int) -> float: + """Compute Gaussian-copula MI between continuous and discrete variable. + + Parameters + ---------- + x : np.ndarray + Continuous variable. + y : np.ndarray + Discrete variable. + Ym : int + Number of discrete values. + + Returns + ------- + float + Mutual information in bits. """ x = np.atleast_2d(x) @@ -491,18 +604,22 @@ def gcmi_model_cd(x,y,Ym): return I -def mi_mixture_gd(x, y, Ym): - """Mutual information (MI) between a Gaussian and a discrete variable in bits - calculated from a Gaussian mixture. - - I = mi_mixture_gd(x,y,Ym) returns the MI between the (possibly multidimensional) - Gaussian variable x and the discrete variable y. - Columns of x correspond to samples, rows to dimensions/variables. - (Samples last axis) - y should contain integer values in the range [0 Ym-1] (inclusive). - - See also: mi_model_gd - +def mi_mixture_gd(x: np.ndarray, y: np.ndarray, Ym: int) -> float: + """Compute MI using Gaussian mixture model. + + Parameters + ---------- + x : np.ndarray + Gaussian variable. + y : np.ndarray + Discrete variable. + Ym : int + Number of discrete values. + + Returns + ------- + float + Mutual information in bits. """ x = np.atleast_2d(x) @@ -586,28 +703,42 @@ def mi_mixture_gd(x, y, Ym): I = (Hmix - np.sum(w*Hcond)) / np.log(2.0) return I -def _norm_innerv(x, chC): - """ normalised innervations """ +def _norm_innerv(x: np.ndarray, chC: np.ndarray) -> np.ndarray: + """Compute normalized inner products. + + Parameters + ---------- + x : np.ndarray + Input data. + chC : np.ndarray + Cholesky decomposition. + + Returns + ------- + np.ndarray + Normalized inner products. + """ m = np.linalg.solve(chC,x) w = -0.5 * (m * m).sum(axis=0) return w -def gcmi_mixture_cd(x,y,Ym): - """Gaussian-Copula Mutual Information between a continuous and a discrete variable - calculated from a Gaussian mixture. - - The Gaussian mixture is fit using robust measures of location (median) and scale - (median absolute deviation) for each class. - I = gcmi_mixture_cd(x,y,Ym) returns the MI between the (possibly multidimensional) - continuous variable x and the discrete variable y. - For 1D x this is a lower bound to the mutual information. - Columns of x correspond to samples, rows to dimensions/variables. - (Samples last axis) - y should contain integer values in the range [0 Ym-1] (inclusive). - - See also: gcmi_model_cd - +def gcmi_mixture_cd(x: np.ndarray, y: np.ndarray, Ym: int) -> float: + """Compute Gaussian-copula MI using Gaussian mixture. + + Parameters + ---------- + x : np.ndarray + Continuous variable. + y : np.ndarray + Discrete variable. + Ym : int + Number of discrete values. + + Returns + ------- + float + Mutual information in bits. """ x = np.atleast_2d(x) @@ -663,20 +794,26 @@ def gcmi_mixture_cd(x,y,Ym): return I -def cmi_ggg(x, y, z, biascorrect=True, demeaned=False): - """Conditional Mutual information (CMI) between two Gaussian variables - conditioned on a third - - I = cmi_ggg(x,y,z) returns the CMI between two (possibly multidimensional) - Gassian variables, x and y, conditioned on a third, z, with bias correction. - If x / y / z are multivariate columns must correspond to samples, rows - to dimensions/variables. (Samples last axis) - - biascorrect : true / false option (default true) which specifies whether - bias correction should be applied to the esimtated MI. - demeaned : false / true option (default false) which specifies whether the - input data already has zero mean (true if it has been copula-normalized) - +def cmi_ggg(x: np.ndarray, y: np.ndarray, z: np.ndarray, biascorrect: bool = True, demeaned: bool = False) -> float: + """Compute conditional mutual information. + + Parameters + ---------- + x : np.ndarray + First variable. + y : np.ndarray + Second variable. + z : np.ndarray + Conditioning variable. + biascorrect : bool, optional + Whether to apply bias correction. Default is True. + demeaned : bool, optional + Whether data is demeaned. Default is False. + + Returns + ------- + float + Conditional mutual information in bits. """ x = np.atleast_2d(x) @@ -736,15 +873,22 @@ def cmi_ggg(x, y, z, biascorrect=True, demeaned=False): return I -def gccmi_ccc(x,y,z): - """Gaussian-Copula CMI between three continuous variables. - - I = gccmi_ccc(x,y,z) returns the CMI between two (possibly multidimensional) - continuous variables, x and y, conditioned on a third, z, estimated via a - Gaussian copula. - If x and/or y are multivariate columns must correspond to samples, rows - to dimensions/variables. (Samples first axis) - +def gccmi_ccc(x: np.ndarray, y: np.ndarray, z: np.ndarray) -> float: + """Compute Gaussian-copula conditional mutual information. + + Parameters + ---------- + x : np.ndarray + First variable. + y : np.ndarray + Second variable. + z : np.ndarray + Conditioning variable. + + Returns + ------- + float + Conditional mutual information in bits. """ x = np.atleast_2d(x) @@ -784,16 +928,24 @@ def gccmi_ccc(x,y,z): return I -def gccmi_ccd(x,y,z,Zm): - """Gaussian-Copula CMI between 2 continuous variables conditioned on a discrete variable. - - I = gccmi_ccd(x,y,z,Zm) returns the CMI between two (possibly multidimensional) - continuous variables, x and y, conditioned on a third discrete variable z, estimated - via a Gaussian copula. - If x and/or y are multivariate columns must correspond to samples, rows - to dimensions/variables. (Samples first axis) - z should contain integer values in the range [0 Zm-1] (inclusive). - +def gccmi_ccd(x: np.ndarray, y: np.ndarray, z: np.ndarray, Zm: int) -> Tuple[float, float]: + """Compute Gaussian-copula CMI conditioned on discrete variable. + + Parameters + ---------- + x : np.ndarray + First continuous variable. + y : np.ndarray + Second continuous variable. + z : np.ndarray + Discrete conditioning variable. + Zm : int + Number of discrete values. + + Returns + ------- + tuple + Tuple containing (CMI, I). """ x = np.atleast_2d(x) @@ -849,23 +1001,24 @@ def gccmi_ccd(x,y,z,Zm): I = mi_gg(np.hstack(cx),np.hstack(cy),True,False) return (CMI,I) -def phase_gcmi(mne_data, seed_to_target, freqs0, freqs1=None): - """ +def phase_gcmi(mne_data: mne.Epochs, seed_to_target: Tuple[np.ndarray, np.ndarray], freqs0: Tuple[float, float], freqs1: Optional[Tuple[float, float]] = None) -> np.ndarray: + """Compute phase-based Gaussian-copula mutual information. - Compute the gaussian-copula condition mutual information between the phase of two signals. - Can be within-frequency or between-frequency coupling. - Parameters ---------- - mne_data : epochs object - MNE epochs object containing the data to be analyzed. - seed_to_target : list of tuples - List of tuples containing the indices of the seed and target electrodes. - freqs0 : list or tuple - Frequency range for the first signal. - freqs1 : list or tuple - Frequency range for the second signal. - + mne_data : mne.Epochs + MNE epochs object. + seed_to_target : tuple + Seed-to-target indices as (seed_indices, target_indices). + freqs0 : tuple + Frequency range for first signal as (low, high). + freqs1 : tuple, optional + Frequency range for second signal as (low, high). Default is None. + + Returns + ------- + np.ndarray + Pairwise connectivity matrix. """ nevents = mne_data._data.shape[0] @@ -929,23 +1082,24 @@ def phase_gcmi(mne_data, seed_to_target, freqs0, freqs1=None): return pairwise_connectivity -def amp_amp_coupling(mne_data, seed_to_target, freqs0, freqs1=None): - """ - Compute the correlation between the amplitude envelope of two signals. - Can be within-frequency or between-frequency coupling. - +def amp_amp_coupling(mne_data: mne.Epochs, seed_to_target: Tuple[np.ndarray, np.ndarray], freqs0: Tuple[float, float], freqs1: Optional[Tuple[float, float]] = None) -> np.ndarray: + """Compute amplitude-amplitude coupling. + Parameters ---------- - mne_data : epochs object - MNE epochs object containing the data to be analyzed. - seed_to_target : list of tuples - List of tuples containing the indices of the seed and target electrodes. - freqs0 : list or tuple - Frequency range for the first signal. - freqs1 : list or tuple - Frequency range for the second signal. If None, assume within-frequency coupling. - - Note: inspired by MNE's pairwise orthogonal envelope connectivity metric but altered for iEEG data + mne_data : mne.Epochs + MNE epochs object. + seed_to_target : tuple + Seed-to-target indices as (seed_indices, target_indices). + freqs0 : tuple + Frequency range for first signal as (low, high). + freqs1 : tuple, optional + Frequency range for second signal as (low, high). Default is None. + + Returns + ------- + np.ndarray + Pairwise connectivity matrix. """ nevents = mne_data._data.shape[0] @@ -1015,17 +1169,34 @@ def amp_amp_coupling(mne_data, seed_to_target, freqs0, freqs1=None): return pairwise_connectivity -def compute_gc_tr(mne_data=None, - band=None, - indices=None, - freqs=None, - n_cycles=None, - rank=None, - gc_n_lags=15, - buf_ms=1000, - avg_over_dim='time'): - """ - Following https://mne.tools/mne-connectivity/stable/auto_examples/granger_causality.html#sphx-glr-auto-examples-granger-causality-py +def compute_gc_tr(mne_data: Optional[mne.Epochs] = None, band: Optional[Tuple[float, float]] = None, indices: Optional[Tuple[np.ndarray, np.ndarray]] = None, freqs: Optional[np.ndarray] = None, n_cycles: Optional[Union[float, np.ndarray]] = None, rank: Optional[int] = None, gc_n_lags: int = 15, buf_ms: int = 1000, avg_over_dim: str = 'time') -> np.ndarray: + """Compute Granger causality time-resolved. + + Parameters + ---------- + mne_data : mne.Epochs, optional + MNE epochs object. + band : tuple, optional + Frequency band as (low, high). + indices : tuple, optional + Connectivity indices as (seed_indices, target_indices). + freqs : np.ndarray, optional + Frequency array. + n_cycles : float or np.ndarray, optional + Number of cycles. + rank : int, optional + Rank parameter. + gc_n_lags : int, optional + Number of lags. Default is 15. + buf_ms : int, optional + Buffer in milliseconds. Default is 1000. + avg_over_dim : str, optional + Dimension to average over. Default is 'time'. + + Returns + ------- + np.ndarray + Granger causality results. """ indices_ab = (np.array([np.unique(indices[0]).tolist()]), np.array([np.unique(indices[1]).tolist()])) # A => B @@ -1148,7 +1319,37 @@ def compute_gc_tr(mne_data=None, else: return np.squeeze(gc_tr) -def compute_surr_connectivity_epochs(surr_mne, indices, metric, band, freqs, n_cycles, surr_method = 'swap_epochs', rng_seed=None, gc_n_lags=15, buf_ms=1000): +def compute_surr_connectivity_epochs(surr_mne: mne.Epochs, indices: Tuple[np.ndarray, np.ndarray], metric: str, band: Tuple[float, float], freqs: np.ndarray, n_cycles: Union[float, np.ndarray], surr_method: str = 'swap_epochs', rng_seed: Optional[int] = None, gc_n_lags: int = 15, buf_ms: int = 1000) -> np.ndarray: + """Compute surrogate connectivity over epochs. + + Parameters + ---------- + surr_mne : mne.Epochs + Surrogate MNE epochs. + indices : tuple + Connectivity indices as (seed_indices, target_indices). + metric : str + Connectivity metric. + band : tuple + Frequency band as (low, high). + freqs : np.ndarray + Frequency array. + n_cycles : float or np.ndarray + Number of cycles. + surr_method : str, optional + Surrogate method. Default is 'swap_epochs'. + rng_seed : int, optional + Random seed. + gc_n_lags : int, optional + Number of lags. Default is 15. + buf_ms : int, optional + Buffer in milliseconds. Default is 1000. + + Returns + ------- + np.ndarray + Surrogate connectivity results. + """ n_pairs = len(indices[0]) # data = np.swapaxes(mne_data.get_data(copy=False), 0, 1) # swap so now it's chan, events, times @@ -1233,7 +1434,37 @@ def compute_surr_connectivity_epochs(surr_mne, indices, metric, band, freqs, n_c return surr_conn -def compute_surr_connectivity_time(surr_mne, indices, metric, band, freqs, n_cycles, buf_ms, surr_method = 'swap_epochs', rng_seed=42, gc_n_lags=15): +def compute_surr_connectivity_time(surr_mne: mne.Epochs, indices: Tuple[np.ndarray, np.ndarray], metric: str, band: Tuple[float, float], freqs: np.ndarray, n_cycles: Union[float, np.ndarray], buf_ms: Union[int, Tuple[int, int]], surr_method: str = 'swap_epochs', rng_seed: int = 42, gc_n_lags: int = 15) -> np.ndarray: + """Compute surrogate connectivity over time. + + Parameters + ---------- + surr_mne : mne.Epochs + Surrogate MNE epochs. + indices : tuple + Connectivity indices as (seed_indices, target_indices). + metric : str + Connectivity metric. + band : tuple + Frequency band as (low, high). + freqs : np.ndarray + Frequency array. + n_cycles : float or np.ndarray + Number of cycles. + buf_ms : int or tuple + Buffer in milliseconds. + surr_method : str, optional + Surrogate method. Default is 'swap_epochs'. + rng_seed : int, optional + Random seed. Default is 42. + gc_n_lags : int, optional + Number of lags. Default is 15. + + Returns + ------- + np.ndarray + Surrogate connectivity results. + """ n_pairs = len(indices[0]) # data = np.swapaxes(mne_data.get_data(copy=False), 0, 1) # swap so now it's chan, events, times @@ -1300,29 +1531,42 @@ def compute_surr_connectivity_time(surr_mne, indices, metric, band, freqs, n_cyc return surr_conn -def compute_connectivity(mne_data=None, - band=None, - metric=None, - indices=None, - freqs=None, - n_cycles=None, - buf_ms=1000, - avg_over_dim='time', - surr_method = 'swap_epochs', - n_surr=500, - parallelize=False, - band1=None, - gc_n_lags=7): - """ - Compute different connectivity metrics using mne. - :param eeg_mne: MNE formatted EEG - :param samplerate: sample rate of the data - :param band: tuple of band of interest - :param metric: 'psi' for directional, or for non_directional: ['coh', 'cohy', 'imcoh', 'plv', 'ciplv', 'ppc', 'pli', pli2_unbiased', 'dpli', 'wpli', 'wpli2_debiased'] - see: https://mne.tools/mne-connectivity/stable/generated/mne_connectivity.spectral_connectivity_epochs.html - :param indices: determine the source and target for connectivity. Matters most for directional metrics i.e. 'psi' - :return: - pairwise connectivity: array of pairwise weights for the connectivity metric with some number of timepoints +def compute_connectivity(mne_data: Optional[mne.Epochs] = None, band: Optional[Tuple[float, float]] = None, metric: Optional[str] = None, indices: Optional[Tuple[np.ndarray, np.ndarray]] = None, freqs: Optional[np.ndarray] = None, n_cycles: Optional[Union[float, np.ndarray]] = None, buf_ms: int = 1000, avg_over_dim: str = 'time', surr_method: str = 'swap_epochs', n_surr: int = 500, parallelize: bool = False, band1: Optional[Tuple[float, float]] = None, gc_n_lags: int = 7) -> np.ndarray: + """Compute connectivity metrics. + + Parameters + ---------- + mne_data : mne.Epochs, optional + MNE epochs object. + band : tuple, optional + Frequency band as (low, high). + metric : str, optional + Connectivity metric. + indices : tuple, optional + Connectivity indices as (seed_indices, target_indices). + freqs : np.ndarray, optional + Frequency array. + n_cycles : float or np.ndarray, optional + Number of cycles. + buf_ms : int, optional + Buffer in milliseconds. Default is 1000. + avg_over_dim : str, optional + Dimension to average over. Default is 'time'. + surr_method : str, optional + Surrogate method. Default is 'swap_epochs'. + n_surr : int, optional + Number of surrogates. Default is 500. + parallelize : bool, optional + Whether to parallelize. Default is False. + band1 : tuple, optional + Second frequency band as (low, high). + gc_n_lags : int, optional + Number of lags. Default is 7. + + Returns + ------- + np.ndarray + Connectivity results. """ if metric == 'gr_tc': return (ValueError('Use the function compute_gc_tr')) @@ -1766,20 +2010,24 @@ def _process_surrogate_time(ns): --- """ -def BOSC_tf(eegsignal,F,Fsample,wavenumber): - """ - Computes the Better Oscillation Detection (BOSC) time-frequency matrix for a given LFP signal. - - Args: - - eegsignal (numpy.ndarray): The LFP signal to compute the BOSC time-frequency matrix for. - - F (numpy.ndarray): The frequency range to compute the BOSC time-frequency matrix over. - - Fsample (float): The sampling frequency of the LFP signal. - - wavenumber (float): The wavenumber to use for the Morlet wavelet. - - Returns: - - B (numpy.ndarray): The BOSC time-frequency matrix. - - T (numpy.ndarray): The time vector corresponding to the BOSC time-frequency matrix. - - F (numpy.ndarray): The frequency vector corresponding to the BOSC time-frequency matrix. +def BOSC_tf(eegsignal: np.ndarray, F: np.ndarray, Fsample: float, wavenumber: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute BOSC time-frequency matrix. + + Parameters + ---------- + eegsignal : np.ndarray + LFP signal. + F : np.ndarray + Frequency range. + Fsample : float + Sampling frequency. + wavenumber : float + Morlet wavelet wavenumber. + + Returns + ------- + tuple + Tuple containing (B, T, F). """ st=1./(2*np.pi*(F/wavenumber)) @@ -1800,30 +2048,24 @@ def BOSC_tf(eegsignal,F,Fsample,wavenumber): return B, T, F -def BOSC_detect(b,powthresh,durthresh,Fsample): - """ - detected=BOSC_detect(b,powthresh,durthresh,Fsample) - This function detects oscillations based on a wavelet power - timecourse, b, a power threshold (powthresh) and duration - threshold (durthresh) returned from BOSC_thresholds.m. +def BOSC_detect(b: np.ndarray, powthresh: float, durthresh: float, Fsample: float) -> np.ndarray: + """Detect oscillations using BOSC. - It now returns the detected vector which is already episode-detected. - - b - the power timecourse (at one frequency of interest) - - durthresh - duration threshold in required to be deemed oscillatory - powthresh - power threshold - - returns: - detected - a binary vector containing the value 1 for times at - which oscillations (at the frequency of interest) were - detected and 0 where no oscillations were detected. - - note: Remember to account for edge effects by including - "shoulder" data and accounting for it afterwards! - - To calculate Pepisode: - Pepisode=length(find(detected))/(length(detected)); + Parameters + ---------- + b : np.ndarray + Power timecourse. + powthresh : float + Power threshold. + durthresh : float + Duration threshold. + Fsample : float + Sampling frequency. + + Returns + ------- + np.ndarray + Binary detection vector. """ # number of time points @@ -1889,23 +2131,22 @@ def BOSC_detect(b,powthresh,durthresh,Fsample): detected = np.array(list(map(np.int, detected))) return detected -def eBOSC_getThresholds(cfg_eBOSC, TFR, eBOSC): - """This function estimates the static duration and power thresholds and - saves information regarding the overall spectrum and background. - Inputs: - cfg | config structure with cfg.eBOSC field - TFR | time-frequency matrix - eBOSC | main eBOSC output structure; will be updated - - Outputs: - eBOSC | updated w.r.t. background info (see below) - | bg_pow: overall power spectrum - | bg_log10_pow: overall power spectrum (log10) - | pv: intercept and slope of fit - | mp: linear background power - | pt: power threshold - pt | empirical power threshold - dt | duration threshold +def eBOSC_getThresholds(cfg_eBOSC: dict, TFR: np.ndarray, eBOSC: dict) -> Tuple[dict, np.ndarray, np.ndarray]: + """Estimate static duration and power thresholds. + + Parameters + ---------- + cfg_eBOSC : dict + Configuration dictionary. + TFR : np.ndarray + Time-frequency matrix. + eBOSC : dict + eBOSC output structure. + + Returns + ------- + tuple + Tuple containing (eBOSC, pt, dt). """ # concatenate power estimates in time across trials of interest @@ -1989,8 +2230,22 @@ def eBOSC_getThresholds(cfg_eBOSC, TFR, eBOSC): return eBOSC, pt, dt -def eBOSC_episode_sparsefreq(cfg_eBOSC, detected, TFR): - """Sparsen the detected matrix along the frequency dimension +def eBOSC_episode_sparsefreq(cfg_eBOSC: dict, detected: np.ndarray, TFR: np.ndarray) -> np.ndarray: + """Sparsen detected matrix along frequency dimension. + + Parameters + ---------- + cfg_eBOSC : dict + Configuration dictionary. + detected : np.ndarray + Detected matrix. + TFR : np.ndarray + Time-frequency matrix. + + Returns + ------- + np.ndarray + Sparsed detected matrix. """ # print('Creating sparse detected matrix ...') @@ -2037,20 +2292,26 @@ def eBOSC_episode_sparsefreq(cfg_eBOSC, detected, TFR): detected = detected[freqs_to_search,:] return detected -def eBOSC_episode_postproc_fwhm(cfg_eBOSC, episodes, TFR): - """ - % This function performs post-processing of input episodes by checking - % whether 'detected' time points can trivially be explained by the FWHM of - % the wavelet used in the time-frequency transform. - % - % Inputs: - % cfg | config structure with cfg.eBOSC field - % episodes | table of episodes - % TFR | time-frequency matrix - % - % Outputs: - % episodes_new | updated table of episodes - % detected_new | updated binary detected matrix +def eBOSC_episode_postproc_fwhm(cfg_eBOSC: dict, episodes: dict, TFR: np.ndarray) -> Tuple[dict, np.ndarray]: + """Perform post-processing of episodes using FWHM correction. + + This function performs post-processing of input episodes by checking + whether 'detected' time points can trivially be explained by the FWHM of + the wavelet used in the time-frequency transform. + + Parameters + ---------- + cfg_eBOSC : dict + Configuration dictionary with eBOSC field. + episodes : dict + Table of episodes. + TFR : np.ndarray + Time-frequency matrix. + + Returns + ------- + tuple + Tuple containing (episodes_new, detected_new). """ print("Applying FWHM post-processing ...") @@ -2178,31 +2439,37 @@ def eBOSC_episode_postproc_fwhm(cfg_eBOSC, episodes, TFR): # return post-processed episode dictionary and updated binary detected matrix return episodesTable, detected_new -def eBOSC_episode_postproc_maxbias(cfg_eBOSC, episodes, TFR): - """ - % This function performs post-processing of input episodes by checking - % whether 'detected' time points can be explained by the simulated extension of - % the wavelet used in the time-frequency transform. - % - % Inputs: - % cfg | config structure with cfg.eBOSC field - % episodes | table of episodes - % TFR | time-frequency matrix - % - % Outputs: - % episodes_new | updated table of episodes - % detected_new | updated binary detected matrix - - % This method works as follows: we estimate the bias introduced by - % wavelet convolution. The bias is represented by the amplitudes - % estimated for the zero-shouldered signal (i.e. for which no real - % data was initially available). The influence of episodic - % amplitudes on neighboring time points is assessed by scaling each - % time point's amplitude with the last 'rhythmic simulated time - % point', i.e. the first time wavelet amplitude in the simulated - % rhythmic time points. At this time point the 'bias' is maximal, - % although more precisely, this amplitude does not represent a - % bias per se. +def eBOSC_episode_postproc_maxbias(cfg_eBOSC: dict, episodes: dict, TFR: np.ndarray) -> Tuple[dict, np.ndarray]: + """Perform post-processing of episodes using maxbias correction. + + This function performs post-processing of input episodes by checking + whether 'detected' time points can be explained by the simulated extension of + the wavelet used in the time-frequency transform. + + This method works as follows: we estimate the bias introduced by + wavelet convolution. The bias is represented by the amplitudes + estimated for the zero-shouldered signal (i.e. for which no real + data was initially available). The influence of episodic + amplitudes on neighboring time points is assessed by scaling each + time point's amplitude with the last 'rhythmic simulated time + point', i.e. the first time wavelet amplitude in the simulated + rhythmic time points. At this time point the 'bias' is maximal, + although more precisely, this amplitude does not represent a + bias per se. + + Parameters + ---------- + cfg_eBOSC : dict + Configuration dictionary with eBOSC field. + episodes : dict + Table of episodes. + TFR : np.ndarray + Time-frequency matrix. + + Returns + ------- + tuple + Tuple containing (episodes_new, detected_new). """ print("Applying maxbias post-processing ...") @@ -2326,11 +2593,22 @@ def eBOSC_episode_postproc_maxbias(cfg_eBOSC, episodes, TFR): # return post-processed episode dictionary and updated binary detected matrix return episodesTable, detected_new -def eBOSC_episode_rm_shoulder(cfg_eBOSC,detected1,episodes): - """ Remove parts of the episode that fall into the 'shoulder' of individual - trials. There is no check for adherence to a given duration criterion necessary, - as the point of the padding of the detected matrix is exactly to account - for allowing the presence of a few cycles. +def eBOSC_episode_rm_shoulder(cfg_eBOSC: dict, detected1: np.ndarray, episodes: dict): + """Remove episode parts in trial shoulders. + + Parameters + ---------- + cfg_eBOSC : dict + Configuration dictionary. + detected1 : np.ndarray + Detected matrix. + episodes : dict + Episodes dictionary. + + Returns + ------- + dict + Updated episodes dictionary. """ print("Removing padding from detected episodes") @@ -2376,46 +2654,55 @@ def eBOSC_episode_rm_shoulder(cfg_eBOSC,detected1,episodes): episodes[entry] = [v for i, v in enumerate(episodes[entry]) if i not in rmv] return episodes -def eBOSC_episode_create(cfg_eBOSC,TFR,detected,eBOSC): - """This function creates continuous rhythmic "episodes" and attempts to control for the impact of wavelet parameters. - Time-frequency points that best represent neural rhythms are identified by - heuristically removing temporal and frequency leakage. - - Frequency leakage: at each frequency x time point, power has to exceed neighboring frequencies. - Then it is checked whether the detected time-frequency points belong to - a continuous episode for which (1) the frequency maximally changes by - +/- n steps (cfg.eBOSC.fstp) from on time point to the next and (2) that is at - least as long as n number of cycles (cfg.eBOSC.threshold.duration) of the average freqency - of that episode (a priori duration threshold). - - Temporal leakage: The impact of the amplitude at each time point within a rhythmic episode on previous - and following time points is tested with the goal to exclude supra-threshold time - points that are due to the wavelet extension in time. - - Input: - cfg | config structure with cfg.eBOSC field - TFR | time-frequency matrix (excl. WLpadding) - detected | detected oscillations in TFR (based on power and duration threshold) - eBOSC | main eBOSC output structure; necessary to read in - prior eBOSC.episodes if they exist in a loop - - Output: - detected_new | new detected matrix with frequency leakage removed - episodesTable | table with specific episode information: - Trial: trial index (corresponds to cfg.eBOSC.trial) - Channel: channel index - FrequencyMean: mean frequency of episode (Hz) - DurationS: episode duration (in sec) - DurationC: episode duration (in cycles, based on mean frequency) - PowerMean: mean amplitude of amplitude - Onset: episode onset in s - Offset: episode onset in s - Power: (cell) time-resolved wavelet-based amplitude estimates during episode - Frequency: (cell) time-resolved wavelet-based frequency - RowID: (cell) row index (frequency dimension): following eBOSC_episode_rm_shoulder relative to data excl. detection padding - ColID: (cell) column index (time dimension) - SNR: (cell) time-resolved signal-to-noise ratio: momentary amplitude/static background estimate at channel*frequency - SNRMean: mean signal-to-noise ratio +def eBOSC_episode_create(cfg_eBOSC: dict, TFR: np.ndarray, detected: np.ndarray, eBOSC: dict) -> Tuple[dict, np.ndarray]: + """Create continuous rhythmic episodes and control for wavelet parameter impact. + + This function creates continuous rhythmic "episodes" and attempts to control for the impact of wavelet parameters. + Time-frequency points that best represent neural rhythms are identified by + heuristically removing temporal and frequency leakage. + + Frequency leakage: at each frequency x time point, power has to exceed neighboring frequencies. + Then it is checked whether the detected time-frequency points belong to + a continuous episode for which (1) the frequency maximally changes by + +/- n steps (cfg.eBOSC.fstp) from on time point to the next and (2) that is at + least as long as n number of cycles (cfg.eBOSC.threshold.duration) of the average freqency + of that episode (a priori duration threshold). + + Temporal leakage: The impact of the amplitude at each time point within a rhythmic episode on previous + and following time points is tested with the goal to exclude supra-threshold time + points that are due to the wavelet extension in time. + + Parameters + ---------- + cfg_eBOSC : dict + Configuration structure with eBOSC field. + TFR : np.ndarray + Time-frequency matrix (excl. WLpadding). + detected : np.ndarray + Detected oscillations in TFR (based on power and duration threshold). + eBOSC : dict + Main eBOSC output structure; necessary to read in + prior eBOSC.episodes if they exist in a loop. + + Returns + ------- + tuple + Tuple containing (episodesTable, detected_new). + episodesTable contains: + - Trial: trial index (corresponds to cfg.eBOSC.trial) + - Channel: channel index + - FrequencyMean: mean frequency of episode (Hz) + - DurationS: episode duration (in sec) + - DurationC: episode duration (in cycles, based on mean frequency) + - PowerMean: mean amplitude of amplitude + - Onset: episode onset in s + - Offset: episode offset in s + - Power: (list) time-resolved wavelet-based amplitude estimates during episode + - Frequency: (list) time-resolved wavelet-based frequency + - RowID: (list) row index (frequency dimension) + - ColID: (list) column index (time dimension) + - SNR: (list) time-resolved signal-to-noise ratio + - SNRMean: mean signal-to-noise ratio """ # initialize dictionary to save results in @@ -2570,36 +2857,44 @@ def eBOSC_episode_create(cfg_eBOSC,TFR,detected,eBOSC): return episodesTable, detected_new -def eBOSC_wrapper(cfg_eBOSC, data): +def eBOSC_wrapper(cfg_eBOSC: dict, data: pd.DataFrame) -> Tuple[dict, dict]: """Main eBOSC wrapper function. Executes eBOSC subfunctions. - Inputs: - cfg_eBOSC | dictionary containing the following entries: - F | frequency sampling - wavenumber | wavelet family parameter (time-frequency tradeoff) - fsample | current sampling frequency of EEG data - pad.tfr_s | padding following wavelet transform to avoid edge artifacts in seconds (bi-lateral) - pad.detection_s | padding following rhythm detection in seconds (bi-lateral); 'shoulder' for BOSC eBOSC.detected matrix to account for duration threshold - pad.total_s | complete padding (WL + shoulder) - pad.background_s | padding of segments for BG (only avoiding edge artifacts) - threshold.excludePeak | lower and upper bound of frequencies to be excluded during background fit (Hz) (previously: LowFreqExcludeBG HighFreqExcludeBG) - threshold.duration | vector of duration thresholds at each frequency (previously: ncyc) - threshold.percentile | percentile of background fit for power threshold - postproc.use | Post-processing of rhythmic eBOSC.episodes, i.e., wavelet 'deconvolution' (default = 'no') - postproc.method | Deconvolution method (default = 'MaxBias', FWHM: 'FWHM') - postproc.edgeOnly | Deconvolution only at on- and offsets of eBOSC.episodes? (default = 'yes') - postproc.effSignal | Power deconvolution on whole signal or signal above power threshold? (default = 'PT') - channel | Subset of channels? (default: [] = all) - trial | Subset of trials? (default: [] = all) - trial_background | Subset of trials for background? (default: [] = all) - data | input time series data as a Pandas DataFrame: - - channels as columns - - multiindex containing: 'time', 'epoch', - Outputs: - eBOSC | main eBOSC output dictionary containing the following entries: - episodes | Dictionary: individual rhythmic episodes (see eBOSC_episode_create) - detected | DataFrame: binary detected time-frequency points (prior to episode creation), pepisode = temporal average - detected_ep | DataFrame: binary detected time-frequency points (following episode creation), abundance = temporal average - cfg | config structure (see input) + + Parameters + ---------- + cfg_eBOSC : dict + Configuration dictionary containing the following entries: + - F: frequency sampling + - wavenumber: wavelet family parameter (time-frequency tradeoff) + - fsample: current sampling frequency of EEG data + - pad.tfr_s: padding following wavelet transform to avoid edge artifacts in seconds (bi-lateral) + - pad.detection_s: padding following rhythm detection in seconds (bi-lateral); 'shoulder' for BOSC eBOSC.detected matrix to account for duration threshold + - pad.total_s: complete padding (WL + shoulder) + - pad.background_s: padding of segments for BG (only avoiding edge artifacts) + - threshold.excludePeak: lower and upper bound of frequencies to be excluded during background fit (Hz) + - threshold.duration: vector of duration thresholds at each frequency + - threshold.percentile: percentile of background fit for power threshold + - postproc.use: Post-processing of rhythmic eBOSC.episodes, i.e., wavelet 'deconvolution' (default = 'no') + - postproc.method: Deconvolution method (default = 'MaxBias', FWHM: 'FWHM') + - postproc.edgeOnly: Deconvolution only at on- and offsets of eBOSC.episodes? (default = 'yes') + - postproc.effSignal: Power deconvolution on whole signal or signal above power threshold? (default = 'PT') + - channel: Subset of channels? (default: [] = all) + - trial: Subset of trials? (default: [] = all) + - trial_background: Subset of trials for background? (default: [] = all) + data : pd.DataFrame + Input time series data as a Pandas DataFrame with: + - channels as columns + - multiindex containing: 'time', 'epoch' + + Returns + ------- + tuple + Tuple containing (eBOSC, cfg). + eBOSC is the main eBOSC output dictionary containing: + - episodes: Dictionary of individual rhythmic episodes (see eBOSC_episode_create) + - detected: DataFrame of binary detected time-frequency points (prior to episode creation) + - detected_ep: DataFrame of binary detected time-frequency points (following episode creation) + cfg is the config structure (see input) """ # %% get list of channel names (very manual solution, replace if possible) @@ -2757,14 +3052,48 @@ def eBOSC_wrapper(cfg_eBOSC, data): return eBOSC, cfg_eBOSC -def compute_eBOSC_parallel(chan_name, MNE_object, subj_id, elec_df, event_name, ev_dict, conditions, - do_plot=False, save_path='/sc/arion/projects/guLab/Salman/EphysAnalyses', - do_save=False, mean_across_time=False, mean_across_freqs=False, both_dfs=True, **kwargs): - """ - +def compute_eBOSC_parallel(chan_name: str, MNE_object: mne.Epochs, subj_id: str, elec_df: pd.DataFrame, event_name: str, ev_dict: dict, conditions: List[str], + do_plot: bool = False, save_path: str = '/sc/arion/projects/guLab/Salman/EphysAnalyses', + do_save: bool = False, mean_across_time: bool = False, mean_across_freqs: bool = False, both_dfs: bool = True, **kwargs) -> None: + """Parallelize eBOSC computation over many channels simultaneously. + This function is meant to parallelize our BOSC code to be computed over many channels simultaneously and save the results - to individual dataframes. - + to individual dataframes. + + Parameters + ---------- + chan_name : str + Channel name. + MNE_object : mne.Epochs + MNE epochs object. + subj_id : str + Subject ID. + elec_df : pd.DataFrame + Electrode DataFrame. + event_name : str + Event name. + ev_dict : dict + Event dictionary. + conditions : list + List of conditions. + do_plot : bool, optional + Whether to plot. Default is False. + save_path : str, optional + Path to save results. Default is '/sc/arion/projects/guLab/Salman/EphysAnalyses'. + do_save : bool, optional + Whether to save results. Default is False. + mean_across_time : bool, optional + Whether to average across time. Default is False. + mean_across_freqs : bool, optional + Whether to average across frequencies. Default is False. + both_dfs : bool, optional + Whether to create both dataframes. Default is True. + **kwargs + Additional keyword arguments for eBOSC configuration. + + Returns + ------- + None """ if not os.path.exists(f'{save_path}/{subj_id}/scratch/eBOSC/{event_name}/dfs'): diff --git a/LFPAnalysis/statistics_utils.py b/LFPAnalysis/statistics_utils.py index f41b2aa..f2fce79 100644 --- a/LFPAnalysis/statistics_utils.py +++ b/LFPAnalysis/statistics_utils.py @@ -16,21 +16,40 @@ warnings.filterwarnings('ignore') def fit_permuted_model(y_permuted, X): - """ - Convenience function for running backend OLS with surrogates + """Fit OLS model with permuted data. + + Parameters + ---------- + y_permuted + Permuted dependent variable. + X + Design matrix. + + Returns + ------- + np.ndarray + Model parameters. """ return OLS(y_permuted, X).fit().params -def permutation_regression_zscore(data, formula, n_permutations=1000, plot_res=False): - """ - - A quick way to perform single-electrode regression with many permutations: - # Example usage: - # data = pd.DataFrame({'y': y, 'x1': x1, 'x2': x2, 'category': ['A', 'B', 'A', 'B', ...]}) - # formula = 'y ~ x1 + x2 + C(category)' - # results = permutation_regression_zscore(data, formula, plot_res=True) - # print(results) - +def permutation_regression_zscore(data: pd.DataFrame, formula: str, n_permutations: int = 1000, plot_res: bool = False): + """Perform regression with permutation-based z-scores. + + Parameters + ---------- + data : pd.DataFrame + DataFrame containing dependent and independent variables. + formula : str + Regression formula. + n_permutations : int, optional + Number of permutations. Default is 1000. + plot_res : bool, optional + Whether to plot results. Default is False. + + Returns + ------- + pd.DataFrame + Results DataFrame with raw and z-scored statistics. """ # Perform original regression y, X = patsy.dmatrices(formula, data, return_type='dataframe') @@ -123,47 +142,26 @@ def permutation_regression_zscore(data, formula, n_permutations=1000, plot_res=F return results -def shuffle_data_for_mlm(df, - y='tfr', - lower_group='unique_label', - higher_group='participant', - trial_key='trial'): - """ - For mixed-effects models where we have two hierarchies: trials within electrodes, and electrodes within participants. - - A good shuffle will permute the trial-level data within electrode, but do it the same way for each electrode within a participant to preserve - any structure that might exist across electrodes. +def shuffle_data_for_mlm(df: pd.DataFrame, y: str = 'tfr', lower_group: str = 'unique_label', higher_group: str = 'participant', trial_key: str = 'trial'): + """Shuffle data for mixed-effects models preserving hierarchical structure. - Parameters: + Parameters ---------- df : pd.DataFrame - DataFrame containing the data to shuffle. - y : str - Name of the dependent variable to shuffle. - lower_group : str - Name of the lower-level grouping variable. - higher_group : str - Name of the higher-level grouping variable. - trial_key : str - Name of the trial identifier variable. - - Returns: - -------- - surr_df : pd.DataFrame + DataFrame containing data to shuffle. + y : str, optional + Name of dependent variable. Default is 'tfr'. + lower_group : str, optional + Name of lower-level grouping variable. Default is 'unique_label'. + higher_group : str, optional + Name of higher-level grouping variable. Default is 'participant'. + trial_key : str, optional + Name of trial identifier variable. Default is 'trial'. + + Returns + ------- + pd.DataFrame DataFrame with shuffled dependent variable. - - - Example: - -------- - surr_df = shuffle_data_for_mlm(df, - y='tfr', - lower_group='unique_label', - higher_group='participant', - trial_key='trial') - - surr_model = smf.mixedlm(formula, - data=surr_df, - groups=surr_df[lower_group]).fit() """ surr_df = df.copy() @@ -197,39 +195,30 @@ def shuffle_data_for_mlm(df, return surr_df -def generate_surrogate_results(df, - formula = 'tfr ~ 1 + zrpe*phit', - y='tfr', - lower_group='unique_label', - higher_group='participant', - trial_key='trial', - n_permutations=100): +def generate_surrogate_results(df: pd.DataFrame, formula: str = 'tfr ~ 1 + zrpe*phit', y: str = 'tfr', lower_group: str = 'unique_label', higher_group: str = 'participant', trial_key: str = 'trial', n_permutations: int = 100): + """Generate surrogate estimates for mixed-effects model. - """ - Generate surrogate estimates for a mixed-effects model by shuffling the dependent variable within electrodes. - - Parameters: + Parameters ---------- df : pd.DataFrame - DataFrame containing the data to shuffle. - formula : str - Formula to use for the mixed-effects model. - y : str - Name of the dependent variable to shuffle. - lower_group : str - Name of the lower-level grouping variable. - higher_group : str - Name of the higher-level grouping variable. - trial_key : str - Name of the trial identifier variable. - n_iterations : int - Number of surrogate estimates to generate. - - Returns: - -------- - surr_results : list - List of DataFrames containing surrogate estimates for each iteration. - + DataFrame containing data to shuffle. + formula : str, optional + Formula for mixed-effects model. Default is 'tfr ~ 1 + zrpe*phit'. + y : str, optional + Name of dependent variable. Default is 'tfr'. + lower_group : str, optional + Name of lower-level grouping variable. Default is 'unique_label'. + higher_group : str, optional + Name of higher-level grouping variable. Default is 'participant'. + trial_key : str, optional + Name of trial identifier variable. Default is 'trial'. + n_permutations : int, optional + Number of surrogate estimates. Default is 100. + + Returns + ------- + list + List of DataFrames with surrogate estimates. """ surr_results = [] @@ -250,36 +239,26 @@ def generate_surrogate_results(df, return surr_results -def time_resolved_regression_single_channel(smoothed_df=None, - y='tfr', - formula='1 + zrpe*phit', - permute=False, - n_permutations=100): - """ - In this function, if you provide a 2D array of z-scored time-varying neural data and a sert of regressors, - this function will run a time-resolved generalized linear model with the provided regressor dataframe. - - Typically, this timeseries will be HFA, and the default win_len and slide_len reflect this - - timeseries: ndarray, trials x times - regressors: pandas df, index = trials, columns = regressors - +def time_resolved_regression_single_channel(smoothed_df: pd.DataFrame = None, y: str = 'tfr', formula: str = '1 + zrpe*phit', permute: bool = False, n_permutations: int = 100): + """Run time-resolved regression for single channel. + Parameters ---------- - timeseries : 2D ndarray, dimensions = trials x times - Time-varying neural data. - regressors : pandas.DataFrame, dimensions = trials x regressors - Dataframe containing the regressors. - win_len : int - Length of the window for the time-resolved regression. - slide_len : int - Step size for the time-resolved regression. - standardize : bool - Whether to standardize the regressors. The default is True. - smooth : bool - Whether to bin the timeseries according to win_len and slide_len. The default is Fault. - sr: int - sampling rate to determine the proper timing of the resulting timeseries of coefficients + smoothed_df : pd.DataFrame, optional + DataFrame with time-varying neural data. + y : str, optional + Name of dependent variable. Default is 'tfr'. + formula : str, optional + Regression formula. Default is '1 + zrpe*phit'. + permute : bool, optional + Whether to use permutation testing. Default is False. + n_permutations : int, optional + Number of permutations. Default is 100. + + Returns + ------- + pd.DataFrame + Results DataFrame with regression statistics. """ # # Optional: bin the data @@ -334,9 +313,32 @@ def time_resolved_regression_single_channel(smoothed_df=None, import statsmodels.formula.api as smf from tqdm import tqdm -def process_single_timepoint(ts, smoothed_df, formula, lower_group, y, higher_group, trial_key, n_permutations): - """ - Processes a single timepoint. +def process_single_timepoint(ts, smoothed_df: pd.DataFrame, formula: str, lower_group: str, y: str, higher_group: str, trial_key: str, n_permutations: int): + """Process single timepoint for mixed-effects model. + + Parameters + ---------- + ts + Timepoint value. + smoothed_df : pd.DataFrame + DataFrame with time-varying data. + formula : str + Mixed-effects model formula. + lower_group : str + Lower-level grouping variable. + y : str + Dependent variable name. + higher_group : str + Higher-level grouping variable. + trial_key : str + Trial identifier variable. + n_permutations : int + Number of permutations. + + Returns + ------- + pd.DataFrame + Results DataFrame for this timepoint. """ model_df = smoothed_df[smoothed_df.ts == ts] test_model = smf.mixedlm(formula, @@ -378,16 +380,32 @@ def process_single_timepoint(ts, smoothed_df, formula, lower_group, y, higher_gr return results -def time_resolved_mlm(smoothed_df, - y='tfr', - formula='tfr ~ 1 + zrpe*phit', - lower_group='unique_label', - higher_group='participant', - trial_key='trial', - n_permutations=100, - n_jobs=-1): - """ - Parallelized version of the function with progress bar. +def time_resolved_mlm(smoothed_df: pd.DataFrame, y: str = 'tfr', formula: str = 'tfr ~ 1 + zrpe*phit', lower_group: str = 'unique_label', higher_group: str = 'participant', trial_key: str = 'trial', n_permutations: int = 100, n_jobs: int = -1): + """Run time-resolved mixed-effects model with parallelization. + + Parameters + ---------- + smoothed_df : pd.DataFrame + DataFrame with time-varying data. + y : str, optional + Dependent variable name. Default is 'tfr'. + formula : str, optional + Mixed-effects model formula. Default is 'tfr ~ 1 + zrpe*phit'. + lower_group : str, optional + Lower-level grouping variable. Default is 'unique_label'. + higher_group : str, optional + Higher-level grouping variable. Default is 'participant'. + trial_key : str, optional + Trial identifier variable. Default is 'trial'. + n_permutations : int, optional + Number of permutations. Default is 100. + n_jobs : int, optional + Number of parallel jobs. Default is -1. + + Returns + ------- + pd.DataFrame + Results DataFrame with all timepoints. """ unique_ts = smoothed_df.ts.unique() diff --git a/LFPAnalysis/sync_utils.py b/LFPAnalysis/sync_utils.py index b6452db..64bedbe 100644 --- a/LFPAnalysis/sync_utils.py +++ b/LFPAnalysis/sync_utils.py @@ -8,9 +8,13 @@ # Might be nice to synergize with https://github.com/alexrockhill/pd-parser to see if there's some improvements to be made -def get_behav_ts(logfile): - """ - Insert custom function to extract the behavioral timestamps from the logfile for your task. +def get_behav_ts(logfile): + """Extract behavioral timestamps from logfile. + + Parameters + ---------- + logfile + Logfile to extract timestamps from. """ pass @@ -43,9 +47,22 @@ def moving_average(a, n=11) : # r = c[0, 0] / (np.std(x) * np.std(y)) # return r -def get_neural_ts_photodiode(mne_sync, smoothSize=11, height=0.5): - """ - get neural ts from photodiode +def get_neural_ts_photodiode(mne_sync, smoothSize: int = 11, height: float = 0.5): + """Extract neural timestamps from photodiode signal. + + Parameters + ---------- + mne_sync + MNE sync data object. + smoothSize : int, optional + Smoothing window size. Default is 11. + height : float, optional + Threshold for detecting rising edge. Default is 0.5. + + Returns + ------- + np.ndarray + Neural timestamps. """ sig = np.squeeze(moving_average(mne_sync._data, n=smoothSize)) @@ -60,30 +77,37 @@ def get_neural_ts_photodiode(mne_sync, smoothSize=11, height=0.5): return neural_ts def get_neural_ts_ttl(nev_data): - """ - get neural ts from ttl recording on nlx + """Extract neural timestamps from TTL recording. + + Parameters + ---------- + nev_data : dict + NEV data dictionary containing records. + + Returns + ------- + np.ndarray + Neural timestamps in seconds. """ return nev_data['records']['TimeStamp'][nev_data['records']['ttl']==1] * 1e-6 -def pulsealign(beh_ms=None, - pulses=None, - windSize=15): - """ - Aligns the behavioral timestamps with the EEG pulses by finding the chunks of behavioral pulse times - where the inter-pulse intervals are correlated with the EEG pulses. - +def pulsealign(beh_ms=None, pulses=None, windSize: int = 15): + """Align behavioral timestamps with EEG pulses. + Parameters ---------- - beh_ms (np.ndarray): A vector of ms times extracted from the log file. - pulses (np.ndarray): Vector of EEG pulses extracted from the EEG. - windSize (int): The size of the chunks to step through the recorded sync pulses. Default is 15. - + beh_ms : np.ndarray, optional + Vector of ms times extracted from the log file. + pulses : np.ndarray, optional + Vector of EEG pulses extracted from the EEG. + windSize : int, optional + Size of chunks to step through recorded sync pulses. Default is 15. + Returns ------- - A tuple of two np.ndarrays: - - beh_ms: The truncated beh_ms values that match the eeg_offset. - - eeg_offset: The truncated pulses that match the beh_ms. + tuple + Tuple of (beh_ms, eeg_offset) np.ndarrays. """ # these are parameters that one could potentially tweak.... @@ -139,18 +163,19 @@ def pulsealign(beh_ms=None, return good_beh_ms, eeg_offset def sync_matched_pulses(beh_pulse, neural_pulse): - """ - Compute the slope and offset of the linear regression between two sets of pulse timestamps. - - Parameters: - beh_pulse (array-like): The timestamps of the behavioral pulses. - neural_pulse (array-like): The timestamps of the neural pulses. - - Returns: - tuple: A tuple containing the slope, offset, and correlation coefficient of the linear regression. - - Note: Idea is similar to this: https://github.com/mne-tools/mne-python/blob/main/mne/preprocessing/realign.py#L13-L111 - + """Compute slope and offset of linear regression between pulse timestamps. + + Parameters + ---------- + beh_pulse : array-like + Timestamps of behavioral pulses. + neural_pulse : array-like + Timestamps of neural pulses. + + Returns + ------- + tuple + Tuple containing (slope, offset, rval). """ bfix = beh_pulse[0] res = scipy.stats.linregress(beh_pulse-bfix, neural_pulse) @@ -161,7 +186,27 @@ def sync_matched_pulses(beh_pulse, neural_pulse): return slope, offset, rval -def synchronize_data_robust(beh_ts=None, neural_ts=None, window_size=15, step_size=1, correlation_threshold=0.99): +def synchronize_data_robust(beh_ts=None, neural_ts=None, window_size: int = 15, step_size: int = 1, correlation_threshold: float = 0.99): + """Robustly synchronize behavioral and neural timestamps. + + Parameters + ---------- + beh_ts : array-like, optional + Behavioral timestamps. + neural_ts : array-like, optional + Neural timestamps. + window_size : int, optional + Window size for matching. Default is 15. + step_size : int, optional + Step size for iteration. Default is 1. + correlation_threshold : float, optional + Correlation threshold for matching. Default is 0.99. + + Returns + ------- + tuple + Tuple containing (slope, offset, rval). + """ # Calculate differences between consecutive timestamps neural_diff = np.diff(neural_ts) beh_diff = np.diff(beh_ts) @@ -215,33 +260,33 @@ def synchronize_data_robust(beh_ts=None, neural_ts=None, window_size=15, step_si return slope, offset, rval -def synchronize_data(beh_ts=None, mne_sync=None, - smoothSize=11, windSize=15, height=0.5, sync_source='photodiode'): - """ - Synchronize the behavioral timestamps from the logfile and the mne photodiode data and return the slope and offset for the session. - - Parameters: - beh_ts (array-like): The timestamps of the behavioral events. - mne_sync: The MNE photodiode data OR the nev data for TTL (UIowa) - smoothSize (int): The size of the smoothing window for the photodiode data. - windSize (int): The size of the window for pulse alignment. - height (float): The threshold for detecting the rising edge of the photodiode signal. - sync_source (str): the type of signal used to sync the data - - Returns: - tuple: A tuple containing the slope and offset of the linear regression between the behavioral and neural timestamps. - - Raises: - ValueError: If the synchronization fails. - - Note: - - The function uses a moving average filter to smooth the photodiode data. - - The function uses a z-score normalization for the photodiode data. - - The function uses a threshold to detect the rising edge of the photodiode signal. - - The function uses pulse alignment to match the behavioral and neural timestamps. - - The function uses linear regression to compute the slope and offset of the synchronization. - - The function increases the window size for pulse alignment until the correlation coefficient of the linear regression is greater than or equal to 0.99. - - The function raises a ValueError if the synchronization fails. +def synchronize_data(beh_ts=None, mne_sync=None, smoothSize: int = 11, windSize: int = 15, height: float = 0.5, sync_source: str = 'photodiode'): + """Synchronize behavioral timestamps with MNE photodiode data. + + Parameters + ---------- + beh_ts : array-like, optional + Timestamps of behavioral events. + mne_sync + MNE photodiode data or NEV data for TTL. + smoothSize : int, optional + Smoothing window size. Default is 11. + windSize : int, optional + Window size for pulse alignment. Default is 15. + height : float, optional + Threshold for detecting rising edge. Default is 0.5. + sync_source : str, optional + Type of signal used to sync data. Default is 'photodiode'. + + Returns + ------- + tuple + Tuple containing (slope, offset). + + Raises + ------ + ValueError + If synchronization fails. """ if isinstance(sync_source, str):