diff --git a/gns/train.py b/gns/train.py index 2581658..0560d76 100644 --- a/gns/train.py +++ b/gns/train.py @@ -398,6 +398,25 @@ def setup_tensorboard(cfg, metadata): return writer +def prepare_data(example, device_id): + """Prepare data for training or validation.""" + position = example[0][0].to(device_id) + particle_type = example[0][1].to(device_id) + + if len(example[0]) == 4: # if data loader includes material_property + material_property = example[0][2].to(device_id) + n_particles_per_example = example[0][3].to(device_id) + elif len(example[0]) == 3: + material_property = None + n_particles_per_example = example[0][2].to(device_id) + else: + raise ValueError("Unexpected number of elements in the data loader") + + labels = example[1].to(device_id) + + return position, particle_type, material_property, n_particles_per_example, labels + + def train(rank, cfg, world_size, device, verbose): """Train the model. @@ -484,9 +503,7 @@ def train(rank, cfg, world_size, device, verbose): writer = setup_tensorboard(cfg, metadata) if verbose else None try: - num_epochs = max( - 1, (cfg.training.steps + len(train_dl) - 1) // len(train_dl) - ) # Calculate total epochs + num_epochs = max(1, (cfg.training.steps + len(train_dl) - 1) // len(train_dl)) print(f"Total epochs = {num_epochs}") for epoch in tqdm(range(epoch, num_epochs), desc="Training", unit="epoch"): if device == torch.device("cuda"): @@ -499,16 +516,14 @@ def train(rank, cfg, world_size, device, verbose): with tqdm(total=len(train_dl), desc=f"Epoch {epoch}", unit="batch") as pbar: for example in train_dl: steps_per_epoch += 1 - position = example[0][0].to(device_id) - particle_type = example[0][1].to(device_id) - if n_features == 3: - material_property = example[0][2].to(device_id) - n_particles_per_example = example[0][3].to(device_id) - elif n_features == 2: - n_particles_per_example = example[0][2].to(device_id) - else: - raise NotImplementedError - labels = example[1].to(device_id) + # Prepare data + ( + position, + particle_type, + material_property, + n_particles_per_example, + labels, + ) = prepare_data(example, device_id) n_particles_per_example = n_particles_per_example.to(device_id) labels = labels.to(device_id) @@ -739,16 +754,10 @@ def _get_simulator( def validation(simulator, example, n_features, cfg, rank, device_id): - position = example[0][0].to(device_id) - particle_type = example[0][1].to(device_id) - if n_features == 3: # if dl includes material_property - material_property = example[0][2].to(device_id) - n_particles_per_example = example[0][3].to(device_id) - elif n_features == 2: - n_particles_per_example = example[0][2].to(device_id) - else: - raise NotImplementedError - labels = example[1].to(device_id) + + position, particle_type, material_property, n_particles_per_example, labels = ( + prepare_data(example, device_id) + ) # Sample the noise to add to the inputs. sampled_noise = noise_utils.get_random_walk_noise_for_position_sequence(