Skip to content

Commit

Permalink
Rename and add functions
Browse files Browse the repository at this point in the history
  • Loading branch information
bjhardcastle committed Aug 16, 2024
1 parent 2fc4fcd commit cc27acd
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 62 deletions.
10 changes: 5 additions & 5 deletions src/npc_sessions_cache/figures/paper2/fig1c.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import numpy.typing as npt
import polars as pl
import unit_utils
import utils


def plot(session_id: str) -> plt.Figure:
Expand All @@ -17,10 +17,10 @@ def plot(session_id: str) -> plt.Figure:
except (AttributeError, TypeError):
session_id = npc_session.SessionRecord(session_id).id

licks_all_sessions = unit_utils.get_component_zarr('licks')
trials_all_sessions = unit_utils.get_component_df('trials')
all_sessions = unit_utils.get_component_df('session')
performance_all_sessions = unit_utils.get_component_df('performance')
licks_all_sessions = utils.get_component_zarr('licks')
trials_all_sessions = utils.get_component_df('trials')
all_sessions = utils.get_component_df('session')
performance_all_sessions = utils.get_component_df('performance')

performance = performance_all_sessions.filter(pl.col('session_id') == session_id)
trials = trials_all_sessions.filter(pl.col('session_id') == session_id)
Expand Down
74 changes: 18 additions & 56 deletions src/npc_sessions_cache/figures/paper2/fig1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy.typing as npt
import polars as pl

import unit_utils
import utils


def get_rate_expr(stim: str, is_target: bool):
Expand All @@ -17,55 +17,9 @@ def get_rate_expr(stim: str, is_target: bool):
return (response_trials / total_trials).over(['session_id', 'block_index'])

def plot() -> plt.Figure:

trials_all_sessions = unit_utils.get_component_df('trials')
all_sessions = unit_utils.get_component_df('session')
performance_all_sessions = unit_utils.get_component_df('performance')


df = (
trials_all_sessions
# exclude templeton sessions
.join(
other=all_sessions.filter(~pl.col('stimulus_notes').str.to_lowercase().str.contains('templeton')),
on='session_id',
how='semi',
)
# exclude sessions based on task performance:
.join(
other=(
performance_all_sessions
.filter(
# pl.col('same_modal_dprime') > 1.5,
pl.col('cross_modal_dprime') > 1.5,
)
.with_columns(
pl.col('block_index').count().over('session_id').alias('n_passing_blocks'),
)
.filter(
pl.col('n_passing_blocks') > 3,
)
),
on='session_id',
how='semi',
)
# .lazy()
# basic filtering on trial type: exclude autoreward trials:
.filter(
~pl.col('is_reward_scheduled'),
)
# filter blocks with too few trials:
.filter(
pl.col('trial_index_in_block').max().over('session_id', 'block_index') > 20,
)
# filter sessions with too few blocks:
.filter(
pl.col('block_index').n_unique().over('session_id') == 6,
pl.col('block_index').max().over('session_id') == 5,
)
# add a column that indicates if the first block in a session is aud context:
.with_columns(
(pl.col('context_name').first() == 'aud').over('session_id').alias('is_first_block_aud'),
)
utils.get_prod_trials(cross_modal_dprime_threshold=1.5)
# calculate response rates in each block:
.with_columns(
get_rate_expr(stim='vis', is_target=True).alias(a := 'vis_target_response_rate'),
Expand All @@ -88,8 +42,9 @@ def plot() -> plt.Figure:

is_first_block_aud = False
is_boxplot = False
is_mean_marker = True
is_mean_line = True
is_ci_lines = True
is_median_marker = False
is_median_line = True

fig, axes = plt.subplots(1, df.n_unique('block_index'), figsize=(df.n_unique('block_index') * 2, 3), sharey=True)

Expand Down Expand Up @@ -122,10 +77,10 @@ def plot() -> plt.Figure:
for target in ('nontarget', 'target'):
y = block_df.select(*[f"{modality}_{target}_response_rate" for modality in modalitites]).to_numpy().T
ax.plot(xpos, y, **line_params[target])
if is_mean_marker:
ax.plot(xpos, np.mean(y, axis=-1), '+', c=-0.1 + np.array(line_params[target]['c']), ms=6, zorder=99)
if is_mean_line:
ax.plot(xpos, np.mean(y, axis=-1), c=-0.1 + np.array(line_params[target]['c']), lw=1.5, zorder=99)
if is_median_marker:
ax.plot(xpos, np.median(y, axis=-1), '+', c=-0.1 + np.array(line_params[target]['c']), ms=6, zorder=99)
if is_median_line:
ax.plot(xpos, np.median(y, axis=-1), c=-0.1 + np.array(line_params[target]['c']), lw=1.5, zorder=99)

box_data.extend([y[0, :].flatten(), y[1, :].flatten()])

Expand All @@ -146,7 +101,14 @@ def plot() -> plt.Figure:
)
for box in boxplot_objects['boxes']:
box.set_edgecolor([0.4]*3)


if is_ci_lines:
import matplotlib.cbook
stats = matplotlib.cbook.boxplot_stats(box_data, bootstrap=10_000)
for i, stat in enumerate(stats):
ax.plot([xpos[i % 2]]*2, [stat['cilo'], stat['cihi']], c=[0.4]*3, lw=1.5, zorder=98)
#TODO save stats

context = ('vis', 'aud')[(block_index + is_first_block_aud) % 2]
ax.set_title('V' if context == 'vis' else 'A')
ax.set_aspect(1.5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

CCF_MIDLINE_ML = 5700

CACHE_VERSION = 'v0.0.231'
CACHE_VERSION = 'v0.0.232'

@functools.cache
def get_component_lf(nwb_component: npc_lims.NWBComponentStr) -> pl.LazyFrame:
Expand Down Expand Up @@ -115,6 +115,66 @@ def get_good_units_df() -> pl.DataFrame:
).collect()
logger.info(f"Fetched {len(good_units)} good units")
return good_units

def get_prod_trials(cross_modal_dprime_threshold: float = 1.0) -> pl.DataFrame:
return (
get_component_df('trials')
.join(
other=(
get_component_df('session')
.filter(
pl.col('keywords').list.contains('production'),
~pl.col('keywords').list.contains('opto_perturbation'),
~pl.col('keywords').list.contains('injection_perturbation'),
~pl.col('keywords').list.contains('hab'),
~pl.col('keywords').list.contains('training'),
~pl.col('keywords').list.contains('context_naive'),
~pl.col('keywords').list.contains('templeton'),
)
),
on='session_id',
how='semi',
)
# exclude sessions based on task performance:
.join(
other=(
get_component_df('performance')
.filter(
# pl.col('same_modal_dprime') > 1.0,
pl.col('cross_modal_dprime') > cross_modal_dprime_threshold,
)
.with_columns(
pl.col('block_index').count().over('session_id').alias('n_passing_blocks'),
)
.filter(
pl.col('n_passing_blocks') > 3,
)
),
on='session_id',
how='semi',
)
# .lazy()
# basic filtering on trial type: exclude autoreward trials:
.filter(
~pl.col('is_reward_scheduled'),
)
# filter blocks with too few trials:
.with_columns(
pl.col('trial_index_in_block').max().over('session_id', 'block_index').alias('n_trials_in_block'),
)
.filter(
pl.col('n_trials_in_block') > 10,
)
# filter sessions with too few blocks:
.filter(
pl.col('block_index').n_unique().over('session_id') == 6,
pl.col('block_index').max().over('session_id') == 5,
)
# add a column that indicates if the first block in a session is aud context:
.with_columns(
(pl.col('context_name').first() == 'aud').over('session_id').alias('is_first_block_aud'),
)
)

def copy_parquet_files_to_home() -> None:
for component in ('units', 'session', 'subject', 'trials', 'epochs', 'performance', 'devices', 'electrode_groups', 'electrodes'):
Expand Down

0 comments on commit cc27acd

Please sign in to comment.