Skip to content

Commit

Permalink
Add Bayesian VAR(2) example script #1658 (#1915)
Browse files Browse the repository at this point in the history
* Add Bayesian VAR(2) example script

* Added index with thumbnail

* Fix Linting issues

* Apply ruff formatting to examples/var2.py

* Added header

* Added Header

* Added Header

* Added header

* shape fixed and added event

* Added dim and event
  • Loading branch information
aibit0111 authored Dec 4, 2024
1 parent d8bc7f7 commit 244ff38
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 0 deletions.
Binary file added docs/source/_static/img/examples/var2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ NumPyro documentation
examples/zero_inflated_poisson
examples/cvae
tutorials/tbip
examples/var2

.. nbgallery::
:maxdepth: 1
Expand Down
216 changes: 216 additions & 0 deletions examples/var2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

r"""
Example: VAR(2) process
=======================
In this example, we demonstrate how to implement and perform Bayesian inference for a
Vector Autoregressive process of order 2 (VAR(2)). VAR models are widely used in
time series analysis, especially for capturing the dynamics between multiple variables.
A VAR(2) process for a multivariate time series :math:`y_t` with :math:`K` variables is defined as:
.. math::
y_t = c + \Phi_1 y_{t-1} + \Phi_2 y_{t-2} + \epsilon_t
Here, :math:`c` is a constant vector, :math:`\Phi_1` and :math:`\Phi_2` are coefficient matrices for lag 1
and lag 2, respectively, and :math:`\epsilon_t` is a Gaussian noise term with zero mean and a
covariance matrix :math:`\Sigma`.
This example uses NumPyro's `scan` utility to efficiently model the temporal dependencies without
explicit Python loops.
For more general time series forecasting techniques and examples, refer to the
`Time Series Forecasting` tutorial:
https://num.pyro.ai/en/stable/tutorials/time_series_forecasting.html#Forecasting
Reference
---------
For more information on Vector Autoregressive models, see:
https://otexts.com/fpp2/VAR.html
.. image:: ../_static/img/examples/var2.png
:align: center
"""

import argparse
import os
import time

import matplotlib.pyplot as plt
import numpy as np

from jax import random
import jax.numpy as jnp

import numpyro
from numpyro.contrib.control_flow import scan
import numpyro.distributions as dist


def var2_scan(y):
T, K = y.shape # Number of time steps and number of variables

# Priors for constants and coefficients
c = numpyro.sample("c", dist.Normal(0, 1).expand([K])) # Constants vector of size K
Phi1 = numpyro.sample(
"Phi1", dist.Normal(0, 1).expand([K, K]).to_event(2)
) # Coefficients for lag 1
Phi2 = numpyro.sample(
"Phi2", dist.Normal(0, 1).expand([K, K]).to_event(2)
) # Coefficients for lag 2

# Priors for error terms
sigma = numpyro.sample("sigma", dist.HalfNormal(1.0).expand([K]).to_event(1))
L_omega = numpyro.sample(
"L_omega", dist.LKJCholesky(dimension=K, concentration=1.0)
)
L_Sigma = (
sigma[..., None] * L_omega
) # Alternative: jnp.einsum("...i,...ij->...ij", sigma, L_omega)

def transition(carry, t):
y_prev1, y_prev2, y_obs = carry # Previous two observations and observed data
m_t = c + jnp.dot(Phi1, y_prev1) + jnp.dot(Phi2, y_prev2) # Mean prediction
# Conditioned on observed y
y_t = numpyro.sample(
f"y_{t}",
dist.MultivariateNormal(loc=m_t, scale_tril=L_Sigma),
obs=y_obs[t],
)
new_carry = (y_t, y_prev1, y_obs)
return new_carry, m_t

# Initial carry: observations at time steps 1 and 0
init_carry = (y[1], y[0], y[2:])

# Time indices starting from time step 2
time_indices = jnp.arange(T - 2)

# Run the scan
_, mu = scan(transition, init_carry, time_indices)

# Store the mean trajectory as a deterministic variable
numpyro.deterministic("mu", mu)


def generate_var2_data(T, K, c, Phi1, Phi2, sigma):
"""
Generate time series data from a VAR(2) process.
Args:
T (int): Number of time steps.
K (int): Number of variables in the time series.
c (array): Constants (shape: (K,)).
Phi1 (array): Coefficients for lag 1 (shape: (K, K)).
Phi2 (array): Coefficients for lag 2 (shape: (K, K)).
sigma (array): Covariance matrix for the noise (shape: (K, K)).
Returns:
np.ndarray: Generated time series data (shape: (T, K)).
"""
# Initialize time series with random values
y = np.zeros((T, K))
y[:2] = np.random.multivariate_normal(mean=np.zeros(K), cov=sigma, size=2)

# Generate the time series
for t in range(2, T):
y[t] = (
c
+ Phi1 @ y[t - 1]
+ Phi2 @ y[t - 2]
+ np.random.multivariate_normal(mean=np.zeros(K), cov=sigma)
)

return y


def run_inference(model, args, rng_key, y):
"""
Run MCMC inference for the given model.
Args:
model: The probabilistic model to infer.
args: Command-line arguments.
rng_key: PRNG key for randomness.
y: Observed time series data.
"""
start = time.time()
sampler = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(
sampler,
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(rng_key, y=y)
mcmc.print_summary()
print("\nMCMC elapsed time:", time.time() - start)
return mcmc.get_samples()


def main(args):
# Generate artificial dataset
T = args.num_data # Number of time steps
K = 2 # Number of variables
c_true = jnp.array([0.5, -0.3]) # Constants
Phi1_true = jnp.array([[0.7, 0.1], [0.2, 0.5]]) # Coefficients for lag 1
Phi2_true = jnp.array([[0.2, -0.1], [-0.1, 0.2]]) # Coefficients for lag 2
sigma_true = jnp.array([[0.1, 0.02], [0.02, 0.1]]) # Covariance matrix

rng_key = random.PRNGKey(0)
y = generate_var2_data(T, K, c_true, Phi1_true, Phi2_true, sigma_true)

# Perform inference
samples = run_inference(var2_scan, args, rng_key, y)

# Prediction
mean_prediction = samples["mu"].mean(axis=0)
lower_bound = jnp.percentile(samples["mu"], 2.5, axis=0) # 2.5th percentile
upper_bound = jnp.percentile(samples["mu"], 97.5, axis=0) # 97.5th percentile

# Plot results
fig, axes = plt.subplots(K, 1, figsize=(10, 6), sharex=True)
time_steps = jnp.arange(T)

for i in range(K):
# True values
axes[i].plot(time_steps, y[:, i], label=f"True Variable {i + 1}", color="blue")
# Posterior mean prediction
axes[i].plot(
time_steps[2:],
mean_prediction[:, i],
label=f"Predicted Mean Variable {i + 1}",
color="orange",
)
# 95% confidence interval
axes[i].fill_between(
time_steps[2:],
lower_bound[:, i],
upper_bound[:, i],
color="orange",
alpha=0.2,
label="95% CI",
)
axes[i].set_title(f"Variable {i + 1}")
axes[i].legend()
axes[i].grid(True)

plt.xlabel("Time Steps")
plt.tight_layout()
plt.savefig("var2.png")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="VAR(2) example")
parser.add_argument("--num-data", nargs="?", default=100, type=int)
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
parser.add_argument("--num-chains", nargs="?", default=1, type=int)
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
args = parser.parse_args()

numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)

main(args)

0 comments on commit 244ff38

Please sign in to comment.