Skip to content

Commit

Permalink
improve vtk renderer to handle material properties
Browse files Browse the repository at this point in the history
  • Loading branch information
yjchoi1 committed Sep 23, 2024
1 parent 4950aa9 commit 9d9b116
Showing 1 changed file with 45 additions and 2 deletions.
47 changes: 45 additions & 2 deletions gns/render_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import matplotlib.pyplot as plt
import numpy as np
import os
from pyevtk.hl import pointsToVTK
from pyevtk.hl import pointsToVTK, gridToVTK

flags.DEFINE_string("rollout_dir", None, help="Directory where rollout.pkl are located")
flags.DEFINE_string("rollout_name", None, help="Name of rollout `.pkl` file")
Expand Down Expand Up @@ -254,8 +254,39 @@ def write_vtk(self):
if not os.path.exists(path):
os.makedirs(path)
initial_position = self.trajectory[rollout_case][0]

# Extract boundary information
[[x_min, x_max], [y_min, y_max]] = self.boundaries

for i, coord in enumerate(self.trajectory[rollout_case]):
disp = np.linalg.norm(coord - initial_position, axis=1)

# Extract particle type
particle_type = self.rollout_data["particle_types"]

# Prepare data dictionary
data = {
"displacement": disp,
"particle_type": particle_type,
}

# Check if material property exists and add it to data if it does
if "material_property" in self.rollout_data:
material_property = self.rollout_data["material_property"]
data["material_property"] = material_property

# Create a color field based on material property and particle type
color_field = np.copy(material_property)
static_particle_value = np.max(
material_property) + 1 # Use a value outside the material property range
color_field[particle_type == 3] = static_particle_value # Assumes static particle type = 3
data["color"] = color_field
else:
# If no material property, use particle type for color
color_field = np.copy(particle_type)
data["color"] = color_field

# Save particle data
pointsToVTK(
f"{path}/points{i}",
np.array(coord[:, 0]),
Expand All @@ -265,8 +296,20 @@ def write_vtk(self):
if self.dims == 2
else np.array(coord[:, 2])
),
data={"displacement": disp},
data=data,
)

# Create and save boundary data
x = np.linspace(x_min, x_max, num=2)
y = np.linspace(y_min, y_max, num=2)
z = np.array([0, 0]) if self.dims == 2 else np.linspace(0, 1, num=2)

gridToVTK(
f"{path}/boundary{i}",
x, y, z,
cellData={"boundary": np.ones((1, 1, 1))}
)

print(f"vtk saved to: {self.output_dir}{self.output_name}...")


Expand Down

0 comments on commit 9d9b116

Please sign in to comment.