Skip to content

Commit

Permalink
MAML approach
Browse files Browse the repository at this point in the history
  • Loading branch information
kks32 committed Jul 15, 2024
1 parent 385c116 commit 2f504d5
Showing 1 changed file with 103 additions and 159 deletions.
262 changes: 103 additions & 159 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import re
import sys

import copy
from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -706,7 +709,7 @@ def train(rank, cfg, world_size, device, verbose, use_dist):


def train_maml(rank, cfg, world_size, device, verbose, use_dist):
"""Train the model using Model Agnostic Meta Learning (MAML).
"""Train the model using MAML.
Args:
rank: local rank
Expand All @@ -723,6 +726,9 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):
cfg, rank, world_size, device, use_dist
)

# Extract the Encoder from the simulator
main_encoder = simulator._encode_process_decode._encoder

# Initialize training state
step = 0
epoch = 0
Expand All @@ -735,75 +741,19 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):
train_loss_hist = []
valid_loss_hist = []

# MAML hyperparameters (hardcoded)
inner_lr = 1e-3
num_inner_steps = 5

# If model_path does exist and model_file and train_state_file exist continue training.
if cfg.model.file is not None and cfg.training.resume:
if cfg.model.file == "latest" and cfg.model.train_state_file == "latest":
# find the latest model, assumes model and train_state files are in step.
fnames = glob.glob(f"{cfg.model.path}*model*pt")
max_model_number = 0
expr = re.compile(".*model-(\d+).pt")
for fname in fnames:
model_num = int(expr.search(fname).groups()[0])
if model_num > max_model_number:
max_model_number = model_num
# reset names to point to the latest.
cfg.model.file = f"model-{max_model_number}.pt"
cfg.model.train_state_file = f"train_state-{max_model_number}.pt"

if os.path.exists(cfg.model.path + cfg.model.file) and os.path.exists(
cfg.model.path + cfg.model.train_state_file
):
# load model
if use_dist:
simulator.module.load(cfg.model.path + cfg.model.file)
else:
simulator.load(cfg.model.path + cfg.model.file)

# load train state
train_state = torch.load(cfg.model.path + cfg.model.train_state_file)

# set optimizer state
optimizer = torch.optim.Adam(
simulator.module.parameters() if use_dist else simulator.parameters()
)
optimizer.load_state_dict(train_state["optimizer_state"])
optimizer_to(optimizer, device_id)

# set global train state
step = train_state["global_train_state"]["step"]
epoch = train_state["global_train_state"]["epoch"]
train_loss_hist = train_state["loss_history"]["train"]
valid_loss_hist = train_state["loss_history"]["valid"]

else:
msg = f"Specified model_file {cfg.model.path + cfg.model.file} and train_state_file {cfg.model.path + cfg.model.train_state_file} not found."
raise FileNotFoundError(msg)

simulator.train()
simulator.to(device_id)

# Extract the Encoder from the simulator
main_encoder = simulator._encode_process_decode._encoder

# Extract the correct dimensions from the main encoder
nnode_in_features=31 #main_encoder.node_fn[0][0].in_features,
nnode_out_features=128 #main_encoder.node_fn[-1].normalized_shape[0],
nedge_in_features=3 # main_encoder.edge_fn[0][0].in_features,
nedge_out_features=128 #main_encoder.edge_fn[-1].normalized_shape[0],
nmlp_layers = 2 # main_encoder.node_fn[1].mlp_layers
mlp_hidden_dim = 128 # main_encoder.node_fn[1].out_features

# Load datasets
train_dl, valid_dl, n_features = load_datasets(cfg, use_dist)

print(f"rank = {rank}, cuda = {torch.cuda.is_available()}")

writer = setup_tensorboard(cfg, metadata) if verbose else None

# Initialize encoders for each material property
task_encoders = defaultdict(lambda: copy.deepcopy(main_encoder))

try:
num_epochs = max(1, (cfg.training.steps + len(train_dl) - 1) // len(train_dl))
if verbose:
Expand All @@ -817,6 +767,7 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):
epoch_loss = 0.0
steps_this_epoch = 0

# Create a tqdm progress bar for each epoch
with tqdm(
range(step % len(train_dl) + 1, len(train_dl)),
desc=f"Epoch {epoch}",
Expand Down Expand Up @@ -857,121 +808,114 @@ def train_maml(rank, cfg, world_size, device, verbose, use_dist):
else simulator.predict_accelerations
)

# MAML inner loop
if material_property is not None and n_features == 3:
unique_materials = material_property.unique()
adapted_encoders = {}

for material in unique_materials:
# Clone the Encoder for this material with correct dimensions
material_encoder = Encoder(
nnode_in_features=nnode_in_features,
nnode_out_features=nnode_out_features,
nedge_in_features=nedge_in_features,
nedge_out_features=nedge_out_features,
nmlp_layers=nmlp_layers,
mlp_hidden_dim=mlp_hidden_dim
).to(device_id)
material_encoder.load_state_dict(main_encoder.state_dict())
material_encoder_optimizer = torch.optim.SGD(
material_encoder.parameters(), lr=inner_lr
)

# Select data for this material
material_mask = material_property == material
material_position = position[material_mask]
material_particle_type = particle_type[material_mask]
material_labels = labels[material_mask]
material_n_particles = torch.sum(material_mask).unsqueeze(0)
material_noise = sampled_noise[material_mask]

