From 653d0b41bcc07a709089a1218637f4c46cd7e341 Mon Sep 17 00:00:00 2001 From: Nicolas Chartier <47000650+CompiledAtBirth@users.noreply.github.com> Date: Tue, 12 Nov 2024 21:45:55 +0900 Subject: [PATCH] Fix type checks for EnsemblePosterior weights (#1299) * minor fix for EnsemblePosterior weights.setter * Update sbi/inference/posteriors/ensemble_posterior.py Co-authored-by: Jan --------- Co-authored-by: Jan --- sbi/inference/posteriors/ensemble_posterior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbi/inference/posteriors/ensemble_posterior.py b/sbi/inference/posteriors/ensemble_posterior.py index b182559cb..1b251439f 100644 --- a/sbi/inference/posteriors/ensemble_posterior.py +++ b/sbi/inference/posteriors/ensemble_posterior.py @@ -137,7 +137,7 @@ def weights(self, weights: Optional[Union[List[float], Tensor]]) -> None: self._weights = torch.tensor([ 1.0 / self.num_components for _ in range(self.num_components) ]) - elif weights is Tensor or weights is List: + elif isinstance(weights, (Tensor, List)): self._weights = torch.tensor(weights) / sum(weights) else: raise TypeError