Skip to content

Commit

Permalink
MAML encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
kks32 committed Jul 14, 2024
1 parent b8f3a9f commit 07918f1
Showing 1 changed file with 8 additions and 16 deletions.
24 changes: 8 additions & 16 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from gns import particle_data_loader as pdl
from gns import distribute
from gns.args import Config
from gns.graph_network import Encoder, build_mlp

Stats = collections.namedtuple("Stats", ["mean", "std"])

Expand Down Expand Up @@ -735,7 +736,7 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):

# MAML hyperparameters (hardcoded)
inner_lr = 1e-3
num_inner_steps = 10
num_inner_steps = 5

# If model_path does exist and model_file and train_state_file exist continue training.
if cfg.model.file is not None and cfg.training.resume:
Expand Down Expand Up @@ -855,21 +856,12 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):
for material in unique_materials:
# Clone the Encoder for this material
material_encoder = Encoder(
nnode_in_features=main_encoder.node_fn[0][
0
].in_features,
nnode_out_features=main_encoder.node_fn[0][
-1
].out_features,
nedge_in_features=main_encoder.edge_fn[0][
0
].in_features,
nedge_out_features=main_encoder.edge_fn[0][
-1
].out_features,
nmlp_layers=len(main_encoder.node_fn[0])
- 2, # Subtract input and output layers
mlp_hidden_dim=main_encoder.node_fn[0][1].out_features,
nnode_in_features=main_encoder.node_fn[0][0].in_features,
nnode_out_features=main_encoder.node_fn[0][-1].out_features,
nedge_in_features=main_encoder.edge_fn[0][0].in_features,
nedge_out_features=main_encoder.edge_fn[0][-1].out_features,
nmlp_layers=len(main_encoder.node_fn[0]) - 2, # Subtract input and output layers
mlp_hidden_dim=main_encoder.node_fn[0][1].out_features
).to(device_id)
material_encoder.load_state_dict(main_encoder.state_dict())
material_encoder_optimizer = torch.optim.SGD(
Expand Down

0 comments on commit 07918f1

Please sign in to comment.