Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 158 additions & 87 deletions LFPAnalysis/analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,22 @@


# There are some things that MNE is not that good at, or simply does not do. Let's write our own code for these.
def select_rois_picks(elec_data, chan_name, manual_col='collapsed_manual'):
def select_rois_picks(elec_data: pd.DataFrame, chan_name: str, manual_col: str = 'collapsed_manual'):
"""Select ROI for specific channel.

"""
Grab specific roi for the channel you are looking at
Parameters
----------
elec_data : pd.DataFrame
Electrode data DataFrame.
chan_name : str
Channel name.
manual_col : str, optional
Manual column name. Default is 'collapsed_manual'.

Returns
-------
str
ROI label.
"""

# Load the YBA ROI labels, custom assigned by Salman:
Expand Down Expand Up @@ -105,9 +117,20 @@ def select_rois_picks(elec_data, chan_name, manual_col='collapsed_manual'):

return roi

def select_picks_rois(elec_data, roi=None):
"""
Grab specific electrodes that you care about
def select_picks_rois(elec_data: pd.DataFrame, roi=None):
"""Select electrodes for specific ROI.

Parameters
----------
elec_data : pd.DataFrame
Electrode data DataFrame.
roi : str or list, optional
ROI name or list of ROI names.

Returns
-------
list
List of electrode labels.
"""

# Site specific processing:
Expand Down Expand Up @@ -142,16 +165,27 @@ def select_picks_rois(elec_data, roi=None):

return picks

def lfp_sta(ev_times, signal, sr, pre, post):
'''
Compute the STA for a vector of stimuli.

Input:
spikes - raw spike times used to compute STA, should be in s
signal - signal for averaging. can be filtered or unfiltered.
bound - bound of the STA in ms, +- this number
def lfp_sta(ev_times: np.ndarray, signal: np.ndarray, sr: float, pre: float, post: float):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Scope Error: Variable Name Mismatch

The variable elec_df is used but the function parameter is named elec_data. This causes a NameError since elec_df is not defined in the function scope. The code should use elec_data instead of elec_df.

Fix in Cursor Fix in Web

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Naming Error Prevents Code Execution

The variable elec_df is used but the function parameter is named elec_data. This causes a NameError since elec_df is not defined in the function scope. The code should use elec_data instead of elec_df.

Fix in Cursor Fix in Web

"""Compute spike-triggered average for LFP signal.

Parameters
----------
ev_times : np.ndarray
Event times in seconds.
signal : np.ndarray
Signal for averaging.
sr : float
Sampling rate.
pre : float
Pre-event window in seconds.
post : float
Post-event window in seconds.

'''
Returns
-------
tuple
Tuple containing (sta, ste).
"""

num_evs = len(ev_times)
ev_in_samples = (ev_times * sr).astype(int)
Expand All @@ -175,10 +209,28 @@ def lfp_sta(ev_times, signal, sr, pre, post):
return sta, ste


def plot_TFR(data, freqs, pre_win, post_win, sr, title):
"""

pre_win should be in seconds
def plot_TFR(data: np.ndarray, freqs: np.ndarray, pre_win: float, post_win: float, sr: float, title: str):
"""Plot time-frequency representation.

Parameters
----------
data : np.ndarray
TFR data array.
freqs : np.ndarray
Frequency array.
pre_win : float
Pre-window in seconds.
post_win : float
Post-window in seconds.
sr : float
Sampling rate.
title : str
Plot title.

Returns
-------
matplotlib.figure.Figure
Figure object.
"""

f, tfr = plt.subplots(1, 1, figsize=[7, 4], dpi=300)
Expand All @@ -199,16 +251,28 @@ def plot_TFR(data, freqs, pre_win, post_win, sr, title):

return f

def detect_fast_burst_evs(mne_data,
baseline_data,
burst_frequency = (70, 200),
smooth_win_s=0.02,
sd_upper_cutoff=6,
sd_lower_cutoff=1):
"""
def detect_fast_burst_evs(mne_data, baseline_data, burst_frequency: tuple = (70, 200), smooth_win_s: float = 0.02, sd_upper_cutoff: float = 6, sd_lower_cutoff: float = 1):
"""Detect fast burst events in HFA band.

HFA band: 70-200 Hz
Ripple range: 80-120
Parameters
----------
mne_data
MNE epochs object.
baseline_data
Baseline MNE epochs object.
burst_frequency : tuple, optional
Frequency range for burst detection. Default is (70, 200).
smooth_win_s : float, optional
Smoothing window in seconds. Default is 0.02.
sd_upper_cutoff : float, optional
Upper SD cutoff. Default is 6.
sd_lower_cutoff : float, optional
Lower SD cutoff. Default is 1.

Returns
-------
dict
Dictionary of burst events per channel.
"""


Expand Down Expand Up @@ -534,39 +598,35 @@ def detect_fast_burst_evs(mne_data,
# # then reject any electrode with a low ripple count (< 20 ripples detected per electrode per task) or high rejection rate (greater than 30% rejection rate)
# return allts, ripple_categories, ripple_psds

def FOOOF_continuous(signal):
"""
TODO
def FOOOF_continuous(signal: np.ndarray):
"""Compute FOOOF on continuous signal.

Parameters
----------
signal : np.ndarray
Continuous signal.
"""
pass


def FOOOF_compute_epochs(epochs, tmin=0, tmax=1.5, **kwargs):
"""

This function is meant to enable easy computation of FOOOF.

def FOOOF_compute_epochs(epochs, tmin: float = 0, tmax: float = 1.5, **kwargs):
"""Compute FOOOF on epoched data.

Parameters
----------
epochs : mne Epochs object
mne object

tmin : time to start (s)
float

