diff --git a/gns/train.py b/gns/train.py index 164c420..6e647f0 100644 --- a/gns/train.py +++ b/gns/train.py @@ -710,7 +710,7 @@ def train(rank, cfg, world_size, device, verbose, use_dist): def train_maml(rank, cfg, world_size, device, verbose, use_dist): """Train the model using MAML. - + Args: rank: local rank cfg: configuration dictionary @@ -727,8 +727,6 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist): ) # Extract the Encoder from the simulator - inner_lr = 1e-3 - inner_steps = 5 main_encoder = simulator._encode_process_decode._encoder # Initialize training state @@ -756,6 +754,10 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist): # Initialize encoders for each material property task_encoders = defaultdict(lambda: copy.deepcopy(main_encoder)) + # Hardcoded values + inner_steps = 5 + inner_lr = 1e-3 + try: num_epochs = max(1, (cfg.training.steps + len(train_dl) - 1) // len(train_dl)) if verbose: @@ -811,19 +813,16 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist): ) # Identify the material property and use the corresponding encoder - material_id = material_property[ - 0 - ].item() # Assuming all materials in the batch are the same + material_id = material_property[0].item() # Assuming all materials in the batch are the same task_encoder = task_encoders[material_id] - task_optimizer = torch.optim.Adam( - task_encoder.parameters(), lr=inner_lr - ) + task_optimizer = optim.Adam(task_encoder.parameters(), lr=inner_lr) # Perform a few steps of gradient descent in the inner loop for _ in range(inner_steps): - node_features, edge_features = task_encoder( - position, material_property - ) + # Replace the simulator's encoder with the task-specific encoder + original_encoder = simulator._encode_process_decode._encoder + simulator._encode_process_decode._encoder = task_encoder + pred_acc, target_acc = predict_fn( next_positions=labels.to(device_or_rank), position_sequence_noise=sampled_noise.to(device_or_rank), @@ -832,22 +831,21 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist): device_or_rank ), particle_types=particle_type.to(device_or_rank), - node_features=node_features.to(device_or_rank), - edge_features=edge_features.to(device_or_rank), + material_property=( + material_property.to(device_or_rank) + if n_features == 3 + else None + ), ) - loss = acceleration_loss( - pred_acc, target_acc, non_kinematic_mask - ) + loss = acceleration_loss(pred_acc, target_acc, non_kinematic_mask) task_optimizer.zero_grad() loss.backward() task_optimizer.step() - # Save the task-specific encoder parameters - task_encoder_params = [ - p.clone().detach() for p in task_encoder.parameters() - ] + # Restore the original encoder + simulator._encode_process_decode._encoder = original_encoder # Compute the meta-loss using the task-specific encoder parameters meta_loss = 0.0 @@ -876,20 +874,9 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist): ) sampled_noise *= non_kinematic_mask.view(-1, 1, 1) - # Load task-specific encoder parameters - task_encoder.load_state_dict( - { - k: v - for k, v in zip( - task_encoder.state_dict().keys(), - task_encoder_params, - ) - } - ) + # Replace the simulator's encoder with the task-specific encoder + simulator._encode_process_decode._encoder = task_encoder - node_features, edge_features = task_encoder( - position, material_property - ) pred_acc, target_acc = predict_fn( next_positions=labels.to(device_or_rank), position_sequence_noise=sampled_noise.to(device_or_rank), @@ -898,15 +885,19 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist): device_or_rank ), particle_types=particle_type.to(device_or_rank), - node_features=node_features.to(device_or_rank), - edge_features=edge_features.to(device_or_rank), + material_property=( + material_property.to(device_or_rank) + if n_features == 3 + else None + ), ) - valid_loss = acceleration_loss( - pred_acc, target_acc, non_kinematic_mask - ) + valid_loss = acceleration_loss(pred_acc, target_acc, non_kinematic_mask) meta_loss += valid_loss + # Restore the original encoder + simulator._encode_process_decode._encoder = original_encoder + meta_loss /= len(valid_dl) # Outer loop: Update the main encoder using the meta-loss