Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions about memory consumption of infinitely wide NTK #166

Open
jasonli0707 opened this issue Sep 7, 2022 · 6 comments
Open

Questions about memory consumption of infinitely wide NTK #166

jasonli0707 opened this issue Sep 7, 2022 · 6 comments
Labels
bug Something isn't working

Comments

@jasonli0707
Copy link

I am working on a simple MNIST example. I found that I could not compute the NTK for the entire dataset without running out of memory. Below is the code snippet I used:

import neural_tangents as nt
from neural_tangents import stax
from examples import datasets
from jax import random, jit
import jax.numpy as jnp

def FC(depth=1, num_classes=10, W_std=1.0, b_std=0.0):
    layers = [stax.Flatten()]
    for _ in range(depth):
        layers += [stax.Dense(1, W_std, b_std), stax.Relu()]
    layers += [stax.Dense(num_classes, W_std, b_std)]
    return stax.serial(*layers)

x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', data_dir="./data", permute_train=True)

key = random.PRNGKey(0)
init_fn, apply_fn, kernel_fn = FC()
_, params = init_fn(key, (-1, 784))

apply_fn = jit(apply_fn)
kernel_fn = jit(kernel_fn, static_argnums=(2,))

batched_kernel_fn = nt.batch(kernel_fn, 1000, store_on_device=False)

k_train_train = kernel_fn(x_train, None, 'ntk')
k_test_train = kernel_fn(x_test, x_train, 'ntk')
predict_fn = nt.predict.gradient_descent_mse(k_train_train, y_train)
fx_train_0 = apply_fn(params, x_train)
fx_test_0 = apply_fn(params, x_test)
fx_train_inf, fx_test_inf = predict_fn(fx_train_0=fx_train_0, fx_test_0=fx_test_0, k_test_train=k_test_train)

I am running this on two RTX3090 each having a 24Gb buffer.
Is there something I'm doing wrong, or is it normal for NTK to consume so much memory?
Thank you!

@romanngg
Copy link
Contributor

romanngg commented Sep 7, 2022

Thanks for the report, your code correct!

This actually looks like two bugs on our side:

  1. store_on_device argument isn't working, and the kernel is stored on the GPU (I'm assuming you have enough CPU RAM, so you're not running out of it).
  2. even if store_on_device=True, 24Gb of GPU RAM should be enough for the 50k x 50k kernel, but somehow it's not. I suspect there might be a conflict with JAX and Tensorflow competing for GPU memory, could you try running this version of the code on your machine?
    https://colab.research.google.com/gist/romanngg/96421af87f4cc1e13a78454d8bfb4ee9/memory_repro.ipynb
    The part that hopefully helps is
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
import tensorflow_datasets as tfds

(and I'm using tfds instead of neural_tangents.examples)

Another idea is to binary search smaller training set sizes to figure out if we're really hitting the memory limit (e.g. it works for 40K, but not 50K), or if the GPU memory is just not available for some reason (e.g. it doesn't work even for 5K).

Also, could you please post the whole error message?

@romanngg romanngg added the bug Something isn't working label Sep 7, 2022
@jasonli0707
Copy link
Author

jasonli0707 commented Sep 8, 2022

Thank you so much for the detailed reply!

I have tried your code but still face the same issue. Below shows the complete error message for your reference:

2022-09-08 13:20:36.044808: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 9.31GiB (rounded to 10000000000)requested by op
2022-09-08 13:20:36.044942: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:491] ***************************************************************************************************_
2022-09-08 13:20:36.045005: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2130] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 10000000000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 9.31GiB
constant allocation: 0B
maybe_live_out allocation: 9.31GiB
preallocated temp allocation: 0B
total allocation: 18.63GiB
total fragmentation: 0B (0.00%)
Peak buffers:
Buffer 1:
Size: 9.31GiB
Entry Parameter Subshape: s32[50000,50000]

    Buffer 2:
            Size: 9.31GiB
            Operator: op_name="jit(add)/jit(main)/add" source_file="/home/jason/dev/neural-tangents/neural_tangents/_src/predict.py" source_line=1222
            XLA Label: fusion
            Shape: s32[50000,50000]
            

    Buffer 3:
            Size: 4B
            Entry Parameter Subshape: s32[]

