From 4dfee4382887ed90b24102397daf36471c84f1ca Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 15 Nov 2024 19:34:49 +0100 Subject: [PATCH] rsample --- src/gluonts/torch/distributions/generalized_pareto.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/gluonts/torch/distributions/generalized_pareto.py b/src/gluonts/torch/distributions/generalized_pareto.py index 702dc0ca4c..5e868f7aa5 100644 --- a/src/gluonts/torch/distributions/generalized_pareto.py +++ b/src/gluonts/torch/distributions/generalized_pareto.py @@ -52,7 +52,7 @@ class GeneralizedPareto(Distribution): "scale": constraints.positive, "concentration": constraints.real, } - has_rsample = False + has_rsample = True def __init__(self, loc, scale, concentration, validate_args=None): self.loc, self.scale, self.concentration = broadcast_all( @@ -80,11 +80,10 @@ def expand(self, batch_shape, _instance=None): new._validate_args = self._validate_args return new - def sample(self, sample_shape=torch.Size()): + def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) - with torch.no_grad(): - u = torch.rand(shape, dtype=self.loc.dtype, device=self.loc.device) - return self.icdf(u) + u = torch.rand(shape, dtype=self.loc.dtype, device=self.loc.device) + return self.icdf(u) def log_prob(self, value): if self._validate_args: