diff --git a/examples/00_coordinate_frames.py b/examples/00_coordinate_frames.py index a24daef0c..2e7bb6b49 100644 --- a/examples/00_coordinate_frames.py +++ b/examples/00_coordinate_frames.py @@ -15,17 +15,17 @@ while True: # Add some coordinate frames to the scene. These will be visualized in the viewer. - server.add_frame( + server.scene.add_frame( "/tree", wxyz=(1.0, 0.0, 0.0, 0.0), position=(random.random() * 2.0, 2.0, 0.2), ) - server.add_frame( + server.scene.add_frame( "/tree/branch", wxyz=(1.0, 0.0, 0.0, 0.0), position=(random.random() * 2.0, 2.0, 0.2), ) - leaf = server.add_frame( + leaf = server.scene.add_frame( "/tree/branch/leaf", wxyz=(1.0, 0.0, 0.0, 0.0), position=(random.random() * 2.0, 2.0, 0.2), diff --git a/examples/01_image.py b/examples/01_image.py index d2e2f421a..05568b679 100644 --- a/examples/01_image.py +++ b/examples/01_image.py @@ -18,13 +18,13 @@ def main() -> None: server = viser.ViserServer() # Add a background image. - server.set_background_image( + server.scene.set_background_image( iio.imread(Path(__file__).parent / "assets/Cal_logo.png"), format="png", ) # Add main image. - server.add_image( + server.scene.add_image( "/img", iio.imread(Path(__file__).parent / "assets/Cal_logo.png"), 4.0, @@ -34,7 +34,7 @@ def main() -> None: position=(2.0, 2.0, 0.0), ) while True: - server.add_image( + server.scene.add_image( "/noise", onp.random.randint( 0, diff --git a/examples/02_gui.py b/examples/02_gui.py index 72976934c..5688ff92e 100644 --- a/examples/02_gui.py +++ b/examples/02_gui.py @@ -17,14 +17,14 @@ def main() -> None: server = viser.ViserServer() # Add some common GUI elements: number inputs, sliders, vectors, checkboxes. - with server.add_gui_folder("Read-only"): - gui_counter = server.add_gui_number( + with server.gui.add_folder("Read-only"): + gui_counter = server.gui.add_number( "Counter", initial_value=0, disabled=True, ) - gui_slider = server.add_gui_slider( + gui_slider = server.gui.add_slider( "Slider", min=0, max=100, @@ -33,36 +33,36 @@ def main() -> None: disabled=True, ) - with server.add_gui_folder("Editable"): - gui_vector2 = server.add_gui_vector2( + with server.gui.add_folder("Editable"): + gui_vector2 = server.gui.add_vector2( "Position", initial_value=(0.0, 0.0), step=0.1, ) - gui_vector3 = server.add_gui_vector3( + gui_vector3 = server.gui.add_vector3( "Size", initial_value=(1.0, 1.0, 1.0), step=0.25, ) - with server.add_gui_folder("Text toggle"): - gui_checkbox_hide = server.add_gui_checkbox( + with server.gui.add_folder("Text toggle"): + gui_checkbox_hide = server.gui.add_checkbox( "Hide", initial_value=False, ) - gui_text = server.add_gui_text( + gui_text = server.gui.add_text( "Text", initial_value="Hello world", ) - gui_button = server.add_gui_button("Button") - gui_checkbox_disable = server.add_gui_checkbox( + gui_button = server.gui.add_button("Button") + gui_checkbox_disable = server.gui.add_checkbox( "Disable", initial_value=False, ) - gui_rgb = server.add_gui_rgb( + gui_rgb = server.gui.add_rgb( "Color", initial_value=(255, 255, 0), ) - gui_multi_slider = server.add_gui_multi_slider( + gui_multi_slider = server.gui.add_multi_slider( "Multi slider", min=0, max=100, @@ -70,7 +70,7 @@ def main() -> None: initial_value=(0, 30, 100), marks=((0, "0"), (50, "5"), (70, "7"), 99), ) - gui_slider_positions = server.add_gui_slider( + gui_slider_positions = server.gui.add_slider( "# sliders", min=0, max=10, @@ -78,7 +78,7 @@ def main() -> None: initial_value=3, marks=((0, "0"), (5, "5"), (7, "7"), 10), ) - gui_upload_button = server.add_gui_upload_button( + gui_upload_button = server.gui.add_upload_button( "Upload", icon=viser.Icon.UPLOAD ) @@ -102,7 +102,7 @@ def _(_) -> None: # We can set the position of a scene node with `.position`, and read the value # of a gui element with `.value`. Changes are automatically reflected in # connected clients. - server.add_point_cloud( + server.scene.add_point_cloud( "/point_cloud", points=point_positions * onp.array(gui_vector3.value, dtype=onp.float32), colors=( diff --git a/examples/03_gui_callbacks.py b/examples/03_gui_callbacks.py index 54923da59..6528fb2fc 100644 --- a/examples/03_gui_callbacks.py +++ b/examples/03_gui_callbacks.py @@ -17,14 +17,14 @@ def main() -> None: server = viser.ViserServer() - gui_reset_scene = server.add_gui_button("Reset Scene") + gui_reset_scene = server.gui.add_button("Reset Scene") - gui_plane = server.add_gui_dropdown( + gui_plane = server.gui.add_dropdown( "Grid plane", ("xz", "xy", "yx", "yz", "zx", "zy") ) def update_plane() -> None: - server.add_grid( + server.scene.add_grid( "/grid", width=10.0, height=20.0, @@ -35,23 +35,23 @@ def update_plane() -> None: gui_plane.on_update(lambda _: update_plane()) - with server.add_gui_folder("Control"): - gui_show_frame = server.add_gui_checkbox("Show Frame", initial_value=True) - gui_show_everything = server.add_gui_checkbox( + with server.gui.add_folder("Control"): + gui_show_frame = server.gui.add_checkbox("Show Frame", initial_value=True) + gui_show_everything = server.gui.add_checkbox( "Show Everything", initial_value=True ) - gui_axis = server.add_gui_dropdown("Axis", ("x", "y", "z")) - gui_include_z = server.add_gui_checkbox("Z in dropdown", initial_value=True) + gui_axis = server.gui.add_dropdown("Axis", ("x", "y", "z")) + gui_include_z = server.gui.add_checkbox("Z in dropdown", initial_value=True) @gui_include_z.on_update def _(_) -> None: gui_axis.options = ("x", "y", "z") if gui_include_z.value else ("x", "y") - with server.add_gui_folder("Sliders"): - gui_location = server.add_gui_slider( + with server.gui.add_folder("Sliders"): + gui_location = server.gui.add_slider( "Location", min=-5.0, max=5.0, step=0.05, initial_value=0.0 ) - gui_num_points = server.add_gui_slider( + gui_num_points = server.gui.add_slider( "# Points", min=1000, max=200_000, step=1000, initial_value=10_000 ) @@ -66,7 +66,7 @@ def draw_frame() -> None: else: assert_never(axis) - server.add_frame( + server.scene.add_frame( "/frame", wxyz=(1.0, 0.0, 0.0, 0.0), position=pos, @@ -76,7 +76,7 @@ def draw_frame() -> None: def draw_points() -> None: num_points = gui_num_points.value - server.add_point_cloud( + server.scene.add_point_cloud( "/frame/point_cloud", points=onp.random.normal(size=(num_points, 3)), colors=onp.random.randint(0, 256, size=(num_points, 3)), @@ -86,7 +86,9 @@ def draw_points() -> None: # Here, we update the point clouds + frames whenever any of the GUI items are updated. gui_show_frame.on_update(lambda _: draw_frame()) gui_show_everything.on_update( - lambda _: server.set_global_scene_node_visibility(gui_show_everything.value) + lambda _: server.scene.set_global_scene_node_visibility( + gui_show_everything.value + ) ) gui_axis.on_update(lambda _: draw_frame()) gui_location.on_update(lambda _: draw_frame()) diff --git a/examples/04_camera_poses.py b/examples/04_camera_poses.py index 90b9e4ffb..4f07566cc 100644 --- a/examples/04_camera_poses.py +++ b/examples/04_camera_poses.py @@ -8,7 +8,7 @@ import viser server = viser.ViserServer() -server.world_axes.visible = True +server.scene.world_axes.visible = True @server.on_client_connect @@ -21,7 +21,7 @@ def _(_: viser.CameraHandle) -> None: print(f"New camera on client {client.client_id}!") # Show the client ID in the GUI. - gui_info = client.add_gui_text("Client ID", initial_value=str(client.client_id)) + gui_info = client.gui.add_text("Client ID", initial_value=str(client.client_id)) gui_info.disabled = True diff --git a/examples/05_camera_commands.py b/examples/05_camera_commands.py index 0962ef9cc..bb059af63 100644 --- a/examples/05_camera_commands.py +++ b/examples/05_camera_commands.py @@ -35,8 +35,8 @@ def make_frame(i: int) -> None: position = rng.uniform(-3.0, 3.0, size=(3,)) # Create a coordinate frame and label. - frame = client.add_frame(f"/frame_{i}", wxyz=wxyz, position=position) - client.add_label(f"/frame_{i}/label", text=f"Frame {i}") + frame = client.scene.add_frame(f"/frame_{i}", wxyz=wxyz, position=position) + client.scene.add_label(f"/frame_{i}/label", text=f"Frame {i}") # Move the camera when we click a frame. @frame.on_click diff --git a/examples/06_mesh.py b/examples/06_mesh.py index 1bfd42ba0..eeeaf783b 100644 --- a/examples/06_mesh.py +++ b/examples/06_mesh.py @@ -20,14 +20,14 @@ print(f"Loaded mesh with {vertices.shape} vertices, {faces.shape} faces") server = viser.ViserServer() -server.add_mesh_simple( +server.scene.add_mesh_simple( name="/simple", vertices=vertices, faces=faces, wxyz=tf.SO3.from_x_radians(onp.pi / 2).wxyz, position=(0.0, 0.0, 0.0), ) -server.add_mesh_trimesh( +server.scene.add_mesh_trimesh( name="/trimesh", mesh=mesh.smoothed(), wxyz=tf.SO3.from_x_radians(onp.pi / 2).wxyz, diff --git a/examples/07_record3d_visualizer.py b/examples/07_record3d_visualizer.py index 9ff5c5d81..10332262f 100644 --- a/examples/07_record3d_visualizer.py +++ b/examples/07_record3d_visualizer.py @@ -30,8 +30,8 @@ def main( num_frames = min(max_frames, loader.num_frames()) # Add playback UI. - with server.add_gui_folder("Playback"): - gui_timestep = server.add_gui_slider( + with server.gui.add_folder("Playback"): + gui_timestep = server.gui.add_slider( "Timestep", min=0, max=num_frames - 1, @@ -39,13 +39,13 @@ def main( initial_value=0, disabled=True, ) - gui_next_frame = server.add_gui_button("Next Frame", disabled=True) - gui_prev_frame = server.add_gui_button("Prev Frame", disabled=True) - gui_playing = server.add_gui_checkbox("Playing", True) - gui_framerate = server.add_gui_slider( + gui_next_frame = server.gui.add_button("Next Frame", disabled=True) + gui_prev_frame = server.gui.add_button("Prev Frame", disabled=True) + gui_playing = server.gui.add_checkbox("Playing", True) + gui_framerate = server.gui.add_slider( "FPS", min=1, max=60, step=0.1, initial_value=loader.fps ) - gui_framerate_options = server.add_gui_button_group( + gui_framerate_options = server.gui.add_button_group( "FPS options", ("10", "20", "30", "60") ) @@ -84,7 +84,7 @@ def _(_) -> None: server.flush() # Optional! # Load in frames. - server.add_frame( + server.scene.add_frame( "/frames", wxyz=tf.SO3.exp(onp.array([onp.pi / 2.0, 0.0, 0.0])).wxyz, position=(0, 0, 0), @@ -96,10 +96,10 @@ def _(_) -> None: position, color = frame.get_point_cloud(downsample_factor) # Add base frame. - frame_nodes.append(server.add_frame(f"/frames/t{i}", show_axes=False)) + frame_nodes.append(server.scene.add_frame(f"/frames/t{i}", show_axes=False)) # Place the point cloud in the frame. - server.add_point_cloud( + server.scene.add_point_cloud( name=f"/frames/t{i}/point_cloud", points=position, colors=color, @@ -110,7 +110,7 @@ def _(_) -> None: # Place the frustum. fov = 2 * onp.arctan2(frame.rgb.shape[0] / 2, frame.K[0, 0]) aspect = frame.rgb.shape[1] / frame.rgb.shape[0] - server.add_camera_frustum( + server.scene.add_camera_frustum( f"/frames/t{i}/frustum", fov=fov, aspect=aspect, @@ -121,7 +121,7 @@ def _(_) -> None: ) # Add some axes. - server.add_frame( + server.scene.add_frame( f"/frames/t{i}/frustum/axes", axes_length=0.05, axes_radius=0.005, diff --git a/examples/08_smpl_visualizer.py b/examples/08_smpl_visualizer.py index 2d3377735..46a5e9ec8 100644 --- a/examples/08_smpl_visualizer.py +++ b/examples/08_smpl_visualizer.py @@ -78,8 +78,8 @@ def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutpu def main(model_path: Path) -> None: server = viser.ViserServer() - server.set_up_direction("+y") - server.configure_theme(control_layout="collapsible") + server.scene.set_up_direction("+y") + server.gui.configure_theme(control_layout="collapsible") # Main loop. We'll read pose/shape from the GUI elements, compute the mesh, # and then send the updated mesh in a loop. @@ -106,7 +106,7 @@ def main(model_path: Path) -> None: np.array([x.value for x in gui_elements.gui_joints]) ).as_matrix(), ) - server.add_mesh_simple( + server.scene.add_mesh_simple( "/human", smpl_outputs.vertices, smpl_outputs.faces, @@ -141,16 +141,16 @@ def make_gui_elements( ) -> GuiElements: """Make GUI elements for interacting with the model.""" - tab_group = server.add_gui_tab_group() + tab_group = server.gui.add_tab_group() def set_changed(_) -> None: out.changed = True # out is define later! # GUI elements: mesh settings + visibility. with tab_group.add_tab("View", viser.Icon.VIEWFINDER): - gui_rgb = server.add_gui_rgb("Color", initial_value=(90, 200, 255)) - gui_wireframe = server.add_gui_checkbox("Wireframe", initial_value=False) - gui_show_controls = server.add_gui_checkbox("Handles", initial_value=False) + gui_rgb = server.gui.add_rgb("Color", initial_value=(90, 200, 255)) + gui_wireframe = server.gui.add_checkbox("Wireframe", initial_value=False) + gui_show_controls = server.gui.add_checkbox("Handles", initial_value=False) gui_rgb.on_update(set_changed) gui_wireframe.on_update(set_changed) @@ -162,8 +162,8 @@ def _(_): # GUI elements: shape parameters. with tab_group.add_tab("Shape", viser.Icon.BOX): - gui_reset_shape = server.add_gui_button("Reset Shape") - gui_random_shape = server.add_gui_button("Random Shape") + gui_reset_shape = server.gui.add_button("Reset Shape") + gui_random_shape = server.gui.add_button("Random Shape") @gui_reset_shape.on_click def _(_): @@ -177,7 +177,7 @@ def _(_): gui_betas = [] for i in range(num_betas): - beta = server.add_gui_slider( + beta = server.gui.add_slider( f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0 ) gui_betas.append(beta) @@ -185,8 +185,8 @@ def _(_): # GUI elements: joint angles. with tab_group.add_tab("Joints", viser.Icon.ANGLE): - gui_reset_joints = server.add_gui_button("Reset Joints") - gui_random_joints = server.add_gui_button("Random Joints") + gui_reset_joints = server.gui.add_button("Reset Joints") + gui_random_joints = server.gui.add_button("Random Joints") @gui_reset_joints.on_click def _(_): @@ -204,7 +204,7 @@ def _(_): gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] = [] for i in range(num_joints): - gui_joint = server.add_gui_vector3( + gui_joint = server.gui.add_vector3( label=f"Joint {i}", initial_value=(0.0, 0.0, 0.0), step=0.05, @@ -231,7 +231,7 @@ def _(_): prefixed_joint_names[parent_idx[i]] + "/" + prefixed_joint_name ) prefixed_joint_names.append(prefixed_joint_name) - controls = server.add_transform_controls( + controls = server.scene.add_transform_controls( f"/smpl/{prefixed_joint_name}", depth_test=False, scale=0.2 * (0.75 ** prefixed_joint_name.count("/")), diff --git a/examples/09_urdf_visualizer.py b/examples/09_urdf_visualizer.py index da99c257a..91fd0f8eb 100644 --- a/examples/09_urdf_visualizer.py +++ b/examples/09_urdf_visualizer.py @@ -34,7 +34,7 @@ def main(urdf_path: Path) -> None: upper = upper if upper is not None else onp.pi initial_angle = 0.0 if lower < 0 and upper > 0 else (lower + upper) / 2.0 - slider = server.add_gui_slider( + slider = server.gui.add_slider( label=joint_name, min=lower, max=upper, @@ -49,7 +49,7 @@ def main(urdf_path: Path) -> None: initial_angles.append(initial_angle) # Create joint reset button. - reset_button = server.add_gui_button("Reset") + reset_button = server.gui.add_button("Reset") @reset_button.on_click def _(_): diff --git a/examples/10_realsense.py b/examples/10_realsense.py index ae153ec28..b19b38129 100644 --- a/examples/10_realsense.py +++ b/examples/10_realsense.py @@ -91,7 +91,7 @@ def point_cloud_arrays_from_frames( def main(): # Start visualization server. - viser_server = viser.ViserServer() + server = viser.ViserServer() with realsense_pipeline() as pipeline: for i in tqdm(range(10000000)): @@ -114,7 +114,7 @@ def main(): positions = positions @ R.T # Visualize. - viser_server.add_point_cloud( + server.scene.add_point_cloud( "/realsense", points=positions * 10.0, colors=colors, diff --git a/examples/11_colmap_visualizer.py b/examples/11_colmap_visualizer.py index f655a17de..493131b69 100644 --- a/examples/11_colmap_visualizer.py +++ b/examples/11_colmap_visualizer.py @@ -33,13 +33,13 @@ def main( downsample_factor: Downsample factor for the images. """ server = viser.ViserServer() - server.configure_theme(titlebar_content=None, control_layout="collapsible") + server.gui.configure_theme(titlebar_content=None, control_layout="collapsible") # Load the colmap info. cameras = read_cameras_binary(colmap_path / "cameras.bin") images = read_images_binary(colmap_path / "images.bin") points3d = read_points3d_binary(colmap_path / "points3D.bin") - gui_reset_up = server.add_gui_button( + gui_reset_up = server.gui.add_button( "Reset up direction", hint="Set the camera control 'up' direction to the current camera's 'up'.", ) @@ -52,21 +52,21 @@ def _(event: viser.GuiEvent) -> None: [0.0, -1.0, 0.0] ) - gui_points = server.add_gui_slider( + gui_points = server.gui.add_slider( "Max points", min=1, max=len(points3d), step=1, initial_value=min(len(points3d), 50_000), ) - gui_frames = server.add_gui_slider( + gui_frames = server.gui.add_slider( "Max frames", min=1, max=len(images), step=1, initial_value=min(len(images), 100), ) - gui_point_size = server.add_gui_number("Point size", initial_value=0.05) + gui_point_size = server.gui.add_number("Point size", initial_value=0.05) def visualize_colmap() -> None: """Send all COLMAP elements to viser for visualization. This could be optimized @@ -80,7 +80,7 @@ def visualize_colmap() -> None: points = points[points_selection] colors = colors[points_selection] - server.add_point_cloud( + server.scene.add_point_cloud( name="/colmap/pcd", points=points, colors=colors, @@ -113,7 +113,7 @@ def _(_) -> None: T_world_camera = tf.SE3.from_rotation_and_translation( tf.SO3(img.qvec), img.tvec ).inverse() - frame = server.add_frame( + frame = server.scene.add_frame( f"/colmap/frame_{img_id}", wxyz=T_world_camera.rotation().wxyz, position=T_world_camera.translation(), @@ -129,7 +129,7 @@ def _(_) -> None: fy = cam.params[1] image = iio.imread(image_filename) image = image[::downsample_factor, ::downsample_factor] - frustum = server.add_camera_frustum( + frustum = server.scene.add_camera_frustum( f"/colmap/frame_{img_id}/frustum", fov=2 * onp.arctan2(H / 2, fy), aspect=W / H, @@ -159,7 +159,7 @@ def _(_) -> None: if need_update: need_update = False - server.reset_scene() + server.scene.reset() visualize_colmap() time.sleep(1e-3) diff --git a/examples/12_click_meshes.py b/examples/12_click_meshes.py index e5d1dcb9f..ddd705840 100644 --- a/examples/12_click_meshes.py +++ b/examples/12_click_meshes.py @@ -13,14 +13,14 @@ def main() -> None: grid_shape = (4, 5) server = viser.ViserServer() - with server.add_gui_folder("Last clicked"): - x_value = server.add_gui_number( + with server.gui.add_folder("Last clicked"): + x_value = server.gui.add_number( label="x", initial_value=0, disabled=True, hint="x coordinate of the last clicked mesh", ) - y_value = server.add_gui_number( + y_value = server.gui.add_number( label="y", initial_value=0, disabled=True, @@ -46,14 +46,14 @@ def create_mesh(counter: int) -> None: color = colormap(index)[:3] if counter in (0, 1): - handle = server.add_box( + handle = server.scene.add_box( name=f"/sphere_{i}_{j}", position=(i, j, 0.0), color=color, dimensions=(0.5, 0.5, 0.5), ) else: - handle = server.add_icosphere( + handle = server.scene.add_icosphere( name=f"/sphere_{i}_{j}", radius=0.4, color=color, diff --git a/examples/13_theming.py b/examples/13_theming.py index 1a7dac6e3..057a41af2 100644 --- a/examples/13_theming.py +++ b/examples/13_theming.py @@ -40,28 +40,28 @@ def main(): ) titlebar_theme = TitlebarConfig(buttons=buttons, image=image) - server.add_gui_markdown( + server.gui.add_markdown( "Viser includes support for light theming via the `.configure_theme()` method." ) - gui_theme_code = server.add_gui_markdown("no theme applied yet") + gui_theme_code = server.gui.add_markdown("no theme applied yet") # GUI elements for controllable values. - titlebar = server.add_gui_checkbox("Titlebar", initial_value=True) - dark_mode = server.add_gui_checkbox("Dark mode", initial_value=True) - show_logo = server.add_gui_checkbox("Show logo", initial_value=True) - show_share_button = server.add_gui_checkbox("Show share button", initial_value=True) - brand_color = server.add_gui_rgb("Brand color", (230, 180, 30)) - control_layout = server.add_gui_dropdown( + titlebar = server.gui.add_checkbox("Titlebar", initial_value=True) + dark_mode = server.gui.add_checkbox("Dark mode", initial_value=True) + show_logo = server.gui.add_checkbox("Show logo", initial_value=True) + show_share_button = server.gui.add_checkbox("Show share button", initial_value=True) + brand_color = server.gui.add_rgb("Brand color", (230, 180, 30)) + control_layout = server.gui.add_dropdown( "Control layout", ("floating", "fixed", "collapsible") ) - control_width = server.add_gui_dropdown( + control_width = server.gui.add_dropdown( "Control width", ("small", "medium", "large"), initial_value="medium" ) - synchronize = server.add_gui_button("Apply theme", icon=viser.Icon.CHECK) + synchronize = server.gui.add_button("Apply theme", icon=viser.Icon.CHECK) def synchronize_theme() -> None: - server.configure_theme( + server.gui.configure_theme( titlebar_content=titlebar_theme if titlebar.value else None, control_layout=control_layout.value, control_width=control_width.value, @@ -73,7 +73,7 @@ def synchronize_theme() -> None: gui_theme_code.content = f""" ### Current applied theme ``` - server.configure_theme( + server.gui.configure_theme( titlebar_content={"titlebar_content" if titlebar.value else None}, control_layout="{control_layout.value}", control_width="{control_width.value}", diff --git a/examples/14_markdown.py b/examples/14_markdown.py index 4d85ac6cc..97b93b8a0 100644 --- a/examples/14_markdown.py +++ b/examples/14_markdown.py @@ -9,17 +9,17 @@ import viser server = viser.ViserServer() -server.world_axes.visible = True +server.scene.world_axes.visible = True -markdown_counter = server.add_gui_markdown("Counter: 0") +markdown_counter = server.gui.add_markdown("Counter: 0") here = Path(__file__).absolute().parent -button = server.add_gui_button("Remove blurb") -checkbox = server.add_gui_checkbox("Visibility", initial_value=True) +button = server.gui.add_button("Remove blurb") +checkbox = server.gui.add_checkbox("Visibility", initial_value=True) markdown_source = (here / "./assets/mdx_example.mdx").read_text() -markdown_blurb = server.add_gui_markdown( +markdown_blurb = server.gui.add_markdown( content=markdown_source, image_root=here, ) diff --git a/examples/15_gui_in_scene.py b/examples/15_gui_in_scene.py index 82ba9c360..81728ec7e 100644 --- a/examples/15_gui_in_scene.py +++ b/examples/15_gui_in_scene.py @@ -39,7 +39,7 @@ def make_frame(i: int) -> None: position = rng.uniform(-3.0, 3.0, size=(3,)) # Create a coordinate frame and label. - frame = client.add_frame(f"/frame_{i}", wxyz=wxyz, position=position) + frame = client.scene.add_frame(f"/frame_{i}", wxyz=wxyz, position=position) # Move the camera when we click a frame. @frame.on_click @@ -50,11 +50,11 @@ def _(_): if displayed_3d_container is not None: displayed_3d_container.remove() - displayed_3d_container = client.add_3d_gui_container(f"/frame_{i}/gui") + displayed_3d_container = client.scene.add_3d_gui_container(f"/frame_{i}/gui") with displayed_3d_container: - go_to = client.add_gui_button("Go to") - randomize_orientation = client.add_gui_button("Randomize orientation") - close = client.add_gui_button("Close GUI") + go_to = client.gui.add_button("Go to") + randomize_orientation = client.gui.add_button("Randomize orientation") + close = client.gui.add_button("Close GUI") @go_to.on_click def _(_) -> None: diff --git a/examples/16_modal.py b/examples/16_modal.py index 88afc612a..0858e55ff 100644 --- a/examples/16_modal.py +++ b/examples/16_modal.py @@ -12,23 +12,23 @@ def main(): @server.on_client_connect def _(client: viser.ClientHandle) -> None: - with client.add_gui_modal("Modal example"): - client.add_gui_markdown( + with client.gui.add_modal("Modal example"): + client.gui.add_markdown( "**The input below determines the title of the modal...**" ) - gui_title = client.add_gui_text( + gui_title = client.gui.add_text( "Title", initial_value="My Modal", ) - modal_button = client.add_gui_button("Show more modals") + modal_button = client.gui.add_button("Show more modals") @modal_button.on_click def _(_) -> None: - with client.add_gui_modal(gui_title.value) as modal: - client.add_gui_markdown("This is content inside the modal!") - client.add_gui_button("Close").on_click(lambda _: modal.close()) + with client.gui.add_modal(gui_title.value) as modal: + client.gui.add_markdown("This is content inside the modal!") + client.gui.add_button("Close").on_click(lambda _: modal.close()) while True: time.sleep(0.15) diff --git a/examples/17_background_composite.py b/examples/17_background_composite.py index 1bf4f481c..3904d442e 100644 --- a/examples/17_background_composite.py +++ b/examples/17_background_composite.py @@ -22,12 +22,12 @@ img[250:750, 250:750, :] = 255 mesh = trimesh.creation.box((0.5, 0.5, 0.5)) -server.add_mesh_trimesh( +server.scene.add_mesh_trimesh( name="/cube", mesh=mesh, position=(0, 0, 0.0), ) -server.set_background_image(img, depth=depth) +server.scene.set_background_image(img, depth=depth) while True: diff --git a/examples/18_splines.py b/examples/18_splines.py index c0174b1ca..3c939c9d0 100644 --- a/examples/18_splines.py +++ b/examples/18_splines.py @@ -13,7 +13,7 @@ def main() -> None: server = viser.ViserServer() for i in range(10): positions = onp.random.normal(size=(30, 3)) * 3.0 - server.add_spline_catmull_rom( + server.scene.add_spline_catmull_rom( f"/catmull_{i}", positions, tension=0.5, @@ -23,7 +23,7 @@ def main() -> None: ) control_points = onp.random.normal(size=(30 * 2 - 2, 3)) * 3.0 - server.add_spline_cubic_bezier( + server.scene.add_spline_cubic_bezier( f"/cubic_bezier_{i}", positions, control_points, diff --git a/examples/19_get_renders.py b/examples/19_get_renders.py index 4da13dda3..f235730bb 100644 --- a/examples/19_get_renders.py +++ b/examples/19_get_renders.py @@ -12,20 +12,20 @@ def main(): server = viser.ViserServer() - button = server.add_gui_button("Render a GIF") + button = server.gui.add_button("Render a GIF") @button.on_click def _(event: viser.GuiEvent) -> None: client = event.client assert client is not None - client.reset_scene() + client.scene.reset() images = [] for i in range(20): positions = onp.random.normal(size=(30, 3)) * 3.0 - client.add_spline_catmull_rom( + client.scene.add_spline_catmull_rom( f"/catmull_{i}", positions, tension=0.5, diff --git a/examples/20_scene_pointer.py b/examples/20_scene_pointer.py index caf18a638..e828f5e68 100644 --- a/examples/20_scene_pointer.py +++ b/examples/20_scene_pointer.py @@ -18,15 +18,15 @@ import viser.transforms as tf server = viser.ViserServer() -server.configure_theme(brand_color=(130, 0, 150)) -server.set_up_direction("+y") +server.gui.configure_theme(brand_color=(130, 0, 150)) +server.scene.set_up_direction("+y") mesh = cast( trimesh.Trimesh, trimesh.load_mesh(str(Path(__file__).parent / "assets/dragon.obj")) ) mesh.apply_scale(0.05) -mesh_handle = server.add_mesh_trimesh( +mesh_handle = server.scene.add_mesh_trimesh( name="/mesh", mesh=mesh, position=(0.0, 0.0, 0.0), @@ -43,13 +43,13 @@ def _(client: viser.ClientHandle) -> None: client.camera.wxyz = onp.array([0.0, 0.0, 0.0, 1.0]) # Tests "click" scenepointerevent. - click_button_handle = client.add_gui_button("Add sphere", icon=viser.Icon.POINTER) + click_button_handle = client.gui.add_button("Add sphere", icon=viser.Icon.POINTER) @click_button_handle.on_click def _(_): click_button_handle.disabled = True - @client.on_scene_pointer(event_type="click") + @client.scene.on_pointer_event(event_type="click") def _(event: viser.ScenePointerEvent) -> None: # Check for intersection with the mesh, using trimesh's ray-mesh intersection. # Note that mesh is in the mesh frame, so we need to transform the ray. @@ -62,7 +62,7 @@ def _(event: viser.ScenePointerEvent) -> None: if len(hit_pos) == 0: return - client.remove_scene_pointer_callback() + client.scene.remove_pointer_callback() # Get the first hit position (based on distance from the ray origin). hit_pos = min(hit_pos, key=lambda x: onp.linalg.norm(x - origin)) @@ -71,25 +71,25 @@ def _(event: viser.ScenePointerEvent) -> None: hit_pos_mesh = trimesh.creation.icosphere(radius=0.1) hit_pos_mesh.vertices += R_world_mesh @ hit_pos hit_pos_mesh.visual.vertex_colors = (0.5, 0.0, 0.7, 1.0) # type: ignore - hit_pos_handle = server.add_mesh_trimesh( + hit_pos_handle = server.scene.add_mesh_trimesh( name=f"/hit_pos_{len(hit_pos_handles)}", mesh=hit_pos_mesh ) hit_pos_handles.append(hit_pos_handle) - @client.on_scene_pointer_removed + @client.scene.on_pointer_callback_removed def _(): click_button_handle.disabled = False # Tests "rect-select" scenepointerevent. - paint_button_handle = client.add_gui_button("Paint mesh", icon=viser.Icon.PAINT) + paint_button_handle = client.gui.add_button("Paint mesh", icon=viser.Icon.PAINT) @paint_button_handle.on_click def _(_): paint_button_handle.disabled = True - @client.on_scene_pointer(event_type="rect-select") + @client.scene.on_pointer_event(event_type="rect-select") def _(message: viser.ScenePointerEvent) -> None: - client.remove_scene_pointer_callback() + client.scene.remove_pointer_callback() global mesh_handle camera = message.client.camera @@ -127,18 +127,18 @@ def _(message: viser.ScenePointerEvent) -> None: mesh.visual.vertex_colors = onp.where( # type: ignore mask, (0.5, 0.0, 0.7, 1.0), (0.9, 0.9, 0.9, 1.0) ) - mesh_handle = server.add_mesh_trimesh( + mesh_handle = server.scene.add_mesh_trimesh( name="/mesh", mesh=mesh, position=(0.0, 0.0, 0.0), ) - @client.on_scene_pointer_removed + @client.scene.on_pointer_callback_removed def _(): paint_button_handle.disabled = False # Button to clear spheres. - clear_button_handle = client.add_gui_button("Clear scene", icon=viser.Icon.X) + clear_button_handle = client.gui.add_button("Clear scene", icon=viser.Icon.X) @clear_button_handle.on_click def _(_): @@ -148,7 +148,7 @@ def _(_): handle.remove() hit_pos_handles.clear() mesh.visual.vertex_colors = (0.9, 0.9, 0.9, 1.0) # type: ignore - mesh_handle = server.add_mesh_trimesh( + mesh_handle = server.scene.add_mesh_trimesh( name="/mesh", mesh=mesh, position=(0.0, 0.0, 0.0), diff --git a/examples/21_set_up_direction.py b/examples/21_set_up_direction.py index 98c988d8d..3c2a01a92 100644 --- a/examples/21_set_up_direction.py +++ b/examples/21_set_up_direction.py @@ -9,8 +9,8 @@ def main() -> None: server = viser.ViserServer() - server.world_axes.visible = True - gui_up = server.add_gui_vector3( + server.scene.world_axes.visible = True + gui_up = server.gui.add_vector3( "Up Direction", initial_value=(0.0, 0.0, 1.0), step=0.01, @@ -18,7 +18,7 @@ def main() -> None: @gui_up.on_update def _(_) -> None: - server.set_up_direction(gui_up.value) + server.scene.set_up_direction(gui_up.value) while True: time.sleep(1.0) diff --git a/examples/22_games.py b/examples/22_games.py index a8b37de62..2840ac0cd 100644 --- a/examples/22_games.py +++ b/examples/22_games.py @@ -19,11 +19,11 @@ def main() -> None: server = viser.ViserServer() - server.configure_theme(dark_mode=True) + server.gui.configure_theme(dark_mode=True) play_connect_4(server) - server.add_gui_button("Tic-Tac-Toe").on_click(lambda _: play_tic_tac_toe(server)) - server.add_gui_button("Connect 4").on_click(lambda _: play_connect_4(server)) + server.gui.add_button("Tic-Tac-Toe").on_click(lambda _: play_tic_tac_toe(server)) + server.gui.add_button("Connect 4").on_click(lambda _: play_connect_4(server)) while True: time.sleep(10.0) @@ -31,7 +31,7 @@ def main() -> None: def play_connect_4(server: viser.ViserServer) -> None: """Play a game of Connect 4.""" - server.reset_scene() + server.scene.reset() num_rows = 6 num_cols = 7 @@ -42,7 +42,7 @@ def play_connect_4(server: viser.ViserServer) -> None: # Create the board frame. for col in range(num_cols): for row in range(num_rows): - server.add_mesh_trimesh( + server.scene.add_mesh_trimesh( f"/structure/{row}_{col}", trimesh.creation.annulus(0.45, 0.55, 0.125), position=(0.0, col, row), @@ -51,7 +51,7 @@ def play_connect_4(server: viser.ViserServer) -> None: # Create a sphere to click on for each column. def setup_column(col: int) -> None: - sphere = server.add_icosphere( + sphere = server.scene.add_icosphere( f"/spheres/{col}", radius=0.25, position=(0, col, num_rows - 0.25), @@ -70,7 +70,7 @@ def _(_) -> None: pieces_in_col[col] += 1 cylinder = trimesh.creation.cylinder(radius=0.4, height=0.125) - piece = server.add_mesh_simple( + piece = server.scene.add_mesh_simple( f"/game_pieces/{row}_{col}", cylinder.vertices, cylinder.faces, @@ -91,12 +91,12 @@ def _(_) -> None: def play_tic_tac_toe(server: viser.ViserServer) -> None: """Play a game of tic-tac-toe.""" - server.reset_scene() + server.scene.reset() whose_turn: Literal["x", "o"] = "x" for i in range(4): - server.add_spline_catmull_rom( + server.scene.add_spline_catmull_rom( f"/gridlines/{i}", ((-0.5, -1.5, 0), (-0.5, 1.5, 0)), color=(127, 127, 127), @@ -109,7 +109,7 @@ def draw_symbol(symbol: Literal["x", "o"], i: int, j: int) -> None: for scale in onp.linspace(0.01, 1.0, 5): if symbol == "x": for k in range(2): - server.add_box( + server.scene.add_box( f"/symbols/{i}_{j}/{k}", dimensions=(0.7 * scale, 0.125 * scale, 0.125), position=(i, j, 0), @@ -120,7 +120,7 @@ def draw_symbol(symbol: Literal["x", "o"], i: int, j: int) -> None: ) elif symbol == "o": mesh = trimesh.creation.annulus(0.25 * scale, 0.35 * scale, 0.125) - server.add_mesh_simple( + server.scene.add_mesh_simple( f"/symbols/{i}_{j}", mesh.vertices, mesh.faces, @@ -134,7 +134,7 @@ def draw_symbol(symbol: Literal["x", "o"], i: int, j: int) -> None: def setup_cell(i: int, j: int) -> None: """Create a clickable sphere in a given cell.""" - sphere = server.add_icosphere( + sphere = server.scene.add_icosphere( f"/spheres/{i}_{j}", radius=0.25, position=(i, j, 0), diff --git a/examples/23_plotly.py b/examples/23_plotly.py index 54f5f52a1..48cfd36b5 100644 --- a/examples/23_plotly.py +++ b/examples/23_plotly.py @@ -37,14 +37,14 @@ def main() -> None: # Plot type 1: Line plot. line_plot_time = 0.0 - line_plot = server.add_gui_plotly(figure=create_sinusoidal_wave(line_plot_time)) + line_plot = server.gui.add_plotly(figure=create_sinusoidal_wave(line_plot_time)) # Plot type 2: Image plot. fig = px.imshow(Image.open("assets/Cal_logo.png")) fig.update_layout( margin=dict(l=20, r=20, t=20, b=20), ) - server.add_gui_plotly(figure=fig, aspect=1.0) + server.gui.add_plotly(figure=fig, aspect=1.0) # Plot type 3: 3D Scatter plot. fig = px.scatter_3d( @@ -58,7 +58,7 @@ def main() -> None: fig.update_layout( margin=dict(l=20, r=20, t=20, b=20), ) - server.add_gui_plotly(figure=fig, aspect=1.0) + server.gui.add_plotly(figure=fig, aspect=1.0) while True: # Update the line plot. diff --git a/pyproject.toml b/pyproject.toml index eaf78bfce..e2e975b58 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "viser" -version = "0.1.29" +version = "0.1.30" description = "3D visualization + Python" readme = "README.md" license = { text="MIT" } diff --git a/src/viser/__init__.py b/src/viser/__init__.py index 803ae3802..e78d641b8 100644 --- a/src/viser/__init__.py +++ b/src/viser/__init__.py @@ -27,8 +27,3 @@ from ._viser import CameraHandle as CameraHandle from ._viser import ClientHandle as ClientHandle from ._viser import ViserServer as ViserServer - -if not TYPE_CHECKING: - # Backwards compatibility. - GuiHandle = GuiInputHandle - ClickEvent = SceneNodePointerEvent diff --git a/src/viser/_gui_api.py b/src/viser/_gui_api.py index f89942a57..1adb7cff3 100644 --- a/src/viser/_gui_api.py +++ b/src/viser/_gui_api.py @@ -4,7 +4,7 @@ # - https://github.com/python/mypy/issues/12554 from __future__ import annotations -import abc +import colorsys import dataclasses import functools import threading @@ -35,6 +35,8 @@ get_type_hints, ) +from viser import theme + from . import _messages from ._gui_handles import ( GuiButtonGroupHandle, @@ -57,14 +59,16 @@ ) from ._icons import svg_from_icon from ._icons_enum import IconName -from ._message_api import MessageApi, cast_vector from ._messages import FileTransferPartAck +from ._scene_api import cast_vector if TYPE_CHECKING: import plotly.graph_objects as go + from ._viser import ClientHandle, ViserServer from .infra import ClientId + IntOrFloat = TypeVar("IntOrFloat", int, float) TString = TypeVar("TString", bound=str) TLiteralString = TypeVar("TLiteralString", bound=LiteralString) @@ -180,12 +184,25 @@ class FileUploadState(TypedDict): lock: threading.Lock -class GuiApi(abc.ABC): +class GuiApi: _target_container_from_thread_id: Dict[int, str] = {} """ID of container to put GUI elements into.""" - def __init__(self) -> None: - super().__init__() + def __init__( + self, + owner: ViserServer | ClientHandle, # Who do I belong to? + ) -> None: + from ._viser import ViserServer + + self._owner = owner + """Entity that owns this API.""" + + self._websock_interface = ( + owner._websock_server + if isinstance(owner, ViserServer) + else owner._websock_connection + ) + """Interface for sending and listening to messages.""" self._gui_handle_from_id: Dict[str, _GuiInputHandle[Any]] = {} self._container_handle_from_id: Dict[str, GuiContainerProtocol] = { @@ -196,13 +213,13 @@ def __init__(self) -> None: # Set to True when plotly.min.js has been sent to client. self._setup_plotly_js: bool = False - self._get_api()._message_handler.register_handler( + self._websock_interface.register_handler( _messages.GuiUpdateMessage, self._handle_gui_updates ) - self._get_api()._message_handler.register_handler( + self._websock_interface.register_handler( _messages.FileTransferStart, self._handle_file_transfer_start ) - self._get_api()._message_handler.register_handler( + self._websock_interface.register_handler( _messages.FileTransferPart, self._handle_file_transfer_part, ) @@ -245,11 +262,10 @@ def _handle_gui_updates( from ._viser import ClientHandle, ViserServer # Get the handle of the client that triggered this event. - api = self._get_api() - if isinstance(api, ClientHandle): - client = api - elif isinstance(api, ViserServer): - client = api.get_clients()[client_id] + if isinstance(self._owner, ClientHandle): + client = self._owner + elif isinstance(self._owner, ViserServer): + client = self._owner.get_clients()[client_id] else: assert False @@ -288,7 +304,7 @@ def _handle_file_transfer_part( state["transferred_bytes"] += len(message.content) # Send ack to the server. - self._get_api()._queue( + self._websock_interface.queue_message( FileTransferPartAck( source_component_id=message.source_component_id, transfer_uuid=message.transfer_uuid, @@ -316,7 +332,7 @@ def _handle_file_transfer_part( ) # Update state. - with self._get_api()._atomic_lock: + with self._owner.atomic(): handle_state.value = value handle_state.update_timestamp = time.time() @@ -325,11 +341,10 @@ def _handle_file_transfer_part( from ._viser import ClientHandle, ViserServer # Get the handle of the client that triggered this event. - api = self._get_api() - if isinstance(api, ClientHandle): - client = api - elif isinstance(api, ViserServer): - client = api.get_clients()[client_id] + if isinstance(self._owner, ClientHandle): + client = self._owner + elif isinstance(self._owner, ViserServer): + client = self._owner.get_clients()[client_id] else: assert False @@ -343,22 +358,99 @@ def _set_container_id(self, container_id: str) -> None: """Set container ID associated with the current thread.""" self._target_container_from_thread_id[threading.get_ident()] = container_id - @abc.abstractmethod - def _get_api(self) -> MessageApi: - """Message API to use.""" - ... - if not TYPE_CHECKING: def gui_folder(self, label: str) -> GuiFolderHandle: """Deprecated.""" warnings.warn( - "gui_folder() is deprecated. Use add_gui_folder() instead!", + "gui_folder() is deprecated. Use add_folder() instead!", stacklevel=2, ) - return self.add_gui_folder(label) + return self.add_folder(label) + + def set_gui_panel_label(self, label: Optional[str]) -> None: + """Set the main label that appears in the GUI panel. + + Args: + label: The new label. + """ + self._websock_interface.queue_message(_messages.SetGuiPanelLabelMessage(label)) + + def configure_theme( + self, + *, + titlebar_content: Optional[theme.TitlebarConfig] = None, + control_layout: Literal["floating", "collapsible", "fixed"] = "floating", + control_width: Literal["small", "medium", "large"] = "medium", + dark_mode: bool = False, + show_logo: bool = True, + show_share_button: bool = True, + brand_color: Optional[Tuple[int, int, int]] = None, + ) -> None: + """Configures the visual appearance of the viser front-end. + + Args: + titlebar_content: Optional configuration for the title bar. + control_layout: The layout of control elements, options are "floating", + "collapsible", or "fixed". + control_width: The width of control elements, options are "small", + "medium", or "large". + dark_mode: A boolean indicating if dark mode should be enabled. + show_logo: A boolean indicating if the logo should be displayed. + show_share_button: A boolean indicating if the share button should be displayed. + brand_color: An optional tuple of integers (RGB) representing the brand color. + """ - def add_gui_folder( + colors_cast: Optional[ + Tuple[str, str, str, str, str, str, str, str, str, str] + ] = None + + if brand_color is not None: + assert len(brand_color) in (3, 10) + if len(brand_color) == 3: + assert all( + map(lambda val: isinstance(val, int), brand_color) + ), "All channels should be integers." + + # RGB => HLS. + h, l, s = colorsys.rgb_to_hls( + brand_color[0] / 255.0, + brand_color[1] / 255.0, + brand_color[2] / 255.0, + ) + + # Automatically generate a 10-color palette. + min_l = max(l - 0.08, 0.0) + max_l = min(0.8 + 0.5, 0.9) + l = max(min_l, min(max_l, l)) + + primary_index = 8 + ls = tuple( + onp.interp( + x=onp.arange(10), + xp=onp.array([0, primary_index, 9]), + fp=onp.array([max_l, l, min_l]), + ) + ) + colors_cast = tuple(_hex_from_hls(h, ls[i], s) for i in range(10)) # type: ignore + + assert colors_cast is None or all( + [isinstance(val, str) and val.startswith("#") for val in colors_cast] + ), "All string colors should be in hexadecimal + prefixed with #, eg #ffffff." + + self._websock_interface.queue_message( + _messages.ThemeConfigurationMessage( + titlebar_content=titlebar_content, + control_layout=control_layout, + control_width=control_width, + dark_mode=dark_mode, + show_logo=show_logo, + show_share_button=show_share_button, + colors=colors_cast, + ), + ) + + def add_folder( self, label: str, order: Optional[float] = None, @@ -379,7 +471,7 @@ def add_gui_folder( """ folder_container_id = _make_unique_id() order = _apply_default_order(order) - self._get_api()._queue( + self._websock_interface.queue_message( _messages.GuiAddFolderMessage( order=order, id=folder_container_id, @@ -396,7 +488,7 @@ def add_gui_folder( _order=order, ) - def add_gui_modal( + def add_modal( self, title: str, order: Optional[float] = None, @@ -413,7 +505,7 @@ def add_gui_modal( """ modal_container_id = _make_unique_id() order = _apply_default_order(order) - self._get_api()._queue( + self._websock_interface.queue_message( _messages.GuiModalMessage( order=order, id=modal_container_id, @@ -425,7 +517,7 @@ def add_gui_modal( _id=modal_container_id, ) - def add_gui_tab_group( + def add_tab_group( self, order: Optional[float] = None, visible: bool = True, @@ -442,7 +534,7 @@ def add_gui_tab_group( tab_group_id = _make_unique_id() order = _apply_default_order(order) - self._get_api()._queue( + self._websock_interface.queue_message( _messages.GuiAddTabGroupMessage( order=order, id=tab_group_id, @@ -462,7 +554,7 @@ def add_gui_tab_group( _order=order, ) - def add_gui_markdown( + def add_markdown( self, content: str, image_root: Optional[Path] = None, @@ -489,7 +581,7 @@ def add_gui_markdown( _image_root=image_root, _content=None, ) - self._get_api()._queue( + self._websock_interface.queue_message( _messages.GuiAddMarkdownMessage( order=handle._order, id=handle._id, @@ -504,7 +596,7 @@ def add_gui_markdown( handle.content = content return handle - def add_gui_plotly( + def add_plotly( self, figure: go.Figure, aspect: float = 1.0, @@ -554,14 +646,16 @@ def add_gui_plotly( # Send it over! plotly_js = plotly_path.read_text(encoding="utf-8") - self._get_api()._queue(_messages.RunJavascriptMessage(source=plotly_js)) + self._websock_interface.queue_message( + _messages.RunJavascriptMessage(source=plotly_js) + ) # Update the flag so we don't send it again. self._setup_plotly_js = True # After plotly.min.js has been sent, we can send the plotly figure. # Empty string for `plotly_json_str` is a signal to the client to render nothing. - self._get_api()._queue( + self._websock_interface.queue_message( _messages.GuiAddPlotlyMessage( order=handle._order, id=handle._id, @@ -578,7 +672,7 @@ def add_gui_plotly( return handle - def add_gui_button( + def add_button( self, label: str, disabled: bool = False, @@ -643,7 +737,7 @@ def add_gui_button( )._impl ) - def add_gui_upload_button( + def add_upload_button( self, label: str, disabled: bool = False, @@ -717,7 +811,7 @@ def add_gui_upload_button( # TString is helpful when the input types are generic (could be str, could be # Literal). @overload - def add_gui_button_group( + def add_button_group( self, label: str, options: Sequence[TLiteralString], @@ -725,10 +819,11 @@ def add_gui_button_group( disabled: bool = False, hint: Optional[str] = None, order: Optional[float] = None, - ) -> GuiButtonGroupHandle[TLiteralString]: ... + ) -> GuiButtonGroupHandle[TLiteralString]: + ... @overload - def add_gui_button_group( + def add_button_group( self, label: str, options: Sequence[TString], @@ -736,9 +831,10 @@ def add_gui_button_group( disabled: bool = False, hint: Optional[str] = None, order: Optional[float] = None, - ) -> GuiButtonGroupHandle[TString]: ... + ) -> GuiButtonGroupHandle[TString]: + ... - def add_gui_button_group( + def add_button_group( self, label: str, options: Sequence[TLiteralString] | Sequence[TString], @@ -780,7 +876,7 @@ def add_gui_button_group( )._impl, ) - def add_gui_checkbox( + def add_checkbox( self, label: str, initial_value: bool, @@ -820,7 +916,7 @@ def add_gui_checkbox( ), ) - def add_gui_text( + def add_text( self, label: str, initial_value: str, @@ -860,7 +956,7 @@ def add_gui_text( ), ) - def add_gui_number( + def add_number( self, label: str, initial_value: IntOrFloat, @@ -929,7 +1025,7 @@ def add_gui_number( is_button=False, ) - def add_gui_vector2( + def add_vector2( self, label: str, initial_value: Tuple[float, float] | onp.ndarray, @@ -991,7 +1087,7 @@ def add_gui_vector2( ), ) - def add_gui_vector3( + def add_vector3( self, label: str, initial_value: Tuple[float, float, float] | onp.ndarray, @@ -1053,9 +1149,9 @@ def add_gui_vector3( ), ) - # See add_gui_dropdown for notes on overloads. + # See add_dropdown for notes on overloads. @overload - def add_gui_dropdown( + def add_dropdown( self, label: str, options: Sequence[TLiteralString], @@ -1064,10 +1160,11 @@ def add_gui_dropdown( visible: bool = True, hint: Optional[str] = None, order: Optional[float] = None, - ) -> GuiDropdownHandle[TLiteralString]: ... + ) -> GuiDropdownHandle[TLiteralString]: + ... @overload - def add_gui_dropdown( + def add_dropdown( self, label: str, options: Sequence[TString], @@ -1076,9 +1173,10 @@ def add_gui_dropdown( visible: bool = True, hint: Optional[str] = None, order: Optional[float] = None, - ) -> GuiDropdownHandle[TString]: ... + ) -> GuiDropdownHandle[TString]: + ... - def add_gui_dropdown( + def add_dropdown( self, label: str, options: Sequence[TLiteralString] | Sequence[TString], @@ -1125,7 +1223,7 @@ def add_gui_dropdown( _impl_options=tuple(options), ) - def add_gui_slider( + def add_slider( self, label: str, min: IntOrFloat, @@ -1206,7 +1304,7 @@ def add_gui_slider( is_button=False, ) - def add_gui_multi_slider( + def add_multi_slider( self, label: str, min: IntOrFloat, @@ -1290,7 +1388,7 @@ def add_gui_multi_slider( is_button=False, ) - def add_gui_rgb( + def add_rgb( self, label: str, initial_value: Tuple[int, int, int], @@ -1330,7 +1428,7 @@ def add_gui_rgb( ), ) - def add_gui_rgba( + def add_rgba( self, label: str, initial_value: Tuple[int, int, int, int], @@ -1378,7 +1476,7 @@ def _create_gui_input( """Private helper for adding a simple GUI element.""" # Send add GUI input message. - self._get_api()._queue(message) + self._websock_interface.queue_message(message) # Construct handle. handle_state = _GuiHandleState( @@ -1408,7 +1506,7 @@ def sync_other_clients( ) -> None: message = _messages.GuiUpdateMessage(handle_state.id, updates) message.excluded_self_client = client_id - self._get_api()._queue(message) + self._websock_interface.queue_message(message) handle_state.sync_cb = sync_other_clients diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index 200049d50..579c42421 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -27,8 +27,8 @@ from ._icons import svg_from_icon from ._icons_enum import IconName -from ._message_api import _encode_image_base64 from ._messages import GuiCloseModalMessage, GuiRemoveMessage, GuiUpdateMessage, Message +from ._scene_api import _encode_image_base64 from .infra import ClientId if TYPE_CHECKING: @@ -54,7 +54,8 @@ class GuiContainerProtocol(Protocol): class SupportsRemoveProtocol(Protocol): - def remove(self) -> None: ... + def remove(self) -> None: + ... @dataclasses.dataclass diff --git a/src/viser/_message_api.py b/src/viser/_scene_api.py similarity index 84% rename from src/viser/_message_api.py rename to src/viser/_scene_api.py index ef86bf759..5bc7d6528 100644 --- a/src/viser/_message_api.py +++ b/src/viser/_scene_api.py @@ -7,14 +7,9 @@ from __future__ import annotations -import abc import base64 import colorsys -import contextlib import io -import mimetypes -import queue -import threading import time import warnings from concurrent.futures import ThreadPoolExecutor @@ -22,7 +17,6 @@ TYPE_CHECKING, Callable, Dict, - Generator, Optional, Tuple, TypeVar, @@ -36,7 +30,7 @@ import numpy.typing as onpt from typing_extensions import Literal, ParamSpec, TypeAlias, assert_never -from . import _messages, infra, theme +from . import _messages from . import transforms as tf from ._scene_handles import ( BatchedAxesHandle, @@ -52,13 +46,14 @@ SceneNodePointerEvent, ScenePointerEvent, TransformControlsHandle, + _SceneNodeHandleState, _TransformControlsState, ) if TYPE_CHECKING: import trimesh - from ._viser import ClientHandle + from ._viser import ClientHandle, ViserServer from .infra import ClientId @@ -140,20 +135,42 @@ def cast_vector(vector: TVector | onp.ndarray, length: int) -> TVector: return cast(TVector, tuple(map(float, vector))) -class MessageApi(abc.ABC): +class SceneApi: """Interface for all commands we can use to send messages over a websocket connection. Should be implemented by both our global server object (for broadcasting) and by invidividual clients.""" - _locked_thread_id: int # Appeasing mypy 1.5.1, not sure why this is needed. - def __init__( - self, handler: infra.MessageHandler, thread_executor: ThreadPoolExecutor + self, + owner: ViserServer | ClientHandle, # Who do I belong to? + thread_executor: ThreadPoolExecutor, ) -> None: - self._message_handler = handler + from ._viser import ViserServer + + self._owner = owner + """Entity that owns this API.""" + + self._websock_interface = ( + owner._websock_server + if isinstance(owner, ViserServer) + else owner._websock_connection + ) + """Interface for sending and listening to messages.""" + + self.world_axes = FrameHandle( + _SceneNodeHandleState( + "/WorldAxes", + self, + wxyz=onp.array([1.0, 0.0, 0.0, 0.0]), + position=onp.zeros(3), + ) + ) + """Handle for the world axes, which are hardcoded to exist.""" - super().__init__() + # Hide world axes on initialization. + if isinstance(owner, ViserServer): + self.world_axes.visible = False self._handle_from_transform_controls_name: Dict[ str, TransformControlsHandle @@ -164,106 +181,21 @@ def __init__( self._scene_pointer_done_cb: Callable[[], None] = lambda: None self._scene_pointer_event_type: Optional[_messages.ScenePointerEventType] = None - handler.register_handler( + self._websock_interface.register_handler( _messages.TransformControlsUpdateMessage, self._handle_transform_controls_updates, ) - handler.register_handler( + self._websock_interface.register_handler( _messages.SceneNodeClickMessage, self._handle_node_click_updates, ) - handler.register_handler( + self._websock_interface.register_handler( _messages.ScenePointerMessage, self._handle_scene_pointer_updates, ) - self._atomic_lock = threading.Lock() - self._queued_messages: queue.Queue = queue.Queue() - self._locked_thread_id = -1 self._thread_executor = thread_executor - def set_gui_panel_label(self, label: Optional[str]) -> None: - """Set the main label that appears in the GUI panel. - - Args: - label: The new label. - """ - self._queue(_messages.SetGuiPanelLabelMessage(label)) - - def configure_theme( - self, - *, - titlebar_content: Optional[theme.TitlebarConfig] = None, - control_layout: Literal["floating", "collapsible", "fixed"] = "floating", - control_width: Literal["small", "medium", "large"] = "medium", - dark_mode: bool = False, - show_logo: bool = True, - show_share_button: bool = True, - brand_color: Optional[Tuple[int, int, int]] = None, - ) -> None: - """Configures the visual appearance of the viser front-end. - - Args: - titlebar_content: Optional configuration for the title bar. - control_layout: The layout of control elements, options are "floating", - "collapsible", or "fixed". - control_width: The width of control elements, options are "small", - "medium", or "large". - dark_mode: A boolean indicating if dark mode should be enabled. - show_logo: A boolean indicating if the logo should be displayed. - show_share_button: A boolean indicating if the share button should be displayed. - brand_color: An optional tuple of integers (RGB) representing the brand color. - """ - - colors_cast: Optional[ - Tuple[str, str, str, str, str, str, str, str, str, str] - ] = None - - if brand_color is not None: - assert len(brand_color) in (3, 10) - if len(brand_color) == 3: - assert all( - map(lambda val: isinstance(val, int), brand_color) - ), "All channels should be integers." - - # RGB => HLS. - h, l, s = colorsys.rgb_to_hls( - brand_color[0] / 255.0, - brand_color[1] / 255.0, - brand_color[2] / 255.0, - ) - - # Automatically generate a 10-color palette. - min_l = max(l - 0.08, 0.0) - max_l = min(0.8 + 0.5, 0.9) - l = max(min_l, min(max_l, l)) - - primary_index = 8 - ls = tuple( - onp.interp( - x=onp.arange(10), - xp=onp.array([0, primary_index, 9]), - fp=onp.array([max_l, l, min_l]), - ) - ) - colors_cast = tuple(_hex_from_hls(h, ls[i], s) for i in range(10)) # type: ignore - - assert colors_cast is None or all( - [isinstance(val, str) and val.startswith("#") for val in colors_cast] - ), "All string colors should be in hexadecimal + prefixed with #, eg #ffffff." - - self._queue( - _messages.ThemeConfigurationMessage( - titlebar_content=titlebar_content, - control_layout=control_layout, - control_width=control_width, - dark_mode=dark_mode, - show_logo=show_logo, - show_share_button=show_share_button, - colors=colors_cast, - ), - ) - def set_up_direction( self, direction: Literal["+x", "+y", "+z", "-x", "-y", "-z"] @@ -328,7 +260,7 @@ def rotate_between(before: onp.ndarray, after: onp.ndarray) -> tf.SO3: if not onp.any(onp.isnan(R_threeworld_world.wxyz)): # Set the orientation of the root node. - self._queue( + self._websock_interface.queue_message( _messages.SetOrientationMessage( "", cast_vector(R_threeworld_world.wxyz, 4) ) @@ -340,7 +272,9 @@ def set_global_scene_node_visibility(self, visible: bool) -> None: Args: visible: Whether or not all scene nodes should be visible. """ - self._queue(_messages.SetSceneNodeVisibilityMessage("", visible)) + self._websock_interface.queue_message( + _messages.SetSceneNodeVisibilityMessage("", visible) + ) def add_glb( self, @@ -371,7 +305,9 @@ def add_glb( Returns: Handle for manipulating scene node. """ - self._queue(_messages.GlbMessage(name, glb_data, scale)) + self._websock_interface.queue_message( + _messages.GlbMessage(name, glb_data, scale) + ) return GlbHandle._make(self, name, wxyz, position, visible) def add_spline_catmull_rom( @@ -415,7 +351,7 @@ def add_spline_catmull_rom( positions = tuple(map(tuple, positions)) # type: ignore assert len(positions[0]) == 3 assert isinstance(positions, tuple) - self._queue( + self._websock_interface.queue_message( _messages.CatmullRomSplineMessage( name, positions, @@ -473,7 +409,7 @@ def add_spline_cubic_bezier( assert isinstance(positions, tuple) assert isinstance(control_points, tuple) assert len(control_points) == (2 * len(positions) - 2) - self._queue( + self._websock_interface.queue_message( _messages.CubicBezierSplineMessage( name, positions, @@ -534,7 +470,7 @@ def add_camera_frustum( media_type = None base64_data = None - self._queue( + self._websock_interface.queue_message( _messages.CameraFrustumMessage( name=name, fov=fov, @@ -586,7 +522,7 @@ def add_frame( """ if origin_radius is None: origin_radius = axes_radius * 2 - self._queue( + self._websock_interface.queue_message( _messages.FrameMessage( name=name, show_axes=show_axes, @@ -641,7 +577,7 @@ def add_batched_axes( num_axes = batched_wxyzs.shape[0] assert batched_wxyzs.shape == (num_axes, 4) assert batched_positions.shape == (num_axes, 3) - self._queue( + self._websock_interface.queue_message( _messages.BatchedAxesMessage( name=name, wxyzs_batched=batched_wxyzs.astype(onp.float32), @@ -694,7 +630,7 @@ def add_grid( Returns: Handle for manipulating scene node. """ - self._queue( + self._websock_interface.queue_message( _messages.GridMessage( name=name, width=width, @@ -735,7 +671,7 @@ def add_label( Returns: Handle for manipulating scene node. """ - self._queue(_messages.LabelMessage(name, text)) + self._websock_interface.queue_message(_messages.LabelMessage(name, text)) return LabelHandle._make(self, name, wxyz, position, visible=visible) def add_point_cloud( @@ -778,7 +714,7 @@ def add_point_cloud( if colors_cast.shape == (3,): colors_cast = onp.tile(colors_cast[None, :], reps=(points.shape[0], 1)) - self._queue( + self._websock_interface.queue_message( _messages.PointCloudMessage( name=name, points=points.astype(onp.float32), @@ -848,7 +784,7 @@ def add_mesh_simple( stacklevel=2, ) - self._queue( + self._websock_interface.queue_message( _messages.MeshMessage( name, vertices.astype(onp.float32), @@ -1021,7 +957,7 @@ def set_background_image( "ascii" ) - self._queue( + self._websock_interface.queue_message( _messages.BackgroundImageMessage( media_type=media_type, base64_rgb=base64_data, @@ -1062,7 +998,7 @@ def add_image( media_type, base64_data = _encode_image_base64( image, format, jpeg_quality=jpeg_quality ) - self._queue( + self._websock_interface.queue_message( _messages.ImageMessage( name=name, media_type=media_type, @@ -1123,7 +1059,7 @@ def add_transform_controls( Returns: Handle for manipulating (and reading state of) scene node. """ - self._queue( + self._websock_interface.queue_message( _messages.TransformControlsMessage( name=name, scale=scale, @@ -1147,14 +1083,14 @@ def sync_cb(client_id: ClientId, state: TransformControlsHandle) -> None: wxyz=tuple(map(float, state._impl.wxyz)), # type: ignore ) message_orientation.excluded_self_client = client_id - self._queue(message_orientation) + self._websock_interface.queue_message(message_orientation) message_position = _messages.SetPositionMessage( name=name, position=tuple(map(float, state._impl.position)), # type: ignore ) message_position.excluded_self_client = client_id - self._queue(message_position) + self._websock_interface.queue_message(message_position) node_handle = SceneNodeHandle._make(self, name, wxyz, position, visible) state_aux = _TransformControlsState( @@ -1166,31 +1102,9 @@ def sync_cb(client_id: ClientId, state: TransformControlsHandle) -> None: self._handle_from_transform_controls_name[name] = handle return handle - def reset_scene(self) -> None: + def reset(self) -> None: """Reset the scene.""" - self._queue(_messages.ResetSceneMessage()) - - def _queue(self, message: _messages.Message) -> None: - """Wrapped method for sending messages safely.""" - got_lock = self._atomic_lock.acquire(blocking=False) - if got_lock: - self._queue_unsafe(message) - self._atomic_lock.release() - else: - # Send when lock is acquirable, while retaining message order. - # This could be optimized! - self._queued_messages.put(message) - - def try_again() -> None: - with self._atomic_lock: - self._queue_unsafe(self._queued_messages.get()) - - self._thread_executor.submit(try_again) - - @abc.abstractmethod - def _queue_unsafe(self, message: _messages.Message) -> None: - """Abstract method for sending messages.""" - ... + self._websock_interface.queue_message(_messages.ResetSceneMessage()) def _get_client_handle(self, client_id: ClientId) -> ClientHandle: """Private helper for getting a client handle from its ID.""" @@ -1222,6 +1136,7 @@ def _handle_transform_controls_updates( # Update state. wxyz = onp.array(message.wxyz) position = onp.array(message.position) + assert False, "TODO implement atomic()" with self.atomic(): handle._impl.wxyz = wxyz handle._impl.position = position @@ -1268,20 +1183,7 @@ def _handle_scene_pointer_updates( return self._scene_pointer_cb(event) - def on_scene_click( - self, - func: Callable[[ScenePointerEvent], None], - ) -> Callable[[ScenePointerEvent], None]: - """Deprecated. Use `on_scene_pointer` instead. - - Registers a callback for scene click events. (event_type == "click") - - Args: - func: The callback function to add. - """ - return self.on_scene_pointer(event_type="click")(func) - - def on_scene_pointer( + def on_pointer_event( self, event_type: Literal["click", "rect-select"] ) -> Callable[ [Callable[[ScenePointerEvent], None]], Callable[[ScenePointerEvent], None] @@ -1296,44 +1198,43 @@ def on_scene_pointer( from ._viser import ClientHandle, ViserServer - def cleanup_previous_event(msg_api: MessageApi): + def cleanup_previous_event(target: ViserServer | ClientHandle): # If the server or client does not have a scene pointer callback, return. - if msg_api._scene_pointer_cb is None: + if target.scene._scene_pointer_cb is None: return # Remove callback. - msg_api.remove_scene_pointer_callback() + target.scene.remove_pointer_callback() def decorator( func: Callable[[ScenePointerEvent], None], ) -> Callable[[ScenePointerEvent], None]: # Check if another scene pointer event was previously registered. # If so, we need to clear the previous event and register the new one. - cleanup_previous_event(self) + cleanup_previous_event(self._owner) # If called on the server handle, remove all clients' callbacks. - if isinstance(self, ViserServer): - clients = list(self.get_clients().values()) - for client in clients: + if isinstance(self._owner, ViserServer): + for client in self._owner.get_clients().values(): cleanup_previous_event(client) # If called on the client handle, and server handle has a callback, remove the server's callback. # (If the server has a callback, none of the clients should have callbacks.) - elif isinstance(self, ClientHandle): - server = self._state.viser_server + elif isinstance(self._owner, ClientHandle): + server = self._owner._viser_server cleanup_previous_event(server) self._scene_pointer_cb = func self._scene_pointer_event_type = event_type - self._queue( + self._websock_interface.queue_message( _messages.ScenePointerEnableMessage(enable=True, event_type=event_type) ) return func return decorator - def on_scene_pointer_removed( + def on_pointer_callback_removed( self, func: Callable[[], None], ) -> Callable[[], None]: @@ -1348,7 +1249,7 @@ def on_scene_pointer_removed( self._scene_pointer_done_cb = func return func - def remove_scene_pointer_callback( + def remove_pointer_callback( self, ) -> None: """Remove the currently attached scene pointer event. This will trigger @@ -1364,10 +1265,11 @@ def remove_scene_pointer_callback( # Notify client that the listener has been removed. event_type = self._scene_pointer_event_type assert event_type is not None - self._queue( + self._websock_interface.queue_message( _messages.ScenePointerEnableMessage(enable=False, event_type=event_type) ) - self.flush() + assert False, "TODO implement flush" + # self.flush() # Run cleanup callback. self._scene_pointer_done_cb() @@ -1413,7 +1315,7 @@ def add_3d_gui_container( gui_api._handle_from_node_name[name].remove() container_id = _make_unique_id() - self._queue( + self._websock_interface.queue_message( _messages.Gui3DMessage( order=time.time(), name=name, @@ -1422,83 +1324,3 @@ def add_3d_gui_container( ) node_handle = SceneNodeHandle._make(self, name, wxyz, position, visible=visible) return Gui3dContainerHandle(node_handle._impl, gui_api, container_id) - - def send_file_download( - self, filename: str, content: bytes, chunk_size: int = 1024 * 1024 - ) -> None: - """Send a file for a client or clients to download. - - Args: - filename: Name of the file to send. Used to infer MIME type. - content: Content of the file. - chunk_size: Number of bytes to send at a time. - """ - mime_type = mimetypes.guess_type(filename, strict=False)[0] - assert ( - mime_type is not None - ), f"Could not guess MIME type from filename {filename}!" - - from ._gui_api import _make_unique_id - - parts = [ - content[i * chunk_size : (i + 1) * chunk_size] - for i in range(int(onp.ceil(len(content) / chunk_size))) - ] - - uuid = _make_unique_id() - - from ._viser import ClientHandle, ViserServer - - # If called on the server handle, send the file to each client. - # If called on the client handle, send the file to just that client. - # - # We avoid calling ViserServer._queue() here because it will create a - # "persistent" message, which is saved and sent to all new clients in - # the future. While this makes sense for things like GUI components or - # 3D assets, this produces unintuitive behavior for file downloads. - if isinstance(self, ViserServer): - clients = list(self.get_clients().values()) - elif isinstance(self, ClientHandle): - clients = [self] - else: - assert False - - for client in clients: - client._queue( - _messages.FileTransferStart( - source_component_id=None, - transfer_uuid=uuid, - filename=filename, - mime_type=mime_type, - part_count=len(parts), - size_bytes=len(content), - ) - ) - - for i, part in enumerate(parts): - client._queue( - _messages.FileTransferPart( - None, - transfer_uuid=uuid, - part=i, - content=part, - ) - ) - client.flush() - - @abc.abstractmethod - def flush(self) -> None: - """Flush the outgoing message buffer. Any buffered messages will immediately be - sent. (by default they are windowed)""" - raise NotImplementedError() - - @contextlib.contextmanager - @abc.abstractmethod - def atomic(self) -> Generator[None, None, None]: - """Returns a context where: all outgoing messages are grouped and applied by - clients atomically. - - This can be helpful for things like animations, or when we want position and - orientation updates to happen synchronously. - """ - raise NotImplementedError() diff --git a/src/viser/_scene_handles.py b/src/viser/_scene_handles.py index a2022d818..35a6f0471 100644 --- a/src/viser/_scene_handles.py +++ b/src/viser/_scene_handles.py @@ -13,7 +13,6 @@ Generic, List, Literal, - Optional, Tuple, Type, TypeVar, @@ -26,8 +25,9 @@ if TYPE_CHECKING: from ._gui_api import GuiApi from ._gui_handles import SupportsRemoveProtocol - from ._message_api import ClientId, MessageApi + from ._scene_api import SceneApi from ._viser import ClientHandle + from .infra import ClientId @dataclasses.dataclass(frozen=True) @@ -40,9 +40,9 @@ class ScenePointerEvent: """ID of client that triggered this event.""" event_type: _messages.ScenePointerEventType """Type of event that was triggered. Currently we only support clicks and box selections.""" - ray_origin: Optional[Tuple[float, float, float]] + ray_origin: Tuple[float, float, float] | None """Origin of 3D ray corresponding to this click, in world coordinates.""" - ray_direction: Optional[Tuple[float, float, float]] + ray_direction: Tuple[float, float, float] | None """Direction of 3D ray corresponding to this click, in world coordinates.""" screen_pos: List[Tuple[float, float]] """Screen position of the click on the screen (OpenCV image coordinates, 0 to 1). @@ -61,7 +61,7 @@ def event(self): @dataclasses.dataclass class _SceneNodeHandleState: name: str - api: MessageApi + api: SceneApi wxyz: onp.ndarray = dataclasses.field( default_factory=lambda: onp.array([1.0, 0.0, 0.0, 0.0]) ) @@ -70,9 +70,9 @@ class _SceneNodeHandleState: ) visible: bool = True # TODO: we should remove SceneNodeHandle as an argument here. - click_cb: Optional[ - List[Callable[[SceneNodePointerEvent[SceneNodeHandle]], None]] - ] = None + click_cb: List[ + Callable[[SceneNodePointerEvent[SceneNodeHandle]], None] + ] | None = None @dataclasses.dataclass @@ -84,7 +84,7 @@ class SceneNodeHandle: @classmethod def _make( cls: Type[TSceneNodeHandle], - api: MessageApi, + api: SceneApi, name: str, wxyz: Tuple[float, float, float, float] | onp.ndarray, position: Tuple[float, float, float] | onp.ndarray, @@ -111,11 +111,11 @@ def wxyz(self) -> onp.ndarray: @wxyz.setter def wxyz(self, wxyz: Tuple[float, float, float, float] | onp.ndarray) -> None: - from ._message_api import cast_vector + from ._scene_api import cast_vector wxyz_cast = cast_vector(wxyz, 4) self._impl.wxyz = onp.asarray(wxyz) - self._impl.api._queue( + self._impl.api._websock_interface.queue_message( _messages.SetOrientationMessage(self._impl.name, wxyz_cast) ) @@ -128,11 +128,11 @@ def position(self) -> onp.ndarray: @position.setter def position(self, position: Tuple[float, float, float] | onp.ndarray) -> None: - from ._message_api import cast_vector + from ._scene_api import cast_vector position_cast = cast_vector(position, 3) self._impl.position = onp.asarray(position) - self._impl.api._queue( + self._impl.api._websock_interface.queue_message( _messages.SetPositionMessage(self._impl.name, position_cast) ) @@ -145,14 +145,16 @@ def visible(self) -> bool: def visible(self, visible: bool) -> None: if visible == self._impl.visible: return - self._impl.api._queue( + self._impl.api._websock_interface.queue_message( _messages.SetSceneNodeVisibilityMessage(self._impl.name, visible) ) self._impl.visible = visible def remove(self) -> None: """Remove the node from the scene.""" - self._impl.api._queue(_messages.RemoveSceneNodeMessage(self._impl.name)) + self._impl.api._websock_interface.queue_message( + _messages.RemoveSceneNodeMessage(self._impl.name) + ) @dataclasses.dataclass(frozen=True) @@ -180,7 +182,7 @@ def on_click( func: Callable[[SceneNodePointerEvent[TSceneNodeHandle]], None], ) -> Callable[[SceneNodePointerEvent[TSceneNodeHandle]], None]: """Attach a callback for when a scene node is clicked.""" - self._impl.api._queue( + self._impl.api._websock_interface.queue_message( _messages.SetSceneNodeClickableMessage(self._impl.name, True) ) if self._impl.click_cb is None: @@ -233,7 +235,7 @@ class LabelHandle(SceneNodeHandle): class _TransformControlsState: last_updated: float update_cb: List[Callable[[TransformControlsHandle], None]] - sync_cb: Optional[Callable[[ClientId, TransformControlsHandle], None]] = None + sync_cb: None | Callable[[ClientId, TransformControlsHandle], None] = None @dataclasses.dataclass @@ -260,7 +262,7 @@ class Gui3dContainerHandle(SceneNodeHandle): _gui_api: GuiApi _container_id: str - _container_id_restore: Optional[str] = None + _container_id_restore: str | None = None _children: Dict[str, SupportsRemoveProtocol] = dataclasses.field( default_factory=dict ) diff --git a/src/viser/_viser.py b/src/viser/_viser.py index 23c0178dc..a14a109e8 100644 --- a/src/viser/_viser.py +++ b/src/viser/_viser.py @@ -1,12 +1,12 @@ from __future__ import annotations -import contextlib import dataclasses import io +import mimetypes import threading import time from pathlib import Path -from typing import Callable, Dict, Generator, List, Optional, Tuple +from typing import Callable, ContextManager, Dict, List, Optional, Tuple import imageio.v3 as iio import numpy as onp @@ -15,13 +15,12 @@ from rich import box, style from rich.panel import Panel from rich.table import Table -from typing_extensions import Literal, override +from typing_extensions import Literal from . import _client_autobuild, _messages, infra from . import transforms as tf from ._gui_api import GuiApi -from ._message_api import MessageApi, cast_vector -from ._scene_handles import FrameHandle, _SceneNodeHandleState +from ._scene_api import SceneApi, cast_vector from ._tunnel import ViserTunnel @@ -111,7 +110,7 @@ def position(self, position: Tuple[float, float, float] | onp.ndarray) -> None: self._state.position = onp.asarray(position) self.look_at = onp.array(self.look_at) + offset self._state.update_timestamp = time.time() - self._state.client._queue( + self._state.client._websock_connection.queue_message( _messages.SetCameraPositionMessage(cast_vector(position, 3)) ) @@ -136,7 +135,9 @@ def fov(self) -> float: def fov(self, fov: float) -> None: self._state.fov = fov self._state.update_timestamp = time.time() - self._state.client._queue(_messages.SetCameraFovMessage(fov)) + self._state.client._websock_connection.queue_message( + _messages.SetCameraFovMessage(fov) + ) @property def aspect(self) -> float: @@ -160,7 +161,7 @@ def look_at(self, look_at: Tuple[float, float, float] | onp.ndarray) -> None: self._state.look_at = onp.asarray(look_at) self._state.update_timestamp = time.time() self._update_wxyz() - self._state.client._queue( + self._state.client._websock_connection.queue_message( _messages.SetCameraLookAtMessage(cast_vector(look_at, 3)) ) @@ -177,7 +178,7 @@ def up_direction( self._state.up_direction = onp.asarray(up_direction) self._update_wxyz() self._state.update_timestamp = time.time() - self._state.client._queue( + self._state.client._websock_connection.queue_message( _messages.SetCameraUpDirectionMessage(cast_vector(up_direction, 3)) ) @@ -207,7 +208,7 @@ def get_render( render_ready_event = threading.Event() out: Optional[onp.ndarray] = None - connection = self.client._state.connection + connection = self.client._websock_connection def got_render_cb( client_id: int, message: _messages.GetRenderResponseMessage @@ -224,7 +225,7 @@ def got_render_cb( render_ready_event.set() connection.register_handler(_messages.GetRenderResponseMessage, got_render_cb) - self.client._queue( + self.client._websock_connection.queue_message( _messages.GetRenderRequestMessage( "image/jpeg" if transport_format == "jpeg" else "image/png", height=height, @@ -240,40 +241,48 @@ def got_render_cb( return out -@dataclasses.dataclass -class _ClientHandleState: - viser_server: ViserServer - server: infra.Server - connection: infra.ClientConnection - - -@dataclasses.dataclass -class ClientHandle(MessageApi, GuiApi): +class ClientHandle: """Handle for interacting with a specific client. Can be used to send messages to individual clients and read/write camera information.""" - client_id: int - """Unique ID for this client.""" - camera: CameraHandle - """Handle for reading from and manipulating the client's viewport camera.""" - _state: _ClientHandleState + def __init__( + self, conn: infra.WebsockClientConnection, server: ViserServer + ) -> None: + self._websock_connection = conn + self._viser_server = server + + self.scene = SceneApi( + self, thread_executor=server._websock_server._thread_executor + ) + """Handle for interacting with the 3D scene.""" - def __post_init__(self): - super().__init__(self._state.connection, self._state.server._thread_executor) + self.gui = GuiApi(self) + """Handle for interacting with the GUI.""" - @override - def _get_api(self) -> MessageApi: - """Message API to use.""" - return self + self.client_id = conn.client_id + """Unique ID for this client.""" - @override - def _queue_unsafe(self, message: _messages.Message) -> None: - """Define how the message API should send messages.""" - self._state.connection.send(message) + self.camera = CameraHandle( + _CameraHandleState( + self, + wxyz=onp.zeros(4), + position=onp.zeros(3), + fov=0.0, + aspect=0.0, + look_at=onp.zeros(3), + up_direction=onp.zeros(3), + update_timestamp=0.0, + camera_cb=[], + ) + ) + """Handle for reading from and manipulating the client's viewport camera.""" - @override - @contextlib.contextmanager - def atomic(self) -> Generator[None, None, None]: + def flush(self) -> None: + """Flush the outgoing message buffer. Any buffered messages will immediately be + sent. (by default they are windowed)""" + self._viser_server._websock_server.flush_client(self.client_id) + + def atomic(self) -> ContextManager[None]: """Returns a context where: all outgoing messages are grouped and applied by clients atomically. @@ -284,26 +293,51 @@ def atomic(self) -> Generator[None, None, None]: Returns: Context manager. """ - # If called multiple times in the same thread, we ignore inner calls. - thread_id = threading.get_ident() - if thread_id == self._locked_thread_id: - got_lock = False - else: - self._atomic_lock.acquire() - self._locked_thread_id = thread_id - got_lock = True + return self._websock_connection.atomic() - yield - - if got_lock: - self._atomic_lock.release() - self._locked_thread_id = -1 + def send_file_download( + self, filename: str, content: bytes, chunk_size: int = 1024 * 1024 + ) -> None: + """Send a file for a client or clients to download. - @override - def flush(self) -> None: - """Flush the outgoing message buffer. Any buffered messages will immediately be - sent. (by default they are windowed)""" - self._state.server.flush_client(self.client_id) + Args: + filename: Name of the file to send. Used to infer MIME type. + content: Content of the file. + chunk_size: Number of bytes to send at a time. + """ + mime_type = mimetypes.guess_type(filename, strict=False)[0] + assert ( + mime_type is not None + ), f"Could not guess MIME type from filename {filename}!" + + from ._gui_api import _make_unique_id + + parts = [ + content[i * chunk_size : (i + 1) * chunk_size] + for i in range(int(onp.ceil(len(content) / chunk_size))) + ] + + uuid = _make_unique_id() + self._websock_connection.queue_message( + _messages.FileTransferStart( + source_component_id=None, + transfer_uuid=uuid, + filename=filename, + mime_type=mime_type, + part_count=len(parts), + size_bytes=len(content), + ) + ) + for i, part in enumerate(parts): + self._websock_connection.queue_message( + _messages.FileTransferPart( + None, + transfer_uuid=uuid, + part=i, + content=part, + ) + ) + self.flush() # We can serialize the state of a ViserServer via a tuple of @@ -311,18 +345,14 @@ def flush(self) -> None: SerializedServerState = Tuple[Tuple[bytes, float], ...] -def dummy_process() -> None: - pass - - @dataclasses.dataclass class _ViserServerState: - connection: infra.Server + connection: infra.WebsockServer connected_clients: Dict[int, ClientHandle] = dataclasses.field(default_factory=dict) client_lock: threading.Lock = dataclasses.field(default_factory=threading.Lock) -class ViserServer(MessageApi, GuiApi): +class ViserServer: """Viser server class. The primary interface for functionality in `viser`. Commands on a server object (`add_frame`, `add_gui_*`, ...) will be sent to all @@ -334,16 +364,14 @@ class ViserServer(MessageApi, GuiApi): label: Label shown at the top of the GUI panel. """ - world_axes: FrameHandle - """Handle for manipulating the world frame axes (/WorldAxes), which is instantiated - and then hidden by default.""" + scene: SceneApi + """Handle for interacting with the 3D scene.""" + + gui: GuiApi + """Handle for interacting with the GUI.""" # Hide deprecated arguments from docstring and type checkers. def __init__( - self, host: str = "0.0.0.0", port: int = 8080, label: Optional[str] = None - ): ... - - def _actual_init( self, host: str = "0.0.0.0", port: int = 8080, @@ -351,15 +379,14 @@ def _actual_init( **_deprecated_kwargs, ): # Create server. - server = infra.Server( + server = infra.WebsockServer( host=host, port=port, message_class=_messages.Message, http_server_root=Path(__file__).absolute().parent / "client" / "build", client_api_version=1, ) - self._server = server - super().__init__(server, server._thread_executor) + self._websock_server = server _client_autobuild.ensure_client_is_built() @@ -370,27 +397,8 @@ def _actual_init( # For new clients, register and add a handler for camera messages. @server.on_client_connect - def _(conn: infra.ClientConnection) -> None: - camera = CameraHandle( - _CameraHandleState( - # TODO: values are initially not valid. - client=None, # type: ignore - wxyz=onp.zeros(4), - position=onp.zeros(3), - fov=0.0, - aspect=0.0, - look_at=onp.zeros(3), - up_direction=onp.zeros(3), - update_timestamp=0.0, - camera_cb=[], - ) - ) - client = ClientHandle( - conn.client_id, - camera, - _ClientHandleState(self, server, conn), - ) - camera._state.client = client + def _(conn: infra.WebsockClientConnection) -> None: + client = ClientHandle(conn, server=self) first = True def handle_camera_message( @@ -430,7 +438,7 @@ def handle_camera_message( # Remove clients when they disconnect. @server.on_client_disconnect - def _(conn: infra.ClientConnection) -> None: + def _(conn: infra.WebsockClientConnection) -> None: with self._state.client_lock: if conn.client_id not in state.connected_clients: return @@ -442,6 +450,9 @@ def _(conn: infra.ClientConnection) -> None: # Start the server. server.start() + self.scene = SceneApi(owner=self, thread_executor=server._thread_executor) + self.gui = GuiApi(self) + server.register_handler( _messages.ShareUrlDisconnect, lambda client_id, msg: self.disconnect_share_url(), @@ -472,19 +483,8 @@ def _(conn: infra.ClientConnection) -> None: if share: self.request_share_url() - self.reset_scene() - self.set_gui_panel_label(label) - - # Create a handle for the world axes, which are hardcoded to exist in the client. - self.world_axes = FrameHandle( - _SceneNodeHandleState( - "/WorldAxes", - self, - wxyz=onp.array([1.0, 0.0, 0.0, 0.0]), - position=onp.zeros(3), - ) - ) - self.world_axes.visible = False + self.scene.reset() + self.gui.set_gui_panel_label(label) def get_host(self) -> str: """Returns the host address of the Viser server. @@ -492,7 +492,7 @@ def get_host(self) -> str: Returns: Host address as string. """ - return self._server._host + return self._websock_server._host def get_port(self) -> int: """Returns the port of the Viser server. This could be different from the @@ -501,7 +501,7 @@ def get_port(self) -> int: Returns: Port as integer. """ - return self._server._port + return self._websock_server._port def request_share_url(self, verbose: bool = True) -> Optional[str]: """Request a share URL for the Viser server, which allows for public access. @@ -526,13 +526,17 @@ def request_share_url(self, verbose: bool = True) -> Optional[str]: connect_event = threading.Event() - self._share_tunnel = ViserTunnel("share.viser.studio", self._server._port) + self._share_tunnel = ViserTunnel( + "share.viser.studio", self._websock_server._port + ) @self._share_tunnel.on_disconnect def _() -> None: rich.print("[bold](viser)[/bold] Disconnected from share URL") self._share_tunnel = None - self._server.broadcast(_messages.ShareUrlUpdated(None)) + self._websock_server.unsafe_send_message( + _messages.ShareUrlUpdated(None) + ) @self._share_tunnel.on_connect def _(max_clients: int) -> None: @@ -545,7 +549,9 @@ def _(max_clients: int) -> None: rich.print( f"[bold](viser)[/bold] Generated share URL (expires in 24 hours, max {max_clients} clients): {share_url}" ) - self._server.broadcast(_messages.ShareUrlUpdated(share_url)) + self._websock_server.unsafe_send_message( + _messages.ShareUrlUpdated(share_url) + ) connect_event.set() connect_event.wait() @@ -564,7 +570,7 @@ def disconnect_share_url(self) -> None: def stop(self) -> None: """Stop the Viser server and associated threads and tunnels.""" - self._server.stop() + self._websock_server.stop() if self._share_tunnel is not None: self._share_tunnel.close() @@ -605,9 +611,12 @@ def on_client_disconnect( self._client_disconnect_cb.append(cb) return cb - @override - @contextlib.contextmanager - def atomic(self) -> Generator[None, None, None]: + def flush(self) -> None: + """Flush the outgoing message buffer. Any buffered messages will immediately be + sent. (by default they are windowed)""" + self._websock_server.flush() + + def atomic(self) -> ContextManager[None]: """Returns a context where: all outgoing messages are grouped and applied by clients atomically. @@ -618,44 +627,17 @@ def atomic(self) -> Generator[None, None, None]: Returns: Context manager. """ - # Acquire the global atomic lock. - # If called multiple times in the same thread, we ignore inner calls. - thread_id = threading.get_ident() - if thread_id == self._locked_thread_id: - got_lock = False - else: - self._atomic_lock.acquire() - self._locked_thread_id = thread_id - got_lock = True - - with contextlib.ExitStack() as stack: - if got_lock: - # Grab each client's atomic lock. - # We don't need to do anything with `client._locked_thread_id`. - for client in self.get_clients().values(): - stack.enter_context(client._atomic_lock) - - yield - - if got_lock: - self._atomic_lock.release() - self._locked_thread_id = -1 - - @override - def flush(self) -> None: - """Flush the outgoing message buffer. Any buffered messages will immediately be - sent. (by default they are windowed)""" - self._server.flush() - - @override - def _get_api(self) -> MessageApi: - """Message API to use.""" - return self - - @override - def _queue_unsafe(self, message: _messages.Message) -> None: - """Define how the message API should send messages.""" - self._server.broadcast(message) + return self._websock_server.atomic() + def send_file_download( + self, filename: str, content: bytes, chunk_size: int = 1024 * 1024 + ) -> None: + """Send a file for a client or clients to download. -ViserServer.__init__ = ViserServer._actual_init # type: ignore + Args: + filename: Name of the file to send. Used to infer MIME type. + content: Content of the file. + chunk_size: Number of bytes to send at a time. + """ + for client in self.get_clients().values(): + client.send_file_download(filename, content, chunk_size) diff --git a/src/viser/infra/__init__.py b/src/viser/infra/__init__.py index dc25d9a8c..57171dfa3 100644 --- a/src/viser/infra/__init__.py +++ b/src/viser/infra/__init__.py @@ -11,10 +11,10 @@ you're building a web-based application from scratch. """ -from ._infra import ClientConnection as ClientConnection from ._infra import ClientId as ClientId -from ._infra import MessageHandler as MessageHandler -from ._infra import Server as Server +from ._infra import WebsockClientConnection as WebsockClientConnection +from ._infra import WebsockMessageHandler as WebsockMessageHandler +from ._infra import WebsockServer as WebsockServer from ._messages import Message as Message from ._typescript_interface_gen import ( TypeScriptAnnotationOverride as TypeScriptAnnotationOverride, diff --git a/src/viser/infra/_infra.py b/src/viser/infra/_infra.py index c817ceda1..ab2d1ad55 100644 --- a/src/viser/infra/_infra.py +++ b/src/viser/infra/_infra.py @@ -1,14 +1,28 @@ from __future__ import annotations +import abc import asyncio +import contextlib import dataclasses import http import mimetypes +import queue import threading from asyncio.events import AbstractEventLoop from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Type, TypeVar +from typing import ( + Any, + Callable, + Dict, + Generator, + List, + NewType, + Optional, + Tuple, + Type, + TypeVar, +) import msgpack import rich @@ -16,7 +30,7 @@ import websockets.datastructures import websockets.exceptions import websockets.server -from typing_extensions import Literal, assert_never +from typing_extensions import Literal, assert_never, override from websockets.legacy.server import WebSocketServerProtocol from ._async_message_buffer import AsyncMessageBuffer @@ -35,13 +49,17 @@ class _ClientHandleState: TMessage = TypeVar("TMessage", bound=Message) -class MessageHandler: +class WebsockMessageHandler: """Mix-in for adding message handling to a class.""" - def __init__(self) -> None: + def __init__(self, thread_executor: ThreadPoolExecutor) -> None: + self._thread_executor = thread_executor self._incoming_handlers: Dict[ Type[Message], List[Callable[[ClientId, Message], None]] ] = {} + self._atomic_lock = threading.Lock() + self._queued_messages: queue.Queue = queue.Queue() + self._locked_thread_id = -1 def register_handler( self, @@ -73,25 +91,76 @@ def _handle_incoming_message(self, client_id: ClientId, message: Message) -> Non for cb in self._incoming_handlers[type(message)]: cb(client_id, message) + @abc.abstractmethod + def unsafe_send_message(self, message: Message) -> None: + ... -@dataclasses.dataclass -class ClientConnection(MessageHandler): - """Handle for interacting with a single connected client. + def queue_message(self, message: Message) -> None: + """Wrapped method for sending messages safely.""" + got_lock = self._atomic_lock.acquire(blocking=False) + if got_lock: + self.unsafe_send_message(message) + self._atomic_lock.release() + else: + # Send when lock is acquirable, while retaining message order. + # This could be optimized! + self._queued_messages.put(message) + + def try_again() -> None: + with self._atomic_lock: + self.unsafe_send_message(self._queued_messages.get()) + + self._thread_executor.submit(try_again) + + @contextlib.contextmanager + def atomic(self) -> Generator[None, None, None]: + """Returns a context where: all outgoing messages are grouped and applied by + clients atomically. + + This should be treated as a soft constraint that's helpful for things + like animations, or when we want position and orientation updates to + happen synchronously. + + Returns: + Context manager. + """ + # If called multiple times in the same thread, we ignore inner calls. + thread_id = threading.get_ident() + if thread_id == self._locked_thread_id: + got_lock = False + else: + self._atomic_lock.acquire() + self._locked_thread_id = thread_id + got_lock = True - We can use this to read the camera state or send client-specific messages.""" + yield - client_id: ClientId - _state: _ClientHandleState + if got_lock: + self._atomic_lock.release() + self._locked_thread_id = -1 - def __post_init__(self) -> None: - super().__init__() - def send(self, message: Message) -> None: +class WebsockClientConnection(WebsockMessageHandler): + """Handle for sending messages to and listening to messages from a single + connected client.""" + + def __init__( + self, + client_id: int, + thread_executor: ThreadPoolExecutor, + client_state: _ClientHandleState, + ) -> None: + self.client_id = client_id + self._state = client_state + super().__init__(thread_executor) + + @override + def unsafe_send_message(self, message: Message) -> None: """Send a message to a specific client.""" self._state.message_buffer.push(message) -class Server(MessageHandler): +class WebsockServer(WebsockMessageHandler): """Websocket server abstraction. Communicates asynchronously with client applications. @@ -121,11 +190,11 @@ def __init__( verbose: bool = True, client_api_version: Literal[0, 1] = 0, ): - super().__init__() + super().__init__(thread_executor=ThreadPoolExecutor(max_workers=32)) # Track connected clients. - self._client_connect_cb: List[Callable[[ClientConnection], None]] = [] - self._client_disconnect_cb: List[Callable[[ClientConnection], None]] = [] + self._client_connect_cb: List[Callable[[WebsockClientConnection], None]] = [] + self._client_disconnect_cb: List[Callable[[WebsockClientConnection], None]] = [] self._host = host self._port = port @@ -133,8 +202,6 @@ def __init__( self._http_server_root = http_server_root self._verbose = verbose self._client_api_version: Literal[0, 1] = client_api_version - - self._thread_executor = ThreadPoolExecutor(max_workers=32) self._shutdown_event = threading.Event() self._client_state_from_id: Dict[int, _ClientHandleState] = {} @@ -161,15 +228,18 @@ def stop(self) -> None: self._thread_executor.shutdown(wait=True) self._event_loop.stop() - def on_client_connect(self, cb: Callable[[ClientConnection], Any]) -> None: + def on_client_connect(self, cb: Callable[[WebsockClientConnection], Any]) -> None: """Attach a callback to run for newly connected clients.""" self._client_connect_cb.append(cb) - def on_client_disconnect(self, cb: Callable[[ClientConnection], Any]) -> None: + def on_client_disconnect( + self, cb: Callable[[WebsockClientConnection], Any] + ) -> None: """Attach a callback to run when clients disconnect.""" self._client_disconnect_cb.append(cb) - def broadcast(self, message: Message) -> None: + @override + def unsafe_send_message(self, message: Message) -> None: """Pushes a message onto the broadcast queue. Message will be sent to all clients. Broadcasted messages are persistent: if a new client connects to the server, @@ -229,7 +299,9 @@ async def serve(websocket: WebSocketServerProtocol) -> None: AsyncMessageBuffer(event_loop, persistent_messages=False), event_loop, ) - client_connection = ClientConnection(client_id, client_state) + client_connection = WebsockClientConnection( + client_id, self._thread_executor, client_state + ) self._client_state_from_id[client_id] = client_state def handle_incoming(message: Message) -> None: