Skip to content

Commit

Permalink
Clean up the gallery (#526)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored and neerajprad committed Jan 23, 2020
1 parent 01b34a3 commit e6b2027
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 17 deletions.
2 changes: 1 addition & 1 deletion examples/funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def main(args):
reparam_samples = run_inference(reparam_model, args, rng_key)

# make plots
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(6.4, 6.4))
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(8, 8))

ax1.plot(samples['x'][:, 0], samples['y'], "go", alpha=0.3)
ax1.set(xlim=(-20, 20), ylim=(-9, 9), ylabel='y',
Expand Down
31 changes: 17 additions & 14 deletions examples/neutra.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import argparse
from functools import partial
import os

from matplotlib.gridspec import GridSpec
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -64,7 +65,8 @@ def dual_moon_model():
def main(args):
print("Start vanilla HMC...")
nuts_kernel = NUTS(dual_moon_model)
mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples)
mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
mcmc.run(random.PRNGKey(0))
mcmc.print_summary()
vanilla_samples = mcmc.get_samples()['x'].copy()
Expand All @@ -87,7 +89,8 @@ def main(args):

print("\nStart NeuTra HMC...")
nuts_kernel = NUTS(potential_fn=transformed_potential_fn)
mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples)
mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
init_params = np.zeros(guide.latent_size)
mcmc.run(random.PRNGKey(3), init_params=init_params)
mcmc.print_summary()
Expand All @@ -108,44 +111,44 @@ def main(args):
X1, X2 = np.meshgrid(x1, x2)
P = np.exp(DualMoonDistribution().log_prob(np.stack([X1, X2], axis=-1)))

fig = plt.figure(figsize=(12, 16), constrained_layout=True)
gs = GridSpec(3, 2, figure=fig)
fig = plt.figure(figsize=(12, 8), constrained_layout=True)
gs = GridSpec(2, 3, figure=fig)
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[1, 0])
ax2 = fig.add_subplot(gs[1, 0])
ax3 = fig.add_subplot(gs[0, 1])
ax4 = fig.add_subplot(gs[1, 1])
ax5 = fig.add_subplot(gs[2, 0])
ax6 = fig.add_subplot(gs[2, 1])
ax5 = fig.add_subplot(gs[0, 2])
ax6 = fig.add_subplot(gs[1, 2])

ax1.plot(losses[1000:])
ax1.set_title('Autoguide training loss (after 1000 steps)')
ax1.set_title('Autoguide training loss\n(after 1000 steps)')

ax2.contourf(X1, X2, P, cmap='OrRd')
sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2)
ax2.set(xlim=[-3, 3], ylim=[-3, 3],
xlabel='x0', ylabel='x1', title='Posterior using AutoBNAFNormal guide')
xlabel='x0', ylabel='x1', title='Posterior using\nAutoBNAFNormal guide')

sns.scatterplot(guide_base_samples[:, 0], guide_base_samples[:, 1], ax=ax3,
hue=guide_trans_samples[:, 0] < 0.)
ax3.set(xlim=[-3, 3], ylim=[-3, 3],
xlabel='x0', ylabel='x1', title='AutoBNAFNormal base samples (True=left moon; False=right moon)')
xlabel='x0', ylabel='x1', title='AutoBNAFNormal base samples\n(True=left moon; False=right moon)')

ax4.contourf(X1, X2, P, cmap='OrRd')
sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], n_levels=30, ax=ax4)
ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5)
ax4.set(xlim=[-3, 3], ylim=[-3, 3],
xlabel='x0', ylabel='x1', title='Posterior using vanilla HMC sampler')
xlabel='x0', ylabel='x1', title='Posterior using\nvanilla HMC sampler')

sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples[:, 0] < 0.,
s=30, alpha=0.5, edgecolor="none")
ax5.set(xlim=[-5, 5], ylim=[-5, 5],
xlabel='x0', ylabel='x1', title='Samples from the warped posterior - p(z)')
xlabel='x0', ylabel='x1', title='Samples from the\nwarped posterior - p(z)')

ax6.contourf(X1, X2, P, cmap='OrRd')
sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6)
ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2)
ax6.set(xlim=[-3, 3], ylim=[-3, 3],
xlabel='x0', ylabel='x1', title='Posterior using NeuTra HMC sampler')
xlabel='x0', ylabel='x1', title='Posterior using\nNeuTra HMC sampler')

plt.savefig("neutra.pdf")

Expand Down
2 changes: 1 addition & 1 deletion examples/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def main(args):
mcmc = MCMC(NUTS(model, dense_mass=True),
args.num_warmup, args.num_samples, num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
mcmc.run(PRNGKey(1), N=data.shape[0], y=np.log(data))
mcmc.run(PRNGKey(0), N=data.shape[0], y=np.log(data))
mcmc.print_summary()

# predict populations
Expand Down
2 changes: 1 addition & 1 deletion examples/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def reconstruct_img(epoch, rng_key):
if __name__ == '__main__':
assert numpyro.__version__.startswith('0.2.3')
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-epochs', default=20, type=int, help='number of training epochs')
parser.add_argument('-n', '--num-epochs', default=15, type=int, help='number of training epochs')
parser.add_argument('-lr', '--learning-rate', default=1.0e-3, type=float, help='learning rate')
parser.add_argument('-batch-size', default=128, type=int, help='batch size')
parser.add_argument('-z-dim', default=50, type=int, help='size of latent')
Expand Down

0 comments on commit e6b2027

Please sign in to comment.