Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

render rollout merge #87

Merged
merged 26 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
18f3b7c
render rollout merge
Naveen-Raj-M Sep 9, 2024
e22406a
Update config.yaml
Naveen-Raj-M Sep 24, 2024
23bfe0c
Merge branch 'v2' of https://github.com/geoelements/gns into v2
Naveen-Raj-M Oct 1, 2024
8d5d7b4
Added unit test for render-rollout merge
Naveen-Raj-M Oct 1, 2024
2f9ab01
Merge branch 'v2' of https://github.com/Naveen-Raj-M/gns into v2
Naveen-Raj-M Oct 1, 2024
5146e57
deleted debugging fixtures
Naveen-Raj-M Oct 2, 2024
a115cb0
update config
Naveen-Raj-M Oct 4, 2024
2bcf615
add test for VTK rendering
Naveen-Raj-M Oct 4, 2024
cdf6f45
bug fix for NoneType material property
Naveen-Raj-M Oct 4, 2024
f9a2a9e
remove test_rendering and temp directory
Naveen-Raj-M Oct 4, 2024
57fad31
add test for vtk rendering
Naveen-Raj-M Oct 5, 2024
b387599
modify config for render_rollout merge
Naveen-Raj-M Oct 5, 2024
db3aec3
update to merge render-rollout
Naveen-Raj-M Oct 5, 2024
74b1e43
set default mode to gif
Naveen-Raj-M Oct 5, 2024
e08fcf5
rewrite 'rendering' function in an extensible way
Naveen-Raj-M Oct 10, 2024
3b88f65
improve readability and consistency
Naveen-Raj-M Oct 10, 2024
9fbafbb
update rendering options
Naveen-Raj-M Oct 10, 2024
81bc3f0
run black
Oct 11, 2024
f82c674
minor fix on viewpoint_rotation type
Oct 12, 2024
a861ec4
improve logging and reformat with black
Oct 12, 2024
7110352
refactor: move n_files function to a separate count_n_files.py in uti…
Oct 12, 2024
a86e58f
rename count_n_files.py to file_utils.py
Naveen-Raj-M Oct 13, 2024
b72770a
minor fix on module import
Naveen-Raj-M Oct 14, 2024
94c8519
run black
Naveen-Raj-M Oct 20, 2024
037b020
minor fix on raising error
Naveen-Raj-M Oct 22, 2024
ef41a87
add package for reading vtk files
Naveen-Raj-M Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,19 @@ python3 -m gns.train mode="train" training.resume=True
python3 -m meshnet.train mode="train" training.resume=True
```

> Rollout prediction
> Rollout prediction and render
```shell
# For particulate domain,
python3 -m gns.train mode="rollout"
# For mesh-based domain,
python3 -m meshnet.train mode="rollout"
```
### Rendering Options
Set rendering mode with `rendering.mode=<option>`
`null`: Disables rendering
`gif`: Creates a .gif file (default)
`vtk`: Writes .vtu files for ParaView visualization
Example: `rendering.mode=null` to disable rendering

> Render
```shell
Expand All @@ -50,8 +56,6 @@ python3 -m gns.render_rollout --output_mode="gif" --rollout_dir="<path-containin
python3 -m gns.render --rollout_dir="<path-containing-rollout-file>" --rollout_name="<name-of-rollout-file>"
```

In particulate domain, the renderer also writes `.vtu` files to visualize in ParaView.

![Sand rollout](docs/img/rollout_0.gif)
> GNS prediction of Sand rollout after training for 2 million steps.

Expand Down Expand Up @@ -118,8 +122,16 @@ hardware:
# Logging configuration
logging:
tensorboard_dir: logs/
```

# Rendering configuration
rendering:
mode: gif

gif:
step_stride: 3
vertical_camera_angle: 20
viewpoint_rotation: 0.3
change_yz: False
</details>


Expand Down Expand Up @@ -180,7 +192,6 @@ The total number of training steps to execute before stopping.
**nsave_steps (Integer)**

Interval at which the model and training state are saved.

</details>

