Skip to content

Commit

Permalink
Fix resume training and README
Browse files Browse the repository at this point in the history
  • Loading branch information
kks32 committed Jun 25, 2024
1 parent 5329b78 commit a97cd5e
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 108 deletions.
8 changes: 0 additions & 8 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,8 @@ jobs:
- run:
name: Train & Test
command: |
TMP_DIR="../gns-sample"
DATASET_NAME="WaterDropSample"
git clone https://github.com/geoelements/gns-sample ../gns-sample
mkdir -p ${TMP_DIR}/${DATASET_NAME}/models/
mkdir -p ${TMP_DIR}/${DATASET_NAME}/rollout/
DATA_PATH="${TMP_DIR}/${DATASET_NAME}/dataset/"
MODEL_PATH="${TMP_DIR}/${DATASET_NAME}/models/"
ROLLOUT_PATH="${TMP_DIR}/${DATASET_NAME}/rollout/"
pytest test/
echo "Test paths: ${DATA_PATH} ${MODEL_PATH}"
python -m gns.train
ls ../gns-sample/WaterDropSample/models/
Expand Down
6 changes: 0 additions & 6 deletions .github/workflows/train.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,4 @@ jobs:
TMP_DIR="../gns-sample"
DATASET_NAME="WaterDropSample"
git clone https://github.com/geoelements/gns-sample ../gns-sample
mkdir -p ${TMP_DIR}/${DATASET_NAME}/models/
mkdir -p ${TMP_DIR}/${DATASET_NAME}/rollout/
DATA_PATH="${TMP_DIR}/${DATASET_NAME}/dataset/"
MODEL_PATH="${TMP_DIR}/${DATASET_NAME}/models/"
ROLLOUT_PATH="${TMP_DIR}/${DATASET_NAME}/rollout/"
echo "Test paths: ${DATA_PATH} ${MODEL_PATH}"
python -m gns.train
151 changes: 63 additions & 88 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ MeshNet is a scalable surrogate simulator for any mesh-based models like Finite
> Training GNS/MeshNet on simulation data
```shell
# For particulate domain,
python3 -m gns.train --data_path="<input-training-data-path>" --model_path="<path-to-load-save-model-file>" --ntraining_steps=100
python3 -m gns.train mode="train" --config-path ./ --config-name config.yaml
# For mesh-based domain,
python3 -m meshnet.train --data_path="<input-training-data-path>" --model_path="<path-to-load-save-model-file>" --ntraining_steps=100
```
Expand All @@ -29,15 +29,15 @@ To resume training specify `model_file` and `train_state_file`:

```shell
# For particulate domain,
python3 -m gns.train --data_path="<input-training-data-path>" --model_path="<path-to-load-save-model-file>" --model_file="model.pt" --train_state_file="train_state.pt" --ntraining_steps=100
python3 -m gns.train mode="train" training.resume=True
# For mesh-based domain,
python3 -m meshnet.train --data_path="<input-training-data-path>" --model_path="<path-to-load-save-model-file>" --model_file="model.pt" --train_state_file="train_state.pt" --ntraining_steps=100
```

> Rollout prediction
```shell
# For particulate domain,
python3 -m gns.train --mode="rollout" --data_path="<input-data-path>" --model_path="<path-to-load-save-model-file>" --output_path="<path-to-save-output>" --model_file="model.pt" --train_state_file="train_state.pt"
python3 -m gns.train mode="rollout"
# For mesh-based domain,
python3 -m meshnet.train --mode="rollout" --data_path="<input-data-path>" --model_path="<path-to-load-save-model-file>" --output_path="<path-to-save-output>" --model_file="model.pt" --train_state_file="train_state.pt"
```
Expand All @@ -61,91 +61,66 @@ In mesh-based domain, the renderer writes `.gif` animation.
> Meshnet GNS prediction of cylinder flow after training for 1 million steps.

## Command line arguments details
## Configuration file
<details>
<summary>`train.py` in GNS (particulate domain) </summary>

**mode (Enum)**

This flag is used to set the operation mode for the script. It can take one of three values; 'train', 'valid', or 'rollout'.

**batch_size (Integer)**

Batch size for training.

**noise_std (Float)**

Standard deviation of the noise when training.

**data_path (String)**

Specifies the directory path where the dataset is located.
The dataset is expected to be in a specific format (e.g., .npz files).
It should contain `metadata.json`.
If `--mode` is training, the directory should contain `train.npz`.
If `--mode` is testing (rollout), the directory should contain `test.npz`.
If `--mode` is valid, the directory should contain `valid.npz`.