tmax : time to end (s)
float

band_dict : definitions of the bands of interest
dict

kwargs : input arguments to the FOOOFGroup function, including: 'min_peak_height', 'peak_threshold', 'max_n_peaks'
dict

epochs
MNE Epochs object.
tmin : float, optional
Start time in seconds. Default is 0.
tmax : float, optional
End time in seconds. Default is 1.5.
**kwargs
Additional FOOOFGroup arguments.

Returns
-------
mne_data_reref : mne object
mne object with re-referenced data
tuple
Tuple containing (FOOOFGroup_res, pd.DataFrame).
"""

# bands = fooof.bands.Bands(band_dict)
Expand Down Expand Up @@ -825,39 +885,40 @@ def FOOOF_compute_epochs(epochs, tmin=0, tmax=1.5, **kwargs):

# We put all of our basic FOOOF usage into a slightly clunky function that is meant to be used for running the regression
# over multiple channels in parallel using joblib/Dask/multiprocessing.Pool:
def compute_FOOOF_parallel(chan_name, MNE_object, subj_id, elec_df, event_name, ev_dict, band_dict, conditions,
do_plot=False, save_path='/sc/arion/projects/guLab/Salman/EphysAnalyses',
do_save=False, **kwargs):
"""
Compute FOOOF for a single channel across all trials and for each condition of interest.
Meant to be used in parallel, hence a little clunky.

def compute_FOOOF_parallel(chan_name: str, MNE_object, subj_id: str, elec_df: pd.DataFrame, event_name: str, ev_dict: dict, band_dict: dict, conditions: list, do_plot: bool = False, save_path: str = '/sc/arion/projects/guLab/Salman/EphysAnalyses', do_save: bool = False, **kwargs):
"""Compute FOOOF for single channel in parallel.

Parameters
----------
----------
chan_name : str
Name of the channel to compute FOOOF for
MNE_object : mne.Epochs
MNE object containing the data
Channel name.
MNE_object
MNE Epochs object.
subj_id : str
Subject ID
Subject ID.
elec_df : pd.DataFrame
DataFrame containing the electrode information
event : str
Event to compute FOOOF for
Electrode DataFrame.
event_name : str
Event name.
ev_dict : dict
Dictionary containing the start and end times for each event
Event time dictionary.
band_dict : dict
Dictionary containing the frequency bands to compute FOOOF for
Frequency band dictionary.
conditions : list
List of conditions to compute FOOOF for
do_plot : bool
Whether to plot the FOOOF results
save_path : str
Path to save the FOOOF results
do_save : bool
Whether to save the FOOOF results
**kwargs : dict
Additional arguments to pass to FOOOF_compute_epochs
List of conditions.
do_plot : bool, optional
Whether to plot. Default is False.
save_path : str, optional
Save path. Default is '/sc/arion/projects/guLab/Salman/EphysAnalyses'.
do_save : bool, optional
Whether to save. Default is False.
**kwargs
Additional FOOOF arguments.

Returns
-------
pd.DataFrame or None
Results DataFrame if not saving.
"""

# First, compute FOOOF across all trials
Expand Down Expand Up @@ -917,19 +978,29 @@ def compute_FOOOF_parallel(chan_name, MNE_object, subj_id, elec_df, event_name,
return chan_df


def sliding_FOOOF(signal):
"""
Implement time-resolved FOOOF:
https://github.com/lucwilson/SPRiNT now has a python implementation we can borrow from!
def sliding_FOOOF(signal: np.ndarray):
"""Compute time-resolved FOOOF.


Parameters
----------
signal : np.ndarray
Signal array.
"""
pass


def hctsa_signal_features(signal):
"""
Implement https://github.com/DynamicsAndNeuralSystems/catch22
def hctsa_signal_features(signal: np.ndarray):
"""Extract catch22 signal features.

Parameters
----------
signal : np.ndarray
Signal array.

Returns
-------
pd.DataFrame
DataFrame with signal features.
"""

signal_features = pycatch22.catch22_all(signal)
Expand Down
43 changes: 33 additions & 10 deletions LFPAnalysis/iowa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,18 @@
from itertools import chain
from LFPAnalysis import lfp_preprocess_utils

def extract_names_connect_table(connect_table_path):
"""
Utility function for extracting channel types from Iowa connection table
def extract_names_connect_table(connect_table_path: str):
"""Extract channel types from Iowa connection table.

Parameters
----------
connect_table_path : str
Path to the connection table CSV file.

Returns
-------
tuple
Tuple containing (eeg_names, resp_names, ekg_names, seeg_names, drop_names).
"""

connect_table = pd.read_csv(connect_table_path)
Expand Down Expand Up @@ -95,10 +104,18 @@ def extract_names_connect_table(connect_table_path):

return eeg_names, resp_names, ekg_names, seeg_names, drop_names

def extract_names_elec_table(elec_table_path):
"""
In some instances we just have Kiril's electrode table. In this case, we need a different extractor for
the data
def extract_names_elec_table(elec_table_path: str):
"""Extract channel names from electrode table.

Parameters
----------
elec_table_path : str
Path to the electrode table file.

Returns
-------
list
List of sEEG channel names.
"""

elec_data = lfp_preprocess_utils.load_elec(elec_table_path, site='UI')
Expand Down Expand Up @@ -134,9 +151,15 @@ def extract_names_elec_table(elec_table_path):

# return mapping_name

def rename_mne_channels(mne_data, location_table_path):
def rename_mne_channels(mne_data, location_table_path: str):
"""Rename MNE channels based on location table.

Parameters
----------
mne_data
MNE data object.
location_table_path : str
Path to the location table CSV file.
"""
"""

location_table = pd.read_csv(location_table_path)

Loading
Loading