Skip to content

Commit

Permalink
Check dim again
Browse files Browse the repository at this point in the history
  • Loading branch information
kks32 committed Jul 14, 2024
1 parent 8e08a70 commit d5cac98
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,9 +790,9 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):
main_encoder = simulator._encode_process_decode._encoder

# Extract the correct dimensions from the main encoder
nnode_in_features=main_encoder.node_fn[0][0].in_features,
nnode_in_features=31 #main_encoder.node_fn[0][0].in_features,
nnode_out_features=main_encoder.node_fn[-1].normalized_shape[0],
nedge_in_features=main_encoder.edge_fn[0][0].in_features,
nedge_in_features=3 # main_encoder.edge_fn[0][0].in_features,
nedge_out_features=main_encoder.edge_fn[-1].normalized_shape[0],
nmlp_layers = len([m for m in main_encoder.node_fn if isinstance(m, nn.Linear)]) - 1
mlp_hidden_dim = 128 # main_encoder.node_fn[1].out_features
Expand Down

0 comments on commit d5cac98

Please sign in to comment.