From 520f5247b930be8526659a4aabde364e16862580 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Thu, 14 Nov 2024 09:35:40 +0100 Subject: [PATCH] type improvements --- numpyro/contrib/stochastic_support/dcc.py | 11 +++++++---- numpyro/contrib/stochastic_support/sdvi.py | 8 ++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/numpyro/contrib/stochastic_support/dcc.py b/numpyro/contrib/stochastic_support/dcc.py index 2e574853c..72f6471f0 100644 --- a/numpyro/contrib/stochastic_support/dcc.py +++ b/numpyro/contrib/stochastic_support/dcc.py @@ -15,6 +15,7 @@ from numpyro.handlers import condition, seed, trace from numpyro.infer import MCMC, NUTS from numpyro.infer.autoguide import AutoGuide, AutoNormal +from numpyro.infer.mcmc import MCMCKernel from numpyro.infer.util import init_to_value, log_density DCCResult = namedtuple("DCCResult", ["samples", "slp_weights"]) @@ -127,7 +128,9 @@ def _combine_inferences( ) -> DCCResult | SDVIResult: raise NotImplementedError - def run(self, rng_key, *args, **kwargs): + def run( + self, rng_key: ArrayImpl, *args: Any, **kwargs: Any + ) -> DCCResult | SDVIResult: """ Run inference on each SLP separately and combine the results. @@ -193,8 +196,8 @@ def __init__( self, model: Callable, mcmc_kwargs: dict[str, Any], - kernel_cls: Callable = NUTS, - num_slp_samples: int = 1000, + kernel_cls: type[MCMCKernel] = NUTS, + num_slp_samples: int = 1_000, max_slps: int = 124, proposal_scale: float = 1.0, ) -> None: @@ -216,7 +219,7 @@ def _run_inference( Run MCMC on the model conditioned on the given branching trace. """ slp_model = condition(self.model, data=branching_trace) - kernel = self.kernel_cls(slp_model) + kernel = self.kernel_cls(slp_model) # type: ignore[call-arg] mcmc = MCMC(kernel, **self.mcmc_kwargs) mcmc.run(rng_key, *args, **kwargs) diff --git a/numpyro/contrib/stochastic_support/sdvi.py b/numpyro/contrib/stochastic_support/sdvi.py index 2a390964e..81d50c3e1 100644 --- a/numpyro/contrib/stochastic_support/sdvi.py +++ b/numpyro/contrib/stochastic_support/sdvi.py @@ -76,13 +76,13 @@ def __init__( self, model: Callable, optimizer, - svi_num_steps: int = 1000, - combine_elbo_particles: int = 1000, + svi_num_steps: int = 1_000, + combine_elbo_particles: int = 1_000, guide_init: Callable = AutoNormal, loss: ELBO = Trace_ELBO(), svi_progress_bar: bool = False, - num_slp_samples=1000, - max_slps=124, + num_slp_samples: int = 1_000, + max_slps: int = 124, ): self.guide_init = guide_init self.optimizer = optimizer