Skip to content

Commit

Permalink
all basic hints
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Dec 20, 2024
1 parent 1f6e471 commit f03da11
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import namedtuple
from contextlib import ExitStack, contextmanager
import functools
from typing import Callable, Dict, List, Optional, Tuple, cast
from typing import Callable, Dict, Generator, List, Optional, Tuple, cast
import warnings

import jax
Expand Down Expand Up @@ -532,14 +532,14 @@ def __enter__(self):
return self._indices

@staticmethod
def _get_batch_shape(cond_indep_stack):
def _get_batch_shape(cond_indep_stack: List[CondIndepStackFrame]) -> Tuple:
n_dims = max(-f.dim for f in cond_indep_stack)
batch_shape = [1] * n_dims
for f in cond_indep_stack:
batch_shape[f.dim] = f.size
return tuple(batch_shape)

def process_message(self, msg):
def process_message(self, msg: Dict) -> None:
if msg["type"] not in ("param", "sample", "plate", "deterministic"):
if msg["type"] == "control_flow":
raise NotImplementedError(
Expand All @@ -555,7 +555,7 @@ def process_message(self, msg):
):
return

cond_indep_stack = msg["cond_indep_stack"]
cond_indep_stack: List[CondIndepStackFrame] = msg["cond_indep_stack"]
frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size)
cond_indep_stack.append(frame)
if msg["type"] == "deterministic":
Expand All @@ -579,7 +579,7 @@ def process_message(self, msg):
self.size / self.subsample_size if self.subsample_size else 1
)

def postprocess_message(self, msg):
def postprocess_message(self, msg: Dict) -> None:
if msg["type"] in ("subsample", "param") and self.dim is not None:
event_dim = msg["kwargs"].get("event_dim")
if event_dim is not None:
Expand Down Expand Up @@ -608,7 +608,9 @@ def postprocess_message(self, msg):


@contextmanager
def plate_stack(prefix, sizes, rightmost_dim=-1):
def plate_stack(
prefix: str, sizes: List[int], rightmost_dim: int = -1
) -> Generator[None, None, None]:
"""
Create a contiguous stack of :class:`plate` s with dimensions::
Expand All @@ -626,7 +628,7 @@ def plate_stack(prefix, sizes, rightmost_dim=-1):
yield


def factor(name, log_factor):
def factor(name: str, log_factor: ArrayLike) -> None:
"""
Factor statement to add arbitrary log probability factor to a
probabilistic model.
Expand All @@ -639,7 +641,7 @@ def factor(name, log_factor):
sample(name, unit_dist, obs=unit_value, infer={"is_auxiliary": True})


def prng_key():
def prng_key() -> Array | None:
"""
A statement to draw a pseudo-random number generator key
:func:`~jax.random.PRNGKey` under :class:`~numpyro.handlers.seed` handler.
Expand All @@ -651,7 +653,7 @@ def prng_key():
"Cannot generate JAX PRNG key outside of `seed` handler.",
stacklevel=find_stack_level(),
)
return
return None

initial_msg = {
"type": "prng_key",
Expand All @@ -665,7 +667,7 @@ def prng_key():
return msg["value"]


def subsample(data, event_dim):
def subsample(data: ArrayLike, event_dim: int) -> Array:
"""
EXPERIMENTAL Subsampling statement to subsample data based on enclosing
:class:`~numpyro.primitives.plate` s.
Expand Down Expand Up @@ -693,7 +695,7 @@ def model(data):
:rtype: ~jnp.ndarray
"""
if not _PYRO_STACK:
return data
return cast(Array, data)

assert isinstance(event_dim, int) and event_dim >= 0
initial_msg = {
Expand Down

0 comments on commit f03da11

Please sign in to comment.