Skip to content

Commit

Permalink
Merge branch 'main' into big-preprocessing-refactor-cov
Browse files Browse the repository at this point in the history
  • Loading branch information
neukym committed Jan 12, 2024
2 parents d32c0cd + b999e49 commit 7cf2e78
Show file tree
Hide file tree
Showing 22 changed files with 725 additions and 64 deletions.
2 changes: 1 addition & 1 deletion demos/demo_plotting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"source": [
"from kymata.datasets.sample import KymataMirror2023Q3Dataset, TVLDeltaInsTC1LoudnessOnlySensorsDataset\n",
"from kymata.entities.expression import HexelExpressionSet, SensorExpressionSet\n",
"from kymata.plot.plotting import expression_plot"
"from kymata.plot.plot import expression_plot"
]
},
{
Expand Down
95 changes: 95 additions & 0 deletions invokers/run_gridsearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from pathlib import Path
import argparse
from kymata.gridsearch.plain import do_gridsearch
from kymata.io.functions import load_function
from kymata.io.mne import load_emeg_pack


def main():

parser = argparse.ArgumentParser(description='Gridsearch Params')
parser.add_argument('--emeg_sample_rate', type=int, default=1000,
help='sampling rate of the emeg machine (not implemented yet)')
parser.add_argument('--snr', type=float, default=3,
help='inverse solution snr')
parser.add_argument('--downsample_rate', type=int, default=5,
help='downsample_rate')
parser.add_argument('--base_dir', type=str, default="/imaging/projects/cbu/kymata/data/dataset_4-english-narratives/",
help='base data directory')
parser.add_argument('--data_path', type=str, default="intrim_preprocessing_files/3_trialwise_sensorspace/evoked_data",
help='data path after base dir')
parser.add_argument('--function_path', type=str, default="predicted_function_contours/GMSloudness/stimulisig",
help='snr')
parser.add_argument('--function_name', type=str, default="d_IL2",
help='function name in stimulisig')
parser.add_argument('--emeg_file', type=str, default="participant_01-ave",
help='emeg_file_name')
parser.add_argument('--ave_mode', type=str, default="ave",
help='either ave or add, either average over the list of repetitions or treat them as extra data')
parser.add_argument('--inverse_operator', type=str, default="intrim_preprocessing_files/4_hexel_current_reconstruction/inverse-operators",
help='inverse solution path')
parser.add_argument('--seconds_per_split', type=float, default=0.5,
help='seconds in each split of the recording, also maximum range of latencies being checked')
parser.add_argument('--n_splits', type=int, default=800,
help='number of splits to split the recording into, (set to 400/seconds_per_split for full file)')
parser.add_argument('--n_derangements', type=int, default=1,
help='inverse solution snr')
parser.add_argument('--start_latency', type=float, default=-100,
help='earliest latency to check in cross correlation')
parser.add_argument('--emeg_t_start', type=float, default=-200,
help='start of the emeg evoked files relative to the start of the function')
parser.add_argument('--audio_shift_correction', type=float, default=0.000_537_5,
help='audio shift correction, for every second of function, add this number of seconds (to the start of the emeg split) per seconds of emeg seen')
args = parser.parse_args()
args.base_dir = Path(args.base_dir)


emeg_dir = Path(args.base_dir, args.data_path)
emeg_paths = [Path(emeg_dir, args.emeg_file)]

participants = ['participant_01',
'participant_01b',
'participant_02',
'participant_03',
'participant_04',
'participant_05',
'pilot_01',
'pilot_02']

reps = [f'_rep{i}' for i in range(8)] + ['-ave']

# emeg_paths = [Path(emeg_dir, p + r) for p in participants[:2] for r in reps[-1:]]

inverse_operator = Path(args.base_dir, args.inverse_operator, f"{participants[0]}_ico5-3L-loose02-cps-nodepth.fif")

# Load data
emeg, ch_names = load_emeg_pack(emeg_paths,
need_names=False,
ave_mode=args.ave_mode,
inverse_operator=None, #inverse_operator, # set to None/inverse_operator if you want to run on sensor space/source space
p_tshift=None,
snr=args.snr)

func = load_function(Path(args.base_dir, args.function_path),
func_name=args.function_name,
bruce_neurons=(5, 10))
func = func.downsampled(args.downsample_rate)

es = do_gridsearch(
emeg_values=emeg,
sensor_names=ch_names,
function=func,
seconds_per_split=args.seconds_per_split,
n_derangements=args.n_derangements,
n_splits=args.n_splits,
start_latency=args.start_latency,
emeg_t_start=args.emeg_t_start,
emeg_sample_rate=args.emeg_sample_rate,
audio_shift_correction=args.audio_shift_correction,
ave_mode=args.ave_mode,
)

# expression_plot(es)

if __name__ == '__main__':
main()
33 changes: 33 additions & 0 deletions invokers/run_preprocessing_dataset4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from pathlib import Path

import sys
sys.path.append('/imaging/projects/cbu/kymata/analyses/ollie/kymata-toolbox')

from kymata.io.yaml import load_config
from kymata.preproc.data_cleansing import run_preprocessing, create_trials, create_trialwise_data


# noinspection DuplicatedCode
def main():
config = load_config(str(Path(Path(__file__).parent.parent, "kymata", "config", "dataset4.yaml")))

create_trialwise_data(
dataset_directory_name=config['dataset_directory_name'],
list_of_participants=config['list_of_participants'],
repetitions_per_runs=config['repetitions_per_runs'],
number_of_runs=config['number_of_runs'],
number_of_trials=config['number_of_trials'],
input_streams=config['input_streams'],
eeg_thresh=float(config['eeg_thresh']),
grad_thresh=float(config['grad_thresh']),
mag_thresh=float(config['mag_thresh']),
visual_delivery_latency=config['visual_delivery_latency'],
audio_delivery_latency=config['audio_delivery_latency'],
audio_delivery_shift_correction=config['audio_delivery_shift_correction'],
tmin=config['tmin'],
tmax=config['tmax'],
)


if __name__ == '__main__':
main()
6 changes: 3 additions & 3 deletions kymata/config/dataset3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ supress_excessive_plots_and_prompts: True
audio_delivery_latency: 26 # in milliseconds
visual_delivery_latency: 17 # in milliseconds
tactile_delivery_latency: 0 # in milliseconds
audio_delivery_shift_correction: 0 # in milliseconds per second
tmin: -0.2
tmax: 1.8
audio_delivery_shift_correction: 0 # in seconds per second
tmin: -0.2 # seconds
tmax: 1.8 # seconds

eeg_thresh: 1 #200e-6
grad_thresh: 1 #4000e-13
Expand Down
6 changes: 3 additions & 3 deletions kymata/config/dataset4.yaml
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ cov_method: 'grand_ave' # grand_ave | empty_room | run_start
audio_delivery_latency: 26 # in milliseconds
visual_delivery_latency: 17 # in milliseconds
tactile_delivery_latency: 0 # in milliseconds
audio_delivery_shift_correction: 0.5375 # in milliseconds per second
tmin: -0.2
tmax: 1.8
audio_delivery_shift_correction: 0.0005375 # in seconds per second
tmin: -0.2 # seconds
tmax: 1.8 # seconds

eeg_thresh: 1 #200e-6
grad_thresh: 1 #4000e-13
Expand Down
42 changes: 21 additions & 21 deletions kymata/entities/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
_InputDataArray = Union[ndarray, SparseArray] # Type alias for data which can be accepted

# Data dimension labels
_HEXEL = "hexel"
_SENSOR = "sensor"
_LATENCY = "latency"
_FUNCTION = "function"
DIM_HEXEL = "hexel"
DIM_SENSOR = "sensor"
DIM_LATENCY = "latency"
DIM_FUNCTION = "function"

# Layer (e.g. hemisphere)
LAYER_LEFT = "left"
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(self,
self._layers: list[str] = list(data_layers.keys())

self._channel_coord_name = channel_coord_name
self._dims = (channel_coord_name, _LATENCY, _FUNCTION) # Canonical order of dimensions
self._dims = (channel_coord_name, DIM_LATENCY, DIM_FUNCTION) # Canonical order of dimensions

# Validate arguments
_length_mismatch_message = ("Argument length mismatch, please supply one function name and accompanying data, "
Expand Down Expand Up @@ -100,9 +100,9 @@ def __init__(self,
)
datasets.append(
Dataset(dataset_dict,
coords={channel_coord_name: channels, _LATENCY: latencies, _FUNCTION: [f]})
coords={channel_coord_name: channels, DIM_LATENCY: latencies, DIM_FUNCTION: [f]})
)
self._data = concat(datasets, dim=_FUNCTION)
self._data = concat(datasets, dim=DIM_FUNCTION)

@classmethod
def _init_prep_data(cls, data: _InputDataArray) -> COO:
Expand All @@ -120,12 +120,12 @@ def _channels(self) -> NDArray:
@property
def functions(self) -> list[FunctionNameDType]:
"""Function names."""
return self._data.coords[_FUNCTION].values.tolist()
return self._data.coords[DIM_FUNCTION].values.tolist()

@property
def latencies(self) -> NDArray[LatencyDType]:
"""Latencies, in seconds."""
return self._data.coords[_LATENCY].values
return self._data.coords[DIM_LATENCY].values

@abstractmethod
def __getitem__(self, functions: str | Sequence[str]) -> ExpressionSet:
Expand Down Expand Up @@ -167,27 +167,27 @@ def _best_functions_for_layer(self, layer: str) -> DataFrame:
data = self._data.copy()
densify_dataset(data)

best_latency = data.idxmin(dim=_LATENCY) # (channel, function) → l, the best latency
logp_at_best_latency = data.min(dim=_LATENCY) # (channel, function) → log p of best latency for each function
best_latency = data.idxmin(dim=DIM_LATENCY) # (channel, function) → l, the best latency
logp_at_best_latency = data.min(dim=DIM_LATENCY) # (channel, function) → log p of best latency for each function

logp_at_best_function = logp_at_best_latency.min(dim=_FUNCTION) # (channel) → log p of best function (at best latency)
best_function = logp_at_best_latency.idxmin(dim=_FUNCTION) # (channel) → f, the best function
logp_at_best_function = logp_at_best_latency.min(dim=DIM_FUNCTION) # (channel) → log p of best function (at best latency)
best_function = logp_at_best_latency.idxmin(dim=DIM_FUNCTION) # (channel) → f, the best function

# TODO: shame I have to break into the layer structure here,
# but I can't think of a better way to do it
logp_vals = logp_at_best_function[layer].data

best_functions = best_function[layer].data

best_latencies = best_latency[layer].sel({self._channel_coord_name: self._channels, _FUNCTION: best_function[layer]}).data
best_latencies = best_latency[layer].sel({self._channel_coord_name: self._channels, DIM_FUNCTION: best_function[layer]}).data

# Cut out channels which have a best log p-val of 1
idxs = logp_vals < 1

return DataFrame.from_dict({
self._channel_coord_name: self._channels[idxs],
_FUNCTION: best_functions[idxs],
_LATENCY: best_latencies[idxs],
DIM_FUNCTION: best_functions[idxs],
DIM_LATENCY: best_latencies[idxs],
"value": logp_vals[idxs],
})

Expand Down Expand Up @@ -221,7 +221,7 @@ def __init__(self,
LAYER_LEFT: data_lh,
LAYER_RIGHT: data_rh,
},
channel_coord_name=_HEXEL,
channel_coord_name=DIM_HEXEL,
channel_coord_dtype=HexelDType,
channel_coord_values=hexels,
)
Expand Down Expand Up @@ -256,8 +256,8 @@ def __getitem__(self, functions: str | Sequence[str]) -> HexelExpressionSet:
functions=functions,
hexels=self.hexels,
latencies=self.latencies,
data_lh=[self._data[LAYER_LEFT].sel({_FUNCTION: function}).data for function in functions],
data_rh=[self._data[LAYER_RIGHT].sel({_FUNCTION: function}).data for function in functions],
data_lh=[self._data[LAYER_LEFT].sel({DIM_FUNCTION: function}).data for function in functions],
data_rh=[self._data[LAYER_RIGHT].sel({DIM_FUNCTION: function}).data for function in functions],
)