for _ in range(num_inner_steps):
# Temporarily replace the simulator's encoder
original_encoder = simulator._encode_process_decode._encoder
simulator._encode_process_decode._encoder = material_encoder

pred_acc, target_acc = predict_fn(
next_positions=material_labels.to(device_or_rank),
position_sequence_noise=material_noise.to(device_or_rank),
position_sequence=material_position.to(device_or_rank),
nparticles_per_example=material_n_particles.to(device_or_rank),
particle_types=material_particle_type.to(device_or_rank),
material_property=material.to(device_or_rank) if n_features == 3 else None
)

# Restore the original encoder
simulator._encode_process_decode._encoder = original_encoder
# Identify the material property and use the corresponding encoder
material_id = material_property[
0
].item() # Assuming all materials in the batch are the same
task_encoder = task_encoders[material_id]
task_optimizer = torch.optim.Adam(
task_encoder.parameters(), lr=cfg.training.inner_lr
)

inner_loss = acceleration_loss(
pred_acc, target_acc, non_kinematic_mask[material_mask]
)
# Perform a few steps of gradient descent in the inner loop
for _ in range(cfg.training.inner_steps):
node_features, edge_features = task_encoder(
position, material_property
)
pred_acc, target_acc = predict_fn(
next_positions=labels.to(device_or_rank),
position_sequence_noise=sampled_noise.to(device_or_rank),
position_sequence=position.to(device_or_rank),
nparticles_per_example=n_particles_per_example.to(
device_or_rank
),
particle_types=particle_type.to(device_or_rank),
node_features=node_features.to(device_or_rank),
edge_features=edge_features.to(device_or_rank),
)

material_encoder_optimizer.zero_grad()
inner_loss.backward()
material_encoder_optimizer.step()
loss = acceleration_loss(
pred_acc, target_acc, non_kinematic_mask
)

adapted_encoders[material] = material_encoder
task_optimizer.zero_grad()
loss.backward()
task_optimizer.step()

# Update the main Encoder with the average of adapted Encoders
with torch.no_grad():
for name, param in main_encoder.named_parameters():
param.data = torch.mean(torch.stack([
adapted_encoders[m].state_dict()[name]
for m in unique_materials
]), dim=0)
# Save the task-specific encoder parameters
task_encoder_params = [
p.clone().detach() for p in task_encoder.parameters()
]

else:
# If no material property, just use the main encoder
node_features, edge_index, edge_features = simulator._encoder_preprocessor(
# Compute the meta-loss using the task-specific encoder parameters
meta_loss = 0.0
for example in valid_dl:
(
position,
n_particles_per_example,
particle_type,
None
)
encoded_nodes, encoded_edges = main_encoder(node_features, edge_features)
material_property,
n_particles_per_example,
labels,
) = prepare_data(example, device_id)

# Outer loop (meta-update)
(
node_features,
edge_index,
edge_features,
) = simulator._encoder_preprocessor(
position,
n_particles_per_example,
particle_type,
material_property if n_features == 3 else None,
)
n_particles_per_example = n_particles_per_example.to(device_id)
labels = labels.to(device_id)

encoded_nodes, encoded_edges = main_encoder(
node_features, edge_features
)
sampled_noise = (
noise_utils.get_random_walk_noise_for_position_sequence(
position, noise_std_last_step=cfg.data.noise_std
).to(device_id)
)
non_kinematic_mask = (
(particle_type != cfg.data.kinematic_particle_id)
.clone()
.detach()
.to(device_id)
)
sampled_noise *= non_kinematic_mask.view(-1, 1, 1)

# Load task-specific encoder parameters
task_encoder.load_state_dict(
{
k: v
for k, v in zip(
task_encoder.state_dict().keys(),
task_encoder_params,
)
}
)

pred_acc, target_acc = predict_fn(
next_positions=labels.to(device_or_rank),
position_sequence_noise=sampled_noise.to(device_or_rank),
position_sequence=position.to(device_or_rank),
nparticles_per_example=n_particles_per_example.to(
device_or_rank
),
particle_types=particle_type.to(device_or_rank),
material_property=(
material_property.to(device_or_rank)
if n_features == 3
else None
),
encoder_output=(encoded_nodes, encoded_edges, edge_index),
)
node_features, edge_features = task_encoder(
position, material_property
)
pred_acc, target_acc = predict_fn(
next_positions=labels.to(device_or_rank),
position_sequence_noise=sampled_noise.to(device_or_rank),
position_sequence=position.to(device_or_rank),
nparticles_per_example=n_particles_per_example.to(
device_or_rank
),
particle_types=particle_type.to(device_or_rank),
node_features=node_features.to(device_or_rank),
edge_features=edge_features.to(device_or_rank),
)

loss = acceleration_loss(pred_acc, target_acc, non_kinematic_mask)
valid_loss = acceleration_loss(
pred_acc, target_acc, non_kinematic_mask
)
meta_loss += valid_loss

train_loss = loss.item()
epoch_loss += train_loss
steps_this_epoch += 1
meta_loss /= len(valid_dl)

# Outer loop: Update the main encoder using the meta-loss
optimizer.zero_grad()
loss.backward()
meta_loss.backward()
optimizer.step()

train_loss = meta_loss.item()
epoch_loss += train_loss
steps_this_epoch += 1

lr_new = (
cfg.training.learning_rate.initial
* (
Expand Down

0 comments on commit 2f504d5

Please sign in to comment.