Skip to content

Commit

Permalink
type improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Nov 14, 2024
1 parent e2267ba commit 520f524
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
11 changes: 7 additions & 4 deletions numpyro/contrib/stochastic_support/dcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from numpyro.handlers import condition, seed, trace
from numpyro.infer import MCMC, NUTS
from numpyro.infer.autoguide import AutoGuide, AutoNormal
from numpyro.infer.mcmc import MCMCKernel
from numpyro.infer.util import init_to_value, log_density

DCCResult = namedtuple("DCCResult", ["samples", "slp_weights"])
Expand Down Expand Up @@ -127,7 +128,9 @@ def _combine_inferences(
) -> DCCResult | SDVIResult:
raise NotImplementedError

def run(self, rng_key, *args, **kwargs):
def run(
self, rng_key: ArrayImpl, *args: Any, **kwargs: Any
) -> DCCResult | SDVIResult:
"""
Run inference on each SLP separately and combine the results.
Expand Down Expand Up @@ -193,8 +196,8 @@ def __init__(
self,
model: Callable,
mcmc_kwargs: dict[str, Any],
kernel_cls: Callable = NUTS,
num_slp_samples: int = 1000,
kernel_cls: type[MCMCKernel] = NUTS,
num_slp_samples: int = 1_000,
max_slps: int = 124,
proposal_scale: float = 1.0,
) -> None:
Expand All @@ -216,7 +219,7 @@ def _run_inference(
Run MCMC on the model conditioned on the given branching trace.
"""
slp_model = condition(self.model, data=branching_trace)
kernel = self.kernel_cls(slp_model)
kernel = self.kernel_cls(slp_model) # type: ignore[call-arg]
mcmc = MCMC(kernel, **self.mcmc_kwargs)
mcmc.run(rng_key, *args, **kwargs)

Expand Down
8 changes: 4 additions & 4 deletions numpyro/contrib/stochastic_support/sdvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ def __init__(
self,
model: Callable,
optimizer,
svi_num_steps: int = 1000,
combine_elbo_particles: int = 1000,
svi_num_steps: int = 1_000,
combine_elbo_particles: int = 1_000,
guide_init: Callable = AutoNormal,
loss: ELBO = Trace_ELBO(),
svi_progress_bar: bool = False,
num_slp_samples=1000,
max_slps=124,
num_slp_samples: int = 1_000,
max_slps: int = 124,
):
self.guide_init = guide_init
self.optimizer = optimizer
Expand Down

0 comments on commit 520f524

Please sign in to comment.