Skip to content

Commit

Permalink
Training using same example for meta
Browse files Browse the repository at this point in the history
  • Loading branch information
kks32 committed Jul 15, 2024
1 parent b7fc8c5 commit 48f6e32
Showing 1 changed file with 27 additions and 47 deletions.
74 changes: 27 additions & 47 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,12 @@ def train(rank, cfg, world_size, device, verbose, use_dist):
if use_dist:
distribute.cleanup()

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import copy
from collections import defaultdict

def train_maml(rank, cfg, world_size, device, verbose, use_dist):
"""Train the model using MAML.
Expand Down Expand Up @@ -817,7 +823,7 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):
# Identify the material property and use the corresponding encoder
material_id = material_property[0].item() # Assuming all materials in the batch are the same
task_encoder = task_encoders[material_id]
task_optimizer = torch.optim.Adam(task_encoder.parameters(), lr=inner_lr)
task_optimizer = optim.Adam(task_encoder.parameters(), lr=inner_lr)

# Perform a few steps of gradient descent in the inner loop
for _ in range(inner_steps):
Expand Down Expand Up @@ -851,56 +857,30 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):

# Compute the meta-loss using the task-specific encoder parameters
meta_loss = 0.0
for example in valid_dl:
(
position,
particle_type,
material_property,
n_particles_per_example,
labels,
) = prepare_data(example, device_id)

n_particles_per_example = n_particles_per_example.to(device_id)
labels = labels.to(device_id)

sampled_noise = (
noise_utils.get_random_walk_noise_for_position_sequence(
position, noise_std_last_step=cfg.data.noise_std
).to(device_id)
)
non_kinematic_mask = (
(particle_type != cfg.data.kinematic_particle_id)
.clone()
.detach()
.to(device_id)
)
sampled_noise *= non_kinematic_mask.view(-1, 1, 1)

# Replace the simulator's encoder with the task-specific encoder
simulator._encode_process_decode._encoder = task_encoder
# Replace the simulator's encoder with the task-specific encoder
simulator._encode_process_decode._encoder = task_encoder

pred_acc, target_acc = predict_fn(
next_positions=labels.to(device_or_rank),
position_sequence_noise=sampled_noise.to(device_or_rank),
position_sequence=position.to(device_or_rank),
nparticles_per_example=n_particles_per_example.to(
device_or_rank
),
particle_types=particle_type.to(device_or_rank),
material_property=(
material_property.to(device_or_rank)
if n_features == 3
else None
),
)

valid_loss = acceleration_loss(pred_acc, target_acc, non_kinematic_mask)
meta_loss += valid_loss
pred_acc, target_acc = predict_fn(
next_positions=labels.to(device_or_rank),
position_sequence_noise=sampled_noise.to(device_or_rank),
position_sequence=position.to(device_or_rank),
nparticles_per_example=n_particles_per_example.to(
device_or_rank
),
particle_types=particle_type.to(device_or_rank),
material_property=(
material_property.to(device_or_rank)
if n_features == 3
else None
),
)

# Restore the original encoder
simulator._encode_process_decode._encoder = original_encoder
valid_loss = acceleration_loss(pred_acc, target_acc, non_kinematic_mask)
meta_loss += valid_loss

meta_loss /= len(valid_dl)
# Restore the original encoder
simulator._encode_process_decode._encoder = original_encoder

# Outer loop: Update the main encoder using the meta-loss
optimizer.zero_grad()
Expand Down

0 comments on commit 48f6e32

Please sign in to comment.