**model_path (String)**

The directory path where the trained model checkpoints are saved during training or loaded from during validation/rollout.

**output_path (String)**

Defines the directory where the outputs (e.g., rollouts) are saved,
when the `--mode` is set to rollout.
This is particularly relevant in the rollout mode where the predictions of the model are stored.

**output_filename (String)**

Base filename to use when saving outputs during rollout.
Default is "rollout", and the output will be saved as `rollout.pkl` in `output_path`.
It is not intended to include the file extension.

**model_file (String)**

The filename of the model checkpoint to load for validation or rollout (e.g., model-10000.pt).
It supports a special value "latest" to automatically select the newest checkpoint file.
This flexibility facilitates the evaluation of models at different stages of training.

**train_state_file (String)**

Similar to model_file, but for loading the training state (e.g., optimizer state).
It supports a special value "latest" to automatically select the newest checkpoint file.
(e.g., training_state-10000.pt)

**ntraining_steps (Integer)**

The total number of training steps to execute before stopping.

**nsave_steps (Integer)**

Interval at which the model and training state are saved.

**lr_init (Float)**

Initial learning rate.

**lr_decay (Float)**

How much the learning rate should decay over time.

**lr_decay_steps (Integer)**

Steps at which learning rate should decay.

**cuda_device_number (Integer)**

Base CUDA device (zero indexed).
Default is None so default CUDA device will be used.

**n_gpus (Integer)**

Number of GPUs to use for training.

**tensorboard_log_dir (String)**

Path to log info on training and validation and visualize via tensorboard.
<summary>GNS (particulate domain) </summary>

```yaml
defaults:
- _self_
- override hydra/hydra_logging: disabled
- override hydra/job_logging: disabled

hydra:
output_subdir: null
run:
dir: .

# Top-level configuration
mode: train

# Data configuration
data:
path: ../gns-sample/WaterDropSample/dataset/
batch_size: 2
noise_std: 6.7e-4

# Model configuration
model:
path: ../gns-sample/WaterDropSample/models/
file: null
train_state_file: null

# Output configuration
output:
path: ../gns-sample/WaterDropSample/rollouts/
filename: rollout

# Training configuration
training:
steps: 2000
validation_interval: null
save_steps: 500
resume: False
learning_rate:
initial: 1e-4
decay: 0.1
decay_steps: 50000

# Hardware configuration
hardware:
cuda_device_number: null
n_gpus: 1

# Logging configuration
logging:
tensorboard_dir: logs/

constants:
input_sequence_length: 6
num_particle_types: 9
kinematic_particle_id: 3
```
</details>
Expand Down Expand Up @@ -254,7 +229,7 @@ The dataset is shared on [DesignSafe DataDepot](https://doi.org/10.17603/ds2-fzg
GNS uses [pytorch geometric](https://www.pyg.org/) and [CUDA](https://developer.nvidia.com/cuda-downloads). These packages have specific requirements, please see [PyG installation]((https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html) for details.
> CPU-only installation on Linux
> CPU-only installation on Linux (Conda)
```shell
conda install -y pytorch torchvision torchaudio cpuonly -c pytorch
Expand Down
10 changes: 5 additions & 5 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@ data:

# Model configuration
model:
path: models/
file: None
train_state_file: None
path: ../gns-sample/WaterDropSample/models/
file: model-1500.pt
train_state_file: train_state-1500.pt

# Output configuration
output:
path: rollouts/
path: ../gns-sample/WaterDropSample/rollouts/
filename: rollout

# Training configuration
training:
steps: 10
steps: 2000
validation_interval: null
save_steps: 500
resume: False
Expand Down
4 changes: 3 additions & 1 deletion gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,9 @@ def train(rank, cfg, world_size, device):
writer.add_hparams(hparam_dict, metric_dict)

try:
num_epochs = max(1, cfg.training.steps // len(dl)) # Calculate total epochs
num_epochs = max(
1, (cfg.training.steps + len(dl) - 1) // len(dl)
) # Calculate total epochs
print(f"Total epochs = {num_epochs}")
for epoch in tqdm(range(epoch, num_epochs), desc="Training", unit="epoch"):
if device == torch.device("cuda"):
Expand Down

0 comments on commit a97cd5e

Please sign in to comment.