Skip to content

Commit

Permalink
Hardcode inner LR and steps
Browse files Browse the repository at this point in the history
  • Loading branch information
kks32 committed Jul 15, 2024
1 parent 2f504d5 commit b336176
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,8 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):
)

# Extract the Encoder from the simulator
inner_lr = 1e-3
inner_steps = 5
main_encoder = simulator._encode_process_decode._encoder

# Initialize training state
Expand Down Expand Up @@ -814,11 +816,11 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):
].item() # Assuming all materials in the batch are the same
task_encoder = task_encoders[material_id]
task_optimizer = torch.optim.Adam(
task_encoder.parameters(), lr=cfg.training.inner_lr
task_encoder.parameters(), lr=inner_lr
)

# Perform a few steps of gradient descent in the inner loop
for _ in range(cfg.training.inner_steps):
for _ in range(inner_steps):
node_features, edge_features = task_encoder(
position, material_property
)
Expand Down

0 comments on commit b336176

Please sign in to comment.