diff --git a/gns/train.py b/gns/train.py index b498d23..1ce329d 100644 --- a/gns/train.py +++ b/gns/train.py @@ -413,24 +413,26 @@ def setup_tensorboard(cfg, metadata): def prepare_data(example, device_id): - """Prepare data for training or validation.""" - position = example[0][0].to(device_id) - particle_type = example[0][1].to(device_id) - - if len(example[0]) == 4: # if data loader includes material_property - material_property = example[0][2].to(device_id) - n_particles_per_example = example[0][3].to(device_id) - elif len(example[0]) == 3: - material_property = None - n_particles_per_example = example[0][2].to(device_id) + features, labels = example + + if len(features) == 4: # If material property is present + position, particle_type, material_property, n_particles_per_example = features else: - raise ValueError("Unexpected number of elements in the data loader") - - labels = example[1].to(device_id) - + position, particle_type, n_particles_per_example = features + material_property = None + + # Convert numpy arrays to tensors + position = torch.from_numpy(position).float().to(device_id) + particle_type = torch.from_numpy(particle_type).long().to(device_id) + if material_property is not None: + material_property = torch.from_numpy(np.array(material_property)).float().to(device_id) + n_particles_per_example = torch.tensor([n_particles_per_example], device=device_id).long() + labels = torch.from_numpy(labels).float().to(device_id) + return position, particle_type, material_property, n_particles_per_example, labels + def train(rank, cfg, world_size, device, verbose, use_dist): """Train the model.