diff --git a/invokers/run_gridsearch.py b/invokers/run_gridsearch.py index 6fec8e3e..473f7962 100644 --- a/invokers/run_gridsearch.py +++ b/invokers/run_gridsearch.py @@ -1,95 +1,130 @@ from pathlib import Path import argparse +import time + +from kymata.datasets.data_root import data_root_path from kymata.gridsearch.plain import do_gridsearch from kymata.io.functions import load_function from kymata.io.mne import load_emeg_pack +from kymata.io.nkg import save_expression_set +from kymata.plot.plot import expression_plot + +_default_output_dir = Path(data_root_path(), "output") def main(): + _default_output_dir.mkdir(exist_ok=True, parents=False) + parser = argparse.ArgumentParser(description='Gridsearch Params') - parser.add_argument('--emeg_sample_rate', type=int, default=1000, + parser.add_argument('--emeg-sample-rate', type=int, default=1000, help='sampling rate of the emeg machine (not implemented yet)') - parser.add_argument('--snr', type=float, default=3, - help='inverse solution snr') - parser.add_argument('--downsample_rate', type=int, default=5, - help='downsample_rate') - parser.add_argument('--base_dir', type=str, default="/imaging/projects/cbu/kymata/data/dataset_4-english-narratives/", - help='base data directory') - parser.add_argument('--data_path', type=str, default="intrim_preprocessing_files/3_trialwise_sensorspace/evoked_data", - help='data path after base dir') - parser.add_argument('--function_path', type=str, default="predicted_function_contours/GMSloudness/stimulisig", - help='snr') - parser.add_argument('--function_name', type=str, default="d_IL2", - help='function name in stimulisig') - parser.add_argument('--emeg_file', type=str, default="participant_01-ave", - help='emeg_file_name') - parser.add_argument('--ave_mode', type=str, default="ave", + parser.add_argument('--snr', type=float, default=3, help='inverse solution snr') + parser.add_argument('--downsample-rate', type=int, default=5, help='downsample_rate') + parser.add_argument('--base-dir', type=str, default='/imaging/projects/cbu/kymata/data/dataset_4-english-narratives/', help='base data directory') + parser.add_argument('--data-path', type=str, default='intrim_preprocessing_files/3_trialwise_sensorspace/evoked_data', help='data path after base dir') + parser.add_argument('--function-path', type=str, default='predicted_function_contours/GMSloudness/stimulisig', help='location of function stimulisig') + parser.add_argument('--save-expression-set-location', type=Path, default=Path(_default_output_dir), + help="Save the results of the gridsearch into an ExpressionSet .nkg file") + parser.add_argument('--save-plot-location', type=Path, default=Path(_default_output_dir), + help="Save an expression plots, and other plots, in this location") + parser.add_argument('--overwrite', action="store_true", help="Silently overwrite existing files.") + parser.add_argument('--function-name', type=str, default="IL", help='function name in stimulisig') + parser.add_argument('--emeg-file', type=str, default="participant_01-ave", help='emeg_file_name') + parser.add_argument('--ave-mode', type=str, default="ave", help='either ave or add, either average over the list of repetitions or treat them as extra data') - parser.add_argument('--inverse_operator', type=str, default="intrim_preprocessing_files/4_hexel_current_reconstruction/inverse-operators", - help='inverse solution path') - parser.add_argument('--seconds_per_split', type=float, default=0.5, + parser.add_argument('--inverse-operator-dir', type=str, default=None, help='inverse solution path') + parser.add_argument('--inverse-operator-name', type=str, default="participant_01_ico5-3L-loose02-cps-nodepth-fusion.fif", + help='inverse solution name') + parser.add_argument('--seconds-per-split', type=float, default=0.5, help='seconds in each split of the recording, also maximum range of latencies being checked') - parser.add_argument('--n_splits', type=int, default=800, + parser.add_argument('--n-splits', type=int, default=800, help='number of splits to split the recording into, (set to 400/seconds_per_split for full file)') - parser.add_argument('--n_derangements', type=int, default=1, - help='inverse solution snr') - parser.add_argument('--start_latency', type=float, default=-100, + parser.add_argument('--n-derangements', type=int, default=1, + help='number of deragements for the null distribution') + parser.add_argument('--start-latency', type=float, default=-100, help='earliest latency to check in cross correlation') - parser.add_argument('--emeg_t_start', type=float, default=-200, + parser.add_argument('--emeg-t-start', type=float, default=-200, help='start of the emeg evoked files relative to the start of the function') - parser.add_argument('--audio_shift_correction', type=float, default=0.000_537_5, + parser.add_argument('--audio-shift-correction', type=float, default=0.000_537_5, help='audio shift correction, for every second of function, add this number of seconds (to the start of the emeg split) per seconds of emeg seen') args = parser.parse_args() args.base_dir = Path(args.base_dir) - emeg_dir = Path(args.base_dir, args.data_path) emeg_paths = [Path(emeg_dir, args.emeg_file)] - participants = ['participant_01', + participants = ['pilot_01', + 'pilot_02', + 'participant_01', 'participant_01b', 'participant_02', 'participant_03', 'participant_04', 'participant_05', - 'pilot_01', - 'pilot_02'] + 'participant_07', + 'participant_08', + 'participant_09', + 'participant_10', + 'participant_11', + 'participant_12', + 'participant_13', + 'participant_14', + 'participant_15', + 'participant_16', + 'participant_17' + ] reps = [f'_rep{i}' for i in range(8)] + ['-ave'] # emeg_paths = [Path(emeg_dir, p + r) for p in participants[:2] for r in reps[-1:]] - inverse_operator = Path(args.base_dir, args.inverse_operator, f"{participants[0]}_ico5-3L-loose02-cps-nodepth.fif") + start = time.time() + + if args.inverse_operator_dir is None: + inverse_operator = None + else: + inverse_operator = Path(args.base_dir, args.inverse_operator_dir, args.inverse_operator_name) # Load data - emeg, ch_names = load_emeg_pack(emeg_paths, - need_names=False, - ave_mode=args.ave_mode, - inverse_operator=None, #inverse_operator, # set to None/inverse_operator if you want to run on sensor space/source space - p_tshift=None, - snr=args.snr) + emeg_values, ch_names = load_emeg_pack(emeg_paths, + need_names=True, + ave_mode=args.ave_mode, + inverse_operator=inverse_operator, + p_tshift=None, + snr=args.snr) func = load_function(Path(args.base_dir, args.function_path), func_name=args.function_name, bruce_neurons=(5, 10)) func = func.downsampled(args.downsample_rate) + channel_space = "source" if inverse_operator is not None else "sensor" + es = do_gridsearch( - emeg_values=emeg, - sensor_names=ch_names, + emeg_values=emeg_values, + channel_names=ch_names, + channel_space=channel_space, function=func, seconds_per_split=args.seconds_per_split, n_derangements=args.n_derangements, n_splits=args.n_splits, start_latency=args.start_latency, + plot_location=args.save_plot_location, emeg_t_start=args.emeg_t_start, emeg_sample_rate=args.emeg_sample_rate, audio_shift_correction=args.audio_shift_correction, ave_mode=args.ave_mode, + overwrite=args.overwrite, ) - # expression_plot(es) + if args.save_expression_set_location is not None: + save_expression_set(es, to_path_or_file = Path(args.save_expression_set_location, args.function_name + '_gridsearch.nkg'), overwrite=args.overwrite) + + expression_plot(es, paired_axes=channel_space == "source", save_to=Path(args.save_plot_location, args.function_name + '_gridsearch.png'), overwrite=args.overwrite) + + print(f'Time taken for code to run: {time.time() - start:.4f}') + if __name__ == '__main__': main() diff --git a/kymata/config/dataset4.yaml b/kymata/config/dataset4.yaml index ce0ac379..fd242469 100755 --- a/kymata/config/dataset4.yaml +++ b/kymata/config/dataset4.yaml @@ -31,6 +31,7 @@ supress_excessive_plots_and_prompts: True # Inverse operator eeg: True meg: True +inverse_operator: "intrim_preprocessing_files/4_hexel_current_reconstruction/inverse-operators" # Method to estimate noise covariance matrix cov_method: 'grand_ave' # grand_ave | empty_room | run_start diff --git a/kymata/datasets/data_root.py b/kymata/datasets/data_root.py new file mode 100644 index 00000000..1e997d0a --- /dev/null +++ b/kymata/datasets/data_root.py @@ -0,0 +1,73 @@ +from os import getcwd, getenv +from pathlib import Path +from typing import Optional + +from kymata.io.file import path_type + + +_DATA_PATH_ENVIRONMENT_VAR_NAME = "KYMATA_DATA_ROOT" +DATA_DIR_NAME = "kymata-toolbox-data" + +# Places downloaded datasets could go, in order of preference +_preferred_default_data_locations = [ + Path(Path(__file__).parent.parent.parent), # kymata/../data_dir (next to kymata dir) + Path(getcwd()), # /data_dir + Path(Path.home(), "Documents"), # ~/Documents/data_dir + Path(Path.home()), # ~/data_dir +] + + +def data_root_path(data_root: Optional[path_type] = None) -> Path: + + # Check if the data root has been specified + + # Might be in an environmental variable + if data_root is None: + data_root: path_type | None = getenv(_DATA_PATH_ENVIRONMENT_VAR_NAME, default=None) + + # Might have been supplied as an argument + if data_root is not None: + if isinstance(data_root, str): + data_root = Path(data_root) + # Data root specified + if not data_root.exists(): + raise FileNotFoundError(f"data_root {str(data_root)} specified but does not exist") + if not data_root.is_dir(): + raise NotADirectoryError(f"Please specify a directory ({str(data_root)} is not a directory)") + + return data_root + + else: + # Data root not specified + + # Check if the data root already exists + for loc in _preferred_default_data_locations: + if (here := Path(loc, DATA_DIR_NAME)).exists(): + data_root = here + break + + # If not, attempt to create it + if data_root is None: + here: Path | None = None + for loc in _preferred_default_data_locations: + here = Path(loc, DATA_DIR_NAME) + try: + here.mkdir() + break + # If it fails for sensible reasons, no sweat, we'll fall through to the next option + except (FileNotFoundError, OSError): + # Parent didn't exist, not writeable, etc + pass + # Did we make it? + if here is not None and here.exists(): + data_root = here + else: + raise FileNotFoundError("Failed to create data root directory") + + # Data root location has been derived, rather than prespecified, so feed that back to the user to avoid a + # different location somehow being derived next time + print(f"Data root set at {str(data_root)}.") + print(f"Consider setting this as environmental variable {_DATA_PATH_ENVIRONMENT_VAR_NAME} to ensure it's reused" + f" next time.") + print(f"Hint: $> {_DATA_PATH_ENVIRONMENT_VAR_NAME}=\"{str(data_root)}\"") + return data_root diff --git a/kymata/datasets/sample.py b/kymata/datasets/sample.py index 719f0968..260aa8eb 100644 --- a/kymata/datasets/sample.py +++ b/kymata/datasets/sample.py @@ -1,23 +1,15 @@ from abc import ABC, abstractmethod -from os import getenv, getcwd, remove, rmdir +from os import remove, rmdir from pathlib import Path from typing import Optional from urllib import request +from kymata.datasets.data_root import data_root_path from kymata.entities.expression import HexelExpressionSet, SensorExpressionSet from kymata.io.file import path_type from kymata.io.nkg import load_expression_set -_DATA_PATH_ENVIRONMENT_VAR_NAME = "KYMATA_DATA_ROOT" -_DATA_DIR_NAME = "kymata-toolbox-data/tutorial_nkg_data" - -# Places downloaded datasets could go, in order of preference -_preferred_default_data_locations = [ - Path(Path(__file__).parent.parent.parent), # kymata/../data_dir (next to kymata dir) - Path(getcwd()), # /data_dir - Path(Path.home(), "Documents"), # ~/Documents/data_dir - Path(Path.home()), # ~/data_dir -] +_SAMPLE_DATA_DIR_NAME = "tutorial_nkg_data" class SampleDataset(ABC): @@ -36,9 +28,13 @@ def __init__(self, download: bool): self.name: str = name self.filenames: list[str] = filenames - self.data_root: Path = data_root_path(data_root) + self.data_root: Path = Path(data_root_path(data_root), _SAMPLE_DATA_DIR_NAME) self.remote_root: str = remote_root + # Create the default location, if it's being used + if data_root is None: + self.data_root.mkdir(exist_ok=True) + if download: self.download() @@ -141,62 +137,6 @@ def to_expressionset(self) -> SensorExpressionSet: return es -def data_root_path(data_root: Optional[path_type] = None) -> Path: - - # Check if the data root has been specified - - # Might be in an environmental variable - if data_root is None: - data_root: path_type | None = getenv(_DATA_PATH_ENVIRONMENT_VAR_NAME, default=None) - - # Might have been supplied as an argument - if data_root is not None: - if isinstance(data_root, str): - data_root = Path(data_root) - # Data root specified - if not data_root.exists(): - raise FileNotFoundError(f"data_root {str(data_root)} specified but does not exist") - if not data_root.is_dir(): - raise NotADirectoryError(f"Please specify a directory ({str(data_root)} is not a directory)") - - return data_root - - else: - # Data root not specified - - # Check if the data root already exists - for loc in _preferred_default_data_locations: - if (here := Path(loc, _DATA_DIR_NAME)).exists(): - data_root = here - break - - # If not, attempt to create it - if data_root is None: - here: Path | None = None - for loc in _preferred_default_data_locations: - here = Path(loc, _DATA_DIR_NAME) - try: - here.mkdir() - break - # If it fails for sensible reasons, no sweat, we'll fall through to the next option - except (FileNotFoundError, OSError): - # Parent didn't exist, not writeable, etc - pass - # Did we make it? - if here is not None and here.exists(): - data_root = here - else: - raise FileNotFoundError("Failed to create data root directory") - - # Data root location has been derived, rather than prespecified, so feed that back to the user to avoid a - # different location somehow being derived next time - print(f"Data root set at {str(data_root)}.") - print(f"Consider setting this as environmental variable {_DATA_PATH_ENVIRONMENT_VAR_NAME} to ensure it's reused" - f" next time.") - print(f"Hint: $> {_DATA_PATH_ENVIRONMENT_VAR_NAME}=\"{str(data_root)}\"") - return data_root - - def delete_dataset(local_dataset: SampleDataset): # Make sure it's not silent print(f"Deleting dataset {local_dataset.name}") diff --git a/kymata/entities/expression.py b/kymata/entities/expression.py index c27de851..7f44d226 100644 --- a/kymata/entities/expression.py +++ b/kymata/entities/expression.py @@ -92,8 +92,8 @@ def __init__(self, data = data[i] data = self._init_prep_data(data) # Check validity of input data dimensions - assert len(channels) == data.shape[0], f"{channel_coord_name} mismatch for {f}" - assert len(latencies) == data.shape[1], f"Latencies mismatch for {f}" + assert len(channels) == data.shape[0], f"{channel_coord_name} mismatch for {f}: {len(channels)} {channel_coord_name} versus data shape {data.shape}" + assert len(latencies) == data.shape[1], f"Latencies mismatch for {f}: {len(latencies)} latencies versus data shape {data.shape}" dataset_dict[layer] = DataArray( data=data, dims=self._dims, @@ -403,6 +403,9 @@ def best_functions(self) -> DataFrame: return super()._best_functions_for_layer(LAYER_SCALP) +log_base = 10 + + def p_to_logp(arraylike: ArrayLike) -> ArrayLike: """The one-stop-shop for converting from p-values to log p-values.""" return log10(arraylike) diff --git a/kymata/gridsearch/plain.py b/kymata/gridsearch/plain.py index 94ee55d2..1bd77d9e 100644 --- a/kymata/gridsearch/plain.py +++ b/kymata/gridsearch/plain.py @@ -1,35 +1,42 @@ +from pathlib import Path +from typing import Optional + import numpy as np -from numpy.typing import NDArray +from numpy.typing import NDArray, ArrayLike from scipy import stats from kymata.entities.functions import Function from kymata.math.combinatorics import generate_derangement from kymata.math.vector import normalize, get_stds -#from kymata.entities.expression import SensorExpressionSet, p_to_logp -import matplotlib.pyplot as plt - +from kymata.entities.expression import ExpressionSet, SensorExpressionSet, HexelExpressionSet, p_to_logp, log_base +from kymata.plot.plot import plot_top_five_channels_of_gridsearch def do_gridsearch( emeg_values: NDArray, # chan x time function: Function, - sensor_names: list[str], + channel_names: list, + channel_space: str, start_latency: float, # ms emeg_t_start: float, # ms + plot_location: Optional[Path] = None, emeg_sample_rate: int = 1000, # Hertz audio_shift_correction: float = 0.000_537_5, # seconds/second # TODO: describe in which direction? n_derangements: int = 1, seconds_per_split: float = 0.5, n_splits: int = 800, ave_mode: str = 'ave', # either ave or add, for averaging over input files or adding in as extra evidence - add_autocorr: bool = True, - plot_name: str = 'example' - ): + overwrite: bool = True, +) -> ExpressionSet: """ Do the Kymata gridsearch over all hexels for all latencies. """ + channel_space = channel_space.lower() + if channel_space not in {"sensor", "source"}: + raise NotImplementedError(channel_space) + # We'll need to downsample the EMEG to match the function's sample rate - downsample_rate: int = int(emeg_sample_rate / function.sample_rate) + downsample_rate: int = int(emeg_sample_rate / function.sample_rate) # TODO: implement for general emeg_sample_rate n_samples_per_split = int(seconds_per_split * emeg_sample_rate * 2 // downsample_rate) @@ -47,8 +54,10 @@ def do_gridsearch( n_channels = emeg_values.shape[0] # Reshape EMEG into splits of `seconds_per_split` s - split_initial_timesteps = [int(start_latency + round(i * 1000 * seconds_per_split * (1 + audio_shift_correction)) - emeg_t_start) - for i in range(n_splits)] + split_initial_timesteps = [ + int(start_latency + round(i * 1000 * seconds_per_split * (1 + audio_shift_correction)) - emeg_t_start) + for i in range(n_splits) + ] emeg_reshaped = np.zeros((n_channels, n_splits * n_reps, n_samples_per_split)) for j in range(n_reps): @@ -75,80 +84,62 @@ def do_gridsearch( deranged_emeg = emeg_reshaped[:, derangement, :] corrs[:, der_i] = np.fft.irfft(deranged_emeg * F_func)[:, :, :n_samples_per_split//2] / emeg_stds[:, derangement] - if add_autocorr: - auto_corrs = np.zeros((n_splits, n_samples_per_split//2)) - noise = normalize(np.random.randn(func.shape[0], func.shape[1])) * 0 - noisy_func = normalize(np.copy(func)) + noise - nn = n_samples_per_split // 2 + # work out autocorrelation for channel-by-channel plots + noise = normalize(np.random.randn(func.shape[0], func.shape[1])) * 0 + noisy_func = normalize(np.copy(func)) + noise + nn = n_samples_per_split // 2 - F_noisy_func = np.fft.rfft(normalize(noisy_func), n=nn, axis=-1) - F_func = np.conj(np.fft.rfft(normalize(func), n=nn, axis=-1)) + F_noisy_func = np.fft.rfft(normalize(noisy_func), n=nn, axis=-1) + F_func = np.conj(np.fft.rfft(normalize(func), n=nn, axis=-1)) - auto_corrs = np.fft.irfft(F_noisy_func * F_func) + auto_corrs = np.fft.irfft(F_noisy_func * F_func) del F_func, deranged_emeg, emeg_reshaped + # derive pvalues log_pvalues = _ttest(corrs) - latencies = np.linspace(start_latency, start_latency + (seconds_per_split * 1000), n_samples_per_split // 2 + 1)[:-1] - - if plot_name: - plt.figure(1) - corr_avrs = np.mean(corrs[:, 0]**2, axis=-2) - maxs = np.max(corr_avrs, axis=1) - n_amaxs = 5 - amaxs = np.argpartition(maxs, -n_amaxs)[-n_amaxs:] - amax = np.argmax(corr_avrs) // (n_samples_per_split // 2) - amaxs = [i for i in amaxs if i != amax] # + [209] - - plt.plot(latencies, np.mean(corrs[amax, 0], axis=-2).T, 'r-', label=amax) - plt.plot(latencies, np.mean(corrs[amaxs, 0], axis=-2).T, label=amaxs) - std_null = np.mean(np.std(corrs[:, 1], axis=-2), axis=0).T * 3 / np.sqrt(n_reps * n_splits) # 3 pop std.s - std_real = np.std(corrs[amax, 0], axis=-2).T * 3 / np.sqrt(n_reps * n_splits) - av_real = np.mean(corrs[amax, 0], axis=-2).T - #print(std_null) - plt.fill_between(latencies, -std_null, std_null, alpha=0.5, color='grey') - plt.fill_between(latencies, av_real - std_real, av_real + std_real, alpha=0.25, color='red') - - if add_autocorr: - peak_lat_ind = np.argmax(corr_avrs) % (n_samples_per_split // 2) - peak_lat = latencies[peak_lat_ind] - peak_corr = np.mean(corrs[amax, 0], axis=-2)[peak_lat_ind] - print(f'{function.name}: peak lat, peak corr, ind:', peak_lat, peak_corr, amax) - - auto_corrs = np.mean(auto_corrs, axis=0) - plt.plot(latencies, np.roll(auto_corrs, peak_lat_ind) * peak_corr / np.max(auto_corrs), 'k--', label='func auto-corr') - - plt.axvline(0, color='k') - plt.legend() - plt.xlabel('latencies (ms)') - plt.ylabel('Corr coef.') - plt.savefig(f'{plot_name}_1.png') - plt.clf() - - plt.figure(2) - plt.plot(latencies, -log_pvalues[amax].T, 'r-', label=amax) - plt.plot(latencies, -log_pvalues[amaxs].T, label=amaxs) - plt.axvline(0, color='k') - plt.legend() - plt.xlabel('latencies (ms)') - plt.ylabel('p-values') - plt.savefig(f'{plot_name}_2.png') - plt.clf() - - return - - """es = SensorExpressionSet( - functions=function.name, - latencies=latencies / 1000, - sensors=sensor_names, - data=log_pvalues, - )""" + latencies_ms = np.linspace(start_latency, start_latency + (seconds_per_split * 1000), n_samples_per_split // 2 + 1)[:-1] + + plot_top_five_channels_of_gridsearch( + corrs=corrs, + auto_corrs=auto_corrs, + function=function, + n_reps=n_reps, + n_splits=n_splits, + n_samples_per_split=n_samples_per_split, + latencies=latencies_ms, + save_to=plot_location, + log_pvalues=log_pvalues, + overwrite=overwrite, + ) + + if channel_space == "sensor": + es = SensorExpressionSet( + functions=function.name, + latencies=latencies_ms / 1000, # seconds + sensors=channel_names, + data=log_pvalues, + ) + elif channel_space == "source": + + log_pvalues_lh, log_pvalues_rh = np.split(log_pvalues, 2, axis=0) + + es = HexelExpressionSet( + functions=function.name, + latencies=latencies_ms / 1000, # seconds + hexels=channel_names[1,:10239], # TODO: HACK - FIX WITH ISSUE #141 + data_lh=log_pvalues_lh[:10239,], # TODO: HACK - FIX WITH ISSUE #141 + data_rh=log_pvalues_rh[:10239,], # TODO: HACK - FIX WITH ISSUE #141 + ) + else: + raise NotImplementedError(channel_space) return es -def _ttest(corrs: NDArray, use_all_lats: bool = True): +def _ttest(corrs: NDArray, use_all_lats: bool = True) -> ArrayLike: + """ Vectorised Welch's t-test. """ @@ -183,9 +174,11 @@ def _ttest(corrs: NDArray, use_all_lats: bool = True): t_stat = numerator / denominator if np.min(df) <= 300: - log_p = np.log(stats.t.sf(np.abs(t_stat), df) * 2) # two-tailed p-value + p = stats.t.sf(np.abs(t_stat), df) * 2 # two-tailed p-value + log_p = p_to_logp(p) else: # norm v good approx for this, (logsf for t not implemented in logspace) - log_p = stats.norm.logsf(np.abs(t_stat)) + np.log(2) + log_p = stats.norm.logsf(np.abs(t_stat)) + np.log(2) + log_p /= np.log(log_base) # log base correction - return log_p / np.log(10) # log base correction + return log_p diff --git a/kymata/io/cli.py b/kymata/io/cli.py index 6034a22b..659d1ca0 100644 --- a/kymata/io/cli.py +++ b/kymata/io/cli.py @@ -1,3 +1,4 @@ +from sys import stdout from colorama import Style @@ -9,3 +10,47 @@ def print_with_color(message: str, fore_color, style=Style.BRIGHT): def input_with_color(message: str, fore_color, style=Style.BRIGHT) -> str: """Get input in a bright style""" return input(f"{fore_color}{style}{message}{Style.RESET_ALL}") + + +def print_progress(iteration: int, total: int, + prefix: str = '', + suffix: str = '', + *, + decimals: int = 1, + bar_length: int = 100, + update_downsample: int = 1, + full_char='▒', + empty_char='┄', + terminal_char='║', + clear_on_completion: bool = False): + """ + Call in a loop to create terminal progress bar. + Based on https://github.com/emcoglab/ldm-core/blob/main/utils/log.py + @params: + iteration - Required : current iteration (0-indexed) (Int) + (so the at the last iteration, `iteration == total - 1`) + total - Required : total iterations (Int) + prefix - Optional : prefix string (Str) + suffix - Optional : suffix string (Str) + decimals - Optional : positive number of decimals in percent complete (Int) + bar_length - Optional : character length of bar (Int) + clear_on_completion - Optional : clear the bar when it reaches 100% (bool) + """ + + iteration += 1 # convert to 1-indexed to make the maths easier + + if iteration < total and iteration % update_downsample != 0: + stdout.flush() + return + + portion_complete = iteration / float(total) + percents = f"{100 * portion_complete:.{decimals}f}%" + filled_length = int(round(bar_length * portion_complete)) + bar = (full_char * filled_length) + (empty_char * (bar_length - filled_length)) + + stdout.write(f'\r{prefix}{terminal_char}{bar}{terminal_char} {percents}{suffix}'), + + if iteration == total: + stdout.write("\r" if clear_on_completion else "\n") + + stdout.flush() diff --git a/kymata/io/mne.py b/kymata/io/mne.py index bd4e1525..1920eee2 100644 --- a/kymata/io/mne.py +++ b/kymata/io/mne.py @@ -1,15 +1,13 @@ -from pathlib import Path - from mne import read_evokeds, minimum_norm, set_eeg_reference import numpy as np from numpy.typing import NDArray from os.path import isfile -from kymata.io.file import path_type - - def load_single_emeg(emeg_path, need_names=False, inverse_operator=None, snr=4): + """ + When using the inverse operator, returns left and right hemispheres concatenated + """ emeg_path_npy = f"{emeg_path}.npy" emeg_path_fif = f"{emeg_path}.fif" if isfile(emeg_path_npy) and (not need_names) and (inverse_operator is None): @@ -19,16 +17,12 @@ def load_single_emeg(emeg_path, need_names=False, inverse_operator=None, snr=4): evoked = read_evokeds(emeg_path_fif, verbose=False) # should be len 1 list if inverse_operator is not None: lh_emeg, rh_emeg, ch_names = inverse_operate(evoked[0], inverse_operator, snr) - # TODO: I think ch_names here is the wrong thing - - emeg = None #np.concatenate((lh_emeg, rh_emeg), axis=0) # TODO: currently this goes OOM (node-h04 atleast): # looks like this will be faster when split up anyway # note, don't run the inv_op twice for rh and lh! # TODO: move inverse operator to run after EMEG channel combination - - emeg = lh_emeg + emeg = np.concatenate((lh_emeg, rh_emeg), axis=0) del lh_emeg, rh_emeg else: emeg = evoked[0].get_data() # numpy array shape (sensor_num, N) = (370, 403_001) @@ -39,12 +33,15 @@ def load_single_emeg(emeg_path, need_names=False, inverse_operator=None, snr=4): del evoked return emeg, ch_names + def inverse_operate(evoked, inverse_operator, snr=4): lambda2 = 1.0 / snr ** 2 inverse_operator = minimum_norm.read_inverse_operator(inverse_operator, verbose=False) set_eeg_reference(evoked, projection=True, verbose=False) stc = minimum_norm.apply_inverse(evoked, inverse_operator, lambda2, 'MNE', pick_ori='normal', verbose=False) - return stc.lh_data, stc.rh_data, evoked.ch_names + print("Inverse operator applied") + return stc.lh_data, stc.rh_data, stc.vertices + def load_emeg_pack(emeg_paths, need_names=False, ave_mode=None, inverse_operator=None, p_tshift=None, snr=4): # TODO: FIX PRE-AVE-NORMALISATION if p_tshift is None: diff --git a/kymata/plot/plot.py b/kymata/plot/plot.py index a46a25e3..bd75820c 100644 --- a/kymata/plot/plot.py +++ b/kymata/plot/plot.py @@ -4,6 +4,7 @@ from typing import Optional, Sequence, Dict, NamedTuple import numpy as np +from numpy.typing import NDArray from matplotlib import pyplot, colors from matplotlib.lines import Line2D from pandas import DataFrame @@ -11,6 +12,7 @@ from kymata.entities.expression import HexelExpressionSet, ExpressionSet, SensorExpressionSet, DIM_SENSOR, DIM_FUNCTION, \ p_to_logp +from kymata.entities.functions import Function from kymata.plot.layouts import get_meg_sensor_xy, eeg_sensors # log scale: 10 ** -this will be the ytick interval and also the resolution to which the ylims will be rounded @@ -75,6 +77,7 @@ def expression_plot( hidden_functions_in_legend: bool = True, # I/O args save_to: Optional[Path] = None, + overwrite: bool = True, ): """ Generates an expression plot @@ -274,7 +277,12 @@ def expression_plot( if save_to is not None: pyplot.rcParams['savefig.dpi'] = 300 - pyplot.savefig(Path(save_to), bbox_inches='tight') + save_to = Path(save_to) + + if overwrite or not save_to.exists(): + pyplot.savefig(Path(save_to), bbox_inches='tight') + else: + raise FileExistsError(save_to) pyplot.show() pyplot.close() @@ -308,3 +316,78 @@ def _get_yticks(ylim): n_major_ticks = int(ylim / _MAJOR_TICK_SIZE) * -1 last_major_tick = -1 * n_major_ticks * _MAJOR_TICK_SIZE return np.linspace(start=0, stop=last_major_tick, num=n_major_ticks + 1) + +def plot_top_five_channels_of_gridsearch( + latencies: NDArray[any], + corrs:NDArray[any], + function:Function, + n_samples_per_split:int, + n_reps: int, + n_splits: int, + auto_corrs:NDArray[any], + log_pvalues: any, + # I/O args + save_to: Optional[Path] = None, + overwrite: bool = True, +): + """ + Generates correlation and pvalue plots showing the top five channels of the gridsearch + + latencies: ... + function: ... + etc.. + """ + + figure, axis = pyplot.subplots(1, 2, figsize=(15, 7)) + figure.suptitle(f'{function.name}: Plotting corrs and pvalues for top five channels') + + corr_avrs = np.mean(corrs[:, 0], axis=-2) ** 2 # (n_chans, n_derangs, n_splits, t_steps) -> (n_chans, t_steps) + maxs = np.max(corr_avrs, axis=1) + n_amaxs = 5 + amaxs = np.argpartition(maxs, -n_amaxs)[-n_amaxs:] + amax = np.argmax(corr_avrs) // (n_samples_per_split // 2) + amaxs = [i for i in amaxs if i != amax] # + [209] + + axis[0].plot(latencies, np.mean(corrs[amax, 0], axis=-2).T, 'r-', label=amax) + axis[0].plot(latencies, np.mean(corrs[amaxs, 0], axis=-2).T, label=amaxs) + std_null = np.mean(np.std(corrs[:, 1], axis=-2), axis=0).T * 3 / np.sqrt(n_reps * n_splits) # 3 pop std.s + std_real = np.std(corrs[amax, 0], axis=-2).T * 3 / np.sqrt(n_reps * n_splits) + av_real = np.mean(corrs[amax, 0], axis=-2).T + + axis[0].fill_between(latencies, -std_null, std_null, alpha=0.5, color='grey') + axis[0].fill_between(latencies, av_real - std_real, av_real + std_real, alpha=0.25, color='red') + + peak_lat_ind = np.argmax(corr_avrs) % (n_samples_per_split // 2) + peak_lat = latencies[peak_lat_ind] + peak_corr = np.mean(corrs[amax, 0], axis=-2)[peak_lat_ind] + print(f'{function.name}: peak lat: {peak_lat:.1f}, peak corr: {peak_corr:.4f} [sensor] ind: {amax}, -log(pval): {-log_pvalues[amax][peak_lat_ind]:.4f}') + + auto_corrs = np.mean(auto_corrs, axis=0) + axis[0].plot(latencies, np.roll(auto_corrs, peak_lat_ind) * peak_corr / np.max(auto_corrs), 'k--', + label='func auto-corr') + + axis[0].axvline(0, color='k') + axis[0].legend() + axis[0].set_title("Corr coef.") + axis[0].set_xlabel('latencies (ms)') + axis[0].set_ylabel('Corr coef.') + + axis[1].plot(latencies, -log_pvalues[amax].T, 'r-', label=amax) + axis[1].plot(latencies, -log_pvalues[amaxs].T, label=amaxs) + axis[1].axvline(0, color='k') + axis[1].legend() + axis[1].set_title("p-values") + axis[1].set_xlabel('latencies (ms)') + axis[1].set_ylabel('p-values') + + if save_to is not None: + pyplot.rcParams['savefig.dpi'] = 300 + save_to = Path(save_to, function.name + '_gridsearch_top_five_channels.png') + + if overwrite or not save_to.exists(): + pyplot.savefig(Path(save_to)) + else: + raise FileExistsError(save_to) + + pyplot.clf() + pyplot.close() \ No newline at end of file diff --git a/submit_gridsearch.sh b/submit_gridsearch.sh index 41763242..228317e0 100755 --- a/submit_gridsearch.sh +++ b/submit_gridsearch.sh @@ -11,16 +11,26 @@ #SBATCH --error=slurm_log.txt #SBATCH --ntasks=1 #SBATCH --time=05:00:00 -#SBATCH --mem=1000 +#SBATCH --mem=240G #SBATCH --array=1-1 #SBATCH --exclusive -conda activate mne_venv - args=(5) # 2 3 4 5 6 7 8 9 10) ARG=${args[$SLURM_ARRAY_TASK_ID - 1]} -python invokers/run_gridsearch.py - # --snr $ARG # >> result3.txt - -conda deactivate +module load apptainer +apptainer exec \ + -B /imaging/projects/cbu/kymata/ \ + /imaging/local/software/singularity_images/python/python_3.11.7-slim.sif \ + bash -c \ + " cd /imaging/projects/cbu/kymata/analyses/andy/kymata-toolbox/ ; \ + export VENV_PATH=~/poetry/ ; \ + \$VENV_PATH/bin/poetry run python -m invokers.run_gridsearch \ + --base-dir '/imaging/projects/cbu/kymata/data/dataset_4-english-narratives/' \ + --function-path 'predicted_function_contours/GMSloudness/stimulisig' \ + --function-name 'IL' \ + --emeg-file 'participant_01-ave' \ + --overwrite \ + --inverse-operator-dir '/imaging/projects/cbu/kymata/data/dataset_4-english-narratives/intrim_preprocessing_files/4_hexel_current_reconstruction/inverse-operators/' + " + # --snr $ARG # >> result3.txt