From b48fa2f16a1e778d0cbc44e86f2deddf017aa00e Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Mon, 3 Jun 2024 20:48:57 -0700 Subject: [PATCH 01/33] change render button to run ns-render directly --- nerfstudio/viewer/render_panel.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index 4cfe380d9e..44cfcbf2a1 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -31,6 +31,7 @@ from scipy import interpolate from nerfstudio.viewer.control_panel import ControlPanel +from nerfstudio.utils.scripts import run_command @dataclasses.dataclass @@ -1019,13 +1020,21 @@ def _(_) -> None: initial_value=now.strftime("%Y-%m-%d-%H-%M-%S"), hint="Name of the render", ) + render_button = server.add_gui_button( - "Generate Command", + "Render", color="green", icon=viser.Icon.FILE_EXPORT, - hint="Generate the ns-render command for rendering the camera path.", + hint="Render the camera path and save video as mp4 file.", ) + # generate_render_button = server.add_gui_button( + # "Generate Command", + # color="green", + # icon=viser.Icon.FILE_EXPORT, + # hint="Generate the ns-render command for rendering the camera path.", + # ) + reset_up_button = server.add_gui_button( "Reset Up Direction", icon=viser.Icon.ARROW_BIG_UP_LINES, @@ -1148,14 +1157,15 @@ def _(event: viser.GuiEvent) -> None: event.client.add_gui_markdown( "\n".join( [ - "To render the trajectory, run the following from the command line:", + "Rendering trajectory and saving file as", "", "```", - command, + f"renders/{dataname}/{render_name_text.value}.mp4", "```", ] ) ) + run_command(command, verbose=False) close_button = event.client.add_gui_button("Close") @close_button.on_click From 2f5e5ca9f81222c3668509a47d7e550440909471 Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Mon, 3 Jun 2024 20:51:07 -0700 Subject: [PATCH 02/33] nit --- nerfstudio/viewer/render_panel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index 44cfcbf2a1..a3afdab508 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -1144,7 +1144,7 @@ def _(event: viser.GuiEvent) -> None: with open(json_outfile.absolute(), "w") as outfile: json.dump(json_data, outfile) # now show the command - with event.client.add_gui_modal("Render Command") as modal: + with event.client.add_gui_modal("Render") as modal: dataname = datapath.name command = " ".join( [ From 0933c82c1c840eab151d3272737a1323b5c4bded Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Tue, 4 Jun 2024 06:13:38 -0700 Subject: [PATCH 03/33] nits --- nerfstudio/viewer/export_panel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index 16201ba299..752306f9f3 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -194,6 +194,7 @@ def populate_splat_tab( server.add_gui_markdown("Generate ply export of Gaussian Splat") output_directory = server.add_gui_text("Output Directory", initial_value="exports/splat/") + # TODO: change to export directly generate_command = server.add_gui_button("Generate Command", icon=viser.Icon.TERMINAL_2) @generate_command.on_click From c43e9e06712011fa791a4e09d7f5201b7fc15bd8 Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Wed, 12 Jun 2024 21:31:02 -0700 Subject: [PATCH 04/33] add direct render via viewer and notifications --- nerfstudio/scripts/render.py | 4 ++- nerfstudio/viewer/render_panel.py | 53 +++++++++++++++---------------- 2 files changed, 29 insertions(+), 28 deletions(-) diff --git a/nerfstudio/scripts/render.py b/nerfstudio/scripts/render.py index c2d6d83ce6..29b2dff3b6 100644 --- a/nerfstudio/scripts/render.py +++ b/nerfstudio/scripts/render.py @@ -439,6 +439,8 @@ class RenderCameraPath(BaseRender): """Filename of the camera path to render.""" output_format: Literal["images", "video"] = "video" """How to save output data.""" + complete: bool = True + """Whether rendering is complete""" def main(self) -> None: """Main function.""" @@ -575,7 +577,7 @@ def main(self) -> None: if str(left_eye_path.parent)[-5:] == "_temp": shutil.rmtree(left_eye_path.parent, ignore_errors=True) CONSOLE.print("[bold green]Final VR180 Render Complete") - + self.complete = True @dataclass class RenderInterpolated(BaseRender): diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index a3afdab508..76d1aa7148 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -1050,8 +1050,17 @@ def _(event: viser.GuiEvent) -> None: @render_button.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None + render_path = f"""renders/{datapath.name}/{render_name_text.value}.mp4""" + server.add_notification( + title="Rendering trajectory", + body="Saving rendered video as " + render_path, + withCloseButton=True, + loading=True, + autoClose=False, + ) num_frames = int(framerate_number.value * duration_number.value) json_data = {} + # json data has the properties: # keyframes: list of keyframes with # matrix : flattened 4x4 matrix @@ -1070,7 +1079,7 @@ def _(event: viser.GuiEvent) -> None: # aspect: float # first populate the keyframes: keyframes = [] - for keyframe, dummy in camera_path._keyframes.values(): + for keyframe, _ in camera_path._keyframes.values(): pose = tf.SE3.from_rotation_and_translation( tf.SO3(keyframe.wxyz) @ tf.SO3.from_x_radians(np.pi), keyframe.position / VISER_NERFSTUDIO_SCALE_RATIO, @@ -1143,34 +1152,24 @@ def _(event: viser.GuiEvent) -> None: json_outfile.parent.mkdir(parents=True, exist_ok=True) with open(json_outfile.absolute(), "w") as outfile: json.dump(json_data, outfile) - # now show the command - with event.client.add_gui_modal("Render") as modal: - dataname = datapath.name - command = " ".join( - [ - "ns-render camera-path", - f"--load-config {config_path}", - f"--camera-path-filename {json_outfile.absolute()}", - f"--output-path renders/{dataname}/{render_name_text.value}.mp4", - ] - ) - event.client.add_gui_markdown( - "\n".join( - [ - "Rendering trajectory and saving file as", - "", - "```", - f"renders/{dataname}/{render_name_text.value}.mp4", - "```", - ] + + # rendering + from nerfstudio.scripts.render import RenderCameraPath + render = RenderCameraPath( + load_config=config_path, + camera_path_filename=json_outfile.absolute(), + output_path=Path(render_path) ) + render.main() + + if render.complete: + server.add_notification( + title="Render complete!", + body="Video saved as " + render_path, + withCloseButton=True, + loading=False, + autoClose=5000, ) - run_command(command, verbose=False) - close_button = event.client.add_gui_button("Close") - - @close_button.on_click - def _(_) -> None: - modal.close() if control_panel is not None: camera_path = CameraPath(server, duration_number, control_panel._time_enabled) From 7e92ca8ff67063df04f895a87252fba33714e62e Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Wed, 12 Jun 2024 22:18:10 -0700 Subject: [PATCH 05/33] add direct export button and notifications --- nerfstudio/scripts/exporter.py | 158 +++++++++++++++++++----- nerfstudio/scripts/render.py | 184 +++++++++++++++++++++------ nerfstudio/viewer/export_panel.py | 198 ++++++++++++++++++++++-------- nerfstudio/viewer/render_panel.py | 182 +++++++++++++++++++-------- 4 files changed, 546 insertions(+), 176 deletions(-) diff --git a/nerfstudio/scripts/exporter.py b/nerfstudio/scripts/exporter.py index 7d90e0708d..548f146e3c 100644 --- a/nerfstudio/scripts/exporter.py +++ b/nerfstudio/scripts/exporter.py @@ -38,11 +38,19 @@ from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanager from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManager -from nerfstudio.data.datamanagers.random_cameras_datamanager import RandomCamerasDataManager +from nerfstudio.data.datamanagers.random_cameras_datamanager import ( + RandomCamerasDataManager, +) from nerfstudio.data.scene_box import OrientedBox from nerfstudio.exporter import texture_utils, tsdf_utils -from nerfstudio.exporter.exporter_utils import collect_camera_poses, generate_point_cloud, get_mesh_from_filename -from nerfstudio.exporter.marching_cubes import generate_mesh_with_multires_marching_cubes +from nerfstudio.exporter.exporter_utils import ( + collect_camera_poses, + generate_point_cloud, + get_mesh_from_filename, +) +from nerfstudio.exporter.marching_cubes import ( + generate_mesh_with_multires_marching_cubes, +) from nerfstudio.fields.sdf_field import SDFField # noqa from nerfstudio.models.splatfacto import SplatfactoModel from nerfstudio.pipelines.base_pipeline import Pipeline, VanillaPipeline @@ -58,9 +66,13 @@ class Exporter: """Path to the config YAML file.""" output_dir: Path """Path to the output directory.""" + complete: bool = False + """Set to True when export is finished.""" -def validate_pipeline(normal_method: str, normal_output_name: str, pipeline: Pipeline) -> None: +def validate_pipeline( + normal_method: str, normal_output_name: str, pipeline: Pipeline +) -> None: """Check that the pipeline is valid for this exporter. Args: @@ -75,11 +87,16 @@ def validate_pipeline(normal_method: str, normal_output_name: str, pipeline: Pip pixel_area = torch.ones_like(origins[..., :1]) camera_indices = torch.zeros_like(origins[..., :1]) ray_bundle = RayBundle( - origins=origins, directions=directions, pixel_area=pixel_area, camera_indices=camera_indices + origins=origins, + directions=directions, + pixel_area=pixel_area, + camera_indices=camera_indices, ) outputs = pipeline.model(ray_bundle) if normal_output_name not in outputs: - CONSOLE.print(f"[bold yellow]Warning: Normal output '{normal_output_name}' not found in pipeline outputs.") + CONSOLE.print( + f"[bold yellow]Warning: Normal output '{normal_output_name}' not found in pipeline outputs." + ) CONSOLE.print(f"Available outputs: {list(outputs.keys())}") CONSOLE.print( "[bold yellow]Warning: Please train a model with normals " @@ -136,16 +153,29 @@ def main(self) -> None: # Increase the batchsize to speed up the evaluation. assert isinstance( pipeline.datamanager, - (VanillaDataManager, ParallelDataManager, FullImageDatamanager, RandomCamerasDataManager), + ( + VanillaDataManager, + ParallelDataManager, + FullImageDatamanager, + RandomCamerasDataManager, + ), ) assert pipeline.datamanager.train_pixel_sampler is not None - pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = self.num_rays_per_batch + pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = ( + self.num_rays_per_batch + ) # Whether the normals should be estimated based on the point cloud. estimate_normals = self.normal_method == "open3d" crop_obb = None - if self.obb_center is not None and self.obb_rotation is not None and self.obb_scale is not None: - crop_obb = OrientedBox.from_params(self.obb_center, self.obb_rotation, self.obb_scale) + if ( + self.obb_center is not None + and self.obb_rotation is not None + and self.obb_scale is not None + ): + crop_obb = OrientedBox.from_params( + self.obb_center, self.obb_rotation, self.obb_scale + ) pcd = generate_point_cloud( pipeline=pipeline, num_points=self.num_points, @@ -154,14 +184,18 @@ def main(self) -> None: estimate_normals=estimate_normals, rgb_output_name=self.rgb_output_name, depth_output_name=self.depth_output_name, - normal_output_name=self.normal_output_name if self.normal_method == "model_output" else None, + normal_output_name=self.normal_output_name + if self.normal_method == "model_output" + else None, crop_obb=crop_obb, std_ratio=self.std_ratio, ) if self.save_world_frame: # apply the inverse dataparser transform to the point cloud points = np.asarray(pcd.points) - poses = np.eye(4, dtype=np.float32)[None, ...].repeat(points.shape[0], axis=0)[:, :3, :] + poses = np.eye(4, dtype=np.float32)[None, ...].repeat( + points.shape[0], axis=0 + )[:, :3, :] poses[:, :3, 3] = points poses = pipeline.datamanager.train_dataparser_outputs.transform_poses_to_original_space( torch.from_numpy(poses) @@ -181,6 +215,8 @@ def main(self) -> None: print("\033[A\033[A") CONSOLE.print("[bold green]:white_check_mark: Saving Point Cloud") + self.complete = True + @dataclass class ExportTSDFMesh(Exporter): @@ -242,18 +278,23 @@ def main(self) -> None: if self.texture_method == "nerf": # load the mesh from the tsdf export mesh = get_mesh_from_filename( - str(self.output_dir / "tsdf_mesh.ply"), target_num_faces=self.target_num_faces + str(self.output_dir / "tsdf_mesh.ply"), + target_num_faces=self.target_num_faces, ) CONSOLE.print("Texturing mesh with NeRF") texture_utils.export_textured_mesh( mesh, pipeline, self.output_dir, - px_per_uv_triangle=self.px_per_uv_triangle if self.unwrap_method == "custom" else None, + px_per_uv_triangle=self.px_per_uv_triangle + if self.unwrap_method == "custom" + else None, unwrap_method=self.unwrap_method, num_pixels_per_side=self.num_pixels_per_side, ) + self.complete = True + @dataclass class ExportPoissonMesh(Exporter): @@ -317,15 +358,28 @@ def main(self) -> None: # Increase the batchsize to speed up the evaluation. assert isinstance( pipeline.datamanager, - (VanillaDataManager, ParallelDataManager, FullImageDatamanager, RandomCamerasDataManager), + ( + VanillaDataManager, + ParallelDataManager, + FullImageDatamanager, + RandomCamerasDataManager, + ), ) assert pipeline.datamanager.train_pixel_sampler is not None - pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = self.num_rays_per_batch + pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = ( + self.num_rays_per_batch + ) # Whether the normals should be estimated based on the point cloud. estimate_normals = self.normal_method == "open3d" - if self.obb_center is not None and self.obb_rotation is not None and self.obb_scale is not None: - crop_obb = OrientedBox.from_params(self.obb_center, self.obb_rotation, self.obb_scale) + if ( + self.obb_center is not None + and self.obb_rotation is not None + and self.obb_scale is not None + ): + crop_obb = OrientedBox.from_params( + self.obb_center, self.obb_rotation, self.obb_scale + ) else: crop_obb = None @@ -337,7 +391,9 @@ def main(self) -> None: estimate_normals=estimate_normals, rgb_output_name=self.rgb_output_name, depth_output_name=self.depth_output_name, - normal_output_name=self.normal_output_name if self.normal_method == "model_output" else None, + normal_output_name=self.normal_output_name + if self.normal_method == "model_output" + else None, crop_obb=crop_obb, std_ratio=self.std_ratio, ) @@ -351,7 +407,9 @@ def main(self) -> None: CONSOLE.print("[bold green]:white_check_mark: Saving Point Cloud") CONSOLE.print("Computing Mesh... this may take a while.") - mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=9) + mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson( + pcd, depth=9 + ) vertices_to_remove = densities < np.quantile(densities, 0.1) mesh.remove_vertices_by_mask(vertices_to_remove) print("\033[A\033[A") @@ -367,18 +425,23 @@ def main(self) -> None: if self.texture_method == "nerf": # load the mesh from the poisson reconstruction mesh = get_mesh_from_filename( - str(self.output_dir / "poisson_mesh.ply"), target_num_faces=self.target_num_faces + str(self.output_dir / "poisson_mesh.ply"), + target_num_faces=self.target_num_faces, ) CONSOLE.print("Texturing mesh with NeRF") texture_utils.export_textured_mesh( mesh, pipeline, self.output_dir, - px_per_uv_triangle=self.px_per_uv_triangle if self.unwrap_method == "custom" else None, + px_per_uv_triangle=self.px_per_uv_triangle + if self.unwrap_method == "custom" + else None, unwrap_method=self.unwrap_method, num_pixels_per_side=self.num_pixels_per_side, ) + self.complete = True + @dataclass class ExportMarchingCubesMesh(Exporter): @@ -411,11 +474,15 @@ def main(self) -> None: _, pipeline, _, _ = eval_setup(self.load_config) # TODO: Make this work with Density Field - assert hasattr(pipeline.model.config, "sdf_field"), "Model must have an SDF field." + assert hasattr( + pipeline.model.config, "sdf_field" + ), "Model must have an SDF field." CONSOLE.print("Extracting mesh with marching cubes... which may take a while") - assert self.resolution % 512 == 0, f"""resolution must be divisible by 512, got {self.resolution}. + assert ( + self.resolution % 512 == 0 + ), f"""resolution must be divisible by 512, got {self.resolution}. This is important because the algorithm uses a multi-resolution approach to evaluate the SDF where the minimum resolution is 512.""" @@ -434,17 +501,23 @@ def main(self) -> None: multi_res_mesh.export(filename) # load the mesh from the marching cubes export - mesh = get_mesh_from_filename(str(filename), target_num_faces=self.target_num_faces) + mesh = get_mesh_from_filename( + str(filename), target_num_faces=self.target_num_faces + ) CONSOLE.print("Texturing mesh with NeRF...") texture_utils.export_textured_mesh( mesh, pipeline, self.output_dir, - px_per_uv_triangle=self.px_per_uv_triangle if self.unwrap_method == "custom" else None, + px_per_uv_triangle=self.px_per_uv_triangle + if self.unwrap_method == "custom" + else None, unwrap_method=self.unwrap_method, num_pixels_per_side=self.num_pixels_per_side, ) + self.complete = True + @dataclass class ExportCameraPoses(Exporter): @@ -461,9 +534,14 @@ def main(self) -> None: assert isinstance(pipeline, VanillaPipeline) train_frames, eval_frames = collect_camera_poses(pipeline) - for file_name, frames in [("transforms_train.json", train_frames), ("transforms_eval.json", eval_frames)]: + for file_name, frames in [ + ("transforms_train.json", train_frames), + ("transforms_eval.json", eval_frames), + ]: if len(frames) == 0: - CONSOLE.print(f"[bold yellow]No frames found for {file_name}. Skipping.") + CONSOLE.print( + f"[bold yellow]No frames found for {file_name}. Skipping." + ) continue output_file_path = os.path.join(self.output_dir, file_name) @@ -471,7 +549,9 @@ def main(self) -> None: with open(output_file_path, "w", encoding="UTF-8") as f: json.dump(frames, f, indent=4) - CONSOLE.print(f"[bold green]:white_check_mark: Saved poses to {output_file_path}") + CONSOLE.print( + f"[bold green]:white_check_mark: Saved poses to {output_file_path}" + ) @dataclass @@ -515,7 +595,9 @@ def write_ply( and tensor.size > 0 for tensor in map_to_tensors.values() ): - raise ValueError("All tensors must be numpy arrays of float or uint8 type and not empty") + raise ValueError( + "All tensors must be numpy arrays of float or uint8 type and not empty" + ) with open(filename, "wb") as ply_file: # Write PLY header @@ -591,8 +673,14 @@ def main(self) -> None: for i in range(4): map_to_tensors[f"rot_{i}"] = quats[:, i, None] - if self.obb_center is not None and self.obb_rotation is not None and self.obb_scale is not None: - crop_obb = OrientedBox.from_params(self.obb_center, self.obb_rotation, self.obb_scale) + if ( + self.obb_center is not None + and self.obb_rotation is not None + and self.obb_scale is not None + ): + crop_obb = OrientedBox.from_params( + self.obb_center, self.obb_rotation, self.obb_scale + ) assert crop_obb is not None mask = crop_obb.within(torch.from_numpy(positions)).numpy() for k, t in map_to_tensors.items(): @@ -612,13 +700,17 @@ def main(self) -> None: CONSOLE.print(f"{n_before - n_after} NaN/Inf elements in {k}") if np.sum(select) < n: - CONSOLE.print(f"values have NaN/Inf in map_to_tensors, only export {np.sum(select)}/{n}") + CONSOLE.print( + f"values have NaN/Inf in map_to_tensors, only export {np.sum(select)}/{n}" + ) for k, t in map_to_tensors.items(): map_to_tensors[k] = map_to_tensors[k][select] count = np.sum(select) ExportGaussianSplat.write_ply(str(filename), count, map_to_tensors) + self.complete = True + Commands = tyro.conf.FlagConversionOff[ Union[ diff --git a/nerfstudio/scripts/render.py b/nerfstudio/scripts/render.py index 29b2dff3b6..628e50c975 100644 --- a/nerfstudio/scripts/render.py +++ b/nerfstudio/scripts/render.py @@ -37,17 +37,35 @@ from jaxtyping import Float from rich import box, style from rich.panel import Panel -from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn +from rich.progress import ( + BarColumn, + Progress, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) from rich.table import Table from torch import Tensor from typing_extensions import Annotated -from nerfstudio.cameras.camera_paths import get_interpolated_camera_path, get_path_from_json, get_spiral_path +from nerfstudio.cameras.camera_paths import ( + get_interpolated_camera_path, + get_path_from_json, + get_spiral_path, +) from nerfstudio.cameras.cameras import Cameras, CameraType, RayBundle -from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager, VanillaDataManagerConfig -from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig +from nerfstudio.data.datamanagers.base_datamanager import ( + VanillaDataManager, + VanillaDataManagerConfig, +) +from nerfstudio.data.datamanagers.full_images_datamanager import ( + FullImageDatamanagerConfig, +) from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManager -from nerfstudio.data.datamanagers.random_cameras_datamanager import RandomCamerasDataManager +from nerfstudio.data.datamanagers.random_cameras_datamanager import ( + RandomCamerasDataManager, +) from nerfstudio.data.datasets.base_dataset import Dataset from nerfstudio.data.scene_box import OrientedBox from nerfstudio.data.utils.dataloaders import FixedIndicesEvalDataloader @@ -148,14 +166,19 @@ def _render_trajectory_video( assert train_dataset is not None assert train_cameras is not None cam_pos = cameras[camera_idx].camera_to_worlds[:, 3].cpu() - cam_quat = tf.SO3.from_matrix(cameras[camera_idx].camera_to_worlds[:3, :3].numpy(force=True)).wxyz + cam_quat = tf.SO3.from_matrix( + cameras[camera_idx].camera_to_worlds[:3, :3].numpy(force=True) + ).wxyz for i in range(len(train_cameras)): train_cam_pos = train_cameras[i].camera_to_worlds[:, 3].cpu() # Make sure the line of sight from rendered cam to training cam is not blocked by any object bundle = RayBundle( origins=cam_pos.view(1, 3), - directions=((cam_pos - train_cam_pos) / (cam_pos - train_cam_pos).norm()).view(1, 3), + directions=( + (cam_pos - train_cam_pos) + / (cam_pos - train_cam_pos).norm() + ).view(1, 3), pixel_area=torch.tensor(1).view(1, 1), nears=torch.tensor(0.05).view(1, 1), fars=torch.tensor(100).view(1, 1), @@ -164,7 +187,9 @@ def _render_trajectory_video( ).to(pipeline.device) outputs = pipeline.model.get_outputs(bundle) - q = tf.SO3.from_matrix(train_cameras[i].camera_to_worlds[:3, :3].numpy(force=True)).wxyz + q = tf.SO3.from_matrix( + train_cameras[i].camera_to_worlds[:3, :3].numpy(force=True) + ).wxyz # calculate distance between two quaternions rot_dist = 1 - np.dot(q, cam_quat) ** 2 pos_dist = torch.norm(train_cam_pos - cam_pos) @@ -174,7 +199,10 @@ def _render_trajectory_video( true_max_dist = dist true_max_idx = i - if outputs["depth"][0] < torch.norm(cam_pos - train_cam_pos).item(): + if ( + outputs["depth"][0] + < torch.norm(cam_pos - train_cam_pos).item() + ): continue if check_occlusions and (max_dist == -1 or dist < max_dist): @@ -201,9 +229,13 @@ def _render_trajectory_video( for rendered_output_name in rendered_output_names: if rendered_output_name not in outputs: CONSOLE.rule("Error", style="red") - CONSOLE.print(f"Could not find {rendered_output_name} in the model outputs", justify="center") CONSOLE.print( - f"Please set --rendered_output_name to one of: {outputs.keys()}", justify="center" + f"Could not find {rendered_output_name} in the model outputs", + justify="center", + ) + CONSOLE.print( + f"Please set --rendered_output_name to one of: {outputs.keys()}", + justify="center", ) sys.exit(1) output_image = outputs[rendered_output_name] @@ -255,10 +287,17 @@ def _render_trajectory_video( render_image = np.concatenate(render_image, axis=1) if output_format == "images": if image_format == "png": - media.write_image(output_image_dir / f"{camera_idx:05d}.png", render_image, fmt="png") + media.write_image( + output_image_dir / f"{camera_idx:05d}.png", + render_image, + fmt="png", + ) if image_format == "jpeg": media.write_image( - output_image_dir / f"{camera_idx:05d}.jpg", render_image, fmt="jpeg", quality=jpeg_quality + output_image_dir / f"{camera_idx:05d}.jpg", + render_image, + fmt="jpeg", + quality=jpeg_quality, ) if output_format == "video": if writer is None: @@ -286,7 +325,13 @@ def _render_trajectory_video( table.add_row("Video", str(output_filename)) else: table.add_row("Images", str(output_image_dir)) - CONSOLE.print(Panel(table, title="[bold][green]:tada: Render Complete :tada:[/bold]", expand=False)) + CONSOLE.print( + Panel( + table, + title="[bold][green]:tada: Render Complete :tada:[/bold]", + expand=False, + ) + ) def insert_spherical_metadata_into_file( @@ -365,7 +410,11 @@ class CropData: background_color: Float[Tensor, "3"] = torch.Tensor([0.0, 0.0, 0.0]) """background color""" - obb: OrientedBox = field(default_factory=lambda: OrientedBox(R=torch.eye(3), T=torch.zeros(3), S=torch.ones(3) * 2)) + obb: OrientedBox = field( + default_factory=lambda: OrientedBox( + R=torch.eye(3), T=torch.zeros(3), S=torch.ones(3) * 2 + ) + ) """Oriented box representing the crop region""" # properties for backwards-compatibility interface @@ -391,12 +440,18 @@ def get_crop_from_json(camera_json: Dict[str, Any]) -> Optional[CropData]: bg_color = camera_json["crop"]["crop_bg_color"] center = camera_json["crop"]["crop_center"] scale = camera_json["crop"]["crop_scale"] - rot = (0.0, 0.0, 0.0) if "crop_rot" not in camera_json["crop"] else tuple(camera_json["crop"]["crop_rot"]) + rot = ( + (0.0, 0.0, 0.0) + if "crop_rot" not in camera_json["crop"] + else tuple(camera_json["crop"]["crop_rot"]) + ) assert len(center) == 3 assert len(scale) == 3 assert len(rot) == 3 return CropData( - background_color=torch.Tensor([bg_color["r"] / 255.0, bg_color["g"] / 255.0, bg_color["b"] / 255.0]), + background_color=torch.Tensor( + [bg_color["r"] / 255.0, bg_color["g"] / 255.0, bg_color["b"] / 255.0] + ), obb=OrientedBox.from_params(center, rot, scale), ) @@ -463,7 +518,9 @@ def main(self) -> None: or camera_path.camera_type[0] == CameraType.VR180_L.value ): # temp folder for writing left and right view renders - temp_folder_path = self.output_path.parent / (self.output_path.stem + "_temp") + temp_folder_path = self.output_path.parent / ( + self.output_path.stem + "_temp" + ) Path(temp_folder_path).mkdir(parents=True, exist_ok=True) left_eye_path = temp_folder_path / "render_left.mp4" @@ -471,7 +528,9 @@ def main(self) -> None: self.output_path = left_eye_path if camera_path.camera_type[0] == CameraType.OMNIDIRECTIONALSTEREO_L.value: - CONSOLE.print("[bold green]:goggles: Omni-directional Stereo VR :goggles:") + CONSOLE.print( + "[bold green]:goggles: Omni-directional Stereo VR :goggles:" + ) else: CONSOLE.print("[bold green]:goggles: VR180 :goggles:") @@ -579,6 +638,7 @@ def main(self) -> None: CONSOLE.print("[bold green]Final VR180 Render Complete") self.complete = True + @dataclass class RenderInterpolated(BaseRender): """Render a trajectory that interpolates between training or eval dataset images.""" @@ -721,7 +781,10 @@ def main(self): def update_config(config: TrainerConfig) -> TrainerConfig: data_manager_config = config.pipeline.datamanager - assert isinstance(data_manager_config, (VanillaDataManagerConfig, FullImageDatamanagerConfig)) + assert isinstance( + data_manager_config, + (VanillaDataManagerConfig, FullImageDatamanagerConfig), + ) data_manager_config.eval_num_images_to_sample_from = -1 data_manager_config.eval_num_times_to_repeat_images = -1 if isinstance(data_manager_config, VanillaDataManagerConfig): @@ -731,7 +794,11 @@ def update_config(config: TrainerConfig) -> TrainerConfig: data_manager_config.data = self.data if self.downscale_factor is not None: assert hasattr(data_manager_config.dataparser, "downscale_factor") - setattr(data_manager_config.dataparser, "downscale_factor", self.downscale_factor) + setattr( + data_manager_config.dataparser, + "downscale_factor", + self.downscale_factor, + ) return config config, pipeline, _, _ = eval_setup( @@ -741,25 +808,39 @@ def update_config(config: TrainerConfig) -> TrainerConfig: update_config_callback=update_config, ) data_manager_config = config.pipeline.datamanager - assert isinstance(data_manager_config, (VanillaDataManagerConfig, FullImageDatamanagerConfig)) + assert isinstance( + data_manager_config, (VanillaDataManagerConfig, FullImageDatamanagerConfig) + ) for split in self.split.split("+"): datamanager: VanillaDataManager dataset: Dataset if split == "train": - with _disable_datamanager_setup(data_manager_config._target): # pylint: disable=protected-access - datamanager = data_manager_config.setup(test_mode="test", device=pipeline.device) + with _disable_datamanager_setup( + data_manager_config._target + ): # pylint: disable=protected-access + datamanager = data_manager_config.setup( + test_mode="test", device=pipeline.device + ) dataset = datamanager.train_dataset - dataparser_outputs = getattr(dataset, "_dataparser_outputs", datamanager.train_dataparser_outputs) + dataparser_outputs = getattr( + dataset, "_dataparser_outputs", datamanager.train_dataparser_outputs + ) else: - with _disable_datamanager_setup(data_manager_config._target): # pylint: disable=protected-access - datamanager = data_manager_config.setup(test_mode=split, device=pipeline.device) + with _disable_datamanager_setup( + data_manager_config._target + ): # pylint: disable=protected-access + datamanager = data_manager_config.setup( + test_mode=split, device=pipeline.device + ) dataset = datamanager.eval_dataset dataparser_outputs = getattr(dataset, "_dataparser_outputs", None) if dataparser_outputs is None: - dataparser_outputs = datamanager.dataparser.get_dataparser_outputs(split=datamanager.test_split) + dataparser_outputs = datamanager.dataparser.get_dataparser_outputs( + split=datamanager.test_split + ) dataloader = FixedIndicesEvalDataloader( input_dataset=dataset, device=datamanager.device, @@ -777,7 +858,9 @@ def update_config(config: TrainerConfig) -> TrainerConfig: TimeRemainingColumn(elapsed_when_finished=False, compact=False), TimeElapsedColumn(), ) as progress: - for camera_idx, (camera, batch) in enumerate(progress.track(dataloader, total=len(dataset))): + for camera_idx, (camera, batch) in enumerate( + progress.track(dataloader, total=len(dataset)) + ): with torch.no_grad(): outputs = pipeline.model.get_outputs_for_camera(camera) @@ -796,10 +879,12 @@ def update_config(config: TrainerConfig) -> TrainerConfig: if rendered_output_name not in all_outputs: CONSOLE.rule("Error", style="red") CONSOLE.print( - f"Could not find {rendered_output_name} in the model outputs", justify="center" + f"Could not find {rendered_output_name} in the model outputs", + justify="center", ) CONSOLE.print( - f"Please set --rendered-output-name to one of: {all_outputs}", justify="center" + f"Please set --rendered-output-name to one of: {all_outputs}", + justify="center", ) sys.exit(1) @@ -808,9 +893,13 @@ def update_config(config: TrainerConfig) -> TrainerConfig: image_name = f"{camera_idx:05d}" # Try to get the original filename - image_name = dataparser_outputs.image_filenames[camera_idx].relative_to(images_root) + image_name = dataparser_outputs.image_filenames[ + camera_idx + ].relative_to(images_root) - output_path = self.output_path / split / rendered_output_name / image_name + output_path = ( + self.output_path / split / rendered_output_name / image_name + ) output_path.parent.mkdir(exist_ok=True, parents=True) output_name = rendered_output_name @@ -824,7 +913,9 @@ def update_config(config: TrainerConfig) -> TrainerConfig: output_image = outputs[output_name] if is_depth: # Divide by the dataparser scale factor - output_image.div_(dataparser_outputs.dataparser_scale) + output_image.div_( + dataparser_outputs.dataparser_scale + ) else: if output_name.startswith("gt-"): output_name = output_name[3:] @@ -860,16 +951,25 @@ def update_config(config: TrainerConfig) -> TrainerConfig: # Save to file if is_raw: - with gzip.open(output_path.with_suffix(".npy.gz"), "wb") as f: + with gzip.open( + output_path.with_suffix(".npy.gz"), "wb" + ) as f: np.save(f, output_image) elif self.image_format == "png": - media.write_image(output_path.with_suffix(".png"), output_image, fmt="png") + media.write_image( + output_path.with_suffix(".png"), output_image, fmt="png" + ) elif self.image_format == "jpeg": media.write_image( - output_path.with_suffix(".jpg"), output_image, fmt="jpeg", quality=self.jpeg_quality + output_path.with_suffix(".jpg"), + output_image, + fmt="jpeg", + quality=self.jpeg_quality, ) else: - raise ValueError(f"Unknown image format {self.image_format}") + raise ValueError( + f"Unknown image format {self.image_format}" + ) table = Table( title=None, @@ -879,7 +979,13 @@ def update_config(config: TrainerConfig) -> TrainerConfig: ) for split in self.split.split("+"): table.add_row(f"Outputs {split}", str(self.output_path / split)) - CONSOLE.print(Panel(table, title="[bold][green]:tada: Render on split {} Complete :tada:[/bold]", expand=False)) + CONSOLE.print( + Panel( + table, + title="[bold][green]:tada: Render on split {} Complete :tada:[/bold]", + expand=False, + ) + ) Commands = tyro.conf.FlagConversionOff[ diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index 752306f9f3..76e054d5e4 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -48,7 +48,11 @@ def _(_) -> None: populate_mesh_tab(server, control_panel, config_path, viewing_gsplat) -def show_command_modal(client: viser.ClientHandle, what: Literal["mesh", "point cloud", "splat"], command: str) -> None: +def show_command_modal( + client: viser.ClientHandle, + what: Literal["mesh", "point cloud", "splat"], + command: str, +) -> None: """Show a modal to each currently connected client. In the future, we should only show the modal to the client that pushes the @@ -73,7 +77,7 @@ def _(_) -> None: modal.close() -def get_crop_string(obb: OrientedBox, crop_viewport: bool): +def get_crop_string(obb: OrientedBox, crop_viewport: bool) -> List[str]: """Takes in an oriented bounding box and returns a string of the form "--obb_{center,rotation,scale} and each arg formatted with spaces around it """ @@ -85,7 +89,7 @@ def get_crop_string(obb: OrientedBox, crop_viewport: bool): rpystring = " ".join([f"{x:.10f}" for x in rpy]) posstring = " ".join([f"{x:.10f}" for x in pos]) scalestring = " ".join([f"{x:.10f}" for x in scale]) - return f"--obb_center {posstring} --obb_rotation {rpystring} --obb_scale {scalestring}" + return [posstring, rpystring, scalestring] def populate_point_cloud_tab( @@ -95,8 +99,12 @@ def populate_point_cloud_tab( viewing_gsplat: bool, ) -> None: if not viewing_gsplat: - server.add_gui_markdown("Render depth, project to an oriented point cloud, and filter ") - num_points = server.add_gui_number("# Points", initial_value=1_000_000, min=1, max=None, step=1) + server.add_gui_markdown( + "Render depth, project to an oriented point cloud, and filter " + ) + num_points = server.add_gui_number( + "# Points", initial_value=1_000_000, min=1, max=None, step=1 + ) world_frame = server.add_gui_checkbox( "Save in world frame", False, @@ -113,28 +121,54 @@ def populate_point_cloud_tab( initial_value="open3d", hint="Normal map source.", ) - output_dir = server.add_gui_text("Output Directory", initial_value="exports/pcd/") - generate_command = server.add_gui_button("Generate Command", icon=viser.Icon.TERMINAL_2) + output_dir = server.add_gui_text( + "Output Directory", initial_value="exports/pcd/" + ) + export_button = server.add_gui_button("Export", icon=viser.Icon.TERMINAL_2) - @generate_command.on_click + @export_button.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None - command = " ".join( - [ - "ns-export pointcloud", - f"--load-config {config_path}", - f"--output-dir {output_dir.value}", - f"--num-points {num_points.value}", - f"--remove-outliers {remove_outliers.value}", - f"--normal-method {normals.value}", - f"--save-world-frame {world_frame.value}", - get_crop_string(control_panel.crop_obb, control_panel.crop_viewport), - ] + server.add_notification( + title="Exporting point cloud", + body="File will be saved under " + str(output_dir.value), + withCloseButton=True, + loading=True, + autoClose=False, ) - show_command_modal(event.client, "point cloud", command) + + from nerfstudio.scripts.exporter import ExportPointCloud + + posstring, rpystring, scalestring = get_crop_string( + control_panel.crop_obb, control_panel.crop_viewport + ) + export = ExportPointCloud( + load_config=config_path, + output_dir=output_dir.value, + num_points=num_points.value, + remove_outliers=remove_outliers.value, + normal_method=normals.value, + save_world_frame=world_frame.value, + obb_center=posstring, + obb_rotation=rpystring, + obb_scale=scalestring, + ) + export.main() + + if export.complete: + server.clear_notification() + server.add_notification( + title="Export complete!", + body="File saved under " + str(output_dir.value), + withCloseButton=True, + loading=False, + autoClose=5000, + ) else: - server.add_gui_markdown("Point cloud export is not currently supported with Gaussian Splatting") + server.add_gui_markdown( + "Point cloud export is not currently supported with Gaussian Splatting" + ) def populate_mesh_tab( @@ -155,33 +189,63 @@ def populate_mesh_tab( hint="Source for normal maps.", ) num_faces = server.add_gui_number("# Faces", initial_value=50_000, min=1) - texture_resolution = server.add_gui_number("Texture Resolution", min=8, initial_value=2048) - output_directory = server.add_gui_text("Output Directory", initial_value="exports/mesh/") - num_points = server.add_gui_number("# Points", initial_value=1_000_000, min=1, max=None, step=1) + texture_resolution = server.add_gui_number( + "Texture Resolution", min=8, initial_value=2048 + ) + output_dir = server.add_gui_text( + "Output Directory", initial_value="exports/mesh/" + ) + num_points = server.add_gui_number( + "# Points", initial_value=1_000_000, min=1, max=None, step=1 + ) remove_outliers = server.add_gui_checkbox("Remove outliers", True) - generate_command = server.add_gui_button("Generate Command", icon=viser.Icon.TERMINAL_2) + export_button = server.add_gui_button("Export", icon=viser.Icon.TERMINAL_2) - @generate_command.on_click + @export_button.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None - command = " ".join( - [ - "ns-export poisson", - f"--load-config {config_path}", - f"--output-dir {output_directory.value}", - f"--target-num-faces {num_faces.value}", - f"--num-pixels-per-side {texture_resolution.value}", - f"--num-points {num_points.value}", - f"--remove-outliers {remove_outliers.value}", - f"--normal-method {normals.value}", - get_crop_string(control_panel.crop_obb, control_panel.crop_viewport), - ] + server.add_notification( + title="Exporting poisson mesh", + body="File will be saved under " + str(output_dir.value), + withCloseButton=True, + loading=True, + autoClose=False, ) - show_command_modal(event.client, "mesh", command) + + from nerfstudio.scripts.exporter import ExportPoissonMesh + + posstring, rpystring, scalestring = get_crop_string( + control_panel.crop_obb, control_panel.crop_viewport + ) + export = ExportPoissonMesh( + load_config=config_path, + output_dir=output_dir.value, + target_num_faces=num_faces.value, + num_pixels_per_side=texture_resolution.value, + num_points=num_points.value, + remove_outliers=remove_outliers.value, + normal_method=normals.value, + obb_center=posstring, + obb_rotation=rpystring, + obb_scale=scalestring, + ) + export.main() + + if export.complete: + server.clear_notification() + server.add_notification( + title="Export complete!", + body="File saved under " + str(output_dir.value), + withCloseButton=True, + loading=False, + autoClose=5000, + ) else: - server.add_gui_markdown("Mesh export is not currently supported with Gaussian Splatting") + server.add_gui_markdown( + "Mesh export is not currently supported with Gaussian Splatting" + ) def populate_splat_tab( @@ -191,24 +255,50 @@ def populate_splat_tab( viewing_gsplat: bool, ) -> None: if viewing_gsplat: - server.add_gui_markdown("Generate ply export of Gaussian Splat") + server.add_gui_markdown("Export ply of Gaussian Splat") + + output_dir = server.add_gui_text( + "Output Directory", initial_value="exports/splat/" + ) - output_directory = server.add_gui_text("Output Directory", initial_value="exports/splat/") - # TODO: change to export directly - generate_command = server.add_gui_button("Generate Command", icon=viser.Icon.TERMINAL_2) + export_button = server.add_gui_button("Export", icon=viser.Icon.TERMINAL_2) - @generate_command.on_click + @export_button.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None - command = " ".join( - [ - "ns-export gaussian-splat", - f"--load-config {config_path}", - f"--output-dir {output_directory.value}", - get_crop_string(control_panel.crop_obb, control_panel.crop_viewport), - ] + server.add_notification( + title="Exporting gaussian splat", + body="File will be saved under " + str(output_dir.value), + withCloseButton=True, + loading=True, + autoClose=False, ) - show_command_modal(event.client, "splat", command) + + from nerfstudio.scripts.exporter import ExportGaussianSplat + + posstring, rpystring, scalestring = get_crop_string( + control_panel.crop_obb, control_panel.crop_viewport + ) + export = ExportGaussianSplat( + load_config=config_path, + output_dir=output_dir.value, + obb_center=posstring, + obb_rotation=rpystring, + obb_scale=scalestring, + ) + export.main() + + if export.complete: + server.clear_notification() + server.add_notification( + title="Export complete!", + body="File saved under " + str(output_dir.value), + withCloseButton=True, + loading=False, + autoClose=5000, + ) else: - server.add_gui_markdown("Splat export is only supported with Gaussian Splatting methods") + server.add_gui_markdown( + "Splat export is only supported with Gaussian Splatting methods" + ) diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index 76d1aa7148..c5862b903e 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -63,7 +63,10 @@ def from_camera(camera: viser.CameraHandle, aspect: float) -> Keyframe: class CameraPath: def __init__( - self, server: viser.ViserServer, duration_element: viser.GuiInputHandle[float], time_enabled: bool = False + self, + server: viser.ViserServer, + duration_element: viser.GuiInputHandle[float], + time_enabled: bool = False, ): self._server = server self._keyframes: Dict[int, Tuple[Keyframe, viser.CameraFrustumHandle]] = {} @@ -93,7 +96,9 @@ def set_keyframes_visible(self, visible: bool) -> None: for keyframe in self._keyframes.values(): keyframe[1].visible = visible - def add_camera(self, keyframe: Keyframe, keyframe_index: Optional[int] = None) -> None: + def add_camera( + self, keyframe: Keyframe, keyframe_index: Optional[int] = None + ) -> None: """Add a new camera, or replace an old one if `keyframe_index` is passed in.""" server = self._server @@ -104,7 +109,9 @@ def add_camera(self, keyframe: Keyframe, keyframe_index: Optional[int] = None) - frustum_handle = server.add_camera_frustum( f"/render_cameras/{keyframe_index}", - fov=keyframe.override_fov_rad if keyframe.override_fov_enabled else self.default_fov, + fov=keyframe.override_fov_rad + if keyframe.override_fov_enabled + else self.default_fov, aspect=keyframe.aspect, scale=0.1, color=(200, 10, 30), @@ -129,7 +136,9 @@ def _(_) -> None: position=keyframe.position, ) as camera_edit_panel: self._camera_edit_panel = camera_edit_panel - override_fov = server.add_gui_checkbox("Override FOV", initial_value=keyframe.override_fov_enabled) + override_fov = server.add_gui_checkbox( + "Override FOV", initial_value=keyframe.override_fov_enabled + ) override_fov_degrees = server.add_gui_slider( "Override FOV (degrees)", 5.0, @@ -162,7 +171,9 @@ def _(_) -> None: keyframe.override_time_val = override_time_val.value self.add_camera(keyframe, keyframe_index) - delete_button = server.add_gui_button("Delete", color="red", icon=viser.Icon.TRASH) + delete_button = server.add_gui_button( + "Delete", color="red", icon=viser.Icon.TRASH + ) go_to_button = server.add_gui_button("Go to") close_button = server.add_gui_button("Close") @@ -182,7 +193,9 @@ def _(event: viser.GuiEvent) -> None: assert event.client is not None with event.client.add_gui_modal("Confirm") as modal: event.client.add_gui_markdown("Delete keyframe?") - confirm_button = event.client.add_gui_button("Yes", color="red", icon=viser.Icon.TRASH) + confirm_button = event.client.add_gui_button( + "Yes", color="red", icon=viser.Icon.TRASH + ) exit_button = event.client.add_gui_button("Cancel") @confirm_button.on_click @@ -221,7 +234,9 @@ def _(event: viser.GuiEvent) -> None: T_current_target = T_world_current.inverse() @ T_world_target for j in range(10): - T_world_set = T_world_current @ tf.SE3.exp(T_current_target.log() * j / 9.0) + T_world_set = T_world_current @ tf.SE3.exp( + T_current_target.log() * j / 9.0 + ) # Important bit: we atomically set both the orientation and the position # of the camera. @@ -276,10 +291,14 @@ def spline_t_from_t_sec(self, time: np.ndarray) -> np.ndarray: ], axis=0, ), - y=np.concatenate([[-1], spline_indices, [spline_indices[-1] + 1]], axis=0), + y=np.concatenate( + [[-1], spline_indices, [spline_indices[-1] + 1]], axis=0 + ), ) else: - interpolator = interpolate.PchipInterpolator(x=transition_times_cumsum, y=spline_indices) + interpolator = interpolate.PchipInterpolator( + x=transition_times_cumsum, y=spline_indices + ) # Clip to account for floating point error. return np.clip(interpolator(time), 0, spline_indices[-1]) @@ -292,7 +311,9 @@ def interpolate_pose_and_fov_rad( self._fov_spline = splines.KochanekBartels( [ - keyframe[0].override_fov_rad if keyframe[0].override_fov_enabled else self.default_fov + keyframe[0].override_fov_rad + if keyframe[0].override_fov_enabled + else self.default_fov for keyframe in self._keyframes.values() ], tcb=(self.tension, 0.0, 0.0), @@ -301,7 +322,9 @@ def interpolate_pose_and_fov_rad( self._time_spline = splines.KochanekBartels( [ - keyframe[0].override_time_val if keyframe[0].override_time_enabled else self.default_render_time + keyframe[0].override_time_val + if keyframe[0].override_time_enabled + else self.default_render_time for keyframe in self._keyframes.values() ], tcb=(self.tension, 0.0, 0.0), @@ -351,7 +374,9 @@ def update_spline(self) -> None: self._orientation_spline = splines.quaternion.KochanekBartels( [ - splines.quaternion.UnitQuaternion.from_unit_xyzw(np.roll(keyframe[0].wxyz, shift=-1)) + splines.quaternion.UnitQuaternion.from_unit_xyzw( + np.roll(keyframe[0].wxyz, shift=-1) + ) for keyframe in keyframes ], tcb=(self.tension, 0.0, 0.0), @@ -365,9 +390,16 @@ def update_spline(self) -> None: # Update visualized spline. points_array = self._position_spline.evaluate( - self.spline_t_from_t_sec(np.linspace(0, transition_times_cumsum[-1], num_frames)) + self.spline_t_from_t_sec( + np.linspace(0, transition_times_cumsum[-1], num_frames) + ) + ) + colors_array = np.array( + [ + colorsys.hls_to_rgb(h, 0.5, 1.0) + for h in np.linspace(0.0, 1.0, len(points_array)) + ] ) - colors_array = np.array([colorsys.hls_to_rgb(h, 0.5, 1.0) for h in np.linspace(0.0, 1.0, len(points_array))]) # Clear prior spline nodes. for node in self._spline_nodes: @@ -398,7 +430,8 @@ def make_transition_handle(i: int) -> None: transition_pos = self._position_spline.evaluate( float( self.spline_t_from_t_sec( - (transition_times_cumsum[i] + transition_times_cumsum[i + 1]) / 2.0, + (transition_times_cumsum[i] + transition_times_cumsum[i + 1]) + / 2.0, ) ) ) @@ -444,8 +477,12 @@ def _(_) -> None: @override_transition_enabled.on_update def _(_) -> None: - keyframe.override_transition_enabled = override_transition_enabled.value - override_transition_sec.disabled = not override_transition_enabled.value + keyframe.override_transition_enabled = ( + override_transition_enabled.value + ) + override_transition_sec.disabled = ( + not override_transition_enabled.value + ) self._duration_element.value = self.compute_duration() @override_transition_sec.on_update @@ -474,7 +511,8 @@ def compute_duration(self) -> float: del frustum total += ( keyframe.override_transition_sec - if keyframe.override_transition_enabled and keyframe.override_transition_sec is not None + if keyframe.override_transition_enabled + and keyframe.override_transition_sec is not None else self.default_transition_sec ) return total @@ -489,7 +527,8 @@ def compute_transition_times_cumsum(self) -> np.ndarray: del frustum total += ( keyframe.override_transition_sec - if keyframe.override_transition_enabled and keyframe.override_transition_sec is not None + if keyframe.override_transition_enabled + and keyframe.override_transition_sec is not None else self.default_transition_sec ) out.append(total) @@ -498,7 +537,8 @@ def compute_transition_times_cumsum(self) -> np.ndarray: keyframe = next(iter(self._keyframes.values()))[0] total += ( keyframe.override_transition_sec - if keyframe.override_transition_enabled and keyframe.override_transition_sec is not None + if keyframe.override_transition_enabled + and keyframe.override_transition_sec is not None else self.default_transition_sec ) out.append(total) @@ -622,7 +662,9 @@ def _(event: viser.GuiEvent) -> None: client = server.get_clients()[event.client_id] with client.atomic(), client.add_gui_modal("Confirm") as modal: client.add_gui_markdown("Clear all keyframes?") - confirm_button = client.add_gui_button("Yes", color="red", icon=viser.Icon.TRASH) + confirm_button = client.add_gui_button( + "Yes", color="red", icon=viser.Icon.TRASH + ) exit_button = client.add_gui_button("Cancel") @confirm_button.on_click @@ -643,7 +685,9 @@ def _(_) -> None: def _(_) -> None: modal.close() - loop = server.add_gui_checkbox("Loop", False, hint="Add a segment between the first and last keyframes.") + loop = server.add_gui_checkbox( + "Loop", False, hint="Add a segment between the first and last keyframes." + ) @loop.on_update def _(_) -> None: @@ -731,11 +775,15 @@ def _(_) -> None: playback_folder = server.add_gui_folder("Playback") with playback_folder: play_button = server.add_gui_button("Play", icon=viser.Icon.PLAYER_PLAY) - pause_button = server.add_gui_button("Pause", icon=viser.Icon.PLAYER_PAUSE, visible=False) + pause_button = server.add_gui_button( + "Pause", icon=viser.Icon.PLAYER_PAUSE, visible=False + ) preview_render_button = server.add_gui_button( "Preview Render", hint="Show a preview of the render in the viewport." ) - preview_render_stop_button = server.add_gui_button("Exit Render Preview", color="red", visible=False) + preview_render_stop_button = server.add_gui_button( + "Exit Render Preview", color="red", visible=False + ) transition_sec_number = server.add_gui_number( "Transition (sec)", @@ -745,7 +793,9 @@ def _(_) -> None: initial_value=2.0, hint="Time in seconds between each keyframe, which can also be overridden on a per-transition basis.", ) - framerate_number = server.add_gui_number("FPS", min=0.1, max=240.0, step=1e-2, initial_value=30.0) + framerate_number = server.add_gui_number( + "FPS", min=0.1, max=240.0, step=1e-2, initial_value=30.0 + ) framerate_buttons = server.add_gui_button_group("", ("24", "30", "60")) duration_number = server.add_gui_number( "Duration (sec)", @@ -776,7 +826,9 @@ def remove_preview_camera() -> None: preview_camera_handle.remove() preview_camera_handle = None - def compute_and_update_preview_camera_state() -> Optional[Union[Tuple[tf.SE3, float], Tuple[tf.SE3, float, float]]]: + def compute_and_update_preview_camera_state() -> ( + Optional[Union[Tuple[tf.SE3, float], Tuple[tf.SE3, float, float]]] + ): """Update the render tab state with the current preview camera pose. Returns current camera pose + FOV if available.""" @@ -890,7 +942,9 @@ def _(_) -> None: for client in server.get_clients().values(): if client.client_id not in camera_pose_backup_from_id: continue - cam_position, cam_look_at, cam_up = camera_pose_backup_from_id.pop(client.client_id) + cam_position, cam_look_at, cam_up = camera_pose_backup_from_id.pop( + client.client_id + ) client.camera.position = cam_position client.camera.look_at = cam_look_at client.camera.up_direction = cam_up @@ -931,7 +985,9 @@ def play() -> None: max_frame = int(framerate_number.value * duration_number.value) if max_frame > 0: assert preview_frame_slider is not None - preview_frame_slider.value = (preview_frame_slider.value + 1) % max_frame + preview_frame_slider.value = ( + preview_frame_slider.value + 1 + ) % max_frame time.sleep(1.0 / framerate_number.value) threading.Thread(target=play).start() @@ -978,7 +1034,9 @@ def _(_) -> None: camera_path.reset() for i in range(len(keyframes)): frame = keyframes[i] - pose = tf.SE3.from_matrix(np.array(frame["matrix"]).reshape(4, 4)) + pose = tf.SE3.from_matrix( + np.array(frame["matrix"]).reshape(4, 4) + ) # apply the x rotation by 180 deg pose = tf.SE3.from_rotation_and_translation( pose.rotation() @ tf.SO3.from_x_radians(np.pi), @@ -986,21 +1044,33 @@ def _(_) -> None: ) camera_path.add_camera( Keyframe( - position=pose.translation() * VISER_NERFSTUDIO_SCALE_RATIO, + position=pose.translation() + * VISER_NERFSTUDIO_SCALE_RATIO, wxyz=pose.rotation().wxyz, # There are some floating point conversions between degrees and radians, so the fov and # default_Fov values will not be exactly matched. - override_fov_enabled=abs(frame["fov"] - json_data.get("default_fov", 0.0)) > 1e-3, + override_fov_enabled=abs( + frame["fov"] - json_data.get("default_fov", 0.0) + ) + > 1e-3, override_fov_rad=frame["fov"] / 180.0 * np.pi, - override_time_enabled=frame.get("override_time_enabled", False), + override_time_enabled=frame.get( + "override_time_enabled", False + ), override_time_val=frame.get("render_time", None), aspect=frame["aspect"], - override_transition_enabled=frame.get("override_transition_enabled", None), - override_transition_sec=frame.get("override_transition_sec", None), + override_transition_enabled=frame.get( + "override_transition_enabled", None + ), + override_transition_sec=frame.get( + "override_transition_sec", None + ), ), ) - transition_sec_number.value = json_data.get("default_transition_sec", 0.5) + transition_sec_number.value = json_data.get( + "default_transition_sec", 0.5 + ) # update the render name render_name_text.value = json_path.stem @@ -1045,7 +1115,9 @@ def _(_) -> None: @reset_up_button.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None - event.client.camera.up_direction = tf.SO3(event.client.camera.wxyz) @ np.array([0.0, -1.0, 0.0]) + event.client.camera.up_direction = tf.SO3(event.client.camera.wxyz) @ np.array( + [0.0, -1.0, 0.0] + ) @render_button.on_click def _(event: viser.GuiEvent) -> None: @@ -1086,20 +1158,26 @@ def _(event: viser.GuiEvent) -> None: ) keyframe_dict = { "matrix": pose.as_matrix().flatten().tolist(), - "fov": np.rad2deg(keyframe.override_fov_rad) if keyframe.override_fov_enabled else fov_degrees.value, + "fov": np.rad2deg(keyframe.override_fov_rad) + if keyframe.override_fov_enabled + else fov_degrees.value, "aspect": keyframe.aspect, "override_transition_enabled": keyframe.override_transition_enabled, "override_transition_sec": keyframe.override_transition_sec, } if render_time is not None: keyframe_dict["render_time"] = ( - keyframe.override_time_val if keyframe.override_time_enabled else render_time.value + keyframe.override_time_val + if keyframe.override_time_enabled + else render_time.value ) keyframe_dict["override_time_enabled"] = keyframe.override_time_enabled keyframes.append(keyframe_dict) json_data["default_fov"] = fov_degrees.value if render_time is not None: - json_data["default_time"] = render_time.value if render_time is not None else None + json_data["default_time"] = ( + render_time.value if render_time is not None else None + ) json_data["default_transition_sec"] = transition_sec_number.value json_data["keyframes"] = keyframes json_data["camera_type"] = camera_type.value.lower() @@ -1112,7 +1190,9 @@ def _(event: viser.GuiEvent) -> None: # now populate the camera path: camera_path_list = [] for i in range(num_frames): - maybe_pose_and_fov = camera_path.interpolate_pose_and_fov_rad(i / num_frames) + maybe_pose_and_fov = camera_path.interpolate_pose_and_fov_rad( + i / num_frames + ) if maybe_pose_and_fov is None: return time = None @@ -1155,20 +1235,22 @@ def _(event: viser.GuiEvent) -> None: # rendering from nerfstudio.scripts.render import RenderCameraPath + render = RenderCameraPath( - load_config=config_path, - camera_path_filename=json_outfile.absolute(), - output_path=Path(render_path) - ) + load_config=config_path, + camera_path_filename=json_outfile.absolute(), + output_path=Path(render_path), + ) render.main() - + if render.complete: + server.clear_notification() server.add_notification( - title="Render complete!", - body="Video saved as " + render_path, - withCloseButton=True, - loading=False, - autoClose=5000, + title="Render complete!", + body="Video saved as " + render_path, + withCloseButton=True, + loading=False, + autoClose=5000, ) if control_panel is not None: From 44ae194ba9a6f4bb2054d6d3b37cb231e4594af1 Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Thu, 13 Jun 2024 14:51:50 -0700 Subject: [PATCH 06/33] viser notification api changes --- nerfstudio/viewer/export_panel.py | 90 ++++++++++++++++--------------- nerfstudio/viewer/render_panel.py | 30 ++++++----- 2 files changed, 64 insertions(+), 56 deletions(-) diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index 76e054d5e4..62a506fe44 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -129,13 +129,14 @@ def populate_point_cloud_tab( @export_button.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None - server.add_notification( - title="Exporting point cloud", - body="File will be saved under " + str(output_dir.value), - withCloseButton=True, - loading=True, - autoClose=False, - ) + notif = server.add_notification( + title="Exporting point cloud", + body="File will be saved under " + str(output_dir.value), + withCloseButton=True, + loading=True, + autoClose=False, + ) + notif.show() from nerfstudio.scripts.exporter import ExportPointCloud @@ -157,13 +158,14 @@ def _(event: viser.GuiEvent) -> None: if export.complete: server.clear_notification() - server.add_notification( - title="Export complete!", - body="File saved under " + str(output_dir.value), - withCloseButton=True, - loading=False, - autoClose=5000, - ) + notif = server.add_notification( + title="Export complete!", + body="File saved under " + str(output_dir.value), + withCloseButton=True, + loading=False, + autoClose=5000, + ) + notif.show() else: server.add_gui_markdown( @@ -205,13 +207,14 @@ def populate_mesh_tab( @export_button.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None - server.add_notification( - title="Exporting poisson mesh", - body="File will be saved under " + str(output_dir.value), - withCloseButton=True, - loading=True, - autoClose=False, - ) + notif = server.add_notification( + title="Exporting poisson mesh", + body="File will be saved under " + str(output_dir.value), + withCloseButton=True, + loading=True, + autoClose=False, + ) + notif.show() from nerfstudio.scripts.exporter import ExportPoissonMesh @@ -234,13 +237,14 @@ def _(event: viser.GuiEvent) -> None: if export.complete: server.clear_notification() - server.add_notification( - title="Export complete!", - body="File saved under " + str(output_dir.value), - withCloseButton=True, - loading=False, - autoClose=5000, - ) + notif = server.add_notification( + title="Export complete!", + body="File saved under " + str(output_dir.value), + withCloseButton=True, + loading=False, + autoClose=5000, + ) + notif.show() else: server.add_gui_markdown( @@ -266,13 +270,14 @@ def populate_splat_tab( @export_button.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None - server.add_notification( - title="Exporting gaussian splat", - body="File will be saved under " + str(output_dir.value), - withCloseButton=True, - loading=True, - autoClose=False, - ) + notif = server.add_notification( + title="Exporting gaussian splat", + body="File will be saved under " + str(output_dir.value), + withCloseButton=True, + loading=True, + autoClose=False, + ) + notif.show() from nerfstudio.scripts.exporter import ExportGaussianSplat @@ -290,13 +295,14 @@ def _(event: viser.GuiEvent) -> None: if export.complete: server.clear_notification() - server.add_notification( - title="Export complete!", - body="File saved under " + str(output_dir.value), - withCloseButton=True, - loading=False, - autoClose=5000, - ) + notif = server.add_notification( + title="Export complete!", + body="File saved under " + str(output_dir.value), + withCloseButton=True, + loading=False, + autoClose=5000, + ) + notif.show() else: server.add_gui_markdown( diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index c5862b903e..1c159dfa37 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -1123,13 +1123,14 @@ def _(event: viser.GuiEvent) -> None: def _(event: viser.GuiEvent) -> None: assert event.client is not None render_path = f"""renders/{datapath.name}/{render_name_text.value}.mp4""" - server.add_notification( - title="Rendering trajectory", - body="Saving rendered video as " + render_path, - withCloseButton=True, - loading=True, - autoClose=False, - ) + notif = server.add_notification( + title="Rendering trajectory", + body="Saving rendered video as " + render_path, + withCloseButton=True, + loading=True, + autoClose=False, + ) + notif.show() num_frames = int(framerate_number.value * duration_number.value) json_data = {} @@ -1245,13 +1246,14 @@ def _(event: viser.GuiEvent) -> None: if render.complete: server.clear_notification() - server.add_notification( - title="Render complete!", - body="Video saved as " + render_path, - withCloseButton=True, - loading=False, - autoClose=5000, - ) + notif = server.add_notification( + title="Render complete!", + body="Video saved as " + render_path, + withCloseButton=True, + loading=False, + autoClose=5000, + ) + notif.show() if control_panel is not None: camera_path = CameraPath(server, duration_number, control_panel._time_enabled) From ec044216f2d1def46f0e05db7bad34be2a082f53 Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Thu, 13 Jun 2024 18:43:52 -0700 Subject: [PATCH 07/33] export bug fixes --- nerfstudio/viewer/export_panel.py | 42 ++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index 62a506fe44..5c6192c4d1 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -18,7 +18,7 @@ import viser import viser.transforms as vtf -from typing_extensions import Literal +from typing_extensions import Literal, List from nerfstudio.data.scene_box import OrientedBox from nerfstudio.models.base_model import Model @@ -138,14 +138,18 @@ def _(event: viser.GuiEvent) -> None: ) notif.show() - from nerfstudio.scripts.exporter import ExportPointCloud + if control_panel.crop_obb is not None and control_panel.crop_viewport: + posstring, rpystring, scalestring = get_crop_string( + control_panel.crop_obb, control_panel.crop_viewport + ) + else: + posstring = rpystring = scalestring = None - posstring, rpystring, scalestring = get_crop_string( - control_panel.crop_obb, control_panel.crop_viewport - ) + from nerfstudio.scripts.exporter import ExportPointCloud + export = ExportPointCloud( load_config=config_path, - output_dir=output_dir.value, + output_dir=Path(output_dir.value), num_points=num_points.value, remove_outliers=remove_outliers.value, normal_method=normals.value, @@ -216,14 +220,18 @@ def _(event: viser.GuiEvent) -> None: ) notif.show() - from nerfstudio.scripts.exporter import ExportPoissonMesh + if control_panel.crop_obb is not None and control_panel.crop_viewport: + posstring, rpystring, scalestring = get_crop_string( + control_panel.crop_obb, control_panel.crop_viewport + ) + else: + posstring = rpystring = scalestring = None - posstring, rpystring, scalestring = get_crop_string( - control_panel.crop_obb, control_panel.crop_viewport - ) + from nerfstudio.scripts.exporter import ExportPoissonMesh + export = ExportPoissonMesh( load_config=config_path, - output_dir=output_dir.value, + output_dir=Path(output_dir.value), target_num_faces=num_faces.value, num_pixels_per_side=texture_resolution.value, num_points=num_points.value, @@ -279,14 +287,18 @@ def _(event: viser.GuiEvent) -> None: ) notif.show() + if control_panel.crop_obb is not None and control_panel.crop_viewport: + posstring, rpystring, scalestring = get_crop_string( + control_panel.crop_obb, control_panel.crop_viewport + ) + else: + posstring = rpystring = scalestring = None + from nerfstudio.scripts.exporter import ExportGaussianSplat - posstring, rpystring, scalestring = get_crop_string( - control_panel.crop_obb, control_panel.crop_viewport - ) export = ExportGaussianSplat( load_config=config_path, - output_dir=output_dir.value, + output_dir=Path(output_dir.value), obb_center=posstring, obb_rotation=rpystring, obb_scale=scalestring, From 1aad57b3d8fc5946faaf46277ac98edf972f2e93 Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Thu, 13 Jun 2024 20:01:53 -0700 Subject: [PATCH 08/33] add (kind of?) error message for checkpoint isn't found during export --- nerfstudio/viewer/export_panel.py | 63 ++++++++++++++++--------------- 1 file changed, 33 insertions(+), 30 deletions(-) diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index 5c6192c4d1..5cc407de54 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -91,6 +91,30 @@ def get_crop_string(obb: OrientedBox, crop_viewport: bool) -> List[str]: scalestring = " ".join([f"{x:.10f}" for x in scale]) return [posstring, rpystring, scalestring] +def show_notification( + export_complete: bool, + server: viser.ViserServer, + output_dir: str) -> None: + if export_complete: + server.clear_notification() + notif = server.add_notification( + title="Export complete!", + body="File saved under " + output_dir, + withCloseButton=True, + loading=False, + autoClose=5000, + ) + notif.show() + else: + server.clear_notification() + notif = server.add_notification( + title="Export error!", + body="Please try again after a checkpoint is saved.", + withCloseButton=True, + loading=False, + autoClose=5000, + ) + def populate_point_cloud_tab( server: viser.ViserServer, @@ -160,16 +184,9 @@ def _(event: viser.GuiEvent) -> None: ) export.main() - if export.complete: - server.clear_notification() - notif = server.add_notification( - title="Export complete!", - body="File saved under " + str(output_dir.value), - withCloseButton=True, - loading=False, - autoClose=5000, - ) - notif.show() + show_notification(export_complete=export.complete, + server=server, + output_dir=str(output_dir.value)) else: server.add_gui_markdown( @@ -243,16 +260,9 @@ def _(event: viser.GuiEvent) -> None: ) export.main() - if export.complete: - server.clear_notification() - notif = server.add_notification( - title="Export complete!", - body="File saved under " + str(output_dir.value), - withCloseButton=True, - loading=False, - autoClose=5000, - ) - notif.show() + show_notification(export_complete=export.complete, + server=server, + output_dir=str(output_dir.value)) else: server.add_gui_markdown( @@ -305,16 +315,9 @@ def _(event: viser.GuiEvent) -> None: ) export.main() - if export.complete: - server.clear_notification() - notif = server.add_notification( - title="Export complete!", - body="File saved under " + str(output_dir.value), - withCloseButton=True, - loading=False, - autoClose=5000, - ) - notif.show() + show_notification(export_complete=export.complete, + server=server, + output_dir=str(output_dir.value)) else: server.add_gui_markdown( From 14839ddbff5ba3b1ea9357b615704b64884a7f4d Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Thu, 13 Jun 2024 21:51:59 -0700 Subject: [PATCH 09/33] message --- nerfstudio/viewer/export_panel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index 5cc407de54..216eb5e779 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -109,7 +109,7 @@ def show_notification( server.clear_notification() notif = server.add_notification( title="Export error!", - body="Please try again after a checkpoint is saved.", + body="Please try again after a checkpoint is saved after 2000 steps.", withCloseButton=True, loading=False, autoClose=5000, From df134724ee89e47b98e6b90a8b14f29168df0cf1 Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Sat, 15 Jun 2024 12:01:25 -0700 Subject: [PATCH 10/33] add export warning message --- nerfstudio/viewer/export_panel.py | 40 ++----------------------------- 1 file changed, 2 insertions(+), 38 deletions(-) diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index 216eb5e779..9b79d86620 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -40,6 +40,7 @@ def populate_export_tab( def _(_) -> None: control_panel.crop_viewport = crop_output.value + server.add_gui_markdown("Export available after a checkpoint is saved (default minimum 2000 steps)") with server.add_gui_folder("Splat"): populate_splat_tab(server, control_panel, config_path, viewing_gsplat) with server.add_gui_folder("Point Cloud"): @@ -48,35 +49,6 @@ def _(_) -> None: populate_mesh_tab(server, control_panel, config_path, viewing_gsplat) -def show_command_modal( - client: viser.ClientHandle, - what: Literal["mesh", "point cloud", "splat"], - command: str, -) -> None: - """Show a modal to each currently connected client. - - In the future, we should only show the modal to the client that pushes the - generation button. - """ - with client.add_gui_modal(what.title() + " Export") as modal: - client.add_gui_markdown( - "\n".join( - [ - f"To export a {what}, run the following from the command line:", - "", - "```", - command, - "```", - ] - ) - ) - close_button = client.add_gui_button("Close") - - @close_button.on_click - def _(_) -> None: - modal.close() - - def get_crop_string(obb: OrientedBox, crop_viewport: bool) -> List[str]: """Takes in an oriented bounding box and returns a string of the form "--obb_{center,rotation,scale} and each arg formatted with spaces around it @@ -91,6 +63,7 @@ def get_crop_string(obb: OrientedBox, crop_viewport: bool) -> List[str]: scalestring = " ".join([f"{x:.10f}" for x in scale]) return [posstring, rpystring, scalestring] + def show_notification( export_complete: bool, server: viser.ViserServer, @@ -105,15 +78,6 @@ def show_notification( autoClose=5000, ) notif.show() - else: - server.clear_notification() - notif = server.add_notification( - title="Export error!", - body="Please try again after a checkpoint is saved after 2000 steps.", - withCloseButton=True, - loading=False, - autoClose=5000, - ) def populate_point_cloud_tab( From 6fbccf16036cf5a919926135e8cdbd163c10277b Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Sun, 16 Jun 2024 16:03:27 -0700 Subject: [PATCH 11/33] wip --- nerfstudio/viewer/render_panel.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index 1c159dfa37..7579afaffc 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -1122,7 +1122,7 @@ def _(event: viser.GuiEvent) -> None: @render_button.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None - render_path = f"""renders/{datapath.name}/{render_name_text.value}.mp4""" + render_path = f"renders/{datapath.name}/{render_name_text.value}.mp4" notif = server.add_notification( title="Rendering trajectory", body="Saving rendered video as " + render_path, @@ -1255,6 +1255,25 @@ def _(event: viser.GuiEvent) -> None: ) notif.show() + with server.gui.add_modal("Render download") as modal: + server.gui.add_markdown("Download the rendered video to your local machine:") + + download_button = server.gui.add_button("Download") + + @downlaod_button.on_click + def _(_) -> None: + import imageio.v3 as iio + + server.send_file_download( + f"{render_name_text.value}.mp4", iio.imwrite("", images, extension=".gif") + ) + + close_button = server.gui.add_button("Close") + + @close_button.on_click + def _(_) -> None: + modal.close() + if control_panel is not None: camera_path = CameraPath(server, duration_number, control_panel._time_enabled) else: From 006f5f3aa069a499546b5027e2176db84a1a1551 Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Fri, 21 Jun 2024 16:58:49 -0700 Subject: [PATCH 12/33] update to match viser notification api --- nerfstudio/viewer/export_panel.py | 53 ++++++++++--------------------- nerfstudio/viewer/render_panel.py | 37 +++------------------ 2 files changed, 22 insertions(+), 68 deletions(-) diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index 9b79d86620..9e721ae1f4 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -63,23 +63,6 @@ def get_crop_string(obb: OrientedBox, crop_viewport: bool) -> List[str]: scalestring = " ".join([f"{x:.10f}" for x in scale]) return [posstring, rpystring, scalestring] - -def show_notification( - export_complete: bool, - server: viser.ViserServer, - output_dir: str) -> None: - if export_complete: - server.clear_notification() - notif = server.add_notification( - title="Export complete!", - body="File saved under " + output_dir, - withCloseButton=True, - loading=False, - autoClose=5000, - ) - notif.show() - - def populate_point_cloud_tab( server: viser.ViserServer, control_panel: ControlPanel, @@ -117,14 +100,11 @@ def populate_point_cloud_tab( @export_button.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None - notif = server.add_notification( + notif = server.gui.add_notification( title="Exporting point cloud", body="File will be saved under " + str(output_dir.value), - withCloseButton=True, loading=True, - autoClose=False, ) - notif.show() if control_panel.crop_obb is not None and control_panel.crop_viewport: posstring, rpystring, scalestring = get_crop_string( @@ -147,10 +127,12 @@ def _(event: viser.GuiEvent) -> None: obb_scale=scalestring, ) export.main() - - show_notification(export_complete=export.complete, - server=server, - output_dir=str(output_dir.value)) + + if export.complete: + notif.update( + title="Export complete!", + body="File saved under " + str(output_dir.value), + ) else: server.add_gui_markdown( @@ -195,11 +177,8 @@ def _(event: viser.GuiEvent) -> None: notif = server.add_notification( title="Exporting poisson mesh", body="File will be saved under " + str(output_dir.value), - withCloseButton=True, loading=True, - autoClose=False, ) - notif.show() if control_panel.crop_obb is not None and control_panel.crop_viewport: posstring, rpystring, scalestring = get_crop_string( @@ -224,9 +203,11 @@ def _(event: viser.GuiEvent) -> None: ) export.main() - show_notification(export_complete=export.complete, - server=server, - output_dir=str(output_dir.value)) + if export.complete: + notif.update( + title="Export complete!", + body="File saved under " + str(output_dir.value), + ) else: server.add_gui_markdown( @@ -255,9 +236,7 @@ def _(event: viser.GuiEvent) -> None: notif = server.add_notification( title="Exporting gaussian splat", body="File will be saved under " + str(output_dir.value), - withCloseButton=True, loading=True, - autoClose=False, ) notif.show() @@ -279,9 +258,11 @@ def _(event: viser.GuiEvent) -> None: ) export.main() - show_notification(export_complete=export.complete, - server=server, - output_dir=str(output_dir.value)) + if export.complete: + notif.update( + title="Export complete!", + body="File saved under " + str(output_dir.value), + ) else: server.add_gui_markdown( diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index 7579afaffc..4cccc76229 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -1123,14 +1123,11 @@ def _(event: viser.GuiEvent) -> None: def _(event: viser.GuiEvent) -> None: assert event.client is not None render_path = f"renders/{datapath.name}/{render_name_text.value}.mp4" - notif = server.add_notification( + notif = server.gui.add_notification( title="Rendering trajectory", body="Saving rendered video as " + render_path, - withCloseButton=True, loading=True, - autoClose=False, ) - notif.show() num_frames = int(framerate_number.value * duration_number.value) json_data = {} @@ -1245,34 +1242,10 @@ def _(event: viser.GuiEvent) -> None: render.main() if render.complete: - server.clear_notification() - notif = server.add_notification( - title="Render complete!", - body="Video saved as " + render_path, - withCloseButton=True, - loading=False, - autoClose=5000, - ) - notif.show() - - with server.gui.add_modal("Render download") as modal: - server.gui.add_markdown("Download the rendered video to your local machine:") - - download_button = server.gui.add_button("Download") - - @downlaod_button.on_click - def _(_) -> None: - import imageio.v3 as iio - - server.send_file_download( - f"{render_name_text.value}.mp4", iio.imwrite("", images, extension=".gif") - ) - - close_button = server.gui.add_button("Close") - - @close_button.on_click - def _(_) -> None: - modal.close() + notif.update( + title="Render complete!", + body="Video saved as " + render_path, + ) if control_panel is not None: camera_path = CameraPath(server, duration_number, control_panel._time_enabled) From 56e61397a8bcb24c823b9ac6e0a3c1f910b492a8 Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Fri, 21 Jun 2024 17:54:25 -0700 Subject: [PATCH 13/33] add local render download --- nerfstudio/viewer/render_panel.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index 4cccc76229..7a34560ef1 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -1091,6 +1091,8 @@ def _(_) -> None: hint="Name of the render", ) + server.add_gui_markdown("Render available after a checkpoint is saved (default minimum 2000 steps)") + render_button = server.add_gui_button( "Render", color="green", @@ -1098,12 +1100,12 @@ def _(_) -> None: hint="Render the camera path and save video as mp4 file.", ) - # generate_render_button = server.add_gui_button( - # "Generate Command", - # color="green", - # icon=viser.Icon.FILE_EXPORT, - # hint="Generate the ns-render command for rendering the camera path.", - # ) + download_render_button = server.gui.add_button( + "Download Render", + color="green", + icon=viser.Icon.DOWNLOAD, + hint="Download the latest render locally as mp4 file." + ) reset_up_button = server.add_gui_button( "Reset Up Direction", @@ -1247,6 +1249,20 @@ def _(event: viser.GuiEvent) -> None: body="Video saved as " + render_path, ) + @download_render_button.on_click + def _(event: viser.GuiEvent) -> None: + render_path = f"renders/{datapath.name}/{render_name_text.value}.mp4" + + client = event.client + assert client is not None + + with open(render_path, 'rb') as file: + video_bytes = file.read() + + client.send_file_download( + "render.mp4", video_bytes + ) + if control_panel is not None: camera_path = CameraPath(server, duration_number, control_panel._time_enabled) else: From d39a7292331d9c525922991c4211f78bb7fb07a4 Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Fri, 21 Jun 2024 18:35:50 -0700 Subject: [PATCH 14/33] wip for export local download --- nerfstudio/viewer/export_panel.py | 40 +++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index 9e721ae1f4..2131e7604d 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -96,6 +96,7 @@ def populate_point_cloud_tab( "Output Directory", initial_value="exports/pcd/" ) export_button = server.add_gui_button("Export", icon=viser.Icon.TERMINAL_2) + download_button = server.gui.add_button("Download Point Cloud", icon=viser.Icon.DOWNLOAD) @export_button.on_click def _(event: viser.GuiEvent) -> None: @@ -134,6 +135,18 @@ def _(event: viser.GuiEvent) -> None: body="File saved under " + str(output_dir.value), ) + @download_button.on_click + def _(event: viser.GuiEvent) -> None: + client = event.client + assert client is not None + + with open(str(output_dir.value) + "point_cloud.ply", 'rb') as ply_file: + ply_bytes = ply_file.read() + + client.send_file_download( + "point_cloud.ply", ply_bytes + ) + else: server.add_gui_markdown( "Point cloud export is not currently supported with Gaussian Splatting" @@ -170,6 +183,7 @@ def populate_mesh_tab( remove_outliers = server.add_gui_checkbox("Remove outliers", True) export_button = server.add_gui_button("Export", icon=viser.Icon.TERMINAL_2) + download_button = server.gui.add_button("Download Mesh", icon=viser.Icon.DOWNLOAD) @export_button.on_click def _(event: viser.GuiEvent) -> None: @@ -208,6 +222,19 @@ def _(event: viser.GuiEvent) -> None: title="Export complete!", body="File saved under " + str(output_dir.value), ) + + @download_button.on_click + def _(event: viser.GuiEvent) -> None: + client = event.client + assert client is not None + + with open(str(output_dir.value) + "poisson_mesh.ply", 'rb') as ply_file: + ply_bytes = ply_file.read() + + client.send_file_download( + "poisson_mesh.ply", ply_bytes + ) + else: server.add_gui_markdown( @@ -229,6 +256,7 @@ def populate_splat_tab( ) export_button = server.add_gui_button("Export", icon=viser.Icon.TERMINAL_2) + download_button = server.gui.add_button("Download Splat", icon=viser.Icon.DOWNLOAD) @export_button.on_click def _(event: viser.GuiEvent) -> None: @@ -263,6 +291,18 @@ def _(event: viser.GuiEvent) -> None: title="Export complete!", body="File saved under " + str(output_dir.value), ) + + @download_button.on_click + def _(event: viser.GuiEvent) -> None: + client = event.client + assert client is not None + + with open(str(output_dir.value) + "splat.ply", 'rb') as ply_file: + ply_bytes = ply_file.read() + + client.send_file_download( + "splat.ply", ply_bytes + ) else: server.add_gui_markdown( From b6210fd8cdbbf681fff1ed491780172c4b2556df Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Wed, 26 Jun 2024 19:19:09 -0700 Subject: [PATCH 15/33] update icons --- nerfstudio/viewer/export_panel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index 2131e7604d..2ebb2e6601 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -95,7 +95,7 @@ def populate_point_cloud_tab( output_dir = server.add_gui_text( "Output Directory", initial_value="exports/pcd/" ) - export_button = server.add_gui_button("Export", icon=viser.Icon.TERMINAL_2) + export_button = server.add_gui_button("Export", icon=viser.Icon.FILE_EXPORT) download_button = server.gui.add_button("Download Point Cloud", icon=viser.Icon.DOWNLOAD) @export_button.on_click @@ -182,7 +182,7 @@ def populate_mesh_tab( ) remove_outliers = server.add_gui_checkbox("Remove outliers", True) - export_button = server.add_gui_button("Export", icon=viser.Icon.TERMINAL_2) + export_button = server.add_gui_button("Export", icon=viser.Icon.FILE_EXPORT) download_button = server.gui.add_button("Download Mesh", icon=viser.Icon.DOWNLOAD) @export_button.on_click @@ -255,7 +255,7 @@ def populate_splat_tab( "Output Directory", initial_value="exports/splat/" ) - export_button = server.add_gui_button("Export", icon=viser.Icon.TERMINAL_2) + export_button = server.add_gui_button("Export", icon=viser.Icon.FILE_EXPORT) download_button = server.gui.add_button("Download Splat", icon=viser.Icon.DOWNLOAD) @export_button.on_click From 0ddabed68c4f34c5da5d29ee46e318447709194e Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Mon, 26 Aug 2024 16:24:53 -0700 Subject: [PATCH 16/33] better rendering functionality and command palette --- nerfstudio/viewer/export_panel.py | 30 ++++-- nerfstudio/viewer/render_panel.py | 149 ++++++++++++++++++++++-------- 2 files changed, 130 insertions(+), 49 deletions(-) diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index ba050fe257..a86586c451 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -114,13 +114,18 @@ def populate_point_cloud_tab( initial_value="open3d", hint="Normal map source.", ) + + export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) + download_button = server.gui.add_button("Download Point Cloud", icon=viser.Icon.DOWNLOAD) output_dir = server.gui.add_text("Output Directory", initial_value="exports/pcd/") generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) @export_button.on_click def _(event: viser.GuiEvent) -> None: - assert event.client is not None - notif = server.gui.add_notification( + client = event.client + assert client is not None + + notif = client.add_notification( title="Exporting point cloud", body="File will be saved under " + str(output_dir.value), loading=True, @@ -189,16 +194,20 @@ def populate_mesh_tab( ) num_faces = server.gui.add_number("# Faces", initial_value=50_000, min=1) texture_resolution = server.gui.add_number("Texture Resolution", min=8, initial_value=2048) - output_directory = server.gui.add_text("Output Directory", initial_value="exports/mesh/") + output_dir = server.gui.add_text("Output Directory", initial_value="exports/mesh/") num_points = server.gui.add_number("# Points", initial_value=1_000_000, min=1, max=None, step=1) remove_outliers = server.gui.add_checkbox("Remove outliers", True) + export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) + download_button = server.gui.add_button("Download Mesh", icon=viser.Icon.DOWNLOAD) generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) @export_button.on_click def _(event: viser.GuiEvent) -> None: - assert event.client is not None - notif = server.add_notification( + client = event.client + assert client is not None + + notif = client.add_notification( title="Exporting poisson mesh", body="File will be saved under " + str(output_dir.value), loading=True, @@ -259,13 +268,18 @@ def populate_splat_tab( if viewing_gsplat: server.gui.add_markdown("Generate ply export of Gaussian Splat") - output_directory = server.gui.add_text("Output Directory", initial_value="exports/splat/") + export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) + download_button = server.gui.add_button("Download Splat", icon=viser.Icon.DOWNLOAD) + + output_dir = server.gui.add_text("Output Directory", initial_value="exports/splat/") generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) @export_button.on_click def _(event: viser.GuiEvent) -> None: - assert event.client is not None - notif = server.add_notification( + client = event.client + assert client is not None + + notif = client.add_notification( title="Exporting gaussian splat", body="File will be saved under " + str(output_dir.value), loading=True, diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index 032b164a1f..329e1ed28b 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -31,7 +31,6 @@ from scipy import interpolate from nerfstudio.viewer.control_panel import ControlPanel -from nerfstudio.utils.scripts import run_command @dataclasses.dataclass @@ -877,6 +876,7 @@ def _(_) -> None: position=pose.translation(), color=(10, 200, 30), ) + if render_tab_state.preview_render: for client in server.get_clients().values(): client.camera.wxyz = pose.rotation().wxyz @@ -1021,11 +1021,13 @@ def _(_) -> None: pose = tf.SE3.from_matrix( np.array(frame["matrix"]).reshape(4, 4) ) + # apply the x rotation by 180 deg pose = tf.SE3.from_rotation_and_translation( pose.rotation() @ tf.SO3.from_x_radians(np.pi), pose.translation(), ) + camera_path.add_camera( Keyframe( position=pose.translation() @@ -1075,45 +1077,29 @@ def _(_) -> None: hint="Name of the render", ) - server.add_gui_markdown("Render available after a checkpoint is saved (default minimum 2000 steps)") - - render_button = server.gui.add_button( - "Render", - color="green", - icon=viser.Icon.FILE_EXPORT, - hint="Render the camera path and save video as mp4 file.", - ) - - download_render_button = server.gui.add_button( - "Download Render", - color="green", - icon=viser.Icon.DOWNLOAD, - hint="Download the latest render locally as mp4 file." - ) + render_folder = server.gui.add_folder("Render") + with render_folder: + server.gui.add_markdown("Render available after a checkpoint is saved (default minimum 2000 steps)") - reset_up_button = server.gui.add_button( - "Reset Up Direction", - icon=viser.Icon.ARROW_BIG_UP_LINES, - color="gray", - hint="Set the up direction of the camera orbit controls to the camera's current up direction.", - ) + render_button = server.gui.add_button( + "Render", + icon=viser.Icon.VIDEO, + hint="Render the camera path and save video as mp4 file.", + ) - @reset_up_button.on_click - def _(event: viser.GuiEvent) -> None: - assert event.client is not None - event.client.camera.up_direction = tf.SO3(event.client.camera.wxyz) @ np.array( - [0.0, -1.0, 0.0] + download_render_button = server.gui.add_button( + "Download Render", + icon=viser.Icon.DOWNLOAD, + hint="Download the latest render locally as mp4 file." ) - @render_button.on_click - def _(event: viser.GuiEvent) -> None: - assert event.client is not None - render_path = f"renders/{datapath.name}/{render_name_text.value}.mp4" - notif = server.gui.add_notification( - title="Rendering trajectory", - body="Saving rendered video as " + render_path, - loading=True, - ) + generate_command_render_button = server.gui.add_button( + "Generate Command", + icon=viser.Icon.FILE_EXPORT, + hint="Generate the ns-render command for rendering the camera path instead of directly rendering.", + ) + + def _write_json() -> json: num_frames = int(framerate_number.value * duration_number.value) json_data = {} @@ -1133,6 +1119,7 @@ def _(event: viser.GuiEvent) -> None: # camera_to_world: flattened 4x4 matrix # fov: float in degrees # aspect: float + # first populate the keyframes: keyframes = [] for keyframe, _ in camera_path._keyframes.values(): @@ -1140,6 +1127,7 @@ def _(event: viser.GuiEvent) -> None: tf.SO3(keyframe.wxyz) @ tf.SO3.from_x_radians(np.pi), keyframe.position / VISER_NERFSTUDIO_SCALE_RATIO, ) + keyframe_dict = { "matrix": pose.as_matrix().flatten().tolist(), "fov": np.rad2deg(keyframe.override_fov_rad) @@ -1149,6 +1137,7 @@ def _(event: viser.GuiEvent) -> None: "override_transition_enabled": keyframe.override_transition_enabled, "override_transition_sec": keyframe.override_transition_sec, } + if render_time is not None: keyframe_dict["render_time"] = ( keyframe.override_time_val @@ -1156,12 +1145,16 @@ def _(event: viser.GuiEvent) -> None: else render_time.value ) keyframe_dict["override_time_enabled"] = keyframe.override_time_enabled + keyframes.append(keyframe_dict) + json_data["default_fov"] = fov_degrees.value + if render_time is not None: json_data["default_time"] = ( render_time.value if render_time is not None else None ) + json_data["default_transition_sec"] = transition_sec_number.value json_data["keyframes"] = keyframes json_data["camera_type"] = camera_type.value.lower() @@ -1171,6 +1164,7 @@ def _(event: viser.GuiEvent) -> None: json_data["seconds"] = duration_number.value json_data["is_cycle"] = loop.value json_data["smoothness_value"] = tension_slider.value + # now populate the camera path: camera_path_list = [] for i in range(num_frames): @@ -1184,20 +1178,24 @@ def _(event: viser.GuiEvent) -> None: pose, fov, time = maybe_pose_and_fov else: pose, fov = maybe_pose_and_fov + # rotate the axis of the camera 180 about x axis pose = tf.SE3.from_rotation_and_translation( pose.rotation() @ tf.SO3.from_x_radians(np.pi), pose.translation() / VISER_NERFSTUDIO_SCALE_RATIO, ) + camera_path_list_dict = { "camera_to_world": pose.as_matrix().flatten().tolist(), "fov": np.rad2deg(fov), "aspect": resolution.value[0] / resolution.value[1], } + if time is not None: camera_path_list_dict["render_time"] = time camera_path_list.append(camera_path_list_dict) json_data["camera_path"] = camera_path_list + # finally add crop data if crop is enabled if control_panel is not None: if control_panel.crop_viewport: @@ -1216,8 +1214,27 @@ def _(event: viser.GuiEvent) -> None: json_outfile.parent.mkdir(parents=True, exist_ok=True) with open(json_outfile.absolute(), "w") as outfile: json.dump(json_data, outfile) + + return json_outfile + + # TODO: disable render button while rendering/only enable download if file exists + # TODO: progress bar while rendering + + @render_button.on_click + def _(event: viser.GuiEvent) -> None: + client = event.client + assert client is not None + + render_path = f"renders/{datapath.name}/{render_name_text.value}.mp4" + json_outfile = _write_json() + + notif = client.add_notification( + title="Rendering trajectory", + body="Saving rendered video as " + render_path, + loading=True, + with_close_button=False, + ) - # rendering from nerfstudio.scripts.render import RenderCameraPath render = RenderCameraPath( @@ -1228,10 +1245,10 @@ def _(event: viser.GuiEvent) -> None: render.main() if render.complete: - notif.update( - title="Render complete!", - body="Video saved as " + render_path, - ) + notif.title = "Render complete!" + notif.body = "Video saved as " + render_path + notif.loading = False + notif.with_close_button = True @download_render_button.on_click def _(event: viser.GuiEvent) -> None: @@ -1246,11 +1263,61 @@ def _(event: viser.GuiEvent) -> None: client.send_file_download( "render.mp4", video_bytes ) + + @generate_command_render_button.on_click + def _(event: viser.GuiEvent) -> None: + client = event.client + assert client is not None + + render_path = f"renders/{datapath.name}/{render_name_text.value}.mp4" + json_outfile = _write_json() + + with event.client.gui.add_modal("Render Command") as modal: + dataname = datapath.name + command = " ".join( + [ + "ns-render camera-path", + f"--load-config {config_path}", + f"--camera-path-filename {json_outfile.absolute()}", + f"--output-path renders/{dataname}/{render_name_text.value}.mp4", + ] + ) + event.client.gui.add_markdown( + "\n".join( + [ + "To render the trajectory, run the following from the command line:", + "", + "```", + command, + "```", + ] + ) + ) + close_button = event.client.gui.add_button("Close") + + @close_button.on_click + def _(_) -> None: + modal.close() + + reset_up_button = server.gui.add_button( + "Reset Up Direction", + icon=viser.Icon.ARROW_BIG_UP_LINES, + color="gray", + hint="Set the up direction of the camera orbit controls to the camera's current up direction.", + ) + + @reset_up_button.on_click + def _(event: viser.GuiEvent) -> None: + assert event.client is not None + event.client.camera.up_direction = tf.SO3(event.client.camera.wxyz) @ np.array( + [0.0, -1.0, 0.0] + ) if control_panel is not None: camera_path = CameraPath(server, duration_number, control_panel._time_enabled) else: camera_path = CameraPath(server, duration_number) + camera_path.tension = tension_slider.value camera_path.default_fov = fov_degrees.value / 180.0 * np.pi camera_path.default_transition_sec = transition_sec_number.value From 8a5bd0b5bf1a3ea18a5386a4a7d2af69a1e5587a Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Mon, 26 Aug 2024 17:00:57 -0700 Subject: [PATCH 17/33] add button disable settings for render requirements and in-progress --- nerfstudio/scripts/render.py | 2 +- nerfstudio/viewer/render_panel.py | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/nerfstudio/scripts/render.py b/nerfstudio/scripts/render.py index 0db00aab69..5402227e03 100644 --- a/nerfstudio/scripts/render.py +++ b/nerfstudio/scripts/render.py @@ -498,7 +498,7 @@ class RenderCameraPath(BaseRender): output_format: Literal["images", "video"] = "video" """How to save output data.""" complete: bool = True - """Whether rendering is complete""" + """Set to True when render is finished.""" def main(self) -> None: """Main function.""" diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index 329e1ed28b..3538910da3 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -1217,14 +1217,18 @@ def _write_json() -> json: return json_outfile - # TODO: disable render button while rendering/only enable download if file exists - # TODO: progress bar while rendering + # TODO: disable render button while before checkpoints/keyframes placed down + + render_complete = False + download_render_button.disabled = not render_complete @render_button.on_click def _(event: viser.GuiEvent) -> None: client = event.client assert client is not None + render_button.disabled = True + render_path = f"renders/{datapath.name}/{render_name_text.value}.mp4" json_outfile = _write_json() @@ -1250,6 +1254,11 @@ def _(event: viser.GuiEvent) -> None: notif.loading = False notif.with_close_button = True + nonlocal render_complete + render_complete = True + + render_button.disabled = False + @download_render_button.on_click def _(event: viser.GuiEvent) -> None: render_path = f"renders/{datapath.name}/{render_name_text.value}.mp4" @@ -1322,6 +1331,10 @@ def _(event: viser.GuiEvent) -> None: camera_path.default_fov = fov_degrees.value / 180.0 * np.pi camera_path.default_transition_sec = transition_sec_number.value + # disable render option if no keyframes are added + render_button.disabled = len(camera_path._keyframes) <= 0 + generate_command_render_button.disabled = len(camera_path._keyframes) <= 0 + return render_tab_state From 8027a492d7bb42e2a162923ea7c1a336d4520dfa Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Mon, 26 Aug 2024 17:30:50 -0700 Subject: [PATCH 18/33] fix button disabling for renders --- nerfstudio/viewer/render_panel.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index 3538910da3..48c62de1c9 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -643,6 +643,10 @@ def _(event: viser.GuiEvent) -> None: duration_number.value = camera_path.compute_duration() camera_path.update_spline() + # Enable render only if there are two or more keyframes + render_button.disabled = len(camera_path._keyframes) < 2 + generate_command_render_button.disabled = len(camera_path._keyframes) < 2 + clear_keyframes_button = server.gui.add_button( "Clear Keyframes", icon=viser.Icon.TRASH, @@ -663,6 +667,9 @@ def _(_) -> None: camera_path.reset() modal.close() + render_button.disabled = True + generate_command_render_button.disabled = True + duration_number.value = camera_path.compute_duration() # Clear move handles. @@ -1085,18 +1092,21 @@ def _(_) -> None: "Render", icon=viser.Icon.VIDEO, hint="Render the camera path and save video as mp4 file.", + disabled=True, ) download_render_button = server.gui.add_button( "Download Render", icon=viser.Icon.DOWNLOAD, - hint="Download the latest render locally as mp4 file." + hint="Download the latest render locally as mp4 file.", + disabled=True, ) generate_command_render_button = server.gui.add_button( "Generate Command", icon=viser.Icon.FILE_EXPORT, hint="Generate the ns-render command for rendering the camera path instead of directly rendering.", + disabled=True, ) def _write_json() -> json: @@ -1217,11 +1227,6 @@ def _write_json() -> json: return json_outfile - # TODO: disable render button while before checkpoints/keyframes placed down - - render_complete = False - download_render_button.disabled = not render_complete - @render_button.on_click def _(event: viser.GuiEvent) -> None: client = event.client @@ -1254,11 +1259,9 @@ def _(event: viser.GuiEvent) -> None: notif.loading = False notif.with_close_button = True - nonlocal render_complete - render_complete = True - render_button.disabled = False - + download_render_button.disabled = False + @download_render_button.on_click def _(event: viser.GuiEvent) -> None: render_path = f"renders/{datapath.name}/{render_name_text.value}.mp4" @@ -1331,10 +1334,6 @@ def _(event: viser.GuiEvent) -> None: camera_path.default_fov = fov_degrees.value / 180.0 * np.pi camera_path.default_transition_sec = transition_sec_number.value - # disable render option if no keyframes are added - render_button.disabled = len(camera_path._keyframes) <= 0 - generate_command_render_button.disabled = len(camera_path._keyframes) <= 0 - return render_tab_state From b5e9f8c533ad6cfc5a9354463b8397fe8945eab4 Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Mon, 26 Aug 2024 18:38:10 -0700 Subject: [PATCH 19/33] add render cancellation option --- nerfstudio/scripts/render.py | 22 ++++++++++++++++++---- nerfstudio/viewer/render_panel.py | 24 ++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/nerfstudio/scripts/render.py b/nerfstudio/scripts/render.py index 5402227e03..3da70d304c 100644 --- a/nerfstudio/scripts/render.py +++ b/nerfstudio/scripts/render.py @@ -93,9 +93,10 @@ def _render_trajectory_video( depth_near_plane: Optional[float] = None, depth_far_plane: Optional[float] = None, colormap_options: colormaps.ColormapOptions = colormaps.ColormapOptions(), - render_nearest_camera=False, + render_nearest_camera: bool = False, check_occlusions: bool = False, -) -> None: + kill_flag: List[bool] = [False], +) -> bool: """Helper function to create a video of the spiral trajectory. Args: @@ -155,6 +156,9 @@ def _render_trajectory_video( with progress: for camera_idx in progress.track(range(cameras.size), description=""): + if kill_flag[0]: + return False + obb_box = None if crop_data is not None: obb_box = crop_data.obb @@ -334,6 +338,8 @@ def _render_trajectory_video( ) ) + return True + def insert_spherical_metadata_into_file( output_filename: Path, @@ -487,6 +493,11 @@ class BaseRender: """If true, checks line-of-sight occlusions when computing camera distance and rejects cameras not visible to each other""" camera_idx: Optional[int] = None """Index of the training camera to render.""" + kill_flag: Optional[List[bool]] = field(default_factory=lambda: [False]) + """Stop execution of render if set to True.""" + + def kill(self) -> None: + self.kill_flag[0] = True @dataclass @@ -546,7 +557,7 @@ def main(self) -> None: if self.camera_idx is not None: camera_path.metadata = {"cam_idx": self.camera_idx} - _render_trajectory_video( + self.complete = _render_trajectory_video( pipeline, camera_path, output_filename=self.output_path, @@ -562,6 +573,7 @@ def main(self) -> None: colormap_options=self.colormap_options, render_nearest_camera=self.render_nearest_camera, check_occlusions=self.check_occlusions, + kill_flag=self.kill_flag, ) if ( @@ -597,6 +609,7 @@ def main(self) -> None: colormap_options=self.colormap_options, render_nearest_camera=self.render_nearest_camera, check_occlusions=self.check_occlusions, + kill_flag=self.kill_flag, ) self.output_path = Path(str(left_eye_path.parent)[:-5] + ".mp4") @@ -642,7 +655,6 @@ def main(self) -> None: if str(left_eye_path.parent)[-5:] == "_temp": shutil.rmtree(left_eye_path.parent, ignore_errors=True) CONSOLE.print("[bold green]Final VR180 Render Complete") - self.complete = True @dataclass @@ -701,6 +713,7 @@ def main(self) -> None: colormap_options=self.colormap_options, render_nearest_camera=self.render_nearest_camera, check_occlusions=self.check_occlusions, + kill_flag=self.kill_flag, ) @@ -756,6 +769,7 @@ def main(self) -> None: colormap_options=self.colormap_options, render_nearest_camera=self.render_nearest_camera, check_occlusions=self.check_occlusions, + kill_flag=self.kill_flag, ) diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index 48c62de1c9..9f63921b96 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -1095,6 +1095,13 @@ def _(_) -> None: disabled=True, ) + cancel_render_button = server.gui.add_button( + "Cancel Render", + icon=viser.Icon.CIRCLE_X, + hint="Cancel current render in progress.", + disabled=True, + ) + download_render_button = server.gui.add_button( "Download Render", icon=viser.Icon.DOWNLOAD, @@ -1233,6 +1240,7 @@ def _(event: viser.GuiEvent) -> None: assert client is not None render_button.disabled = True + cancel_render_button.disabled = False render_path = f"renders/{datapath.name}/{render_name_text.value}.mp4" json_outfile = _write_json() @@ -1251,6 +1259,18 @@ def _(event: viser.GuiEvent) -> None: camera_path_filename=json_outfile.absolute(), output_path=Path(render_path), ) + + @cancel_render_button.on_click + def _(event: viser.GuiEvent) -> None: + render.kill() + render_button.disabled = False + cancel_render_button.disabled = True + + notif.title = "Render canceled" + notif.body = "The render in progress has been canceled." + notif.loading = False + notif.with_close_button = True + render.main() if render.complete: @@ -1259,8 +1279,8 @@ def _(event: viser.GuiEvent) -> None: notif.loading = False notif.with_close_button = True - render_button.disabled = False - download_render_button.disabled = False + render_button.disabled = False + download_render_button.disabled = False @download_render_button.on_click def _(event: viser.GuiEvent) -> None: From fef8461f4481ae76d5dbecd8588ff1049853ef5f Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Mon, 26 Aug 2024 20:26:29 -0700 Subject: [PATCH 20/33] disable button feature on exports and clean export panel --- nerfstudio/viewer/export_panel.py | 95 ++++++++++++++++++++++++------- 1 file changed, 75 insertions(+), 20 deletions(-) diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index a86586c451..c19c29b546 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -115,9 +115,9 @@ def populate_point_cloud_tab( hint="Normal map source.", ) - export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) - download_button = server.gui.add_button("Download Point Cloud", icon=viser.Icon.DOWNLOAD) output_dir = server.gui.add_text("Output Directory", initial_value="exports/pcd/") + export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) + download_button = server.gui.add_button("Download Point Cloud", icon=viser.Icon.DOWNLOAD, disabled=True) generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) @export_button.on_click @@ -129,6 +129,7 @@ def _(event: viser.GuiEvent) -> None: title="Exporting point cloud", body="File will be saved under " + str(output_dir.value), loading=True, + with_close_button=False, ) if control_panel.crop_obb is not None and control_panel.crop_viewport: @@ -154,10 +155,12 @@ def _(event: viser.GuiEvent) -> None: export.main() if export.complete: - notif.update( - title="Export complete!", - body="File saved under " + str(output_dir.value), - ) + notif.title = "Export complete!" + notif.body = "File saved under " + str(output_dir.value) + notif.loading = False + notif.with_close_button = True + + download_button.disabled = False @download_button.on_click def _(event: viser.GuiEvent) -> None: @@ -170,6 +173,23 @@ def _(event: viser.GuiEvent) -> None: client.send_file_download( "point_cloud.ply", ply_bytes ) + + @generate_command.on_click + def _(event: viser.GuiEvent) -> None: + assert event.client is not None + command = " ".join( + [ + "ns-export pointcloud", + f"--load-config {config_path}", + f"--output-dir {output_dir.value}", + f"--num-points {num_points.value}", + f"--remove-outliers {remove_outliers.value}", + f"--normal-method {normals.value}", + f"--save-world-frame {world_frame.value}", + get_crop_string(control_panel.crop_obb, control_panel.crop_viewport), + ] + ) + show_command_modal(event.client, "point cloud", command) else: server.gui.add_markdown("Point cloud export is not currently supported with Gaussian Splatting") @@ -194,12 +214,12 @@ def populate_mesh_tab( ) num_faces = server.gui.add_number("# Faces", initial_value=50_000, min=1) texture_resolution = server.gui.add_number("Texture Resolution", min=8, initial_value=2048) - output_dir = server.gui.add_text("Output Directory", initial_value="exports/mesh/") num_points = server.gui.add_number("# Points", initial_value=1_000_000, min=1, max=None, step=1) remove_outliers = server.gui.add_checkbox("Remove outliers", True) + output_dir = server.gui.add_text("Output Directory", initial_value="exports/mesh/") export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) - download_button = server.gui.add_button("Download Mesh", icon=viser.Icon.DOWNLOAD) + download_button = server.gui.add_button("Download Mesh", icon=viser.Icon.DOWNLOAD, disabled=True) generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) @export_button.on_click @@ -211,6 +231,7 @@ def _(event: viser.GuiEvent) -> None: title="Exporting poisson mesh", body="File will be saved under " + str(output_dir.value), loading=True, + with_close_button=False, ) if control_panel.crop_obb is not None and control_panel.crop_viewport: @@ -237,10 +258,12 @@ def _(event: viser.GuiEvent) -> None: export.main() if export.complete: - notif.update( - title="Export complete!", - body="File saved under " + str(output_dir.value), - ) + notif.title = "Export complete!" + notif.body = "File saved under " + str(output_dir.value) + notif.loading = False + notif.with_close_button = True + + download_button.disabled = False @download_button.on_click def _(event: viser.GuiEvent) -> None: @@ -253,6 +276,24 @@ def _(event: viser.GuiEvent) -> None: client.send_file_download( "poisson_mesh.ply", ply_bytes ) + + @generate_command.on_click + def _(event: viser.GuiEvent) -> None: + assert event.client is not None + command = " ".join( + [ + "ns-export poisson", + f"--load-config {config_path}", + f"--output-dir {output_dir.value}", + f"--target-num-faces {num_faces.value}", + f"--num-pixels-per-side {texture_resolution.value}", + f"--num-points {num_points.value}", + f"--remove-outliers {remove_outliers.value}", + f"--normal-method {normals.value}", + get_crop_string(control_panel.crop_obb, control_panel.crop_viewport), + ] + ) + show_command_modal(event.client, "mesh", command) else: @@ -267,11 +308,9 @@ def populate_splat_tab( ) -> None: if viewing_gsplat: server.gui.add_markdown("Generate ply export of Gaussian Splat") - - export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) - download_button = server.gui.add_button("Download Splat", icon=viser.Icon.DOWNLOAD) - output_dir = server.gui.add_text("Output Directory", initial_value="exports/splat/") + export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) + download_button = server.gui.add_button("Download Splat", icon=viser.Icon.DOWNLOAD, disabled=True) generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) @export_button.on_click @@ -283,6 +322,7 @@ def _(event: viser.GuiEvent) -> None: title="Exporting gaussian splat", body="File will be saved under " + str(output_dir.value), loading=True, + with_close_button=False, ) notif.show() @@ -305,10 +345,12 @@ def _(event: viser.GuiEvent) -> None: export.main() if export.complete: - notif.update( - title="Export complete!", - body="File saved under " + str(output_dir.value), - ) + notif.title = "Export complete!" + notif.body = "File saved under " + str(output_dir.value) + notif.loading = False + notif.with_close_button = True + + download_button.disabled = False @download_button.on_click def _(event: viser.GuiEvent) -> None: @@ -321,6 +363,19 @@ def _(event: viser.GuiEvent) -> None: client.send_file_download( "splat.ply", ply_bytes ) + + @generate_command.on_click + def _(event: viser.GuiEvent) -> None: + assert event.client is not None + command = " ".join( + [ + "ns-export gaussian-splat", + f"--load-config {config_path}", + f"--output-dir {output_directory.value}", + get_crop_string(control_panel.crop_obb, control_panel.crop_viewport), + ] + ) + show_command_modal(event.client, "splat", command) else: server.gui.add_markdown("Splat export is only supported with Gaussian Splatting methods") From 24235109e6a70ae6405b1c632266a63de223c0f7 Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Mon, 26 Aug 2024 20:33:38 -0700 Subject: [PATCH 21/33] ruff format --- nerfstudio/scripts/exporter.py | 102 +++++--------------- nerfstudio/scripts/render.py | 100 +++++-------------- nerfstudio/viewer/export_panel.py | 84 +++++++--------- nerfstudio/viewer/render_panel.py | 153 +++++++++--------------------- 4 files changed, 130 insertions(+), 309 deletions(-) diff --git a/nerfstudio/scripts/exporter.py b/nerfstudio/scripts/exporter.py index e4203b1db7..bfe6367784 100644 --- a/nerfstudio/scripts/exporter.py +++ b/nerfstudio/scripts/exporter.py @@ -69,9 +69,7 @@ class Exporter: """Set to True when export is finished.""" -def validate_pipeline( - normal_method: str, normal_output_name: str, pipeline: Pipeline -) -> None: +def validate_pipeline(normal_method: str, normal_output_name: str, pipeline: Pipeline) -> None: """Check that the pipeline is valid for this exporter. Args: @@ -93,9 +91,7 @@ def validate_pipeline( ) outputs = pipeline.model(ray_bundle) if normal_output_name not in outputs: - CONSOLE.print( - f"[bold yellow]Warning: Normal output '{normal_output_name}' not found in pipeline outputs." - ) + CONSOLE.print(f"[bold yellow]Warning: Normal output '{normal_output_name}' not found in pipeline outputs.") CONSOLE.print(f"Available outputs: {list(outputs.keys())}") CONSOLE.print( "[bold yellow]Warning: Please train a model with normals " @@ -160,21 +156,13 @@ def main(self) -> None: ), ) assert pipeline.datamanager.train_pixel_sampler is not None - pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = ( - self.num_rays_per_batch - ) + pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = self.num_rays_per_batch # Whether the normals should be estimated based on the point cloud. estimate_normals = self.normal_method == "open3d" crop_obb = None - if ( - self.obb_center is not None - and self.obb_rotation is not None - and self.obb_scale is not None - ): - crop_obb = OrientedBox.from_params( - self.obb_center, self.obb_rotation, self.obb_scale - ) + if self.obb_center is not None and self.obb_rotation is not None and self.obb_scale is not None: + crop_obb = OrientedBox.from_params(self.obb_center, self.obb_rotation, self.obb_scale) pcd = generate_point_cloud( pipeline=pipeline, num_points=self.num_points, @@ -183,18 +171,14 @@ def main(self) -> None: estimate_normals=estimate_normals, rgb_output_name=self.rgb_output_name, depth_output_name=self.depth_output_name, - normal_output_name=self.normal_output_name - if self.normal_method == "model_output" - else None, + normal_output_name=self.normal_output_name if self.normal_method == "model_output" else None, crop_obb=crop_obb, std_ratio=self.std_ratio, ) if self.save_world_frame: # apply the inverse dataparser transform to the point cloud points = np.asarray(pcd.points) - poses = np.eye(4, dtype=np.float32)[None, ...].repeat( - points.shape[0], axis=0 - )[:, :3, :] + poses = np.eye(4, dtype=np.float32)[None, ...].repeat(points.shape[0], axis=0)[:, :3, :] poses[:, :3, 3] = points poses = pipeline.datamanager.train_dataparser_outputs.transform_poses_to_original_space( torch.from_numpy(poses) @@ -285,9 +269,7 @@ def main(self) -> None: mesh, pipeline, self.output_dir, - px_per_uv_triangle=self.px_per_uv_triangle - if self.unwrap_method == "custom" - else None, + px_per_uv_triangle=self.px_per_uv_triangle if self.unwrap_method == "custom" else None, unwrap_method=self.unwrap_method, num_pixels_per_side=self.num_pixels_per_side, ) @@ -365,20 +347,12 @@ def main(self) -> None: ), ) assert pipeline.datamanager.train_pixel_sampler is not None - pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = ( - self.num_rays_per_batch - ) + pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = self.num_rays_per_batch # Whether the normals should be estimated based on the point cloud. estimate_normals = self.normal_method == "open3d" - if ( - self.obb_center is not None - and self.obb_rotation is not None - and self.obb_scale is not None - ): - crop_obb = OrientedBox.from_params( - self.obb_center, self.obb_rotation, self.obb_scale - ) + if self.obb_center is not None and self.obb_rotation is not None and self.obb_scale is not None: + crop_obb = OrientedBox.from_params(self.obb_center, self.obb_rotation, self.obb_scale) else: crop_obb = None @@ -390,9 +364,7 @@ def main(self) -> None: estimate_normals=estimate_normals, rgb_output_name=self.rgb_output_name, depth_output_name=self.depth_output_name, - normal_output_name=self.normal_output_name - if self.normal_method == "model_output" - else None, + normal_output_name=self.normal_output_name if self.normal_method == "model_output" else None, crop_obb=crop_obb, std_ratio=self.std_ratio, ) @@ -406,9 +378,7 @@ def main(self) -> None: CONSOLE.print("[bold green]:white_check_mark: Saving Point Cloud") CONSOLE.print("Computing Mesh... this may take a while.") - mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson( - pcd, depth=9 - ) + mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=9) vertices_to_remove = densities < np.quantile(densities, 0.1) mesh.remove_vertices_by_mask(vertices_to_remove) print("\033[A\033[A") @@ -432,9 +402,7 @@ def main(self) -> None: mesh, pipeline, self.output_dir, - px_per_uv_triangle=self.px_per_uv_triangle - if self.unwrap_method == "custom" - else None, + px_per_uv_triangle=self.px_per_uv_triangle if self.unwrap_method == "custom" else None, unwrap_method=self.unwrap_method, num_pixels_per_side=self.num_pixels_per_side, ) @@ -473,15 +441,11 @@ def main(self) -> None: _, pipeline, _, _ = eval_setup(self.load_config) # TODO: Make this work with Density Field - assert hasattr( - pipeline.model.config, "sdf_field" - ), "Model must have an SDF field." + assert hasattr(pipeline.model.config, "sdf_field"), "Model must have an SDF field." CONSOLE.print("Extracting mesh with marching cubes... which may take a while") - assert ( - self.resolution % 512 == 0 - ), f"""resolution must be divisible by 512, got {self.resolution}. + assert self.resolution % 512 == 0, f"""resolution must be divisible by 512, got {self.resolution}. This is important because the algorithm uses a multi-resolution approach to evaluate the SDF where the minimum resolution is 512.""" @@ -500,17 +464,13 @@ def main(self) -> None: multi_res_mesh.export(filename) # load the mesh from the marching cubes export - mesh = get_mesh_from_filename( - str(filename), target_num_faces=self.target_num_faces - ) + mesh = get_mesh_from_filename(str(filename), target_num_faces=self.target_num_faces) CONSOLE.print("Texturing mesh with NeRF...") texture_utils.export_textured_mesh( mesh, pipeline, self.output_dir, - px_per_uv_triangle=self.px_per_uv_triangle - if self.unwrap_method == "custom" - else None, + px_per_uv_triangle=self.px_per_uv_triangle if self.unwrap_method == "custom" else None, unwrap_method=self.unwrap_method, num_pixels_per_side=self.num_pixels_per_side, ) @@ -538,9 +498,7 @@ def main(self) -> None: ("transforms_eval.json", eval_frames), ]: if len(frames) == 0: - CONSOLE.print( - f"[bold yellow]No frames found for {file_name}. Skipping." - ) + CONSOLE.print(f"[bold yellow]No frames found for {file_name}. Skipping.") continue output_file_path = os.path.join(self.output_dir, file_name) @@ -548,9 +506,7 @@ def main(self) -> None: with open(output_file_path, "w", encoding="UTF-8") as f: json.dump(frames, f, indent=4) - CONSOLE.print( - f"[bold green]:white_check_mark: Saved poses to {output_file_path}" - ) + CONSOLE.print(f"[bold green]:white_check_mark: Saved poses to {output_file_path}") @dataclass @@ -594,9 +550,7 @@ def write_ply( and tensor.size > 0 for tensor in map_to_tensors.values() ): - raise ValueError( - "All tensors must be numpy arrays of float or uint8 type and not empty" - ) + raise ValueError("All tensors must be numpy arrays of float or uint8 type and not empty") with open(filename, "wb") as ply_file: # Write PLY header @@ -672,14 +626,8 @@ def main(self) -> None: for i in range(4): map_to_tensors[f"rot_{i}"] = quats[:, i, None] - if ( - self.obb_center is not None - and self.obb_rotation is not None - and self.obb_scale is not None - ): - crop_obb = OrientedBox.from_params( - self.obb_center, self.obb_rotation, self.obb_scale - ) + if self.obb_center is not None and self.obb_rotation is not None and self.obb_scale is not None: + crop_obb = OrientedBox.from_params(self.obb_center, self.obb_rotation, self.obb_scale) assert crop_obb is not None mask = crop_obb.within(torch.from_numpy(positions)).numpy() for k, t in map_to_tensors.items(): @@ -699,9 +647,7 @@ def main(self) -> None: CONSOLE.print(f"{n_before - n_after} NaN/Inf elements in {k}") if np.sum(select) < n: - CONSOLE.print( - f"values have NaN/Inf in map_to_tensors, only export {np.sum(select)}/{n}" - ) + CONSOLE.print(f"values have NaN/Inf in map_to_tensors, only export {np.sum(select)}/{n}") for k, t in map_to_tensors.items(): map_to_tensors[k] = map_to_tensors[k][select] count = np.sum(select) diff --git a/nerfstudio/scripts/render.py b/nerfstudio/scripts/render.py index 3da70d304c..b61e734341 100644 --- a/nerfstudio/scripts/render.py +++ b/nerfstudio/scripts/render.py @@ -158,7 +158,7 @@ def _render_trajectory_video( for camera_idx in progress.track(range(cameras.size), description=""): if kill_flag[0]: return False - + obb_box = None if crop_data is not None: obb_box = crop_data.obb @@ -171,19 +171,14 @@ def _render_trajectory_video( assert train_dataset is not None assert train_cameras is not None cam_pos = cameras[camera_idx].camera_to_worlds[:, 3].cpu() - cam_quat = tf.SO3.from_matrix( - cameras[camera_idx].camera_to_worlds[:3, :3].numpy(force=True) - ).wxyz + cam_quat = tf.SO3.from_matrix(cameras[camera_idx].camera_to_worlds[:3, :3].numpy(force=True)).wxyz for i in range(len(train_cameras)): train_cam_pos = train_cameras[i].camera_to_worlds[:, 3].cpu() # Make sure the line of sight from rendered cam to training cam is not blocked by any object bundle = RayBundle( origins=cam_pos.view(1, 3), - directions=( - (cam_pos - train_cam_pos) - / (cam_pos - train_cam_pos).norm() - ).view(1, 3), + directions=((cam_pos - train_cam_pos) / (cam_pos - train_cam_pos).norm()).view(1, 3), pixel_area=torch.tensor(1).view(1, 1), nears=torch.tensor(0.05).view(1, 1), fars=torch.tensor(100).view(1, 1), @@ -192,9 +187,7 @@ def _render_trajectory_video( ).to(pipeline.device) outputs = pipeline.model.get_outputs(bundle) - q = tf.SO3.from_matrix( - train_cameras[i].camera_to_worlds[:3, :3].numpy(force=True) - ).wxyz + q = tf.SO3.from_matrix(train_cameras[i].camera_to_worlds[:3, :3].numpy(force=True)).wxyz # calculate distance between two quaternions rot_dist = 1 - np.dot(q, cam_quat) ** 2 pos_dist = torch.norm(train_cam_pos - cam_pos) @@ -204,10 +197,7 @@ def _render_trajectory_video( true_max_dist = dist true_max_idx = i - if ( - outputs["depth"][0] - < torch.norm(cam_pos - train_cam_pos).item() - ): + if outputs["depth"][0] < torch.norm(cam_pos - train_cam_pos).item(): continue if check_occlusions and (max_dist == -1 or dist < max_dist): @@ -417,11 +407,7 @@ class CropData: background_color: Float[Tensor, "3"] = torch.Tensor([0.0, 0.0, 0.0]) """background color""" - obb: OrientedBox = field( - default_factory=lambda: OrientedBox( - R=torch.eye(3), T=torch.zeros(3), S=torch.ones(3) * 2 - ) - ) + obb: OrientedBox = field(default_factory=lambda: OrientedBox(R=torch.eye(3), T=torch.zeros(3), S=torch.ones(3) * 2)) """Oriented box representing the crop region""" # properties for backwards-compatibility interface @@ -447,18 +433,12 @@ def get_crop_from_json(camera_json: Dict[str, Any]) -> Optional[CropData]: bg_color = camera_json["crop"]["crop_bg_color"] center = camera_json["crop"]["crop_center"] scale = camera_json["crop"]["crop_scale"] - rot = ( - (0.0, 0.0, 0.0) - if "crop_rot" not in camera_json["crop"] - else tuple(camera_json["crop"]["crop_rot"]) - ) + rot = (0.0, 0.0, 0.0) if "crop_rot" not in camera_json["crop"] else tuple(camera_json["crop"]["crop_rot"]) assert len(center) == 3 assert len(scale) == 3 assert len(rot) == 3 return CropData( - background_color=torch.Tensor( - [bg_color["r"] / 255.0, bg_color["g"] / 255.0, bg_color["b"] / 255.0] - ), + background_color=torch.Tensor([bg_color["r"] / 255.0, bg_color["g"] / 255.0, bg_color["b"] / 255.0]), obb=OrientedBox.from_params(center, rot, scale), ) @@ -532,9 +512,7 @@ def main(self) -> None: or camera_path.camera_type[0] == CameraType.VR180_L.value ): # temp folder for writing left and right view renders - temp_folder_path = self.output_path.parent / ( - self.output_path.stem + "_temp" - ) + temp_folder_path = self.output_path.parent / (self.output_path.stem + "_temp") Path(temp_folder_path).mkdir(parents=True, exist_ok=True) left_eye_path = temp_folder_path / "render_left.mp4" @@ -542,9 +520,7 @@ def main(self) -> None: self.output_path = left_eye_path if camera_path.camera_type[0] == CameraType.OMNIDIRECTIONALSTEREO_L.value: - CONSOLE.print( - "[bold green]:goggles: Omni-directional Stereo VR :goggles:" - ) + CONSOLE.print("[bold green]:goggles: Omni-directional Stereo VR :goggles:") else: CONSOLE.print("[bold green]:goggles: VR180 :goggles:") @@ -834,39 +810,25 @@ def update_config(config: TrainerConfig) -> TrainerConfig: update_config_callback=update_config, ) data_manager_config = config.pipeline.datamanager - assert isinstance( - data_manager_config, (VanillaDataManagerConfig, FullImageDatamanagerConfig) - ) + assert isinstance(data_manager_config, (VanillaDataManagerConfig, FullImageDatamanagerConfig)) for split in self.split.split("+"): datamanager: VanillaDataManager dataset: Dataset if split == "train": - with _disable_datamanager_setup( - data_manager_config._target - ): # pylint: disable=protected-access - datamanager = data_manager_config.setup( - test_mode="test", device=pipeline.device - ) + with _disable_datamanager_setup(data_manager_config._target): # pylint: disable=protected-access + datamanager = data_manager_config.setup(test_mode="test", device=pipeline.device) dataset = datamanager.train_dataset - dataparser_outputs = getattr( - dataset, "_dataparser_outputs", datamanager.train_dataparser_outputs - ) + dataparser_outputs = getattr(dataset, "_dataparser_outputs", datamanager.train_dataparser_outputs) else: - with _disable_datamanager_setup( - data_manager_config._target - ): # pylint: disable=protected-access - datamanager = data_manager_config.setup( - test_mode=split, device=pipeline.device - ) + with _disable_datamanager_setup(data_manager_config._target): # pylint: disable=protected-access + datamanager = data_manager_config.setup(test_mode=split, device=pipeline.device) dataset = datamanager.eval_dataset dataparser_outputs = getattr(dataset, "_dataparser_outputs", None) if dataparser_outputs is None: - dataparser_outputs = datamanager.dataparser.get_dataparser_outputs( - split=datamanager.test_split - ) + dataparser_outputs = datamanager.dataparser.get_dataparser_outputs(split=datamanager.test_split) dataloader = FixedIndicesEvalDataloader( input_dataset=dataset, device=datamanager.device, @@ -884,9 +846,7 @@ def update_config(config: TrainerConfig) -> TrainerConfig: TimeRemainingColumn(elapsed_when_finished=False, compact=False), TimeElapsedColumn(), ) as progress: - for camera_idx, (camera, batch) in enumerate( - progress.track(dataloader, total=len(dataset)) - ): + for camera_idx, (camera, batch) in enumerate(progress.track(dataloader, total=len(dataset))): with torch.no_grad(): outputs = pipeline.model.get_outputs_for_camera(camera) @@ -919,13 +879,9 @@ def update_config(config: TrainerConfig) -> TrainerConfig: image_name = f"{camera_idx:05d}" # Try to get the original filename - image_name = dataparser_outputs.image_filenames[ - camera_idx - ].relative_to(images_root) + image_name = dataparser_outputs.image_filenames[camera_idx].relative_to(images_root) - output_path = ( - self.output_path / split / rendered_output_name / image_name - ) + output_path = self.output_path / split / rendered_output_name / image_name output_path.parent.mkdir(exist_ok=True, parents=True) output_name = rendered_output_name @@ -939,9 +895,7 @@ def update_config(config: TrainerConfig) -> TrainerConfig: output_image = outputs[output_name] if is_depth: # Divide by the dataparser scale factor - output_image.div_( - dataparser_outputs.dataparser_scale - ) + output_image.div_(dataparser_outputs.dataparser_scale) else: if output_name.startswith("gt-"): output_name = output_name[3:] @@ -977,14 +931,10 @@ def update_config(config: TrainerConfig) -> TrainerConfig: # Save to file if is_raw: - with gzip.open( - output_path.with_suffix(".npy.gz"), "wb" - ) as f: + with gzip.open(output_path.with_suffix(".npy.gz"), "wb") as f: np.save(f, output_image) elif self.image_format == "png": - media.write_image( - output_path.with_suffix(".png"), output_image, fmt="png" - ) + media.write_image(output_path.with_suffix(".png"), output_image, fmt="png") elif self.image_format == "jpeg": media.write_image( output_path.with_suffix(".jpg"), @@ -993,9 +943,7 @@ def update_config(config: TrainerConfig) -> TrainerConfig: quality=self.jpeg_quality, ) else: - raise ValueError( - f"Unknown image format {self.image_format}" - ) + raise ValueError(f"Unknown image format {self.image_format}") table = Table( title=None, diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index c19c29b546..38257f5bf8 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -89,6 +89,7 @@ def get_crop_string(obb: OrientedBox, crop_viewport: bool): scalestring = " ".join([f"{x:.10f}" for x in scale]) return [posstring, rpystring, scalestring] + def populate_point_cloud_tab( server: viser.ViserServer, control_panel: ControlPanel, @@ -126,21 +127,19 @@ def _(event: viser.GuiEvent) -> None: assert client is not None notif = client.add_notification( - title="Exporting point cloud", - body="File will be saved under " + str(output_dir.value), - loading=True, - with_close_button=False, - ) + title="Exporting point cloud", + body="File will be saved under " + str(output_dir.value), + loading=True, + with_close_button=False, + ) if control_panel.crop_obb is not None and control_panel.crop_viewport: - posstring, rpystring, scalestring = get_crop_string( - control_panel.crop_obb, control_panel.crop_viewport - ) - else: + posstring, rpystring, scalestring = get_crop_string(control_panel.crop_obb, control_panel.crop_viewport) + else: posstring = rpystring = scalestring = None from nerfstudio.scripts.exporter import ExportPointCloud - + export = ExportPointCloud( load_config=config_path, output_dir=Path(output_dir.value), @@ -153,7 +152,7 @@ def _(event: viser.GuiEvent) -> None: obb_scale=scalestring, ) export.main() - + if export.complete: notif.title = "Export complete!" notif.body = "File saved under " + str(output_dir.value) @@ -167,13 +166,11 @@ def _(event: viser.GuiEvent) -> None: client = event.client assert client is not None - with open(str(output_dir.value) + "point_cloud.ply", 'rb') as ply_file: + with open(str(output_dir.value) + "point_cloud.ply", "rb") as ply_file: ply_bytes = ply_file.read() - client.send_file_download( - "point_cloud.ply", ply_bytes - ) - + client.send_file_download("point_cloud.ply", ply_bytes) + @generate_command.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None @@ -228,21 +225,19 @@ def _(event: viser.GuiEvent) -> None: assert client is not None notif = client.add_notification( - title="Exporting poisson mesh", - body="File will be saved under " + str(output_dir.value), - loading=True, - with_close_button=False, - ) + title="Exporting poisson mesh", + body="File will be saved under " + str(output_dir.value), + loading=True, + with_close_button=False, + ) if control_panel.crop_obb is not None and control_panel.crop_viewport: - posstring, rpystring, scalestring = get_crop_string( - control_panel.crop_obb, control_panel.crop_viewport - ) - else: + posstring, rpystring, scalestring = get_crop_string(control_panel.crop_obb, control_panel.crop_viewport) + else: posstring = rpystring = scalestring = None from nerfstudio.scripts.exporter import ExportPoissonMesh - + export = ExportPoissonMesh( load_config=config_path, output_dir=Path(output_dir.value), @@ -264,19 +259,17 @@ def _(event: viser.GuiEvent) -> None: notif.with_close_button = True download_button.disabled = False - + @download_button.on_click def _(event: viser.GuiEvent) -> None: client = event.client assert client is not None - with open(str(output_dir.value) + "poisson_mesh.ply", 'rb') as ply_file: + with open(str(output_dir.value) + "poisson_mesh.ply", "rb") as ply_file: ply_bytes = ply_file.read() - client.send_file_download( - "poisson_mesh.ply", ply_bytes - ) - + client.send_file_download("poisson_mesh.ply", ply_bytes) + @generate_command.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None @@ -294,7 +287,6 @@ def _(event: viser.GuiEvent) -> None: ] ) show_command_modal(event.client, "mesh", command) - else: server.gui.add_markdown("Mesh export is not currently supported with Gaussian Splatting") @@ -319,18 +311,16 @@ def _(event: viser.GuiEvent) -> None: assert client is not None notif = client.add_notification( - title="Exporting gaussian splat", - body="File will be saved under " + str(output_dir.value), - loading=True, - with_close_button=False, - ) + title="Exporting gaussian splat", + body="File will be saved under " + str(output_dir.value), + loading=True, + with_close_button=False, + ) notif.show() if control_panel.crop_obb is not None and control_panel.crop_viewport: - posstring, rpystring, scalestring = get_crop_string( - control_panel.crop_obb, control_panel.crop_viewport - ) - else: + posstring, rpystring, scalestring = get_crop_string(control_panel.crop_obb, control_panel.crop_viewport) + else: posstring = rpystring = scalestring = None from nerfstudio.scripts.exporter import ExportGaussianSplat @@ -351,19 +341,17 @@ def _(event: viser.GuiEvent) -> None: notif.with_close_button = True download_button.disabled = False - + @download_button.on_click def _(event: viser.GuiEvent) -> None: client = event.client assert client is not None - with open(str(output_dir.value) + "splat.ply", 'rb') as ply_file: + with open(str(output_dir.value) + "splat.ply", "rb") as ply_file: ply_bytes = ply_file.read() - client.send_file_download( - "splat.ply", ply_bytes - ) - + client.send_file_download("splat.ply", ply_bytes) + @generate_command.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index 9f63921b96..f773c0cb35 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -95,9 +95,7 @@ def set_keyframes_visible(self, visible: bool) -> None: for keyframe in self._keyframes.values(): keyframe[1].visible = visible - def add_camera( - self, keyframe: Keyframe, keyframe_index: Optional[int] = None - ) -> None: + def add_camera(self, keyframe: Keyframe, keyframe_index: Optional[int] = None) -> None: """Add a new camera, or replace an old one if `keyframe_index` is passed in.""" server = self._server @@ -108,9 +106,7 @@ def add_camera( frustum_handle = server.scene.add_camera_frustum( f"/render_cameras/{keyframe_index}", - fov=keyframe.override_fov_rad - if keyframe.override_fov_enabled - else self.default_fov, + fov=keyframe.override_fov_rad if keyframe.override_fov_enabled else self.default_fov, aspect=keyframe.aspect, scale=0.1, color=(200, 10, 30), @@ -227,9 +223,7 @@ def _(event: viser.GuiEvent) -> None: T_current_target = T_world_current.inverse() @ T_world_target for j in range(10): - T_world_set = T_world_current @ tf.SE3.exp( - T_current_target.log() * j / 9.0 - ) + T_world_set = T_world_current @ tf.SE3.exp(T_current_target.log() * j / 9.0) # Important bit: we atomically set both the orientation and the position # of the camera. @@ -284,14 +278,10 @@ def spline_t_from_t_sec(self, time: np.ndarray) -> np.ndarray: ], axis=0, ), - y=np.concatenate( - [[-1], spline_indices, [spline_indices[-1] + 1]], axis=0 - ), + y=np.concatenate([[-1], spline_indices, [spline_indices[-1] + 1]], axis=0), ) else: - interpolator = interpolate.PchipInterpolator( - x=transition_times_cumsum, y=spline_indices - ) + interpolator = interpolate.PchipInterpolator(x=transition_times_cumsum, y=spline_indices) # Clip to account for floating point error. return np.clip(interpolator(time), 0, spline_indices[-1]) @@ -304,9 +294,7 @@ def interpolate_pose_and_fov_rad( self._fov_spline = splines.KochanekBartels( [ - keyframe[0].override_fov_rad - if keyframe[0].override_fov_enabled - else self.default_fov + keyframe[0].override_fov_rad if keyframe[0].override_fov_enabled else self.default_fov for keyframe in self._keyframes.values() ], tcb=(self.tension, 0.0, 0.0), @@ -315,9 +303,7 @@ def interpolate_pose_and_fov_rad( self._time_spline = splines.KochanekBartels( [ - keyframe[0].override_time_val - if keyframe[0].override_time_enabled - else self.default_render_time + keyframe[0].override_time_val if keyframe[0].override_time_enabled else self.default_render_time for keyframe in self._keyframes.values() ], tcb=(self.tension, 0.0, 0.0), @@ -367,9 +353,7 @@ def update_spline(self) -> None: self._orientation_spline = splines.quaternion.KochanekBartels( [ - splines.quaternion.UnitQuaternion.from_unit_xyzw( - np.roll(keyframe[0].wxyz, shift=-1) - ) + splines.quaternion.UnitQuaternion.from_unit_xyzw(np.roll(keyframe[0].wxyz, shift=-1)) for keyframe in keyframes ], tcb=(self.tension, 0.0, 0.0), @@ -383,16 +367,9 @@ def update_spline(self) -> None: # Update visualized spline. points_array = self._position_spline.evaluate( - self.spline_t_from_t_sec( - np.linspace(0, transition_times_cumsum[-1], num_frames) - ) - ) - colors_array = np.array( - [ - colorsys.hls_to_rgb(h, 0.5, 1.0) - for h in np.linspace(0.0, 1.0, len(points_array)) - ] + self.spline_t_from_t_sec(np.linspace(0, transition_times_cumsum[-1], num_frames)) ) + colors_array = np.array([colorsys.hls_to_rgb(h, 0.5, 1.0) for h in np.linspace(0.0, 1.0, len(points_array))]) # Clear prior spline nodes. for node in self._spline_nodes: @@ -423,8 +400,7 @@ def make_transition_handle(i: int) -> None: transition_pos = self._position_spline.evaluate( float( self.spline_t_from_t_sec( - (transition_times_cumsum[i] + transition_times_cumsum[i + 1]) - / 2.0, + (transition_times_cumsum[i] + transition_times_cumsum[i + 1]) / 2.0, ) ) ) @@ -470,12 +446,8 @@ def _(_) -> None: @override_transition_enabled.on_update def _(_) -> None: - keyframe.override_transition_enabled = ( - override_transition_enabled.value - ) - override_transition_sec.disabled = ( - not override_transition_enabled.value - ) + keyframe.override_transition_enabled = override_transition_enabled.value + override_transition_sec.disabled = not override_transition_enabled.value self._duration_element.value = self.compute_duration() @override_transition_sec.on_update @@ -504,8 +476,7 @@ def compute_duration(self) -> float: del frustum total += ( keyframe.override_transition_sec - if keyframe.override_transition_enabled - and keyframe.override_transition_sec is not None + if keyframe.override_transition_enabled and keyframe.override_transition_sec is not None else self.default_transition_sec ) return total @@ -520,8 +491,7 @@ def compute_transition_times_cumsum(self) -> np.ndarray: del frustum total += ( keyframe.override_transition_sec - if keyframe.override_transition_enabled - and keyframe.override_transition_sec is not None + if keyframe.override_transition_enabled and keyframe.override_transition_sec is not None else self.default_transition_sec ) out.append(total) @@ -530,8 +500,7 @@ def compute_transition_times_cumsum(self) -> np.ndarray: keyframe = next(iter(self._keyframes.values()))[0] total += ( keyframe.override_transition_sec - if keyframe.override_transition_enabled - and keyframe.override_transition_sec is not None + if keyframe.override_transition_enabled and keyframe.override_transition_sec is not None else self.default_transition_sec ) out.append(total) @@ -816,9 +785,7 @@ def remove_preview_camera() -> None: preview_camera_handle.remove() preview_camera_handle = None - def compute_and_update_preview_camera_state() -> ( - Optional[Union[Tuple[tf.SE3, float], Tuple[tf.SE3, float, float]]] - ): + def compute_and_update_preview_camera_state() -> Optional[Union[Tuple[tf.SE3, float], Tuple[tf.SE3, float, float]]]: """Update the render tab state with the current preview camera pose. Returns current camera pose + FOV if available.""" @@ -933,9 +900,7 @@ def _(_) -> None: for client in server.get_clients().values(): if client.client_id not in camera_pose_backup_from_id: continue - cam_position, cam_look_at, cam_up = camera_pose_backup_from_id.pop( - client.client_id - ) + cam_position, cam_look_at, cam_up = camera_pose_backup_from_id.pop(client.client_id) client.camera.position = cam_position client.camera.look_at = cam_look_at client.camera.up_direction = cam_up @@ -976,9 +941,7 @@ def play() -> None: max_frame = int(framerate_number.value * duration_number.value) if max_frame > 0: assert preview_frame_slider is not None - preview_frame_slider.value = ( - preview_frame_slider.value + 1 - ) % max_frame + preview_frame_slider.value = (preview_frame_slider.value + 1) % max_frame time.sleep(1.0 / framerate_number.value) threading.Thread(target=play).start() @@ -1025,9 +988,7 @@ def _(_) -> None: camera_path.reset() for i in range(len(keyframes)): frame = keyframes[i] - pose = tf.SE3.from_matrix( - np.array(frame["matrix"]).reshape(4, 4) - ) + pose = tf.SE3.from_matrix(np.array(frame["matrix"]).reshape(4, 4)) # apply the x rotation by 180 deg pose = tf.SE3.from_rotation_and_translation( @@ -1037,33 +998,21 @@ def _(_) -> None: camera_path.add_camera( Keyframe( - position=pose.translation() - * VISER_NERFSTUDIO_SCALE_RATIO, + position=pose.translation() * VISER_NERFSTUDIO_SCALE_RATIO, wxyz=pose.rotation().wxyz, # There are some floating point conversions between degrees and radians, so the fov and # default_Fov values will not be exactly matched. - override_fov_enabled=abs( - frame["fov"] - json_data.get("default_fov", 0.0) - ) - > 1e-3, + override_fov_enabled=abs(frame["fov"] - json_data.get("default_fov", 0.0)) > 1e-3, override_fov_rad=frame["fov"] / 180.0 * np.pi, - override_time_enabled=frame.get( - "override_time_enabled", False - ), + override_time_enabled=frame.get("override_time_enabled", False), override_time_val=frame.get("render_time", None), aspect=frame["aspect"], - override_transition_enabled=frame.get( - "override_transition_enabled", None - ), - override_transition_sec=frame.get( - "override_transition_sec", None - ), + override_transition_enabled=frame.get("override_transition_enabled", None), + override_transition_sec=frame.get("override_transition_sec", None), ), ) - transition_sec_number.value = json_data.get( - "default_transition_sec", 0.5 - ) + transition_sec_number.value = json_data.get("default_transition_sec", 0.5) # update the render name render_name_text.value = json_path.stem @@ -1086,7 +1035,9 @@ def _(_) -> None: render_folder = server.gui.add_folder("Render") with render_folder: - server.gui.add_markdown("Render available after a checkpoint is saved (default minimum 2000 steps)") + server.gui.add_markdown( + "Render available after a checkpoint is saved (default minimum 2000 steps)" + ) render_button = server.gui.add_button( "Render", @@ -1115,7 +1066,7 @@ def _(_) -> None: hint="Generate the ns-render command for rendering the camera path instead of directly rendering.", disabled=True, ) - + def _write_json() -> json: num_frames = int(framerate_number.value * duration_number.value) json_data = {} @@ -1147,9 +1098,7 @@ def _write_json() -> json: keyframe_dict = { "matrix": pose.as_matrix().flatten().tolist(), - "fov": np.rad2deg(keyframe.override_fov_rad) - if keyframe.override_fov_enabled - else fov_degrees.value, + "fov": np.rad2deg(keyframe.override_fov_rad) if keyframe.override_fov_enabled else fov_degrees.value, "aspect": keyframe.aspect, "override_transition_enabled": keyframe.override_transition_enabled, "override_transition_sec": keyframe.override_transition_sec, @@ -1157,9 +1106,7 @@ def _write_json() -> json: if render_time is not None: keyframe_dict["render_time"] = ( - keyframe.override_time_val - if keyframe.override_time_enabled - else render_time.value + keyframe.override_time_val if keyframe.override_time_enabled else render_time.value ) keyframe_dict["override_time_enabled"] = keyframe.override_time_enabled @@ -1168,9 +1115,7 @@ def _write_json() -> json: json_data["default_fov"] = fov_degrees.value if render_time is not None: - json_data["default_time"] = ( - render_time.value if render_time is not None else None - ) + json_data["default_time"] = render_time.value if render_time is not None else None json_data["default_transition_sec"] = transition_sec_number.value json_data["keyframes"] = keyframes @@ -1185,9 +1130,7 @@ def _write_json() -> json: # now populate the camera path: camera_path_list = [] for i in range(num_frames): - maybe_pose_and_fov = camera_path.interpolate_pose_and_fov_rad( - i / num_frames - ) + maybe_pose_and_fov = camera_path.interpolate_pose_and_fov_rad(i / num_frames) if maybe_pose_and_fov is None: return time = None @@ -1231,9 +1174,9 @@ def _write_json() -> json: json_outfile.parent.mkdir(parents=True, exist_ok=True) with open(json_outfile.absolute(), "w") as outfile: json.dump(json_data, outfile) - + return json_outfile - + @render_button.on_click def _(event: viser.GuiEvent) -> None: client = event.client @@ -1246,11 +1189,11 @@ def _(event: viser.GuiEvent) -> None: json_outfile = _write_json() notif = client.add_notification( - title="Rendering trajectory", - body="Saving rendered video as " + render_path, - loading=True, - with_close_button=False, - ) + title="Rendering trajectory", + body="Saving rendered video as " + render_path, + loading=True, + with_close_button=False, + ) from nerfstudio.scripts.render import RenderCameraPath @@ -1289,13 +1232,11 @@ def _(event: viser.GuiEvent) -> None: client = event.client assert client is not None - with open(render_path, 'rb') as file: + with open(render_path, "rb") as file: video_bytes = file.read() - client.send_file_download( - "render.mp4", video_bytes - ) - + client.send_file_download("render.mp4", video_bytes) + @generate_command_render_button.on_click def _(event: viser.GuiEvent) -> None: client = event.client @@ -1330,7 +1271,7 @@ def _(event: viser.GuiEvent) -> None: @close_button.on_click def _(_) -> None: modal.close() - + reset_up_button = server.gui.add_button( "Reset Up Direction", icon=viser.Icon.ARROW_BIG_UP_LINES, @@ -1341,9 +1282,7 @@ def _(_) -> None: @reset_up_button.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None - event.client.camera.up_direction = tf.SO3(event.client.camera.wxyz) @ np.array( - [0.0, -1.0, 0.0] - ) + event.client.camera.up_direction = tf.SO3(event.client.camera.wxyz) @ np.array([0.0, -1.0, 0.0]) if control_panel is not None: camera_path = CameraPath(server, duration_number, control_panel._time_enabled) From f5f9fdb31b348df707e801bed0bce497de71109b Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Mon, 26 Aug 2024 20:33:45 -0700 Subject: [PATCH 22/33] ruff format --- nerfstudio/viewer/render_panel.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index f773c0cb35..2b181a4afb 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -1242,7 +1242,6 @@ def _(event: viser.GuiEvent) -> None: client = event.client assert client is not None - render_path = f"renders/{datapath.name}/{render_name_text.value}.mp4" json_outfile = _write_json() with event.client.gui.add_modal("Render Command") as modal: From 2ff7da4888338e5e5e3abc00493d047ec42e4e6a Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Mon, 26 Aug 2024 20:40:41 -0700 Subject: [PATCH 23/33] ruff format --- nerfstudio/scripts/exporter.py | 120 +++++++++++++++++++--------- nerfstudio/scripts/render.py | 126 +++++++++++++++++++----------- nerfstudio/viewer/export_panel.py | 2 +- 3 files changed, 166 insertions(+), 82 deletions(-) diff --git a/nerfstudio/scripts/exporter.py b/nerfstudio/scripts/exporter.py index bfe6367784..84ee4c5b68 100644 --- a/nerfstudio/scripts/exporter.py +++ b/nerfstudio/scripts/exporter.py @@ -37,19 +37,11 @@ from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanager from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManager -from nerfstudio.data.datamanagers.random_cameras_datamanager import ( - RandomCamerasDataManager, -) +from nerfstudio.data.datamanagers.random_cameras_datamanager import RandomCamerasDataManager from nerfstudio.data.scene_box import OrientedBox from nerfstudio.exporter import texture_utils, tsdf_utils -from nerfstudio.exporter.exporter_utils import ( - collect_camera_poses, - generate_point_cloud, - get_mesh_from_filename, -) -from nerfstudio.exporter.marching_cubes import ( - generate_mesh_with_multires_marching_cubes, -) +from nerfstudio.exporter.exporter_utils import collect_camera_poses, generate_point_cloud, get_mesh_from_filename +from nerfstudio.exporter.marching_cubes import generate_mesh_with_multires_marching_cubes from nerfstudio.fields.sdf_field import SDFField # noqa from nerfstudio.models.splatfacto import SplatfactoModel from nerfstudio.pipelines.base_pipeline import Pipeline, VanillaPipeline @@ -69,7 +61,9 @@ class Exporter: """Set to True when export is finished.""" -def validate_pipeline(normal_method: str, normal_output_name: str, pipeline: Pipeline) -> None: +def validate_pipeline( + normal_method: str, normal_output_name: str, pipeline: Pipeline +) -> None: """Check that the pipeline is valid for this exporter. Args: @@ -91,7 +85,9 @@ def validate_pipeline(normal_method: str, normal_output_name: str, pipeline: Pip ) outputs = pipeline.model(ray_bundle) if normal_output_name not in outputs: - CONSOLE.print(f"[bold yellow]Warning: Normal output '{normal_output_name}' not found in pipeline outputs.") + CONSOLE.print( + f"[bold yellow]Warning: Normal output '{normal_output_name}' not found in pipeline outputs." + ) CONSOLE.print(f"Available outputs: {list(outputs.keys())}") CONSOLE.print( "[bold yellow]Warning: Please train a model with normals " @@ -156,13 +152,21 @@ def main(self) -> None: ), ) assert pipeline.datamanager.train_pixel_sampler is not None - pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = self.num_rays_per_batch + pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = ( + self.num_rays_per_batch + ) # Whether the normals should be estimated based on the point cloud. estimate_normals = self.normal_method == "open3d" crop_obb = None - if self.obb_center is not None and self.obb_rotation is not None and self.obb_scale is not None: - crop_obb = OrientedBox.from_params(self.obb_center, self.obb_rotation, self.obb_scale) + if ( + self.obb_center is not None + and self.obb_rotation is not None + and self.obb_scale is not None + ): + crop_obb = OrientedBox.from_params( + self.obb_center, self.obb_rotation, self.obb_scale + ) pcd = generate_point_cloud( pipeline=pipeline, num_points=self.num_points, @@ -171,14 +175,20 @@ def main(self) -> None: estimate_normals=estimate_normals, rgb_output_name=self.rgb_output_name, depth_output_name=self.depth_output_name, - normal_output_name=self.normal_output_name if self.normal_method == "model_output" else None, + normal_output_name=( + self.normal_output_name + if self.normal_method == "model_output" + else None + ), crop_obb=crop_obb, std_ratio=self.std_ratio, ) if self.save_world_frame: # apply the inverse dataparser transform to the point cloud points = np.asarray(pcd.points) - poses = np.eye(4, dtype=np.float32)[None, ...].repeat(points.shape[0], axis=0)[:, :3, :] + poses = np.eye(4, dtype=np.float32)[None, ...].repeat( + points.shape[0], axis=0 + )[:, :3, :] poses[:, :3, 3] = points poses = pipeline.datamanager.train_dataparser_outputs.transform_poses_to_original_space( torch.from_numpy(poses) @@ -269,7 +279,9 @@ def main(self) -> None: mesh, pipeline, self.output_dir, - px_per_uv_triangle=self.px_per_uv_triangle if self.unwrap_method == "custom" else None, + px_per_uv_triangle=( + self.px_per_uv_triangle if self.unwrap_method == "custom" else None + ), unwrap_method=self.unwrap_method, num_pixels_per_side=self.num_pixels_per_side, ) @@ -347,12 +359,20 @@ def main(self) -> None: ), ) assert pipeline.datamanager.train_pixel_sampler is not None - pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = self.num_rays_per_batch + pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = ( + self.num_rays_per_batch + ) # Whether the normals should be estimated based on the point cloud. estimate_normals = self.normal_method == "open3d" - if self.obb_center is not None and self.obb_rotation is not None and self.obb_scale is not None: - crop_obb = OrientedBox.from_params(self.obb_center, self.obb_rotation, self.obb_scale) + if ( + self.obb_center is not None + and self.obb_rotation is not None + and self.obb_scale is not None + ): + crop_obb = OrientedBox.from_params( + self.obb_center, self.obb_rotation, self.obb_scale + ) else: crop_obb = None @@ -364,7 +384,11 @@ def main(self) -> None: estimate_normals=estimate_normals, rgb_output_name=self.rgb_output_name, depth_output_name=self.depth_output_name, - normal_output_name=self.normal_output_name if self.normal_method == "model_output" else None, + normal_output_name=( + self.normal_output_name + if self.normal_method == "model_output" + else None + ), crop_obb=crop_obb, std_ratio=self.std_ratio, ) @@ -378,7 +402,9 @@ def main(self) -> None: CONSOLE.print("[bold green]:white_check_mark: Saving Point Cloud") CONSOLE.print("Computing Mesh... this may take a while.") - mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=9) + mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson( + pcd, depth=9 + ) vertices_to_remove = densities < np.quantile(densities, 0.1) mesh.remove_vertices_by_mask(vertices_to_remove) print("\033[A\033[A") @@ -402,7 +428,9 @@ def main(self) -> None: mesh, pipeline, self.output_dir, - px_per_uv_triangle=self.px_per_uv_triangle if self.unwrap_method == "custom" else None, + px_per_uv_triangle=( + self.px_per_uv_triangle if self.unwrap_method == "custom" else None + ), unwrap_method=self.unwrap_method, num_pixels_per_side=self.num_pixels_per_side, ) @@ -441,11 +469,15 @@ def main(self) -> None: _, pipeline, _, _ = eval_setup(self.load_config) # TODO: Make this work with Density Field - assert hasattr(pipeline.model.config, "sdf_field"), "Model must have an SDF field." + assert hasattr( + pipeline.model.config, "sdf_field" + ), "Model must have an SDF field." CONSOLE.print("Extracting mesh with marching cubes... which may take a while") - assert self.resolution % 512 == 0, f"""resolution must be divisible by 512, got {self.resolution}. + assert ( + self.resolution % 512 == 0 + ), f"""resolution must be divisible by 512, got {self.resolution}. This is important because the algorithm uses a multi-resolution approach to evaluate the SDF where the minimum resolution is 512.""" @@ -464,13 +496,17 @@ def main(self) -> None: multi_res_mesh.export(filename) # load the mesh from the marching cubes export - mesh = get_mesh_from_filename(str(filename), target_num_faces=self.target_num_faces) + mesh = get_mesh_from_filename( + str(filename), target_num_faces=self.target_num_faces + ) CONSOLE.print("Texturing mesh with NeRF...") texture_utils.export_textured_mesh( mesh, pipeline, self.output_dir, - px_per_uv_triangle=self.px_per_uv_triangle if self.unwrap_method == "custom" else None, + px_per_uv_triangle=( + self.px_per_uv_triangle if self.unwrap_method == "custom" else None + ), unwrap_method=self.unwrap_method, num_pixels_per_side=self.num_pixels_per_side, ) @@ -498,7 +534,9 @@ def main(self) -> None: ("transforms_eval.json", eval_frames), ]: if len(frames) == 0: - CONSOLE.print(f"[bold yellow]No frames found for {file_name}. Skipping.") + CONSOLE.print( + f"[bold yellow]No frames found for {file_name}. Skipping." + ) continue output_file_path = os.path.join(self.output_dir, file_name) @@ -506,7 +544,9 @@ def main(self) -> None: with open(output_file_path, "w", encoding="UTF-8") as f: json.dump(frames, f, indent=4) - CONSOLE.print(f"[bold green]:white_check_mark: Saved poses to {output_file_path}") + CONSOLE.print( + f"[bold green]:white_check_mark: Saved poses to {output_file_path}" + ) @dataclass @@ -550,7 +590,9 @@ def write_ply( and tensor.size > 0 for tensor in map_to_tensors.values() ): - raise ValueError("All tensors must be numpy arrays of float or uint8 type and not empty") + raise ValueError( + "All tensors must be numpy arrays of float or uint8 type and not empty" + ) with open(filename, "wb") as ply_file: # Write PLY header @@ -626,8 +668,14 @@ def main(self) -> None: for i in range(4): map_to_tensors[f"rot_{i}"] = quats[:, i, None] - if self.obb_center is not None and self.obb_rotation is not None and self.obb_scale is not None: - crop_obb = OrientedBox.from_params(self.obb_center, self.obb_rotation, self.obb_scale) + if ( + self.obb_center is not None + and self.obb_rotation is not None + and self.obb_scale is not None + ): + crop_obb = OrientedBox.from_params( + self.obb_center, self.obb_rotation, self.obb_scale + ) assert crop_obb is not None mask = crop_obb.within(torch.from_numpy(positions)).numpy() for k, t in map_to_tensors.items(): @@ -647,7 +695,9 @@ def main(self) -> None: CONSOLE.print(f"{n_before - n_after} NaN/Inf elements in {k}") if np.sum(select) < n: - CONSOLE.print(f"values have NaN/Inf in map_to_tensors, only export {np.sum(select)}/{n}") + CONSOLE.print( + f"values have NaN/Inf in map_to_tensors, only export {np.sum(select)}/{n}" + ) for k, t in map_to_tensors.items(): map_to_tensors[k] = map_to_tensors[k][select] count = np.sum(select) diff --git a/nerfstudio/scripts/render.py b/nerfstudio/scripts/render.py index b61e734341..9668e01711 100644 --- a/nerfstudio/scripts/render.py +++ b/nerfstudio/scripts/render.py @@ -38,35 +38,17 @@ from jaxtyping import Float from rich import box, style from rich.panel import Panel -from rich.progress import ( - BarColumn, - Progress, - TaskProgressColumn, - TextColumn, - TimeElapsedColumn, - TimeRemainingColumn, -) +from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn from rich.table import Table from torch import Tensor from typing_extensions import Annotated -from nerfstudio.cameras.camera_paths import ( - get_interpolated_camera_path, - get_path_from_json, - get_spiral_path, -) +from nerfstudio.cameras.camera_paths import get_interpolated_camera_path, get_path_from_json, get_spiral_path from nerfstudio.cameras.cameras import Cameras, CameraType, RayBundle -from nerfstudio.data.datamanagers.base_datamanager import ( - VanillaDataManager, - VanillaDataManagerConfig, -) -from nerfstudio.data.datamanagers.full_images_datamanager import ( - FullImageDatamanagerConfig, -) +from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager, VanillaDataManagerConfig +from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManager -from nerfstudio.data.datamanagers.random_cameras_datamanager import ( - RandomCamerasDataManager, -) +from nerfstudio.data.datamanagers.random_cameras_datamanager import RandomCamerasDataManager from nerfstudio.data.datasets.base_dataset import Dataset from nerfstudio.data.scene_box import OrientedBox from nerfstudio.data.utils.dataloaders import FixedIndicesEvalDataloader @@ -171,14 +153,19 @@ def _render_trajectory_video( assert train_dataset is not None assert train_cameras is not None cam_pos = cameras[camera_idx].camera_to_worlds[:, 3].cpu() - cam_quat = tf.SO3.from_matrix(cameras[camera_idx].camera_to_worlds[:3, :3].numpy(force=True)).wxyz + cam_quat = tf.SO3.from_matrix( + cameras[camera_idx].camera_to_worlds[:3, :3].numpy(force=True) + ).wxyz for i in range(len(train_cameras)): train_cam_pos = train_cameras[i].camera_to_worlds[:, 3].cpu() # Make sure the line of sight from rendered cam to training cam is not blocked by any object bundle = RayBundle( origins=cam_pos.view(1, 3), - directions=((cam_pos - train_cam_pos) / (cam_pos - train_cam_pos).norm()).view(1, 3), + directions=( + (cam_pos - train_cam_pos) + / (cam_pos - train_cam_pos).norm() + ).view(1, 3), pixel_area=torch.tensor(1).view(1, 1), nears=torch.tensor(0.05).view(1, 1), fars=torch.tensor(100).view(1, 1), @@ -187,7 +174,9 @@ def _render_trajectory_video( ).to(pipeline.device) outputs = pipeline.model.get_outputs(bundle) - q = tf.SO3.from_matrix(train_cameras[i].camera_to_worlds[:3, :3].numpy(force=True)).wxyz + q = tf.SO3.from_matrix( + train_cameras[i].camera_to_worlds[:3, :3].numpy(force=True) + ).wxyz # calculate distance between two quaternions rot_dist = 1 - np.dot(q, cam_quat) ** 2 pos_dist = torch.norm(train_cam_pos - cam_pos) @@ -197,7 +186,10 @@ def _render_trajectory_video( true_max_dist = dist true_max_idx = i - if outputs["depth"][0] < torch.norm(cam_pos - train_cam_pos).item(): + if ( + outputs["depth"][0] + < torch.norm(cam_pos - train_cam_pos).item() + ): continue if check_occlusions and (max_dist == -1 or dist < max_dist): @@ -407,7 +399,11 @@ class CropData: background_color: Float[Tensor, "3"] = torch.Tensor([0.0, 0.0, 0.0]) """background color""" - obb: OrientedBox = field(default_factory=lambda: OrientedBox(R=torch.eye(3), T=torch.zeros(3), S=torch.ones(3) * 2)) + obb: OrientedBox = field( + default_factory=lambda: OrientedBox( + R=torch.eye(3), T=torch.zeros(3), S=torch.ones(3) * 2 + ) + ) """Oriented box representing the crop region""" # properties for backwards-compatibility interface @@ -433,12 +429,18 @@ def get_crop_from_json(camera_json: Dict[str, Any]) -> Optional[CropData]: bg_color = camera_json["crop"]["crop_bg_color"] center = camera_json["crop"]["crop_center"] scale = camera_json["crop"]["crop_scale"] - rot = (0.0, 0.0, 0.0) if "crop_rot" not in camera_json["crop"] else tuple(camera_json["crop"]["crop_rot"]) + rot = ( + (0.0, 0.0, 0.0) + if "crop_rot" not in camera_json["crop"] + else tuple(camera_json["crop"]["crop_rot"]) + ) assert len(center) == 3 assert len(scale) == 3 assert len(rot) == 3 return CropData( - background_color=torch.Tensor([bg_color["r"] / 255.0, bg_color["g"] / 255.0, bg_color["b"] / 255.0]), + background_color=torch.Tensor( + [bg_color["r"] / 255.0, bg_color["g"] / 255.0, bg_color["b"] / 255.0] + ), obb=OrientedBox.from_params(center, rot, scale), ) @@ -512,7 +514,9 @@ def main(self) -> None: or camera_path.camera_type[0] == CameraType.VR180_L.value ): # temp folder for writing left and right view renders - temp_folder_path = self.output_path.parent / (self.output_path.stem + "_temp") + temp_folder_path = self.output_path.parent / ( + self.output_path.stem + "_temp" + ) Path(temp_folder_path).mkdir(parents=True, exist_ok=True) left_eye_path = temp_folder_path / "render_left.mp4" @@ -520,7 +524,9 @@ def main(self) -> None: self.output_path = left_eye_path if camera_path.camera_type[0] == CameraType.OMNIDIRECTIONALSTEREO_L.value: - CONSOLE.print("[bold green]:goggles: Omni-directional Stereo VR :goggles:") + CONSOLE.print( + "[bold green]:goggles: Omni-directional Stereo VR :goggles:" + ) else: CONSOLE.print("[bold green]:goggles: VR180 :goggles:") @@ -810,25 +816,39 @@ def update_config(config: TrainerConfig) -> TrainerConfig: update_config_callback=update_config, ) data_manager_config = config.pipeline.datamanager - assert isinstance(data_manager_config, (VanillaDataManagerConfig, FullImageDatamanagerConfig)) + assert isinstance( + data_manager_config, (VanillaDataManagerConfig, FullImageDatamanagerConfig) + ) for split in self.split.split("+"): datamanager: VanillaDataManager dataset: Dataset if split == "train": - with _disable_datamanager_setup(data_manager_config._target): # pylint: disable=protected-access - datamanager = data_manager_config.setup(test_mode="test", device=pipeline.device) + with _disable_datamanager_setup( + data_manager_config._target + ): # pylint: disable=protected-access + datamanager = data_manager_config.setup( + test_mode="test", device=pipeline.device + ) dataset = datamanager.train_dataset - dataparser_outputs = getattr(dataset, "_dataparser_outputs", datamanager.train_dataparser_outputs) + dataparser_outputs = getattr( + dataset, "_dataparser_outputs", datamanager.train_dataparser_outputs + ) else: - with _disable_datamanager_setup(data_manager_config._target): # pylint: disable=protected-access - datamanager = data_manager_config.setup(test_mode=split, device=pipeline.device) + with _disable_datamanager_setup( + data_manager_config._target + ): # pylint: disable=protected-access + datamanager = data_manager_config.setup( + test_mode=split, device=pipeline.device + ) dataset = datamanager.eval_dataset dataparser_outputs = getattr(dataset, "_dataparser_outputs", None) if dataparser_outputs is None: - dataparser_outputs = datamanager.dataparser.get_dataparser_outputs(split=datamanager.test_split) + dataparser_outputs = datamanager.dataparser.get_dataparser_outputs( + split=datamanager.test_split + ) dataloader = FixedIndicesEvalDataloader( input_dataset=dataset, device=datamanager.device, @@ -846,7 +866,9 @@ def update_config(config: TrainerConfig) -> TrainerConfig: TimeRemainingColumn(elapsed_when_finished=False, compact=False), TimeElapsedColumn(), ) as progress: - for camera_idx, (camera, batch) in enumerate(progress.track(dataloader, total=len(dataset))): + for camera_idx, (camera, batch) in enumerate( + progress.track(dataloader, total=len(dataset)) + ): with torch.no_grad(): outputs = pipeline.model.get_outputs_for_camera(camera) @@ -879,9 +901,13 @@ def update_config(config: TrainerConfig) -> TrainerConfig: image_name = f"{camera_idx:05d}" # Try to get the original filename - image_name = dataparser_outputs.image_filenames[camera_idx].relative_to(images_root) + image_name = dataparser_outputs.image_filenames[ + camera_idx + ].relative_to(images_root) - output_path = self.output_path / split / rendered_output_name / image_name + output_path = ( + self.output_path / split / rendered_output_name / image_name + ) output_path.parent.mkdir(exist_ok=True, parents=True) output_name = rendered_output_name @@ -895,7 +921,9 @@ def update_config(config: TrainerConfig) -> TrainerConfig: output_image = outputs[output_name] if is_depth: # Divide by the dataparser scale factor - output_image.div_(dataparser_outputs.dataparser_scale) + output_image.div_( + dataparser_outputs.dataparser_scale + ) else: if output_name.startswith("gt-"): output_name = output_name[3:] @@ -931,10 +959,14 @@ def update_config(config: TrainerConfig) -> TrainerConfig: # Save to file if is_raw: - with gzip.open(output_path.with_suffix(".npy.gz"), "wb") as f: + with gzip.open( + output_path.with_suffix(".npy.gz"), "wb" + ) as f: np.save(f, output_image) elif self.image_format == "png": - media.write_image(output_path.with_suffix(".png"), output_image, fmt="png") + media.write_image( + output_path.with_suffix(".png"), output_image, fmt="png" + ) elif self.image_format == "jpeg": media.write_image( output_path.with_suffix(".jpg"), @@ -943,7 +975,9 @@ def update_config(config: TrainerConfig) -> TrainerConfig: quality=self.jpeg_quality, ) else: - raise ValueError(f"Unknown image format {self.image_format}") + raise ValueError( + f"Unknown image format {self.image_format}" + ) table = Table( title=None, diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index 38257f5bf8..63751d7f27 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -18,7 +18,7 @@ import viser import viser.transforms as vtf -from typing_extensions import Literal, List +from typing_extensions import Literal from nerfstudio.data.scene_box import OrientedBox from nerfstudio.models.base_model import Model From 12c454a07403a60ce6c172208aa13424f86e8ef2 Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Mon, 26 Aug 2024 20:43:45 -0700 Subject: [PATCH 24/33] RUFF FORMAT --- nerfstudio/scripts/exporter.py | 106 ++++++++------------------------- nerfstudio/scripts/render.py | 98 +++++++----------------------- 2 files changed, 47 insertions(+), 157 deletions(-) diff --git a/nerfstudio/scripts/exporter.py b/nerfstudio/scripts/exporter.py index 84ee4c5b68..c359efcaf2 100644 --- a/nerfstudio/scripts/exporter.py +++ b/nerfstudio/scripts/exporter.py @@ -61,9 +61,7 @@ class Exporter: """Set to True when export is finished.""" -def validate_pipeline( - normal_method: str, normal_output_name: str, pipeline: Pipeline -) -> None: +def validate_pipeline(normal_method: str, normal_output_name: str, pipeline: Pipeline) -> None: """Check that the pipeline is valid for this exporter. Args: @@ -85,9 +83,7 @@ def validate_pipeline( ) outputs = pipeline.model(ray_bundle) if normal_output_name not in outputs: - CONSOLE.print( - f"[bold yellow]Warning: Normal output '{normal_output_name}' not found in pipeline outputs." - ) + CONSOLE.print(f"[bold yellow]Warning: Normal output '{normal_output_name}' not found in pipeline outputs.") CONSOLE.print(f"Available outputs: {list(outputs.keys())}") CONSOLE.print( "[bold yellow]Warning: Please train a model with normals " @@ -152,21 +148,13 @@ def main(self) -> None: ), ) assert pipeline.datamanager.train_pixel_sampler is not None - pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = ( - self.num_rays_per_batch - ) + pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = self.num_rays_per_batch # Whether the normals should be estimated based on the point cloud. estimate_normals = self.normal_method == "open3d" crop_obb = None - if ( - self.obb_center is not None - and self.obb_rotation is not None - and self.obb_scale is not None - ): - crop_obb = OrientedBox.from_params( - self.obb_center, self.obb_rotation, self.obb_scale - ) + if self.obb_center is not None and self.obb_rotation is not None and self.obb_scale is not None: + crop_obb = OrientedBox.from_params(self.obb_center, self.obb_rotation, self.obb_scale) pcd = generate_point_cloud( pipeline=pipeline, num_points=self.num_points, @@ -175,20 +163,14 @@ def main(self) -> None: estimate_normals=estimate_normals, rgb_output_name=self.rgb_output_name, depth_output_name=self.depth_output_name, - normal_output_name=( - self.normal_output_name - if self.normal_method == "model_output" - else None - ), + normal_output_name=(self.normal_output_name if self.normal_method == "model_output" else None), crop_obb=crop_obb, std_ratio=self.std_ratio, ) if self.save_world_frame: # apply the inverse dataparser transform to the point cloud points = np.asarray(pcd.points) - poses = np.eye(4, dtype=np.float32)[None, ...].repeat( - points.shape[0], axis=0 - )[:, :3, :] + poses = np.eye(4, dtype=np.float32)[None, ...].repeat(points.shape[0], axis=0)[:, :3, :] poses[:, :3, 3] = points poses = pipeline.datamanager.train_dataparser_outputs.transform_poses_to_original_space( torch.from_numpy(poses) @@ -279,9 +261,7 @@ def main(self) -> None: mesh, pipeline, self.output_dir, - px_per_uv_triangle=( - self.px_per_uv_triangle if self.unwrap_method == "custom" else None - ), + px_per_uv_triangle=(self.px_per_uv_triangle if self.unwrap_method == "custom" else None), unwrap_method=self.unwrap_method, num_pixels_per_side=self.num_pixels_per_side, ) @@ -359,20 +339,12 @@ def main(self) -> None: ), ) assert pipeline.datamanager.train_pixel_sampler is not None - pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = ( - self.num_rays_per_batch - ) + pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = self.num_rays_per_batch # Whether the normals should be estimated based on the point cloud. estimate_normals = self.normal_method == "open3d" - if ( - self.obb_center is not None - and self.obb_rotation is not None - and self.obb_scale is not None - ): - crop_obb = OrientedBox.from_params( - self.obb_center, self.obb_rotation, self.obb_scale - ) + if self.obb_center is not None and self.obb_rotation is not None and self.obb_scale is not None: + crop_obb = OrientedBox.from_params(self.obb_center, self.obb_rotation, self.obb_scale) else: crop_obb = None @@ -384,11 +356,7 @@ def main(self) -> None: estimate_normals=estimate_normals, rgb_output_name=self.rgb_output_name, depth_output_name=self.depth_output_name, - normal_output_name=( - self.normal_output_name - if self.normal_method == "model_output" - else None - ), + normal_output_name=(self.normal_output_name if self.normal_method == "model_output" else None), crop_obb=crop_obb, std_ratio=self.std_ratio, ) @@ -402,9 +370,7 @@ def main(self) -> None: CONSOLE.print("[bold green]:white_check_mark: Saving Point Cloud") CONSOLE.print("Computing Mesh... this may take a while.") - mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson( - pcd, depth=9 - ) + mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=9) vertices_to_remove = densities < np.quantile(densities, 0.1) mesh.remove_vertices_by_mask(vertices_to_remove) print("\033[A\033[A") @@ -428,9 +394,7 @@ def main(self) -> None: mesh, pipeline, self.output_dir, - px_per_uv_triangle=( - self.px_per_uv_triangle if self.unwrap_method == "custom" else None - ), + px_per_uv_triangle=(self.px_per_uv_triangle if self.unwrap_method == "custom" else None), unwrap_method=self.unwrap_method, num_pixels_per_side=self.num_pixels_per_side, ) @@ -469,15 +433,11 @@ def main(self) -> None: _, pipeline, _, _ = eval_setup(self.load_config) # TODO: Make this work with Density Field - assert hasattr( - pipeline.model.config, "sdf_field" - ), "Model must have an SDF field." + assert hasattr(pipeline.model.config, "sdf_field"), "Model must have an SDF field." CONSOLE.print("Extracting mesh with marching cubes... which may take a while") - assert ( - self.resolution % 512 == 0 - ), f"""resolution must be divisible by 512, got {self.resolution}. + assert self.resolution % 512 == 0, f"""resolution must be divisible by 512, got {self.resolution}. This is important because the algorithm uses a multi-resolution approach to evaluate the SDF where the minimum resolution is 512.""" @@ -496,17 +456,13 @@ def main(self) -> None: multi_res_mesh.export(filename) # load the mesh from the marching cubes export - mesh = get_mesh_from_filename( - str(filename), target_num_faces=self.target_num_faces - ) + mesh = get_mesh_from_filename(str(filename), target_num_faces=self.target_num_faces) CONSOLE.print("Texturing mesh with NeRF...") texture_utils.export_textured_mesh( mesh, pipeline, self.output_dir, - px_per_uv_triangle=( - self.px_per_uv_triangle if self.unwrap_method == "custom" else None - ), + px_per_uv_triangle=(self.px_per_uv_triangle if self.unwrap_method == "custom" else None), unwrap_method=self.unwrap_method, num_pixels_per_side=self.num_pixels_per_side, ) @@ -534,9 +490,7 @@ def main(self) -> None: ("transforms_eval.json", eval_frames), ]: if len(frames) == 0: - CONSOLE.print( - f"[bold yellow]No frames found for {file_name}. Skipping." - ) + CONSOLE.print(f"[bold yellow]No frames found for {file_name}. Skipping.") continue output_file_path = os.path.join(self.output_dir, file_name) @@ -544,9 +498,7 @@ def main(self) -> None: with open(output_file_path, "w", encoding="UTF-8") as f: json.dump(frames, f, indent=4) - CONSOLE.print( - f"[bold green]:white_check_mark: Saved poses to {output_file_path}" - ) + CONSOLE.print(f"[bold green]:white_check_mark: Saved poses to {output_file_path}") @dataclass @@ -590,9 +542,7 @@ def write_ply( and tensor.size > 0 for tensor in map_to_tensors.values() ): - raise ValueError( - "All tensors must be numpy arrays of float or uint8 type and not empty" - ) + raise ValueError("All tensors must be numpy arrays of float or uint8 type and not empty") with open(filename, "wb") as ply_file: # Write PLY header @@ -668,14 +618,8 @@ def main(self) -> None: for i in range(4): map_to_tensors[f"rot_{i}"] = quats[:, i, None] - if ( - self.obb_center is not None - and self.obb_rotation is not None - and self.obb_scale is not None - ): - crop_obb = OrientedBox.from_params( - self.obb_center, self.obb_rotation, self.obb_scale - ) + if self.obb_center is not None and self.obb_rotation is not None and self.obb_scale is not None: + crop_obb = OrientedBox.from_params(self.obb_center, self.obb_rotation, self.obb_scale) assert crop_obb is not None mask = crop_obb.within(torch.from_numpy(positions)).numpy() for k, t in map_to_tensors.items(): @@ -695,9 +639,7 @@ def main(self) -> None: CONSOLE.print(f"{n_before - n_after} NaN/Inf elements in {k}") if np.sum(select) < n: - CONSOLE.print( - f"values have NaN/Inf in map_to_tensors, only export {np.sum(select)}/{n}" - ) + CONSOLE.print(f"values have NaN/Inf in map_to_tensors, only export {np.sum(select)}/{n}") for k, t in map_to_tensors.items(): map_to_tensors[k] = map_to_tensors[k][select] count = np.sum(select) diff --git a/nerfstudio/scripts/render.py b/nerfstudio/scripts/render.py index 9668e01711..d71bb7f238 100644 --- a/nerfstudio/scripts/render.py +++ b/nerfstudio/scripts/render.py @@ -153,19 +153,14 @@ def _render_trajectory_video( assert train_dataset is not None assert train_cameras is not None cam_pos = cameras[camera_idx].camera_to_worlds[:, 3].cpu() - cam_quat = tf.SO3.from_matrix( - cameras[camera_idx].camera_to_worlds[:3, :3].numpy(force=True) - ).wxyz + cam_quat = tf.SO3.from_matrix(cameras[camera_idx].camera_to_worlds[:3, :3].numpy(force=True)).wxyz for i in range(len(train_cameras)): train_cam_pos = train_cameras[i].camera_to_worlds[:, 3].cpu() # Make sure the line of sight from rendered cam to training cam is not blocked by any object bundle = RayBundle( origins=cam_pos.view(1, 3), - directions=( - (cam_pos - train_cam_pos) - / (cam_pos - train_cam_pos).norm() - ).view(1, 3), + directions=((cam_pos - train_cam_pos) / (cam_pos - train_cam_pos).norm()).view(1, 3), pixel_area=torch.tensor(1).view(1, 1), nears=torch.tensor(0.05).view(1, 1), fars=torch.tensor(100).view(1, 1), @@ -174,9 +169,7 @@ def _render_trajectory_video( ).to(pipeline.device) outputs = pipeline.model.get_outputs(bundle) - q = tf.SO3.from_matrix( - train_cameras[i].camera_to_worlds[:3, :3].numpy(force=True) - ).wxyz + q = tf.SO3.from_matrix(train_cameras[i].camera_to_worlds[:3, :3].numpy(force=True)).wxyz # calculate distance between two quaternions rot_dist = 1 - np.dot(q, cam_quat) ** 2 pos_dist = torch.norm(train_cam_pos - cam_pos) @@ -186,10 +179,7 @@ def _render_trajectory_video( true_max_dist = dist true_max_idx = i - if ( - outputs["depth"][0] - < torch.norm(cam_pos - train_cam_pos).item() - ): + if outputs["depth"][0] < torch.norm(cam_pos - train_cam_pos).item(): continue if check_occlusions and (max_dist == -1 or dist < max_dist): @@ -399,11 +389,7 @@ class CropData: background_color: Float[Tensor, "3"] = torch.Tensor([0.0, 0.0, 0.0]) """background color""" - obb: OrientedBox = field( - default_factory=lambda: OrientedBox( - R=torch.eye(3), T=torch.zeros(3), S=torch.ones(3) * 2 - ) - ) + obb: OrientedBox = field(default_factory=lambda: OrientedBox(R=torch.eye(3), T=torch.zeros(3), S=torch.ones(3) * 2)) """Oriented box representing the crop region""" # properties for backwards-compatibility interface @@ -429,18 +415,12 @@ def get_crop_from_json(camera_json: Dict[str, Any]) -> Optional[CropData]: bg_color = camera_json["crop"]["crop_bg_color"] center = camera_json["crop"]["crop_center"] scale = camera_json["crop"]["crop_scale"] - rot = ( - (0.0, 0.0, 0.0) - if "crop_rot" not in camera_json["crop"] - else tuple(camera_json["crop"]["crop_rot"]) - ) + rot = (0.0, 0.0, 0.0) if "crop_rot" not in camera_json["crop"] else tuple(camera_json["crop"]["crop_rot"]) assert len(center) == 3 assert len(scale) == 3 assert len(rot) == 3 return CropData( - background_color=torch.Tensor( - [bg_color["r"] / 255.0, bg_color["g"] / 255.0, bg_color["b"] / 255.0] - ), + background_color=torch.Tensor([bg_color["r"] / 255.0, bg_color["g"] / 255.0, bg_color["b"] / 255.0]), obb=OrientedBox.from_params(center, rot, scale), ) @@ -514,9 +494,7 @@ def main(self) -> None: or camera_path.camera_type[0] == CameraType.VR180_L.value ): # temp folder for writing left and right view renders - temp_folder_path = self.output_path.parent / ( - self.output_path.stem + "_temp" - ) + temp_folder_path = self.output_path.parent / (self.output_path.stem + "_temp") Path(temp_folder_path).mkdir(parents=True, exist_ok=True) left_eye_path = temp_folder_path / "render_left.mp4" @@ -524,9 +502,7 @@ def main(self) -> None: self.output_path = left_eye_path if camera_path.camera_type[0] == CameraType.OMNIDIRECTIONALSTEREO_L.value: - CONSOLE.print( - "[bold green]:goggles: Omni-directional Stereo VR :goggles:" - ) + CONSOLE.print("[bold green]:goggles: Omni-directional Stereo VR :goggles:") else: CONSOLE.print("[bold green]:goggles: VR180 :goggles:") @@ -816,39 +792,25 @@ def update_config(config: TrainerConfig) -> TrainerConfig: update_config_callback=update_config, ) data_manager_config = config.pipeline.datamanager - assert isinstance( - data_manager_config, (VanillaDataManagerConfig, FullImageDatamanagerConfig) - ) + assert isinstance(data_manager_config, (VanillaDataManagerConfig, FullImageDatamanagerConfig)) for split in self.split.split("+"): datamanager: VanillaDataManager dataset: Dataset if split == "train": - with _disable_datamanager_setup( - data_manager_config._target - ): # pylint: disable=protected-access - datamanager = data_manager_config.setup( - test_mode="test", device=pipeline.device - ) + with _disable_datamanager_setup(data_manager_config._target): # pylint: disable=protected-access + datamanager = data_manager_config.setup(test_mode="test", device=pipeline.device) dataset = datamanager.train_dataset - dataparser_outputs = getattr( - dataset, "_dataparser_outputs", datamanager.train_dataparser_outputs - ) + dataparser_outputs = getattr(dataset, "_dataparser_outputs", datamanager.train_dataparser_outputs) else: - with _disable_datamanager_setup( - data_manager_config._target - ): # pylint: disable=protected-access - datamanager = data_manager_config.setup( - test_mode=split, device=pipeline.device - ) + with _disable_datamanager_setup(data_manager_config._target): # pylint: disable=protected-access + datamanager = data_manager_config.setup(test_mode=split, device=pipeline.device) dataset = datamanager.eval_dataset dataparser_outputs = getattr(dataset, "_dataparser_outputs", None) if dataparser_outputs is None: - dataparser_outputs = datamanager.dataparser.get_dataparser_outputs( - split=datamanager.test_split - ) + dataparser_outputs = datamanager.dataparser.get_dataparser_outputs(split=datamanager.test_split) dataloader = FixedIndicesEvalDataloader( input_dataset=dataset, device=datamanager.device, @@ -866,9 +828,7 @@ def update_config(config: TrainerConfig) -> TrainerConfig: TimeRemainingColumn(elapsed_when_finished=False, compact=False), TimeElapsedColumn(), ) as progress: - for camera_idx, (camera, batch) in enumerate( - progress.track(dataloader, total=len(dataset)) - ): + for camera_idx, (camera, batch) in enumerate(progress.track(dataloader, total=len(dataset))): with torch.no_grad(): outputs = pipeline.model.get_outputs_for_camera(camera) @@ -901,13 +861,9 @@ def update_config(config: TrainerConfig) -> TrainerConfig: image_name = f"{camera_idx:05d}" # Try to get the original filename - image_name = dataparser_outputs.image_filenames[ - camera_idx - ].relative_to(images_root) + image_name = dataparser_outputs.image_filenames[camera_idx].relative_to(images_root) - output_path = ( - self.output_path / split / rendered_output_name / image_name - ) + output_path = self.output_path / split / rendered_output_name / image_name output_path.parent.mkdir(exist_ok=True, parents=True) output_name = rendered_output_name @@ -921,9 +877,7 @@ def update_config(config: TrainerConfig) -> TrainerConfig: output_image = outputs[output_name] if is_depth: # Divide by the dataparser scale factor - output_image.div_( - dataparser_outputs.dataparser_scale - ) + output_image.div_(dataparser_outputs.dataparser_scale) else: if output_name.startswith("gt-"): output_name = output_name[3:] @@ -959,14 +913,10 @@ def update_config(config: TrainerConfig) -> TrainerConfig: # Save to file if is_raw: - with gzip.open( - output_path.with_suffix(".npy.gz"), "wb" - ) as f: + with gzip.open(output_path.with_suffix(".npy.gz"), "wb") as f: np.save(f, output_image) elif self.image_format == "png": - media.write_image( - output_path.with_suffix(".png"), output_image, fmt="png" - ) + media.write_image(output_path.with_suffix(".png"), output_image, fmt="png") elif self.image_format == "jpeg": media.write_image( output_path.with_suffix(".jpg"), @@ -975,9 +925,7 @@ def update_config(config: TrainerConfig) -> TrainerConfig: quality=self.jpeg_quality, ) else: - raise ValueError( - f"Unknown image format {self.image_format}" - ) + raise ValueError(f"Unknown image format {self.image_format}") table = Table( title=None, From e1ac0e56cd28ce5897d6a6edc8e5dc71c56aa012 Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Wed, 28 Aug 2024 14:45:39 -0700 Subject: [PATCH 25/33] some pyright type fixes --- nerfstudio/scripts/render.py | 2 +- nerfstudio/viewer/export_panel.py | 2 +- nerfstudio/viewer/render_panel.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nerfstudio/scripts/render.py b/nerfstudio/scripts/render.py index d71bb7f238..94a3c64d0a 100644 --- a/nerfstudio/scripts/render.py +++ b/nerfstudio/scripts/render.py @@ -455,7 +455,7 @@ class BaseRender: """If true, checks line-of-sight occlusions when computing camera distance and rejects cameras not visible to each other""" camera_idx: Optional[int] = None """Index of the training camera to render.""" - kill_flag: Optional[List[bool]] = field(default_factory=lambda: [False]) + kill_flag: List[bool] = field(default_factory=lambda: [False]) """Stop execution of render if set to True.""" def kill(self) -> None: diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index 63751d7f27..abf25ea8b3 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -40,7 +40,7 @@ def populate_export_tab( def _(_) -> None: control_panel.crop_viewport = crop_output.value - server.add_gui_markdown("Export available after a checkpoint is saved (default minimum 2000 steps)") + server.gui.add_markdown("Export available after a checkpoint is saved (default minimum 2000 steps)") with server.gui.add_folder("Splat"): populate_splat_tab(server, control_panel, config_path, viewing_gsplat) with server.gui.add_folder("Point Cloud"): diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index 2b181a4afb..e615b2ab92 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -1067,7 +1067,7 @@ def _(_) -> None: disabled=True, ) - def _write_json() -> json: + def _write_json() -> Path: num_frames = int(framerate_number.value * duration_number.value) json_data = {} From 2308450d3922c34e059b8703331bfb54a1243076 Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Thu, 29 Aug 2024 11:50:15 -0700 Subject: [PATCH 26/33] pyright --- nerfstudio/viewer/export_panel.py | 46 +++++++++++++++++++------------ nerfstudio/viewer/render_panel.py | 2 +- 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index abf25ea8b3..a66d1588c5 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -73,7 +73,6 @@ def show_command_modal(client: viser.ClientHandle, what: Literal["mesh", "point def _(_) -> None: modal.close() - def get_crop_string(obb: OrientedBox, crop_viewport: bool): """Takes in an oriented bounding box and returns a string of the form "--obb_{center,rotation,scale} and each arg formatted with spaces around it @@ -87,8 +86,19 @@ def get_crop_string(obb: OrientedBox, crop_viewport: bool): rpystring = " ".join([f"{x:.10f}" for x in rpy]) posstring = " ".join([f"{x:.10f}" for x in pos]) scalestring = " ".join([f"{x:.10f}" for x in scale]) - return [posstring, rpystring, scalestring] + return f"--obb_center {posstring} --obb_rotation {rpystring} --obb_scale {scalestring}" +def get_crop_tuple(obb: OrientedBox, crop_viewport: bool): + """Takes in an oriented bounding box and returns a string of the form "--obb_{center,rotation,scale} + and each arg formatted with spaces around it + """ + if not crop_viewport: + return "" + rpy = vtf.SO3.from_matrix(obb.R.numpy(force=True)).as_rpy_radians() + obb_rotation = [rpy.roll, rpy.pitch, rpy.yaw] + obb_center = obb.T.squeeze().tolist() + obb_scale = obb.S.squeeze().tolist() + return obb_center, obb_rotation, obb_scale def populate_point_cloud_tab( server: viser.ViserServer, @@ -134,9 +144,9 @@ def _(event: viser.GuiEvent) -> None: ) if control_panel.crop_obb is not None and control_panel.crop_viewport: - posstring, rpystring, scalestring = get_crop_string(control_panel.crop_obb, control_panel.crop_viewport) + obb_center, obb_rotation, obb_scale = get_crop_tuple(control_panel.crop_obb, control_panel.crop_viewport) else: - posstring = rpystring = scalestring = None + obb_center, obb_rotation, obb_scale = None from nerfstudio.scripts.exporter import ExportPointCloud @@ -147,9 +157,9 @@ def _(event: viser.GuiEvent) -> None: remove_outliers=remove_outliers.value, normal_method=normals.value, save_world_frame=world_frame.value, - obb_center=posstring, - obb_rotation=rpystring, - obb_scale=scalestring, + obb_center=obb_center, + obb_rotation=obb_rotation, + obb_scale=obb_scale, ) export.main() @@ -232,9 +242,9 @@ def _(event: viser.GuiEvent) -> None: ) if control_panel.crop_obb is not None and control_panel.crop_viewport: - posstring, rpystring, scalestring = get_crop_string(control_panel.crop_obb, control_panel.crop_viewport) + obb_center, obb_rotation, obb_scale = get_crop_tuple(control_panel.crop_obb, control_panel.crop_viewport) else: - posstring = rpystring = scalestring = None + obb_center, obb_rotation, obb_scale = None from nerfstudio.scripts.exporter import ExportPoissonMesh @@ -246,9 +256,9 @@ def _(event: viser.GuiEvent) -> None: num_points=num_points.value, remove_outliers=remove_outliers.value, normal_method=normals.value, - obb_center=posstring, - obb_rotation=rpystring, - obb_scale=scalestring, + obb_center=obb_center, + obb_rotation=obb_rotation, + obb_scale=obb_scale, ) export.main() @@ -319,18 +329,18 @@ def _(event: viser.GuiEvent) -> None: notif.show() if control_panel.crop_obb is not None and control_panel.crop_viewport: - posstring, rpystring, scalestring = get_crop_string(control_panel.crop_obb, control_panel.crop_viewport) + obb_center, obb_rotation, obb_scale = get_crop_tuple(control_panel.crop_obb, control_panel.crop_viewport) else: - posstring = rpystring = scalestring = None + obb_center, obb_rotation, obb_scale = None from nerfstudio.scripts.exporter import ExportGaussianSplat export = ExportGaussianSplat( load_config=config_path, output_dir=Path(output_dir.value), - obb_center=posstring, - obb_rotation=rpystring, - obb_scale=scalestring, + obb_center=obb_center, + obb_rotation=obb_rotation, + obb_scale=obb_scale, ) export.main() @@ -359,7 +369,7 @@ def _(event: viser.GuiEvent) -> None: [ "ns-export gaussian-splat", f"--load-config {config_path}", - f"--output-dir {output_directory.value}", + f"--output-dir {output_dir.value}", get_crop_string(control_panel.crop_obb, control_panel.crop_viewport), ] ) diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index e615b2ab92..ad444c3775 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -1240,7 +1240,7 @@ def _(event: viser.GuiEvent) -> None: @generate_command_render_button.on_click def _(event: viser.GuiEvent) -> None: client = event.client - assert client is not None + assert client is not None, client.gui is not None json_outfile = _write_json() From 58a4c3ad05bf66503c5ff2b82c01ff18c8d62303 Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Thu, 29 Aug 2024 15:30:30 -0700 Subject: [PATCH 27/33] pyright --- nerfstudio/viewer/export_panel.py | 127 +++++++++++++++++++++--------- nerfstudio/viewer/render_panel.py | 21 ++--- 2 files changed, 102 insertions(+), 46 deletions(-) diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index a66d1588c5..be9e0b7968 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -18,8 +18,8 @@ import viser import viser.transforms as vtf -from typing_extensions import Literal - +from typing_extensions import Literal, Tuple, List +from typing import cast from nerfstudio.data.scene_box import OrientedBox from nerfstudio.models.base_model import Model from nerfstudio.models.splatfacto import SplatfactoModel @@ -40,7 +40,9 @@ def populate_export_tab( def _(_) -> None: control_panel.crop_viewport = crop_output.value - server.gui.add_markdown("Export available after a checkpoint is saved (default minimum 2000 steps)") + server.gui.add_markdown( + "Export available after a checkpoint is saved (default minimum 2000 steps)" + ) with server.gui.add_folder("Splat"): populate_splat_tab(server, control_panel, config_path, viewing_gsplat) with server.gui.add_folder("Point Cloud"): @@ -49,7 +51,11 @@ def _(_) -> None: populate_mesh_tab(server, control_panel, config_path, viewing_gsplat) -def show_command_modal(client: viser.ClientHandle, what: Literal["mesh", "point cloud", "splat"], command: str) -> None: +def show_command_modal( + client: viser.ClientHandle, + what: Literal["mesh", "point cloud", "splat"], + command: str, +) -> None: """Show a modal to each currently connected client. In the future, we should only show the modal to the client that pushes the @@ -73,7 +79,8 @@ def show_command_modal(client: viser.ClientHandle, what: Literal["mesh", "point def _(_) -> None: modal.close() -def get_crop_string(obb: OrientedBox, crop_viewport: bool): + +def get_crop_string(obb: OrientedBox, crop_viewport: bool) -> str: """Takes in an oriented bounding box and returns a string of the form "--obb_{center,rotation,scale} and each arg formatted with spaces around it """ @@ -86,19 +93,24 @@ def get_crop_string(obb: OrientedBox, crop_viewport: bool): rpystring = " ".join([f"{x:.10f}" for x in rpy]) posstring = " ".join([f"{x:.10f}" for x in pos]) scalestring = " ".join([f"{x:.10f}" for x in scale]) - return f"--obb_center {posstring} --obb_rotation {rpystring} --obb_scale {scalestring}" + return ( + f"--obb_center {posstring} --obb_rotation {rpystring} --obb_scale {scalestring}" + ) + + +Vec3f = Tuple[float, float, float] def get_crop_tuple(obb: OrientedBox, crop_viewport: bool): - """Takes in an oriented bounding box and returns a string of the form "--obb_{center,rotation,scale} - and each arg formatted with spaces around it + """Takes in an oriented bounding box and returns tuples for obb_{center,rotation,scale}. """ if not crop_viewport: - return "" + return None, None, None rpy = vtf.SO3.from_matrix(obb.R.numpy(force=True)).as_rpy_radians() obb_rotation = [rpy.roll, rpy.pitch, rpy.yaw] obb_center = obb.T.squeeze().tolist() obb_scale = obb.S.squeeze().tolist() - return obb_center, obb_rotation, obb_scale + return cast(Vec3f, tuple(obb_center)), cast(Vec3f, tuple(obb_rotation)), cast(Vec3f, tuple(obb_scale)) + def populate_point_cloud_tab( server: viser.ViserServer, @@ -107,8 +119,12 @@ def populate_point_cloud_tab( viewing_gsplat: bool, ) -> None: if not viewing_gsplat: - server.gui.add_markdown("Render depth, project to an oriented point cloud, and filter ") - num_points = server.gui.add_number("# Points", initial_value=1_000_000, min=1, max=None, step=1) + server.gui.add_markdown( + "Render depth, project to an oriented point cloud, and filter " + ) + num_points = server.gui.add_number( + "# Points", initial_value=1_000_000, min=1, max=None, step=1 + ) world_frame = server.gui.add_checkbox( "Save in world frame", False, @@ -126,10 +142,16 @@ def populate_point_cloud_tab( hint="Normal map source.", ) - output_dir = server.gui.add_text("Output Directory", initial_value="exports/pcd/") + output_dir = server.gui.add_text( + "Output Directory", initial_value="exports/pcd/" + ) export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) - download_button = server.gui.add_button("Download Point Cloud", icon=viser.Icon.DOWNLOAD, disabled=True) - generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) + download_button = server.gui.add_button( + "Download Point Cloud", icon=viser.Icon.DOWNLOAD, disabled=True + ) + generate_command = server.gui.add_button( + "Generate Command", icon=viser.Icon.TERMINAL_2 + ) @export_button.on_click def _(event: viser.GuiEvent) -> None: @@ -144,9 +166,11 @@ def _(event: viser.GuiEvent) -> None: ) if control_panel.crop_obb is not None and control_panel.crop_viewport: - obb_center, obb_rotation, obb_scale = get_crop_tuple(control_panel.crop_obb, control_panel.crop_viewport) + obb_center, obb_rotation, obb_scale = get_crop_tuple( + control_panel.crop_obb, control_panel.crop_viewport + ) else: - obb_center, obb_rotation, obb_scale = None + obb_center, obb_rotation, obb_scale = None, None, None from nerfstudio.scripts.exporter import ExportPointCloud @@ -193,13 +217,17 @@ def _(event: viser.GuiEvent) -> None: f"--remove-outliers {remove_outliers.value}", f"--normal-method {normals.value}", f"--save-world-frame {world_frame.value}", - get_crop_string(control_panel.crop_obb, control_panel.crop_viewport), + get_crop_string( + control_panel.crop_obb, control_panel.crop_viewport + ), ] ) show_command_modal(event.client, "point cloud", command) else: - server.gui.add_markdown("Point cloud export is not currently supported with Gaussian Splatting") + server.gui.add_markdown( + "Point cloud export is not currently supported with Gaussian Splatting" + ) def populate_mesh_tab( @@ -220,14 +248,24 @@ def populate_mesh_tab( hint="Source for normal maps.", ) num_faces = server.gui.add_number("# Faces", initial_value=50_000, min=1) - texture_resolution = server.gui.add_number("Texture Resolution", min=8, initial_value=2048) - num_points = server.gui.add_number("# Points", initial_value=1_000_000, min=1, max=None, step=1) + texture_resolution = server.gui.add_number( + "Texture Resolution", min=8, initial_value=2048 + ) + num_points = server.gui.add_number( + "# Points", initial_value=1_000_000, min=1, max=None, step=1 + ) remove_outliers = server.gui.add_checkbox("Remove outliers", True) - output_dir = server.gui.add_text("Output Directory", initial_value="exports/mesh/") + output_dir = server.gui.add_text( + "Output Directory", initial_value="exports/mesh/" + ) export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) - download_button = server.gui.add_button("Download Mesh", icon=viser.Icon.DOWNLOAD, disabled=True) - generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) + download_button = server.gui.add_button( + "Download Mesh", icon=viser.Icon.DOWNLOAD, disabled=True + ) + generate_command = server.gui.add_button( + "Generate Command", icon=viser.Icon.TERMINAL_2 + ) @export_button.on_click def _(event: viser.GuiEvent) -> None: @@ -242,9 +280,11 @@ def _(event: viser.GuiEvent) -> None: ) if control_panel.crop_obb is not None and control_panel.crop_viewport: - obb_center, obb_rotation, obb_scale = get_crop_tuple(control_panel.crop_obb, control_panel.crop_viewport) + obb_center, obb_rotation, obb_scale = get_crop_tuple( + control_panel.crop_obb, control_panel.crop_viewport + ) else: - obb_center, obb_rotation, obb_scale = None + obb_center, obb_rotation, obb_scale = None, None, None from nerfstudio.scripts.exporter import ExportPoissonMesh @@ -293,13 +333,17 @@ def _(event: viser.GuiEvent) -> None: f"--num-points {num_points.value}", f"--remove-outliers {remove_outliers.value}", f"--normal-method {normals.value}", - get_crop_string(control_panel.crop_obb, control_panel.crop_viewport), + get_crop_string( + control_panel.crop_obb, control_panel.crop_viewport + ), ] ) show_command_modal(event.client, "mesh", command) else: - server.gui.add_markdown("Mesh export is not currently supported with Gaussian Splatting") + server.gui.add_markdown( + "Mesh export is not currently supported with Gaussian Splatting" + ) def populate_splat_tab( @@ -310,10 +354,16 @@ def populate_splat_tab( ) -> None: if viewing_gsplat: server.gui.add_markdown("Generate ply export of Gaussian Splat") - output_dir = server.gui.add_text("Output Directory", initial_value="exports/splat/") + output_dir = server.gui.add_text( + "Output Directory", initial_value="exports/splat/" + ) export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) - download_button = server.gui.add_button("Download Splat", icon=viser.Icon.DOWNLOAD, disabled=True) - generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) + download_button = server.gui.add_button( + "Download Splat", icon=viser.Icon.DOWNLOAD, disabled=True + ) + generate_command = server.gui.add_button( + "Generate Command", icon=viser.Icon.TERMINAL_2 + ) @export_button.on_click def _(event: viser.GuiEvent) -> None: @@ -326,12 +376,13 @@ def _(event: viser.GuiEvent) -> None: loading=True, with_close_button=False, ) - notif.show() if control_panel.crop_obb is not None and control_panel.crop_viewport: - obb_center, obb_rotation, obb_scale = get_crop_tuple(control_panel.crop_obb, control_panel.crop_viewport) + obb_center, obb_rotation, obb_scale = get_crop_tuple( + control_panel.crop_obb, control_panel.crop_viewport + ) else: - obb_center, obb_rotation, obb_scale = None + obb_center, obb_rotation, obb_scale = None, None, None from nerfstudio.scripts.exporter import ExportGaussianSplat @@ -370,10 +421,14 @@ def _(event: viser.GuiEvent) -> None: "ns-export gaussian-splat", f"--load-config {config_path}", f"--output-dir {output_dir.value}", - get_crop_string(control_panel.crop_obb, control_panel.crop_viewport), + get_crop_string( + control_panel.crop_obb, control_panel.crop_viewport + ), ] ) show_command_modal(event.client, "splat", command) else: - server.gui.add_markdown("Splat export is only supported with Gaussian Splatting methods") + server.gui.add_markdown( + "Splat export is only supported with Gaussian Splatting methods" + ) diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index ad444c3775..d288cc54b1 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -1132,7 +1132,7 @@ def _write_json() -> Path: for i in range(num_frames): maybe_pose_and_fov = camera_path.interpolate_pose_and_fov_rad(i / num_frames) if maybe_pose_and_fov is None: - return + return Path() time = None if len(maybe_pose_and_fov) == 3: # Time is enabled. pose, fov, time = maybe_pose_and_fov @@ -1175,7 +1175,7 @@ def _write_json() -> Path: with open(json_outfile.absolute(), "w") as outfile: json.dump(json_data, outfile) - return json_outfile + return json_outfile.absolute() @render_button.on_click def _(event: viser.GuiEvent) -> None: @@ -1199,7 +1199,7 @@ def _(event: viser.GuiEvent) -> None: render = RenderCameraPath( load_config=config_path, - camera_path_filename=json_outfile.absolute(), + camera_path_filename=json_outfile, output_path=Path(render_path), ) @@ -1240,21 +1240,21 @@ def _(event: viser.GuiEvent) -> None: @generate_command_render_button.on_click def _(event: viser.GuiEvent) -> None: client = event.client - assert client is not None, client.gui is not None + assert client is not None json_outfile = _write_json() - with event.client.gui.add_modal("Render Command") as modal: + with client.gui.add_modal("Render Command") as modal: dataname = datapath.name command = " ".join( [ "ns-render camera-path", f"--load-config {config_path}", - f"--camera-path-filename {json_outfile.absolute()}", + f"--camera-path-filename {json_outfile}", f"--output-path renders/{dataname}/{render_name_text.value}.mp4", ] ) - event.client.gui.add_markdown( + client.gui.add_markdown( "\n".join( [ "To render the trajectory, run the following from the command line:", @@ -1265,7 +1265,7 @@ def _(event: viser.GuiEvent) -> None: ] ) ) - close_button = event.client.gui.add_button("Close") + close_button = client.gui.add_button("Close") @close_button.on_click def _(_) -> None: @@ -1280,8 +1280,9 @@ def _(_) -> None: @reset_up_button.on_click def _(event: viser.GuiEvent) -> None: - assert event.client is not None - event.client.camera.up_direction = tf.SO3(event.client.camera.wxyz) @ np.array([0.0, -1.0, 0.0]) + client = event.client + assert client is not None + client.camera.up_direction = tf.SO3(client.camera.wxyz) @ np.array([0.0, -1.0, 0.0]) if control_panel is not None: camera_path = CameraPath(server, duration_number, control_panel._time_enabled) From 8aa49c369963f294bbf6580b6fa0aaedbad73723 Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Thu, 29 Aug 2024 16:05:21 -0700 Subject: [PATCH 28/33] ruff format --- nerfstudio/viewer/export_panel.py | 88 ++++++++----------------------- 1 file changed, 23 insertions(+), 65 deletions(-) diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index be9e0b7968..516e72425f 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -40,9 +40,7 @@ def populate_export_tab( def _(_) -> None: control_panel.crop_viewport = crop_output.value - server.gui.add_markdown( - "Export available after a checkpoint is saved (default minimum 2000 steps)" - ) + server.gui.add_markdown("Export available after a checkpoint is saved (default minimum 2000 steps)") with server.gui.add_folder("Splat"): populate_splat_tab(server, control_panel, config_path, viewing_gsplat) with server.gui.add_folder("Point Cloud"): @@ -93,16 +91,14 @@ def get_crop_string(obb: OrientedBox, crop_viewport: bool) -> str: rpystring = " ".join([f"{x:.10f}" for x in rpy]) posstring = " ".join([f"{x:.10f}" for x in pos]) scalestring = " ".join([f"{x:.10f}" for x in scale]) - return ( - f"--obb_center {posstring} --obb_rotation {rpystring} --obb_scale {scalestring}" - ) + return f"--obb_center {posstring} --obb_rotation {rpystring} --obb_scale {scalestring}" Vec3f = Tuple[float, float, float] + def get_crop_tuple(obb: OrientedBox, crop_viewport: bool): - """Takes in an oriented bounding box and returns tuples for obb_{center,rotation,scale}. - """ + """Takes in an oriented bounding box and returns tuples for obb_{center,rotation,scale}.""" if not crop_viewport: return None, None, None rpy = vtf.SO3.from_matrix(obb.R.numpy(force=True)).as_rpy_radians() @@ -119,12 +115,8 @@ def populate_point_cloud_tab( viewing_gsplat: bool, ) -> None: if not viewing_gsplat: - server.gui.add_markdown( - "Render depth, project to an oriented point cloud, and filter " - ) - num_points = server.gui.add_number( - "# Points", initial_value=1_000_000, min=1, max=None, step=1 - ) + server.gui.add_markdown("Render depth, project to an oriented point cloud, and filter ") + num_points = server.gui.add_number("# Points", initial_value=1_000_000, min=1, max=None, step=1) world_frame = server.gui.add_checkbox( "Save in world frame", False, @@ -142,16 +134,10 @@ def populate_point_cloud_tab( hint="Normal map source.", ) - output_dir = server.gui.add_text( - "Output Directory", initial_value="exports/pcd/" - ) + output_dir = server.gui.add_text("Output Directory", initial_value="exports/pcd/") export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) - download_button = server.gui.add_button( - "Download Point Cloud", icon=viser.Icon.DOWNLOAD, disabled=True - ) - generate_command = server.gui.add_button( - "Generate Command", icon=viser.Icon.TERMINAL_2 - ) + download_button = server.gui.add_button("Download Point Cloud", icon=viser.Icon.DOWNLOAD, disabled=True) + generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) @export_button.on_click def _(event: viser.GuiEvent) -> None: @@ -217,17 +203,13 @@ def _(event: viser.GuiEvent) -> None: f"--remove-outliers {remove_outliers.value}", f"--normal-method {normals.value}", f"--save-world-frame {world_frame.value}", - get_crop_string( - control_panel.crop_obb, control_panel.crop_viewport - ), + get_crop_string(control_panel.crop_obb, control_panel.crop_viewport), ] ) show_command_modal(event.client, "point cloud", command) else: - server.gui.add_markdown( - "Point cloud export is not currently supported with Gaussian Splatting" - ) + server.gui.add_markdown("Point cloud export is not currently supported with Gaussian Splatting") def populate_mesh_tab( @@ -248,24 +230,14 @@ def populate_mesh_tab( hint="Source for normal maps.", ) num_faces = server.gui.add_number("# Faces", initial_value=50_000, min=1) - texture_resolution = server.gui.add_number( - "Texture Resolution", min=8, initial_value=2048 - ) - num_points = server.gui.add_number( - "# Points", initial_value=1_000_000, min=1, max=None, step=1 - ) + texture_resolution = server.gui.add_number("Texture Resolution", min=8, initial_value=2048) + num_points = server.gui.add_number("# Points", initial_value=1_000_000, min=1, max=None, step=1) remove_outliers = server.gui.add_checkbox("Remove outliers", True) - output_dir = server.gui.add_text( - "Output Directory", initial_value="exports/mesh/" - ) + output_dir = server.gui.add_text("Output Directory", initial_value="exports/mesh/") export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) - download_button = server.gui.add_button( - "Download Mesh", icon=viser.Icon.DOWNLOAD, disabled=True - ) - generate_command = server.gui.add_button( - "Generate Command", icon=viser.Icon.TERMINAL_2 - ) + download_button = server.gui.add_button("Download Mesh", icon=viser.Icon.DOWNLOAD, disabled=True) + generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) @export_button.on_click def _(event: viser.GuiEvent) -> None: @@ -333,17 +305,13 @@ def _(event: viser.GuiEvent) -> None: f"--num-points {num_points.value}", f"--remove-outliers {remove_outliers.value}", f"--normal-method {normals.value}", - get_crop_string( - control_panel.crop_obb, control_panel.crop_viewport - ), + get_crop_string(control_panel.crop_obb, control_panel.crop_viewport), ] ) show_command_modal(event.client, "mesh", command) else: - server.gui.add_markdown( - "Mesh export is not currently supported with Gaussian Splatting" - ) + server.gui.add_markdown("Mesh export is not currently supported with Gaussian Splatting") def populate_splat_tab( @@ -354,16 +322,10 @@ def populate_splat_tab( ) -> None: if viewing_gsplat: server.gui.add_markdown("Generate ply export of Gaussian Splat") - output_dir = server.gui.add_text( - "Output Directory", initial_value="exports/splat/" - ) + output_dir = server.gui.add_text("Output Directory", initial_value="exports/splat/") export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) - download_button = server.gui.add_button( - "Download Splat", icon=viser.Icon.DOWNLOAD, disabled=True - ) - generate_command = server.gui.add_button( - "Generate Command", icon=viser.Icon.TERMINAL_2 - ) + download_button = server.gui.add_button("Download Splat", icon=viser.Icon.DOWNLOAD, disabled=True) + generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) @export_button.on_click def _(event: viser.GuiEvent) -> None: @@ -421,14 +383,10 @@ def _(event: viser.GuiEvent) -> None: "ns-export gaussian-splat", f"--load-config {config_path}", f"--output-dir {output_dir.value}", - get_crop_string( - control_panel.crop_obb, control_panel.crop_viewport - ), + get_crop_string(control_panel.crop_obb, control_panel.crop_viewport), ] ) show_command_modal(event.client, "splat", command) else: - server.gui.add_markdown( - "Splat export is only supported with Gaussian Splatting methods" - ) + server.gui.add_markdown("Splat export is only supported with Gaussian Splatting methods") From 6b4afed0a9a623346a2b18d056bcc65464ba6817 Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Thu, 29 Aug 2024 16:08:35 -0700 Subject: [PATCH 29/33] ruff format exports --- nerfstudio/viewer/export_panel.py | 95 +++++++++++++++++++++++-------- 1 file changed, 71 insertions(+), 24 deletions(-) diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index 516e72425f..91b61886cc 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -15,11 +15,12 @@ from __future__ import annotations from pathlib import Path +from typing import cast import viser import viser.transforms as vtf -from typing_extensions import Literal, Tuple, List -from typing import cast +from typing_extensions import Literal, Tuple + from nerfstudio.data.scene_box import OrientedBox from nerfstudio.models.base_model import Model from nerfstudio.models.splatfacto import SplatfactoModel @@ -40,7 +41,9 @@ def populate_export_tab( def _(_) -> None: control_panel.crop_viewport = crop_output.value - server.gui.add_markdown("Export available after a checkpoint is saved (default minimum 2000 steps)") + server.gui.add_markdown( + "Export available after a checkpoint is saved (default minimum 2000 steps)" + ) with server.gui.add_folder("Splat"): populate_splat_tab(server, control_panel, config_path, viewing_gsplat) with server.gui.add_folder("Point Cloud"): @@ -91,7 +94,9 @@ def get_crop_string(obb: OrientedBox, crop_viewport: bool) -> str: rpystring = " ".join([f"{x:.10f}" for x in rpy]) posstring = " ".join([f"{x:.10f}" for x in pos]) scalestring = " ".join([f"{x:.10f}" for x in scale]) - return f"--obb_center {posstring} --obb_rotation {rpystring} --obb_scale {scalestring}" + return ( + f"--obb_center {posstring} --obb_rotation {rpystring} --obb_scale {scalestring}" + ) Vec3f = Tuple[float, float, float] @@ -105,7 +110,11 @@ def get_crop_tuple(obb: OrientedBox, crop_viewport: bool): obb_rotation = [rpy.roll, rpy.pitch, rpy.yaw] obb_center = obb.T.squeeze().tolist() obb_scale = obb.S.squeeze().tolist() - return cast(Vec3f, tuple(obb_center)), cast(Vec3f, tuple(obb_rotation)), cast(Vec3f, tuple(obb_scale)) + return ( + cast(Vec3f, tuple(obb_center)), + cast(Vec3f, tuple(obb_rotation)), + cast(Vec3f, tuple(obb_scale)), + ) def populate_point_cloud_tab( @@ -115,8 +124,12 @@ def populate_point_cloud_tab( viewing_gsplat: bool, ) -> None: if not viewing_gsplat: - server.gui.add_markdown("Render depth, project to an oriented point cloud, and filter ") - num_points = server.gui.add_number("# Points", initial_value=1_000_000, min=1, max=None, step=1) + server.gui.add_markdown( + "Render depth, project to an oriented point cloud, and filter " + ) + num_points = server.gui.add_number( + "# Points", initial_value=1_000_000, min=1, max=None, step=1 + ) world_frame = server.gui.add_checkbox( "Save in world frame", False, @@ -134,10 +147,16 @@ def populate_point_cloud_tab( hint="Normal map source.", ) - output_dir = server.gui.add_text("Output Directory", initial_value="exports/pcd/") + output_dir = server.gui.add_text( + "Output Directory", initial_value="exports/pcd/" + ) export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) - download_button = server.gui.add_button("Download Point Cloud", icon=viser.Icon.DOWNLOAD, disabled=True) - generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) + download_button = server.gui.add_button( + "Download Point Cloud", icon=viser.Icon.DOWNLOAD, disabled=True + ) + generate_command = server.gui.add_button( + "Generate Command", icon=viser.Icon.TERMINAL_2 + ) @export_button.on_click def _(event: viser.GuiEvent) -> None: @@ -203,13 +222,17 @@ def _(event: viser.GuiEvent) -> None: f"--remove-outliers {remove_outliers.value}", f"--normal-method {normals.value}", f"--save-world-frame {world_frame.value}", - get_crop_string(control_panel.crop_obb, control_panel.crop_viewport), + get_crop_string( + control_panel.crop_obb, control_panel.crop_viewport + ), ] ) show_command_modal(event.client, "point cloud", command) else: - server.gui.add_markdown("Point cloud export is not currently supported with Gaussian Splatting") + server.gui.add_markdown( + "Point cloud export is not currently supported with Gaussian Splatting" + ) def populate_mesh_tab( @@ -230,14 +253,24 @@ def populate_mesh_tab( hint="Source for normal maps.", ) num_faces = server.gui.add_number("# Faces", initial_value=50_000, min=1) - texture_resolution = server.gui.add_number("Texture Resolution", min=8, initial_value=2048) - num_points = server.gui.add_number("# Points", initial_value=1_000_000, min=1, max=None, step=1) + texture_resolution = server.gui.add_number( + "Texture Resolution", min=8, initial_value=2048 + ) + num_points = server.gui.add_number( + "# Points", initial_value=1_000_000, min=1, max=None, step=1 + ) remove_outliers = server.gui.add_checkbox("Remove outliers", True) - output_dir = server.gui.add_text("Output Directory", initial_value="exports/mesh/") + output_dir = server.gui.add_text( + "Output Directory", initial_value="exports/mesh/" + ) export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) - download_button = server.gui.add_button("Download Mesh", icon=viser.Icon.DOWNLOAD, disabled=True) - generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) + download_button = server.gui.add_button( + "Download Mesh", icon=viser.Icon.DOWNLOAD, disabled=True + ) + generate_command = server.gui.add_button( + "Generate Command", icon=viser.Icon.TERMINAL_2 + ) @export_button.on_click def _(event: viser.GuiEvent) -> None: @@ -305,13 +338,17 @@ def _(event: viser.GuiEvent) -> None: f"--num-points {num_points.value}", f"--remove-outliers {remove_outliers.value}", f"--normal-method {normals.value}", - get_crop_string(control_panel.crop_obb, control_panel.crop_viewport), + get_crop_string( + control_panel.crop_obb, control_panel.crop_viewport + ), ] ) show_command_modal(event.client, "mesh", command) else: - server.gui.add_markdown("Mesh export is not currently supported with Gaussian Splatting") + server.gui.add_markdown( + "Mesh export is not currently supported with Gaussian Splatting" + ) def populate_splat_tab( @@ -322,10 +359,16 @@ def populate_splat_tab( ) -> None: if viewing_gsplat: server.gui.add_markdown("Generate ply export of Gaussian Splat") - output_dir = server.gui.add_text("Output Directory", initial_value="exports/splat/") + output_dir = server.gui.add_text( + "Output Directory", initial_value="exports/splat/" + ) export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) - download_button = server.gui.add_button("Download Splat", icon=viser.Icon.DOWNLOAD, disabled=True) - generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) + download_button = server.gui.add_button( + "Download Splat", icon=viser.Icon.DOWNLOAD, disabled=True + ) + generate_command = server.gui.add_button( + "Generate Command", icon=viser.Icon.TERMINAL_2 + ) @export_button.on_click def _(event: viser.GuiEvent) -> None: @@ -383,10 +426,14 @@ def _(event: viser.GuiEvent) -> None: "ns-export gaussian-splat", f"--load-config {config_path}", f"--output-dir {output_dir.value}", - get_crop_string(control_panel.crop_obb, control_panel.crop_viewport), + get_crop_string( + control_panel.crop_obb, control_panel.crop_viewport + ), ] ) show_command_modal(event.client, "splat", command) else: - server.gui.add_markdown("Splat export is only supported with Gaussian Splatting methods") + server.gui.add_markdown( + "Splat export is only supported with Gaussian Splatting methods" + ) From 5f4ba8dd1021cae75426459e91f3704822622be7 Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Thu, 29 Aug 2024 16:13:28 -0700 Subject: [PATCH 30/33] ruff again --- nerfstudio/viewer/export_panel.py | 84 ++++++++----------------------- 1 file changed, 21 insertions(+), 63 deletions(-) diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index 91b61886cc..ffe1c7a097 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -41,9 +41,7 @@ def populate_export_tab( def _(_) -> None: control_panel.crop_viewport = crop_output.value - server.gui.add_markdown( - "Export available after a checkpoint is saved (default minimum 2000 steps)" - ) + server.gui.add_markdown("Export available after a checkpoint is saved (default minimum 2000 steps)") with server.gui.add_folder("Splat"): populate_splat_tab(server, control_panel, config_path, viewing_gsplat) with server.gui.add_folder("Point Cloud"): @@ -94,9 +92,7 @@ def get_crop_string(obb: OrientedBox, crop_viewport: bool) -> str: rpystring = " ".join([f"{x:.10f}" for x in rpy]) posstring = " ".join([f"{x:.10f}" for x in pos]) scalestring = " ".join([f"{x:.10f}" for x in scale]) - return ( - f"--obb_center {posstring} --obb_rotation {rpystring} --obb_scale {scalestring}" - ) + return f"--obb_center {posstring} --obb_rotation {rpystring} --obb_scale {scalestring}" Vec3f = Tuple[float, float, float] @@ -124,12 +120,8 @@ def populate_point_cloud_tab( viewing_gsplat: bool, ) -> None: if not viewing_gsplat: - server.gui.add_markdown( - "Render depth, project to an oriented point cloud, and filter " - ) - num_points = server.gui.add_number( - "# Points", initial_value=1_000_000, min=1, max=None, step=1 - ) + server.gui.add_markdown("Render depth, project to an oriented point cloud, and filter ") + num_points = server.gui.add_number("# Points", initial_value=1_000_000, min=1, max=None, step=1) world_frame = server.gui.add_checkbox( "Save in world frame", False, @@ -147,16 +139,10 @@ def populate_point_cloud_tab( hint="Normal map source.", ) - output_dir = server.gui.add_text( - "Output Directory", initial_value="exports/pcd/" - ) + output_dir = server.gui.add_text("Output Directory", initial_value="exports/pcd/") export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) - download_button = server.gui.add_button( - "Download Point Cloud", icon=viser.Icon.DOWNLOAD, disabled=True - ) - generate_command = server.gui.add_button( - "Generate Command", icon=viser.Icon.TERMINAL_2 - ) + download_button = server.gui.add_button("Download Point Cloud", icon=viser.Icon.DOWNLOAD, disabled=True) + generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) @export_button.on_click def _(event: viser.GuiEvent) -> None: @@ -222,17 +208,13 @@ def _(event: viser.GuiEvent) -> None: f"--remove-outliers {remove_outliers.value}", f"--normal-method {normals.value}", f"--save-world-frame {world_frame.value}", - get_crop_string( - control_panel.crop_obb, control_panel.crop_viewport - ), + get_crop_string(control_panel.crop_obb, control_panel.crop_viewport), ] ) show_command_modal(event.client, "point cloud", command) else: - server.gui.add_markdown( - "Point cloud export is not currently supported with Gaussian Splatting" - ) + server.gui.add_markdown("Point cloud export is not currently supported with Gaussian Splatting") def populate_mesh_tab( @@ -253,24 +235,14 @@ def populate_mesh_tab( hint="Source for normal maps.", ) num_faces = server.gui.add_number("# Faces", initial_value=50_000, min=1) - texture_resolution = server.gui.add_number( - "Texture Resolution", min=8, initial_value=2048 - ) - num_points = server.gui.add_number( - "# Points", initial_value=1_000_000, min=1, max=None, step=1 - ) + texture_resolution = server.gui.add_number("Texture Resolution", min=8, initial_value=2048) + num_points = server.gui.add_number("# Points", initial_value=1_000_000, min=1, max=None, step=1) remove_outliers = server.gui.add_checkbox("Remove outliers", True) - output_dir = server.gui.add_text( - "Output Directory", initial_value="exports/mesh/" - ) + output_dir = server.gui.add_text("Output Directory", initial_value="exports/mesh/") export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) - download_button = server.gui.add_button( - "Download Mesh", icon=viser.Icon.DOWNLOAD, disabled=True - ) - generate_command = server.gui.add_button( - "Generate Command", icon=viser.Icon.TERMINAL_2 - ) + download_button = server.gui.add_button("Download Mesh", icon=viser.Icon.DOWNLOAD, disabled=True) + generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) @export_button.on_click def _(event: viser.GuiEvent) -> None: @@ -338,17 +310,13 @@ def _(event: viser.GuiEvent) -> None: f"--num-points {num_points.value}", f"--remove-outliers {remove_outliers.value}", f"--normal-method {normals.value}", - get_crop_string( - control_panel.crop_obb, control_panel.crop_viewport - ), + get_crop_string(control_panel.crop_obb, control_panel.crop_viewport), ] ) show_command_modal(event.client, "mesh", command) else: - server.gui.add_markdown( - "Mesh export is not currently supported with Gaussian Splatting" - ) + server.gui.add_markdown("Mesh export is not currently supported with Gaussian Splatting") def populate_splat_tab( @@ -359,16 +327,10 @@ def populate_splat_tab( ) -> None: if viewing_gsplat: server.gui.add_markdown("Generate ply export of Gaussian Splat") - output_dir = server.gui.add_text( - "Output Directory", initial_value="exports/splat/" - ) + output_dir = server.gui.add_text("Output Directory", initial_value="exports/splat/") export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) - download_button = server.gui.add_button( - "Download Splat", icon=viser.Icon.DOWNLOAD, disabled=True - ) - generate_command = server.gui.add_button( - "Generate Command", icon=viser.Icon.TERMINAL_2 - ) + download_button = server.gui.add_button("Download Splat", icon=viser.Icon.DOWNLOAD, disabled=True) + generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) @export_button.on_click def _(event: viser.GuiEvent) -> None: @@ -426,14 +388,10 @@ def _(event: viser.GuiEvent) -> None: "ns-export gaussian-splat", f"--load-config {config_path}", f"--output-dir {output_dir.value}", - get_crop_string( - control_panel.crop_obb, control_panel.crop_viewport - ), + get_crop_string(control_panel.crop_obb, control_panel.crop_viewport), ] ) show_command_modal(event.client, "splat", command) else: - server.gui.add_markdown( - "Splat export is only supported with Gaussian Splatting methods" - ) + server.gui.add_markdown("Splat export is only supported with Gaussian Splatting methods") From b838598fc9112a6ff40f8a3f119df9b841b5f9ee Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Tue, 5 Nov 2024 15:23:24 -0800 Subject: [PATCH 31/33] nit --- nerfstudio/scripts/render.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/nerfstudio/scripts/render.py b/nerfstudio/scripts/render.py index 94a3c64d0a..1548ef1077 100644 --- a/nerfstudio/scripts/render.py +++ b/nerfstudio/scripts/render.py @@ -77,7 +77,7 @@ def _render_trajectory_video( colormap_options: colormaps.ColormapOptions = colormaps.ColormapOptions(), render_nearest_camera: bool = False, check_occlusions: bool = False, - kill_flag: List[bool] = [False], + _kill_flag: List[bool] = [False], ) -> bool: """Helper function to create a video of the spiral trajectory. @@ -455,11 +455,11 @@ class BaseRender: """If true, checks line-of-sight occlusions when computing camera distance and rejects cameras not visible to each other""" camera_idx: Optional[int] = None """Index of the training camera to render.""" - kill_flag: List[bool] = field(default_factory=lambda: [False]) + _kill_flag: tyro.conf.Suppress[List[bool]] = field(default_factory=lambda: [False]) """Stop execution of render if set to True.""" def kill(self) -> None: - self.kill_flag[0] = True + self._kill_flag[0] = True @dataclass @@ -531,7 +531,7 @@ def main(self) -> None: colormap_options=self.colormap_options, render_nearest_camera=self.render_nearest_camera, check_occlusions=self.check_occlusions, - kill_flag=self.kill_flag, + _kill_flag=self._kill_flag, ) if ( @@ -567,7 +567,7 @@ def main(self) -> None: colormap_options=self.colormap_options, render_nearest_camera=self.render_nearest_camera, check_occlusions=self.check_occlusions, - kill_flag=self.kill_flag, + _kill_flag=self._kill_flag, ) self.output_path = Path(str(left_eye_path.parent)[:-5] + ".mp4") @@ -671,7 +671,7 @@ def main(self) -> None: colormap_options=self.colormap_options, render_nearest_camera=self.render_nearest_camera, check_occlusions=self.check_occlusions, - kill_flag=self.kill_flag, + _kill_flag=self._kill_flag, ) @@ -727,7 +727,7 @@ def main(self) -> None: colormap_options=self.colormap_options, render_nearest_camera=self.render_nearest_camera, check_occlusions=self.check_occlusions, - kill_flag=self.kill_flag, + _kill_flag=self._kill_flag, ) From f51f355f797f0e16e9d0c440aaeea509f6640e15 Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Wed, 6 Nov 2024 15:56:07 -0800 Subject: [PATCH 32/33] error message and load checkpoint fixes --- nerfstudio/scripts/exporter.py | 12 ++++----- nerfstudio/scripts/render.py | 6 ++--- nerfstudio/viewer/export_panel.py | 41 ++++++++++++++++++++++++++++--- nerfstudio/viewer/render_panel.py | 23 ++++++++++++++++- 4 files changed, 69 insertions(+), 13 deletions(-) diff --git a/nerfstudio/scripts/exporter.py b/nerfstudio/scripts/exporter.py index 7c57d5a3c9..73aee72494 100644 --- a/nerfstudio/scripts/exporter.py +++ b/nerfstudio/scripts/exporter.py @@ -58,7 +58,7 @@ class Exporter: """Path to the config YAML file.""" output_dir: Path """Path to the output directory.""" - complete: bool = False + _complete: tyro.conf.Suppress[bool] = False """Set to True when export is finished.""" @@ -193,7 +193,7 @@ def main(self) -> None: print("\033[A\033[A") CONSOLE.print("[bold green]:white_check_mark: Saving Point Cloud") - self.complete = True + self._complete = True @dataclass @@ -276,7 +276,7 @@ def main(self) -> None: num_pixels_per_side=self.num_pixels_per_side, ) - self.complete = True + self._complete = True @dataclass @@ -409,7 +409,7 @@ def main(self) -> None: num_pixels_per_side=self.num_pixels_per_side, ) - self.complete = True + self._complete = True @dataclass @@ -477,7 +477,7 @@ def main(self) -> None: num_pixels_per_side=self.num_pixels_per_side, ) - self.complete = True + self._complete = True @dataclass @@ -679,7 +679,7 @@ def main(self) -> None: ExportGaussianSplat.write_ply(str(filename), count, map_to_tensors) - self.complete = True + self._complete = True Commands = tyro.conf.FlagConversionOff[ diff --git a/nerfstudio/scripts/render.py b/nerfstudio/scripts/render.py index 8b7f376b3d..3c3c57d16a 100644 --- a/nerfstudio/scripts/render.py +++ b/nerfstudio/scripts/render.py @@ -138,7 +138,7 @@ def _render_trajectory_video( with progress: for camera_idx in progress.track(range(cameras.size), description=""): - if kill_flag[0]: + if _kill_flag[0]: return False obb_box = None @@ -475,7 +475,7 @@ class RenderCameraPath(BaseRender): """Filename of the camera path to render.""" output_format: Literal["images", "video"] = "video" """How to save output data.""" - complete: bool = True + _complete: tyro.conf.Suppress[bool] = False """Set to True when render is finished.""" def main(self) -> None: @@ -520,7 +520,7 @@ def main(self) -> None: if self.camera_idx is not None: camera_path.metadata = {"cam_idx": self.camera_idx} - self.complete = _render_trajectory_video( + self._complete = _render_trajectory_video( pipeline, camera_path, output_filename=self.output_path, diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index ffe1c7a097..ebaa914f6f 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -17,6 +17,8 @@ from pathlib import Path from typing import cast +import yaml + import viser import viser.transforms as vtf from typing_extensions import Literal, Tuple @@ -149,6 +151,17 @@ def _(event: viser.GuiEvent) -> None: client = event.client assert client is not None + config = yaml.load(config_path.read_text(), Loader=yaml.Loader) + if config.load_dir is None: + notif = client.add_notification( + title="Export unsuccessful", + body="Cannot export before 2000 training steps.", + loading=False, + with_close_button=True, + color="red", + ) + return + notif = client.add_notification( title="Exporting point cloud", body="File will be saved under " + str(output_dir.value), @@ -178,7 +191,7 @@ def _(event: viser.GuiEvent) -> None: ) export.main() - if export.complete: + if export._complete: notif.title = "Export complete!" notif.body = "File saved under " + str(output_dir.value) notif.loading = False @@ -249,6 +262,17 @@ def _(event: viser.GuiEvent) -> None: client = event.client assert client is not None + config = yaml.load(config_path.read_text(), Loader=yaml.Loader) + if config.load_dir is None: + notif = client.add_notification( + title="Export unsuccessful", + body="Cannot export before 2000 training steps.", + loading=False, + with_close_button=True, + color="red", + ) + return + notif = client.add_notification( title="Exporting poisson mesh", body="File will be saved under " + str(output_dir.value), @@ -279,7 +303,7 @@ def _(event: viser.GuiEvent) -> None: ) export.main() - if export.complete: + if export._complete: notif.title = "Export complete!" notif.body = "File saved under " + str(output_dir.value) notif.loading = False @@ -337,6 +361,17 @@ def _(event: viser.GuiEvent) -> None: client = event.client assert client is not None + config = yaml.load(config_path.read_text(), Loader=yaml.Loader) + if config.load_dir is None: + notif = client.add_notification( + title="Export unsuccessful", + body="Cannot export before 2000 training steps.", + loading=False, + with_close_button=True, + color="red", + ) + return + notif = client.add_notification( title="Exporting gaussian splat", body="File will be saved under " + str(output_dir.value), @@ -362,7 +397,7 @@ def _(event: viser.GuiEvent) -> None: ) export.main() - if export.complete: + if export._complete: notif.title = "Export complete!" notif.body = "File saved under " + str(output_dir.value) notif.loading = False diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index d288cc54b1..cb9ec56647 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -20,6 +20,8 @@ import json import threading import time +import yaml + from pathlib import Path from typing import Dict, List, Literal, Optional, Tuple, Union @@ -28,6 +30,7 @@ import splines.quaternion import viser import viser.transforms as tf + from scipy import interpolate from nerfstudio.viewer.control_panel import ControlPanel @@ -960,6 +963,7 @@ def _(_) -> None: @load_camera_path_button.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None + camera_path_dir = datapath / "camera_paths" camera_path_dir.mkdir(parents=True, exist_ok=True) preexisting_camera_paths = list(camera_path_dir.glob("*.json")) @@ -1017,6 +1021,10 @@ def _(_) -> None: # update the render name render_name_text.value = json_path.stem camera_path.update_spline() + + if len(camera_path._keyframes) > 1: + render_button.disabled = False + modal.close() cancel_button = event.client.gui.add_button("Cancel") @@ -1182,6 +1190,19 @@ def _(event: viser.GuiEvent) -> None: client = event.client assert client is not None + config = yaml.load(config_path.read_text(), Loader=yaml.Loader) + if config.load_dir is None: + render_button.disabled = True + + notif = client.add_notification( + title="Render unsuccessful", + body="Cannot render video before 2000 training steps.", + loading=False, + with_close_button=True, + color="red", + ) + return + render_button.disabled = True cancel_render_button.disabled = False @@ -1216,7 +1237,7 @@ def _(event: viser.GuiEvent) -> None: render.main() - if render.complete: + if render._complete: notif.title = "Render complete!" notif.body = "Video saved as " + render_path notif.loading = False From 02a3f7a8060487b809135044e9d1e4f326bb831a Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Wed, 6 Nov 2024 16:05:42 -0800 Subject: [PATCH 33/33] ruff format import blocks --- nerfstudio/viewer/export_panel.py | 3 +-- nerfstudio/viewer/render_panel.py | 4 +--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index ebaa914f6f..e600fa5a31 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -17,10 +17,9 @@ from pathlib import Path from typing import cast -import yaml - import viser import viser.transforms as vtf +import yaml from typing_extensions import Literal, Tuple from nerfstudio.data.scene_box import OrientedBox diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index cb9ec56647..b375b3ce73 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -20,8 +20,6 @@ import json import threading import time -import yaml - from pathlib import Path from typing import Dict, List, Literal, Optional, Tuple, Union @@ -30,7 +28,7 @@ import splines.quaternion import viser import viser.transforms as tf - +import yaml from scipy import interpolate from nerfstudio.viewer.control_panel import ControlPanel