diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index a3e1420db..85ea89532 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -32,6 +32,7 @@ import warnings import numpy as np +from typyng import Any import jax from jax import lax, tree_util @@ -572,7 +573,7 @@ class DistributionLike(Protocol): or tensorflow_probability.distributions.Distribution. """ - def __call__(self, *args: functools.Any, **kwds: functools.Any) -> functools.Any: + def __call__(self, *args: Any, **kwds: Any) -> Any: return super().__call__(*args, **kwds) @property