Skip to content

Commit

Permalink
Merge pull request #76 from kymata-atlas/ollie
Browse files Browse the repository at this point in the history
Gridsearch merge
  • Loading branch information
neukym authored Jan 12, 2024
2 parents 175c308 + 9cd5eb1 commit 3a48a46
Show file tree
Hide file tree
Showing 16 changed files with 668 additions and 12 deletions.
95 changes: 95 additions & 0 deletions invokers/run_gridsearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from pathlib import Path
import argparse
from kymata.gridsearch.gridsearch import do_gridsearch
from kymata.io.functions import load_function
from kymata.io.mne import load_emeg_pack


def main():

parser = argparse.ArgumentParser(description='Gridsearch Params')
parser.add_argument('--emeg_sample_rate', type=int, default=1000,
help='sampling rate of the emeg machine (not implemented yet)')
parser.add_argument('--snr', type=float, default=3,
help='inverse solution snr')
parser.add_argument('--downsample_rate', type=int, default=5,
help='downsample_rate')
parser.add_argument('--base_dir', type=str, default="/imaging/projects/cbu/kymata/data/dataset_4-english-narratives/",
help='base data directory')
parser.add_argument('--data_path', type=str, default="intrim_preprocessing_files/3_trialwise_sensorspace/evoked_data",
help='data path after base dir')
parser.add_argument('--function_path', type=str, default="predicted_function_contours/GMSloudness/stimulisig",
help='snr')
parser.add_argument('--function_name', type=str, default="d_IL2",
help='function name in stimulisig')
parser.add_argument('--emeg_file', type=str, default="participant_01-ave",
help='emeg_file_name')
parser.add_argument('--ave_mode', type=str, default="ave",
help='either ave or add, either average over the list of repetitions or treat them as extra data')
parser.add_argument('--inverse_operator', type=str, default="intrim_preprocessing_files/4_hexel_current_reconstruction/inverse-operators",
help='inverse solution path')
parser.add_argument('--seconds_per_split', type=float, default=0.5,
help='seconds in each split of the recording, also maximum range of latencies being checked')
parser.add_argument('--n_splits', type=int, default=800,
help='number of splits to split the recording into, (set to 400/seconds_per_split for full file)')
parser.add_argument('--n_derangements', type=int, default=1,
help='inverse solution snr')
parser.add_argument('--start_latency', type=float, default=-100,
help='earliest latency to check in cross correlation')
parser.add_argument('--emeg_t_start', type=float, default=-200,
help='start of the emeg evoked files relative to the start of the function')
parser.add_argument('--audio_shift_correction', type=float, default=0.000_537_5,
help='audio shift correction, for every second of function, add this number of seconds (to the start of the emeg split) per seconds of emeg seen')
args = parser.parse_args()
args.base_dir = Path(args.base_dir)


emeg_dir = Path(args.base_dir, args.data_path)
emeg_paths = [Path(emeg_dir, args.emeg_file)]

participants = ['participant_01',
'participant_01b',
'participant_02',
'participant_03',
'participant_04',
'participant_05',
'pilot_01',
'pilot_02']

reps = [f'_rep{i}' for i in range(8)] + ['-ave']

# emeg_paths = [Path(emeg_dir, p + r) for p in participants[:2] for r in reps[-1:]]

inverse_operator = Path(args.base_dir, args.inverse_operator, f"{participants[0]}_ico5-3L-loose02-cps-nodepth.fif")

# Load data
emeg, ch_names = load_emeg_pack(emeg_paths,
need_names=False,
ave_mode=args.ave_mode,
inverse_operator=None, #inverse_operator, # set to None/inverse_operator if you want to run on sensor space/source space
p_tshift=None,
snr=args.snr)

func = load_function(Path(args.base_dir, args.function_path),
func_name=args.function_name,
bruce_neurons=(5, 10))
func = func.downsampled(args.downsample_rate)

es = do_gridsearch(
emeg_values=emeg,
sensor_names=ch_names,
function=func,
seconds_per_split=args.seconds_per_split,
n_derangements=args.n_derangements,
n_splits=args.n_splits,
start_latency=args.start_latency,
emeg_t_start=args.emeg_t_start,
emeg_sample_rate=args.emeg_sample_rate,
audio_shift_correction=args.audio_shift_correction,
ave_mode=args.ave_mode,
)

