-
Notifications
You must be signed in to change notification settings - Fork 247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add type hints to the primitives module #1940
base: master
Are you sure you want to change the base?
Changes from 9 commits
1a2c51f
1f6e471
f03da11
75abe66
02e149a
ebc510c
ba9a07a
a14adaa
a810776
b4dacc0
ee126ef
32df593
d29507a
cc6eb4b
5945fa8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
from typing import Union | ||
|
||
import numpy as np | ||
import numpy.typing as npt | ||
|
||
import jax | ||
from jax import device_get | ||
|
@@ -25,7 +26,7 @@ | |
] | ||
|
||
|
||
def _compute_chain_variance_stats(x: np.ndarray) -> tuple[np.ndarray, np.ndarray]: | ||
def _compute_chain_variance_stats(x: npt.NDArray) -> tuple[npt.NDArray, npt.NDArray]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why NDArray is used instead of ndarray? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. MyPy started complaining while I was adding hints and decided to follow the typing guidelines from numpy https://numpy.org/doc/stable/reference/typing.html :) Then all errors were gone |
||
# compute within-chain variance and variance estimator | ||
# input has shape C x N x sample_shape | ||
C, N = x.shape[:2] | ||
|
@@ -41,7 +42,7 @@ def _compute_chain_variance_stats(x: np.ndarray) -> tuple[np.ndarray, np.ndarray | |
return var_within, var_estimator | ||
|
||
|
||
def gelman_rubin(x: np.ndarray) -> np.ndarray: | ||
def gelman_rubin(x: npt.NDArray) -> npt.NDArray: | ||
""" | ||
Computes R-hat over chains of samples ``x``, where the first dimension of | ||
``x`` is chain dimension and the second dimension of ``x`` is draw dimension. | ||
|
@@ -60,7 +61,7 @@ def gelman_rubin(x: np.ndarray) -> np.ndarray: | |
return rhat | ||
|
||
|
||
def split_gelman_rubin(x: np.ndarray) -> np.ndarray: | ||
def split_gelman_rubin(x: npt.NDArray) -> npt.NDArray: | ||
""" | ||
Computes split R-hat over chains of samples ``x``, where the first dimension | ||
of ``x`` is chain dimension and the second dimension of ``x`` is draw dimension. | ||
|
@@ -97,7 +98,7 @@ def _fft_next_fast_len(target: int) -> int: | |
target += 1 | ||
|
||
|
||
def autocorrelation(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarray: | ||
def autocorrelation(x: npt.NDArray, axis: int = 0, bias: bool = True) -> npt.NDArray: | ||
""" | ||
Computes the autocorrelation of samples at dimension ``axis``. | ||
|
||
|
@@ -137,11 +138,11 @@ def autocorrelation(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarr | |
autocorr = autocorr / np.arange(N, 0.0, -1) | ||
|
||
with np.errstate(invalid="ignore", divide="ignore"): | ||
autocorr = autocorr / autocorr[..., :1] | ||
autocorr = (autocorr / autocorr[..., :1]).astype(np.float64) | ||
return np.swapaxes(autocorr, axis, -1) | ||
|
||
|
||
def autocovariance(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarray: | ||
def autocovariance(x: npt.NDArray, axis: int = 0, bias: bool = True) -> npt.NDArray: | ||
""" | ||
Computes the autocovariance of samples at dimension ``axis``. | ||
|
||
|
@@ -154,7 +155,7 @@ def autocovariance(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarra | |
return autocorrelation(x, axis, bias) * x.var(axis=axis, keepdims=True) | ||
|
||
|
||
def effective_sample_size(x: np.ndarray, bias: bool = True) -> np.ndarray: | ||
def effective_sample_size(x: npt.NDArray, bias: bool = True) -> npt.NDArray: | ||
""" | ||
Computes effective sample size of input ``x``, where the first dimension of | ||
``x`` is chain dimension and the second dimension of ``x`` is draw dimension. | ||
|
@@ -202,7 +203,7 @@ def effective_sample_size(x: np.ndarray, bias: bool = True) -> np.ndarray: | |
return n_eff | ||
|
||
|
||
def hpdi(x: np.ndarray, prob: float = 0.90, axis: int = 0) -> np.ndarray: | ||
def hpdi(x: npt.NDArray, prob: float = 0.90, axis: int = 0) -> npt.NDArray: | ||
""" | ||
Computes "highest posterior density interval" (HPDI) which is the narrowest | ||
interval with probability mass ``prob``. | ||
|
@@ -285,7 +286,7 @@ def summary( | |
|
||
|
||
def print_summary( | ||
samples: Union[dict, np.ndarray], prob: float = 0.90, group_by_chain: bool = True | ||
samples: Union[dict, npt.NDArray], prob: float = 0.90, group_by_chain: bool = True | ||
) -> None: | ||
""" | ||
Prints a summary table displaying diagnostics of ``samples`` from the | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,19 +24,20 @@ | |
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | ||
# POSSIBILITY OF SUCH DAMAGE. | ||
|
||
from collections import OrderedDict | ||
from contextlib import contextmanager | ||
import functools | ||
import inspect | ||
from typing import Protocol, runtime_checkable | ||
import warnings | ||
|
||
import numpy as np | ||
|
||
import jax | ||
from jax import lax, tree_util | ||
from jax import Array, lax, tree_util | ||
import jax.numpy as jnp | ||
from jax.scipy.special import logsumexp | ||
from jax.typing import ArrayLike | ||
|
||
from numpyro.distributions.transforms import AbsTransform, ComposeTransform, Transform | ||
from numpyro.distributions.util import ( | ||
|
@@ -270,7 +271,7 @@ def validate_args(self, strict: bool = True) -> None: | |
raise RuntimeError("Cannot validate arguments inside jitted code.") | ||
|
||
@property | ||
def batch_shape(self): | ||
def batch_shape(self) -> tuple[int, ...]: | ||
""" | ||
Returns the shape over which the distribution parameters are batched. | ||
|
||
|
@@ -280,7 +281,7 @@ def batch_shape(self): | |
return self._batch_shape | ||
|
||
@property | ||
def event_shape(self): | ||
def event_shape(self) -> tuple[int, ...]: | ||
""" | ||
Returns the shape of a single sample from the distribution without | ||
batching. | ||
|
@@ -291,15 +292,15 @@ def event_shape(self): | |
return self._event_shape | ||
|
||
@property | ||
def event_dim(self): | ||
def event_dim(self) -> int: | ||
""" | ||
:return: Number of dimensions of individual events. | ||
:rtype: int | ||
""" | ||
return len(self.event_shape) | ||
|
||
@property | ||
def has_rsample(self): | ||
def has_rsample(self) -> bool: | ||
return set(self.reparametrized_params) == set(self.arg_constraints) | ||
|
||
def rsample(self, key, sample_shape=()): | ||
|
@@ -563,6 +564,38 @@ def is_discrete(self): | |
return self.support.is_discrete | ||
|
||
|
||
@runtime_checkable | ||
class DistributionLike(Protocol): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why Protocol is needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This protocol is needed to provide static type checking for functions that accept
Rather than checking for each specific type, this protocol defines the common interface If we remove it then we will get many MyPy errors 😅 . I have seen this approach in other projects so this is why I suggest it :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (And tensorflow probability and funsor are optional dependencies so we can not use Union to create a common type 🙃) |
||
"""A protocol for typing distributions. | ||
|
||
Used to type object of type numpyro.distributions.Distribution, funsor.Funsor | ||
or tensorflow_probability.distributions.Distribution. | ||
""" | ||
|
||
@property | ||
def batch_shape(self) -> tuple[int, ...]: ... | ||
|
||
@property | ||
def event_shape(self) -> tuple[int, ...]: ... | ||
|
||
@property | ||
def event_dim(self) -> int: ... | ||
|
||
def sample(self, key: ArrayLike, sample_shape: tuple[int, ...] = ()) -> Array: ... | ||
|
||
def log_prob(self, value: ArrayLike) -> ArrayLike: ... | ||
|
||
@property | ||
def mean(self) -> ArrayLike: ... | ||
|
||
@property | ||
def variance(self) -> ArrayLike: ... | ||
|
||
def cdf(self, value: ArrayLike) -> ArrayLike: ... | ||
|
||
def icdf(self, q: ArrayLike) -> ArrayLike: ... | ||
|
||
|
||
class ExpandedDistribution(Distribution): | ||
arg_constraints = {} | ||
pytree_data_fields = ("base_dist",) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why Array is used instead of ArrayLike?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because
ArrayLike
can be a scalar (see https://jax.readthedocs.io/en/latest/jax.typing.html) but we want them to be true arrays