Skip to content

Commit

Permalink
Hardcode dim for 2D
Browse files Browse the repository at this point in the history
  • Loading branch information
kks32 committed Jul 14, 2024
1 parent aaa921e commit 0667c96
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,12 +856,12 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):
for material in unique_materials:
# Clone the Encoder for this material
material_encoder = Encoder(
nnode_in_features=main_encoder.node_fn[0][0].in_features,
nnode_out_features=main_encoder.node_fn[-1].normalized_shape[0],
nedge_in_features=main_encoder.edge_fn[0][0].in_features,
nedge_out_features=main_encoder.edge_fn[-1].normalized_shape[0],
nmlp_layers=len(main_encoder.node_fn[0]) - 2, # Subtract input and output layers
mlp_hidden_dim=main_encoder.node_fn[0][1].out_features
nnode_in_features=30,
nnode_out_features=128,
nedge_in_features=3,
nedge_out_features=128,
nmlp_layers=2, # Subtract input and output layers
mlp_hidden_dim=128
).to(device_id)
material_encoder.load_state_dict(main_encoder.state_dict())
material_encoder_optimizer = torch.optim.SGD(
Expand Down

0 comments on commit 0667c96

Please sign in to comment.