diff --git a/examples/var2.py b/examples/var2.py new file mode 100644 index 000000000..0e1ba06ae --- /dev/null +++ b/examples/var2.py @@ -0,0 +1,158 @@ +import argparse +import os +import time + +import matplotlib.pyplot as plt +import numpy as np +import jax +from jax import random +import jax.numpy as jnp + +import numpyro +from numpyro.contrib.control_flow import scan +import numpyro.distributions as dist + + +def var2_scan(y): + T, K = y.shape # Number of time steps and number of variables + + # Priors for constants and coefficients + c = numpyro.sample("c", dist.Normal(0, 1).expand([K])) # Constants vector of size K + Phi1 = numpyro.sample("Phi1", dist.Normal(0, 1).expand([K, K])) # Coefficients for lag 1 + Phi2 = numpyro.sample("Phi2", dist.Normal(0, 1).expand([K, K])) # Coefficients for lag 2 + + # Priors for error terms + with numpyro.plate("K", K): + sigma = numpyro.sample("sigma", dist.HalfNormal(1.0)) # Standard deviations + L_omega = numpyro.sample("L_omega", dist.LKJCholesky(dimension=K, concentration=1.0)) + L_Sigma = jnp.matmul(jnp.diag(sigma), L_omega) + + def transition(carry, t): + y_prev1, y_prev2, y_obs = carry # Previous two observations and observed data + m_t = c + jnp.dot(Phi1, y_prev1) + jnp.dot(Phi2, y_prev2) # Mean prediction + y_t = numpyro.sample(f"y_{t}", dist.MultivariateNormal(loc=m_t, scale_tril=L_Sigma), obs=y_obs[t]) # Conditioned on observed y + new_carry = (y_t, y_prev1, y_obs) + return new_carry, m_t + + # Initial carry: observations at time steps 1 and 0 + init_carry = (y[1], y[0], y[2:]) + + # Time indices starting from time step 2 + time_indices = jnp.arange(T - 2) + + # Run the scan + _, mu = scan(transition, init_carry, time_indices) + + # Store the mean trajectory as a deterministic variable + numpyro.deterministic("mu", mu) + + +def generate_var2_data(T, K, c, Phi1, Phi2, sigma): + """ + Generate time series data from a VAR(2) process. + Args: + T (int): Number of time steps. + K (int): Number of variables in the time series. + c (array): Constants (shape: (K,)). + Phi1 (array): Coefficients for lag 1 (shape: (K, K)). + Phi2 (array): Coefficients for lag 2 (shape: (K, K)). + sigma (array): Covariance matrix for the noise (shape: (K, K)). + Returns: + np.ndarray: Generated time series data (shape: (T, K)). + """ + # Initialize time series with random values + y = np.zeros((T, K)) + y[:2] = np.random.multivariate_normal(mean=np.zeros(K), cov=sigma, size=2) + + # Generate the time series + for t in range(2, T): + y[t] = c + Phi1 @ y[t - 1] + Phi2 @ y[t - 2] + np.random.multivariate_normal(mean=np.zeros(K), cov=sigma) + + return y + + + +def run_inference(model, args, rng_key, y): + """ + Run MCMC inference for the given model. + Args: + model: The probabilistic model to infer. + args: Command-line arguments. + rng_key: PRNG key for randomness. + y: Observed time series data. + """ + start = time.time() + sampler = numpyro.infer.NUTS(model) + mcmc = numpyro.infer.MCMC( + sampler, + num_warmup=args.num_warmup, + num_samples=args.num_samples, + num_chains=args.num_chains, + progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True, + ) + mcmc.run(rng_key, y=y) + mcmc.print_summary() + print("\nMCMC elapsed time:", time.time() - start) + return mcmc.get_samples() + + +def main(args): + # Generate artificial dataset + T = args.num_data # Number of time steps + K = 2 # Number of variables + c_true = jnp.array([0.5, -0.3]) # Constants + Phi1_true = jnp.array([[0.7, 0.1], [0.2, 0.5]]) # Coefficients for lag 1 + Phi2_true = jnp.array([[0.2, -0.1], [-0.1, 0.2]]) # Coefficients for lag 2 + sigma_true = jnp.array([[0.1, 0.02], [0.02, 0.1]]) # Covariance matrix + + rng_key = random.PRNGKey(0) + y = generate_var2_data(T, K, c_true, Phi1_true, Phi2_true, sigma_true) + + # Perform inference + samples = run_inference(var2_scan, args, rng_key, y) + + # Prediction + mean_prediction = samples["mu"].mean(axis=0) + lower_bound = jnp.percentile(samples["mu"], 2.5, axis=0) # 2.5th percentile + upper_bound = jnp.percentile(samples["mu"], 97.5, axis=0) # 97.5th percentile + + # Plot results + fig, axes = plt.subplots(K, 1, figsize=(10, 6), sharex=True) + time_steps = jnp.arange(T) + + for i in range(K): + # True values + axes[i].plot(time_steps, y[:, i], label=f"True Variable {i + 1}", color="blue") + # Posterior mean prediction + axes[i].plot(time_steps[2:], mean_prediction[:, i], label=f"Predicted Mean Variable {i + 1}", color="orange") + # 95% confidence interval + axes[i].fill_between( + time_steps[2:], + lower_bound[:, i], + upper_bound[:, i], + color="orange", + alpha=0.2, + label="95% CI" + ) + axes[i].set_title(f"Variable {i + 1}") + axes[i].legend() + axes[i].grid(True) + + plt.xlabel("Time Steps") + plt.tight_layout() + plt.savefig("var2_with_confidence_interval.png") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="VAR(2) example") + parser.add_argument("--num-data", nargs="?", default=100, type=int) + parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) + parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) + parser.add_argument("--num-chains", nargs="?", default=1, type=int) + parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".') + args = parser.parse_args() + + numpyro.set_platform(args.device) + numpyro.set_host_device_count(args.num_chains) + + main(args)