You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Changing self.discrete_sigmas[timestep - 1].to(t.device) to self.discrete_sigmas.to(t.device)[timestep - 1] in this line of sde_lib.py seems to fix the problem.
RuntimeError Traceback (most recent call last)
Cell In[29], line 1
----> 1 x, n = sampling_fn(score_model)
2 show_samples(x)
File /workspace/pytorchcode/score_sde_pytorch-main/sampling.py:407, in get_pc_sampler..pc_sampler(model)
405 vec_t = torch.ones(shape[0], device=t.device) * t
406 x, x_mean = corrector_update_fn(x, vec_t, model=model)
--> 407 x, x_mean = predictor_update_fn(x, vec_t, model=model)
409 return inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1)
File /workspace/pytorchcode/score_sde_pytorch-main/sampling.py:341, in shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous)
339 else:
340 predictor_obj = predictor(sde, score_fn, probability_flow)
--> 341 return predictor_obj.update_fn(x, t)
File /workspace/pytorchcode/score_sde_pytorch-main/sampling.py:196, in ReverseDiffusionPredictor.update_fn(self, x, t)
195 def update_fn(self, x, t):
--> 196 f, G = self.rsde.discretize(x, t)
197 z = torch.randn_like(x)
198 x_mean = x - f
File /workspace/pytorchcode/score_sde_pytorch-main/sde_lib.py:104, in SDE.reverse..RSDE.discretize(self, x, t)
102 def discretize(self, x, t):
103 """Create discretized iteration rules for the reverse diffusion sampler."""
--> 104 f, G = discretize_fn(x, t)
105 rev_f = f - G[:, None, None, None] ** 2 * score_fn(x, t) * (0.5 if self.probability_flow else 1.)
106 rev_G = torch.zeros_like(G) if self.probability_flow else G
File /workspace/pytorchcode/score_sde_pytorch-main/sde_lib.py:251, in VESDE.discretize(self, x, t)
248 timestep = (t * (self.N - 1) / self.T).long()
249 sigma = self.discrete_sigmas.to(t.device)[timestep]
250 adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
--> 251 self.discrete_sigmas[timestep - 1].to(t.device))
252 f = torch.zeros_like(x)
253 G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)
The text was updated successfully, but these errors were encountered: