Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gridsearch nkg #127

Merged
merged 42 commits into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
8071bf7
Style
caiw Jan 15, 2024
77906f2
Gridsearch saves nkg files and uses standard expression-plotting
caiw Jan 15, 2024
8e311fc
Better error logging
caiw Jan 15, 2024
0340f30
Clean up style
caiw Jan 15, 2024
025cfe1
Try force getting names
caiw Jan 15, 2024
20e48be
Unpair axes on sensor gridsearch
caiw Jan 15, 2024
bbf013c
Optionally overwrite existing files
caiw Jan 15, 2024
d76b735
Fixed a bug which caused eroneous FileExistErrors when saving figures
caiw Jan 15, 2024
0dc42fb
Review fix: Move default output location to 'output' dir
caiw Jan 15, 2024
729c475
Review fix: move Ollie's paths out of argument defaults and into the …
caiw Jan 15, 2024
887748e
More usual to have hyphens rather than underscores in cli args
caiw Jan 15, 2024
8ad3a1e
Review fix: add reality check plots back in
caiw Jan 15, 2024
8784099
sample invoker arguments in a script
caiw Jan 15, 2024
a4d58ca
Save reality-check plots in output location
caiw Jan 15, 2024
de0f5e0
Bugfix
caiw Jan 15, 2024
2c2c85c
Review fix: create `output` dir if it doesn't already exist
caiw Jan 15, 2024
cc2eb95
Add inverse operator into invoker script
caiw Jan 15, 2024
679e561
Progress bar code (currently unused)
caiw Jan 15, 2024
9d08231
Comment update
caiw Jan 15, 2024
5f9636d
Default sample data location needs to be created if it doesn't alread…
caiw Jan 15, 2024
0bc50f9
Reorder lines for clarity
caiw Jan 15, 2024
0218aec
Move all plotting to plot.py (#130)
neukym Jan 15, 2024
5e0fa44
Bug fix for gridsearch figure
neukym Jan 20, 2024
220fd5c
Update plot.py
neukym Jan 20, 2024
d23c778
Update mne.py
neukym Jan 20, 2024
feac249
Update run_gridsearch.py
neukym Jan 20, 2024
ea6d6ef
Update mne.py
neukym Jan 20, 2024
5e6b70c
Update mne.py
neukym Jan 20, 2024
813d5e1
Fix hexel number bug
neukym Jan 21, 2024
1eabf8b
Update run_gridsearch.py
neukym Jan 21, 2024
68262ed
add function name to begining of outputs
neukym Jan 22, 2024
991f2e1
Got poetry to work with CBU slurm setup
neukym Jan 22, 2024
819ea81
Adds 'time taken' to output
neukym Jan 22, 2024
0ab9c75
Minor changes
neukym Jan 24, 2024
087ec5f
couple of super minor changes
Jan 24, 2024
16064c7
a couple more changes, added in inverse_operator_name
Jan 24, 2024
d3d6dec
Removes "requirements" from invoker
neukym Jan 26, 2024
1229f89
Add s ability to do both hemipheres when inverse operator is selected
neukym Jan 26, 2024
67a23ec
Remove comment from plain
neukym Jan 26, 2024
f71c148
Remove hard-coded path to Andy's toolbox install
caiw Jan 26, 2024
dc0433d
bugfix
neukym Jan 27, 2024
346a4d1
Update submit_gridsearch.sh
neukym Jan 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions invokers/run_gridsearch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from pathlib import Path
import argparse

from kymata.gridsearch.plain import do_gridsearch
from kymata.io.functions import load_function
from kymata.io.mne import load_emeg_pack
from kymata.io.nkg import save_expression_set
from kymata.plot.plot import expression_plot


def main():
Expand All @@ -20,6 +23,11 @@ def main():
help='data path after base dir')
parser.add_argument('--function_path', type=str, default="predicted_function_contours/GMSloudness/stimulisig",
help='snr')
parser.add_argument('--save-expression-set', type=Path, default="gridsearch.nkg",
help="Save the results of the gridsearch into an ExpressionSet .nkg file")
parser.add_argument('--save-plot', type=Path, default="gridsearch.png",
help="Save an expression plot file")
parser.add_argument('--overwrite', action="store_true", help="Silently overwrite existing files.")
parser.add_argument('--function_name', type=str, default="d_IL2",
help='function name in stimulisig')
parser.add_argument('--emeg_file', type=str, default="participant_01-ave",
Expand Down Expand Up @@ -61,23 +69,27 @@ def main():
# emeg_paths = [Path(emeg_dir, p + r) for p in participants[:2] for r in reps[-1:]]

inverse_operator = Path(args.base_dir, args.inverse_operator, f"{participants[0]}_ico5-3L-loose02-cps-nodepth.fif")
inverse_operator = None # set to None/inverse_operator if you want to run on sensor space/source space
caiw marked this conversation as resolved.
Show resolved Hide resolved

# Load data
emeg, ch_names = load_emeg_pack(emeg_paths,
need_names=False,
ave_mode=args.ave_mode,
inverse_operator=None, #inverse_operator, # set to None/inverse_operator if you want to run on sensor space/source space
p_tshift=None,
snr=args.snr)
emeg_values, ch_names = load_emeg_pack(emeg_paths,
need_names=True,
ave_mode=args.ave_mode,
inverse_operator=inverse_operator,
p_tshift=None,
snr=args.snr)

func = load_function(Path(args.base_dir, args.function_path),
func_name=args.function_name,
bruce_neurons=(5, 10))
func = func.downsampled(args.downsample_rate)

channel_space = "source" if inverse_operator is not None else "sensor"

es = do_gridsearch(
emeg_values=emeg,
sensor_names=ch_names,
emeg_values=emeg_values,
channel_names=ch_names,
channel_space=channel_space,
function=func,
seconds_per_split=args.seconds_per_split,
n_derangements=args.n_derangements,
Expand All @@ -89,7 +101,11 @@ def main():
ave_mode=args.ave_mode,
)

# expression_plot(es)
if args.save_expression_set is not None:
save_expression_set(es, args.save_expression_set, overwrite=args.overwrite)

expression_plot(es, paired_axes=channel_space == "source", save_to=args.save_plot, overwrite=args.overwrite)
caiw marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == '__main__':
main()
7 changes: 5 additions & 2 deletions kymata/entities/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def __init__(self,
data = data[i]
data = self._init_prep_data(data)
# Check validity of input data dimensions
assert len(channels) == data.shape[0], f"{channel_coord_name} mismatch for {f}"
assert len(latencies) == data.shape[1], f"Latencies mismatch for {f}"
assert len(channels) == data.shape[0], f"{channel_coord_name} mismatch for {f}: {len(channels)} {channel_coord_name} versus data shape {data.shape}"
assert len(latencies) == data.shape[1], f"Latencies mismatch for {f}: {len(latencies)} latencies versus data shape {data.shape}"
dataset_dict[layer] = DataArray(
data=data,
dims=self._dims,
Expand Down Expand Up @@ -403,6 +403,9 @@ def best_functions(self) -> DataFrame:
return super()._best_functions_for_layer(LAYER_SCALP)


log_base = 10


def p_to_logp(arraylike: ArrayLike) -> ArrayLike:
"""The one-stop-shop for converting from p-values to log p-values."""
return log10(arraylike)
Expand Down
112 changes: 36 additions & 76 deletions kymata/gridsearch/plain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from kymata.entities.functions import Function
from kymata.math.combinatorics import generate_derangement
from kymata.math.vector import normalize, get_stds
#from kymata.entities.expression import SensorExpressionSet, p_to_logp
import matplotlib.pyplot as plt
from kymata.entities.expression import ExpressionSet, SensorExpressionSet, HexelExpressionSet, p_to_logp, log_base


def do_gridsearch(
emeg_values: NDArray, # chan x time
function: Function,
sensor_names: list[str],
channel_names: list,
channel_space: str,
start_latency: float, # ms
emeg_t_start: float, # ms
emeg_sample_rate: int = 1000, # Hertz
Expand All @@ -21,13 +21,15 @@ def do_gridsearch(
seconds_per_split: float = 0.5,
n_splits: int = 800,
ave_mode: str = 'ave', # either ave or add, for averaging over input files or adding in as extra evidence
add_autocorr: bool = True,
plot_name: str = 'example'
):
) -> ExpressionSet:
"""
Do the Kymata gridsearch over all hexels for all latencies.
"""

channel_space = channel_space.lower()
if channel_space not in {"sensor", "source"}:
raise NotImplementedError(channel_space)

# We'll need to downsample the EMEG to match the function's sample rate
downsample_rate: int = int(emeg_sample_rate / function.sample_rate)

Expand All @@ -47,8 +49,10 @@ def do_gridsearch(
n_channels = emeg_values.shape[0]

# Reshape EMEG into splits of `seconds_per_split` s
split_initial_timesteps = [int(start_latency + round(i * 1000 * seconds_per_split * (1 + audio_shift_correction)) - emeg_t_start)
for i in range(n_splits)]
split_initial_timesteps = [
int(start_latency + round(i * 1000 * seconds_per_split * (1 + audio_shift_correction)) - emeg_t_start)
for i in range(n_splits)
]

emeg_reshaped = np.zeros((n_channels, n_splits * n_reps, n_samples_per_split))
for j in range(n_reps):
Expand All @@ -75,75 +79,29 @@ def do_gridsearch(
deranged_emeg = emeg_reshaped[:, derangement, :]
corrs[:, der_i] = np.fft.irfft(deranged_emeg * F_func)[:, :, :n_samples_per_split//2] / emeg_stds[:, derangement]

if add_autocorr:
auto_corrs = np.zeros((n_splits, n_samples_per_split//2))
noise = normalize(np.random.randn(func.shape[0], func.shape[1])) * 0
noisy_func = normalize(np.copy(func)) + noise
nn = n_samples_per_split // 2

F_noisy_func = np.fft.rfft(normalize(noisy_func), n=nn, axis=-1)
F_func = np.conj(np.fft.rfft(normalize(func), n=nn, axis=-1))

auto_corrs = np.fft.irfft(F_noisy_func * F_func)

del F_func, deranged_emeg, emeg_reshaped

log_pvalues = _ttest(corrs)

latencies = np.linspace(start_latency, start_latency + (seconds_per_split * 1000), n_samples_per_split // 2 + 1)[:-1]

if plot_name:
caiw marked this conversation as resolved.
Show resolved Hide resolved
plt.figure(1)
corr_avrs = np.mean(corrs[:, 0]**2, axis=-2)
maxs = np.max(corr_avrs, axis=1)
n_amaxs = 5
amaxs = np.argpartition(maxs, -n_amaxs)[-n_amaxs:]
amax = np.argmax(corr_avrs) // (n_samples_per_split // 2)
amaxs = [i for i in amaxs if i != amax] # + [209]

plt.plot(latencies, np.mean(corrs[amax, 0], axis=-2).T, 'r-', label=amax)
plt.plot(latencies, np.mean(corrs[amaxs, 0], axis=-2).T, label=amaxs)
std_null = np.mean(np.std(corrs[:, 1], axis=-2), axis=0).T * 3 / np.sqrt(n_reps * n_splits) # 3 pop std.s
std_real = np.std(corrs[amax, 0], axis=-2).T * 3 / np.sqrt(n_reps * n_splits)
av_real = np.mean(corrs[amax, 0], axis=-2).T
#print(std_null)
plt.fill_between(latencies, -std_null, std_null, alpha=0.5, color='grey')
plt.fill_between(latencies, av_real - std_real, av_real + std_real, alpha=0.25, color='red')

if add_autocorr:
peak_lat_ind = np.argmax(corr_avrs) % (n_samples_per_split // 2)
peak_lat = latencies[peak_lat_ind]
peak_corr = np.mean(corrs[amax, 0], axis=-2)[peak_lat_ind]
print(f'{function.name}: peak lat, peak corr, ind:', peak_lat, peak_corr, amax)

auto_corrs = np.mean(auto_corrs, axis=0)
plt.plot(latencies, np.roll(auto_corrs, peak_lat_ind) * peak_corr / np.max(auto_corrs), 'k--', label='func auto-corr')

plt.axvline(0, color='k')
plt.legend()
plt.xlabel('latencies (ms)')
plt.ylabel('Corr coef.')
plt.savefig(f'{plot_name}_1.png')
plt.clf()

plt.figure(2)
plt.plot(latencies, -log_pvalues[amax].T, 'r-', label=amax)
plt.plot(latencies, -log_pvalues[amaxs].T, label=amaxs)
plt.axvline(0, color='k')
plt.legend()
plt.xlabel('latencies (ms)')
plt.ylabel('p-values')
plt.savefig(f'{plot_name}_2.png')
plt.clf()

return

"""es = SensorExpressionSet(
functions=function.name,
latencies=latencies / 1000,
sensors=sensor_names,
data=log_pvalues,
)"""
latencies_ms = np.linspace(start_latency, start_latency + (seconds_per_split * 1000), n_samples_per_split // 2 + 1)[:-1]

if channel_space == "sensor":
es = SensorExpressionSet(
functions=function.name,
latencies=latencies_ms / 1000, # seconds
sensors=channel_names,
data=log_pvalues,
)
elif channel_space == "source":
es = HexelExpressionSet(
functions=function.name + f"_mirrored-lh", # TODO: revert to just `function.name` when we have both hemispheres in place
latencies=latencies_ms / 1000, # seconds
hexels=channel_names,
data_lh=log_pvalues,
data_rh=log_pvalues, # TODO: distribute data correctly when we have both hemispheres in place
)
else:
raise NotImplementedError(channel_space)

return es

Expand Down Expand Up @@ -183,9 +141,11 @@ def _ttest(corrs: NDArray, use_all_lats: bool = True):
t_stat = numerator / denominator

if np.min(df) <= 300:
log_p = np.log(stats.t.sf(np.abs(t_stat), df) * 2) # two-tailed p-value
p = stats.t.sf(np.abs(t_stat), df) * 2 # two-tailed p-value
log_p = p_to_logp(p)
else:
# norm v good approx for this, (logsf for t not implemented in logspace)
log_p = stats.norm.logsf(np.abs(t_stat)) + np.log(2)
log_p = stats.norm.logsf(np.abs(t_stat)) + np.log(2)
log_p /= np.log(log_base) # log base correction

return log_p / np.log(10) # log base correction
return log_p
14 changes: 4 additions & 10 deletions kymata/io/mne.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
from pathlib import Path

from mne import read_evokeds, minimum_norm, set_eeg_reference
import numpy as np
from numpy.typing import NDArray
from os.path import isfile

from kymata.io.file import path_type



def load_single_emeg(emeg_path, need_names=False, inverse_operator=None, snr=4):
emeg_path_npy = f"{emeg_path}.npy"
Expand All @@ -19,16 +14,13 @@ def load_single_emeg(emeg_path, need_names=False, inverse_operator=None, snr=4):
evoked = read_evokeds(emeg_path_fif, verbose=False) # should be len 1 list
if inverse_operator is not None:
lh_emeg, rh_emeg, ch_names = inverse_operate(evoked[0], inverse_operator, snr)
# TODO: I think ch_names here is the wrong thing

emeg = None #np.concatenate((lh_emeg, rh_emeg), axis=0)
# TODO: I think ch_names here is the wrong thing

# TODO: currently this goes OOM (node-h04 atleast):
# looks like this will be faster when split up anyway
# note, don't run the inv_op twice for rh and lh!
# TODO: move inverse operator to run after EMEG channel combination

emeg = lh_emeg
emeg = lh_emeg #np.concatenate((lh_emeg, rh_emeg), axis=0)
del lh_emeg, rh_emeg
else:
emeg = evoked[0].get_data() # numpy array shape (sensor_num, N) = (370, 403_001)
Expand All @@ -39,13 +31,15 @@ def load_single_emeg(emeg_path, need_names=False, inverse_operator=None, snr=4):
del evoked
return emeg, ch_names


def inverse_operate(evoked, inverse_operator, snr=4):
lambda2 = 1.0 / snr ** 2
inverse_operator = minimum_norm.read_inverse_operator(inverse_operator, verbose=False)
set_eeg_reference(evoked, projection=True, verbose=False)
stc = minimum_norm.apply_inverse(evoked, inverse_operator, lambda2, 'MNE', pick_ori='normal', verbose=False)
return stc.lh_data, stc.rh_data, evoked.ch_names


def load_emeg_pack(emeg_paths, need_names=False, ave_mode=None, inverse_operator=None, p_tshift=None, snr=4): # TODO: FIX PRE-AVE-NORMALISATION
if p_tshift is None:
p_tshift = [0]*len(emeg_paths)
Expand Down
8 changes: 7 additions & 1 deletion kymata/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def expression_plot(
hidden_functions_in_legend: bool = True,
# I/O args
save_to: Optional[Path] = None,
overwrite: bool = True,
):
"""
Generates an expression plot
Expand Down Expand Up @@ -274,7 +275,12 @@ def expression_plot(

if save_to is not None:
pyplot.rcParams['savefig.dpi'] = 300
pyplot.savefig(Path(save_to), bbox_inches='tight')
save_to = Path(save_to)

if overwrite or not save_to.exists():
pyplot.savefig(Path(save_to), bbox_inches='tight')
else:
raise FileExistsError(save_to)

pyplot.show()
pyplot.close()
Expand Down