Traceback (most recent call last):
File "mnist.py", line 68, in
fx_train_inf, fx_test_inf = predict_fn(fx_train_0=fx_train_0, fx_test_0=fx_test_0, k_test_train=k_test_train)
File "/home/jason/dev/neural-tangents/neural_tangents/_src/predict.py", line 270, in predict_fn
return get_predict_fn_inf()(fx_train_0, fx_test_0, k_test_train)
File "/home/jason/dev/neural-tangents/neural_tangents/_src/predict.py", line 163, in get_predict_fn_inf
solve = _get_cho_solve(k_train_train, diag_reg, diag_reg_absolute_scale)
File "/home/jason/dev/neural-tangents/neural_tangents/_src/predict.py", line 1232, in _get_cho_solve
A = _add_diagonal_regularizer(A, diag_reg, diag_reg_absolute_scale)
File "/home/jason/dev/neural-tangents/neural_tangents/_src/predict.py", line 1222, in _add_diagonal_regularizer
return A + diag_reg * np.eye(dimension)
File "/home/jason/miniconda3/envs/ntk/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2103, in eye
return lax_internal._eye(_jnp_dtype(dtype), (N, M), k)
File "/home/jason/miniconda3/envs/ntk/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 1203, in _eye
bool_eye = eq(add(broadcasted_iota(np.int32, shape, 0), np.int32(offset)),
File "/home/jason/miniconda3/envs/ntk/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 444, in add
return add_p.bind(x, y)
File "/home/jason/miniconda3/envs/ntk/lib/python3.8/site-packages/jax/core.py", line 325, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/jason/miniconda3/envs/ntk/lib/python3.8/site-packages/jax/core.py", line 328, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/jason/miniconda3/envs/ntk/lib/python3.8/site-packages/jax/core.py", line 686, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/jason/miniconda3/envs/ntk/lib/python3.8/site-packages/jax/_src/dispatch.py", line 113, in apply_primitive
return compiled_fun(*args)
File "/home/jason/miniconda3/envs/ntk/lib/python3.8/site-packages/jax/_src/dispatch.py", line 198, in
return lambda *args, **kw: compiled(*args, **kw)[0]
File "/home/jason/miniconda3/envs/ntk/lib/python3.8/site-packages/jax/_src/dispatch.py", line 837, in _execute_compiled
out_flat = compiled.execute(in_flat)
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 10000000000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 9.31GiB
constant allocation: 0B
maybe_live_out allocation: 9.31GiB
preallocated temp allocation: 0B
total allocation: 18.63GiB
total fragmentation: 0B (0.00%)
Peak buffers:
Buffer 1:
Size: 9.31GiB
Entry Parameter Subshape: s32[50000,50000]

    Buffer 2:
            Size: 9.31GiB
            Operator: op_name="jit(add)/jit(main)/add" source_file="/home/jason/dev/neural-tangents/neural_tangents/_src/predict.py" source_line=1222
            XLA Label: fusion
            Shape: s32[50000,50000]
            

    Buffer 3:
            Size: 4B
            Entry Parameter Subshape: s32[]

@jasonli0707
Copy link
Author

I have also tried searching for the maximum number of samples before encountering the memory issue, which turned out to be 36000 in my case:

num_samples = 36000
x_train = x_train[:num_samples]
y_train = y_train[:num_samples]

@romanngg
Copy link
Contributor

romanngg commented Sep 8, 2022

Oh thanks for the error message, I realized what's actually failing is

fx_train_inf, fx_test_inf = predict_fn(fx_train_0=fx_train_0, fx_test_0=fx_test_0, k_test_train=k_test_train)

and not the kernel computation. Indeed 24Gb is not enough to run the Cholesky solver on the 50k x 50k matrix, so you'd need to be doing it on CPU.

To make it happen on CPU, I think the easiest way should be to have predict_fn = jit(predict_fn, backend='cpu') after you define it (and good to jit this function anyway).

Alternatively, but hopefully not necessarily, you can pin input tensors to CPU, to make sure the function called with them as inputs is executed on CPU:

fx_train_0 = jax.device_put(fx_train_0, devices('cpu')[0])
fx_test_0 = jax.device_put(fx_test_0, devices('cpu')[0])
k_test_train = jax.device_put(k_test_train, devices('cpu')[0])

and/or

k_train_train = jax.device_put(k_train_train, devices('cpu')[0])
y_train = jax.device_put(y_train, devices('cpu')[0])

before defining predict_fn. In general, you can print x.device_buffer.device() in various places to see which tensors x are stored on which devices, to figure out what is happening on CPU/GPU (you want your last line to be executed on CPU).

@jasonli0707
Copy link
Author

Thank you so much for the detailed follow-up!

As you suggested, I have tried to move everything to the CPU before defining the predict_fn and verified that they were indeed stored on the CPU. However, after a few minutes, the program is killed by the signal SIGSEGV (Address boundary error). Does it mean that I'm also out of CPU RAM? If yes, is there anything that I can do?

@romanngg
Copy link
Contributor

romanngg commented Sep 9, 2022

How much RAM do you have? Does it work (on CPU, after your modifications) if you use 36k points? I suspect you'd need at least ~64 Gb of RAM, but I only ever tried it on a machine with >128Gb, so I'm not sure what is the exact requirement.

To better debug this you can try to run the piece of code from
#152 (comment)
using first numpy/scipy, and then jax.numpy and jax.scipy to have a smaller repro. Then you could post it to https://github.com/google/jax and ask what they think. I also occasionally get these low-level errors when doing level-3 algebra on large matrices, and don't know how to debug them myself... (e.g. jax-ml/jax#10411, jax-ml/jax#10420)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants