Skip to content

Commit

Permalink
Specify data sample rate, and warn user if function will be resampled
Browse files Browse the repository at this point in the history
  • Loading branch information
caiw committed Oct 18, 2024
1 parent f3cc0ff commit 74569d7
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 8 deletions.
3 changes: 2 additions & 1 deletion dataset_config/dataset3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ participants: [
]
number_of_runs: 2 # number of runs <- 15-minute recording block from scanner
repetitions_per_runs: 2 # number of repetitions per run <- repeated stimulus presentations per run
stimulus_length: 400 # seconds
stimulus_length: 400 # seconds
mri_structural_type: "T1" # T1 | flash
sample_rate: 1000 # Hz

# Preprocessing pipeline
emeg_machine_used_to_record_data: 'vectorview'
Expand Down
1 change: 1 addition & 0 deletions dataset_config/dataset4.1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ number_of_runs: 4 # number of runs
repetitions_per_runs: 2 # number of repetitions per run
stimulus_length: 400 # seconds
mri_structural_type: "T1" # T1 | flash
sample_rate: 1000 # Hz

# Preprocessing pipeline
emeg_machine_used_to_record_data: 'triux'
Expand Down
1 change: 1 addition & 0 deletions dataset_config/dataset4.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ number_of_runs: 4 # number of runs
repetitions_per_runs: 2 # number of repetitions per run
stimulus_length: 400 # seconds
mri_structural_type: "T1" # T1 | flash
sample_rate: 1000 # Hz

# Preprocessing pipeline
emeg_machine_used_to_record_data: 'triux'
Expand Down
8 changes: 3 additions & 5 deletions kymata/gridsearch/plain.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def do_gridsearch(
emeg_t_start: float, # ms
stimulus_shift_correction: float, # seconds/second
stimulus_delivery_latency: float, # seconds
emeg_sample_rate: float, # Hertz
plot_location: Optional[Path] = None,
emeg_sample_rate: int = 1000, # Hertz
n_derangements: int = 1,
seconds_per_split: float = 1,
n_splits: int = 400,
Expand Down Expand Up @@ -60,7 +60,7 @@ def do_gridsearch(
stimulus_delivery_latency (float): Correction offset for stimulus delivery in seconds.
plot_location (Optional[Path], optional): Path to save the plot of the top five channels of the
grid search. If None, plotting is skipped. Default is None.
emeg_sample_rate (int, optional): The sample rate of the EMEG data in Hertz. Default is 1000 Hz.
emeg_sample_rate (float, optional): The sample rate of the EMEG data in Hertz.
n_derangements (int, optional): Number of derangements (random permutations) used to create the
null distribution. Default is 1.
seconds_per_split (float, optional): Duration of each split in seconds. Default is 0.5 seconds.
Expand Down Expand Up @@ -142,9 +142,7 @@ def do_gridsearch(
* seconds_per_split
* (1 + stimulus_shift_correction)
) # splits, stretched by the shift correction
+ round(
stimulus_delivery_latency * emeg_sample_rate
) # correct for stimulus delivery latency delay
+ round(stimulus_delivery_latency * emeg_sample_rate) # correct for stimulus delivery latency delay
)
for i in range(n_splits)
]
Expand Down
12 changes: 10 additions & 2 deletions kymata/invokers/run_gridsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ def main():
function_path = Path(base_dir, args.function_path)
_logger.info(f"Loading functions from {str(function_path)}")

emeg_sample_rate = float(dataset_config.get("sample_rate", 1000))

for function_name in args.function_name:
_logger.info(f"Running gridsearch on {function_name}")
function = load_function(
Expand All @@ -223,8 +225,13 @@ def main():
bruce_neurons=(5, 10),
sample_rate=args.function_sample_rate,
)
if args.resample is not None and args.function_sample_rate != args.resample:
function = function.resampled(args.resample)

# Resample function to match target sample rate if specified, else emeg sample rate
function_resample_rate = args.resample if args.resample is not None else emeg_sample_rate
if function.sample_rate != function_resample_rate:
_logger.info(f"Function sample rate ({function.sample_rate} Hz) doesn't match target sample rate "
f"({function_resample_rate} Hz). Function will be resampled to match.")
function.resampled(function_resample_rate)

es = do_gridsearch(
emeg_values=emeg_values,
Expand All @@ -235,6 +242,7 @@ def main():
n_derangements=args.n_derangements,
n_splits=args.n_splits,
n_reps=n_reps,
emeg_sample_rate=emeg_sample_rate,
start_latency=args.start_latency,
plot_location=args.save_plot_location,
emeg_t_start=args.emeg_t_start,
Expand Down

0 comments on commit 74569d7

Please sign in to comment.