From d5cac98eda75e4a4e1e98bfa22356697feb9e0a1 Mon Sep 17 00:00:00 2001 From: Krishna Kumar Date: Sun, 14 Jul 2024 17:39:39 -0600 Subject: [PATCH] Check dim again --- gns/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gns/train.py b/gns/train.py index 07a7e5b..b42634b 100644 --- a/gns/train.py +++ b/gns/train.py @@ -790,9 +790,9 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist): main_encoder = simulator._encode_process_decode._encoder # Extract the correct dimensions from the main encoder - nnode_in_features=main_encoder.node_fn[0][0].in_features, + nnode_in_features=31 #main_encoder.node_fn[0][0].in_features, nnode_out_features=main_encoder.node_fn[-1].normalized_shape[0], - nedge_in_features=main_encoder.edge_fn[0][0].in_features, + nedge_in_features=3 # main_encoder.edge_fn[0][0].in_features, nedge_out_features=main_encoder.edge_fn[-1].normalized_shape[0], nmlp_layers = len([m for m in main_encoder.node_fn if isinstance(m, nn.Linear)]) - 1 mlp_hidden_dim = 128 # main_encoder.node_fn[1].out_features