diff --git a/README.md b/README.md index 2b4970d4..0cd2037a 100644 --- a/README.md +++ b/README.md @@ -435,4 +435,6 @@ If you use the code in a publication, please cite our ICLR 2020 paper: ##### [14] [Wide Residual Networks.](https://arxiv.org/abs/1605.07146) *BMVC 2018.* Sergey Zagoruyko, Nikos Komodakis -##### [15] [On the Infinite Width Limit of Neural Networks with a Standard Parameterization.](https://arxiv.org/pdf/2001.07301.pdf) *arXiv 2020.* Jascha Sohl-Dickstein, Roman Novak, Samuel S. Schoenholz, Jaehoon Lee +##### [15] [Tensor Programs I: Wide Feedforward or Recurrent Neural Networks of Any Architecture are Gaussian Processes.](https://arxiv.org/abs/1910.12478) *NeurIPS 2019.* Greg Yang. + +##### [16] [On the Infinite Width Limit of Neural Networks with a Standard Parameterization.](https://arxiv.org/pdf/2001.07301.pdf) *arXiv 2020.* Jascha Sohl-Dickstein, Roman Novak, Samuel S. Schoenholz, Jaehoon Lee diff --git a/examples/infinite_fcn.py b/examples/infinite_fcn.py old mode 100644 new mode 100755 index db90c6b9..e7d2dc9c --- a/examples/infinite_fcn.py +++ b/examples/infinite_fcn.py @@ -23,8 +23,7 @@ import jax.numpy as np import neural_tangents as nt from neural_tangents import stax -from examples import datasets -from examples import util +from jax import random flags.DEFINE_integer('train_size', 1000, @@ -37,43 +36,35 @@ FLAGS = flags.FLAGS +import pdb +from jax.experimental import callback +from functools import partial def main(unused_argv): # Build data pipelines. print('Loading data.') - x_train, y_train, x_test, y_test = \ - datasets.get_dataset('cifar10', FLAGS.train_size, FLAGS.test_size) + key = random.PRNGKey(0) + key, split = random.split(key) + x_train = random.normal(key=key, shape=[2, 3, 4, 5]) + x_train2 = random.normal(key=split, shape=[1, 3, 4, 5]) # Build the infinite network. - _, _, kernel_fn = stax.serial( - stax.Dense(1, 2., 0.05), - stax.Relu(), - stax.Dense(1, 2., 0.05) + init_fn, apply_fn, kernel_fn = stax.serial( + stax.Conv(256, (3, 3), padding='SAME'), + stax.BatchNormRelu((0, 1, 2)), + stax.GlobalAvgPool(), + stax.Dense(256, 2., 0.05) ) - - # Optionally, compute the kernel in batches, in parallel. - kernel_fn = nt.batch(kernel_fn, - device_count=0, - batch_size=FLAGS.batch_size) - - start = time.time() - # Bayesian and infinite-time gradient descent inference with infinite network. - fx_test_nngp, fx_test_ntk = nt.predict.gp_inference(kernel_fn, - x_train, - y_train, - x_test, - get=('nngp', 'ntk'), - diag_reg=1e-3) - fx_test_nngp.block_until_ready() - fx_test_ntk.block_until_ready() - - duration = time.time() - start - print('Kernel construction and inference done in %s seconds.' % duration) - - # Print out accuracy and loss for infinite network predictions. - loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2) - util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss) - util.print_summary('NTK test', y_test, fx_test_ntk, None, loss) + # kernel_fn = callback.find_by_value(partial(kernel_fn, get='nngp'), np.nan) + kerobj = kernel_fn(x_train, x_train2, get='nngp') + theory_ker = kerobj + mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key, 10000) + diff = theory_ker - mc_kernel_fn(x_train, x_train2, get='nngp') + print(diff) + # print(kerobj.cov1 - kerobj.nngp) + print(np.linalg.norm(diff) / np.linalg.norm(theory_ker)) + # 0.0032839081 + return if __name__ == '__main__': diff --git a/neural_tangents/stax.py b/neural_tangents/stax.py old mode 100644 new mode 100755 index 24c6acdb..741cff8b --- a/neural_tangents/stax.py +++ b/neural_tangents/stax.py @@ -67,12 +67,14 @@ import enum import functools import operator as op +from functools import reduce import string from typing import Tuple, List, Optional, Iterable, Callable, Union import warnings import frozendict from jax import lax +from jax.api import vmap from jax import linear_util as lu from jax import numpy as np from jax import ops @@ -86,6 +88,9 @@ from neural_tangents.utils import utils from neural_tangents.utils.kernel import Kernel +from jax.nn.initializers import ones, zeros +import quadpy as qp +from scipy.special import roots_laguerre class Padding(enum.Enum): CIRCULAR = 'CIRCULAR' @@ -2915,3 +2920,401 @@ def _pool_mask(mask: np.ndarray, f'please submit a bug to ' f'https://github.com/google/neural-tangents/issues/new.') return mask + + +def _batchnorm_relu_kernel_fn(cov1, cov2, nngp): + # apply the below code to cov1 and cov2 + # write new cross batch code for (nngp, cov1, cov2) + + # cov1 += np.eye(cov1.shape[0]) * 1e-5 + # if cov2 is not None: + # cov2 += np.eye(cov2.shape[0]) * 1e-5 + iu = ops.index_update + ix = ops.index + + def Gmatrix(batch_size): + return np.eye(batch_size) - np.ones((batch_size, batch_size)) / batch_size + + def singlebatchker(ker): + batch_size = ker.shape[0] + G = Gmatrix(batch_size) + eigvals, eigvecs = np.linalg.eigh(G @ ker @ G) + # NOTE: 0 eigvals can appear as negative, so explicitly zero out + eigvals = np.where(eigvals < 0, 0, eigvals) + # eigvals[eigvals < 0] = 0 + logeigvals = np.log(eigvals[1:]) + eigvecs = eigvecs[:, 1:] + + # NOTE: Likely the same as _get_ab_relu_kernel. + def VReLU(cov, eps=1e-5): + indices = list(range(cov.shape[-1])) + d = np.sqrt(cov[..., indices, indices]) + dd = d[..., np.newaxis] * d[..., np.newaxis, :] + c = dd ** (-1) * cov + c = np.where(c > 1 - eps, 1 - eps, c) + c = np.where(c < -1 + eps, -1 + eps, c) + c = (np.sqrt(1 - c ** 2) + (np.pi - np.arccos(c)) * c) / np.pi + return np.nan_to_num(0.5 * dd * c) + + # NOTE: eps is a stability factor for the integral, not for batch norm. + def integrand(log_s, logmultifactor=0, eps=1e-10): + logUeigvals = np.logaddexp(0, np.log(2) + log_s[..., None] + logeigvals) + loginteigvals = logeigvals - logUeigvals + + loginteigvals -= 0.5 * np.sum(logUeigvals, axis=-1, keepdims=True) + loginteigvals += logmultifactor + + inteigvals = np.exp(loginteigvals) + + intvals = np.einsum( + 'ij,...j,jk->...ik', eigvecs, inteigvals, eigvecs.T, optimize=True) + return VReLU(intvals) + + # TODO(Greg): Compute the split point correctly. + intargmax = -np.log(batch_size - 1) + npos = 10 # TODO(Greg): Make these parameters, maybe? + nneg = 10 + alpha = 1 / 8. + + schemepospoints, schemeposweights = roots_laguerre(npos) + schemenegpoints, schemenegweights = roots_laguerre(nneg) + # schemepospoints + # schemepos = qp.e1r.gauss_laguerre(npos) + # schemeneg = qp.e1r.gauss_laguerre(nneg) + + # schemepospoints = schemepos.points + # schemenegpoints = schemeneg.points + # schemeposweights = schemepos.weights + # schemenegweights = schemeneg.weights + + integrandpos = lambda xs: np.moveaxis( + np.moveaxis( + alpha * integrand(intargmax + alpha * xs, + logmultifactor=( + intargmax + (1 + alpha) * xs)[..., np.newaxis]), + 2, 0), 3, 1) + + integrandneg = lambda xs: np.moveaxis( + np.exp(intargmax) * np.moveaxis(integrand(intargmax - xs), 2, 0), 3, 1) + + new_nngp = batch_size * ( + np.einsum('...i,i->...', + integrandpos(np.array([schemepospoints.T])), + schemeposweights) + + + np.einsum('...i,i->...', + integrandneg(np.array([schemenegpoints.T])), + schemenegweights)).squeeze(-1) + return new_nngp + + def J1(c, eps=1e-10): + c = np.clip(c, -1+eps, 1-eps) + return (np.sqrt(1-c**2) + (np.pi - np.arccos(c)) * c) / np.pi + + def VBNReLUCrossBatchIntegrand(Xi, Sigma1, Sigma2): + '''Computes the off diagonal block of the BN+ReLU kernel over 2 batches + Input: + Xi: covariance between batch1 and batch2 + Sigma1: autocovariance of batch1 + Sigma2: autocovariance of batch2 + Output: + f: integrand function in the integral for computing cross batch VBNReLU + ''' + # import pdb; pdb.set_trace() + myblock = np.block([[Sigma1, Xi], [Xi.T, Sigma2]]) + # print(np.linalg.norm(myblock - myblock.T)) + eigval1, eigvec1 = np.linalg.eigh(Sigma1) + eigval2, eigvec2 = np.linalg.eigh(Sigma2) + print('eigval1\n', eigval1) + print('eigval2\n', eigval2) + eigvals, eigvecs = np.linalg.eigh(myblock) + print('block eigvals\n', eigvals) + # import sys; sys.exit() + + n1 = Sigma1.shape[0] + n2 = Sigma2.shape[0] + G1 = Gmatrix(n1) + G2 = Gmatrix(n2) + Delta1, A1 = np.linalg.eigh(G1 @ Sigma1 @ G1) + Delta2, A2 = np.linalg.eigh(G2 @ Sigma2 @ G2) + # NOTE: 0 eigvals can appear as negative, so explicitly zero out + Delta1 = np.where(Delta1 < 0, 0, Delta1) + Delta2 = np.where(Delta2 < 0, 0, Delta2) + # kill first 0 eigenval + Delta1 = Delta1[1:] + Delta2 = Delta2[1:] + A1 = A1[:, 1:] + A2 = A2[:, 1:] + + Xidot = A1.T @ Xi @ A2 + Omegadot = np.block([[np.diag(Delta1), Xidot], [Xidot.T, np.diag(Delta2)]]) + Omegadotinv = np.linalg.inv(Omegadot) + # import pdb; pdb.set_trace() + + def f(s, t, multfactor=1): + # Ddot.shape = (..., n1+n2-2, n1+n2-2) + Ddot = s[..., None, None] * np.eye(n1-1+n2-1) + # Ddot[..., np.arange(n1-1, n1+n2-2), np.arange(n1-1, n1+n2-2)] = t[..., None] + Ddot = iu(Ddot, ix[..., np.arange(n1-1, n1+n2-2), np.arange(n1-1, n1+n2-2)], + t[..., None]) + + ## Compute off-diagonal block of VReLU(Pi) + Pitilde = Omegadotinv + 2 * Ddot + Pitilde = np.linalg.inv(Pitilde) + Pi11diag = np.einsum('ij,...jk,ki->...i', + A1, + Pitilde[..., :n1-1, :n1-1], + A1.T) + Pi22diag = np.einsum('ij,...jk,ki->...i', + A2, + Pitilde[..., n1-1:, n1-1:], + A2.T) + Pi12 = np.einsum('ij,...jk,kl->...il', + A1, + Pitilde[..., :n1-1, n1-1:], + A2.T) + C = J1(np.einsum('...i,...ij,...j->...ij', + Pi11diag**-0.5, + Pi12, + Pi22diag**-0.5)) + VReLUPi12 = 0.5 * np.einsum('...i,...ij,...j->...ij', + Pi11diag**0.5, + C, + Pi22diag**0.5) + print('Cnorm', np.linalg.norm(C)) + print('Pitildenorm', np.linalg.norm(Pitilde)) + print('Omegadotinv', np.linalg.norm(Omegadotinv)) + # import pdb; pdb.set_trace() + ## Compute determinant + ind = np.arange(n1+n2-2) + # Ddot <- matrix inverse of Ddot + # Ddot[..., ind, ind] = Ddot[..., ind, ind]**-1 + Ddot = iu(Ddot, ix[..., ind, ind], Ddot[..., ind, ind]**-1) + logdet = np.linalg.slogdet(Ddot + 2 * Omegadot)[1] + # print('logdet', np.linalg.norm(logdet)) + return np.exp( + (np.log(multfactor) + + (-n1/2) * np.log(s) + + (-n2/2) * np.log(t) + - 1/2 * logdet)[..., None, None] + np.log(VReLUPi12)) + return f + + def VBNReLUCrossBatch(Xi, Sigma1, Sigma2, npos=10, nneg=5, + alphapos1=1/3, alphaneg1=1, + alphapos2=1/3, alphaneg2=1): + '''Compute VBNReLU for two batches. + + Inputs: + Xi: covariance between batch1 and batch2 + Sigma1: autocovariance of batch1 + Sigma2: autocovariance of batch2 + npos: number of points for integrating the big s side of the VBNReLU integral + (effective for both dimensions of integration) + nneg: number of points for integrating the small s side of the VBNReLU integral + (effective for both dimensions of integration) + alphapos1: reparametrize the large s integral by s = exp(alpha r) in the 1st dimension + alphaneg1: reparametrize the small s integral by s = exp(alpha r) in the 1st dimension + alphapos2: reparametrize the large s integral by s = exp(alpha r) in the 2nd dimension + alphaneg2: reparametrize the small s integral by s = exp(alpha r) in the 2nd dimension + By tuning the `alpha` parameters, the integrand is closer to being well-approximated by + low-degree Laguerre polynomials, which makes the quadrature more accurate in approximating the integral. + Outputs: + The (batch1, batch2) block of block matrix obtained by + applying VBNReLU^{\oplus 2} to the kernel of batch1 and batch2 + ''' + # We will do the integration explicitly ourselves: + # We obtain sample points and weights via `quadpy`'s Gauss Laguerre quadrature + # and do the sum ourselves + print('B') + blockeig(Sigma1, Sigma2, Xi) + # schemepos = qp.e1r.gauss_laguerre(npos, alpha=0) + # schemeneg = qp.e1r.gauss_laguerre(nneg, alpha=0) + schemepospoints, schemeposweights = roots_laguerre(npos) + schemenegpoints, schemenegweights = roots_laguerre(nneg) + dim1 = Sigma1.shape[0] + dim2 = Sigma2.shape[0] + intargmax = (-np.log(2*(dim1-1)), -np.log(2*(dim2-1))) + f = VBNReLUCrossBatchIntegrand(Xi, Sigma1, Sigma2) + # Get the points manually for each dimension + # scheme1dpoints = np.concatenate([schemepos.points, -schemeneg.points]) + scheme1dpoints = np.concatenate([schemepospoints, -schemenegpoints]) + # Get the weights manually for each dimension + # scheme1dwts = np.concatenate([schemepos.weights, schemeneg.weights]) + scheme1dwts = np.concatenate([schemeposweights, schemenegweights]) + # Obtain the points for the whole 2d integration + scheme2dpoints = np.meshgrid(scheme1dpoints, scheme1dpoints) + # Obtain the weights for the whole 2d integration + scheme2dwts = scheme1dwts[:, None] * scheme1dwts[None, :] + + def applyalpha(x, alphapos, alphaneg): + x = iu(x, ix[x > 0], x[x > 0] * alphapos) + x = iu(x, ix[x <= 0], x[x <= 0] * alphaneg) + # xx[xx > 0] *= alphapos + # xx[xx <= 0] *= alphaneg + return x + + # def iu(x, mask, ) + def alphafactor(x, y): + a = np.zeros_like(x) + a = iu(a, ix[(x > 0) & (y > 0)], alphapos1 * alphapos2) + a = iu(a, ix[(x > 0) & (y <= 0)], alphapos1 * alphaneg2) + a = iu(a, ix[(x <= 0) & (y > 0)], alphaneg1 * alphapos2) + a = iu(a, ix[(x <= 0) & (y <= 0)], alphaneg1 * alphaneg2) + return a + + integrand = lambda inp: \ + f(np.exp(applyalpha(inp[0], alphapos1, alphaneg1) + intargmax[0]), + np.exp(applyalpha(inp[1], alphapos2, alphaneg2) + intargmax[1]), + multfactor=alphafactor(inp[0], inp[1]) + * np.pi**-1 + * np.exp(applyalpha(inp[0], alphapos1, alphaneg1) + intargmax[0] + + applyalpha(inp[1], alphapos2, alphaneg2) + intargmax[1] + + np.abs(inp[0]) + np.abs(inp[1]) + ) + ) + + return np.sqrt(dim1 * dim2) * np.einsum('ij...,ij->...', + integrand(scheme2dpoints), + scheme2dwts + ) + + + new_cov1 = singlebatchker(cov1) + if cov2 is None: + return cov1, None, cov1 + + new_cov2 = singlebatchker(cov2) + + print('C') + blockeig(cov1, cov2, nngp) + print('new C') + blockeig(new_cov1, new_cov2, nngp) + new_nngp = VBNReLUCrossBatch(nngp, cov1, cov2) + + return new_cov1, new_cov2, new_nngp + +def blockeig(cov1, cov2, nngp): + myblock = np.block([[cov1, nngp], + [nngp.T, cov2]]) + eigvals, _ = np.linalg.eigh(myblock) + print(eigvals) + +@layer +def BatchNormRelu(axis, channel_axis=-1): + """Layer construction function for a batch normalization layer. + + See the papers below for the derivation. + https://arxiv.org/abs/1902.08129 + https://arxiv.org/abs/1910.12478 + + The implementation here follows the reference implementation in + https://github.com/thegregyang/GP4A + + Args: + :axis: integer or a tuple, specifies dimensions over which to normalize. + :channel_axis: integer, channel axis. Defaults to `-1`, the trailing axis. + For `kernel_fn`, channel size is considered to be infinite. + """ + epsilon = 1e-8 + center=True + scale=True + beta_init=zeros + gamma_init=ones + + _beta_init = lambda rng, shape: beta_init(rng, shape) if center else () + _gamma_init = lambda rng, shape: gamma_init(rng, shape) if scale else () + axis = (axis,) if np.isscalar(axis) else axis + + init_fn, bn_apply_fn = ostax.BatchNorm( + axis, epsilon, center, scale, beta_init, gamma_init) + if channel_axis >= 0: + axis = tuple(a if a < channel_axis else a-1 for a in axis) + + def apply_fn(params, xs, **kwargs): + xs = bn_apply_fn(params, xs) + return _ab_relu(xs, 0, 1) + + def rotate_greg(cov1, axis, flatten=False): + cov1 = utils.unzip_axes(cov1) + ndim = len(cov1.shape) // 2 + naxes = len(axis) + unnorm_size1 = reduce(op.mul, (cov1.shape[i] for i in range(ndim) if i not in axis), 1) + unnorm_size2 = reduce(op.mul, (cov1.shape[i+ndim] for i in range(ndim) if i not in axis), 1) + norm_size1 = reduce(op.mul, (cov1.shape[i] for i in axis), 1) + norm_size2 = reduce(op.mul, (cov1.shape[i+ndim] for i in axis), 1) + + source_axes = list(axis) + list(np.array(axis) + ndim) + _negidx = np.array(list(range(-2*naxes, 0))) + dest_axes = list(2*ndim + _negidx) + cov1 = np.moveaxis(cov1, source_axes, dest_axes) + old_shape = cov1.shape + if not flatten: + return cov1 + cov1 = cov1.reshape(unnorm_size1, unnorm_size2, norm_size1, norm_size2) + + def unrotate(cov): + assert cov.shape == (unnorm_size1, unnorm_size2, norm_size1, norm_size2), str((unnorm_size1, unnorm_size2, norm_size1, norm_size2)) + cov = cov.reshape(*old_shape) + cov = np.moveaxis(cov, dest_axes, source_axes) + cov = utils.zip_axes(cov) + return cov + return cov1, unrotate + + @_requires(diagonal_batch=False) + def kernel_fn(kernels): + if not kernels.is_gaussian: + raise NotImplementedError('`BatchNormRelu` is only implemented for the ' + 'case if the input layer is guaranteed to be mean' + '-zero Gaussian, i.e. having `is_gaussian` ' + 'set to `True`.') + if kernels.ntk is not None: + raise NotImplementedError('NTK is currently not supported by `BatchNormRelu`.') + + cov1, cov2, nngp = kernels.cov1, kernels.cov2, kernels.nngp + # myblock = np.block([[cov1, nngp], [nngp.T, cov2]]) + + cov1_flatten, cov1_unrotate = rotate_greg(cov1, axis, flatten=True) + # import pdb; pdb.set_trace() + ll = list(range(cov1_flatten.shape[0])) + cov1diag = cov1_flatten[ll, ll] + + ##### compute kernel ###### + # loop over pairs of non-normalized coordinates + # extract blocks + # cov1.shape = (batch, spatial, batch, spatial) + def fn_sam(cov1, cov2, nngp): + return lax.cond(np.allclose(cov1, cov2), + (cov1, None, cov1), lambda x: _batchnorm_relu_kernel_fn(*x)[2], + (cov1, cov2, nngp), lambda x: _batchnorm_relu_kernel_fn(*x)[2]) + _vmapped_bnrelu_self = vmap(vmap(fn_sam, (None, 0, 0)), (0, None, 0)) + _vmapped_bnrelu_other = vmap(vmap(lambda *x: _batchnorm_relu_kernel_fn(*x)[2], (None, 0, 0)), (0, None, 0)) + # import pdb; pdb.set_trace() + cov1 = _vmapped_bnrelu_self(cov1diag, cov1diag, cov1_flatten) + + # TODO(Greg): bypass to single batch case if no spatial dimension and cov2 is None + if cov2 is None: + cov1 = cov1_unrotate(cov1) + nngp = cov1 + else: + cov2_flatten, cov2_unrotate = rotate_greg(cov2, axis, flatten=True) + # import pdb; pdb.set_trace() + nngp_flatten, nngp_unrotate = rotate_greg(nngp, axis, flatten=True) + cov2diag = cov2_flatten[ll, ll] + print('A') + blockeig(cov1diag[0], cov2diag[0], nngp_flatten[0, 0]) + cov2 = _vmapped_bnrelu_self(cov2diag, cov2diag, cov2_flatten) + # print('cov1', cov1) + # print('cov2', cov2) + # import pdb; pdb.set_trace() + # xxx = _batchnorm_relu_kernel_fn(cov1diag[0], cov2diag[0], nngp_flatten[0, 0]) + # print(xxx) + nngp = _vmapped_bnrelu_other(cov1diag, cov2diag, nngp_flatten) + # import pdb; pdb.set_trace() + cov1 = cov1_unrotate(cov1) + cov2 = cov2_unrotate(cov2) + nngp = nngp_unrotate(nngp) + + return kernels.replace(cov1=cov1, nngp=nngp, cov2=cov2, is_gaussian=False) + + + return init_fn, apply_fn, kernel_fn \ No newline at end of file diff --git a/neural_tangents/tests/stax_test.py b/neural_tangents/tests/stax_test.py old mode 100644 new mode 100755 index 58009451..00ad3eb8 --- a/neural_tangents/tests/stax_test.py +++ b/neural_tangents/tests/stax_test.py @@ -26,11 +26,15 @@ from jax.config import config as jax_config from jax.lib import xla_bridge import jax.numpy as np +import numpy as onp import jax.random as random from neural_tangents import stax from neural_tangents.utils import monte_carlo from neural_tangents.utils import test_utils +from absl.testing import absltest +from jax import disable_jit +disable_jit() jax_config.parse_flags_with_absl() @@ -42,7 +46,7 @@ BATCH_SIZE = 2 -INPUT_SHAPE = (BATCH_SIZE, 7, 6, 3) +INPUT_SHAPE = (BATCH_SIZE, 7, 6, 16) WIDTHS = [2**11] @@ -85,7 +89,9 @@ 'CHW', 'NC', 'NWC', - 'NCHW' + 'NCHW', + 'N', + 'NHW' ] POOL_TYPES = [ @@ -101,13 +107,19 @@ test_utils.update_test_tolerance() -def _get_inputs(key, is_conv, same_inputs, input_shape, fn=np.cos): +def _get_inputs(key, is_conv, same_inputs, input_shape, new_batch_size=None, fn=np.cos): key, split = random.split(key) shape = input_shape if is_conv else (input_shape[0], np.prod(input_shape[1:])) - x1 = fn(random.normal(key, shape)) batch_axis = shape.index(BATCH_SIZE) - shape = shape[:batch_axis] + (2 * BATCH_SIZE,) + shape[batch_axis + 1:] - x2 = None if same_inputs else 2 * fn(random.normal(split, shape)) + if new_batch_size is None: + shape1 = shape + shape2 = shape[:batch_axis] + (2 * BATCH_SIZE,) + shape[batch_axis + 1:] + else: + shape1 = shape[:batch_axis] + (new_batch_size,) + shape[batch_axis + 1:] + shape2 = shape[:batch_axis] + (2 * new_batch_size,) + shape[batch_axis + 1:] + x1 = fn(random.normal(key, shape1)) + x2 = None if same_inputs else 2 * fn(random.normal(split, shape2)) + print(shape1, shape2) return x1, x2 @@ -167,7 +179,7 @@ def conv(out_chan): return stax.GeneralConv( ) affine = conv(width) if is_conv else fc(width) - rate = np.onp.random.uniform(0.5, 0.9) + rate = onp.random.uniform(0.5, 0.9) dropout = stax.Dropout(rate, mode='train') if pool_type == 'AVG': @@ -188,10 +200,16 @@ def conv(out_chan): return stax.GeneralConv( else: pool_or_identity = stax.Identity() dropout_or_identity = dropout if use_dropout else stax.Identity() - layer_norm_or_identity = (stax.Identity() if layer_norm is None - else stax.LayerNorm(axis=layer_norm, - batch_axis=batch_axis, - channel_axis=channel_axis)) + if layer_norm is None: + norm_phi = phi + elif channel_axis not in layer_norm: + norm_phi = stax.BatchNormRelu(layer_norm, channel_axis=channel_axis) + else: + norm_phi = stax.serial( + stax.LayerNorm(axis=layer_norm, + batch_axis=batch_axis, + channel_axis=channel_axis), + phi) res_unit = stax.serial(dropout_or_identity, affine, pool_or_identity) if is_res: block = stax.serial( @@ -200,14 +218,12 @@ def conv(out_chan): return stax.GeneralConv( stax.parallel(stax.Identity(), res_unit), stax.FanInSum(), - layer_norm_or_identity, - phi) + norm_phi) else: block = stax.serial( affine, res_unit, - layer_norm_or_identity, - phi) + norm_phi) if proj_into_2d == 'FLAT': proj_layer = stax.Flatten(batch_axis, 0) @@ -337,7 +353,7 @@ def test_exact(self, model, width, strides, padding, phi, same_inputs, pool_type = 'AVG' W_std, b_std = 2.**0.5, 0.5**0.5 - layer_norm = None + norm_axis = None parameterization = 'ntk' use_dropout = False @@ -430,7 +446,7 @@ def test_parameterizations(self, model, width, same_inputs, is_ntk, for is_ntk in [False, True] for proj_into_2d in PROJECTIONS[:2] for layer_norm in LAYER_NORM)) - def test_layernorm(self, + def test_normalization(self, model, width, same_inputs, @@ -440,9 +456,12 @@ def test_layernorm(self, is_conv = 'conv' in model # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: - if xla_bridge.get_backend().platform == 'cpu': - raise jtu.SkipTest('Not running CNN models on CPU to save time.') - elif proj_into_2d != PROJECTIONS[0] or layer_norm not in ('C', 'NC'): + pass + # if xla_bridge.get_backend().platform == 'cpu': + # raise jtu.SkipTest('Not running CNN models on CPU to save time.') + # if layer_norm == 'N': + # raise jtu.SkipTest('Skipping batchnorm test for convolutional networks.') + elif proj_into_2d != PROJECTIONS[0] or layer_norm not in ('C', 'NC', 'N'): raise jtu.SkipTest('FC models do not have these parameters.') W_std, b_std = 2.**0.5, 0.5**0.5 @@ -458,8 +477,17 @@ def test_layernorm(self, net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm, parameterization, use_dropout) - self._check_agreement_with_empirical(net, same_inputs, is_conv, use_dropout, - is_ntk, proj_into_2d) + # # when testing batchnorm, use batch size 5 (instead of 2) + new_batch_size = None if layer_norm != 'N' else 5 + if layer_norm != 'N' or not is_ntk: + self._check_agreement_with_empirical(net, same_inputs, is_conv, use_dropout, + is_ntk, proj_into_2d, + new_batch_size=new_batch_size) + else: + with self.assertRaises(ValueError): + self._check_agreement_with_empirical(net, same_inputs, is_conv, use_dropout, + is_ntk, proj_into_2d, + new_batch_size=new_batch_size) @jtu.parameterized.named_parameters( jtu.cases_from_list({ @@ -649,11 +677,11 @@ def test_sparse_inputs(self, act, kernel): samples = N_SAMPLES if xla_bridge.get_backend().platform == 'gpu': - jtu._default_tolerance[np.onp.dtype(np.onp.float64)] = 5e-4 + jtu._default_tolerance[onp.dtype(onp.float64)] = 5e-4 samples = 100 * N_SAMPLES else: - jtu._default_tolerance[np.onp.dtype(np.onp.float32)] = 5e-2 - jtu._default_tolerance[np.onp.dtype(np.onp.float64)] = 5e-3 + jtu._default_tolerance[onp.dtype(onp.float32)] = 5e-2 + jtu._default_tolerance[onp.dtype(onp.float64)] = 5e-3 # a batch of dense inputs x_dense = random.normal(key, (input_count, input_size)) @@ -731,13 +759,17 @@ def test_composition_conv(self, avg_pool, same_inputs): self.assertAllClose(composed_ker_out, ker_out_marg, True) def _check_agreement_with_empirical(self, net, same_inputs, is_conv, - use_dropout, is_ntk, proj_into_2d): + use_dropout, is_ntk, proj_into_2d, + new_batch_size=None): (init_fn, apply_fn, kernel_fn), input_shape, device_count = net + # print(use_dropout) num_samples = N_SAMPLES * 5 if use_dropout else N_SAMPLES + # num_samples *= 10 if new_batch_size is not None else 1 key = random.PRNGKey(1) - x1, x2 = _get_inputs(key, is_conv, same_inputs, input_shape) - + x1, x2 = _get_inputs(key, is_conv, same_inputs, input_shape, + new_batch_size=new_batch_size) + # print(np.linalg.norm(x1), np.linalg.norm(x2)) x1_out_shape, params = init_fn(key, x1.shape) if same_inputs: assert (x2 is None) @@ -772,7 +804,9 @@ def _get_empirical(n_samples, get): empirical = np.reshape(_get_empirical(num_samples, 'ntk'), exact.shape) else: exact, shape1, shape2 = kernel_fn(x1, x2, ('nngp', 'shape1', 'shape2')) + print('getting empirical') empirical = _get_empirical(num_samples, 'nngp') + print(empirical) test_utils.assert_close_matrices(self, exact, empirical, rtol) self.assertEqual(shape1, x1_out_shape) self.assertEqual(shape2, x2_out_shape) @@ -1753,6 +1787,5 @@ def get_attn(): empirical = empirical.reshape(exact.shape) test_utils.assert_close_matrices(self, empirical, exact, tol) - if __name__ == '__main__': - jtu.absltest.main() + absltest.main() diff --git a/setup.py b/setup.py old mode 100644 new mode 100755 index d1cdbc76..c3377eec --- a/setup.py +++ b/setup.py @@ -25,7 +25,14 @@ long_description = f.read() -INSTALL_REQUIRES = ['jax>=0.1.58', 'frozendict', 'dataclasses'] +INSTALL_REQUIRES = [ + 'jaxlib>=0.1.47', + 'jax>=0.1.55', + 'frozendict', + 'dataclasses', + 'quadpy==0.13.2' +] + setuptools.setup( name='neural-tangents',