## Datasets
Expand Down
10 changes: 10 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,13 @@ hardware:
# Logging configuration
logging:
tensorboard_dir: logs/

# Rendering configuration
rendering:
mode: gif
kks32 marked this conversation as resolved.
Show resolved Hide resolved

gif:
step_stride: 3
vertical_camera_angle: 20
viewpoint_rotation: 0.3
change_yz: False
14 changes: 13 additions & 1 deletion gns/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ class HardwareConfig:
class LoggingConfig:
tensorboard_dir: str = "logs/"

@dataclass
class GifConfig:
step_stride: int = 3
vertical_camera_angle: int = 20
viewpoint_rotation: int = 0.3
change_yz: bool = False

@dataclass
class RenderingConfig:
mode: Optional[str] = field(default='gif')
gif: GifConfig = field(default_factory=GifConfig)

@dataclass
class Config:
Expand All @@ -62,8 +73,9 @@ class Config:
training: TrainingConfig = field(default_factory=TrainingConfig)
hardware: HardwareConfig = field(default_factory=HardwareConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
rendering: RenderingConfig = field(default_factory=RenderingConfig)


# Hydra configuration
cs = ConfigStore.instance()
cs.store(name="base_config", node=Config)
cs.store(name="base_config", node=Config)
kks32 marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion gns/render_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def write_vtk(self):
}

# Check if material property exists and add it to data if it does
if "material_property" in self.rollout_data:
if "material_property" in self.rollout_data and self.rollout_data['material_property'] is not None:
material_property = self.rollout_data["material_property"]
data["material_property"] = material_property

Expand Down
42 changes: 40 additions & 2 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from gns import reading_utils
from gns import particle_data_loader as pdl
from gns import distribute
from gns import render_rollout
from gns.args import Config

Stats = collections.namedtuple("Stats", ["mean", "std"])
Expand Down Expand Up @@ -200,14 +201,51 @@ def predict(device: str, cfg: DictConfig):
example_rollout["metadata"] = metadata
example_rollout["loss"] = loss.mean()
filename = f"{cfg.output.filename}_ex{example_i}.pkl"
filename = os.path.join(cfg.output.path, filename)
filename_render = f"{cfg.output.filename}_ex{example_i}"
filename = os.path.join(cfg.output.path, filename_render)
with open(filename, "wb") as f:
pickle.dump(example_rollout, f)
if cfg.rendering.mode:
rendering(cfg.output.path, filename_render, cfg)

print(
"Mean loss on rollout prediction: {}".format(torch.mean(torch.cat(eval_loss)))
)

def rendering(input_dir, input_name, cfg: DictConfig):
kks32 marked this conversation as resolved.
Show resolved Hide resolved
"""
Render output based on the specified configuration and input parameters.
It supports rendering in both GIF and VTK formats.

Args:
input_dir (str): The directory containing the input files for rendering.
input_name (str): The base name of the input file to be rendered.
cfg (DictConfig): The configuration dictionary that specifies rendering options,
including the rendering mode and relevant parameters for that mode.

Raises:
ValueError: If the specified rendering mode is not supported.
"""
render = render_rollout.Render(input_dir, input_name)
kks32 marked this conversation as resolved.
Show resolved Hide resolved

rendering_modes = {
'gif' : lambda: render.render_gif_animation(
point_size=1,
timestep_stride=cfg.rendering.gif.step_stride,
vertical_camera_angle=cfg.rendering.gif.vertical_camera_angle,
viewpoint_rotation=cfg.rendering.gif.viewpoint_rotation,
change_yz=cfg.rendering.gif.change_yz,
),
'vtk' : lambda: render.write_vtk()

}

if cfg.rendering.mode in ['gif', 'vtk']:
rendering_mode = rendering_modes.get(cfg.rendering.mode)
rendering_mode()

else:
raise ValueError(f"Unsupported rendering mode: {cfg.rendering.mode}")

def optimizer_to(optim, device):
for param in optim.state.values():
Expand Down Expand Up @@ -850,4 +888,4 @@ def main(cfg: Config):


if __name__ == "__main__":
main()
main()
Loading