From c2c082f947a6d256744e4cb8730635ba2f9c0d30 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 2 Apr 2024 13:49:17 -0700 Subject: [PATCH 1/7] Skinned mesh draft --- src/viser/_message_api.py | 89 +++++++++++ src/viser/_messages.py | 32 ++++ src/viser/client/src/App.tsx | 6 + src/viser/client/src/SceneTree.tsx | 31 +++- src/viser/client/src/WebsocketInterface.tsx | 159 +++++++++++++++++--- src/viser/client/src/WebsocketMessages.tsx | 12 ++ 6 files changed, 299 insertions(+), 30 deletions(-) diff --git a/src/viser/_message_api.py b/src/viser/_message_api.py index ef86bf759..438286604 100644 --- a/src/viser/_message_api.py +++ b/src/viser/_message_api.py @@ -799,6 +799,92 @@ def add_mesh(self, *args, **kwargs) -> MeshHandle: """Deprecated alias for `add_mesh_simple()`.""" return self.add_mesh_simple(*args, **kwargs) + def add_mesh_skinned( + self, + name: str, + vertices: onp.ndarray, + faces: onp.ndarray, + bone_wxyzs: Tuple[Tuple[float, float, float, float], ...] | onp.ndarray, + bone_positions: Tuple[Tuple[float, float, float], ...] | onp.ndarray, + skin_weights: onp.ndarray, + color: RgbTupleOrArray = (90, 200, 255), + wireframe: bool = False, + opacity: Optional[float] = None, + material: Literal["standard", "toon3", "toon5"] = "standard", + flat_shading: bool = False, + side: Literal["front", "back", "double"] = "front", + wxyz: Tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), + position: Tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), + visible: bool = True, + ) -> MeshHandle: + """Add a mesh to the scene. + + Args: + name: A scene tree name. Names in the format of /parent/child can be used to + define a kinematic tree. + vertices: A numpy array of vertex positions. Should have shape (V, 3). + faces: A numpy array of faces, where each face is represented by indices of + vertices. Should have shape (F,) + bone_handles: Tuple of scene node handles. A bone will be attached to each. + skin_weights: A numpy array of skin weights. Should have shape (V, B) where B + is the number of bones. + color: Color of the mesh as an RGB tuple. + wireframe: Boolean indicating if the mesh should be rendered as a wireframe. + opacity: Opacity of the mesh. None means opaque. + material: Material type of the mesh ('standard', 'toon3', 'toon5'). + This argument is ignored when wireframe=True. + flat_shading: Whether to do flat shading. This argument is ignored + when wireframe=True. + side: Side of the surface to render ('front', 'back', 'double'). + wxyz: Quaternion rotation to parent frame from local frame (R_pl). + position: Translation from parent frame to local frame (t_pl). + visible: Whether or not this mesh is initially visible. + + Returns: + Handle for manipulating scene node. + """ + if wireframe and material != "standard": + warnings.warn( + f"Invalid combination of {wireframe=} and {material=}. Material argument will be ignored.", + stacklevel=2, + ) + if wireframe and flat_shading: + warnings.warn( + f"Invalid combination of {wireframe=} and {flat_shading=}. Flat shading argument will be ignored.", + stacklevel=2, + ) + + assert skin_weights.shape == (vertices.shape[0], len(bone_handles)) + + # Take the four biggest indices. + top4_skin_indices = onp.argsort(skin_weights, axis=-1)[:, -4:] + top4_skin_weights = skin_weights[ + onp.arange(vertices.shape[0])[:, None], top4_skin_indices + ] + assert ( + top4_skin_weights.shape == top4_skin_indices.shape == (vertices.shape[0], 4) + ) + self._queue( + _messages.MeshMessage( + name, + vertices.astype(onp.float32), + faces.astype(onp.uint32), + # (255, 255, 255) => 0xffffff, etc + color=_encode_rgb(color), + vertex_colors=None, + wireframe=wireframe, + opacity=opacity, + flat_shading=flat_shading, + side=side, + material=material, + bone_wxyzs=bone_wxyzs.astype(onp.float32), + bone_positions=bone_positions.astype(onp.float32), + skin_indices=top4_skin_indices.astype(onp.uint16), + skin_weights=top4_skin_weights.astype(onp.float32), + ) + ) + return MeshHandle._make(self, name, wxyz, position, visible) + def add_mesh_simple( self, name: str, @@ -861,6 +947,9 @@ def add_mesh_simple( flat_shading=flat_shading, side=side, material=material, + bone_names=None, + skin_indices=None, + skin_weights=None, ) ) return MeshHandle._make(self, name, wxyz, position, visible) diff --git a/src/viser/_messages.py b/src/viser/_messages.py index c81dfd175..eabac1744 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -223,6 +223,13 @@ def __post_init__(self): assert self.colors.dtype == onp.uint8 +@dataclasses.dataclass +class MeshBoneMessage(Message): + """Message for a bone of a skinned mesh.""" + + name: str + + @dataclasses.dataclass class MeshMessage(Message): """Mesh message. @@ -248,6 +255,31 @@ def __post_init__(self): assert self.faces.shape[-1] == 3 +@dataclasses.dataclass +class SkinnedMeshMessage(MeshMessage): + """Mesh message. + + Vertices are internally canonicalized to float32, faces to uint32.""" + + bone_wxyzs: onpt.NDArray[onp.float32] + bone_positions: onpt.NDArray[onp.float32] + skin_indices: onpt.NDArray[onp.uint32] + skin_weights: onpt.NDArray[onp.float32] + + def __post_init__(self): + # Check shapes. + assert self.vertices.shape[-1] == 3 + assert self.faces.shape[-1] == 3 + assert self.skin_weights is not None + assert ( + self.skin_indices.shape + == self.skin_weights.shape + == (self.vertices.shape[0], 4) + ) + assert self.bone_wxyzs.shape[-1] == 4 + assert self.bone_positions.shape[-1] == 3 + + @dataclasses.dataclass class TransformControlsMessage(Message): """Message for transform gizmos.""" diff --git a/src/viser/client/src/App.tsx b/src/viser/client/src/App.tsx index 699473ed0..0552a1275 100644 --- a/src/viser/client/src/App.tsx +++ b/src/viser/client/src/App.tsx @@ -54,6 +54,8 @@ export type ViewerContextContents = { useSceneTree: UseSceneTree; useGui: UseGui; // Useful references. + // TODO: there's really no reason these all need to be their own ref objects. + // We could have just one ref to a global mutable struct. websocketRef: React.MutableRefObject; canvasRef: React.MutableRefObject; sceneRef: React.MutableRefObject; @@ -75,6 +77,9 @@ export type ViewerContextContents = { overrideVisibility?: boolean; // Override from the GUI. }; }>; + nodeRefFromName: React.MutableRefObject<{ + [name: string]: undefined | THREE.Object3D; + }>; messageQueueRef: React.MutableRefObject; // Requested a render. getRenderRequestState: React.MutableRefObject< @@ -137,6 +142,7 @@ function ViewerRoot() { })(), }, }), + nodeRefFromName: React.useRef({}), messageQueueRef: React.useRef([]), getRenderRequestState: React.useRef("ready"), getRenderRequest: React.useRef(null), diff --git a/src/viser/client/src/SceneTree.tsx b/src/viser/client/src/SceneTree.tsx index 4f20aa818..5ac47efbb 100644 --- a/src/viser/client/src/SceneTree.tsx +++ b/src/viser/client/src/SceneTree.tsx @@ -32,7 +32,8 @@ export class SceneNode { * * https://github.com/pmndrs/drei/issues/1323 */ - public readonly unmountWhenInvisible?: true, + public readonly unmountWhenInvisible?: boolean, + public readonly everyFrameCallback?: () => void, ) { this.children = []; this.clickable = false; @@ -136,17 +137,20 @@ export function SceneNodeThreeObject(props: { const unmountWhenInvisible = viewer.useSceneTree( (state) => state.nodeFromName[props.name]?.unmountWhenInvisible, ); + const everyFrameCallback = viewer.useSceneTree( + (state) => state.nodeFromName[props.name]?.everyFrameCallback, + ); const [unmount, setUnmount] = React.useState(false); const clickable = viewer.useSceneTree((state) => state.nodeFromName[props.name]?.clickable) ?? false; const [obj, setRef] = React.useState(null); - const dragInfo = React.useRef({ - dragging: false, - startClientX: 0, - startClientY: 0, - }); + // Update global registry of node objects. + // This is used for updating bone transforms in skinned meshes. + React.useEffect(() => { + if (obj !== null) viewer.nodeRefFromName.current[props.name] = obj; + }, [obj]); // Create object + children. // @@ -206,6 +210,7 @@ export function SceneNodeThreeObject(props: { // although this shouldn't be a bottleneck. useFrame(() => { const attrs = viewer.nodeAttributesFromName.current[props.name]; + everyFrameCallback && everyFrameCallback(); // Unmount when invisible. // Examples: components, PivotControls. @@ -252,7 +257,12 @@ export function SceneNodeThreeObject(props: { }); // Clean up when done. - React.useEffect(() => cleanup); + React.useEffect(() => { + return () => { + cleanup && cleanup(); + delete viewer.nodeRefFromName.current[props.name]; + }; + }); // Clicking logic. const sendClicksThrottled = makeThrottledMessageSender( @@ -264,6 +274,12 @@ export function SceneNodeThreeObject(props: { const hoveredRef = React.useRef(false); if (!clickable && hovered) setHovered(false); + const dragInfo = React.useRef({ + dragging: false, + startClientX: 0, + startClientY: 0, + }); + if (objNode === undefined || unmount) { return <>{children}; } else if (clickable) { @@ -327,7 +343,6 @@ export function SceneNodeThreeObject(props: { }); }} onPointerOver={(e) => { - console.log("over"); if (!isDisplayed()) return; e.stopPropagation(); setHovered(true); diff --git a/src/viser/client/src/WebsocketInterface.tsx b/src/viser/client/src/WebsocketInterface.tsx index c6840e2aa..403ee2278 100644 --- a/src/viser/client/src/WebsocketInterface.tsx +++ b/src/viser/client/src/WebsocketInterface.tsx @@ -30,7 +30,7 @@ import { sendWebsocketMessage, } from "./WebsocketFunctions"; import { isGuiConfig } from "./ControlPanel/GuiState"; -import { useFrame } from "@react-three/fiber"; +import { useFrame, useThree } from "@react-three/fiber"; import GeneratedGuiContainer from "./ControlPanel/Generated"; import { MantineProvider, Paper, Progress } from "@mantine/core"; import { IconCheck } from "@tabler/icons-react"; @@ -71,6 +71,8 @@ function useMessageHandler() { const setClickable = viewer.useSceneTree((state) => state.setClickable); const updateUploadState = viewer.useGui((state) => state.updateUploadState); + const scene = useThree((state) => state.scene); + // Same as addSceneNode, but make a parent in the form of a dummy coordinate // frame if it doesn't exist yet. function addSceneNodeMakeParents(node: SceneNode) { @@ -85,10 +87,10 @@ function useMessageHandler() { // Make sure parents exists. const nodeFromName = viewer.useSceneTree.getState().nodeFromName; - const parent_name = node.name.split("/").slice(0, -1).join("/"); - if (!(parent_name in nodeFromName)) { + const parentName = node.name.split("/").slice(0, -1).join("/"); + if (!(parentName in nodeFromName)) { addSceneNodeMakeParents( - new SceneNode(parent_name, (ref) => ( + new SceneNode(parentName, (ref) => ( )), ); @@ -363,25 +365,138 @@ function useMessageHandler() { ); geometry.computeVertexNormals(); geometry.computeBoundingSphere(); - addSceneNodeMakeParents( - new SceneNode( - message.name, - (ref) => { - return ( - - - + const cleanupMesh = () => { + // TODO: we can switch to the react-three-fiber , + // , etc components to avoid manual + // disposal. + geometry.dispose(); + material.dispose(); + }; + if (message.skin_indices === null) + // Normal mesh. + addSceneNodeMakeParents( + new SceneNode( + message.name, + (ref) => { + return ( + + + + ); + }, + cleanupMesh, + ), + ); + else { + const getT_world_local: (name: string) => THREE.Matrix4 = ( + name: string, + ) => { + const T_current_local = new THREE.Matrix4().identity(); + const T_parent_current = new THREE.Matrix4().identity(); + let done = false; + while (!done) { + const attrs = viewer.nodeAttributesFromName.current[name]; + let wxyz = attrs?.wxyz; + if (wxyz === undefined) wxyz = [1, 0, 0, 0]; + T_parent_current.makeRotationFromQuaternion( + new THREE.Quaternion(wxyz[1], wxyz[2], wxyz[3], wxyz[0]), ); - }, - () => { - // TODO: we can switch to the react-three-fiber , - // , etc components to avoid manual - // disposal. - geometry.dispose(); - material.dispose(); - }, - ), - ); + let position = attrs?.position; + if (position === undefined) position = [0, 0, 0]; + T_parent_current.setPosition( + new THREE.Vector3(position[0], position[1], position[2]), + ); + + T_current_local.premultiply(T_parent_current); + if (name === "") break; + name = name.split("/").slice(0, -1).join("/"); + console.log(name); + } + return T_current_local; + }; + // Skinned mesh. + const bones: THREE.Bone[] = []; + for (let i = 0; i < message.bone_names!.length; i++) { + bones.push(new THREE.Bone()); + } + bones.forEach((bone, i) => { + scene.add(bone); + bone.matrix.copy(getT_world_local(message.bone_names![i])); + bone.matrixWorld.copy(getT_world_local(message.bone_names![i])); + // We'll manage the bone matrices manually. + bone.matrixAutoUpdate = false; + bone.matrixWorldAutoUpdate = false; + }); + const skeleton = new THREE.Skeleton(bones); + + geometry.setAttribute( + "skinIndex", + new THREE.Uint16BufferAttribute( + new Uint16Array( + message.skin_indices.buffer.slice( + message.skin_indices.byteOffset, + message.skin_indices.byteOffset + + message.skin_indices.byteLength, + ), + ), + 4, + ), + ); + geometry.setAttribute( + "skinWeight", + new THREE.Float32BufferAttribute( + new Float32Array( + message.skin_weights!.buffer.slice( + message.skin_weights!.byteOffset, + message.skin_weights!.byteOffset + + message.skin_weights!.byteLength, + ), + ), + 4, + ), + ); + addSceneNodeMakeParents( + new SceneNode( + message.name, + (ref) => { + return ( + + + + ); + }, + () => { + bones.forEach((bone) => { + bone.remove(); + }); + skeleton.dispose(); + cleanupMesh(); + }, + false, + // everyFrameCallback: update bone transforms. + () => { + bones.forEach((bone, i) => { + const nodeRef = + viewer.nodeRefFromName.current[message.bone_names![i]]; + if (nodeRef !== undefined) { + // Our bone objects are placed in the scene root! + // bone.matrix.copy(nodeRef?.matrixWorld); + // bone.matrixWorld.copy(nodeRef?.matrixWorld); + bone.matrix.copy(getT_world_local(message.bone_names![i])); + bone.matrixWorld.copy( + getT_world_local(message.bone_names![i]), + ); + } + }); + }, + ), + ); + } return; } // Add a camera frustum. diff --git a/src/viser/client/src/WebsocketMessages.tsx b/src/viser/client/src/WebsocketMessages.tsx index dbc670457..6a954fac6 100644 --- a/src/viser/client/src/WebsocketMessages.tsx +++ b/src/viser/client/src/WebsocketMessages.tsx @@ -145,6 +145,14 @@ export interface PointCloudMessage { point_size: number; point_ball_norm: number; } +/** Message for a bone of a skinned mesh. + * + * (automatically generated) + */ +export interface MeshBoneMessage { + type: "MeshBoneMessage"; + name: string; +} /** Mesh message. * * Vertices are internally canonicalized to float32, faces to uint32. @@ -163,6 +171,9 @@ export interface MeshMessage { flat_shading: boolean; side: "front" | "back" | "double"; material: "standard" | "toon3" | "toon5"; + bone_names: string[] | null; + skin_indices: Uint8Array | null; + skin_weights: Uint8Array | null; } /** Message for transform gizmos. * @@ -826,6 +837,7 @@ export type Message = | LabelMessage | Gui3DMessage | PointCloudMessage + | MeshBoneMessage | MeshMessage | TransformControlsMessage | SetCameraPositionMessage From d3a78d2d2cac43778e173013fc976fdfb23ab17d Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Wed, 24 Apr 2024 23:30:20 -0700 Subject: [PATCH 2/7] Changes --- src/viser/_message_api.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/viser/_message_api.py b/src/viser/_message_api.py index 438286604..2c70974cb 100644 --- a/src/viser/_message_api.py +++ b/src/viser/_message_api.py @@ -817,7 +817,7 @@ def add_mesh_skinned( position: Tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> MeshHandle: - """Add a mesh to the scene. + """Add a skinned mesh to the scene, which we can deform using a set of bone transformations. Args: name: A scene tree name. Names in the format of /parent/child can be used to @@ -854,7 +854,7 @@ def add_mesh_skinned( stacklevel=2, ) - assert skin_weights.shape == (vertices.shape[0], len(bone_handles)) + assert skin_weights.shape == (vertices.shape[0], len(bone_wxyzs)) # Take the four biggest indices. top4_skin_indices = onp.argsort(skin_weights, axis=-1)[:, -4:] @@ -864,8 +864,11 @@ def add_mesh_skinned( assert ( top4_skin_weights.shape == top4_skin_indices.shape == (vertices.shape[0], 4) ) + + bone_wxyzs = onp.asarray(bone_wxyzs) + bone_positions = onp.asarray(bone_positions) self._queue( - _messages.MeshMessage( + _messages.SkinnedMeshMessage( name, vertices.astype(onp.float32), faces.astype(onp.uint32), @@ -947,9 +950,6 @@ def add_mesh_simple( flat_shading=flat_shading, side=side, material=material, - bone_names=None, - skin_indices=None, - skin_weights=None, ) ) return MeshHandle._make(self, name, wxyz, position, visible) From 1932561122f31ee6d239e7f3759a031df7614eb4 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Sun, 28 Apr 2024 19:55:52 -0700 Subject: [PATCH 3/7] Sync --- src/viser/_message_api.py | 14 +++++++---- src/viser/_scene_handles.py | 5 ++++ src/viser/client/src/WebsocketInterface.tsx | 5 ++-- src/viser/client/src/WebsocketMessages.tsx | 27 ++++++++++++++++++--- 4 files changed, 41 insertions(+), 10 deletions(-) diff --git a/src/viser/_message_api.py b/src/viser/_message_api.py index 2c70974cb..52710f61e 100644 --- a/src/viser/_message_api.py +++ b/src/viser/_message_api.py @@ -51,6 +51,7 @@ SceneNodeHandle, SceneNodePointerEvent, ScenePointerEvent, + SkinnedMeshHandle, TransformControlsHandle, _TransformControlsState, ) @@ -816,8 +817,9 @@ def add_mesh_skinned( wxyz: Tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: Tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, - ) -> MeshHandle: - """Add a skinned mesh to the scene, which we can deform using a set of bone transformations. + ) -> SkinnedMeshHandle: + """Add a skinned mesh to the scene, which we can deform using a set of + bone transformations. Args: name: A scene tree name. Names in the format of /parent/child can be used to @@ -825,9 +827,11 @@ def add_mesh_skinned( vertices: A numpy array of vertex positions. Should have shape (V, 3). faces: A numpy array of faces, where each face is represented by indices of vertices. Should have shape (F,) - bone_handles: Tuple of scene node handles. A bone will be attached to each. + bone_wxyzs: Nested tuple or array of initial bone orientations. + bone_positions: Nested tuple or array of initial bone positions. skin_weights: A numpy array of skin weights. Should have shape (V, B) where B - is the number of bones. + is the number of bones. Only the top 4 bone weights for each + vertex will be used. color: Color of the mesh as an RGB tuple. wireframe: Boolean indicating if the mesh should be rendered as a wireframe. opacity: Opacity of the mesh. None means opaque. @@ -886,7 +890,7 @@ def add_mesh_skinned( skin_weights=top4_skin_weights.astype(onp.float32), ) ) - return MeshHandle._make(self, name, wxyz, position, visible) + return SkinnedMeshHandle._make(self, name, wxyz, position, visible) def add_mesh_simple( self, diff --git a/src/viser/_scene_handles.py b/src/viser/_scene_handles.py index a2022d818..c557d29fd 100644 --- a/src/viser/_scene_handles.py +++ b/src/viser/_scene_handles.py @@ -214,6 +214,11 @@ class MeshHandle(_ClickableSceneNodeHandle): """Handle for mesh objects.""" +@dataclasses.dataclass +class SkinnedMeshHandle(_ClickableSceneNodeHandle): + """Handle for skinned mesh objects.""" + + @dataclasses.dataclass class GlbHandle(_ClickableSceneNodeHandle): """Handle for GLB objects.""" diff --git a/src/viser/client/src/WebsocketInterface.tsx b/src/viser/client/src/WebsocketInterface.tsx index 1d8c4dba3..d3f4277ba 100644 --- a/src/viser/client/src/WebsocketInterface.tsx +++ b/src/viser/client/src/WebsocketInterface.tsx @@ -281,6 +281,7 @@ function useMessageHandler() { } // Add mesh + case "SkinnedMeshMessage": case "MeshMessage": { const geometry = new THREE.BufferGeometry(); @@ -372,7 +373,7 @@ function useMessageHandler() { geometry.dispose(); material.dispose(); }; - if (message.skin_indices === null) + if (message.type === "MeshMessage") // Normal mesh. addSceneNodeMakeParents( new SceneNode( @@ -387,7 +388,7 @@ function useMessageHandler() { cleanupMesh, ), ); - else { + else if (message.type === "SkinMeshMessage") { const getT_world_local: (name: string) => THREE.Matrix4 = ( name: string, ) => { diff --git a/src/viser/client/src/WebsocketMessages.tsx b/src/viser/client/src/WebsocketMessages.tsx index 6a954fac6..275657926 100644 --- a/src/viser/client/src/WebsocketMessages.tsx +++ b/src/viser/client/src/WebsocketMessages.tsx @@ -171,9 +171,29 @@ export interface MeshMessage { flat_shading: boolean; side: "front" | "back" | "double"; material: "standard" | "toon3" | "toon5"; - bone_names: string[] | null; - skin_indices: Uint8Array | null; - skin_weights: Uint8Array | null; +} +/** Mesh message. + * + * Vertices are internally canonicalized to float32, faces to uint32. + * + * (automatically generated) + */ +export interface SkinnedMeshMessage { + type: "SkinnedMeshMessage"; + name: string; + vertices: Uint8Array; + faces: Uint8Array; + color: number | null; + vertex_colors: Uint8Array | null; + wireframe: boolean; + opacity: number | null; + flat_shading: boolean; + side: "front" | "back" | "double"; + material: "standard" | "toon3" | "toon5"; + bone_wxyzs: Uint8Array; + bone_positions: Uint8Array; + skin_indices: Uint8Array; + skin_weights: Uint8Array; } /** Message for transform gizmos. * @@ -839,6 +859,7 @@ export type Message = | PointCloudMessage | MeshBoneMessage | MeshMessage + | SkinnedMeshMessage | TransformControlsMessage | SetCameraPositionMessage | SetCameraUpDirectionMessage From 6ae8e0c914427ea29cadae276b3e0a12fd56e163 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Mon, 29 Apr 2024 01:16:32 -0700 Subject: [PATCH 4/7] working again --- src/viser/_message_api.py | 39 ++++- src/viser/_messages.py | 36 ++++- src/viser/_scene_handles.py | 58 ++++++++ src/viser/client/src/App.tsx | 11 ++ src/viser/client/src/WebsocketInterface.tsx | 152 ++++++++++++-------- src/viser/client/src/WebsocketMessages.tsx | 30 +++- 6 files changed, 253 insertions(+), 73 deletions(-) diff --git a/src/viser/_message_api.py b/src/viser/_message_api.py index 52710f61e..31c5faea9 100644 --- a/src/viser/_message_api.py +++ b/src/viser/_message_api.py @@ -40,6 +40,8 @@ from . import transforms as tf from ._scene_handles import ( BatchedAxesHandle, + BoneHandle, + BoneState, CameraFrustumHandle, FrameHandle, GlbHandle, @@ -858,7 +860,8 @@ def add_mesh_skinned( stacklevel=2, ) - assert skin_weights.shape == (vertices.shape[0], len(bone_wxyzs)) + num_bones = len(bone_wxyzs) + assert skin_weights.shape == (vertices.shape[0], num_bones) # Take the four biggest indices. top4_skin_indices = onp.argsort(skin_weights, axis=-1)[:, -4:] @@ -871,6 +874,8 @@ def add_mesh_skinned( bone_wxyzs = onp.asarray(bone_wxyzs) bone_positions = onp.asarray(bone_positions) + assert bone_wxyzs.shape == (num_bones, 4) + assert bone_positions.shape == (num_bones, 3) self._queue( _messages.SkinnedMeshMessage( name, @@ -884,13 +889,39 @@ def add_mesh_skinned( flat_shading=flat_shading, side=side, material=material, - bone_wxyzs=bone_wxyzs.astype(onp.float32), - bone_positions=bone_positions.astype(onp.float32), + bone_wxyzs=tuple( + ( + float(wxyz[0]), + float(wxyz[1]), + float(wxyz[2]), + float(wxyz[3]), + ) + for wxyz in bone_wxyzs.astype(onp.float32) + ), + bone_positions=tuple( + (float(xyz[0]), float(xyz[1]), float(xyz[2])) + for xyz in bone_positions.astype(onp.float32) + ), skin_indices=top4_skin_indices.astype(onp.uint16), skin_weights=top4_skin_weights.astype(onp.float32), ) ) - return SkinnedMeshHandle._make(self, name, wxyz, position, visible) + handle = MeshHandle._make(self, name, wxyz, position, visible) + return SkinnedMeshHandle( + handle._impl, + bones=tuple( + BoneHandle( + _impl=BoneState( + name=name, + api=self, + bone_index=i, + wxyz=bone_wxyzs[i], + position=bone_positions[i], + ) + ) + for i in range(num_bones) + ), + ) def add_mesh_simple( self, diff --git a/src/viser/_messages.py b/src/viser/_messages.py index eabac1744..efc17f19b 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -261,8 +261,8 @@ class SkinnedMeshMessage(MeshMessage): Vertices are internally canonicalized to float32, faces to uint32.""" - bone_wxyzs: onpt.NDArray[onp.float32] - bone_positions: onpt.NDArray[onp.float32] + bone_wxyzs: Tuple[Tuple[float, float, float, float], ...] + bone_positions: Tuple[Tuple[float, float, float], ...] skin_indices: onpt.NDArray[onp.uint32] skin_weights: onpt.NDArray[onp.float32] @@ -276,8 +276,36 @@ def __post_init__(self): == self.skin_weights.shape == (self.vertices.shape[0], 4) ) - assert self.bone_wxyzs.shape[-1] == 4 - assert self.bone_positions.shape[-1] == 3 + + +@dataclasses.dataclass +class SetBoneOrientationMessage(Message): + """Server -> client message to set a skinned mesh bone's orientation. + + As with all other messages, transforms take the `T_parent_local` convention.""" + + name: str + bone_index: int + wxyz: Tuple[float, float, float, float] + + @override + def redundancy_key(self) -> str: + return type(self).__name__ + "-" + self.name + "-" + str(self.bone_index) + + +@dataclasses.dataclass +class SetBonePositionMessage(Message): + """Server -> client message to set a skinned mesh bone's position. + + As with all other messages, transforms take the `T_parent_local` convention.""" + + name: str + bone_index: int + position: Tuple[float, float, float] + + @override + def redundancy_key(self) -> str: + return type(self).__name__ + "-" + self.name + "-" + str(self.bone_index) @dataclasses.dataclass diff --git a/src/viser/_scene_handles.py b/src/viser/_scene_handles.py index c557d29fd..bcb25776b 100644 --- a/src/viser/_scene_handles.py +++ b/src/viser/_scene_handles.py @@ -218,6 +218,64 @@ class MeshHandle(_ClickableSceneNodeHandle): class SkinnedMeshHandle(_ClickableSceneNodeHandle): """Handle for skinned mesh objects.""" + bones: Tuple[BoneHandle, ...] + """Bones of the skinned mesh. These handles can be used for reading and + writing poses, which are defined relative to the mesh root.""" + + +@dataclasses.dataclass +class BoneState: + name: str + api: MessageApi + bone_index: int + wxyz: onp.ndarray + position: onp.ndarray + + +@dataclasses.dataclass +class BoneHandle: + """Handle for reading and writing the poses of bones in a skinned mesh.""" + + _impl: BoneState + + @property + def wxyz(self) -> onp.ndarray: + """Orientation of the bone. This is the quaternion representation of the R + in `p_parent = [R | t] p_local`. Synchronized to clients automatically when assigned. + """ + return self._impl.wxyz + + @wxyz.setter + def wxyz(self, wxyz: Tuple[float, float, float, float] | onp.ndarray) -> None: + from ._message_api import cast_vector + + wxyz_cast = cast_vector(wxyz, 4) + self._impl.wxyz = onp.asarray(wxyz) + self._impl.api._queue( + _messages.SetBoneOrientationMessage( + self._impl.name, self._impl.bone_index, wxyz_cast + ) + ) + + @property + def position(self) -> onp.ndarray: + """Position of the bone. This is equivalent to the t in + `p_parent = [R | t] p_local`. Synchronized to clients automatically when assigned. + """ + return self._impl.position + + @position.setter + def position(self, position: Tuple[float, float, float] | onp.ndarray) -> None: + from ._message_api import cast_vector + + position_cast = cast_vector(position, 3) + self._impl.position = onp.asarray(position) + self._impl.api._queue( + _messages.SetBonePositionMessage( + self._impl.name, self._impl.bone_index, position_cast + ) + ) + @dataclasses.dataclass class GlbHandle(_ClickableSceneNodeHandle): diff --git a/src/viser/client/src/App.tsx b/src/viser/client/src/App.tsx index 0552a1275..bbe185866 100644 --- a/src/viser/client/src/App.tsx +++ b/src/viser/client/src/App.tsx @@ -94,6 +94,16 @@ export type ViewerContextContents = { }>; // 2D canvas for drawing -- can be used to give feedback on cursor movement, or more. canvas2dRef: React.MutableRefObject; + // Poses for bones in skinned meshes. + skinnedMeshState: React.MutableRefObject<{ + [name: string]: { + initialized: boolean; + poses: { + wxyz: [number, number, number, number]; + position: [number, number, number]; + }[]; + }; + }>; }; export const ViewerContext = React.createContext( null, @@ -152,6 +162,7 @@ function ViewerRoot() { listening: false, }), canvas2dRef: React.useRef(null), + skinnedMeshState: React.useRef({}), }; return ( diff --git a/src/viser/client/src/WebsocketInterface.tsx b/src/viser/client/src/WebsocketInterface.tsx index d3f4277ba..f59712327 100644 --- a/src/viser/client/src/WebsocketInterface.tsx +++ b/src/viser/client/src/WebsocketInterface.tsx @@ -227,16 +227,20 @@ function useMessageHandler() { message.plane == "xz" ? new THREE.Euler(0.0, 0.0, 0.0) : message.plane == "xy" - ? new THREE.Euler(Math.PI / 2.0, 0.0, 0.0) - : message.plane == "yx" - ? new THREE.Euler(0.0, Math.PI / 2.0, Math.PI / 2.0) - : message.plane == "yz" - ? new THREE.Euler(0.0, 0.0, Math.PI / 2.0) - : message.plane == "zx" - ? new THREE.Euler(0.0, Math.PI / 2.0, 0.0) - : message.plane == "zy" - ? new THREE.Euler(-Math.PI / 2.0, 0.0, -Math.PI / 2.0) - : undefined + ? new THREE.Euler(Math.PI / 2.0, 0.0, 0.0) + : message.plane == "yx" + ? new THREE.Euler(0.0, Math.PI / 2.0, Math.PI / 2.0) + : message.plane == "yz" + ? new THREE.Euler(0.0, 0.0, Math.PI / 2.0) + : message.plane == "zx" + ? new THREE.Euler(0.0, Math.PI / 2.0, 0.0) + : message.plane == "zy" + ? new THREE.Euler( + -Math.PI / 2.0, + 0.0, + -Math.PI / 2.0, + ) + : undefined } /> @@ -324,16 +328,16 @@ function useMessageHandler() { message.material == "standard" || message.wireframe ? new THREE.MeshStandardMaterial(standardArgs) : message.material == "toon3" - ? new THREE.MeshToonMaterial({ - gradientMap: generateGradientMap(3), - ...standardArgs, - }) - : message.material == "toon5" - ? new THREE.MeshToonMaterial({ - gradientMap: generateGradientMap(5), - ...standardArgs, - }) - : assertUnreachable(message.material); + ? new THREE.MeshToonMaterial({ + gradientMap: generateGradientMap(3), + ...standardArgs, + }) + : message.material == "toon5" + ? new THREE.MeshToonMaterial({ + gradientMap: generateGradientMap(5), + ...standardArgs, + }) + : assertUnreachable(message.material); geometry.setAttribute( "position", new THREE.Float32BufferAttribute( @@ -388,47 +392,41 @@ function useMessageHandler() { cleanupMesh, ), ); - else if (message.type === "SkinMeshMessage") { - const getT_world_local: (name: string) => THREE.Matrix4 = ( - name: string, - ) => { - const T_current_local = new THREE.Matrix4().identity(); - const T_parent_current = new THREE.Matrix4().identity(); - let done = false; - while (!done) { - const attrs = viewer.nodeAttributesFromName.current[name]; - let wxyz = attrs?.wxyz; - if (wxyz === undefined) wxyz = [1, 0, 0, 0]; - T_parent_current.makeRotationFromQuaternion( - new THREE.Quaternion(wxyz[1], wxyz[2], wxyz[3], wxyz[0]), - ); - let position = attrs?.position; - if (position === undefined) position = [0, 0, 0]; - T_parent_current.setPosition( - new THREE.Vector3(position[0], position[1], position[2]), - ); - - T_current_local.premultiply(T_parent_current); - if (name === "") break; - name = name.split("/").slice(0, -1).join("/"); - console.log(name); - } - return T_current_local; - }; + else if (message.type === "SkinnedMeshMessage") { // Skinned mesh. const bones: THREE.Bone[] = []; - for (let i = 0; i < message.bone_names!.length; i++) { + for (let i = 0; i < message.bone_wxyzs!.length; i++) { bones.push(new THREE.Bone()); } + + const xyzw_quat = new THREE.Quaternion(); + const boneInverses: THREE.Matrix4[] = []; + viewer.skinnedMeshState.current[message.name] = { + initialized: false, + poses: [], + }; bones.forEach((bone, i) => { - scene.add(bone); - bone.matrix.copy(getT_world_local(message.bone_names![i])); - bone.matrixWorld.copy(getT_world_local(message.bone_names![i])); - // We'll manage the bone matrices manually. + const wxyz = message.bone_wxyzs[i]; + const position = message.bone_positions[i]; + xyzw_quat.set(wxyz[1], wxyz[2], wxyz[3], wxyz[0]); + + const boneInverse = new THREE.Matrix4(); + boneInverse.makeRotationFromQuaternion(xyzw_quat); + boneInverse.setPosition(position[0], position[1], position[2]); + boneInverse.invert(); + boneInverses.push(boneInverse); + + bone.quaternion.copy(xyzw_quat); + bone.position.set(position[0], position[1], position[2]); bone.matrixAutoUpdate = false; bone.matrixWorldAutoUpdate = false; + + viewer.skinnedMeshState.current[message.name].poses.push({ + wxyz: wxyz, + position: position, + }); }); - const skeleton = new THREE.Skeleton(bones); + const skeleton = new THREE.Skeleton(bones, boneInverses); geometry.setAttribute( "skinIndex", @@ -456,6 +454,7 @@ function useMessageHandler() { 4, ), ); + addSceneNodeMakeParents( new SceneNode( message.name, @@ -481,25 +480,52 @@ function useMessageHandler() { false, // everyFrameCallback: update bone transforms. () => { + const parentNode = viewer.nodeRefFromName.current[message.name]; + if (parentNode === undefined) return; + + const state = viewer.skinnedMeshState.current[message.name]; bones.forEach((bone, i) => { - const nodeRef = - viewer.nodeRefFromName.current[message.bone_names![i]]; - if (nodeRef !== undefined) { - // Our bone objects are placed in the scene root! - // bone.matrix.copy(nodeRef?.matrixWorld); - // bone.matrixWorld.copy(nodeRef?.matrixWorld); - bone.matrix.copy(getT_world_local(message.bone_names![i])); - bone.matrixWorld.copy( - getT_world_local(message.bone_names![i]), - ); + if (!state.initialized) { + parentNode.add(bone); } + const wxyz = state.initialized + ? state.poses[i].wxyz + : message.bone_wxyzs[i]; + const position = state.initialized + ? state.poses[i].position + : message.bone_positions[i]; + + xyzw_quat.set(wxyz[1], wxyz[2], wxyz[3], wxyz[0]); + bone.matrix.makeRotationFromQuaternion(xyzw_quat); + bone.matrix.setPosition( + position[0], + position[1], + position[2], + ); + bone.updateMatrixWorld(); }); + + if (!state.initialized) { + state.initialized = true; + } }, ), ); } return; } + // Set the bone poses. + case "SetBoneOrientationMessage": { + const bonePoses = viewer.skinnedMeshState.current; + bonePoses[message.name].poses[message.bone_index].wxyz = message.wxyz; + break; + } + case "SetBonePositionMessage": { + const bonePoses = viewer.skinnedMeshState.current; + bonePoses[message.name].poses[message.bone_index].position = + message.position; + break; + } // Add a camera frustum. case "CameraFrustumMessage": { const texture = diff --git a/src/viser/client/src/WebsocketMessages.tsx b/src/viser/client/src/WebsocketMessages.tsx index 275657926..e6b1b9e00 100644 --- a/src/viser/client/src/WebsocketMessages.tsx +++ b/src/viser/client/src/WebsocketMessages.tsx @@ -190,11 +190,35 @@ export interface SkinnedMeshMessage { flat_shading: boolean; side: "front" | "back" | "double"; material: "standard" | "toon3" | "toon5"; - bone_wxyzs: Uint8Array; - bone_positions: Uint8Array; + bone_wxyzs: [number, number, number, number][]; + bone_positions: [number, number, number][]; skin_indices: Uint8Array; skin_weights: Uint8Array; } +/** Server -> client message to set a skinned mesh bone's orientation. + * + * As with all other messages, transforms take the `T_parent_local` convention. + * + * (automatically generated) + */ +export interface SetBoneOrientationMessage { + type: "SetBoneOrientationMessage"; + name: string; + bone_index: number; + wxyz: [number, number, number, number]; +} +/** Server -> client message to set a skinned mesh bone's position. + * + * As with all other messages, transforms take the `T_parent_local` convention. + * + * (automatically generated) + */ +export interface SetBonePositionMessage { + type: "SetBonePositionMessage"; + name: string; + bone_index: number; + position: [number, number, number]; +} /** Message for transform gizmos. * * (automatically generated) @@ -860,6 +884,8 @@ export type Message = | MeshBoneMessage | MeshMessage | SkinnedMeshMessage + | SetBoneOrientationMessage + | SetBonePositionMessage | TransformControlsMessage | SetCameraPositionMessage | SetCameraUpDirectionMessage From d7809fc51d3755e7520d636913effbde71ccab58 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 21 May 2024 11:46:26 +0100 Subject: [PATCH 5/7] Add SMPL skinned example --- examples/23_smpl_visualizer_skinned.py | 288 +++++++++++++++++++++++++ 1 file changed, 288 insertions(+) create mode 100644 examples/23_smpl_visualizer_skinned.py diff --git a/examples/23_smpl_visualizer_skinned.py b/examples/23_smpl_visualizer_skinned.py new file mode 100644 index 000000000..b5c452d42 --- /dev/null +++ b/examples/23_smpl_visualizer_skinned.py @@ -0,0 +1,288 @@ +# mypy: disable-error-code="assignment" +# +# Asymmetric properties are supported in Pyright, but not yet in mypy. +# - https://github.com/python/mypy/issues/3004 +# - https://github.com/python/mypy/pull/11643 +"""Visualizer for SMPL human body models. Requires a .npz model file. + +See here for download instructions: + https://github.com/vchoutas/smplx?tab=readme-ov-file#downloading-the-model +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass +from pathlib import Path +from typing import List, Tuple + +import numpy as np +import numpy as onp +import tyro +import viser +import viser.transforms as tf + + +@dataclass(frozen=True) +class SmplOutputs: + vertices: np.ndarray + faces: np.ndarray + T_world_joint: np.ndarray # (num_joints, 4, 4) + T_parent_joint: np.ndarray # (num_joints, 4, 4) + + +class SmplHelper: + """Helper for models in the SMPL family, implemented in numpy.""" + + def __init__(self, model_path: Path) -> None: + assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!" + body_dict = dict(**onp.load(model_path, allow_pickle=True)) + + self._J_regressor = body_dict["J_regressor"] + self._weights = body_dict["weights"] + self._v_template = body_dict["v_template"] + self._posedirs = body_dict["posedirs"] + self._shapedirs = body_dict["shapedirs"] + self._faces = body_dict["f"] + + self.num_joints: int = self._weights.shape[-1] + self.num_betas: int = self._shapedirs.shape[-1] + self.parent_idx: np.ndarray = body_dict["kintree_table"][0] + + def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutputs: + # Get shaped vertices + joint positions, when all local poses are identity. + v_tpose = self._v_template + np.einsum("vxb,b->vx", self._shapedirs, betas) + j_tpose = np.einsum("jv,vx->jx", self._J_regressor, v_tpose) + + # Local SE(3) transforms. + T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4) + T_parent_joint[:, :3, :3] = joint_rotmats + T_parent_joint[0, :3, 3] = j_tpose[0] + T_parent_joint[1:, :3, 3] = j_tpose[1:] - j_tpose[self.parent_idx[1:]] + + # Forward kinematics. + T_world_joint = T_parent_joint.copy() + for i in range(1, self.num_joints): + T_world_joint[i] = T_world_joint[self.parent_idx[i]] @ T_parent_joint[i] + + # Linear blend skinning. + pose_delta = (joint_rotmats[1:, ...] - np.eye(3)).flatten() + v_blend = v_tpose + np.einsum("byn,n->by", self._posedirs, pose_delta) + v_delta = np.ones((v_blend.shape[0], self.num_joints, 4)) + v_delta[:, :, :3] = v_blend[:, None, :] - j_tpose[None, :, :] + v_posed = np.einsum( + "jxy,vj,vjy->vx", T_world_joint[:, :3, :], self._weights, v_delta + ) + return SmplOutputs(v_posed, self._faces, T_world_joint, T_parent_joint) + + +def main(model_path: Path) -> None: + server = viser.ViserServer() + server.set_up_direction("+y") + server.configure_theme(control_layout="collapsible") + + # Main loop. We'll read pose/shape from the GUI elements, compute the mesh, + # and then send the updated mesh in a loop. + model = SmplHelper(model_path) + gui_elements = make_gui_elements( + server, + num_betas=model.num_betas, + num_joints=model.num_joints, + parent_idx=model.parent_idx, + ) + smpl_outputs = model.get_outputs( + betas=np.array([x.value for x in gui_elements.gui_betas]), + joint_rotmats=onp.zeros((model.num_joints, 3, 3)) + onp.eye(3), + ) + + bone_wxyzs = np.array( + [tf.SO3.from_matrix(R).wxyz for R in smpl_outputs.T_world_joint[:, :3, :3]] + ) + bone_positions = smpl_outputs.T_world_joint[:, :3, 3] + + server.add_transform_controls("/root") + skinned_handle = server.add_mesh_skinned( + "/root/human", + smpl_outputs.vertices, + smpl_outputs.faces, + bone_wxyzs=bone_wxyzs, + bone_positions=bone_positions, + skin_weights=model._weights, + wireframe=gui_elements.gui_wireframe.value, + color=gui_elements.gui_rgb.value, + ) + + while True: + # Do nothing if no change. + time.sleep(0.02) + if not gui_elements.changed: + continue + + gui_elements.changed = False + + # Compute SMPL outputs. + smpl_outputs = model.get_outputs( + betas=np.array([x.value for x in gui_elements.gui_betas]), + joint_rotmats=np.stack( + [ + tf.SO3.exp(np.array(x.value)).as_matrix() + for x in gui_elements.gui_joints + ], + axis=0, + ), + ) + + # Match transform control gizmos to joint positions. + for i, control in enumerate(gui_elements.transform_controls): + control.position = smpl_outputs.T_parent_joint[i, :3, 3] + print(control.position) + + skinned_handle.bones[i].wxyz = tf.SO3.from_matrix( + smpl_outputs.T_world_joint[i, :3, :3] + ).wxyz + skinned_handle.bones[i].position = smpl_outputs.T_world_joint[i, :3, 3] + + +@dataclass +class GuiElements: + """Structure containing handles for reading from GUI elements.""" + + gui_rgb: viser.GuiInputHandle[Tuple[int, int, int]] + gui_wireframe: viser.GuiInputHandle[bool] + gui_betas: List[viser.GuiInputHandle[float]] + gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] + transform_controls: List[viser.TransformControlsHandle] + + changed: bool + """This flag will be flipped to True whenever the mesh needs to be re-generated.""" + + +def make_gui_elements( + server: viser.ViserServer, + num_betas: int, + num_joints: int, + parent_idx: np.ndarray, +) -> GuiElements: + """Make GUI elements for interacting with the model.""" + + tab_group = server.add_gui_tab_group() + + def set_changed(_) -> None: + out.changed = True # out is define later! + + # GUI elements: mesh settings + visibility. + with tab_group.add_tab("View", viser.Icon.VIEWFINDER): + gui_rgb = server.add_gui_rgb("Color", initial_value=(90, 200, 255)) + gui_wireframe = server.add_gui_checkbox("Wireframe", initial_value=False) + gui_show_controls = server.add_gui_checkbox("Handles", initial_value=True) + + gui_rgb.on_update(set_changed) + gui_wireframe.on_update(set_changed) + + @gui_show_controls.on_update + def _(_): + for control in transform_controls: + control.visible = gui_show_controls.value + + # GUI elements: shape parameters. + with tab_group.add_tab("Shape", viser.Icon.BOX): + gui_reset_shape = server.add_gui_button("Reset Shape") + gui_random_shape = server.add_gui_button("Random Shape") + + @gui_reset_shape.on_click + def _(_): + for beta in gui_betas: + beta.value = 0.0 + + @gui_random_shape.on_click + def _(_): + for beta in gui_betas: + beta.value = onp.random.normal(loc=0.0, scale=1.0) + + gui_betas = [] + for i in range(num_betas): + beta = server.add_gui_slider( + f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0 + ) + gui_betas.append(beta) + beta.on_update(set_changed) + + # GUI elements: joint angles. + with tab_group.add_tab("Joints", viser.Icon.ANGLE): + gui_reset_joints = server.add_gui_button("Reset Joints") + gui_random_joints = server.add_gui_button("Random Joints") + + @gui_reset_joints.on_click + def _(_): + for joint in gui_joints: + joint.value = (0.0, 0.0, 0.0) + + @gui_random_joints.on_click + def _(_): + for joint in gui_joints: + # It's hard to uniformly sample orientations directly in so(3), so we + # first sample on S^3 and then convert. + quat = onp.random.normal(loc=0.0, scale=1.0, size=(4,)) + quat /= onp.linalg.norm(quat) + joint.value = tf.SO3(wxyz=quat).log() + + gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] = [] + for i in range(num_joints): + gui_joint = server.add_gui_vector3( + label=f"Joint {i}", + initial_value=(0.0, 0.0, 0.0), + step=0.05, + ) + gui_joints.append(gui_joint) + + def set_callback_in_closure(i: int) -> None: + @gui_joint.on_update + def _(_): + transform_controls[i].wxyz = tf.SO3.exp( + np.array(gui_joints[i].value) + ).wxyz + out.changed = True + + set_callback_in_closure(i) + + # Transform control gizmos on joints. + transform_controls: List[viser.TransformControlsHandle] = [] + prefixed_joint_names = [] # Joint names, but prefixed with parents. + for i in range(num_joints): + prefixed_joint_name = f"joint_{i}" + if i > 0: + prefixed_joint_name = ( + prefixed_joint_names[parent_idx[i]] + "/" + prefixed_joint_name + ) + prefixed_joint_names.append(prefixed_joint_name) + controls = server.add_transform_controls( + f"/root/smpl/{prefixed_joint_name}", + depth_test=False, + scale=0.2 * (0.75 ** prefixed_joint_name.count("/")), + disable_axes=True, + disable_sliders=True, + visible=gui_show_controls.value, + ) + transform_controls.append(controls) + + def set_callback_in_closure(i: int) -> None: + @controls.on_update + def _(_) -> None: + axisangle = tf.SO3(transform_controls[i].wxyz).log() + gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2]) + + set_callback_in_closure(i) + + out = GuiElements( + gui_rgb, + gui_wireframe, + gui_betas, + gui_joints, + transform_controls=transform_controls, + changed=True, + ) + return out + + +if __name__ == "__main__": + tyro.cli(main, description=__doc__) From 310e22d7a364539a0cfd43ee71644eb6b9f2b47c Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 9 Jul 2024 18:06:07 +0900 Subject: [PATCH 6/7] Formatting + eslint --- src/viser/_scene_api.py | 4 ++-- src/viser/_scene_handles.py | 6 +++--- src/viser/client/src/WebsocketInterface.tsx | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/viser/_scene_api.py b/src/viser/_scene_api.py index f1d851cd2..724e740d9 100644 --- a/src/viser/_scene_api.py +++ b/src/viser/_scene_api.py @@ -16,7 +16,6 @@ from . import transforms as tf from ._scene_handles import ( BatchedAxesHandle, - MeshSkinnedBoneHandle, BoneState, CameraFrustumHandle, FrameHandle, @@ -25,11 +24,12 @@ ImageHandle, LabelHandle, MeshHandle, + MeshSkinnedBoneHandle, + MeshSkinnedHandle, PointCloudHandle, SceneNodeHandle, SceneNodePointerEvent, ScenePointerEvent, - MeshSkinnedHandle, TransformControlsHandle, _SceneNodeHandleState, _TransformControlsState, diff --git a/src/viser/_scene_handles.py b/src/viser/_scene_handles.py index 081457602..bc73ff1d4 100644 --- a/src/viser/_scene_handles.py +++ b/src/viser/_scene_handles.py @@ -56,9 +56,9 @@ class _SceneNodeHandleState: ) visible: bool = True # TODO: we should remove SceneNodeHandle as an argument here. - click_cb: list[ - Callable[[SceneNodePointerEvent[SceneNodeHandle]], None] - ] | None = None + click_cb: list[Callable[[SceneNodePointerEvent[SceneNodeHandle]], None]] | None = ( + None + ) @dataclasses.dataclass diff --git a/src/viser/client/src/WebsocketInterface.tsx b/src/viser/client/src/WebsocketInterface.tsx index 3682d347f..06722e2ef 100644 --- a/src/viser/client/src/WebsocketInterface.tsx +++ b/src/viser/client/src/WebsocketInterface.tsx @@ -31,7 +31,7 @@ import { sendWebsocketMessage, } from "./WebsocketFunctions"; import { isGuiConfig } from "./ControlPanel/GuiState"; -import { useFrame, useThree } from "@react-three/fiber"; +import { useFrame } from "@react-three/fiber"; import GeneratedGuiContainer from "./ControlPanel/Generated"; import { Paper, Progress } from "@mantine/core"; import { IconCheck } from "@tabler/icons-react"; From af7d12c79331326a33fc41005fcecf3f8fea9629 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 9 Jul 2024 18:18:10 +0900 Subject: [PATCH 7/7] prettier --- src/viser/client/src/WebsocketInterface.tsx | 44 ++++++++++----------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/src/viser/client/src/WebsocketInterface.tsx b/src/viser/client/src/WebsocketInterface.tsx index 06722e2ef..cab41d54b 100644 --- a/src/viser/client/src/WebsocketInterface.tsx +++ b/src/viser/client/src/WebsocketInterface.tsx @@ -235,20 +235,16 @@ function useMessageHandler() { message.plane == "xz" ? new THREE.Euler(0.0, 0.0, 0.0) : message.plane == "xy" - ? new THREE.Euler(Math.PI / 2.0, 0.0, 0.0) - : message.plane == "yx" - ? new THREE.Euler(0.0, Math.PI / 2.0, Math.PI / 2.0) - : message.plane == "yz" - ? new THREE.Euler(0.0, 0.0, Math.PI / 2.0) - : message.plane == "zx" - ? new THREE.Euler(0.0, Math.PI / 2.0, 0.0) - : message.plane == "zy" - ? new THREE.Euler( - -Math.PI / 2.0, - 0.0, - -Math.PI / 2.0, - ) - : undefined + ? new THREE.Euler(Math.PI / 2.0, 0.0, 0.0) + : message.plane == "yx" + ? new THREE.Euler(0.0, Math.PI / 2.0, Math.PI / 2.0) + : message.plane == "yz" + ? new THREE.Euler(0.0, 0.0, Math.PI / 2.0) + : message.plane == "zx" + ? new THREE.Euler(0.0, Math.PI / 2.0, 0.0) + : message.plane == "zy" + ? new THREE.Euler(-Math.PI / 2.0, 0.0, -Math.PI / 2.0) + : undefined } /> @@ -336,16 +332,16 @@ function useMessageHandler() { message.material == "standard" || message.wireframe ? new THREE.MeshStandardMaterial(standardArgs) : message.material == "toon3" - ? new THREE.MeshToonMaterial({ - gradientMap: generateGradientMap(3), - ...standardArgs, - }) - : message.material == "toon5" - ? new THREE.MeshToonMaterial({ - gradientMap: generateGradientMap(5), - ...standardArgs, - }) - : assertUnreachable(message.material); + ? new THREE.MeshToonMaterial({ + gradientMap: generateGradientMap(3), + ...standardArgs, + }) + : message.material == "toon5" + ? new THREE.MeshToonMaterial({ + gradientMap: generateGradientMap(5), + ...standardArgs, + }) + : assertUnreachable(message.material); geometry.setAttribute( "position", new THREE.Float32BufferAttribute(