diff --git a/python/pipeline/utils/clocktools.py b/python/pipeline/utils/clocktools.py index e704ca37..7a7076ee 100644 --- a/python/pipeline/utils/clocktools.py +++ b/python/pipeline/utils/clocktools.py @@ -80,52 +80,73 @@ def find_time_boundaries(indices, times, drop_single_idx=False): return time_boundaries - -def fetch_timing_data( +def scan_image_correction( scan_key: dict, - source_type: str, - target_type: str, - debug: bool = True, + debug: bool = False ): - """ - Fetches timing data for source and target recordings. Adjusts both timings based on any calculable delays. Returns two - arrays. Converts target recording times on target clock into target recording times on source clock if the two are different. - - Parameters: - - scan_key: A dictionary specifying a single scan and/or field. A single field must be defined if requesting - a source or target from ScanImage. If key specifies a single unit, unit delay will be added to - all timepoints recorded. Single units can be specified via unique mask_id + field or via unit_id. - If only field is specified, average field delay will be added. - - source_type: A string specifying what recording times to fetch for source_times. Both target and source times - will be returned on whatever clock is used for source_type. Fluorescence and deconvolution have - a dash followed by "behavior" or "stimulus" to refer to which clock you are using. - Supported options: - 'fluorescence-stimulus', 'deconvolution-stimulus', ,'fluorescence-behavior', - 'deconvolution-behavior', 'pupil', 'treadmill', 'respiration' + pipe = (fuse.MotionCorrection & scan_key).module + ## Fetch field information if unit_id is defined: + if 'unit_id' in scan_key and 'field' not in scan_key: + scan_key['field'] = (pipe.ScanSet.Unit & scan_key).fetch1('field') + ## Check scan_key defines a unique scan + if len(pipe.ScanInfo & scan_key) != 1: + msg = ( + f"scan_key {scan_key} does not define a unique scan. " + f"Matching scans found: {len(fuse.MotionCorrection & scan_key)}" + ) + raise PipelineException(msg) - target_type: A string specifying what recording times to fetch for target_times. Both target and source times - will be returned on whatever clock is used for source_type. Fluorescence and deconvolution have - a dash followed by "behavior" or "stimulus" to refer to which clock you are using. - Supported options: - 'fluorescence-stimulus', 'deconvolution-stimulus', ,'fluorescence-behavior', - 'deconvolution-behavior', 'pupil', 'treadmill', 'respiration' + ## Check a single field is defined by scan_key + if len(pipe.ScanInfo.Field & scan_key) != 1: + msg = ( + f"scan_key {scan_key} must specify a single field when source or target type is set " + f"to 'scan'. Matching fields found: {len(pipe.ScanInfo.Field & scan_key)}" + ) + raise PipelineException(msg) - debug: Set function to print helpful debug text while running + ## Determine field offset to slice times later on and set ms_delay to field average + scan_restriction = (pipe.ScanInfo & scan_key).fetch("KEY") + all_z = np.unique( + (pipe.ScanInfo.Field & scan_restriction).fetch("z", order_by="field ASC") + ) + slice_num = len(all_z) + field_z = (pipe.ScanInfo.Field & scan_key).fetch1("z") + field_offset = np.where(all_z == field_z)[0][0] + if debug: + print(f"Field offset found as {field_offset} for depths 0-{len(all_z)}") + field_delay_im = (pipe.ScanInfo.Field & scan_key).fetch1("delay_image") + average_field_delay = np.mean(field_delay_im) + ms_delay = average_field_delay + if debug: + print( + f"Average field delay found to be {round(ms_delay,4)}ms. This will be used unless a unit is specified in the key." + ) - Returns: + ## If included, add unit offset + if "unit_id" in scan_key or "mask_id" in scan_key: + if len(pipe.ScanSet.Unit & scan_key) > 0: + unit_key = (pipe.ScanSet.Unit & scan_key).fetch1() + ms_delay = (pipe.ScanSet.UnitInfo & unit_key).fetch1("ms_delay") + if debug: + print( + f"Unit found with delay of {round(ms_delay,4)}ms. Delay added to relevant times." + ) + else: + if debug: + print( + f"Warning: ScanSet.Unit is not populated for the given key! Using field offset minimum instead." + ) + return ms_delay, slice_num, field_offset - source_times: Numpy array of times for source recording on source clock - target_times: Numpy array of times for target recording on source clock - """ - +def fetch_timing_data( + scan_key: dict, + timing_type: str, # should be all lower_cases + debug: bool = True, +): ## Make settings strings lowercase - source_type = source_type.lower() - target_type = target_type.lower() - + timing_type = timing_type.lower() ## ## Set pipe, error check scan_key, and fetch field offset ## @@ -134,11 +155,6 @@ def fetch_timing_data( if len(fuse.MotionCorrection & scan_key) == 0: msg = f"scan_key {scan_key} not found in fuse.MotionCorrection." raise PipelineException(msg) - pipe = (fuse.MotionCorrection & scan_key).module - - ## Make strings lowercase and process indices - source_type = source_type.lower() - target_type = target_type.lower() ## Set default values for later processing field_offset = 0 @@ -152,57 +168,8 @@ def fetch_timing_data( "deconvolution-stimulus", "deconvolution-behavior", ) - if source_type in scan_types or target_type in scan_types: - - ## Check scan_key defines a unique scan - if len(pipe.ScanInfo & scan_key) != 1: - msg = ( - f"scan_key {scan_key} does not define a unique scan. " - f"Matching scans found: {len(fuse.MotionCorrection & scan_key)}" - ) - raise PipelineException(msg) - - ## Check a single field is defined by scan_key - if len(pipe.ScanInfo.Field & scan_key) != 1: - msg = ( - f"scan_key {scan_key} must specify a single field when source or target type is set " - f"to 'scan'. Matching fields found: {len(pipe.ScanInfo.Field & scan_key)}" - ) - raise PipelineException(msg) - - ## Determine field offset to slice times later on and set ms_delay to field average - scan_restriction = (pipe.ScanInfo & scan_key).fetch("KEY") - all_z = np.unique( - (pipe.ScanInfo.Field & scan_restriction).fetch("z", order_by="field ASC") - ) - slice_num = len(all_z) - field_z = (pipe.ScanInfo.Field & scan_key).fetch1("z") - field_offset = np.where(all_z == field_z)[0][0] - if debug: - print(f"Field offset found as {field_offset} for depths 0-{len(all_z)}") - - field_delay_im = (pipe.ScanInfo.Field & scan_key).fetch1("delay_image") - average_field_delay = np.mean(field_delay_im) - ms_delay = average_field_delay - if debug: - print( - f"Average field delay found to be {round(ms_delay,4)}ms. This will be used unless a unit is specified in the key." - ) - - ## If included, add unit offset - if "unit_id" in scan_key or "mask_id" in scan_key: - if len(pipe.ScanSet.Unit & scan_key) > 0: - unit_key = (pipe.ScanSet.Unit & scan_key).fetch1() - ms_delay = (pipe.ScanSet.UnitInfo & unit_key).fetch1("ms_delay") - if debug: - print( - f"Unit found with delay of {round(ms_delay,4)}ms. Delay added to relevant times." - ) - else: - if debug: - print( - f"Warning: ScanSet.Unit is not populated for the given key! Using field offset minimum instead." - ) + if timing_type in scan_types: + ms_delay, slice_num, field_offset = scan_image_correction(scan_key=scan_key, debug=debug) ## ## Fetch source and target sync data @@ -217,33 +184,54 @@ def fetch_timing_data( "treadmill": (treadmill.Treadmill, "treadmill_time"), "pupil": (pupil.Eye, "eye_time"), "respiration": (odor.Respiration * odor.MesoMatch, "times"), + "time-stimulus": None, } ## Error check inputs - if source_type not in data_source_lookup or target_type not in data_source_lookup: + if timing_type not in data_source_lookup: msg = ( - f"Source and target type combination '{source_type}' and '{target_type}' not supported. " + f"Timing type '{timing_type}' not supported. " f"Valid values are 'fluorescence-behavior', 'fluorescence-stimulus', 'deconvolution-behavior', " f"'deconvolution-stimulus', treadmill', 'respiration' or 'pupil'." ) raise PipelineException(msg) ## Fetch source and target times using lookup dictionary - source_table, source_column = data_source_lookup[source_type] - source_times = (source_table & scan_key).fetch1(source_column).squeeze() - - target_table, target_column = data_source_lookup[target_type] - target_times = (target_table & scan_key).fetch1(target_column).squeeze() + if timing_type not in ['time-stimulus',]: + timing_table, timing_column = data_source_lookup[timing_type] + timing = (timing_table & scan_key).fetch1(timing_column).squeeze() + else: + timing = np.concatenate((stimulus.Trial & scan_key).fetch('flip_times', squeeze=True)) ## ## Timing corrections ## ## Slice times if on ScanImage clock and add delay (scan_types defined near top) - if source_type in scan_types: - source_times = source_times[field_offset::slice_num] + ms_delay - if target_type in scan_types: - target_times = target_times[field_offset::slice_num] + ms_delay + if timing_type in scan_types: + timing = timing[field_offset::slice_num] + ms_delay + + return timing + + +def interpolate_timing_data( + scan_key: dict, + source_type: str, + target_type: str, + source_times_source_clock = None, + target_times_target_clock = None, + debug: bool = True, +): + + ## Make settings strings lowercase + source_type = source_type.lower() + target_type = target_type.lower() + + ## Fetch times with delays added if applicable + if source_times_source_clock is None: + source_times_source_clock = fetch_timing_data(scan_key=scan_key, timing_type=source_type, debug=debug) + if target_times_target_clock is None: + target_times_target_clock = fetch_timing_data(scan_key=scan_key, timing_type=target_type, debug=debug) ## ## Interpolate into different clock if necessary @@ -259,6 +247,7 @@ def fetch_timing_data( "treadmill": "behavior", "processed-treadmill": "behavior", "respiration": "odor", + "time-stimulus": "stimulus", } sync_conversion_lookup = { @@ -281,17 +270,18 @@ def fetch_timing_data( target2source_interp = interpolate.interp1d( interp_target, interp_source, fill_value="extrapolate" ) - target_times = target2source_interp(target_times) - - return source_times, target_times + target_times_source_clock = target2source_interp(target_times_target_clock) + else: + target_times_source_clock = target_times_target_clock + return source_times_source_clock, target_times_source_clock def interpolate_signal_data( scan_key: dict, source_type: str, target_type: str, - source_times, - target_times, + source_times_source_clock, + target_times_source_clock, debug: bool = True, ): """ @@ -363,6 +353,8 @@ def interpolate_signal_data( unit_key = (pipe.ScanSet.Unit & scan_key).fetch1() target_signal = (pipe.Activity.Trace & unit_key).fetch1("trace") elif target_type == "pupil": + unique_keys = (set(pupil.FittedPupil.Circle.heading.primary_key) - set(['frame_id'])) + assert len(dj.U(*unique_keys) & (pupil.FittedPupil.Circle & scan_key)) == 1, 'More than 1 pupil method found!' target_signal = (pupil.FittedPupil.Circle & scan_key).fetch("radius") elif target_type == "treadmill": target_signal = (treadmill.Treadmill & scan_key).fetch1("treadmill_vel") @@ -373,23 +365,25 @@ def interpolate_signal_data( raise PipelineException(msg) ## Calculate FPS to determine if lowpass filtering is needed - source_fps = 1 / np.nanmedian(np.diff(source_times)) - target_fps = 1 / np.nanmedian(np.diff(target_times)) + source_fps = 1 / np.nanmedian(np.diff(source_times_source_clock)) + target_fps = 1 / np.nanmedian(np.diff(target_times_source_clock)) ## Fill NaNs to prevent interpolation errors, but store NaNs for later to add back in after interpolating source_replace_nans = None # Use this as a switch to refill things later if sum(np.isnan(target_signal)) > 0: target_nan_indices = np.isnan(target_signal) - time_nan_indices = np.isnan(target_times) + time_nan_indices = np.isnan(target_times_source_clock) target_replace_nans = np.logical_and(target_nan_indices, ~time_nan_indices) if sum(target_replace_nans) > 0: source_replace_nans = convert_clocks( scan_key, np.where(target_replace_nans)[0], "indices", - target_type, + source_type, "indices", source_type, + source_times=target_times_source_clock, + target_times=source_times_source_clock, debug=False, ) nan_filler_func = ( @@ -398,7 +392,7 @@ def interpolate_signal_data( target_signal = nan_filler_func(target_signal) if debug: biggest_time_gap = np.nanmax( - np.diff(target_times[np.where(~target_replace_nans)[0]]) + np.diff(target_times_source_clock[np.where(~target_replace_nans)[0]]) ) msg = ( f"Found NaNs in {sum(target_nan_indices)} locations, which corresponds to " @@ -422,32 +416,32 @@ def interpolate_signal_data( ## Timing and recording array lengths can differ slightly if recording was stopped mid-scan. Timings for ## the next X depths would be recorded, but fluorescence values would be dropped if all depths were not ## recorded. This would mean timings difference shouldn't be more than the number of depths of the scan. - if len(target_times) < len(target_signal): + if len(target_times_source_clock) < len(target_signal): msg = ( f"More recording values than target time values exist! This should not be possible.\n" - f"Target time length: {len(target_times)}\n" + f"Target time length: {len(target_times_source_clock)}\n" f"Target signal length: {len(target_signal)}" ) raise PipelineException(msg) - elif len(target_times) > len(target_signal): + elif len(target_times_source_clock) > len(target_signal): scan_res = pipe.ScanInfo.proj() & scan_key ## To make sure we select all fields z_plane_num = len(dj.U("z") & (pipe.ScanInfo.Field & scan_res)) - if (len(target_times) - len(target_signal)) > z_plane_num: + if (len(target_times_source_clock) - len(target_signal)) > z_plane_num: msg = ( f"Extra timing values exceeds reasonable error bounds. " - f"Error length of {len(target_times) - len(target_signal)} with only {z_plane_num} z-planes." + f"Error length of {len(target_times_source_clock) - len(target_signal)} with only {z_plane_num} z-planes." ) raise PipelineException(msg) else: - shorter_length = np.min((len(target_times), len(target_signal))) - source_times = target_times[:shorter_length] + shorter_length = np.min((len(target_times_source_clock), len(target_signal))) + source_times_source_clock = target_times_source_clock[:shorter_length] source_signal = target_signal[:shorter_length] if debug: - length_diff = np.abs(len(target_times) - len(target_signal)) + length_diff = np.abs(len(target_times_source_clock) - len(target_signal)) msg = ( f"Target times and target signal show length mismatch within acceptable error." f"Difference of {length_diff} within acceptable bounds of {z_plane_num} z-planes." @@ -456,17 +450,37 @@ def interpolate_signal_data( ## Interpolating target signal into source timings signal_interp = interpolate.interp1d( - target_times, target_signal, bounds_error=False + target_times_source_clock, target_signal, bounds_error=False ) with warnings.catch_warnings(): warnings.simplefilter("ignore") - interpolated_signal = signal_interp(source_times) + interpolated_signal = signal_interp(source_times_source_clock) if source_replace_nans is not None: for source_nan_idx in source_replace_nans: interpolated_signal[source_nan_idx] = np.nan return interpolated_signal +def sample_trace( + scan_key: dict, + sample_times, + target_type: str, + debug: bool = True, +): + target_times = fetch_timing_data(scan_key=scan_key, timing_type=target_type, debug=debug) + source_times_source_clock, target_times_source_clock = interpolate_timing_data( + scan_key=scan_key, source_type='time-stimulus', target_type=target_type, + source_times_source_clock=sample_times, target_times_target_clock=target_times, debug=debug + ) + interpolated_signal = interpolate_signal_data( + scan_key=scan_key, + source_type='time-stimulus', + target_type=target_type, + source_times_source_clock=source_times_source_clock, + target_times_source_clock=target_times_source_clock, + debug=debug + ) + return interpolated_signal def convert_clocks( scan_key: dict, @@ -475,6 +489,8 @@ def convert_clocks( source_type: str, target_format: str, target_type: str, + source_times = None, + target_times = None, drop_single_idx: bool = True, debug: bool = True, ): @@ -619,11 +635,11 @@ def convert_clocks( ## Fetch source and target times, along with converting between Stimulus or Behavior clock if needed ## - source_times_source_clock, target_times_source_clock = fetch_timing_data( - scan_key, source_type, target_type, debug + source_times_source_clock, target_times_source_clock = interpolate_timing_data( + scan_key, source_type, target_type, debug=debug, source_times_source_clock=source_times, target_times_target_clock=target_times ) - target_times_target_clock, source_times_target_clock = fetch_timing_data( - scan_key, target_type, source_type, debug + target_times_target_clock, source_times_target_clock = interpolate_timing_data( + scan_key, target_type, source_type, debug=debug, source_times_source_clock=source_times, target_times_target_clock=target_times ) ## @@ -708,14 +724,14 @@ def convert_clocks( source_indices = find_idx_boundaries(input_list, drop_single_idx) elif "times" in source_format: source_indices = convert_clocks( - scan_key, - input_list, - source_format, - source_type, - "indices", - target_type, - drop_single_idx, - False, + scan_key=scan_key, + input_list=input_list, + source_format=source_format, + source_type=source_type, + target_format="indices", + target_type=target_type, + drop_single_idx=drop_single_idx, + debug=False, ) else: msg = ( @@ -733,7 +749,6 @@ def convert_clocks( target_times_source_clock, debug=debug, ) - ## Split indices given into fragments based on which ones are continuous (incrementing by 1) target_signal_fragments = [] for idx_fragment in source_indices: