-
Notifications
You must be signed in to change notification settings - Fork 246
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
158 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
import argparse | ||
import os | ||
import time | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import jax | ||
from jax import random | ||
import jax.numpy as jnp | ||
|
||
import numpyro | ||
from numpyro.contrib.control_flow import scan | ||
import numpyro.distributions as dist | ||
|
||
|
||
def var2_scan(y): | ||
T, K = y.shape # Number of time steps and number of variables | ||
|
||
# Priors for constants and coefficients | ||
c = numpyro.sample("c", dist.Normal(0, 1).expand([K])) # Constants vector of size K | ||
Phi1 = numpyro.sample("Phi1", dist.Normal(0, 1).expand([K, K])) # Coefficients for lag 1 | ||
Phi2 = numpyro.sample("Phi2", dist.Normal(0, 1).expand([K, K])) # Coefficients for lag 2 | ||
|
||
# Priors for error terms | ||
with numpyro.plate("K", K): | ||
sigma = numpyro.sample("sigma", dist.HalfNormal(1.0)) # Standard deviations | ||
L_omega = numpyro.sample("L_omega", dist.LKJCholesky(dimension=K, concentration=1.0)) | ||
L_Sigma = jnp.matmul(jnp.diag(sigma), L_omega) | ||
|
||
def transition(carry, t): | ||
y_prev1, y_prev2, y_obs = carry # Previous two observations and observed data | ||
m_t = c + jnp.dot(Phi1, y_prev1) + jnp.dot(Phi2, y_prev2) # Mean prediction | ||
y_t = numpyro.sample(f"y_{t}", dist.MultivariateNormal(loc=m_t, scale_tril=L_Sigma), obs=y_obs[t]) # Conditioned on observed y | ||
new_carry = (y_t, y_prev1, y_obs) | ||
return new_carry, m_t | ||
|
||
# Initial carry: observations at time steps 1 and 0 | ||
init_carry = (y[1], y[0], y[2:]) | ||
|
||
# Time indices starting from time step 2 | ||
time_indices = jnp.arange(T - 2) | ||
|
||
# Run the scan | ||
_, mu = scan(transition, init_carry, time_indices) | ||
|
||
# Store the mean trajectory as a deterministic variable | ||
numpyro.deterministic("mu", mu) | ||
|
||
|
||
def generate_var2_data(T, K, c, Phi1, Phi2, sigma): | ||
""" | ||
Generate time series data from a VAR(2) process. | ||
Args: | ||
T (int): Number of time steps. | ||
K (int): Number of variables in the time series. | ||
c (array): Constants (shape: (K,)). | ||
Phi1 (array): Coefficients for lag 1 (shape: (K, K)). | ||
Phi2 (array): Coefficients for lag 2 (shape: (K, K)). | ||
sigma (array): Covariance matrix for the noise (shape: (K, K)). | ||
Returns: | ||
np.ndarray: Generated time series data (shape: (T, K)). | ||
""" | ||
# Initialize time series with random values | ||
y = np.zeros((T, K)) | ||
y[:2] = np.random.multivariate_normal(mean=np.zeros(K), cov=sigma, size=2) | ||
|
||
# Generate the time series | ||
for t in range(2, T): | ||
y[t] = c + Phi1 @ y[t - 1] + Phi2 @ y[t - 2] + np.random.multivariate_normal(mean=np.zeros(K), cov=sigma) | ||
|
||
return y | ||
|
||
|
||
|
||
def run_inference(model, args, rng_key, y): | ||
""" | ||
Run MCMC inference for the given model. | ||
Args: | ||
model: The probabilistic model to infer. | ||
args: Command-line arguments. | ||
rng_key: PRNG key for randomness. | ||
y: Observed time series data. | ||
""" | ||
start = time.time() | ||
sampler = numpyro.infer.NUTS(model) | ||
mcmc = numpyro.infer.MCMC( | ||
sampler, | ||
num_warmup=args.num_warmup, | ||
num_samples=args.num_samples, | ||
num_chains=args.num_chains, | ||
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True, | ||
) | ||
mcmc.run(rng_key, y=y) | ||
mcmc.print_summary() | ||
print("\nMCMC elapsed time:", time.time() - start) | ||
return mcmc.get_samples() | ||
|
||
|
||
def main(args): | ||
# Generate artificial dataset | ||
T = args.num_data # Number of time steps | ||
K = 2 # Number of variables | ||
c_true = jnp.array([0.5, -0.3]) # Constants | ||
Phi1_true = jnp.array([[0.7, 0.1], [0.2, 0.5]]) # Coefficients for lag 1 | ||
Phi2_true = jnp.array([[0.2, -0.1], [-0.1, 0.2]]) # Coefficients for lag 2 | ||
sigma_true = jnp.array([[0.1, 0.02], [0.02, 0.1]]) # Covariance matrix | ||
|
||
rng_key = random.PRNGKey(0) | ||
y = generate_var2_data(T, K, c_true, Phi1_true, Phi2_true, sigma_true) | ||
|
||
# Perform inference | ||
samples = run_inference(var2_scan, args, rng_key, y) | ||
|
||
# Prediction | ||
mean_prediction = samples["mu"].mean(axis=0) | ||
lower_bound = jnp.percentile(samples["mu"], 2.5, axis=0) # 2.5th percentile | ||
upper_bound = jnp.percentile(samples["mu"], 97.5, axis=0) # 97.5th percentile | ||
|
||
# Plot results | ||
fig, axes = plt.subplots(K, 1, figsize=(10, 6), sharex=True) | ||
time_steps = jnp.arange(T) | ||
|
||
for i in range(K): | ||
# True values | ||
axes[i].plot(time_steps, y[:, i], label=f"True Variable {i + 1}", color="blue") | ||
# Posterior mean prediction | ||
axes[i].plot(time_steps[2:], mean_prediction[:, i], label=f"Predicted Mean Variable {i + 1}", color="orange") | ||
# 95% confidence interval | ||
axes[i].fill_between( | ||
time_steps[2:], | ||
lower_bound[:, i], | ||
upper_bound[:, i], | ||
color="orange", | ||
alpha=0.2, | ||
label="95% CI" | ||
) | ||
axes[i].set_title(f"Variable {i + 1}") | ||
axes[i].legend() | ||
axes[i].grid(True) | ||
|
||
plt.xlabel("Time Steps") | ||
plt.tight_layout() | ||
plt.savefig("var2_with_confidence_interval.png") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="VAR(2) example") | ||
parser.add_argument("--num-data", nargs="?", default=100, type=int) | ||
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) | ||
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) | ||
parser.add_argument("--num-chains", nargs="?", default=1, type=int) | ||
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".') | ||
args = parser.parse_args() | ||
|
||
numpyro.set_platform(args.device) | ||
numpyro.set_host_device_count(args.num_chains) | ||
|
||
main(args) |