Skip to content

Commit

Permalink
render rollout merge (#87)
Browse files Browse the repository at this point in the history
* render rollout merge

* Update config.yaml

* Added unit test for render-rollout merge

* deleted debugging fixtures

* update config

* add test for VTK rendering

* bug fix for NoneType material property

* remove test_rendering and temp directory

* add test for vtk rendering

* modify config for render_rollout merge

* update to merge render-rollout

* set default mode to gif

* rewrite 'rendering' function in an extensible way

* improve readability and consistency

* update rendering options

* run black

* minor fix on viewpoint_rotation type

* improve logging and reformat with black

* refactor: move n_files function to a separate count_n_files.py in utils directory

* rename count_n_files.py to file_utils.py

* minor fix on module import

* run black

* minor fix on raising error

* add package for reading vtk files

---------

Co-authored-by: Naveen Raj Manoharan <[email protected]>
Co-authored-by: Naveen Raj Manoharan <[email protected]>
  • Loading branch information
3 people authored Oct 23, 2024
1 parent 01d815a commit 7190e4a
Show file tree
Hide file tree
Showing 8 changed files with 408 additions and 7 deletions.
21 changes: 16 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,19 @@ python3 -m gns.train mode="train" training.resume=True
python3 -m meshnet.train mode="train" training.resume=True
```

> Rollout prediction
> Rollout prediction and render
```shell
# For particulate domain,
python3 -m gns.train mode="rollout"
# For mesh-based domain,
python3 -m meshnet.train mode="rollout"
```
### Rendering Options
Set rendering mode with `rendering.mode=<option>`
`null`: Disables rendering
`gif`: Creates a .gif file (default)
`vtk`: Writes .vtu files for ParaView visualization
Example: `rendering.mode=null` to disable rendering

> Render
```shell
Expand All @@ -50,8 +56,6 @@ python3 -m gns.render_rollout --output_mode="gif" --rollout_dir="<path-containin
python3 -m gns.render --rollout_dir="<path-containing-rollout-file>" --rollout_name="<name-of-rollout-file>"
```

In particulate domain, the renderer also writes `.vtu` files to visualize in ParaView.

![Sand rollout](docs/img/rollout_0.gif)
> GNS prediction of Sand rollout after training for 2 million steps.
Expand Down Expand Up @@ -118,8 +122,16 @@ hardware:
# Logging configuration
logging:
tensorboard_dir: logs/
```

# Rendering configuration
rendering:
mode: gif

gif:
step_stride: 3
vertical_camera_angle: 20
viewpoint_rotation: 0.3
change_yz: False
</details>


Expand Down Expand Up @@ -180,7 +192,6 @@ The total number of training steps to execute before stopping.
**nsave_steps (Integer)**

Interval at which the model and training state are saved.

</details>

## Datasets
Expand Down
10 changes: 10 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,13 @@ hardware:
# Logging configuration
logging:
tensorboard_dir: logs/

# Rendering configuration
rendering:
mode: gif

gif:
step_stride: 3
vertical_camera_angle: 20
viewpoint_rotation: 0.3
change_yz: False
15 changes: 15 additions & 0 deletions gns/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,20 @@ class LoggingConfig:
tensorboard_dir: str = "logs/"


@dataclass
class GifConfig:
step_stride: int = 3
vertical_camera_angle: int = 20
viewpoint_rotation: float = 0.3
change_yz: bool = False


@dataclass
class RenderingConfig:
mode: Optional[str] = field(default="gif")
gif: GifConfig = field(default_factory=GifConfig)


@dataclass
class Config:
mode: str = "train"
Expand All @@ -62,6 +76,7 @@ class Config:
training: TrainingConfig = field(default_factory=TrainingConfig)
hardware: HardwareConfig = field(default_factory=HardwareConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
rendering: RenderingConfig = field(default_factory=RenderingConfig)


# Hydra configuration
Expand Down
5 changes: 4 additions & 1 deletion gns/render_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,10 @@ def write_vtk(self):
}

# Check if material property exists and add it to data if it does
if "material_property" in self.rollout_data:
if (
"material_property" in self.rollout_data
and self.rollout_data["material_property"] is not None
):
material_property = self.rollout_data["material_property"]
data["material_property"] = material_property

Expand Down
41 changes: 40 additions & 1 deletion gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from gns import reading_utils
from gns import particle_data_loader as pdl
from gns import distribute
from gns import render_rollout
from gns.args import Config

Stats = collections.namedtuple("Stats", ["mean", "std"])
Expand Down Expand Up @@ -200,15 +201,53 @@ def predict(device: str, cfg: DictConfig):
example_rollout["metadata"] = metadata
example_rollout["loss"] = loss.mean()
filename = f"{cfg.output.filename}_ex{example_i}.pkl"
filename = os.path.join(cfg.output.path, filename)
filename_render = f"{cfg.output.filename}_ex{example_i}"
filename = os.path.join(cfg.output.path, filename_render)
with open(filename, "wb") as f:
pickle.dump(example_rollout, f)
if cfg.rendering.mode:
rendering(cfg.output.path, filename_render, cfg)

print(
"Mean loss on rollout prediction: {}".format(torch.mean(torch.cat(eval_loss)))
)


def rendering(input_dir, input_name, cfg: DictConfig):
"""
Render output based on the specified configuration and input parameters.
It supports rendering in both GIF and VTK formats.
Args:
input_dir (str): The directory containing the input files for rendering.
input_name (str): The base name of the input file to be rendered.
cfg (DictConfig): The configuration dictionary that specifies rendering options,
including the rendering mode and relevant parameters for that mode.
Raises:
ValueError: If the specified rendering mode is not supported.
"""
render = render_rollout.Render(input_dir, input_name)

rendering_modes = {
"gif": lambda: render.render_gif_animation(
point_size=1,
timestep_stride=cfg.rendering.gif.step_stride,
vertical_camera_angle=cfg.rendering.gif.vertical_camera_angle,
viewpoint_rotation=cfg.rendering.gif.viewpoint_rotation,
change_yz=cfg.rendering.gif.change_yz,
),
"vtk": lambda: render.write_vtk(),
}

if cfg.rendering.mode in ["gif", "vtk"]:
rendering_mode = rendering_modes.get(cfg.rendering.mode)
rendering_mode()

else:
raise ValueError(f"Unsupported rendering mode: {cfg.rendering.mode}")


def optimizer_to(optim, device):
for param in optim.state.values():
# Not sure there are any global tensors in the state dict
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ torch_scatter
torch-cluster
tqdm
toml
pyvista
Loading

0 comments on commit 7190e4a

Please sign in to comment.