-
Notifications
You must be signed in to change notification settings - Fork 247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Bayesian VAR(2) example script #1658 #1915
Conversation
This looks interesting! As of #1904, there is a Gaussian state space model distribution. Would that be suitable for the VAR here? |
In addition, could you add the example to docs index together with thumbnails (like in #1429)? |
Hi @fehiepsi, added the index and thumbnails. |
Hi @tillahoffmann, awesome work on SSM, do you want me to create example using SSM? |
@fehiepsi , Lint has failed, when I am fixing with ruff, it is changing some more files other that var2.py |
I think using the current implementation in this PR or SSM could both work. The latter might be a bit more concise because you wouldn't have to implement the transition function. |
@aibit0111 Could you merge with master? lint issues are fixed upstream. |
@fehiepsi Done |
Could you add some docstring like other examples to give the readers a take on the content? |
ed19c33
to
c0cf4c4
Compare
Hi @fehiepsi, i was on holiday, updated the code with proper doc and passed all the test. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @aibit0111! The implementation looks great to me. Just have minor comments.
Could you also mention [Forecasting](https://num.pyro.ai/en/stable/tutorials/time_series_forecasting.html#Forecasting)
in the description so that users have a pointer?
examples/var2.py
Outdated
L_omega = numpyro.sample( | ||
"L_omega", dist.LKJCholesky(dimension=K, concentration=1.0) | ||
) | ||
L_Sigma = jnp.matmul(jnp.diag(sigma), L_omega) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe jnp.einsum("...i,...ij->...ij", sigma, L_omega)
or sigma[..., None] * L_omega
to improve the performance
examples/var2.py
Outdated
# Priors for constants and coefficients | ||
c = numpyro.sample("c", dist.Normal(0, 1).expand([K])) # Constants vector of size K | ||
Phi1 = numpyro.sample( | ||
"Phi1", dist.Normal(0, 1).expand([K, K]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add .to_event(2)
at the end to make sure that no batch dimension appears here?
examples/var2.py
Outdated
"Phi1", dist.Normal(0, 1).expand([K, K]) | ||
) # Coefficients for lag 1 | ||
Phi2 = numpyro.sample( | ||
"Phi2", dist.Normal(0, 1).expand([K, K]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
examples/var2.py
Outdated
|
||
# Priors for error terms | ||
with numpyro.plate("K", K): | ||
sigma = numpyro.sample("sigma", dist.HalfNormal(1.0)) # Standard deviations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
like above, prefer .expand([K]).to_event(1)
than using plate
Thanks @fehiepsi for generous review, I am new to num pyro, learning a lot. Please review once again |
VAR(2) examples:
Reference - https://otexts.com/fpp2/VAR.html
fix #1658