Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update suggest_bads() for OMEGA data #269

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions jumeg/jumeg_suggest_bads.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def clustered_afp(epochs, sensitivity_steps, fraction, mode='adaptive',
return afps, afp_suspects, afp_nearest_neighbour, zlimit_afp


def clustered_psd(epochs, sensitivity_psd, picks, min_samples=1, n_jobs = None):
def clustered_psd(epochs, sensitivity_psd, picks=None, min_samples=1, n_jobs=None):
"""
Perform clustering on PSDs to identify bad channels.

Expand All @@ -145,7 +145,7 @@ def clustered_psd(epochs, sensitivity_psd, picks, min_samples=1, n_jobs = None):
sensitivity_psd: float in range of [0,100]
Percentile to compute threshold used for clustering PSDs,
which must be between 0 and 100 inclusive.
picks: list
picks: None | list
Picks of the channels to be used.
min_samples: int
Number of samples to be chosen for DBSCAN clustering.
Expand Down Expand Up @@ -330,7 +330,7 @@ def plot_autosuggest_summary(afp_nearest_neighbour, psd_nearest_neighbour,

def suggest_bads(raw, sensitivity_steps=97, sensitivity_psd=95,
fraction=0.001, epoch_length=None, summary_plot=False,
show_raw=False, n_jobs = 1, validation=True):
show_raw=False, clipping=1.5, n_jobs=1, validation=True):
"""
Function to suggest bad channels. The bad channels are identified using
time domain methods looking for sharp jumps in short windows of data and
Expand Down Expand Up @@ -365,7 +365,7 @@ def suggest_bads(raw, sensitivity_steps=97, sensitivity_psd=95,

raw = check_read_raw(raw, preload=False)
picks = mne.pick_types(raw.info, meg=True, eeg=False, eog=False,
ecg=False, exclude=[])
ecg=False, ref_meg=False, exclude=[])
# if epoch length is not provided, chose a suitable length
if not epoch_length:
epoch_length = int(raw.n_times/(raw.info['sfreq'] * 20))
Expand All @@ -375,15 +375,14 @@ def suggest_bads(raw, sensitivity_steps=97, sensitivity_psd=95,
duration=epoch_length)
epochs = mne.Epochs(raw, events, event_id=42, tmin=-epoch_length/2,
tmax=epoch_length/2, picks=picks)
picks_bad = [raw.ch_names.index(l) for l in raw.info['bads']]

# compute differences in time domain to identify abrupt jumps in the data
afps, afp_suspects, afp_nearest_neighbour, zlimit_afp = \
clustered_afp(epochs, sensitivity_steps, fraction, n_jobs = n_jobs)

# compute the psds and do the clustering to identify unusual channels
psds, psd_suspects, psd_nearest_neighbour, zlimit_psd = \
clustered_psd(epochs, sensitivity_psd, picks, n_jobs = n_jobs)
clustered_psd(epochs, sensitivity_psd, n_jobs=n_jobs)

# if any of the channels' psds are all zeros, mark as suspect
zero_suspects = [ind for ind in range(psds.shape[1]) if not np.any(psds[:, ind, :])]
Expand All @@ -394,25 +393,26 @@ def suggest_bads(raw, sensitivity_steps=97, sensitivity_psd=95,
[item for sublist in afp_suspects for item in sublist]))

# get the bads suggested but not previosuly marked
picks_fp = [x for x in set(picks_autodetect) if x not in set(picks_bad)]
picks_fp = [x for x in set(picks_autodetect) if epochs.ch_names[x] not in raw.info['bads']]

# marks are all channels of interest, including premarked bad channels
# and zero channels (channel indices)

jumps = list(set([item for sublist in afp_suspects for item in sublist]))
jumps_ch_names = [raw.ch_names[i] for i in jumps]
jumps_ch_names = [epochs.ch_names[i] for i in jumps]
unusual = list(set([item for sublist in psd_suspects for item in sublist]))
unusual_ch_names = [raw.ch_names[i] for i in unusual]
dead_ch_names = [raw.ch_names[i] for i in zero_suspects]
unusual_ch_names = [epochs.ch_names[i] for i in unusual]
dead_ch_names = [epochs.ch_names[i] for i in zero_suspects]

print("Suggested bads [jumps]:", jumps_ch_names)
print("Suggested bads [unusual]:", unusual_ch_names)
print("Suggested bads [dead]:", dead_ch_names)

picks_bad = [epochs.ch_names.index(x) for x in raw.info['bads'] if x in epochs.ch_names]
marks = list(set(picks_autodetect) | set(picks_bad) | set(zero_suspects))

# show summary plot for enhanced manual inspection
#TODO zero suspects do not have any colour coding for the moment
# TODO zero suspects do not have any colour coding for the moment
if summary_plot:
fig = \
plot_autosuggest_summary(afp_nearest_neighbour, psd_nearest_neighbour,
Expand All @@ -423,13 +423,17 @@ def suggest_bads(raw, sensitivity_steps=97, sensitivity_psd=95,
fig.show()

# channel names in str
suggested = [raw.ch_names[i] for i in marks]
# add suggested channels to the raw.info
raw.info['bads'] = suggested
print('\nOriginal bad channels: ', raw.info['bads'])
suggested = [epochs.ch_names[i] for i in marks if i not in picks_bad]
suggested.sort()
print('Suggested bad channels: ', suggested)

# add suggested channels to the raw.info
raw.info['bads'] += suggested
raw.info['bads'].sort()

if show_raw:
raw.plot(block=True)
raw.plot(block=True, clipping=clipping)
visual = raw.info['bads']
visual.sort()
print('Bad channels after visual inspection: ', visual)
Expand Down