Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forecasting Interface #241

Merged
merged 55 commits into from
Jul 12, 2024
Merged

Forecasting Interface #241

merged 55 commits into from
Jul 12, 2024

Conversation

damonbayer
Copy link
Collaborator

@damonbayer damonbayer commented Jul 2, 2024

Closes #169
Partially addresses #249.

This PR

  1. Upgrades minimum version ofjax and numpyro in light of Inference for gaussian_hmm is broken on latest jax version (0.4.30) pyro-ppl/numpyro#1827.
  2. Modifies models to have a clearer definition of some time-related concepts.
  • n_timepoints_to_simulate is always the length of the "data" (whether or not it is observed). I would like to rename it to n_datapoints in a separate PR.
  • Observations begin after the seeding period and the padding period. Thus the total length of the latent infections is n_timepoints_to_simulate + seeding_period_length + padding
  1. Cleans up hospital_admissions_model.qmd in light of (2) and to reduce code repetition.
  2. Simplifies admissionsmodel.py and rtinfectionsrenewalmodel.py to avoid unnecessary if-else paths.
  3. Modifies simplerandomwalk.py to use numpyro.contrib.control_flow.scan rather than an expanded distribution, which allows for forecasting to "just work."
  4. Adds a forecasting example to hospital_admissions_model.qmd.

Original PR text:

Currently working on refactoring the SimpleRandomWalkProcess using scan, which should make forecasting possible. It works in forward simulation, but is broken in inference.

This works:

from pyrenew.process import SimpleRandomWalkProcess
from numpyro import distributions as dist
import numpyro
srwp = SimpleRandomWalkProcess(dist.Normal(0, 1))
with numpyro.handlers.seed(rng_seed=0):
    x = srwp.sample(n_timepoints=10, init=0.0)
x

But pyrenew_demo.qmd fails during inference.

It says:

TypeError: body_fun output and input must have identical types, got
('ShapedArray(int32[], weak_type=True)', ['ShapedArray(float32[33])', 'DIFFERENT ShapedArray(int32[], weak_type=True) vs. ShapedArray(float0[])', 'ShapedArray(float32[])'], []).

@damonbayer damonbayer changed the title Forecast Interfact Forecasting Interface Jul 2, 2024
@damonbayer
Copy link
Collaborator Author

@dylanhmorris @gvegayon I could use your eyes on this.

Copy link

codecov bot commented Jul 2, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 100.00%. Comparing base (a8e4474) to head (45879b1).
Report is 1 commits behind head on main.

Additional details and impacted files
@@             Coverage Diff             @@
##             main      #241      +/-   ##
===========================================
+ Coverage   92.71%   100.00%   +7.28%     
===========================================
  Files          40         2      -38     
  Lines         906         7     -899     
===========================================
- Hits          840         7     -833     
+ Misses         66         0      -66     
Flag Coverage Δ
unittests 100.00% <ø> (+7.28%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@damonbayer
Copy link
Collaborator Author

damonbayer commented Jul 2, 2024

Note, I get the same error when trying to perform inference on the example given in the scan documentation:

import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.control_flow import scan

def gaussian_hmm(y=None, T=10):
    def transition(x_prev, y_curr):
        x_curr = numpyro.sample('x', dist.Normal(x_prev, 1))
        y_curr = numpyro.sample('y', dist.Normal(x_curr, 1), obs=y_curr)
        return x_curr, (x_curr, y_curr)

    x0 = numpyro.sample('x_0', dist.Normal(0, 1))
    _, (x, y) = scan(transition, x0, y, length=T)
    return (x, y)


with numpyro.handlers.seed(rng_seed=0):  # generative
    x, y = gaussian_hmm()

from jax import random
from numpyro.infer import MCMC, NUTS
nuts_kernel = NUTS(gaussian_hmm)

mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)

mcmc.run(rng_key, y=y)

However, they have inference working on their ar2 example, which also uses scan.

@dylanhmorris
Copy link
Collaborator

dylanhmorris commented Jul 2, 2024

Note, I get the same error when trying to perform inference on the example given in the scan documentation:

EDIT: I am able to run your code block without a problem with numpyro 0.15.0 and jax/jaxlib 0.4.28 but can reproduce the error with jax=0.4.30/jaxlib=0.4.30 on macOS and (WSL2) Ubuntu 22.04

EDIT2: can run the code block with numpyro 0.15.0 and jax/jaxlib 0.4.29 as well.

@damonbayer
Copy link
Collaborator Author

damonbayer commented Jul 2, 2024

@dylanhmorris and I have determined that this works with jax 0.4.29, but not jax 0.4.30. I have opened an issue: pyro-ppl/numpyro#1827

@damonbayer
Copy link
Collaborator Author

This is now working with downgraded jax. Will continue on the PR.

@damonbayer
Copy link
Collaborator Author

@dylanhmorris @gvegayon This is ready for review. I have added a summary of all the changes in the at the top of this page.

@damonbayer damonbayer requested a review from a team July 11, 2024 02:55
@AFg6K7h4fhy2
Copy link
Collaborator

@dylanhmorris @gvegayon This is ready for review. I have added a summary of all the changes in the at the top of this page.

Would appreciate this as it will alter some aspects of #162

Copy link
Collaborator

@dylanhmorris dylanhmorris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks @damonbayer. A couple questions and one major request.

Request: can we have a forecasting unit test that confirms stochastic processes are projected forward as we would expect (e.g. that forecast-period steps in a random walk are generated conditional on the posterior sampled value of the step s.d.)

docs/source/tutorials/basic_renewal_model.qmd Show resolved Hide resolved
model/src/pyrenew/process/simplerandomwalk.py Show resolved Hide resolved
docs/source/tutorials/hospital_admissions_model.qmd Outdated Show resolved Hide resolved
@damonbayer
Copy link
Collaborator Author

damonbayer commented Jul 11, 2024

Request: can we have a forecasting unit test that confirms stochastic processes are projected forward as we would expect (e.g. that forecast-period steps in a random walk are generated conditional on the posterior sampled value of the step s.d.)

We do not currently support posterior updates of the step s.d., as far as I can tell. We definitely should, but I think that is out of scope for this PR.

That said, I will implement a test to make sure the forecast length works as expected.

docs/pyproject.toml Outdated Show resolved Hide resolved
@damonbayer damonbayer requested a review from dylanhmorris July 11, 2024 23:39
Copy link
Collaborator

@dylanhmorris dylanhmorris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @damonbayer!

@damonbayer damonbayer merged commit d0b93be into main Jul 12, 2024
8 checks passed
@damonbayer damonbayer deleted the dmb_forecast branch July 12, 2024 04:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Create forecasting interface
5 participants