diff --git a/examples/stein_bnn.py b/examples/stein_bnn.py index 952b0be49..155eb3fd3 100644 --- a/examples/stein_bnn.py +++ b/examples/stein_bnn.py @@ -92,6 +92,7 @@ def model(x, y=None, hidden_dim=50, sub_size=100): # precision prior on observations prec_obs = sample("prec_obs", Gamma(1.0, 0.1)) + with plate("data", x.shape[0], subsample_size=sub_size, dim=-1): batch_x = subsample(x, event_dim=1) if y is not None: @@ -99,7 +100,7 @@ def model(x, y=None, hidden_dim=50, sub_size=100): else: batch_y = y - loc_y = deterministic("y_pred", nn.relu(batch_x @ w1 + b1) @ w2 + b2) + loc_y = deterministic("y_bnn", nn.relu(batch_x @ w1 + b1) @ w2 + b2) sample( "y", @@ -156,28 +157,27 @@ def main(args): data.xte, xtr_mean, xtr_std ) # Use train data statistics when accessing generalization. n = xte.shape[0] - y_preds = pred(pred_key, xte, sub_size=n, hidden_dim=args.hidden_dim)["y_pred"] - mean_pred = y_preds.mean(0) - rmse = jnp.sqrt(jnp.mean((mean_pred - data.yte) ** 2)) + pred_y = pred(pred_key, xte, sub_size=n, hidden_dim=args.hidden_dim)["y"] + rmse = jnp.sqrt(jnp.mean((pred_y.mean(0) - data.yte) ** 2)) print(rf"Time taken: {datetime.timedelta(seconds=int(time_taken))}") print(rf"RMSE: {rmse:.2f}") - # compute mean prediction and confidence interval around median - percentiles = jnp.percentile(y_preds, jnp.array([5.0, 95.0]), axis=0) + # Compute mean prediction and confidence interval around median + percentiles = jnp.percentile(pred_y, jnp.array([5.0, 95.0]), axis=0) - # make plots + # Make plots fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True) - ran = np.arange(mean_pred.shape[0]) + ran = np.arange(pred_y.shape[1]) ax.add_collection( LineCollection( zip(zip(ran, percentiles[0]), zip(ran, percentiles[1])), colors="lightblue" ) ) ax.plot(data.yte, "kx", label="y true") - ax.plot(mean_pred, "ko", label="y pred") - ax.set(xlabel="example", ylabel="y", title="Mean predictions with 90% CI") + ax.plot(pred_y.mean(0), "ko", label="y pred") + ax.set(xlabel="example", ylabel="y", title="Mean Predictions with 90% CI") ax.legend() fig.savefig("stein_bnn.pdf")