Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First training attempt #119

Open
dkokron opened this issue Dec 30, 2024 · 2 comments
Open

First training attempt #119

dkokron opened this issue Dec 30, 2024 · 2 comments

Comments

@dkokron
Copy link

dkokron commented Dec 30, 2024

I am working with the gencast_mini_demo.ipynb demo.
Source data is source-era5_date-2019-03-29_res-1.0_levels-13_steps-01.nc.

I added the optimizer steps and the loop. Inferencing does not work on CPUs (#113) so I'll ask here if those Loss and Mean values look reasonable?

params = ckpt.params
loss, diagnostics, next_state, grads = grads_fn_jitted(
    params=params,
    state=state,
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)
mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")

optimizer = optax.adamw(0.0004)
opt_state = optimizer.init(params)
for i in range(5):
   updates, opt_state = optimizer.update(grads, opt_state, params)
   params = optax.apply_updates(params, updates)
   loss, diagnostics, next_state, grads = grads_fn_jitted(
       params=params,
       state=state,
       inputs=train_inputs,
       targets=train_targets,
       forcings=train_forcings)
   mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
   print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")

Loss: 4.0923, Mean |grad|: 0.006474
Loss: 83.6939, Mean |grad|: 0.272431
Loss: 20.1864, Mean |grad|: 0.044424
Loss: 15.5214, Mean |grad|: 0.022913
Loss: 10.1925, Mean |grad|: 0.008500
Loss: 8.7714, Mean |grad|: 0.005301
@dkokron dkokron changed the title First training attmept First training attempt Dec 30, 2024
@dkokron
Copy link
Author

dkokron commented Dec 30, 2024

This version follows the paper (https://arxiv.org/pdf/2312.15796) more closely.

page 27
Table D1 | Diffusion model training hyperparameters.

params = ckpt.params
loss, diagnostics, next_state, grads = grads_fn_jitted(
    params=params,
    state=state,
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)
mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")

lr_schedule = optax.schedules.warmup_cosine_decay_schedule(
    init_value = .00001,
    peak_value = .003,
    warmup_steps = 1000,
    decay_steps = 1000000,
    end_value = 0.0,
    exponent = 0.1,
)
optimizer = optax.adamw(lr_schedule)
opt_state = optimizer.init(params)
for i in range(5):
   updates, opt_state = optimizer.update(grads, opt_state, params)
   params = optax.apply_updates(params, updates)
   loss, diagnostics, next_state, grads = grads_fn_jitted(
       params=params,
       state=state,
       inputs=train_inputs,
       targets=train_targets,
       forcings=train_forcings)
   mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
   print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")

Loss: 4.0923, Mean |grad|: 0.006474
Loss: 5.3288, Mean |grad|: 0.073276
Loss: 4.4514, Mean |grad|: 0.045505
Loss: 4.5843, Mean |grad|: 0.025861
Loss: 4.5447, Mean |grad|: 0.023869
Loss: 4.3157, Mean |grad|: 0.014323

@dkokron
Copy link
Author

dkokron commented Jan 1, 2025

Setting the lr_schedule init_value to zero and training on source-era5_date-2019-03-29_res-1.0_levels-13_steps-01.nc, I get zero change in the params. That seems odd to me. There ought to be some numerical differences from adding new data. Can someone tell me what I'm doing wrong?

params = ckpt.params
loss, diagnostics, next_state, grads = grads_fn_jitted(
    params=params,
    state=state,
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)
mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
print(f"Loss: {loss:.6f}, Mean |grad|: {mean_grad:.6f}")

lr_schedule = optax.schedules.warmup_cosine_decay_schedule(
    init_value = 0.0,
    peak_value = .003,
    warmup_steps = 1000,
    decay_steps = 1000000,
    end_value = 0.0,
    exponent = 0.1,
)
optimizer = optax.adamw(lr_schedule)
opt_state = optimizer.init(params)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)

print(np.array_equal(ckpt.params["fourier_features_mlp/~/mlp/~/linear_0"]["b"], params["fourier_features_mlp/~/mlp/~/linear_0"]["b"]))
print(np.array_equal(ckpt.params["fourier_features_mlp/~/mlp/~/linear_0"]["w"], params["fourier_features_mlp/~/mlp/~/linear_0"]["w"]))
print(np.array_equal(ckpt.params["fourier_features_mlp/~/mlp/~/linear_1"]["b"], params["fourier_features_mlp/~/mlp/~/linear_1"]["b"]))
print(np.array_equal(ckpt.params["fourier_features_mlp/~/mlp/~/linear_1"]["w"], params["fourier_features_mlp/~/mlp/~/linear_1"]["w"]))

Loss: 4.092268, Mean |grad|: 0.006474
True
True
True
True

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant