diff --git a/gns/train.py b/gns/train.py index 91a1768..026f8b6 100644 --- a/gns/train.py +++ b/gns/train.py @@ -709,6 +709,12 @@ def train(rank, cfg, world_size, device, verbose, use_dist): if use_dist: distribute.cleanup() +import torch +import torch.nn as nn +import torch.optim as optim +from tqdm import tqdm +import copy +from collections import defaultdict def train_maml(rank, cfg, world_size, device, verbose, use_dist): """Train the model using MAML. @@ -817,7 +823,7 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist): # Identify the material property and use the corresponding encoder material_id = material_property[0].item() # Assuming all materials in the batch are the same task_encoder = task_encoders[material_id] - task_optimizer = torch.optim.Adam(task_encoder.parameters(), lr=inner_lr) + task_optimizer = optim.Adam(task_encoder.parameters(), lr=inner_lr) # Perform a few steps of gradient descent in the inner loop for _ in range(inner_steps): @@ -851,56 +857,30 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist): # Compute the meta-loss using the task-specific encoder parameters meta_loss = 0.0 - for example in valid_dl: - ( - position, - particle_type, - material_property, - n_particles_per_example, - labels, - ) = prepare_data(example, device_id) - - n_particles_per_example = n_particles_per_example.to(device_id) - labels = labels.to(device_id) - - sampled_noise = ( - noise_utils.get_random_walk_noise_for_position_sequence( - position, noise_std_last_step=cfg.data.noise_std - ).to(device_id) - ) - non_kinematic_mask = ( - (particle_type != cfg.data.kinematic_particle_id) - .clone() - .detach() - .to(device_id) - ) - sampled_noise *= non_kinematic_mask.view(-1, 1, 1) - # Replace the simulator's encoder with the task-specific encoder - simulator._encode_process_decode._encoder = task_encoder + # Replace the simulator's encoder with the task-specific encoder + simulator._encode_process_decode._encoder = task_encoder - pred_acc, target_acc = predict_fn( - next_positions=labels.to(device_or_rank), - position_sequence_noise=sampled_noise.to(device_or_rank), - position_sequence=position.to(device_or_rank), - nparticles_per_example=n_particles_per_example.to( - device_or_rank - ), - particle_types=particle_type.to(device_or_rank), - material_property=( - material_property.to(device_or_rank) - if n_features == 3 - else None - ), - ) - - valid_loss = acceleration_loss(pred_acc, target_acc, non_kinematic_mask) - meta_loss += valid_loss + pred_acc, target_acc = predict_fn( + next_positions=labels.to(device_or_rank), + position_sequence_noise=sampled_noise.to(device_or_rank), + position_sequence=position.to(device_or_rank), + nparticles_per_example=n_particles_per_example.to( + device_or_rank + ), + particle_types=particle_type.to(device_or_rank), + material_property=( + material_property.to(device_or_rank) + if n_features == 3 + else None + ), + ) - # Restore the original encoder - simulator._encode_process_decode._encoder = original_encoder + valid_loss = acceleration_loss(pred_acc, target_acc, non_kinematic_mask) + meta_loss += valid_loss - meta_loss /= len(valid_dl) + # Restore the original encoder + simulator._encode_process_decode._encoder = original_encoder # Outer loop: Update the main encoder using the meta-loss optimizer.zero_grad()