Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batchnorm #13

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -435,4 +435,6 @@ If you use the code in a publication, please cite our ICLR 2020 paper:

##### [14] [Wide Residual Networks.](https://arxiv.org/abs/1605.07146) *BMVC 2018.* Sergey Zagoruyko, Nikos Komodakis

##### [15] [On the Infinite Width Limit of Neural Networks with a Standard Parameterization.](https://arxiv.org/pdf/2001.07301.pdf) *arXiv 2020.* Jascha Sohl-Dickstein, Roman Novak, Samuel S. Schoenholz, Jaehoon Lee
##### [15] [Tensor Programs I: Wide Feedforward or Recurrent Neural Networks of Any Architecture are Gaussian Processes.](https://arxiv.org/abs/1910.12478) *NeurIPS 2019.* Greg Yang.

##### [16] [On the Infinite Width Limit of Neural Networks with a Standard Parameterization.](https://arxiv.org/pdf/2001.07301.pdf) *arXiv 2020.* Jascha Sohl-Dickstein, Roman Novak, Samuel S. Schoenholz, Jaehoon Lee
55 changes: 23 additions & 32 deletions examples/infinite_fcn.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
import jax.numpy as np
import neural_tangents as nt
from neural_tangents import stax
from examples import datasets
from examples import util
from jax import random


flags.DEFINE_integer('train_size', 1000,
Expand All @@ -37,43 +36,35 @@

FLAGS = flags.FLAGS

import pdb
from jax.experimental import callback
from functools import partial

def main(unused_argv):
# Build data pipelines.
print('Loading data.')
x_train, y_train, x_test, y_test = \
datasets.get_dataset('cifar10', FLAGS.train_size, FLAGS.test_size)
key = random.PRNGKey(0)
key, split = random.split(key)
x_train = random.normal(key=key, shape=[2, 3, 4, 5])
x_train2 = random.normal(key=split, shape=[1, 3, 4, 5])

# Build the infinite network.
_, _, kernel_fn = stax.serial(
stax.Dense(1, 2., 0.05),
stax.Relu(),
stax.Dense(1, 2., 0.05)
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Conv(256, (3, 3), padding='SAME'),
stax.BatchNormRelu((0, 1, 2)),
stax.GlobalAvgPool(),
stax.Dense(256, 2., 0.05)
)

# Optionally, compute the kernel in batches, in parallel.
kernel_fn = nt.batch(kernel_fn,
device_count=0,
batch_size=FLAGS.batch_size)

start = time.time()
# Bayesian and infinite-time gradient descent inference with infinite network.
fx_test_nngp, fx_test_ntk = nt.predict.gp_inference(kernel_fn,
x_train,
y_train,
x_test,
get=('nngp', 'ntk'),
diag_reg=1e-3)
fx_test_nngp.block_until_ready()
fx_test_ntk.block_until_ready()

duration = time.time() - start
print('Kernel construction and inference done in %s seconds.' % duration)

# Print out accuracy and loss for infinite network predictions.
loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss)
util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
# kernel_fn = callback.find_by_value(partial(kernel_fn, get='nngp'), np.nan)
kerobj = kernel_fn(x_train, x_train2, get='nngp')
theory_ker = kerobj
mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key, 10000)
diff = theory_ker - mc_kernel_fn(x_train, x_train2, get='nngp')
print(diff)
# print(kerobj.cov1 - kerobj.nngp)
print(np.linalg.norm(diff) / np.linalg.norm(theory_ker))
# 0.0032839081
return


if __name__ == '__main__':
Expand Down
Loading