Skip to content

Commit

Permalink
MAML training with correct replacement of encoder in simulator
Browse files Browse the repository at this point in the history
  • Loading branch information
kks32 committed Jul 15, 2024
1 parent b336176 commit 45a1bdf
Showing 1 changed file with 30 additions and 39 deletions.
69 changes: 30 additions & 39 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def train(rank, cfg, world_size, device, verbose, use_dist):

def train_maml(rank, cfg, world_size, device, verbose, use_dist):
"""Train the model using MAML.
Args:
rank: local rank
cfg: configuration dictionary
Expand All @@ -727,8 +727,6 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):
)

# Extract the Encoder from the simulator
inner_lr = 1e-3
inner_steps = 5
main_encoder = simulator._encode_process_decode._encoder

# Initialize training state
Expand Down Expand Up @@ -756,6 +754,10 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):
# Initialize encoders for each material property
task_encoders = defaultdict(lambda: copy.deepcopy(main_encoder))

# Hardcoded values
inner_steps = 5
inner_lr = 1e-3

try:
num_epochs = max(1, (cfg.training.steps + len(train_dl) - 1) // len(train_dl))
if verbose:
Expand Down Expand Up @@ -811,19 +813,16 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):
)

# Identify the material property and use the corresponding encoder
material_id = material_property[
0
].item() # Assuming all materials in the batch are the same
material_id = material_property[0].item() # Assuming all materials in the batch are the same
task_encoder = task_encoders[material_id]
task_optimizer = torch.optim.Adam(
task_encoder.parameters(), lr=inner_lr
)
task_optimizer = optim.Adam(task_encoder.parameters(), lr=inner_lr)

# Perform a few steps of gradient descent in the inner loop
for _ in range(inner_steps):
node_features, edge_features = task_encoder(
position, material_property
)
# Replace the simulator's encoder with the task-specific encoder
original_encoder = simulator._encode_process_decode._encoder
simulator._encode_process_decode._encoder = task_encoder

pred_acc, target_acc = predict_fn(
next_positions=labels.to(device_or_rank),
position_sequence_noise=sampled_noise.to(device_or_rank),
Expand All @@ -832,22 +831,21 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):
device_or_rank
),
particle_types=particle_type.to(device_or_rank),
node_features=node_features.to(device_or_rank),
edge_features=edge_features.to(device_or_rank),
material_property=(
material_property.to(device_or_rank)
if n_features == 3
else None
),
)

loss = acceleration_loss(
pred_acc, target_acc, non_kinematic_mask
)
loss = acceleration_loss(pred_acc, target_acc, non_kinematic_mask)

task_optimizer.zero_grad()
loss.backward()
task_optimizer.step()

# Save the task-specific encoder parameters
task_encoder_params = [
p.clone().detach() for p in task_encoder.parameters()
]
# Restore the original encoder
simulator._encode_process_decode._encoder = original_encoder

# Compute the meta-loss using the task-specific encoder parameters
meta_loss = 0.0
Expand Down Expand Up @@ -876,20 +874,9 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):
)
sampled_noise *= non_kinematic_mask.view(-1, 1, 1)

# Load task-specific encoder parameters
task_encoder.load_state_dict(
{
k: v
for k, v in zip(
task_encoder.state_dict().keys(),
task_encoder_params,
)
}
)
# Replace the simulator's encoder with the task-specific encoder
simulator._encode_process_decode._encoder = task_encoder

node_features, edge_features = task_encoder(
position, material_property
)
pred_acc, target_acc = predict_fn(
next_positions=labels.to(device_or_rank),
position_sequence_noise=sampled_noise.to(device_or_rank),
Expand All @@ -898,15 +885,19 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):
device_or_rank
),
particle_types=particle_type.to(device_or_rank),
node_features=node_features.to(device_or_rank),
edge_features=edge_features.to(device_or_rank),
material_property=(
material_property.to(device_or_rank)
if n_features == 3
else None
),
)

valid_loss = acceleration_loss(
pred_acc, target_acc, non_kinematic_mask
)
valid_loss = acceleration_loss(pred_acc, target_acc, non_kinematic_mask)
meta_loss += valid_loss

# Restore the original encoder
simulator._encode_process_decode._encoder = original_encoder

meta_loss /= len(valid_dl)

# Outer loop: Update the main encoder using the meta-loss
Expand Down

0 comments on commit 45a1bdf

Please sign in to comment.