From 5359f466a3f0086b3779c9249c37f264e238ebb6 Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Thu, 29 Aug 2024 14:37:02 +0000 Subject: [PATCH] set correct mu0 for nonedited --- bean/model/survival_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bean/model/survival_model.py b/bean/model/survival_model.py index 118cec3..6956f70 100755 --- a/bean/model/survival_model.py +++ b/bean/model/survival_model.py @@ -266,7 +266,7 @@ def MixtureNormalModel( with pyro.plate("guide_plate1", data.n_targets): mu_targets = pyro.sample("mu_targets", mu_dist) mu_negctrl = pyro.param("mu_negctrl", torch.tensor(0.0)) - mu_center = torch.cat([torch.zeros((data.n_targets, 1)), mu_targets], axis=-1) + mu_center = torch.cat([mu_negctrl.expand(data.n_targets, 1), mu_targets], axis=-1) mu_center[data.negctrl_guide_idx, :] = 0.0 mu = torch.repeat_interleave(mu_center + mu_negctrl, data.target_lengths, dim=0) assert mu.shape == (data.n_guides, 2)