From a894222640d15d9e558c2f6cd6f731fc7c3a6354 Mon Sep 17 00:00:00 2001 From: Krishna Kumar Date: Wed, 26 Jun 2024 13:31:58 -0600 Subject: [PATCH] Multinode configuration --- gns/train_multinode.py | 712 ++++++++++++++++++----------------------- 1 file changed, 307 insertions(+), 405 deletions(-) diff --git a/gns/train_multinode.py b/gns/train_multinode.py index 389b57a..be660ea 100644 --- a/gns/train_multinode.py +++ b/gns/train_multinode.py @@ -10,77 +10,21 @@ import torch from torch.nn.parallel import DistributedDataParallel as DDP from tqdm import tqdm +from torch.utils.tensorboard import SummaryWriter +import torch.distributed as dist +import torchvision.models as models +from torch.utils import collect_env +from torch.utils.data.distributed import DistributedSampler -from absl import flags -from absl import app +import hydra +from omegaconf import DictConfig, OmegaConf sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gns import learned_simulator from gns import noise_utils from gns import reading_utils from gns import data_loader -import torch -import torch.distributed as dist -import torchvision.models as models -from torch.utils import collect_env -from torch.utils.data.distributed import DistributedSampler - -flags.DEFINE_enum( - "mode", - "train", - ["train", "valid", "rollout"], - help="Train model, validation or rollout evaluation.", -) -flags.DEFINE_integer("batch_size", 2, help="The batch size.") -flags.DEFINE_float("noise_std", 6.7e-4, help="The std deviation of the noise.") -flags.DEFINE_string("data_path", None, help="The dataset directory.") -flags.DEFINE_string( - "model_path", "models/", help=("The path for saving checkpoints of the model.") -) -flags.DEFINE_string( - "output_path", "rollouts/", help="The path for saving outputs (e.g. rollouts)." -) -flags.DEFINE_string( - "output_filename", "rollout", help="Base name for saving the rollout" -) -flags.DEFINE_string( - "model_file", - None, - help=( - 'Model filename (.pt) to resume from. Can also use "latest" to default to newest file.' - ), -) -flags.DEFINE_string( - "train_state_file", - "train_state.pt", - help=( - 'Train state filename (.pt) to resume from. Can also use "latest" to default to newest file.' - ), -) - -flags.DEFINE_integer("ntraining_steps", int(2e7), help="Number of training steps.") -flags.DEFINE_integer( - "nsave_steps", int(5000), help="Number of steps at which to save the model." -) - -# Learning rate parameters -flags.DEFINE_float("lr_init", 1e-4, help="Initial learning rate.") -flags.DEFINE_float("lr_decay", 0.1, help="Learning rate decay.") -flags.DEFINE_integer("lr_decay_steps", int(5e6), help="Learning rate decay steps.") - -flags.DEFINE_integer( - "cuda_device_number", - None, - help="CUDA device (zero indexed), default is None so default CUDA device will be used.", -) - -# new argument for multinode training -flags.DEFINE_integer("local-rank", 0, help="local rank for distributed training") - -FLAGS = flags.FLAGS - -if "LOCAL_RANK" in os.environ: - local_rank = int(os.environ["LOCAL_RANK"]) +from gns.args import Config Stats = collections.namedtuple("Stats", ["mean", "std"]) @@ -91,6 +35,7 @@ def rollout( simulator: learned_simulator.LearnedSimulator, + cfg: DictConfig, position: torch.tensor, particle_types: torch.tensor, material_property: torch.tensor, @@ -103,6 +48,7 @@ def rollout( Args: simulator: Learned simulator. + cfg: Configuration object. position: Positions of particles (timesteps, nparticles, ndims) particle_types: Particles types with shape (nparticles) material_property: Friction angle normalized by tan() with shape (nparticles) @@ -111,8 +57,8 @@ def rollout( device: torch device. """ - initial_positions = position[:, :INPUT_SEQUENCE_LENGTH] - ground_truth_positions = position[:, INPUT_SEQUENCE_LENGTH:] + initial_positions = position[:, : cfg.data.input_sequence_length] + ground_truth_positions = position[:, cfg.data.input_sequence_length :] current_positions = initial_positions predictions = [] @@ -128,7 +74,10 @@ def rollout( # Update kinematic particles from prescribed trajectory. kinematic_mask = ( - (particle_types == KINEMATIC_PARTICLE_ID).clone().detach().to(device) + (particle_types == cfg.data.kinematic_particle_id) + .clone() + .detach() + .to(device) ) next_position_ground_truth = ground_truth_positions[:, step] kinematic_mask = kinematic_mask.bool()[:, None].expand( @@ -164,217 +113,46 @@ def rollout( return output_dict, loss -def rollout_par( - simulator, - position: torch.tensor, - particle_types: torch.tensor, - material_property: torch.tensor, - n_particles_per_example: torch.tensor, - nsteps: int, - device, -): - """ - Rolls out a trajectory by applying the model in sequence. - - Args: - simulator: Learned simulator. - position: Positions of particles (timesteps, nparticles, ndims) - particle_types: Particles types with shape (nparticles) - material_property: Friction angle normalized by tan() with shape (nparticles) - n_particles_per_example - nsteps: Number of steps. - device: torch device. - """ - - initial_positions = position[:, :INPUT_SEQUENCE_LENGTH] - ground_truth_positions = position[:, INPUT_SEQUENCE_LENGTH:] - - current_positions = initial_positions - predictions = [] - - for step in tqdm(range(nsteps), total=nsteps): - # Get next position with shape (nnodes, dim) - next_position = simulator.module.predict_positions( - current_positions, - nparticles_per_example=[n_particles_per_example], - particle_types=particle_types, - material_property=material_property, - ) - - # Update kinematic particles from prescribed trajectory. - kinematic_mask = ( - (particle_types == KINEMATIC_PARTICLE_ID).clone().detach().to(device) - ) - next_position_ground_truth = ground_truth_positions[:, step] - kinematic_mask = kinematic_mask.bool()[:, None].expand( - -1, current_positions.shape[-1] - ) - next_position = torch.where( - kinematic_mask, next_position_ground_truth, next_position - ) - predictions.append(next_position) - - # Shift `current_positions`, removing the oldest position in the sequence - # and appending the next position at the end. - current_positions = torch.cat( - [current_positions[:, 1:], next_position[:, None, :]], dim=1 - ) - - # Predictions with shape (time, nnodes, dim) - predictions = torch.stack(predictions) - ground_truth_positions = ground_truth_positions.permute(1, 0, 2) - - loss = (predictions - ground_truth_positions) ** 2 - - output_dict = { - "initial_positions": initial_positions.permute(1, 0, 2).cpu().numpy(), - "predicted_rollout": predictions.cpu().numpy(), - "ground_truth_rollout": ground_truth_positions.cpu().numpy(), - "particle_types": particle_types.cpu().numpy(), - "material_property": material_property.cpu().numpy() - if material_property is not None - else None, - } - - return output_dict, loss - - -def predict(device: str): +def predict_par(rank, cfg): """Predict rollouts. Args: - simulator: Trained simulator if not will undergo training. - + rank: local rank for distributed setup. + cfg: configuration object. """ # Read metadata - metadata = reading_utils.read_metadata(FLAGS.data_path, "rollout") - simulator = _get_simulator(metadata, FLAGS.noise_std, FLAGS.noise_std, device) - - # Load simulator - if os.path.exists(FLAGS.model_path + FLAGS.model_file): - simulator.load(FLAGS.model_path + FLAGS.model_file) - else: - raise Exception( - f"Model does not exist at {FLAGS.model_path + FLAGS.model_file}" - ) - - simulator.to(device) - simulator.eval() - - # Output path - if not os.path.exists(FLAGS.output_path): - os.makedirs(FLAGS.output_path) - - # Use `valid`` set for eval mode if not use `test` - split = "test" if FLAGS.mode == "rollout" else "valid" - - # Get dataset - ds = data_loader.get_data_loader_by_trajectories( - path=f"{FLAGS.data_path}{split}.npz" - ) - # See if our dataset has material property as feature - if ( - len(ds.dataset._data[0]) == 3 - ): # `ds` has (positions, particle_type, material_property) - material_property_as_feature = True - elif len(ds.dataset._data[0]) == 2: # `ds` only has (positions, particle_type) - material_property_as_feature = False - else: - raise NotImplementedError - - eval_loss = [] - with torch.no_grad(): - torch.distributed.barrier() - for example_i, features in enumerate(ds): - positions = features[0].to(device) - if metadata["sequence_length"] is not None: - # If `sequence_length` is predefined in metadata, - nsteps = metadata["sequence_length"] - INPUT_SEQUENCE_LENGTH - else: - # If no predefined `sequence_length`, then get the sequence length - sequence_length = positions.shape[1] - nsteps = sequence_length - INPUT_SEQUENCE_LENGTH - particle_type = features[1].to(device) - if material_property_as_feature: - material_property = features[2].to(device) - n_particles_per_example = torch.tensor( - [int(features[3])], dtype=torch.int32 - ).to(device) - else: - material_property = None - n_particles_per_example = torch.tensor( - [int(features[2])], dtype=torch.int32 - ).to(device) - - # Predict example rollout - example_rollout, loss = rollout( - simulator, - positions, - particle_type, - material_property, - n_particles_per_example, - nsteps, - device, - ) - - example_rollout["metadata"] = metadata - print("Predicting example {} loss: {}".format(example_i, loss.mean())) - eval_loss.append(torch.flatten(loss)) - - # Save rollout in testing - if FLAGS.mode == "rollout": - example_rollout["metadata"] = metadata - example_rollout["loss"] = loss.mean() - filename = f"{FLAGS.output_filename}_ex{example_i}.pkl" - filename = os.path.join(FLAGS.output_path, filename) - with open(filename, "wb") as f: - pickle.dump(example_rollout, f) - - print( - "Mean loss on rollout prediction: {}".format(torch.mean(torch.cat(eval_loss))) + metadata = reading_utils.read_metadata(cfg.data.path, "rollout") + serial_simulator = _get_simulator( + metadata, cfg.data.noise_std, cfg.data.noise_std, rank ) - - -def predict_par(rank): - """Predict rollouts. - - Args: - simulator: Trained simulator if not will undergo training. - - """ - # Read metadata - metadata = reading_utils.read_metadata(FLAGS.data_path, "rollout") - serial_simulator = _get_simulator(metadata, FLAGS.noise_std, FLAGS.noise_std, rank) serial_simulator = serial_simulator.to("cuda") device_id = rank simulator = DDP(serial_simulator, device_ids=[rank]) # Load simulator - if os.path.exists(FLAGS.model_path + FLAGS.model_file): - simulator.module.load(FLAGS.model_path + FLAGS.model_file) + if os.path.exists(cfg.model.path + cfg.model.file): + simulator.module.load(cfg.model.path + cfg.model.file) else: - raise Exception( - f"Model does not exist at {FLAGS.model_path + FLAGS.model_file}" - ) + raise Exception(f"Model does not exist at {cfg.model.path + cfg.model.file}") simulator.eval() # Output path - if not os.path.exists(FLAGS.output_path): - os.makedirs(FLAGS.output_path) + if not os.path.exists(cfg.output.path): + os.makedirs(cfg.output.path) # Use `valid`` set for eval mode if not use `test` - split = "test" if FLAGS.mode == "rollout" else "valid" + split = "test" if cfg.mode == "rollout" else "valid" # Get dataset - # ds = data_loader.get_data_loader_by_trajectories(path=f"{FLAGS.data_path}{split}.npz") - dataset = data_loader.TrajectoriesDataset(path=f"{FLAGS.data_path}{split}.npz") + dataset = data_loader.TrajectoriesDataset(path=f"{cfg.data.path}{split}.npz") sampler = DistributedSampler( dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False ) ds = torch.utils.data.DataLoader( dataset=dataset, sampler=sampler, batch_size=None, pin_memory=True ) + # See if our dataset has material property as feature if ( len(ds.dataset._data[0]) == 3 @@ -394,11 +172,11 @@ def predict_par(rank): positions = features[0].to(rank) if metadata["sequence_length"] is not None: # If `sequence_length` is predefined in metadata, - nsteps = metadata["sequence_length"] - INPUT_SEQUENCE_LENGTH + nsteps = metadata["sequence_length"] - cfg.data.input_sequence_length else: # If no predefined `sequence_length`, then get the sequence length sequence_length = positions.shape[1] - nsteps = sequence_length - INPUT_SEQUENCE_LENGTH + nsteps = sequence_length - cfg.data.input_sequence_length particle_type = features[1].to(rank) if material_property_as_feature: material_property = features[2].to(rank) @@ -412,8 +190,9 @@ def predict_par(rank): ).to(rank) # Predict example rollout - example_rollout, loss = rollout_par( + example_rollout, loss = rollout( simulator, + cfg, positions, particle_type, material_property, @@ -427,11 +206,11 @@ def predict_par(rank): eval_loss.append(torch.flatten(loss)) # Save rollout in testing - if FLAGS.mode == "rollout": + if cfg.mode == "rollout": example_rollout["metadata"] = metadata example_rollout["loss"] = loss.mean() - filename = f"{FLAGS.output_filename}_ex{example_i}.pkl" - filename = os.path.join(FLAGS.output_path, filename) + filename = f"{cfg.output.filename}_ex{example_i}.pkl" + filename = os.path.join(cfg.output.path, filename) with open(filename, "wb") as f: pickle.dump(example_rollout, f) @@ -455,34 +234,101 @@ def optimizer_to(optim, device): subparam._grad.data = subparam._grad.data.to(device) -def train(rank, flags, world_size, verbose): +def acceleration_loss(pred_acc, target_acc, non_kinematic_mask): + """ + Compute the loss between predicted and target accelerations. + + Args: + pred_acc: Predicted accelerations. + target_acc: Target accelerations. + non_kinematic_mask: Mask for kinematic particles. + """ + loss = (pred_acc - target_acc) ** 2 + loss = loss.sum(dim=-1) + num_non_kinematic = non_kinematic_mask.sum() + loss = torch.where(non_kinematic_mask.bool(), loss, torch.zeros_like(loss)) + loss = loss.sum() / num_non_kinematic + return loss + + +def save_model_and_train_state( + rank, + device, + simulator, + cfg, + step, + epoch, + optimizer, + train_loss, + valid_loss, + train_loss_hist, + valid_loss_hist, +): + """Save model state""" + if rank == 0 or device == torch.device("cpu"): + if device == torch.device("cpu"): + simulator.save(cfg.model.path + "model-" + str(step) + ".pt") + else: + simulator.module.save(cfg.model.path + "model-" + str(step) + ".pt") + + train_state = dict( + optimizer_state=optimizer.state_dict(), + global_train_state={ + "step": step, + "epoch": epoch, + "train_loss": train_loss, + "valid_loss": valid_loss, + }, + loss_history={"train": train_loss_hist, "valid": valid_loss_hist}, + ) + torch.save(train_state, f"{cfg.model.path}train_state-{step}.pt") + + +def train(rank, cfg, world_size, device): """Train the model. Args: rank: local rank + cfg: configuration object world_size: total number of ranks - device: cuda device local rank - #""" + device: torch device type + """ + if device == torch.device("cuda"): + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + device_id = rank + else: + device_id = device + # Read metadata - metadata = reading_utils.read_metadata(flags["data_path"], "train") + metadata = reading_utils.read_metadata(cfg.data.path, "train") # Get simulator and optimizer serial_simulator = _get_simulator( - metadata, flags["noise_std"], flags["noise_std"], rank + metadata, cfg.data.noise_std, cfg.data.noise_std, rank ) serial_simulator = serial_simulator.to("cuda") - device_id = rank simulator = DDP(serial_simulator, device_ids=[rank]) optimizer = torch.optim.Adam( - simulator.parameters(), lr=flags["lr_init"] * world_size + simulator.parameters(), lr=cfg.training.learning_rate.initial * world_size ) + + # Initialize training state step = 0 + epoch = 0 + steps_per_epoch = 0 + + valid_loss = None + train_loss = 0 + epoch_valid_loss = None + + train_loss_hist = [] + valid_loss_hist = [] # If model_path does exist and model_file and train_state_file exist continue training. - if flags["model_file"] is not None: - if flags["model_file"] == "latest" and flags["train_state_file"] == "latest": + if cfg.model.file is not None and cfg.training.resume: + if cfg.model.file == "latest" and cfg.model.train_state_file == "latest": # find the latest model, assumes model and train_state files are in step. - fnames = glob.glob(f'{flags["model_path"]}*model*pt') + fnames = glob.glob(f"{cfg.model.path}*model*pt") max_model_number = 0 expr = re.compile(".*model-(\d+).pt") for fname in fnames: @@ -490,68 +336,102 @@ def train(rank, flags, world_size, verbose): if model_num > max_model_number: max_model_number = model_num # reset names to point to the latest. - flags["model_file"] = f"model-{max_model_number}.pt" - flags["train_state_file"] = f"train_state-{max_model_number}.pt" + cfg.model.file = f"model-{max_model_number}.pt" + cfg.model.train_state_file = f"train_state-{max_model_number}.pt" - if os.path.exists(flags["model_path"] + flags["model_file"]) and os.path.exists( - flags["model_path"] + flags["train_state_file"] + if os.path.exists(cfg.model.path + cfg.model.file) and os.path.exists( + cfg.model.path + cfg.model.train_state_file ): # load model - simulator.module.load(flags["model_path"] + flags["model_file"]) + simulator.module.load(cfg.model.path + cfg.model.file) # load train state - train_state = torch.load(flags["model_path"] + flags["train_state_file"]) + train_state = torch.load(cfg.model.path + cfg.model.train_state_file) + # set optimizer state - optimizer = torch.optim.Adam( - simulator.module.parameters() - if device == torch.device("cuda") - else simulator.parameters() - ) + optimizer = torch.optim.Adam(simulator.module.parameters()) optimizer.load_state_dict(train_state["optimizer_state"]) optimizer_to(optimizer, device_id) + # set global train state - step = train_state["global_train_state"].pop("step") + step = train_state["global_train_state"]["step"] + epoch = train_state["global_train_state"]["epoch"] + train_loss_hist = train_state["loss_history"]["train"] + valid_loss_hist = train_state["loss_history"]["valid"] else: - msg = f'Specified model_file {flags["model_path"] + flags["model_file"]} and train_state_file {flags["model_path"] + flags["train_state_file"]} not found.' + msg = f"Specified model_file {cfg.model.path + cfg.model.file} and train_state_file {cfg.model.path + cfg.model.train_state_file} not found." raise FileNotFoundError(msg) simulator.train() simulator.to(device_id) - # dl = distribute.get_data_distributed_dataloader_by_samples(path=f'{flags["data_path"]}train.npz', - # input_length_sequence=INPUT_SEQUENCE_LENGTH, - # batch_size=flags["batch_size"]) - path = f'{flags["data_path"]}train.npz' - input_length_sequence = INPUT_SEQUENCE_LENGTH - batch_size = flags["batch_size"] - dataset = data_loader.SamplesDataset(path, input_length_sequence) - # for multi node training we need to use pytorch's distributed sampler, see https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler + # Get dataset + dataset = data_loader.SamplesDataset( + cfg.data.path + "train.npz", cfg.data.input_sequence_length + ) sampler = DistributedSampler( dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True ) dl = torch.utils.data.DataLoader( dataset=dataset, sampler=sampler, - batch_size=batch_size, + batch_size=cfg.data.batch_size, pin_memory=True, collate_fn=data_loader.collate_fn, ) n_features = len(dl.dataset._data[0]) + # Load validation data + if cfg.training.validation_interval is not None: + valid_dataset = data_loader.SamplesDataset( + cfg.data.path + "valid.npz", cfg.data.input_sequence_length + ) + valid_sampler = DistributedSampler( + valid_dataset, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=False, + ) + dl_valid = torch.utils.data.DataLoader( + dataset=valid_dataset, + sampler=valid_sampler, + batch_size=cfg.data.batch_size, + pin_memory=True, + collate_fn=data_loader.collate_fn, + ) + if len(dl_valid.dataset._data[0]) != n_features: + raise ValueError( + f"`n_features` of `valid.npz` and `train.npz` should be the same" + ) + + if rank == 0: + writer = SummaryWriter(log_dir=cfg.logging.tensorboard_dir) + + writer.add_text("Data path", cfg.data.path) + writer.add_text("metadata", json.dumps(metadata, indent=4)) + + # Log hyperparameters + hparam_dict = { + "lr_init": cfg.training.learning_rate.initial, + "lr_decay": cfg.training.learning_rate.decay, + "lr_decay_steps": cfg.training.learning_rate.decay_steps, + "batch_size": cfg.data.batch_size, + "noise_std": cfg.data.noise_std, + "ntraining_steps": cfg.training.steps, + } + metric_dict = {"train_loss": 0, "valid_loss": 0} # Initial values + writer.add_hparams(hparam_dict, metric_dict) + not_reached_nsteps = True try: while not_reached_nsteps: torch.distributed.barrier() - for ( - example - ) in ( - dl - ): # ((position, particle_type, material_property, n_particles_per_example), labels) are in dl + for example in dl: position = example[0][0].to(device_id) particle_type = example[0][1].to(device_id) - if n_features == 3: # if dl includes material_property + if n_features == 3: material_property = example[0][2].to(device_id) n_particles_per_example = example[0][3].to(device_id) elif n_features == 2: @@ -560,23 +440,20 @@ def train(rank, flags, world_size, verbose): raise NotImplementedError labels = example[1].to(device_id) - n_particles_per_example.to(device_id) - labels.to(device_id) + n_particles_per_example = n_particles_per_example.to(device_id) + labels = labels.to(device_id) - # TODO (jpv): Move noise addition to data_loader - # Sample the noise to add to the inputs to the model during training. sampled_noise = noise_utils.get_random_walk_noise_for_position_sequence( - position, noise_std_last_step=flags["noise_std"] + position, noise_std_last_step=cfg.data.noise_std ).to(device_id) non_kinematic_mask = ( - (particle_type != KINEMATIC_PARTICLE_ID) + (particle_type != cfg.data.kinematic_particle_id) .clone() .detach() .to(device_id) ) sampled_noise *= non_kinematic_mask.view(-1, 1, 1) - # Get the predictions and target accelerations. pred_acc, target_acc = simulator.module.predict_accelerations( next_positions=labels.to(rank), position_sequence_noise=sampled_noise.to(rank), @@ -588,67 +465,84 @@ def train(rank, flags, world_size, verbose): else None, ) - # Calculate the loss and mask out loss on kinematic particles - loss = (pred_acc - target_acc) ** 2 - loss = loss.sum(dim=-1) - num_non_kinematic = non_kinematic_mask.sum() - loss = torch.where( - non_kinematic_mask.bool(), loss, torch.zeros_like(loss) - ) - loss = loss.sum() / num_non_kinematic + if ( + cfg.training.validation_interval is not None + and step > 0 + and step % cfg.training.validation_interval == 0 + ): + if rank == 0: + sampled_valid_example = next(iter(dl_valid)) + valid_loss = validation( + simulator, + sampled_valid_example, + n_features, + cfg, + rank, + device_id, + ) + writer.add_scalar("Loss/valid", valid_loss.item(), step) + + loss = acceleration_loss(pred_acc, target_acc, non_kinematic_mask) + + train_loss = loss.item() - # Computes the gradient of loss optimizer.zero_grad() loss.backward() optimizer.step() - # Update learning rate lr_new = ( - flags["lr_init"] - * (flags["lr_decay"] ** (step / flags["lr_decay_steps"])) + cfg.training.learning_rate.initial + * ( + cfg.training.learning_rate.decay + ** (step / cfg.training.learning_rate.decay_steps) + ) * world_size ) for param in optimizer.param_groups: param["lr"] = lr_new - # for multi node training we need to use global rank (verbose is true is global rank is 0) instead of local rank to decide whether to print - if verbose: + if rank == 0: + writer.add_scalar("Loss/train", train_loss, step) + writer.add_scalar("Learning Rate", lr_new, step) + print( - f'Training step: {step}/{flags["ntraining_steps"]}. Loss: {loss}.', + f"Training step: {step}/{cfg.training.steps}. Loss: {loss}.", flush=True, ) - # Save model state - if step % flags["nsave_steps"] == 0: - simulator.module.save( - flags["model_path"] + "model-" + str(step) + ".pt" - ) - train_state = dict( - optimizer_state=optimizer.state_dict(), - global_train_state={"step": step}, - loss=loss.item(), - ) - torch.save( - train_state, f'{flags["model_path"]}train_state-{step}.pt' + if step % cfg.training.save_steps == 0: + save_model_and_train_state( + rank, + device, + simulator, + cfg, + step, + epoch, + optimizer, + train_loss, + valid_loss, + train_loss_hist, + valid_loss_hist, ) - # Complete training - if step >= flags["ntraining_steps"]: + step += 1 + if step >= cfg.training.steps: not_reached_nsteps = False break - step += 1 + epoch += 1 except KeyboardInterrupt: pass if rank == 0: - simulator.module.save(flags["model_path"] + "model-" + str(step) + ".pt") + simulator.module.save(cfg.model.path + "model-" + str(step) + ".pt") train_state = dict( optimizer_state=optimizer.state_dict(), global_train_state={"step": step}, loss=loss.item(), ) - torch.save(train_state, f'{flags["model_path"]}train_state-{step}.pt') + torch.save(train_state, f"{cfg.model.path}train_state-{step}.pt") + writer.close() if torch.cuda.is_available(): torch.distributed.destroy_process_group() @@ -657,16 +551,7 @@ def train(rank, flags, world_size, verbose): def _get_simulator( metadata: json, acc_noise_std: float, vel_noise_std: float, device: torch.device ) -> learned_simulator.LearnedSimulator: - """Instantiates the simulator. - - Args: - metadata: JSON object with metadata. - acc_noise_std: Acceleration noise std deviation. - vel_noise_std: Velocity noise std deviation. - device: PyTorch device 'cpu' or 'cuda'. - """ - - # Normalization stats + """Instantiates the simulator.""" normalization_stats = { "acceleration": { "mean": torch.FloatTensor(metadata["acc_mean"]).to(device), @@ -682,17 +567,13 @@ def _get_simulator( }, } - # Get necessary parameters for loading simulator. if "nnode_in" in metadata and "nedge_in" in metadata: nnode_in = metadata["nnode_in"] nedge_in = metadata["nedge_in"] else: - # Given that there is no additional node feature (e.g., material_property) except for: - # (position (dim), velocity (dim*6), particle_type (16)), nnode_in = 37 if metadata["dim"] == 3 else 30 nedge_in = metadata["dim"] + 1 - # Init simulator. simulator = learned_simulator.LearnedSimulator( particle_dimensions=metadata["dim"], nnode_in=nnode_in, @@ -715,66 +596,87 @@ def _get_simulator( return simulator -def main(_): - """Train or evaluates the model.""" - device = torch.device("cuda") - myflags = {} - myflags["data_path"] = FLAGS.data_path - myflags["noise_std"] = FLAGS.noise_std - myflags["lr_init"] = FLAGS.lr_init - myflags["lr_decay"] = FLAGS.lr_decay - myflags["lr_decay_steps"] = FLAGS.lr_decay_steps - myflags["batch_size"] = FLAGS.batch_size - myflags["ntraining_steps"] = FLAGS.ntraining_steps - myflags["nsave_steps"] = FLAGS.nsave_steps - myflags["model_file"] = FLAGS.model_file - myflags["model_path"] = FLAGS.model_path - myflags["train_state_file"] = FLAGS.train_state_file - - torch.multiprocessing.set_start_method("spawn") - torch.distributed.init_process_group( - backend="nccl", - init_method="env://", +def validation(simulator, example, n_features, cfg, rank, device_id): + position = example[0][0].to(device_id) + particle_type = example[0][1].to(device_id) + if n_features == 3: + material_property = example[0][2].to(device_id) + n_particles_per_example = example[0][3].to(device_id) + elif n_features == 2: + n_particles_per_example = example[0][2].to(device_id) + else: + raise NotImplementedError + labels = example[1].to(device_id) + + sampled_noise = noise_utils.get_random_walk_noise_for_position_sequence( + position, noise_std_last_step=cfg.data.noise_std + ).to(device_id) + non_kinematic_mask = ( + (particle_type != cfg.data.kinematic_particle_id).clone().detach().to(device_id) ) - torch.distributed.barrier() - # instead of torch.cuda.device_count(), we use dist.get_world_size() to get world size - world_size = dist.get_world_size() - - torch.cuda.set_device(local_rank) - torch.cuda.manual_seed(0) - - # we use global rank to decide the verbose device - flags.DEFINE_boolean("verbose", False, "Verbose.") - FLAGS.verbose = dist.get_rank() == 0 - - if FLAGS.verbose: - print("Collecting env info...") - print(collect_env.get_pretty_env_info()) - print() - - for r in range(torch.distributed.get_world_size()): - if r == torch.distributed.get_rank(): - print( - f"Global rank {torch.distributed.get_rank()} initialized: " - f"local_rank = {local_rank}, " - f"world_size = {torch.distributed.get_world_size()}", - ) - torch.distributed.barrier() + sampled_noise *= non_kinematic_mask.view(-1, 1, 1) + + with torch.no_grad(): + pred_acc, target_acc = simulator.module.predict_accelerations( + next_positions=labels.to(rank), + position_sequence_noise=sampled_noise.to(rank), + position_sequence=position.to(rank), + nparticles_per_example=n_particles_per_example.to(rank), + particle_types=particle_type.to(rank), + material_property=material_property.to(rank) if n_features == 3 else None, + ) + + loss = acceleration_loss(pred_acc, target_acc, non_kinematic_mask) + + return loss + + +@hydra.main(version_base=None, config_path="..", config_name="config") +def main(cfg: Config): + """Train or evaluates the model.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device == torch.device("cuda"): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" + + # Print the final configuration + print(OmegaConf.to_yaml(cfg)) - if FLAGS.mode == "train": + if cfg.mode == "train": # If model_path does not exist create new directory. - if not os.path.exists(FLAGS.model_path): - os.makedirs(FLAGS.model_path) - train(local_rank, myflags, world_size, FLAGS.verbose) + if not os.path.exists(cfg.model.path): + os.makedirs(cfg.model.path) + if not os.path.exists(cfg.logging.tensorboard_dir): + os.makedirs(cfg.logging.tensorboard_dir) + + if device == torch.device("cuda"): + available_gpus = torch.cuda.device_count() + print(f"Available GPUs = {available_gpus}") + + if cfg.hardware.n_gpus is None or cfg.hardware.n_gpus > available_gpus: + world_size = available_gpus + if cfg.hardware.n_gpus is not None: + print( + f"Warning: The number of GPUs specified ({cfg.hardware.n_gpus}) exceeds the available GPUs ({available_gpus})" + ) + else: + world_size = cfg.hardware.n_gpus + + print(f"Using {world_size}/{available_gpus} GPUs") + + torch.multiprocessing.spawn( + train, args=(cfg, world_size, device), nprocs=world_size, join=True + ) + else: + train(0, cfg, 1, device) - elif FLAGS.mode in ["valid", "rollout"]: - # Set device - predict_par(local_rank) - # world_size = torch.cuda.device_count() - # if FLAGS.cuda_device_number is not None and torch.cuda.is_available(): - # device = torch.device(f'cuda:{int(FLAGS.cuda_device_number)}') - # predict(device) + elif cfg.mode in ["valid", "rollout"]: + world_size = torch.cuda.device_count() + if cfg.hardware.cuda_device_number is not None and torch.cuda.is_available(): + device = torch.device(f"cuda:{int(cfg.hardware.cuda_device_number)}") + print(f"device is {device} world size is {world_size}") + predict_par(0, cfg) if __name__ == "__main__": - app.run(main) + main()