diff --git a/sbi/inference/posteriors/ensemble_posterior.py b/sbi/inference/posteriors/ensemble_posterior.py index b182559cb..1b251439f 100644 --- a/sbi/inference/posteriors/ensemble_posterior.py +++ b/sbi/inference/posteriors/ensemble_posterior.py @@ -137,7 +137,7 @@ def weights(self, weights: Optional[Union[List[float], Tensor]]) -> None: self._weights = torch.tensor([ 1.0 / self.num_components for _ in range(self.num_components) ]) - elif weights is Tensor or weights is List: + elif isinstance(weights, (Tensor, List)): self._weights = torch.tensor(weights) / sum(weights) else: raise TypeError