# expression_plot(es)

if __name__ == '__main__':
main()
33 changes: 33 additions & 0 deletions invokers/run_preprocessing_dataset4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from pathlib import Path

import sys
sys.path.append('/imaging/projects/cbu/kymata/analyses/ollie/kymata-toolbox')

from kymata.io.yaml import load_config
from kymata.preproc.data_cleansing import run_preprocessing, create_trials, create_trialwise_data


# noinspection DuplicatedCode
def main():
config = load_config(str(Path(Path(__file__).parent.parent, "kymata", "config", "dataset4.yaml")))

create_trialwise_data(
dataset_directory_name=config['dataset_directory_name'],
list_of_participants=config['list_of_participants'],
repetitions_per_runs=config['repetitions_per_runs'],
number_of_runs=config['number_of_runs'],
number_of_trials=config['number_of_trials'],
input_streams=config['input_streams'],
eeg_thresh=float(config['eeg_thresh']),
grad_thresh=float(config['grad_thresh']),
mag_thresh=float(config['mag_thresh']),
visual_delivery_latency=config['visual_delivery_latency'],
audio_delivery_latency=config['audio_delivery_latency'],
audio_delivery_shift_correction=config['audio_delivery_shift_correction'],
tmin=config['tmin'],
tmax=config['tmax'],
)


if __name__ == '__main__':
main()
6 changes: 3 additions & 3 deletions kymata/config/dataset3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ supress_excessive_plots_and_prompts: True
audio_delivery_latency: 26 # in milliseconds
visual_delivery_latency: 17 # in milliseconds
tactile_delivery_latency: 0 # in milliseconds
audio_delivery_shift_correction: 0 # in milliseconds per second
tmin: -0.2
tmax: 1.8
audio_delivery_shift_correction: 0 # in seconds per second
tmin: -0.2 # seconds
tmax: 1.8 # seconds

eeg_thresh: 1 #200e-6
grad_thresh: 1 #4000e-13
Expand Down
6 changes: 3 additions & 3 deletions kymata/config/dataset4.yaml
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ meg: True
audio_delivery_latency: 26 # in milliseconds
visual_delivery_latency: 17 # in milliseconds
tactile_delivery_latency: 0 # in milliseconds
audio_delivery_shift_correction: 0.5375 # in milliseconds per second
tmin: -0.2
tmax: 1.8
audio_delivery_shift_correction: 0.0005375 # in seconds per second
tmin: -0.2 # seconds
tmax: 1.8 # seconds

eeg_thresh: 1 #200e-6
grad_thresh: 1 #4000e-13
Expand Down
21 changes: 21 additions & 0 deletions kymata/entities/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dataclasses import dataclass

from numpy.typing import NDArray


@dataclass
class Function:
name: str
values: NDArray
sample_rate: float # Hertz

def downsampled(self, rate: int):
return Function(
name=self.name,
values=self.values[::rate],
sample_rate=self.sample_rate / rate,
)

@property
def time_step(self) -> float:
return 1 / self.sample_rate
5 changes: 3 additions & 2 deletions kymata/entities/iterables.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Sequence

from numpy import ndarray
from numpy.typing import NDArray


def all_equal(sequence: Sequence) -> bool:
Expand All @@ -20,10 +21,10 @@ def all_equal(sequence: Sequence) -> bool:
# numpy arrays deal with equality weirdly

# Check first two items are equal, and equal to the rest
first: ndarray = sequence[0]
first: NDArray = sequence[0]
if not isinstance(sequence[1], ndarray):
return False
second: ndarray = sequence[1]
second: NDArray = sequence[1]
try:
# noinspection PyUnresolvedReferences
return (first == second).all() and all_equal(sequence[1:])
Expand Down
Empty file modified kymata/entities/sparse_data.py
100644 → 100755
Empty file.
Empty file added kymata/gridsearch/__init__.py
Empty file.
190 changes: 190 additions & 0 deletions kymata/gridsearch/gridsearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import numpy as np
from numpy.typing import NDArray
from scipy import stats

from kymata.entities.functions import Function
from kymata.math.combinatorics import generate_derangement
from kymata.math.vector import normalize, get_stds
#from kymata.entities.expression import SensorExpressionSet, p_to_logp
import matplotlib.pyplot as plt

def do_gridsearch(
emeg_values: NDArray, # chan x time
function: Function,
sensor_names: list[str],
start_latency: float, # ms
emeg_t_start: float, # ms
emeg_sample_rate: int = 1000, # Hertz
audio_shift_correction: float = 0.000_537_5, # seconds/second # TODO: describe in which direction?
n_derangements: int = 1,
seconds_per_split: float = 0.5,
n_splits: int = 800,
ave_mode: str = 'ave', # either ave or add, for averaging over input files or adding in as extra evidence
add_autocorr: bool = True,
plot_name: str = 'example'
):
"""
Do the Kymata gridsearch over all hexels for all latencies.
"""

# We'll need to downsample the EMEG to match the function's sample rate
downsample_rate: int = int(emeg_sample_rate / function.sample_rate)

n_samples_per_split = int(seconds_per_split * emeg_sample_rate * 2 // downsample_rate)

if ave_mode == 'add':
n_reps = len(EMEG_paths)
else:
n_reps = 1

func_length = n_splits * n_samples_per_split // 2
if func_length < function.values.shape[0]:
func = function.values[:func_length].reshape(n_splits, n_samples_per_split // 2)
print(f'WARNING: not using full 400s of the file (only using {round(n_splits * seconds_per_split, 2)}s)')
else:
func = function.values.reshape(n_splits, n_samples_per_split // 2)
n_channels = emeg_values.shape[0]

# Reshape EMEG into splits of `seconds_per_split` s
split_initial_timesteps = [int(start_latency + round(i * 1000 * seconds_per_split * (1 + audio_shift_correction)) - emeg_t_start)
for i in range(n_splits)]

emeg_reshaped = np.zeros((n_channels, n_splits * n_reps, n_samples_per_split))
for j in range(n_reps):
for split_i in range(n_splits):
split_start = split_initial_timesteps[split_i]
split_stop = split_start + int(2 * emeg_sample_rate * seconds_per_split)
emeg_reshaped[:, split_i + (j * n_splits), :] = emeg_values[:, j, split_start:split_stop:downsample_rate]

del emeg_values

# Derangements for null distribution
derangements = np.zeros((n_derangements, n_splits * n_reps), dtype=int)
for der_i in range(n_derangements):
derangements[der_i, :] = generate_derangement(n_splits * n_reps, n_splits)
derangements = np.vstack((np.arange(n_splits * n_reps), derangements)) # Include the identity on top

# Fast cross-correlation using FFT
emeg_reshaped = normalize(emeg_reshaped)
emeg_stds = get_stds(emeg_reshaped, n_samples_per_split // 2)
emeg_reshaped = np.fft.rfft(emeg_reshaped, n=n_samples_per_split, axis=-1)
F_func = np.conj(np.fft.rfft(normalize(func), n=n_samples_per_split, axis=-1))
corrs = np.zeros((n_channels, n_derangements + 1, n_splits * n_reps, n_samples_per_split // 2))
for der_i, derangement in enumerate(derangements):
deranged_emeg = emeg_reshaped[:, derangement, :]
corrs[:, der_i] = np.fft.irfft(deranged_emeg * F_func)[:, :, :n_samples_per_split//2] / emeg_stds[:, derangement]

if add_autocorr:
auto_corrs = np.zeros((n_splits, n_samples_per_split//2))
noise = normalize(np.random.randn(func.shape[0], func.shape[1])) * 0
noisy_func = normalize(np.copy(func)) + noise
nn = n_samples_per_split // 2

F_noisy_func = np.fft.rfft(normalize(noisy_func), n=nn, axis=-1)
F_func = np.conj(np.fft.rfft(normalize(func), n=nn, axis=-1))

auto_corrs = np.fft.irfft(F_noisy_func * F_func)

del F_func, deranged_emeg, emeg_reshaped

log_pvalues = _ttest(corrs)

latencies = np.linspace(start_latency, start_latency + (seconds_per_split * 1000), n_samples_per_split // 2 + 1)[:-1]

if plot_name:
plt.figure(1)
corr_avrs = np.mean(corrs[:, 0]**2, axis=-2)
maxs = np.max(corr_avrs, axis=1)
n_amaxs = 5
amaxs = np.argpartition(maxs, -n_amaxs)[-n_amaxs:]
amax = np.argmax(corr_avrs) // (n_samples_per_split // 2)
amaxs = [i for i in amaxs if i != amax] # + [209]

plt.plot(latencies, np.mean(corrs[amax, 0], axis=-2).T, 'r-', label=amax)
plt.plot(latencies, np.mean(corrs[amaxs, 0], axis=-2).T, label=amaxs)
std_null = np.mean(np.std(corrs[:, 1], axis=-2), axis=0).T * 3 / np.sqrt(n_reps * n_splits) # 3 pop std.s
std_real = np.std(corrs[amax, 0], axis=-2).T * 3 / np.sqrt(n_reps * n_splits)
av_real = np.mean(corrs[amax, 0], axis=-2).T
#print(std_null)
plt.fill_between(latencies, -std_null, std_null, alpha=0.5, color='grey')
plt.fill_between(latencies, av_real - std_real, av_real + std_real, alpha=0.25, color='red')

if add_autocorr:
peak_lat_ind = np.argmax(corr_avrs) % (n_samples_per_split // 2)
peak_lat = latencies[peak_lat_ind]
peak_corr = np.mean(corrs[amax, 0], axis=-2)[peak_lat_ind]
print(f'{function.name}: peak lat, peak corr, ind:', peak_lat, peak_corr, amax)

auto_corrs = np.mean(auto_corrs, axis=0)
plt.plot(latencies, np.roll(auto_corrs, peak_lat_ind) * peak_corr / np.max(auto_corrs), 'k--', label='func auto-corr')

plt.axvline(0, color='k')
plt.legend()
plt.xlabel('latencies (ms)')
plt.ylabel('Corr coef.')
plt.savefig(f'{plot_name}_1.png')
plt.clf()

plt.figure(2)
plt.plot(latencies, -log_pvalues[amax].T, 'r-', label=amax)
plt.plot(latencies, -log_pvalues[amaxs].T, label=amaxs)
plt.axvline(0, color='k')
plt.legend()
plt.xlabel('latencies (ms)')
plt.ylabel('p-values')
plt.savefig(f'{plot_name}_2.png')
plt.clf()

return

"""es = SensorExpressionSet(
functions=function.name,
latencies=latencies / 1000,
sensors=sensor_names,
data=log_pvalues,
)"""

return es


def _ttest(corrs: NDArray, use_all_lats: bool = True):
"""
Vectorised Welch's t-test.
"""
n_channels, n_derangements, n_splits, t_steps = corrs.shape

# Fisher Z-Transformation
corrs_z = 0.5 * np.log((1 + corrs) / (1 - corrs))

# Non-deranged values are on top
true_mean = np.mean(corrs_z[:, 0, :, :], axis=1)
true_var = np.var(corrs_z[:, 0, :, :], axis=1, ddof=1)
true_n = n_splits

# Recompute mean and var for null correlations
# TODO: why looking at only 1 in the n_derangements dimension?
if use_all_lats:
rand_mean = np.mean(corrs_z[:, 1:, :, :].reshape(n_channels, -1), axis=1).reshape(n_channels, 1)
rand_var = np.var(corrs_z[:, 1:, :, :].reshape(n_channels, -1), axis=1, ddof=1).reshape(n_channels, 1)
rand_n = n_splits * n_derangements * t_steps
else:
rand_mean = np.mean(corrs_z[:, 1:, :, :].reshape(n_channels, -1, t_steps), axis=1)
rand_var = np.var(corrs_z[:, 1:, :, :].reshape(n_channels, -1, t_steps), axis=1, ddof=1)
rand_n = n_splits * n_derangements

# Vectorized two-sample t-tests for all channels and time steps
numerator = true_mean - rand_mean
denominator = np.sqrt(true_var / true_n + rand_var / rand_n)
df = ((true_var / true_n + rand_var / rand_n) ** 2 /
((true_var / true_n) ** 2 / (true_n - 1) +
(rand_var / rand_n) ** 2 / (rand_n - 1)))

t_stat = numerator / denominator

if np.min(df) <= 300:
log_p = np.log(stats.t.sf(np.abs(t_stat), df) * 2) # two-tailed p-value
else:
# norm v good approx for this, (logsf for t not implemented in logspace)
log_p = stats.norm.logsf(np.abs(t_stat)) + np.log(2)

return log_p / np.log(10) # log base correction
Loading

0 comments on commit 3a48a46

Please sign in to comment.