-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Forecasting Interface #241
Conversation
@dylanhmorris @gvegayon I could use your eyes on this. |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #241 +/- ##
===========================================
+ Coverage 92.71% 100.00% +7.28%
===========================================
Files 40 2 -38
Lines 906 7 -899
===========================================
- Hits 840 7 -833
+ Misses 66 0 -66
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
Note, I get the same error when trying to perform inference on the example given in the import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.control_flow import scan
def gaussian_hmm(y=None, T=10):
def transition(x_prev, y_curr):
x_curr = numpyro.sample('x', dist.Normal(x_prev, 1))
y_curr = numpyro.sample('y', dist.Normal(x_curr, 1), obs=y_curr)
return x_curr, (x_curr, y_curr)
x0 = numpyro.sample('x_0', dist.Normal(0, 1))
_, (x, y) = scan(transition, x0, y, length=T)
return (x, y)
with numpyro.handlers.seed(rng_seed=0): # generative
x, y = gaussian_hmm()
from jax import random
from numpyro.infer import MCMC, NUTS
nuts_kernel = NUTS(gaussian_hmm)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, y=y) However, they have inference working on their ar2 example, which also uses |
EDIT: I am able to run your code block without a problem with EDIT2: can run the code block with |
@dylanhmorris and I have determined that this works with jax 0.4.29, but not jax 0.4.30. I have opened an issue: pyro-ppl/numpyro#1827 |
This is now working with downgraded jax. Will continue on the PR. |
…in-tutorials-using-arviz-1
Co-authored-by: Damon Bayer <[email protected]>
Co-authored-by: Damon Bayer <[email protected]>
@dylanhmorris @gvegayon This is ready for review. I have added a summary of all the changes in the at the top of this page. |
Would appreciate this as it will alter some aspects of #162 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks @damonbayer. A couple questions and one major request.
Request: can we have a forecasting unit test that confirms stochastic processes are projected forward as we would expect (e.g. that forecast-period steps in a random walk are generated conditional on the posterior sampled value of the step s.d.)
We do not currently support posterior updates of the step s.d., as far as I can tell. We definitely should, but I think that is out of scope for this PR. That said, I will implement a test to make sure the forecast length works as expected. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @damonbayer!
Co-authored-by: Dylan H. Morris <[email protected]>
Closes #169
Partially addresses #249.
This PR
jax
andnumpyro
in light of Inference forgaussian_hmm
is broken on latest jax version (0.4.30) pyro-ppl/numpyro#1827.n_timepoints_to_simulate
is always the length of the "data" (whether or not it is observed). I would like to rename it ton_datapoints
in a separate PR.n_timepoints_to_simulate
+seeding_period_length
+padding
hospital_admissions_model.qmd
in light of (2) and to reduce code repetition.admissionsmodel.py
andrtinfectionsrenewalmodel.py
to avoid unnecessaryif-else
paths.simplerandomwalk.py
to usenumpyro.contrib.control_flow.scan
rather than anexpand
ed distribution, which allows for forecasting to "just work."hospital_admissions_model.qmd
.Original PR text:
Currently working on refactoring the
SimpleRandomWalkProcess
usingscan
, which should make forecasting possible. It works in forward simulation, but is broken in inference.This works:
But
pyrenew_demo.qmd
fails during inference.It says: