Skip to content

Commit

Permalink
Initial run with condition
Browse files Browse the repository at this point in the history
  • Loading branch information
jykr committed Nov 25, 2023
1 parent 5656d92 commit ee96e71
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 51 deletions.
6 changes: 4 additions & 2 deletions bean/framework/ReporterScreen.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def __init__(
self.layers["X_bcmatch"] = X_bcmatch
for k, df in self.uns.items():
if not isinstance(df, pd.DataFrame):
if k == "sample_covariates" and not isinstance(df, list):
self.uns[k] = df.tolist()
continue
if "guide" in df.columns and len(df) > 0:
if (
Expand Down Expand Up @@ -325,13 +327,13 @@ def __getitem__(self, index):
if k.startswith("repguide_mask"):
if "sample_covariates" in adata.uns:
adata.var["_rc"] = adata.var[
["rep"] + adata.uns["sample_covariates"]
["rep"] + list(adata.uns["sample_covariates"])
].values.tolist()
adata.var["_rc"] = adata.var["_rc"].map(
lambda slist: ".".join(slist)
)
new_uns[k] = df.loc[guides_include, adata.var._rc.unique()]
adata.var.pop("_rc")
#adata.var.pop("_rc")
else:
new_uns[k] = df.loc[guides_include, adata.var.rep.unique()]
if not isinstance(df, pd.DataFrame):
Expand Down
73 changes: 53 additions & 20 deletions bean/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,51 @@ def NormalModel(
sd = sd_alleles
sd = torch.repeat_interleave(sd, data.target_lengths, dim=0)
assert sd.shape == (data.n_guides, 1)

if data.sample_covariates is not None:
with pyro.plate("cov_place", data.n_sample_covariates):
mu_cov = pyro.sample("mu_cov", dist.Normal(0, 1))
assert mu_cov.shape == (data.n_sample_covariates,), mu_cov.shape
with replicate_plate:
with bin_plate as b:
uq = data.upper_bounds[b]
lq = data.lower_bounds[b]
assert uq.shape == lq.shape == (data.n_condits,)
# with guide_plate, poutine.mask(mask=(data.allele_counts.sum(axis=-1) == 0)):
with guide_plate:
mu = mu.unsqueeze(0).unsqueeze(0).expand(
(data.n_reps, data.n_condits, -1, -1)
) + (data.rep_by_cov * mu_cov)[:, 0].unsqueeze(-1).unsqueeze(
-1
).unsqueeze(
-1
).expand(
(-1, data.n_condits, data.n_guides, 1)
)
sd = torch.sqrt(
(
sd.unsqueeze(0)
.unsqueeze(0)
.expand((data.n_reps, data.n_condits, -1, -1))
)
)
alleles_p_bin = get_std_normal_prob(
uq.unsqueeze(-1).unsqueeze(-1).expand((-1, data.n_guides, 1)),
lq.unsqueeze(-1).unsqueeze(-1).expand((-1, data.n_guides, 1)),
mu.unsqueeze(0).expand((data.n_condits, -1, -1)),
sd.unsqueeze(0).expand((data.n_condits, -1, -1)),
uq.unsqueeze(0)
.unsqueeze(-1)
.unsqueeze(-1)
.expand((data.n_reps, -1, data.n_guides, 1)),
lq.unsqueeze(0)
.unsqueeze(-1)
.unsqueeze(-1)
.expand((data.n_reps, -1, data.n_guides, 1)),
mu,
sd,
)
assert alleles_p_bin.shape == (data.n_condits, data.n_guides, 1)

expected_allele_p = alleles_p_bin.unsqueeze(0).expand(
data.n_reps, -1, -1, -1
)
expected_guide_p = expected_allele_p.sum(axis=-1)
assert alleles_p_bin.shape == (
data.n_reps,
data.n_condits,
data.n_guides,
1,
)
expected_guide_p = alleles_p_bin.sum(axis=-1)
assert expected_guide_p.shape == (
data.n_reps,
data.n_condits,
Expand Down Expand Up @@ -158,14 +183,10 @@ def ControlNormalModel(data, mask_thres=10, use_bcmatch=True):
with pyro.plate("guide_plate3", data.n_guides, dim=-1):
a = get_alpha(expected_guide_p, data.size_factor, data.sample_mask, data.a0)

assert (
data.X.shape
== data.X_bcmatch.shape
== (
data.n_reps,
data.n_condits,
data.n_guides,
)
assert data.X.shape == (
data.n_reps,
data.n_condits,
data.n_guides,
)
with poutine.mask(
mask=torch.logical_and(
Expand Down Expand Up @@ -490,6 +511,18 @@ def NormalGuide(data):
constraint=constraints.positive,
)
pyro.sample("sd_alleles", dist.LogNormal(sd_loc, sd_scale))
if data.sample_covariates is not None:
with pyro.plate("cov_place", data.n_sample_covariates):
mu_cov_loc = pyro.param(
"mu_cov_loc", torch.zeros((data.n_sample_covariates,))
)
mu_cov_scale = pyro.param(
"mu_cov_scale",
torch.ones((data.n_sample_covariates,)),
constraint=constraints.positive,
)
mu_cov = pyro.sample("mu_cov", dist.Normal(mu_cov_loc, mu_cov_scale))
assert mu_cov.shape == (data.n_sample_covariates,), mu_cov.shape


def MixtureNormalGuide(
Expand Down
28 changes: 17 additions & 11 deletions bean/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,23 @@
def get_alpha(
expected_guide_p, size_factor, sample_mask, a0, epsilon=1e-5, normalize_by_a0=True
):
p = (
expected_guide_p.permute(0, 2, 1) * size_factor[:, None, :]
) # (n_reps, n_guides, n_bins)
if normalize_by_a0:
a = (
(p + epsilon / p.shape[-1])
/ (p.sum(axis=-1)[:, :, None] + epsilon)
* a0[None, :, None]
)
a = (a * sample_mask[:, None, :]).clamp(min=epsilon)
return a
try:
p = (
expected_guide_p.permute(0, 2, 1) * size_factor[:, None, :]
) # (n_reps, n_guides, n_bins)

if normalize_by_a0:
a = (
(p + epsilon / p.shape[-1])
/ (p.sum(axis=-1)[:, :, None] + epsilon)
* a0[None, :, None]
)
a = (a * sample_mask[:, None, :]).clamp(min=epsilon)
return a
except:
print(size_factor.shape)
print(expected_guide_p.shape)
print(a0.shape)
a = (p * sample_mask[:, None, :]).clamp(min=epsilon)
return a

Expand Down
43 changes: 38 additions & 5 deletions bean/preprocessing/data_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,34 @@ def __init__(
self.device = device
screen.samples["size_factor"] = self.get_size_factor(screen.X)
if not (
"rep" in screen.samples.columns
replicate_column in screen.samples.columns
and condition_column in screen.samples.columns
):
screen.samples["rep"], screen.samples[condition_column] = zip(
screen.samples[replicate_column], screen.samples[condition_column] = zip(
*screen.samples.index.map(lambda s: s.rsplit("_", 1))
)
if condition_column not in screen.samples.columns:
screen.samples[condition_column] = screen.samples["index"].map(
lambda s: s.split("_")[-1]
)

if "sample_covariates" in screen.uns:
self.sample_covariates = screen.uns["sample_covariates"]
self.n_sample_covariates = len(self.sample_covariates)
screen.samples["_rc"] = screen.samples[
[replicate_column] + self.sample_covariates
].values.tolist()
screen.samples["_rc"] = screen.samples["_rc"].map(
lambda slist: ".".join(slist)
)
self.rep_by_cov = torch.as_tensor(
(
screen.samples[["_rc"] + self.sample_covariates]
.drop_duplicates()
.set_index("_rc")
.values.astype(int)
)
)
replicate_column = "_rc"
self.screen = screen
if not control_can_be_selected:
self.screen_selected = screen[
Expand Down Expand Up @@ -146,7 +163,7 @@ def _post_init(
).all()
assert (
self.screen_selected.uns[self.repguide_mask].columns
== self.screen_selected.samples.rep.unique()
== self.screen_selected.samples[self.replicate_column].unique()
).all()
self.repguide_mask = (
torch.as_tensor(self.screen_selected.uns[self.repguide_mask].values.T)
Expand Down Expand Up @@ -182,6 +199,7 @@ def __getitem__(self, guide_idx):
ndata.X_masked = ndata.X_masked[:, :, guide_idx]
ndata.X_control = ndata.X_control[:, :, guide_idx]
ndata.repguide_mask = ndata.repguide_mask[:, guide_idx]
ndata.a0 = ndata.a0[guide_idx]
return ndata

def transform_data(self, X, n_bins=None):
Expand Down Expand Up @@ -905,9 +923,20 @@ def _pre_init(
self.screen.samples.loc[
self.screen_selected.samples.index, f"{self.condition_column}_id"
] = self.screen_selected.samples[f"{self.condition_column}_id"]
print(self.screen.samples.columns)
self.screen = _assign_rep_ids_and_sort(
self.screen, self.replicate_column, self.condition_column
)
print(self.screen.samples.columns)
if self.sample_covariates is not None:
self.rep_by_cov = torch.as_tensor(
(
self.screen.samples[["_rc"] + self.sample_covariates]
.drop_duplicates()
.set_index("_rc")
.values.astype(int)
)
)
self.screen_selected = _assign_rep_ids_and_sort(
self.screen_selected, self.replicate_column, self.condition_column
)
Expand Down Expand Up @@ -986,8 +1015,12 @@ def _post_init(
self.screen = _assign_rep_ids_and_sort(
self.screen, self.replicate_column, self.time_column
)
if self.sample_covariates is not None:
self.rep_by_cov = self.screen.samples.groupby(self.replicate_column)[
self.sample_covariates
].values
self.screen_selected = _assign_rep_ids_and_sort(
self.screen_selected, self.replicate_column, self.time_column
self.screen_selected, self.replicate_column, self.condition_column
)
self.screen_control = _assign_rep_ids_and_sort(
self.screen_control,
Expand Down
8 changes: 4 additions & 4 deletions bean/preprocessing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,10 @@ def _assign_rep_ids_and_sort(
sort_key = f"{rep_col}_id"
else:
sort_key = [f"{rep_col}_id", f"{condition_column}_id"]
screen = screen[
:,
screen.samples.sort_values(sort_key).index,
]
screen = screen[
:,
screen.samples.sort_values(sort_key).index,
]
return screen


Expand Down
20 changes: 17 additions & 3 deletions bin/bean-run
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import os
import sys
import logging
import warnings
from functools import partial
from copy import deepcopy
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -43,6 +45,11 @@ warn = logging.warning
debug = logging.debug
info = logging.info
pyro.set_rng_seed(101)
warnings.filterwarnings(
"ignore",
category=FutureWarning,
message=r".*is_categorical_dtype is deprecated and will be removed in a future version.*",
)


def main(args, bdata):
Expand Down Expand Up @@ -127,8 +134,15 @@ def main(args, bdata):
run_inference(model, guide, ndata, num_steps=args.n_iter)
)
if args.fit_negctrl:
negctrl_model = m.ControlNormalModel
negctrl_guide = m.ControlNormalGuide
negctrl_model = partial(
m.ControlNormalModel,
use_bcmatch=(not args.ignore_bcmatch and "X_bcmatch" in bdata.layers),
)
print((not args.ignore_bcmatch and "X_bcmatch" in bdata.layers))
negctrl_guide = partial(
m.ControlNormalGuide,
use_bcmatch=(not args.ignore_bcmatch and "X_bcmatch" in bdata.layers),
)
negctrl_idx = np.where(
guide_info_df[args.negctrl_col].map(lambda s: s.lower())
== args.negctrl_col_value.lower()
Expand All @@ -137,7 +151,7 @@ def main(args, bdata):
print(negctrl_idx.shape)
ndata_negctrl = ndata[negctrl_idx]
param_history_dict["negctrl"] = run_inference(
negctrl_model, negctrl_guide, ndata_negctrl
negctrl_model, negctrl_guide, ndata_negctrl, num_steps=args.n_iter
)

outfile_path = (
Expand Down
17 changes: 11 additions & 6 deletions notebooks/sample_quality_report.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,10 @@
"outputs": [],
"source": [
"if tiling is not None:\n",
" bdata.uns['tiling'] = tiling\n",
" bdata.uns[\"tiling\"] = tiling\n",
"if not isinstance(replicate_label, str):\n",
" bdata.uns['sample_covariates'] = replicate_label[1:]"
" bdata.uns[\"sample_covariates\"] = replicate_label[1:]\n",
"bdata.samples[replicate_label] = bdata.samples[replicate_label].astype(str)"
]
},
{
Expand Down Expand Up @@ -352,11 +353,15 @@
"metadata": {},
"outputs": [],
"source": [
"bdata.samples['mask'] = 1\n",
"bdata.samples.loc[bdata.samples.median_corr_X < corr_X_thres, 'mask'] = 0\n",
"bdata.samples[\"mask\"] = 1\n",
"bdata.samples.loc[\n",
" bdata.samples.median_corr_X.isnull() | (bdata.samples.median_corr_X < corr_X_thres), \"mask\"\n",
"] = 0\n",
"if \"median_editing_rate\" in bdata.samples.columns.tolist():\n",
" bdata.samples.loc[bdata.samples.median_editing_rate < edit_rate_thres, 'mask'] = 0\n",
"bdata_filtered = bdata[:, bdata.samples[f\"median_lfc_corr.{comp_cond1}_{comp_cond2}\"] > lfc_thres]"
" bdata.samples.loc[bdata.samples.median_editing_rate < edit_rate_thres, \"mask\"] = 0\n",
"bdata_filtered = bdata[\n",
" :, bdata.samples[f\"median_lfc_corr.{comp_cond1}_{comp_cond2}\"] > lfc_thres\n",
"]"
]
},
{
Expand Down

0 comments on commit ee96e71

Please sign in to comment.