Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor for SMPL example, reduce number of spawned threads #189

Merged
merged 2 commits into from
Mar 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/source/examples/01_image.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ NeRFs), or images to render as 3D textures.

import imageio.v3 as iio
import numpy as onp

import viser


Expand Down
23 changes: 22 additions & 1 deletion docs/source/examples/02_gui.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ Examples of basic GUI elements that we can create, read from, and write to.
import time

import numpy as onp

import viser


Expand Down Expand Up @@ -69,6 +68,21 @@ Examples of basic GUI elements that we can create, read from, and write to.
"Color",
initial_value=(255, 255, 0),
)
gui_multi_slider = server.add_gui_multi_slider(
"Multi slider",
min=0,
max=100,
step=1,
initial_value=(0, 30, 100),
)
gui_slider_positions = server.add_gui_slider(
"# sliders",
min=0,
max=10,
step=1,
initial_value=3,
marks=((0, "0"), (5, "5"), (7, "7"), 10),
)

# Pre-generate a point cloud to send.
point_positions = onp.random.uniform(low=-1.0, high=1.0, size=(5000, 3))
Expand All @@ -92,13 +106,20 @@ Examples of basic GUI elements that we can create, read from, and write to.
* color_coeffs[:, None]
).astype(onp.uint8),
position=gui_vector2.value + (0,),
point_shape="circle",
)

# We can use `.visible` and `.disabled` to toggle GUI elements.
gui_text.visible = not gui_checkbox_hide.value
gui_button.visible = not gui_checkbox_hide.value
gui_rgb.disabled = gui_checkbox_disable.value

# Update the number of handles in the multi-slider.
if gui_slider_positions.value != len(gui_multi_slider.value):
gui_multi_slider.value = onp.linspace(
0, 100, gui_slider_positions.value, dtype=onp.int64
)

counter += 1
time.sleep(0.01)

Expand Down
3 changes: 1 addition & 2 deletions docs/source/examples/03_gui_callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ we get updates.
import time

import numpy as onp
from typing_extensions import assert_never

import viser
from typing_extensions import assert_never


def main() -> None:
Expand Down
1 change: 0 additions & 1 deletion docs/source/examples/05_camera_commands.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ corresponding client automatically.
import time

import numpy as onp

import viser
import viser.transforms as tf

Expand Down
1 change: 0 additions & 1 deletion docs/source/examples/06_mesh.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ Visualize a mesh. To get the demo data, see ``./assets/download_dragon_mesh.sh``

import numpy as onp
import trimesh

import viser
import viser.transforms as tf

Expand Down
3 changes: 1 addition & 2 deletions docs/source/examples/07_record3d_visualizer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ Parse and stream record3d captures. To get the demo data, see ``./assets/downloa

import numpy as onp
import tyro
from tqdm.auto import tqdm

import viser
import viser.extras
import viser.transforms as tf
from tqdm.auto import tqdm


def main(
Expand Down
272 changes: 272 additions & 0 deletions docs/source/examples/08_smpl_visualizer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
.. Comment: this file is automatically generated by `update_example_docs.py`.
It should not be modified manually.

Visualizer for SMPL human body models. Requires a .npz model file.
==========================================


See here for download instructions:
https://github.com/vchoutas/smplx?tab=readme-ov-file#downloading-the-model



.. code-block:: python
:linenos:


from __future__ import annotations

import time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple

import numpy as np
import numpy as onp
import tyro
import viser
import viser.transforms as tf


@dataclass(frozen=True)
class SmplOutputs:
vertices: np.ndarray
faces: np.ndarray
T_world_joint: np.ndarray # (num_joints, 4, 4)
T_parent_joint: np.ndarray # (num_joints, 4, 4)


class SmplHelper:
"""Helper for models in the SMPL family, implemented in numpy."""

def __init__(self, model_path: Path) -> None:
assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!"
body_dict = dict(**onp.load(model_path, allow_pickle=True))

self._J_regressor = body_dict["J_regressor"]
self._weights = body_dict["weights"]
self._v_template = body_dict["v_template"]
self._posedirs = body_dict["posedirs"]
self._shapedirs = body_dict["shapedirs"]
self._faces = body_dict["f"]

self.num_joints: int = self._weights.shape[-1]
self.num_betas: int = self._shapedirs.shape[-1]
self.parent_idx: np.ndarray = body_dict["kintree_table"][0]

def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutputs:
# Get shaped vertices + joint positions, when all local poses are identity.
v_tpose = self._v_template + np.einsum("vxb,b->vx", self._shapedirs, betas)
j_tpose = np.einsum("jv,vx->jx", self._J_regressor, v_tpose)

# Local SE(3) transforms.
T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4)
T_parent_joint[:, :3, :3] = joint_rotmats
T_parent_joint[0, :3, 3] = j_tpose[0]
T_parent_joint[1:, :3, 3] = j_tpose[1:] - j_tpose[self.parent_idx[1:]]

# Forward kinematics.
T_world_joint = T_parent_joint.copy()
for i in range(1, self.num_joints):
T_world_joint[i] = T_world_joint[self.parent_idx[i]] @ T_parent_joint[i]

# Linear blend skinning.
pose_delta = (joint_rotmats[1:, ...] - np.eye(3)).flatten()
v_blend = v_tpose + np.einsum("byn,n->by", self._posedirs, pose_delta)
v_delta = np.ones((v_blend.shape[0], self.num_joints, 4))
v_delta[:, :, :3] = v_blend[:, None, :] - j_tpose[None, :, :]
v_posed = np.einsum(
"jxy,vj,vjy->vx", T_world_joint[:, :3, :], self._weights, v_delta
)
return SmplOutputs(v_posed, self._faces, T_world_joint, T_parent_joint)


