Skip to content

Commit

Permalink
Fix torch.optim
Browse files Browse the repository at this point in the history
  • Loading branch information
kks32 committed Jul 15, 2024
1 parent 45a1bdf commit b7fc8c5
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel as DDP

from tqdm import tqdm

import hydra
Expand Down Expand Up @@ -815,7 +817,7 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):
# Identify the material property and use the corresponding encoder
material_id = material_property[0].item() # Assuming all materials in the batch are the same
task_encoder = task_encoders[material_id]
task_optimizer = optim.Adam(task_encoder.parameters(), lr=inner_lr)
task_optimizer = torch.optim.Adam(task_encoder.parameters(), lr=inner_lr)

# Perform a few steps of gradient descent in the inner loop
for _ in range(inner_steps):
Expand Down

0 comments on commit b7fc8c5

Please sign in to comment.