diff --git a/docs/source/examples/01_image.rst b/docs/source/examples/01_image.rst index 9b43a86a4..4f3c42698 100644 --- a/docs/source/examples/01_image.rst +++ b/docs/source/examples/01_image.rst @@ -21,7 +21,6 @@ NeRFs), or images to render as 3D textures. import imageio.v3 as iio import numpy as onp - import viser diff --git a/docs/source/examples/02_gui.rst b/docs/source/examples/02_gui.rst index 6b15a2a74..9de3cfb17 100644 --- a/docs/source/examples/02_gui.rst +++ b/docs/source/examples/02_gui.rst @@ -16,7 +16,6 @@ Examples of basic GUI elements that we can create, read from, and write to. import time import numpy as onp - import viser @@ -69,6 +68,21 @@ Examples of basic GUI elements that we can create, read from, and write to. "Color", initial_value=(255, 255, 0), ) + gui_multi_slider = server.add_gui_multi_slider( + "Multi slider", + min=0, + max=100, + step=1, + initial_value=(0, 30, 100), + ) + gui_slider_positions = server.add_gui_slider( + "# sliders", + min=0, + max=10, + step=1, + initial_value=3, + marks=((0, "0"), (5, "5"), (7, "7"), 10), + ) # Pre-generate a point cloud to send. point_positions = onp.random.uniform(low=-1.0, high=1.0, size=(5000, 3)) @@ -92,6 +106,7 @@ Examples of basic GUI elements that we can create, read from, and write to. * color_coeffs[:, None] ).astype(onp.uint8), position=gui_vector2.value + (0,), + point_shape="circle", ) # We can use `.visible` and `.disabled` to toggle GUI elements. @@ -99,6 +114,12 @@ Examples of basic GUI elements that we can create, read from, and write to. gui_button.visible = not gui_checkbox_hide.value gui_rgb.disabled = gui_checkbox_disable.value + # Update the number of handles in the multi-slider. + if gui_slider_positions.value != len(gui_multi_slider.value): + gui_multi_slider.value = onp.linspace( + 0, 100, gui_slider_positions.value, dtype=onp.int64 + ) + counter += 1 time.sleep(0.01) diff --git a/docs/source/examples/03_gui_callbacks.rst b/docs/source/examples/03_gui_callbacks.rst index 94d1b09ee..a6b2311da 100644 --- a/docs/source/examples/03_gui_callbacks.rst +++ b/docs/source/examples/03_gui_callbacks.rst @@ -17,9 +17,8 @@ we get updates. import time import numpy as onp - from typing_extensions import assert_never - import viser + from typing_extensions import assert_never def main() -> None: diff --git a/docs/source/examples/05_camera_commands.rst b/docs/source/examples/05_camera_commands.rst index 48101cd00..636c6002e 100644 --- a/docs/source/examples/05_camera_commands.rst +++ b/docs/source/examples/05_camera_commands.rst @@ -17,7 +17,6 @@ corresponding client automatically. import time import numpy as onp - import viser import viser.transforms as tf diff --git a/docs/source/examples/06_mesh.rst b/docs/source/examples/06_mesh.rst index 2549cf4ad..918a45f98 100644 --- a/docs/source/examples/06_mesh.rst +++ b/docs/source/examples/06_mesh.rst @@ -18,7 +18,6 @@ Visualize a mesh. To get the demo data, see ``./assets/download_dragon_mesh.sh`` import numpy as onp import trimesh - import viser import viser.transforms as tf diff --git a/docs/source/examples/07_record3d_visualizer.rst b/docs/source/examples/07_record3d_visualizer.rst index c185fbede..9b0d85cdf 100644 --- a/docs/source/examples/07_record3d_visualizer.rst +++ b/docs/source/examples/07_record3d_visualizer.rst @@ -19,11 +19,10 @@ Parse and stream record3d captures. To get the demo data, see ``./assets/downloa import numpy as onp import tyro - from tqdm.auto import tqdm - import viser import viser.extras import viser.transforms as tf + from tqdm.auto import tqdm def main( diff --git a/docs/source/examples/08_smpl_visualizer.rst b/docs/source/examples/08_smpl_visualizer.rst new file mode 100644 index 000000000..86224e269 --- /dev/null +++ b/docs/source/examples/08_smpl_visualizer.rst @@ -0,0 +1,272 @@ +.. Comment: this file is automatically generated by `update_example_docs.py`. + It should not be modified manually. + +Visualizer for SMPL human body models. Requires a .npz model file. +========================================== + + +See here for download instructions: + https://github.com/vchoutas/smplx?tab=readme-ov-file#downloading-the-model + + + +.. code-block:: python + :linenos: + + + from __future__ import annotations + + import time + from dataclasses import dataclass + from pathlib import Path + from typing import List, Tuple + + import numpy as np + import numpy as onp + import tyro + import viser + import viser.transforms as tf + + + @dataclass(frozen=True) + class SmplOutputs: + vertices: np.ndarray + faces: np.ndarray + T_world_joint: np.ndarray # (num_joints, 4, 4) + T_parent_joint: np.ndarray # (num_joints, 4, 4) + + + class SmplHelper: + """Helper for models in the SMPL family, implemented in numpy.""" + + def __init__(self, model_path: Path) -> None: + assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!" + body_dict = dict(**onp.load(model_path, allow_pickle=True)) + + self._J_regressor = body_dict["J_regressor"] + self._weights = body_dict["weights"] + self._v_template = body_dict["v_template"] + self._posedirs = body_dict["posedirs"] + self._shapedirs = body_dict["shapedirs"] + self._faces = body_dict["f"] + + self.num_joints: int = self._weights.shape[-1] + self.num_betas: int = self._shapedirs.shape[-1] + self.parent_idx: np.ndarray = body_dict["kintree_table"][0] + + def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutputs: + # Get shaped vertices + joint positions, when all local poses are identity. + v_tpose = self._v_template + np.einsum("vxb,b->vx", self._shapedirs, betas) + j_tpose = np.einsum("jv,vx->jx", self._J_regressor, v_tpose) + + # Local SE(3) transforms. + T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4) + T_parent_joint[:, :3, :3] = joint_rotmats + T_parent_joint[0, :3, 3] = j_tpose[0] + T_parent_joint[1:, :3, 3] = j_tpose[1:] - j_tpose[self.parent_idx[1:]] + + # Forward kinematics. + T_world_joint = T_parent_joint.copy() + for i in range(1, self.num_joints): + T_world_joint[i] = T_world_joint[self.parent_idx[i]] @ T_parent_joint[i] + + # Linear blend skinning. + pose_delta = (joint_rotmats[1:, ...] - np.eye(3)).flatten() + v_blend = v_tpose + np.einsum("byn,n->by", self._posedirs, pose_delta) + v_delta = np.ones((v_blend.shape[0], self.num_joints, 4)) + v_delta[:, :, :3] = v_blend[:, None, :] - j_tpose[None, :, :] + v_posed = np.einsum( + "jxy,vj,vjy->vx", T_world_joint[:, :3, :], self._weights, v_delta + ) + return SmplOutputs(v_posed, self._faces, T_world_joint, T_parent_joint) + + + def main(model_path: Path) -> None: + server = viser.ViserServer() + server.set_up_direction("+y") + server.configure_theme(control_layout="collapsible") + + # Main loop. We'll read pose/shape from the GUI elements, compute the mesh, + # and then send the updated mesh in a loop. + model = SmplHelper(model_path) + gui_elements = make_gui_elements( + server, + num_betas=model.num_betas, + num_joints=model.num_joints, + parent_idx=model.parent_idx, + ) + while True: + # Do nothing if no change. + time.sleep(0.02) + if not gui_elements.changed: + continue + + gui_elements.changed = False + + # Compute SMPL outputs. + smpl_outputs = model.get_outputs( + betas=np.array([x.value for x in gui_elements.gui_betas]), + joint_rotmats=np.stack( + [ + tf.SO3.exp(np.array(x.value)).as_matrix() + for x in gui_elements.gui_joints + ], + axis=0, + ), + ) + server.add_mesh_simple( + "/human", + smpl_outputs.vertices, + smpl_outputs.faces, + wireframe=gui_elements.gui_wireframe.value, + color=gui_elements.gui_rgb.value, + ) + + # Match transform control gizmos to joint positions. + for i, control in enumerate(gui_elements.transform_controls): + control.position = smpl_outputs.T_parent_joint[i, :3, 3] + + + @dataclass + class GuiElements: + """Structure containing handles for reading from GUI elements.""" + + gui_rgb: viser.GuiInputHandle[Tuple[int, int, int]] + gui_wireframe: viser.GuiInputHandle[bool] + gui_betas: List[viser.GuiInputHandle[float]] + gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] + transform_controls: List[viser.TransformControlsHandle] + + changed: bool + """This flag will be flipped to True whenever the mesh needs to be re-generated.""" + + + def make_gui_elements( + server: viser.ViserServer, + num_betas: int, + num_joints: int, + parent_idx: np.ndarray, + ) -> GuiElements: + """Make GUI elements for interacting with the model.""" + + tab_group = server.add_gui_tab_group() + + def set_changed(_) -> None: + out.changed = True # out is define later! + + # GUI elements: mesh settings + visibility. + with tab_group.add_tab("View", viser.Icon.VIEWFINDER): + gui_rgb = server.add_gui_rgb("Color", initial_value=(90, 200, 255)) + gui_wireframe = server.add_gui_checkbox("Wireframe", initial_value=False) + gui_show_controls = server.add_gui_checkbox("Handles", initial_value=False) + + gui_rgb.on_update(set_changed) + gui_wireframe.on_update(set_changed) + + @gui_show_controls.on_update + def _(_): + for control in transform_controls: + control.visible = gui_show_controls.value + + # GUI elements: shape parameters. + with tab_group.add_tab("Shape", viser.Icon.BOX): + gui_reset_shape = server.add_gui_button("Reset Shape") + gui_random_shape = server.add_gui_button("Random Shape") + + @gui_reset_shape.on_click + def _(_): + for beta in gui_betas: + beta.value = 0.0 + + @gui_random_shape.on_click + def _(_): + for beta in gui_betas: + beta.value = onp.random.normal(loc=0.0, scale=1.0) + + gui_betas = [] + for i in range(num_betas): + beta = server.add_gui_slider( + f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0 + ) + gui_betas.append(beta) + beta.on_update(set_changed) + + # GUI elements: joint angles. + with tab_group.add_tab("Joints", viser.Icon.ANGLE): + gui_reset_joints = server.add_gui_button("Reset Joints") + gui_random_joints = server.add_gui_button("Random Joints") + + @gui_reset_joints.on_click + def _(_): + for joint in gui_joints: + joint.value = (0.0, 0.0, 0.0) + + @gui_random_joints.on_click + def _(_): + for joint in gui_joints: + # It's hard to uniformly sample orientations directly in so(3), so we + # first sample on S^3 and then convert. + quat = onp.random.normal(loc=0.0, scale=1.0, size=(4,)) + quat /= onp.linalg.norm(quat) + joint.value = tf.SO3(wxyz=quat).log() + + gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] = [] + for i in range(num_joints): + gui_joint = server.add_gui_vector3( + label=f"Joint {i}", + initial_value=(0.0, 0.0, 0.0), + step=0.05, + ) + gui_joints.append(gui_joint) + + def set_callback_in_closure(i: int) -> None: + @gui_joints[i].on_update + def _(_): + transform_controls[i].wxyz = tf.SO3.exp( + np.array(gui_joints[i].value) + ).wxyz + out.changed = True + + set_callback_in_closure(i) + + # Transform control gizmos on joints. + transform_controls: List[viser.TransformControlsHandle] = [] + prefixed_joint_names = [] # Joint names, but prefixed with parents. + for i in range(num_joints): + prefixed_joint_name = f"joint_{i}" + if i > 0: + prefixed_joint_name = ( + prefixed_joint_names[parent_idx[i]] + "/" + prefixed_joint_name + ) + prefixed_joint_names.append(prefixed_joint_name) + controls = server.add_transform_controls( + f"/smpl/{prefixed_joint_name}", + depth_test=False, + scale=0.2 * (0.75 ** prefixed_joint_name.count("/")), + disable_axes=True, + disable_sliders=True, + visible=gui_show_controls.value, + ) + transform_controls.append(controls) + + def set_callback_in_closure(i: int) -> None: + @transform_controls[i].on_update + def _(_) -> None: + axisangle = tf.SO3(transform_controls[i].wxyz).log() + gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2]) + + set_callback_in_closure(i) + + out = GuiElements( + gui_rgb, + gui_wireframe, + gui_betas, + gui_joints, + transform_controls=transform_controls, + changed=True, + ) + return out + + + if __name__ == "__main__": + tyro.cli(main, description=__doc__) diff --git a/docs/source/examples/08_smplx_visualizer.rst b/docs/source/examples/08_smplx_visualizer.rst deleted file mode 100644 index dafc82aee..000000000 --- a/docs/source/examples/08_smplx_visualizer.rst +++ /dev/null @@ -1,277 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -SMPL-X visualizer -========================================== - - -We need to install the smplx package and download a corresponding set of model -parameters to run this script: - - -* https://github.com/vchoutas/smplx - - - -.. code-block:: python - :linenos: - - - import dataclasses - import time - from pathlib import Path - from typing import List, Tuple - - import numpy as onp - import smplx - import smplx.joint_names - import smplx.lbs - import torch - import tyro - from typing_extensions import Literal - - import viser - import viser.transforms as tf - - - def main( - model_path: Path, - model_type: Literal["smpl", "smplh", "smplx", "mano"] = "smplx", - gender: Literal["male", "female", "neutral"] = "neutral", - num_betas: int = 10, - num_expression_coeffs: int = 10, - ext: Literal["npz", "pkl"] = "npz", - share: bool = False, - ) -> None: - server = viser.ViserServer() - server.set_up_direction("+y") - if share: - server.request_share_url() - - server.configure_theme(control_layout="collapsible") - model = smplx.create( - model_path=str(model_path), - model_type=model_type, - gender=gender, - num_betas=num_betas, - num_expression_coeffs=num_expression_coeffs, - ext=ext, - ) - - # Main loop. We'll just keep read from the joints, deform the mesh, then sending the - # updated mesh in a loop. This could be made a lot more efficient. - gui_elements = make_gui_elements( - server, num_betas=model.num_betas, num_body_joints=int(model.NUM_BODY_JOINTS) - ) - while True: - # Do nothing if no change. - if not gui_elements.changed: - time.sleep(0.01) - continue - gui_elements.changed = False - - full_pose = torch.from_numpy( - onp.array( - [j.value for j in gui_elements.gui_joints[1:]], dtype=onp.float32 - )[None, ...] # type: ignore - ) - - # Get deformed mesh. - output = model.forward( - betas=torch.from_numpy( # type: ignore - onp.array([b.value for b in gui_elements.gui_betas], dtype=onp.float32)[ - None, ... - ] - ), - expression=None, - return_verts=True, - body_pose=full_pose[:, : model.NUM_BODY_JOINTS], # type: ignore - global_orient=torch.from_numpy( - onp.array(gui_elements.gui_joints[0].value, dtype=onp.float32)[ - None, ... - ] - ), # type: ignore - return_full_pose=True, - ) - joint_positions = output.joints.squeeze(axis=0).detach().cpu().numpy() # type: ignore - joint_transforms, parents = joint_transforms_and_parents_from_smpl( - model, output - ) - - # Send mesh to visualizer. - server.add_mesh_simple( - "/smpl", - vertices=output.vertices.squeeze(axis=0).detach().cpu().numpy(), # type: ignore - faces=model.faces, - wireframe=gui_elements.gui_wireframe.value, - color=gui_elements.gui_rgb.value, - flat_shading=False, - ) - - # Update per-joint frames, which are used for transform controls. - for i in range(model.NUM_BODY_JOINTS + 1): - R = joint_transforms[parents[i], :3, :3] - server.add_frame( - f"/smpl/joint_{i}", - wxyz=((1.0, 0.0, 0.0, 0.0) if i == 0 else tf.SO3.from_matrix(R).wxyz), - position=joint_positions[i], - show_axes=False, - ) - - - @dataclasses.dataclass - class GuiElements: - """Structure containing handles for reading from GUI elements.""" - - gui_rgb: viser.GuiInputHandle[Tuple[int, int, int]] - gui_wireframe: viser.GuiInputHandle[bool] - gui_betas: List[viser.GuiInputHandle[float]] - gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] - - changed: bool - """This flag will be flipped to True whenever the mesh needs to be re-generated.""" - - - def make_gui_elements( - server: viser.ViserServer, num_betas: int, num_body_joints: int - ) -> GuiElements: - """Make GUI elements for interacting with the model.""" - - tab_group = server.add_gui_tab_group() - - # GUI elements: mesh settings + visibility. - with tab_group.add_tab("View", viser.Icon.VIEWFINDER): - gui_rgb = server.add_gui_rgb("Color", initial_value=(90, 200, 255)) - gui_wireframe = server.add_gui_checkbox("Wireframe", initial_value=False) - gui_show_controls = server.add_gui_checkbox("Handles", initial_value=False) - - @gui_rgb.on_update - def _(_): - out.changed = True - - @gui_wireframe.on_update - def _(_): - out.changed = True - - @gui_show_controls.on_update - def _(_): - add_transform_controls(enabled=gui_show_controls.value) - - # GUI elements: shape parameters. - with tab_group.add_tab("Shape", viser.Icon.BOX): - gui_reset_shape = server.add_gui_button("Reset Shape") - gui_random_shape = server.add_gui_button("Random Shape") - - @gui_reset_shape.on_click - def _(_): - for beta in gui_betas: - beta.value = 0.0 - - @gui_random_shape.on_click - def _(_): - for beta in gui_betas: - beta.value = onp.random.normal(loc=0.0, scale=1.0) - - gui_betas = [] - for i in range(num_betas): - beta = server.add_gui_slider( - f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0 - ) - gui_betas.append(beta) - - @beta.on_update - def _(_): - out.changed = True - - # GUI elements: joint angles. - with tab_group.add_tab("Joints", viser.Icon.ANGLE): - gui_reset_joints = server.add_gui_button("Reset Joints") - gui_random_joints = server.add_gui_button("Random Joints") - - @gui_reset_joints.on_click - def _(_): - for joint in gui_joints: - joint.value = (0.0, 0.0, 0.0) - sync_transform_controls() - - @gui_random_joints.on_click - def _(_): - for joint in gui_joints: - # It's hard to uniformly sample orientations directly in so(3), so we - # first sample on S^3 and then convert. - quat = onp.random.normal(loc=0.0, scale=1.0, size=(4,)) - quat /= onp.linalg.norm(quat) - - # xyzw => wxyz => so(3) - joint.value = tf.SO3(wxyz=quat).log() - sync_transform_controls() - - gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] = [] - for i in range(num_body_joints + 1): - gui_joint = server.add_gui_vector3( - label=smplx.joint_names.JOINT_NAMES[i], - initial_value=(0.0, 0.0, 0.0), - step=0.05, - ) - gui_joints.append(gui_joint) - - @gui_joint.on_update - def _(_): - sync_transform_controls() - out.changed = True - - # Transform control gizmos on joints. - transform_controls: List[viser.TransformControlsHandle] = [] - - def add_transform_controls(enabled: bool) -> List[viser.TransformControlsHandle]: - for i in range(1 + num_body_joints): - controls = server.add_transform_controls( - f"/smpl/joint_{i}/controls", - depth_test=False, - line_width=3.5 if i == 0 else 2.0, - scale=0.2 if i == 0 else 0.1, - disable_axes=True, - disable_sliders=True, - disable_rotations=not enabled, - ) - transform_controls.append(controls) - - def curry_callback(i: int) -> None: - @controls.on_update - def _(controls: viser.TransformControlsHandle) -> None: - axisangle = tf.SO3(controls.wxyz).log() - gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2]) - - curry_callback(i) - - return transform_controls - - def sync_transform_controls() -> None: - """Sync transform controls when a joint angle changes.""" - for t, j in zip(transform_controls, gui_joints): - t.wxyz = tf.SO3.exp(onp.array(j.value)).wxyz - - add_transform_controls(enabled=False) - - out = GuiElements(gui_rgb, gui_wireframe, gui_betas, gui_joints, changed=True) - return out - - - def joint_transforms_and_parents_from_smpl(model, output): - """Hack at SMPL internals to get coordinate frames corresponding to each joint.""" - v_shaped = model.v_template + smplx.lbs.blend_shapes( # type: ignore - model.betas, - model.shapedirs, # type: ignore - ) - J = smplx.lbs.vertices2joints(model.J_regressor, v_shaped) # type: ignore - rot_mats = smplx.lbs.batch_rodrigues(output.full_pose.view(-1, 3)).view( # type: ignore - [1, -1, 3, 3] - ) - J_posed, A = smplx.lbs.batch_rigid_transform(rot_mats, J, model.parents) # type: ignore - transforms = A.detach().cpu().numpy().squeeze(axis=0) # type: ignore - parents = model.parents.detach().cpu().numpy() # type: ignore - return transforms, parents - - - if __name__ == "__main__": - tyro.cli(main, description=__doc__) diff --git a/docs/source/examples/09_urdf_visualizer.rst b/docs/source/examples/09_urdf_visualizer.rst index a679113fc..5d05b8a75 100644 --- a/docs/source/examples/09_urdf_visualizer.rst +++ b/docs/source/examples/09_urdf_visualizer.rst @@ -27,7 +27,6 @@ Examples: import numpy as onp import tyro - import viser from viser.extras import ViserUrdf diff --git a/docs/source/examples/10_realsense.rst b/docs/source/examples/10_realsense.rst index e0a2d9e13..9eb748560 100644 --- a/docs/source/examples/10_realsense.rst +++ b/docs/source/examples/10_realsense.rst @@ -20,9 +20,8 @@ pyrealsense2. import numpy as np import numpy.typing as npt import pyrealsense2 as rs # type: ignore - from tqdm.auto import tqdm - import viser + from tqdm.auto import tqdm @contextlib.contextmanager diff --git a/docs/source/examples/11_colmap_visualizer.rst b/docs/source/examples/11_colmap_visualizer.rst index dc1da2c0f..b26d04bc4 100644 --- a/docs/source/examples/11_colmap_visualizer.rst +++ b/docs/source/examples/11_colmap_visualizer.rst @@ -20,10 +20,9 @@ Visualize COLMAP sparse reconstruction outputs. To get demo data, see ``./assets import imageio.v3 as iio import numpy as onp import tyro - from tqdm.auto import tqdm - import viser import viser.transforms as tf + from tqdm.auto import tqdm from viser.extras.colmap import ( read_cameras_binary, read_images_binary, diff --git a/docs/source/examples/12_click_meshes.rst b/docs/source/examples/12_click_meshes.rst index 30ec69e1c..062270f04 100644 --- a/docs/source/examples/12_click_meshes.rst +++ b/docs/source/examples/12_click_meshes.rst @@ -16,7 +16,6 @@ Click on meshes to select them. The index of the last clicked mesh is displayed import time import matplotlib - import viser diff --git a/docs/source/examples/13_theming.rst b/docs/source/examples/13_theming.rst index c2b0bf868..c25d1e1cf 100644 --- a/docs/source/examples/13_theming.rst +++ b/docs/source/examples/13_theming.rst @@ -20,7 +20,7 @@ Viser includes support for light theming. def main(): - server = viser.ViserServer() + server = viser.ViserServer(label="Viser Theming") buttons = ( TitlebarButton( diff --git a/docs/source/examples/15_gui_in_scene.rst b/docs/source/examples/15_gui_in_scene.rst index a8dc4a549..84ef4c83f 100644 --- a/docs/source/examples/15_gui_in_scene.rst +++ b/docs/source/examples/15_gui_in_scene.rst @@ -19,7 +19,6 @@ performed on them. from typing import Optional import numpy as onp - import viser import viser.transforms as tf diff --git a/docs/source/examples/17_background_composite.rst b/docs/source/examples/17_background_composite.rst index 0a923f68c..8af08a2a7 100644 --- a/docs/source/examples/17_background_composite.rst +++ b/docs/source/examples/17_background_composite.rst @@ -19,7 +19,6 @@ be useful when we want a 2D image to occlude 3D geometry, such as for NeRF rende import numpy as onp import trimesh import trimesh.creation - import viser server = viser.ViserServer() diff --git a/docs/source/examples/18_splines.rst b/docs/source/examples/18_splines.rst index 39d408cd5..1ff845edd 100644 --- a/docs/source/examples/18_splines.rst +++ b/docs/source/examples/18_splines.rst @@ -16,7 +16,6 @@ Make a ball with some random splines. import time import numpy as onp - import viser diff --git a/docs/source/examples/19_get_renders.rst b/docs/source/examples/19_get_renders.rst index 5abd22217..5683acf4c 100644 --- a/docs/source/examples/19_get_renders.rst +++ b/docs/source/examples/19_get_renders.rst @@ -17,7 +17,6 @@ Example for getting renders from a client's viewport to the Python API. import imageio.v3 as iio import numpy as onp - import viser diff --git a/docs/source/examples/20_scene_click.rst b/docs/source/examples/20_scene_click.rst index 8dea7292e..3ccf628eb 100644 --- a/docs/source/examples/20_scene_click.rst +++ b/docs/source/examples/20_scene_click.rst @@ -23,7 +23,6 @@ To get the demo data, see ``./assets/download_dragon_mesh.sh``. import numpy as onp import trimesh.creation import trimesh.ray - import viser import viser.transforms as tf diff --git a/docs/source/examples/22_games.rst b/docs/source/examples/22_games.rst index aa82b4b6f..455224ce5 100644 --- a/docs/source/examples/22_games.rst +++ b/docs/source/examples/22_games.rst @@ -18,10 +18,9 @@ Some two-player games implemented using scene click events. import numpy as onp import trimesh.creation - from typing_extensions import assert_never - import viser import viser.transforms as tf + from typing_extensions import assert_never def main() -> None: diff --git a/examples/08_smplx_visualizer.py b/examples/08_smplx_visualizer.py deleted file mode 100644 index 6e9ba6da0..000000000 --- a/examples/08_smplx_visualizer.py +++ /dev/null @@ -1,269 +0,0 @@ -# mypy: disable-error-code="assignment" -# -# Asymmetric properties are supported in Pyright, but not yet in mypy. -# - https://github.com/python/mypy/issues/3004 -# - https://github.com/python/mypy/pull/11643 -"""SMPL-X visualizer - -We need to install the smplx package and download a corresponding set of model -parameters to run this script: -- https://github.com/vchoutas/smplx -""" - -import dataclasses -import time -from pathlib import Path -from typing import List, Tuple - -import numpy as onp -import smplx -import smplx.joint_names -import smplx.lbs -import torch -import tyro -import viser -import viser.transforms as tf -from typing_extensions import Literal - - -def main( - model_path: Path, - model_type: Literal["smpl", "smplh", "smplx", "mano"] = "smplx", - gender: Literal["male", "female", "neutral"] = "neutral", - num_betas: int = 10, - num_expression_coeffs: int = 10, - ext: Literal["npz", "pkl"] = "npz", - share: bool = False, -) -> None: - server = viser.ViserServer() - server.set_up_direction("+y") - if share: - server.request_share_url() - - server.configure_theme(control_layout="collapsible") - model = smplx.create( - model_path=str(model_path), - model_type=model_type, - gender=gender, - num_betas=num_betas, - num_expression_coeffs=num_expression_coeffs, - ext=ext, - ) - - # Main loop. We'll just keep read from the joints, deform the mesh, then sending the - # updated mesh in a loop. This could be made a lot more efficient. - gui_elements = make_gui_elements( - server, num_betas=model.num_betas, num_body_joints=int(model.NUM_BODY_JOINTS) - ) - while True: - # Do nothing if no change. - if not gui_elements.changed: - time.sleep(0.01) - continue - gui_elements.changed = False - - full_pose = torch.from_numpy( - onp.array( - [j.value for j in gui_elements.gui_joints[1:]], dtype=onp.float32 - )[None, ...] # type: ignore - ) - - # Get deformed mesh. - output = model.forward( - betas=torch.from_numpy( # type: ignore - onp.array([b.value for b in gui_elements.gui_betas], dtype=onp.float32)[ - None, ... - ] - ), - expression=None, - return_verts=True, - body_pose=full_pose[:, : model.NUM_BODY_JOINTS], # type: ignore - global_orient=torch.from_numpy( - onp.array(gui_elements.gui_joints[0].value, dtype=onp.float32)[ - None, ... - ] - ), # type: ignore - return_full_pose=True, - ) - joint_positions = output.joints.squeeze(axis=0).detach().cpu().numpy() # type: ignore - joint_transforms, parents = joint_transforms_and_parents_from_smpl( - model, output - ) - - # Send mesh to visualizer. - server.add_mesh_simple( - "/smpl", - vertices=output.vertices.squeeze(axis=0).detach().cpu().numpy(), # type: ignore - faces=model.faces, - wireframe=gui_elements.gui_wireframe.value, - color=gui_elements.gui_rgb.value, - flat_shading=False, - ) - - # Update per-joint frames, which are used for transform controls. - for i in range(model.NUM_BODY_JOINTS + 1): - R = joint_transforms[parents[i], :3, :3] - server.add_frame( - f"/smpl/joint_{i}", - wxyz=((1.0, 0.0, 0.0, 0.0) if i == 0 else tf.SO3.from_matrix(R).wxyz), - position=joint_positions[i], - show_axes=False, - ) - - -@dataclasses.dataclass -class GuiElements: - """Structure containing handles for reading from GUI elements.""" - - gui_rgb: viser.GuiInputHandle[Tuple[int, int, int]] - gui_wireframe: viser.GuiInputHandle[bool] - gui_betas: List[viser.GuiInputHandle[float]] - gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] - - changed: bool - """This flag will be flipped to True whenever the mesh needs to be re-generated.""" - - -def make_gui_elements( - server: viser.ViserServer, num_betas: int, num_body_joints: int -) -> GuiElements: - """Make GUI elements for interacting with the model.""" - - tab_group = server.add_gui_tab_group() - - # GUI elements: mesh settings + visibility. - with tab_group.add_tab("View", viser.Icon.VIEWFINDER): - gui_rgb = server.add_gui_rgb("Color", initial_value=(90, 200, 255)) - gui_wireframe = server.add_gui_checkbox("Wireframe", initial_value=False) - gui_show_controls = server.add_gui_checkbox("Handles", initial_value=False) - - @gui_rgb.on_update - def _(_): - out.changed = True - - @gui_wireframe.on_update - def _(_): - out.changed = True - - @gui_show_controls.on_update - def _(_): - add_transform_controls(enabled=gui_show_controls.value) - - # GUI elements: shape parameters. - with tab_group.add_tab("Shape", viser.Icon.BOX): - gui_reset_shape = server.add_gui_button("Reset Shape") - gui_random_shape = server.add_gui_button("Random Shape") - - @gui_reset_shape.on_click - def _(_): - for beta in gui_betas: - beta.value = 0.0 - - @gui_random_shape.on_click - def _(_): - for beta in gui_betas: - beta.value = onp.random.normal(loc=0.0, scale=1.0) - - gui_betas = [] - for i in range(num_betas): - beta = server.add_gui_slider( - f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0 - ) - gui_betas.append(beta) - - @beta.on_update - def _(_): - out.changed = True - - # GUI elements: joint angles. - with tab_group.add_tab("Joints", viser.Icon.ANGLE): - gui_reset_joints = server.add_gui_button("Reset Joints") - gui_random_joints = server.add_gui_button("Random Joints") - - @gui_reset_joints.on_click - def _(_): - for joint in gui_joints: - joint.value = (0.0, 0.0, 0.0) - sync_transform_controls() - - @gui_random_joints.on_click - def _(_): - for joint in gui_joints: - # It's hard to uniformly sample orientations directly in so(3), so we - # first sample on S^3 and then convert. - quat = onp.random.normal(loc=0.0, scale=1.0, size=(4,)) - quat /= onp.linalg.norm(quat) - - # xyzw => wxyz => so(3) - joint.value = tf.SO3(wxyz=quat).log() - sync_transform_controls() - - gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] = [] - for i in range(num_body_joints + 1): - gui_joint = server.add_gui_vector3( - label=smplx.joint_names.JOINT_NAMES[i], - initial_value=(0.0, 0.0, 0.0), - step=0.05, - ) - gui_joints.append(gui_joint) - - @gui_joint.on_update - def _(_): - sync_transform_controls() - out.changed = True - - # Transform control gizmos on joints. - transform_controls: List[viser.TransformControlsHandle] = [] - - def add_transform_controls(enabled: bool) -> List[viser.TransformControlsHandle]: - for i in range(1 + num_body_joints): - controls = server.add_transform_controls( - f"/smpl/joint_{i}/controls", - depth_test=False, - line_width=3.5 if i == 0 else 2.0, - scale=0.2 if i == 0 else 0.1, - disable_axes=True, - disable_sliders=True, - disable_rotations=not enabled, - ) - transform_controls.append(controls) - - def curry_callback(i: int) -> None: - @controls.on_update - def _(controls: viser.TransformControlsHandle) -> None: - axisangle = tf.SO3(controls.wxyz).log() - gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2]) - - curry_callback(i) - - return transform_controls - - def sync_transform_controls() -> None: - """Sync transform controls when a joint angle changes.""" - for t, j in zip(transform_controls, gui_joints): - t.wxyz = tf.SO3.exp(onp.array(j.value)).wxyz - - add_transform_controls(enabled=False) - - out = GuiElements(gui_rgb, gui_wireframe, gui_betas, gui_joints, changed=True) - return out - - -def joint_transforms_and_parents_from_smpl(model, output): - """Hack at SMPL internals to get coordinate frames corresponding to each joint.""" - v_shaped = model.v_template + smplx.lbs.blend_shapes( # type: ignore - model.betas, - model.shapedirs, # type: ignore - ) - J = smplx.lbs.vertices2joints(model.J_regressor, v_shaped) # type: ignore - rot_mats = smplx.lbs.batch_rodrigues(output.full_pose.view(-1, 3)).view( # type: ignore - [1, -1, 3, 3] - ) - J_posed, A = smplx.lbs.batch_rigid_transform(rot_mats, J, model.parents) # type: ignore - transforms = A.detach().cpu().numpy().squeeze(axis=0) # type: ignore - parents = model.parents.detach().cpu().numpy() # type: ignore - return transforms, parents - - -if __name__ == "__main__": - tyro.cli(main, description=__doc__) diff --git a/examples/09_urdf_visualizer.py b/examples/09_urdf_visualizer.py index 3cb475034..da99c257a 100644 --- a/examples/09_urdf_visualizer.py +++ b/examples/09_urdf_visualizer.py @@ -6,6 +6,7 @@ - https://github.com/OrebroUniversity/yumi/blob/master/yumi_description/urdf/yumi.urdf - https://github.com/ankurhanda/robot-assets """ + from __future__ import annotations import time diff --git a/examples/10_realsense.py b/examples/10_realsense.py index 9b54a2950..ae153ec28 100644 --- a/examples/10_realsense.py +++ b/examples/10_realsense.py @@ -3,6 +3,7 @@ Connect to a RealSense camera, then visualize RGB-D readings as a point clouds. Requires pyrealsense2. """ + import contextlib from typing import Tuple diff --git a/examples/22_games.py b/examples/22_games.py index 8383ac12f..a8b37de62 100644 --- a/examples/22_games.py +++ b/examples/22_games.py @@ -7,7 +7,6 @@ Some two-player games implemented using scene click events.""" - import time from typing import Literal diff --git a/pyproject.toml b/pyproject.toml index d09760716..d4114199f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ dev = [ "pyright>=1.1.308", "mypy>=1.4.1", - "ruff==0.1.13", + "ruff==0.3.3", "pre-commit==3.3.2", ] examples = [ diff --git a/src/viser/_gui_api.py b/src/viser/_gui_api.py index bf7ceec9b..53d23e83e 100644 --- a/src/viser/_gui_api.py +++ b/src/viser/_gui_api.py @@ -427,8 +427,7 @@ def add_gui_button_group( disabled: bool = False, hint: Optional[str] = None, order: Optional[float] = None, - ) -> GuiButtonGroupHandle[TLiteralString]: - ... + ) -> GuiButtonGroupHandle[TLiteralString]: ... @overload def add_gui_button_group( @@ -439,8 +438,7 @@ def add_gui_button_group( disabled: bool = False, hint: Optional[str] = None, order: Optional[float] = None, - ) -> GuiButtonGroupHandle[TString]: - ... + ) -> GuiButtonGroupHandle[TString]: ... def add_gui_button_group( self, @@ -768,8 +766,7 @@ def add_gui_dropdown( visible: bool = True, hint: Optional[str] = None, order: Optional[float] = None, - ) -> GuiDropdownHandle[TLiteralString]: - ... + ) -> GuiDropdownHandle[TLiteralString]: ... @overload def add_gui_dropdown( @@ -781,8 +778,7 @@ def add_gui_dropdown( visible: bool = True, hint: Optional[str] = None, order: Optional[float] = None, - ) -> GuiDropdownHandle[TString]: - ... + ) -> GuiDropdownHandle[TString]: ... def add_gui_dropdown( self, diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index 7be24fd05..3ca868eca 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -2,7 +2,6 @@ import dataclasses import re -import threading import time import urllib.parse import uuid @@ -53,8 +52,7 @@ class GuiContainerProtocol(Protocol): class SupportsRemoveProtocol(Protocol): - def remove(self) -> None: - ... + def remove(self) -> None: ... @dataclasses.dataclass @@ -142,15 +140,15 @@ def value(self, value: T | onp.ndarray) -> None: for cb in self._impl.update_cb: # Pushing callbacks into separate threads helps prevent deadlocks when we # have a lock in a callback. TODO: revisit other callbacks. - threading.Thread( - target=lambda: cb( + self._impl.gui_api._get_api()._thread_executor.submit( + lambda: cb( GuiEvent( client_id=None, client=None, target=self, ) ) - ).start() + ) @property def update_timestamp(self) -> float: diff --git a/src/viser/_message_api.py b/src/viser/_message_api.py index 0f3ef935c..2166e4567 100644 --- a/src/viser/_message_api.py +++ b/src/viser/_message_api.py @@ -17,6 +17,7 @@ import threading import time import warnings +from concurrent.futures import ThreadPoolExecutor from typing import ( TYPE_CHECKING, Callable, @@ -147,7 +148,9 @@ class MessageApi(abc.ABC): _locked_thread_id: int # Appeasing mypy 1.5.1, not sure why this is needed. - def __init__(self, handler: infra.MessageHandler) -> None: + def __init__( + self, handler: infra.MessageHandler, thread_executor: ThreadPoolExecutor + ) -> None: self._message_handler = handler super().__init__() @@ -177,6 +180,7 @@ def __init__(self, handler: infra.MessageHandler) -> None: self._atomic_lock = threading.Lock() self._queued_messages: queue.Queue = queue.Queue() self._locked_thread_id = -1 + self._thread_executor = thread_executor def set_gui_panel_label(self, label: Optional[str]) -> None: """Set the main label that appears in the GUI panel. @@ -1181,7 +1185,7 @@ def try_again() -> None: with self._atomic_lock: self._queue_unsafe(self._queued_messages.get()) - threading.Thread(target=try_again).start() + self._thread_executor.submit(try_again) @abc.abstractmethod def _queue_unsafe(self, message: _messages.Message) -> None: @@ -1216,9 +1220,11 @@ def _handle_transform_controls_updates( return # Update state. + wxyz = onp.array(message.wxyz) + position = onp.array(message.position) with self.atomic(): - handle._impl.wxyz = onp.array(message.wxyz) - handle._impl.position = onp.array(message.position) + handle._impl.wxyz = wxyz + handle._impl.position = position handle._impl_aux.last_updated = time.time() # Trigger callbacks. diff --git a/src/viser/_viser.py b/src/viser/_viser.py index 3580f583f..23c0178dc 100644 --- a/src/viser/_viser.py +++ b/src/viser/_viser.py @@ -259,7 +259,7 @@ class ClientHandle(MessageApi, GuiApi): _state: _ClientHandleState def __post_init__(self): - super().__init__(self._state.connection) + super().__init__(self._state.connection, self._state.server._thread_executor) @override def _get_api(self) -> MessageApi: @@ -341,8 +341,7 @@ class ViserServer(MessageApi, GuiApi): # Hide deprecated arguments from docstring and type checkers. def __init__( self, host: str = "0.0.0.0", port: int = 8080, label: Optional[str] = None - ): - ... + ): ... def _actual_init( self, @@ -360,7 +359,7 @@ def _actual_init( client_api_version=1, ) self._server = server - super().__init__(server) + super().__init__(server, server._thread_executor) _client_autobuild.ensure_client_is_built() diff --git a/src/viser/client/src/CameraControls.tsx b/src/viser/client/src/CameraControls.tsx index 32c055ea9..7c499437a 100644 --- a/src/viser/client/src/CameraControls.tsx +++ b/src/viser/client/src/CameraControls.tsx @@ -114,10 +114,12 @@ export function SynchronizedCameraControls() { }, [connected, sendCamera]); // Send camera for 3D viewport changes. - const canvas = viewer.canvasRef.current!; // R3F canvas. + const canvas = viewer.canvasRef.current!; // R3F canvas. React.useEffect(() => { // Create a resize observer to resize the CSS canvas when the window is resized. - const resizeObserver = new ResizeObserver(() => { sendCamera() }); + const resizeObserver = new ResizeObserver(() => { + sendCamera(); + }); resizeObserver.observe(canvas); // Cleanup. diff --git a/src/viser/client/src/SceneTree.tsx b/src/viser/client/src/SceneTree.tsx index 9d7019b8d..48f7dba92 100644 --- a/src/viser/client/src/SceneTree.tsx +++ b/src/viser/client/src/SceneTree.tsx @@ -201,9 +201,12 @@ export function SceneNodeThreeObject(props: { // Update attributes on a per-frame basis. Currently does redundant work, // although this shouldn't be a bottleneck. useFrame(() => { + const attrs = viewer.nodeAttributesFromName.current[props.name]; if (unmountWhenInvisible) { const displayed = isDisplayed(); if (displayed && unmount) { + // Need to re-set attributes after remounting, eg for transform controls. + if (attrs !== undefined) attrs.poseUpdateState = "needsUpdate"; setUnmount(false); } if (!displayed && !unmount) { @@ -212,8 +215,6 @@ export function SceneNodeThreeObject(props: { } if (obj === null) return; - - const attrs = viewer.nodeAttributesFromName.current[props.name]; if (attrs === undefined) return; const visibility = diff --git a/src/viser/client/src/WebsocketInterface.tsx b/src/viser/client/src/WebsocketInterface.tsx index 408f93c64..98861dd57 100644 --- a/src/viser/client/src/WebsocketInterface.tsx +++ b/src/viser/client/src/WebsocketInterface.tsx @@ -298,7 +298,7 @@ function useMessageHandler() { return texture; }; const standardArgs = { - color: message.color || undefined, + color: message.color ?? undefined, vertexColors: message.vertex_colors !== null, wireframe: message.wireframe, transparent: message.opacity !== null, @@ -416,45 +416,52 @@ function useMessageHandler() { 50, ); addSceneNodeMakeParents( - new SceneNode(message.name, (ref) => ( - e.stopPropagation()}> - { - const attrs = viewer.nodeAttributesFromName.current; - if (attrs[message.name] === undefined) { - attrs[message.name] = {}; - } - - const wxyz = new THREE.Quaternion(); - wxyz.setFromRotationMatrix(l); - const position = new THREE.Vector3().setFromMatrixPosition(l); - - const nodeAttributes = attrs[message.name]!; - nodeAttributes.wxyz = [wxyz.w, wxyz.x, wxyz.y, wxyz.z]; - nodeAttributes.position = position.toArray(); - sendDragMessage({ - type: "TransformControlsUpdateMessage", - name: name, - wxyz: nodeAttributes.wxyz, - position: nodeAttributes.position, - }); - }} - /> - - )), + new SceneNode( + message.name, + (ref) => ( + e.stopPropagation()}> + { + const attrs = viewer.nodeAttributesFromName.current; + if (attrs[message.name] === undefined) { + attrs[message.name] = {}; + } + + const wxyz = new THREE.Quaternion(); + wxyz.setFromRotationMatrix(l); + const position = new THREE.Vector3().setFromMatrixPosition( + l, + ); + + const nodeAttributes = attrs[message.name]!; + nodeAttributes.wxyz = [wxyz.w, wxyz.x, wxyz.y, wxyz.z]; + nodeAttributes.position = position.toArray(); + sendDragMessage({ + type: "TransformControlsUpdateMessage", + name: name, + wxyz: nodeAttributes.wxyz, + position: nodeAttributes.position, + }); + }} + /> + + ), + undefined, + true, // unmountWhenInvisible + ), ); return; } diff --git a/src/viser/scripts/dev_checks.py b/src/viser/scripts/dev_checks.py index c56a5db30..02f4f3a4b 100644 --- a/src/viser/scripts/dev_checks.py +++ b/src/viser/scripts/dev_checks.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Runs formatting, linting, and type checking tests.""" + import subprocess import sys diff --git a/src/viser/transforms/_base.py b/src/viser/transforms/_base.py index c0e07ec74..f78b52b4a 100644 --- a/src/viser/transforms/_base.py +++ b/src/viser/transforms/_base.py @@ -36,12 +36,10 @@ def __init__(self, parameters: onpt.NDArray[onp.floating], /): # Shared implementations. @overload - def __matmul__(self, other: hints.Array) -> onpt.NDArray[onp.floating]: - ... + def __matmul__(self, other: hints.Array) -> onpt.NDArray[onp.floating]: ... @overload - def __matmul__(self: GroupType, other: GroupType) -> GroupType: - ... + def __matmul__(self: GroupType, other: GroupType) -> GroupType: ... def __matmul__( self: GroupType, other: Union[GroupType, hints.Array]