-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support scan for Trace_ELBO #1693
Conversation
Nice support, @deoxyribose! Could you add tests for this change? |
I've added a test that does SVI with an AutoNormal guide, and checks that results are relatively accurate. I've kept model and data size to a minimum, the test takes around 3-4 seconds on my machine. There already was a test combining scan, SVI and AutoNormal (test_subsample_guide in test/infer/test_autoguide.py), but it doesn't run inference and checks results, so the current version of scan passes it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thanks, @deoxyribose !
numpyro/contrib/control_flow/scan.py
Outdated
site["value"] = site["value"][i] | ||
return site | ||
|
||
return {k: get_ith_value(v.copy()) for k, v in replay_trace.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it is better to do the copy inside git_ith_value
* scan replay handling * fix off by one error * cleaned a bit * lint * use substack * handle custom guides * handle empty value shape * a * test * move copy inside get_ith_value --------- Co-authored-by: frans <[email protected]> Co-authored-by: OlaRonning <[email protected]>
Issue #1685
Reuse the
substitute_stack
to store the replay trace, and replay the sites in the scan'ed function one iteration at a time.