diff --git a/gns/train.py b/gns/train.py index b8c7cb3..164c420 100644 --- a/gns/train.py +++ b/gns/train.py @@ -727,6 +727,8 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist): ) # Extract the Encoder from the simulator + inner_lr = 1e-3 + inner_steps = 5 main_encoder = simulator._encode_process_decode._encoder # Initialize training state @@ -814,11 +816,11 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist): ].item() # Assuming all materials in the batch are the same task_encoder = task_encoders[material_id] task_optimizer = torch.optim.Adam( - task_encoder.parameters(), lr=cfg.training.inner_lr + task_encoder.parameters(), lr=inner_lr ) # Perform a few steps of gradient descent in the inner loop - for _ in range(cfg.training.inner_steps): + for _ in range(inner_steps): node_features, edge_features = task_encoder( position, material_property )