Skip to content

Commit

Permalink
Use proper attr in Encoder class
Browse files Browse the repository at this point in the history
  • Loading branch information
kks32 committed Jul 14, 2024
1 parent e930ad5 commit b8f3a9f
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):
train_loss_hist = []
valid_loss_hist = []

# MAML hyperparameters
# MAML hyperparameters (hardcoded)
inner_lr = 1e-3
num_inner_steps = 10

Expand Down Expand Up @@ -854,18 +854,22 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):

for material in unique_materials:
# Clone the Encoder for this material
material_encoder = type(main_encoder)(
nnode_in_features=main_encoder.node_fn[0].in_features,
nnode_out_features=main_encoder.node_fn[
material_encoder = Encoder(
nnode_in_features=main_encoder.node_fn[0][
0
].in_features,
nnode_out_features=main_encoder.node_fn[0][
-1
].out_features,
nedge_in_features=main_encoder.edge_fn[0].in_features,
nedge_out_features=main_encoder.edge_fn[
nedge_in_features=main_encoder.edge_fn[0][
0
].in_features,
nedge_out_features=main_encoder.edge_fn[0][
-1
].out_features,
nmlp_layers=len(main_encoder.node_fn)
nmlp_layers=len(main_encoder.node_fn[0])
- 2, # Subtract input and output layers
mlp_hidden_dim=main_encoder.node_fn[1].out_features,
mlp_hidden_dim=main_encoder.node_fn[0][1].out_features,
).to(device_id)
material_encoder.load_state_dict(main_encoder.state_dict())
material_encoder_optimizer = torch.optim.SGD(
Expand Down

0 comments on commit b8f3a9f

Please sign in to comment.