Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gencast - num_steps_per_chunk > 1 breaks rollout.chunked_prediction_generator_multiple_runs #112

Open
efesurekli opened this issue Dec 20, 2024 · 7 comments

Comments

@efesurekli
Copy link

efesurekli commented Dec 20, 2024

Running examples with num_steps_per_chunk = 2 results in the following error in example notebooks with the 1p0 model:

ValueError: 'grid2mesh_gnn/~_networks_builder/encoder_nodes_grid_nodes_mlp/~/linear_0/w' with retrieved shape (267, 512) does not match shape=[355, 512] dtype=dtype('float32')

Is the step chunking working?

Thanks!

@alvarosg
Copy link
Collaborator

Hi! Could you clarify what is it that you are trying to achieve setting that parameter to 2, so we can advice appropriately?

You should be able to get it to run by adding the wrapper in autoregressive.py wrapper in the construct_wrapped_gencast() function, but it may go out of device memory if you do that.

@efesurekli
Copy link
Author

I am trying to run 50-member ensembles for the 0.25 resolution, and encountered the following:

  1. Currently, running on TPU v5p-8 as advised in the docs. It takes ~30 minutes to get 8-member 30 step forecast, also consistent with the docs. But, when I try to run for another 30 steps with 8-members, it takes still 30 minutes not 8 minutes (as in paper and in docs).
  2. When I try running 30-step 0.25 resolution with 50 members instead of 8 members it takes forever, and after 1h+ the process seems to be killed for some reason.

To resolve these and speed up inference, my understanding was this parameter could help. Would really appreciate if you have any other suggestions!

@alvarosg
Copy link
Collaborator

Thanks for explaining.

I don't think that argument will help you with those.

With respect to 1. could you confirm that you are working past this commit.

With respect to 2. I suspect what is happening here is that you are running out of host memory, when you generate a large number of ensembles you probably want to write the chunks to disk as they get generated rather than appending go the list (of course there will be associated time cost with writing to disk, so you may want to set it up to write it asynchronously, or write a subset of the variables only).

Could you confirm what number you get when you print len(jax.local_devices)?

Thanks!

@efesurekli
Copy link
Author

(1) I think that should be the case, I was running this as in notebooks:
%pip install --upgrade https://github.com/deepmind/graphcast/archive/master.zip

(2) I see, thanks for the heads up I'll try to manage the memory more efficiently.

(3) len(jax.local_devices) outputs 4. Shouldn't that be 8 for v5p-8?

Thanks again!

@andrewlkd
Copy link
Collaborator

andrewlkd commented Dec 21, 2024

Hey!

Sure, but which version of the notebook are you using? Could you confirm it was the one past this commit? Note the change in that commit to separate the line that pmaps the run_forward method in the notebooks.

Regarding the number of devices, that's indeed bizarre. Did you mention you were following these instructions? If so, can you confirm how you requested the TPU VM? But indeed, running 8 samples when you have 4 devices is going to double the inference speed because they will be produced sequentially in two batches of 4.

In the meantime, you should be able to reproduce the inference speed by generating just 4 samples (maximising parallelism in the number of devices).

Andrew

@google-deepmind google-deepmind deleted a comment Dec 21, 2024
@efesurekli
Copy link
Author

Hey Andrew, you were right I had the pmap as in previous version, fixed it, and will test again, thanks a lot!

Also for running the 1deg version with different ERA5 conditions, do you use a known regridder or is it a custom one that goes from 0.25 degree to 1 degree –if so would it be possible for you to share the script that generated the 1deg ERA5 datasets in dm_graphcast/gencast/dataset/? 🙏

@alvarosg
Copy link
Collaborator

alvarosg commented Dec 21, 2024

The 1 deg data is simply the 0.25 deg data subsampling it 1 every 4 points along each of the spatial axes. We do it like this so the distribution of the data does not change and we can more easily compare models across resolutions. We follow this approach because we usually use the 1 dev models just as a baseline for the 0.25 deg models, but for other use cases of 1 deg models it may better to train on data subsampled in a different way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants