Skip to content

Commit

Permalink
set correct mu0 for nonedited
Browse files Browse the repository at this point in the history
  • Loading branch information
jykr committed Aug 29, 2024
1 parent d526512 commit 5359f46
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion bean/model/survival_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def MixtureNormalModel(
with pyro.plate("guide_plate1", data.n_targets):
mu_targets = pyro.sample("mu_targets", mu_dist)
mu_negctrl = pyro.param("mu_negctrl", torch.tensor(0.0))
mu_center = torch.cat([torch.zeros((data.n_targets, 1)), mu_targets], axis=-1)
mu_center = torch.cat([mu_negctrl.expand(data.n_targets, 1), mu_targets], axis=-1)
mu_center[data.negctrl_guide_idx, :] = 0.0
mu = torch.repeat_interleave(mu_center + mu_negctrl, data.target_lengths, dim=0)
assert mu.shape == (data.n_guides, 2)
Expand Down

0 comments on commit 5359f46

Please sign in to comment.