def __copy__(self):
Expand Down Expand Up @@ -331,7 +331,7 @@ def __init__(self,
data_layers={
LAYER_SCALP: data
},
channel_coord_name=_SENSOR,
channel_coord_name=DIM_SENSOR,
channel_coord_dtype=SensorDType,
channel_coord_values=sensors,
)
Expand Down Expand Up @@ -392,7 +392,7 @@ def __getitem__(self, functions: str | Sequence[str]) -> SensorExpressionSet:
functions=functions,
sensors=self.sensors,
latencies=self.latencies,
data=[self._data[LAYER_SCALP].sel({_FUNCTION: function}).data for function in functions],
data=[self._data[LAYER_SCALP].sel({DIM_FUNCTION: function}).data for function in functions],
)

def best_functions(self) -> DataFrame:
Expand Down
21 changes: 21 additions & 0 deletions kymata/entities/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dataclasses import dataclass

from numpy.typing import NDArray


@dataclass
class Function:
name: str
values: NDArray
sample_rate: float # Hertz

def downsampled(self, rate: int):
return Function(
name=self.name,
values=self.values[::rate],
sample_rate=self.sample_rate / rate,
)

@property
def time_step(self) -> float:
return 1 / self.sample_rate
5 changes: 3 additions & 2 deletions kymata/entities/iterables.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Sequence

from numpy import ndarray
from numpy.typing import NDArray


def all_equal(sequence: Sequence) -> bool:
Expand All @@ -20,10 +21,10 @@ def all_equal(sequence: Sequence) -> bool:
# numpy arrays deal with equality weirdly

# Check first two items are equal, and equal to the rest
first: ndarray = sequence[0]
first: NDArray = sequence[0]
if not isinstance(sequence[1], ndarray):
return False
second: ndarray = sequence[1]
second: NDArray = sequence[1]
try:
# noinspection PyUnresolvedReferences
return (first == second).all() and all_equal(sequence[1:])
Expand Down
Empty file modified kymata/entities/sparse_data.py
100644 → 100755
Empty file.
Empty file added kymata/gridsearch/__init__.py
Empty file.
Loading

0 comments on commit 7cf2e78

Please sign in to comment.