Skip to content

Commit

Permalink
Material encoder properties
Browse files Browse the repository at this point in the history
  • Loading branch information
kks32 committed Jul 14, 2024
1 parent 07918f1 commit aaa921e
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,9 +857,9 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):
# Clone the Encoder for this material
material_encoder = Encoder(
nnode_in_features=main_encoder.node_fn[0][0].in_features,
nnode_out_features=main_encoder.node_fn[0][-1].out_features,
nnode_out_features=main_encoder.node_fn[-1].normalized_shape[0],
nedge_in_features=main_encoder.edge_fn[0][0].in_features,
nedge_out_features=main_encoder.edge_fn[0][-1].out_features,
nedge_out_features=main_encoder.edge_fn[-1].normalized_shape[0],
nmlp_layers=len(main_encoder.node_fn[0]) - 2, # Subtract input and output layers
mlp_hidden_dim=main_encoder.node_fn[0][1].out_features
).to(device_id)
Expand Down

0 comments on commit aaa921e

Please sign in to comment.