def main(model_path: Path) -> None:
server = viser.ViserServer()
server.set_up_direction("+y")
server.configure_theme(control_layout="collapsible")

# Main loop. We'll read pose/shape from the GUI elements, compute the mesh,
# and then send the updated mesh in a loop.
model = SmplHelper(model_path)
gui_elements = make_gui_elements(
server,
num_betas=model.num_betas,
num_joints=model.num_joints,
parent_idx=model.parent_idx,
)
while True:
# Do nothing if no change.
time.sleep(0.02)
if not gui_elements.changed:
continue

gui_elements.changed = False

# Compute SMPL outputs.
smpl_outputs = model.get_outputs(
betas=np.array([x.value for x in gui_elements.gui_betas]),
joint_rotmats=np.stack(
[
tf.SO3.exp(np.array(x.value)).as_matrix()
for x in gui_elements.gui_joints
],
axis=0,
),
)
server.add_mesh_simple(
"/human",
smpl_outputs.vertices,
smpl_outputs.faces,
wireframe=gui_elements.gui_wireframe.value,
color=gui_elements.gui_rgb.value,
)

# Match transform control gizmos to joint positions.
for i, control in enumerate(gui_elements.transform_controls):
control.position = smpl_outputs.T_parent_joint[i, :3, 3]


@dataclass
class GuiElements:
"""Structure containing handles for reading from GUI elements."""

gui_rgb: viser.GuiInputHandle[Tuple[int, int, int]]
gui_wireframe: viser.GuiInputHandle[bool]
gui_betas: List[viser.GuiInputHandle[float]]
gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]]
transform_controls: List[viser.TransformControlsHandle]

changed: bool
"""This flag will be flipped to True whenever the mesh needs to be re-generated."""


def make_gui_elements(
server: viser.ViserServer,
num_betas: int,
num_joints: int,
parent_idx: np.ndarray,
) -> GuiElements:
"""Make GUI elements for interacting with the model."""

tab_group = server.add_gui_tab_group()

def set_changed(_) -> None:
out.changed = True # out is define later!

# GUI elements: mesh settings + visibility.
with tab_group.add_tab("View", viser.Icon.VIEWFINDER):
gui_rgb = server.add_gui_rgb("Color", initial_value=(90, 200, 255))
gui_wireframe = server.add_gui_checkbox("Wireframe", initial_value=False)
gui_show_controls = server.add_gui_checkbox("Handles", initial_value=False)

gui_rgb.on_update(set_changed)
gui_wireframe.on_update(set_changed)

@gui_show_controls.on_update
def _(_):
for control in transform_controls:
control.visible = gui_show_controls.value

# GUI elements: shape parameters.
with tab_group.add_tab("Shape", viser.Icon.BOX):
gui_reset_shape = server.add_gui_button("Reset Shape")
gui_random_shape = server.add_gui_button("Random Shape")

@gui_reset_shape.on_click
def _(_):
for beta in gui_betas:
beta.value = 0.0

@gui_random_shape.on_click
def _(_):
for beta in gui_betas:
beta.value = onp.random.normal(loc=0.0, scale=1.0)

gui_betas = []
for i in range(num_betas):
beta = server.add_gui_slider(
f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0
)
gui_betas.append(beta)
beta.on_update(set_changed)

# GUI elements: joint angles.
with tab_group.add_tab("Joints", viser.Icon.ANGLE):
gui_reset_joints = server.add_gui_button("Reset Joints")
gui_random_joints = server.add_gui_button("Random Joints")

@gui_reset_joints.on_click
def _(_):
for joint in gui_joints:
joint.value = (0.0, 0.0, 0.0)

@gui_random_joints.on_click
def _(_):
for joint in gui_joints:
# It's hard to uniformly sample orientations directly in so(3), so we
# first sample on S^3 and then convert.
quat = onp.random.normal(loc=0.0, scale=1.0, size=(4,))
quat /= onp.linalg.norm(quat)
joint.value = tf.SO3(wxyz=quat).log()

gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] = []
for i in range(num_joints):
gui_joint = server.add_gui_vector3(
label=f"Joint {i}",
initial_value=(0.0, 0.0, 0.0),
step=0.05,
)
gui_joints.append(gui_joint)

def set_callback_in_closure(i: int) -> None:
@gui_joints[i].on_update
def _(_):
transform_controls[i].wxyz = tf.SO3.exp(
np.array(gui_joints[i].value)
).wxyz
out.changed = True

set_callback_in_closure(i)

# Transform control gizmos on joints.
transform_controls: List[viser.TransformControlsHandle] = []
prefixed_joint_names = [] # Joint names, but prefixed with parents.
for i in range(num_joints):
prefixed_joint_name = f"joint_{i}"
if i > 0:
prefixed_joint_name = (
prefixed_joint_names[parent_idx[i]] + "/" + prefixed_joint_name
)
prefixed_joint_names.append(prefixed_joint_name)
controls = server.add_transform_controls(
f"/smpl/{prefixed_joint_name}",
depth_test=False,
scale=0.2 * (0.75 ** prefixed_joint_name.count("/")),
disable_axes=True,
disable_sliders=True,
visible=gui_show_controls.value,
)
transform_controls.append(controls)

def set_callback_in_closure(i: int) -> None:
@transform_controls[i].on_update
def _(_) -> None:
axisangle = tf.SO3(transform_controls[i].wxyz).log()
gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2])

set_callback_in_closure(i)

out = GuiElements(
gui_rgb,
gui_wireframe,
gui_betas,
gui_joints,
transform_controls=transform_controls,
changed=True,
)
return out


if __name__ == "__main__":
tyro.cli(main, description=__doc__)
Loading
Loading