Skip to content

Commit

Permalink
Refactor constants to data
Browse files Browse the repository at this point in the history
  • Loading branch information
kks32 committed Jun 25, 2024
1 parent 170037b commit 40b4919
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 41 deletions.
8 changes: 3 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ data:
path: ../gns-sample/WaterDropSample/dataset/
batch_size: 2
noise_std: 6.7e-4
input_sequence_length: 6
num_particle_types: 9
kinematic_particle_id: 3

# Model configuration
model:
Expand Down Expand Up @@ -115,11 +118,6 @@ hardware:
# Logging configuration
logging:
tensorboard_dir: logs/

constants:
input_sequence_length: 6
num_particle_types: 9
kinematic_particle_id: 3
```
</details>
Expand Down
8 changes: 3 additions & 5 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ data:
path: ../gns-sample/WaterDropSample/dataset/
batch_size: 2
noise_std: 6.7e-4
input_sequence_length: 6
num_particle_types: 9
kinematic_particle_id: 3

# Model configuration
model:
Expand Down Expand Up @@ -47,8 +50,3 @@ hardware:
# Logging configuration
logging:
tensorboard_dir: logs/

constants:
input_sequence_length: 6
num_particle_types: 9
kinematic_particle_id: 3
14 changes: 4 additions & 10 deletions gns/args.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from dataclasses import dataclass, field
from typing import Optional
from omegaconf import MISSING
from hydra.core.config_store import ConfigStore


@dataclass
class DataConfig:
path: str = MISSING
batch_size: int = 2
noise_std: float = 6.7e-4
input_sequence_length: int = 6
num_particle_types: int = 9
kinematic_particle_id: int = 3


@dataclass
Expand Down Expand Up @@ -50,13 +54,6 @@ class LoggingConfig:
tensorboard_dir: str = "logs/"


@dataclass
class ConstantsConfig:
input_sequence_length: int = 6
num_particle_types: int = 9
kinematic_particle_id: int = 3


@dataclass
class Config:
mode: str = "train"
Expand All @@ -66,11 +63,8 @@ class Config:
training: TrainingConfig = field(default_factory=TrainingConfig)
hardware: HardwareConfig = field(default_factory=HardwareConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
constants: ConstantsConfig = field(default_factory=ConstantsConfig)


# Hydra configuration
from hydra.core.config_store import ConfigStore

cs = ConfigStore.instance()
cs.store(name="base_config", node=Config)
58 changes: 37 additions & 21 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm

from absl import flags
from absl import app

import hydra
from omegaconf import DictConfig, OmegaConf

Expand All @@ -28,13 +25,10 @@

Stats = collections.namedtuple("Stats", ["mean", "std"])

INPUT_SEQUENCE_LENGTH = 6 # So we can calculate the last 5 velocities.
NUM_PARTICLE_TYPES = 9
KINEMATIC_PARTICLE_ID = 3


def rollout(
simulator: learned_simulator.LearnedSimulator,
cfg: DictConfig,
position: torch.tensor,
particle_types: torch.tensor,
material_property: torch.tensor,
Expand All @@ -55,8 +49,8 @@ def rollout(
device: torch device.
"""

initial_positions = position[:, :INPUT_SEQUENCE_LENGTH]
ground_truth_positions = position[:, INPUT_SEQUENCE_LENGTH:]
initial_positions = position[:, : cfg.data.input_sequence_length]
ground_truth_positions = position[:, cfg.data.input_sequence_length :]

current_positions = initial_positions
predictions = []
Expand All @@ -72,7 +66,10 @@ def rollout(

# Update kinematic particles from prescribed trajectory.
kinematic_mask = (
(particle_types == KINEMATIC_PARTICLE_ID).clone().detach().to(device)
(particle_types == cfg.data.kinematic_particle_id)
.clone()
.detach()
.to(device)
)
next_position_ground_truth = ground_truth_positions[:, step]
kinematic_mask = kinematic_mask.bool()[:, None].expand(
Expand Down Expand Up @@ -118,7 +115,13 @@ def predict(device: str, cfg: DictConfig):
"""
# Read metadata
metadata = reading_utils.read_metadata(cfg.data.path, "rollout")
simulator = _get_simulator(metadata, cfg.data.noise_std, cfg.data.noise_std, device)
simulator = _get_simulator(
metadata,
cfg.data.num_particle_types,
cfg.data.noise_std,
cfg.data.noise_std,
device,
)

# Load simulator
if os.path.exists(cfg.model.path + cfg.model.file):
Expand Down Expand Up @@ -159,11 +162,11 @@ def predict(device: str, cfg: DictConfig):
positions = features[0].to(device)
if metadata["sequence_length"] is not None:
# If `sequence_length` is predefined in metadata,
nsteps = metadata["sequence_length"] - INPUT_SEQUENCE_LENGTH
nsteps = metadata["sequence_length"] - cfg.data.input_sequence_length
else:
# If no predefined `sequence_length`, then get the sequence length
sequence_length = positions.shape[1]
nsteps = sequence_length - INPUT_SEQUENCE_LENGTH
nsteps = sequence_length - cfg.data.input_sequence_length
particle_type = features[1].to(device)
if material_property_as_feature:
material_property = features[2].to(device)
Expand All @@ -179,6 +182,7 @@ def predict(device: str, cfg: DictConfig):
# Predict example rollout
example_rollout, loss = rollout(
simulator,
cfg,
positions,
particle_type,
material_property,
Expand Down Expand Up @@ -304,7 +308,11 @@ def train(rank, cfg, world_size, device):
# Get simulator and optimizer
if device == torch.device("cuda"):
serial_simulator = _get_simulator(
metadata, cfg.data.noise_std, cfg.data.noise_std, rank
metadata,
cfg.data.num_particle_types,
cfg.data.noise_std,
cfg.data.noise_std,
rank,
)
simulator = DDP(
serial_simulator.to(rank), device_ids=[rank], output_device=rank
Expand All @@ -314,7 +322,11 @@ def train(rank, cfg, world_size, device):
)
else:
simulator = _get_simulator(
metadata, cfg.data.noise_std, cfg.data.noise_std, device
metadata,
cfg.data.num_particle_types,
cfg.data.noise_std,
cfg.data.noise_std,
device,
)
optimizer = torch.optim.Adam(
simulator.parameters(), lr=cfg.training.learning_rate.initial * world_size
Expand Down Expand Up @@ -391,7 +403,7 @@ def train(rank, cfg, world_size, device):
# Load training data
dl = get_data_loader(
path=f"{cfg.data.path}train.npz",
input_length_sequence=INPUT_SEQUENCE_LENGTH,
input_length_sequence=cfg.data.input_sequence_length,
batch_size=cfg.data.batch_size,
)
n_features = len(dl.dataset._data[0])
Expand All @@ -400,7 +412,7 @@ def train(rank, cfg, world_size, device):
if cfg.training.validation_interval is not None:
dl_valid = get_data_loader(
path=f"{cfg.data.path}valid.npz",
input_length_sequence=INPUT_SEQUENCE_LENGTH,
input_length_sequence=cfg.data.input_sequence_length,
batch_size=cfg.data.batch_size,
)
if len(dl_valid.dataset._data[0]) != n_features:
Expand Down Expand Up @@ -464,7 +476,7 @@ def train(rank, cfg, world_size, device):
).to(device_id)
)
non_kinematic_mask = (
(particle_type != KINEMATIC_PARTICLE_ID)
(particle_type != cfg.data.kinematic_particle_id)
.clone()
.detach()
.to(device_id)
Expand Down Expand Up @@ -619,7 +631,11 @@ def train(rank, cfg, world_size, device):


def _get_simulator(
metadata: json, acc_noise_std: float, vel_noise_std: float, device: torch.device
metadata: json,
num_particle_types: int,
acc_noise_std: float,
vel_noise_std: float,
device: torch.device,
) -> learned_simulator.LearnedSimulator:
"""Instantiates the simulator.
Expand Down Expand Up @@ -668,7 +684,7 @@ def _get_simulator(
connectivity_radius=metadata["default_connectivity_radius"],
boundaries=np.array(metadata["bounds"]),
normalization_stats=normalization_stats,
nparticle_types=NUM_PARTICLE_TYPES,
nparticle_types=num_particle_types,
particle_type_embedding_size=16,
boundary_clamp_limit=metadata["boundary_augment"]
if "boundary_augment" in metadata
Expand Down Expand Up @@ -696,7 +712,7 @@ def validation(simulator, example, n_features, cfg, rank, device_id):
position, noise_std_last_step=cfg.data.noise_std
).to(device_id)
non_kinematic_mask = (
(particle_type != KINEMATIC_PARTICLE_ID).clone().detach().to(device_id)
(particle_type != cfg.data.kinematic_particle_id).clone().detach().to(device_id)
)
sampled_noise *= non_kinematic_mask.view(-1, 1, 1)

Expand Down

0 comments on commit 40b4919

Please sign in to comment.