diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 85ea89532..a352e9b73 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -28,11 +28,10 @@ from contextlib import contextmanager import functools import inspect -from typing import Protocol, runtime_checkable +from typing import Any, Protocol, runtime_checkable import warnings import numpy as np -from typyng import Any import jax from jax import lax, tree_util