diff --git a/nerfstudio/scripts/exporter.py b/nerfstudio/scripts/exporter.py index b6404af376..73aee72494 100644 --- a/nerfstudio/scripts/exporter.py +++ b/nerfstudio/scripts/exporter.py @@ -58,6 +58,8 @@ class Exporter: """Path to the config YAML file.""" output_dir: Path """Path to the output directory.""" + _complete: tyro.conf.Suppress[bool] = False + """Set to True when export is finished.""" def validate_pipeline(normal_method: str, normal_output_name: str, pipeline: Pipeline) -> None: @@ -141,7 +143,12 @@ def main(self) -> None: # Increase the batchsize to speed up the evaluation. assert isinstance( pipeline.datamanager, - (VanillaDataManager, ParallelDataManager, FullImageDatamanager, RandomCamerasDataManager), + ( + VanillaDataManager, + ParallelDataManager, + FullImageDatamanager, + RandomCamerasDataManager, + ), ) assert pipeline.datamanager.train_pixel_sampler is not None pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = self.num_rays_per_batch @@ -159,7 +166,7 @@ def main(self) -> None: estimate_normals=estimate_normals, rgb_output_name=self.rgb_output_name, depth_output_name=self.depth_output_name, - normal_output_name=self.normal_output_name if self.normal_method == "model_output" else None, + normal_output_name=(self.normal_output_name if self.normal_method == "model_output" else None), crop_obb=crop_obb, std_ratio=self.std_ratio, ) @@ -186,6 +193,8 @@ def main(self) -> None: print("\033[A\033[A") CONSOLE.print("[bold green]:white_check_mark: Saving Point Cloud") + self._complete = True + @dataclass class ExportTSDFMesh(Exporter): @@ -254,18 +263,21 @@ def main(self) -> None: if self.texture_method == "nerf": # load the mesh from the tsdf export mesh = get_mesh_from_filename( - str(self.output_dir / "tsdf_mesh.ply"), target_num_faces=self.target_num_faces + str(self.output_dir / "tsdf_mesh.ply"), + target_num_faces=self.target_num_faces, ) CONSOLE.print("Texturing mesh with NeRF") texture_utils.export_textured_mesh( mesh, pipeline, self.output_dir, - px_per_uv_triangle=self.px_per_uv_triangle if self.unwrap_method == "custom" else None, + px_per_uv_triangle=(self.px_per_uv_triangle if self.unwrap_method == "custom" else None), unwrap_method=self.unwrap_method, num_pixels_per_side=self.num_pixels_per_side, ) + self._complete = True + @dataclass class ExportPoissonMesh(Exporter): @@ -329,7 +341,12 @@ def main(self) -> None: # Increase the batchsize to speed up the evaluation. assert isinstance( pipeline.datamanager, - (VanillaDataManager, ParallelDataManager, FullImageDatamanager, RandomCamerasDataManager), + ( + VanillaDataManager, + ParallelDataManager, + FullImageDatamanager, + RandomCamerasDataManager, + ), ) assert pipeline.datamanager.train_pixel_sampler is not None pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = self.num_rays_per_batch @@ -349,7 +366,7 @@ def main(self) -> None: estimate_normals=estimate_normals, rgb_output_name=self.rgb_output_name, depth_output_name=self.depth_output_name, - normal_output_name=self.normal_output_name if self.normal_method == "model_output" else None, + normal_output_name=(self.normal_output_name if self.normal_method == "model_output" else None), crop_obb=crop_obb, std_ratio=self.std_ratio, ) @@ -379,18 +396,21 @@ def main(self) -> None: if self.texture_method == "nerf": # load the mesh from the poisson reconstruction mesh = get_mesh_from_filename( - str(self.output_dir / "poisson_mesh.ply"), target_num_faces=self.target_num_faces + str(self.output_dir / "poisson_mesh.ply"), + target_num_faces=self.target_num_faces, ) CONSOLE.print("Texturing mesh with NeRF") texture_utils.export_textured_mesh( mesh, pipeline, self.output_dir, - px_per_uv_triangle=self.px_per_uv_triangle if self.unwrap_method == "custom" else None, + px_per_uv_triangle=(self.px_per_uv_triangle if self.unwrap_method == "custom" else None), unwrap_method=self.unwrap_method, num_pixels_per_side=self.num_pixels_per_side, ) + self._complete = True + @dataclass class ExportMarchingCubesMesh(Exporter): @@ -452,11 +472,13 @@ def main(self) -> None: mesh, pipeline, self.output_dir, - px_per_uv_triangle=self.px_per_uv_triangle if self.unwrap_method == "custom" else None, + px_per_uv_triangle=(self.px_per_uv_triangle if self.unwrap_method == "custom" else None), unwrap_method=self.unwrap_method, num_pixels_per_side=self.num_pixels_per_side, ) + self._complete = True + @dataclass class ExportCameraPoses(Exporter): @@ -473,7 +495,10 @@ def main(self) -> None: assert isinstance(pipeline, VanillaPipeline) train_frames, eval_frames = collect_camera_poses(pipeline) - for file_name, frames in [("transforms_train.json", train_frames), ("transforms_eval.json", eval_frames)]: + for file_name, frames in [ + ("transforms_train.json", train_frames), + ("transforms_eval.json", eval_frames), + ]: if len(frames) == 0: CONSOLE.print(f"[bold yellow]No frames found for {file_name}. Skipping.") continue @@ -654,6 +679,8 @@ def main(self) -> None: ExportGaussianSplat.write_ply(str(filename), count, map_to_tensors) + self._complete = True + Commands = tyro.conf.FlagConversionOff[ Union[ diff --git a/nerfstudio/scripts/render.py b/nerfstudio/scripts/render.py index 4eb4a71840..3c3c57d16a 100644 --- a/nerfstudio/scripts/render.py +++ b/nerfstudio/scripts/render.py @@ -75,9 +75,10 @@ def _render_trajectory_video( depth_near_plane: Optional[float] = None, depth_far_plane: Optional[float] = None, colormap_options: colormaps.ColormapOptions = colormaps.ColormapOptions(), - render_nearest_camera=False, + render_nearest_camera: bool = False, check_occlusions: bool = False, -) -> None: + _kill_flag: List[bool] = [False], +) -> bool: """Helper function to create a video of the spiral trajectory. Args: @@ -137,6 +138,9 @@ def _render_trajectory_video( with progress: for camera_idx in progress.track(range(cameras.size), description=""): + if _kill_flag[0]: + return False + obb_box = None if crop_data is not None: obb_box = crop_data.obb @@ -205,9 +209,13 @@ def _render_trajectory_video( for rendered_output_name in rendered_output_names: if rendered_output_name not in outputs: CONSOLE.rule("Error", style="red") - CONSOLE.print(f"Could not find {rendered_output_name} in the model outputs", justify="center") CONSOLE.print( - f"Please set --rendered_output_name to one of: {outputs.keys()}", justify="center" + f"Could not find {rendered_output_name} in the model outputs", + justify="center", + ) + CONSOLE.print( + f"Please set --rendered_output_name to one of: {outputs.keys()}", + justify="center", ) sys.exit(1) output_image = outputs[rendered_output_name] @@ -261,10 +269,17 @@ def _render_trajectory_video( render_image = np.concatenate(render_image, axis=1) if output_format == "images": if image_format == "png": - media.write_image(output_image_dir / f"{camera_idx:05d}.png", render_image, fmt="png") + media.write_image( + output_image_dir / f"{camera_idx:05d}.png", + render_image, + fmt="png", + ) if image_format == "jpeg": media.write_image( - output_image_dir / f"{camera_idx:05d}.jpg", render_image, fmt="jpeg", quality=jpeg_quality + output_image_dir / f"{camera_idx:05d}.jpg", + render_image, + fmt="jpeg", + quality=jpeg_quality, ) if output_format == "video": if writer is None: @@ -292,7 +307,15 @@ def _render_trajectory_video( table.add_row("Video", str(output_filename)) else: table.add_row("Images", str(output_image_dir)) - CONSOLE.print(Panel(table, title="[bold][green]:tada: Render Complete :tada:[/bold]", expand=False)) + CONSOLE.print( + Panel( + table, + title="[bold][green]:tada: Render Complete :tada:[/bold]", + expand=False, + ) + ) + + return True def insert_spherical_metadata_into_file( @@ -437,6 +460,11 @@ class BaseRender: """If true, checks line-of-sight occlusions when computing camera distance and rejects cameras not visible to each other""" camera_idx: Optional[int] = None """Index of the training camera to render.""" + _kill_flag: tyro.conf.Suppress[List[bool]] = field(default_factory=lambda: [False]) + """Stop execution of render if set to True.""" + + def kill(self) -> None: + self._kill_flag[0] = True @dataclass @@ -447,6 +475,8 @@ class RenderCameraPath(BaseRender): """Filename of the camera path to render.""" output_format: Literal["images", "video"] = "video" """How to save output data.""" + _complete: tyro.conf.Suppress[bool] = False + """Set to True when render is finished.""" def main(self) -> None: """Main function.""" @@ -490,7 +520,7 @@ def main(self) -> None: if self.camera_idx is not None: camera_path.metadata = {"cam_idx": self.camera_idx} - _render_trajectory_video( + self._complete = _render_trajectory_video( pipeline, camera_path, output_filename=self.output_path, @@ -506,6 +536,7 @@ def main(self) -> None: colormap_options=self.colormap_options, render_nearest_camera=self.render_nearest_camera, check_occlusions=self.check_occlusions, + _kill_flag=self._kill_flag, ) if ( @@ -541,6 +572,7 @@ def main(self) -> None: colormap_options=self.colormap_options, render_nearest_camera=self.render_nearest_camera, check_occlusions=self.check_occlusions, + _kill_flag=self._kill_flag, ) self.output_path = Path(str(left_eye_path.parent)[:-5] + ".mp4") @@ -644,6 +676,7 @@ def main(self) -> None: colormap_options=self.colormap_options, render_nearest_camera=self.render_nearest_camera, check_occlusions=self.check_occlusions, + _kill_flag=self._kill_flag, ) @@ -699,6 +732,7 @@ def main(self) -> None: colormap_options=self.colormap_options, render_nearest_camera=self.render_nearest_camera, check_occlusions=self.check_occlusions, + _kill_flag=self._kill_flag, ) @@ -736,7 +770,10 @@ def main(self): def update_config(config: TrainerConfig) -> TrainerConfig: data_manager_config = config.pipeline.datamanager - assert isinstance(data_manager_config, (VanillaDataManagerConfig, FullImageDatamanagerConfig)) + assert isinstance( + data_manager_config, + (VanillaDataManagerConfig, FullImageDatamanagerConfig), + ) data_manager_config.eval_num_images_to_sample_from = -1 data_manager_config.eval_num_times_to_repeat_images = -1 if isinstance(data_manager_config, VanillaDataManagerConfig): @@ -746,7 +783,11 @@ def update_config(config: TrainerConfig) -> TrainerConfig: data_manager_config.data = self.data if self.downscale_factor is not None: assert hasattr(data_manager_config.dataparser, "downscale_factor") - setattr(data_manager_config.dataparser, "downscale_factor", self.downscale_factor) + setattr( + data_manager_config.dataparser, + "downscale_factor", + self.downscale_factor, + ) return config config, pipeline, _, _ = eval_setup( @@ -814,10 +855,12 @@ def update_config(config: TrainerConfig) -> TrainerConfig: if rendered_output_name not in all_outputs: CONSOLE.rule("Error", style="red") CONSOLE.print( - f"Could not find {rendered_output_name} in the model outputs", justify="center" + f"Could not find {rendered_output_name} in the model outputs", + justify="center", ) CONSOLE.print( - f"Please set --rendered-output-name to one of: {all_outputs}", justify="center" + f"Please set --rendered-output-name to one of: {all_outputs}", + justify="center", ) sys.exit(1) @@ -885,7 +928,10 @@ def update_config(config: TrainerConfig) -> TrainerConfig: media.write_image(output_path.with_suffix(".png"), output_image, fmt="png") elif self.image_format == "jpeg": media.write_image( - output_path.with_suffix(".jpg"), output_image, fmt="jpeg", quality=self.jpeg_quality + output_path.with_suffix(".jpg"), + output_image, + fmt="jpeg", + quality=self.jpeg_quality, ) else: raise ValueError(f"Unknown image format {self.image_format}") @@ -898,7 +944,13 @@ def update_config(config: TrainerConfig) -> TrainerConfig: ) for split in self.split.split("+"): table.add_row(f"Outputs {split}", str(self.output_path / split)) - CONSOLE.print(Panel(table, title="[bold][green]:tada: Render on split {} Complete :tada:[/bold]", expand=False)) + CONSOLE.print( + Panel( + table, + title="[bold][green]:tada: Render on split {} Complete :tada:[/bold]", + expand=False, + ) + ) Commands = tyro.conf.FlagConversionOff[ diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index 2bb3969cd5..e600fa5a31 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -15,10 +15,12 @@ from __future__ import annotations from pathlib import Path +from typing import cast import viser import viser.transforms as vtf -from typing_extensions import Literal +import yaml +from typing_extensions import Literal, Tuple from nerfstudio.data.scene_box import OrientedBox from nerfstudio.models.base_model import Model @@ -40,6 +42,7 @@ def populate_export_tab( def _(_) -> None: control_panel.crop_viewport = crop_output.value + server.gui.add_markdown("Export available after a checkpoint is saved (default minimum 2000 steps)") with server.gui.add_folder("Splat"): populate_splat_tab(server, control_panel, config_path, viewing_gsplat) with server.gui.add_folder("Point Cloud"): @@ -48,7 +51,11 @@ def _(_) -> None: populate_mesh_tab(server, control_panel, config_path, viewing_gsplat) -def show_command_modal(client: viser.ClientHandle, what: Literal["mesh", "point cloud", "splat"], command: str) -> None: +def show_command_modal( + client: viser.ClientHandle, + what: Literal["mesh", "point cloud", "splat"], + command: str, +) -> None: """Show a modal to each currently connected client. In the future, we should only show the modal to the client that pushes the @@ -73,7 +80,7 @@ def _(_) -> None: modal.close() -def get_crop_string(obb: OrientedBox, crop_viewport: bool): +def get_crop_string(obb: OrientedBox, crop_viewport: bool) -> str: """Takes in an oriented bounding box and returns a string of the form "--obb_{center,rotation,scale} and each arg formatted with spaces around it """ @@ -89,6 +96,24 @@ def get_crop_string(obb: OrientedBox, crop_viewport: bool): return f"--obb_center {posstring} --obb_rotation {rpystring} --obb_scale {scalestring}" +Vec3f = Tuple[float, float, float] + + +def get_crop_tuple(obb: OrientedBox, crop_viewport: bool): + """Takes in an oriented bounding box and returns tuples for obb_{center,rotation,scale}.""" + if not crop_viewport: + return None, None, None + rpy = vtf.SO3.from_matrix(obb.R.numpy(force=True)).as_rpy_radians() + obb_rotation = [rpy.roll, rpy.pitch, rpy.yaw] + obb_center = obb.T.squeeze().tolist() + obb_scale = obb.S.squeeze().tolist() + return ( + cast(Vec3f, tuple(obb_center)), + cast(Vec3f, tuple(obb_rotation)), + cast(Vec3f, tuple(obb_scale)), + ) + + def populate_point_cloud_tab( server: viser.ViserServer, control_panel: ControlPanel, @@ -114,9 +139,75 @@ def populate_point_cloud_tab( initial_value="open3d", hint="Normal map source.", ) + output_dir = server.gui.add_text("Output Directory", initial_value="exports/pcd/") + export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) + download_button = server.gui.add_button("Download Point Cloud", icon=viser.Icon.DOWNLOAD, disabled=True) generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) + @export_button.on_click + def _(event: viser.GuiEvent) -> None: + client = event.client + assert client is not None + + config = yaml.load(config_path.read_text(), Loader=yaml.Loader) + if config.load_dir is None: + notif = client.add_notification( + title="Export unsuccessful", + body="Cannot export before 2000 training steps.", + loading=False, + with_close_button=True, + color="red", + ) + return + + notif = client.add_notification( + title="Exporting point cloud", + body="File will be saved under " + str(output_dir.value), + loading=True, + with_close_button=False, + ) + + if control_panel.crop_obb is not None and control_panel.crop_viewport: + obb_center, obb_rotation, obb_scale = get_crop_tuple( + control_panel.crop_obb, control_panel.crop_viewport + ) + else: + obb_center, obb_rotation, obb_scale = None, None, None + + from nerfstudio.scripts.exporter import ExportPointCloud + + export = ExportPointCloud( + load_config=config_path, + output_dir=Path(output_dir.value), + num_points=num_points.value, + remove_outliers=remove_outliers.value, + normal_method=normals.value, + save_world_frame=world_frame.value, + obb_center=obb_center, + obb_rotation=obb_rotation, + obb_scale=obb_scale, + ) + export.main() + + if export._complete: + notif.title = "Export complete!" + notif.body = "File saved under " + str(output_dir.value) + notif.loading = False + notif.with_close_button = True + + download_button.disabled = False + + @download_button.on_click + def _(event: viser.GuiEvent) -> None: + client = event.client + assert client is not None + + with open(str(output_dir.value) + "point_cloud.ply", "rb") as ply_file: + ply_bytes = ply_file.read() + + client.send_file_download("point_cloud.ply", ply_bytes) + @generate_command.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None @@ -157,12 +248,78 @@ def populate_mesh_tab( ) num_faces = server.gui.add_number("# Faces", initial_value=50_000, min=1) texture_resolution = server.gui.add_number("Texture Resolution", min=8, initial_value=2048) - output_directory = server.gui.add_text("Output Directory", initial_value="exports/mesh/") num_points = server.gui.add_number("# Points", initial_value=1_000_000, min=1, max=None, step=1) remove_outliers = server.gui.add_checkbox("Remove outliers", True) + output_dir = server.gui.add_text("Output Directory", initial_value="exports/mesh/") + export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) + download_button = server.gui.add_button("Download Mesh", icon=viser.Icon.DOWNLOAD, disabled=True) generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) + @export_button.on_click + def _(event: viser.GuiEvent) -> None: + client = event.client + assert client is not None + + config = yaml.load(config_path.read_text(), Loader=yaml.Loader) + if config.load_dir is None: + notif = client.add_notification( + title="Export unsuccessful", + body="Cannot export before 2000 training steps.", + loading=False, + with_close_button=True, + color="red", + ) + return + + notif = client.add_notification( + title="Exporting poisson mesh", + body="File will be saved under " + str(output_dir.value), + loading=True, + with_close_button=False, + ) + + if control_panel.crop_obb is not None and control_panel.crop_viewport: + obb_center, obb_rotation, obb_scale = get_crop_tuple( + control_panel.crop_obb, control_panel.crop_viewport + ) + else: + obb_center, obb_rotation, obb_scale = None, None, None + + from nerfstudio.scripts.exporter import ExportPoissonMesh + + export = ExportPoissonMesh( + load_config=config_path, + output_dir=Path(output_dir.value), + target_num_faces=num_faces.value, + num_pixels_per_side=texture_resolution.value, + num_points=num_points.value, + remove_outliers=remove_outliers.value, + normal_method=normals.value, + obb_center=obb_center, + obb_rotation=obb_rotation, + obb_scale=obb_scale, + ) + export.main() + + if export._complete: + notif.title = "Export complete!" + notif.body = "File saved under " + str(output_dir.value) + notif.loading = False + notif.with_close_button = True + + download_button.disabled = False + + @download_button.on_click + def _(event: viser.GuiEvent) -> None: + client = event.client + assert client is not None + + with open(str(output_dir.value) + "poisson_mesh.ply", "rb") as ply_file: + ply_bytes = ply_file.read() + + client.send_file_download("poisson_mesh.ply", ply_bytes) + @generate_command.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None @@ -170,7 +327,7 @@ def _(event: viser.GuiEvent) -> None: [ "ns-export poisson", f"--load-config {config_path}", - f"--output-dir {output_directory.value}", + f"--output-dir {output_dir.value}", f"--target-num-faces {num_faces.value}", f"--num-pixels-per-side {texture_resolution.value}", f"--num-points {num_points.value}", @@ -193,10 +350,70 @@ def populate_splat_tab( ) -> None: if viewing_gsplat: server.gui.add_markdown("Generate ply export of Gaussian Splat") - - output_directory = server.gui.add_text("Output Directory", initial_value="exports/splat/") + output_dir = server.gui.add_text("Output Directory", initial_value="exports/splat/") + export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT) + download_button = server.gui.add_button("Download Splat", icon=viser.Icon.DOWNLOAD, disabled=True) generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) + @export_button.on_click + def _(event: viser.GuiEvent) -> None: + client = event.client + assert client is not None + + config = yaml.load(config_path.read_text(), Loader=yaml.Loader) + if config.load_dir is None: + notif = client.add_notification( + title="Export unsuccessful", + body="Cannot export before 2000 training steps.", + loading=False, + with_close_button=True, + color="red", + ) + return + + notif = client.add_notification( + title="Exporting gaussian splat", + body="File will be saved under " + str(output_dir.value), + loading=True, + with_close_button=False, + ) + + if control_panel.crop_obb is not None and control_panel.crop_viewport: + obb_center, obb_rotation, obb_scale = get_crop_tuple( + control_panel.crop_obb, control_panel.crop_viewport + ) + else: + obb_center, obb_rotation, obb_scale = None, None, None + + from nerfstudio.scripts.exporter import ExportGaussianSplat + + export = ExportGaussianSplat( + load_config=config_path, + output_dir=Path(output_dir.value), + obb_center=obb_center, + obb_rotation=obb_rotation, + obb_scale=obb_scale, + ) + export.main() + + if export._complete: + notif.title = "Export complete!" + notif.body = "File saved under " + str(output_dir.value) + notif.loading = False + notif.with_close_button = True + + download_button.disabled = False + + @download_button.on_click + def _(event: viser.GuiEvent) -> None: + client = event.client + assert client is not None + + with open(str(output_dir.value) + "splat.ply", "rb") as ply_file: + ply_bytes = ply_file.read() + + client.send_file_download("splat.ply", ply_bytes) + @generate_command.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None @@ -204,7 +421,7 @@ def _(event: viser.GuiEvent) -> None: [ "ns-export gaussian-splat", f"--load-config {config_path}", - f"--output-dir {output_directory.value}", + f"--output-dir {output_dir.value}", get_crop_string(control_panel.crop_obb, control_panel.crop_viewport), ] ) diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index 10d263f8c2..b375b3ce73 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -28,6 +28,7 @@ import splines.quaternion import viser import viser.transforms as tf +import yaml from scipy import interpolate from nerfstudio.viewer.control_panel import ControlPanel @@ -62,7 +63,10 @@ def from_camera(camera: viser.CameraHandle, aspect: float) -> Keyframe: class CameraPath: def __init__( - self, server: viser.ViserServer, duration_element: viser.GuiInputHandle[float], time_enabled: bool = False + self, + server: viser.ViserServer, + duration_element: viser.GuiInputHandle[float], + time_enabled: bool = False, ): self._server = server self._keyframes: Dict[int, Tuple[Keyframe, viser.CameraFrustumHandle]] = {} @@ -609,6 +613,10 @@ def _(event: viser.GuiEvent) -> None: duration_number.value = camera_path.compute_duration() camera_path.update_spline() + # Enable render only if there are two or more keyframes + render_button.disabled = len(camera_path._keyframes) < 2 + generate_command_render_button.disabled = len(camera_path._keyframes) < 2 + clear_keyframes_button = server.gui.add_button( "Clear Keyframes", icon=viser.Icon.TRASH, @@ -629,6 +637,9 @@ def _(_) -> None: camera_path.reset() modal.close() + render_button.disabled = True + generate_command_render_button.disabled = True + duration_number.value = camera_path.compute_duration() # Clear move handles. @@ -840,6 +851,7 @@ def _(_) -> None: position=pose.translation(), color=(10, 200, 30), ) + if render_tab_state.preview_render: for client in server.get_clients().values(): client.camera.wxyz = pose.rotation().wxyz @@ -949,6 +961,7 @@ def _(_) -> None: @load_camera_path_button.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None + camera_path_dir = datapath / "camera_paths" camera_path_dir.mkdir(parents=True, exist_ok=True) preexisting_camera_paths = list(camera_path_dir.glob("*.json")) @@ -978,11 +991,13 @@ def _(_) -> None: for i in range(len(keyframes)): frame = keyframes[i] pose = tf.SE3.from_matrix(np.array(frame["matrix"]).reshape(4, 4)) + # apply the x rotation by 180 deg pose = tf.SE3.from_rotation_and_translation( pose.rotation() @ tf.SO3.from_x_radians(np.pi), pose.translation(), ) + camera_path.add_camera( Keyframe( position=pose.translation() * VISER_NERFSTUDIO_SCALE_RATIO, @@ -1004,6 +1019,10 @@ def _(_) -> None: # update the render name render_name_text.value = json_path.stem camera_path.update_spline() + + if len(camera_path._keyframes) > 1: + render_button.disabled = False + modal.close() cancel_button = event.client.gui.add_button("Cancel") @@ -1019,30 +1038,45 @@ def _(_) -> None: initial_value=now.strftime("%Y-%m-%d-%H-%M-%S"), hint="Name of the render", ) - render_button = server.gui.add_button( - "Generate Command", - color="green", - icon=viser.Icon.FILE_EXPORT, - hint="Generate the ns-render command for rendering the camera path.", - ) - reset_up_button = server.gui.add_button( - "Reset Up Direction", - icon=viser.Icon.ARROW_BIG_UP_LINES, - color="gray", - hint="Set the up direction of the camera orbit controls to the camera's current up direction.", - ) + render_folder = server.gui.add_folder("Render") + with render_folder: + server.gui.add_markdown( + "Render available after a checkpoint is saved (default minimum 2000 steps)" + ) - @reset_up_button.on_click - def _(event: viser.GuiEvent) -> None: - assert event.client is not None - event.client.camera.up_direction = tf.SO3(event.client.camera.wxyz) @ np.array([0.0, -1.0, 0.0]) + render_button = server.gui.add_button( + "Render", + icon=viser.Icon.VIDEO, + hint="Render the camera path and save video as mp4 file.", + disabled=True, + ) - @render_button.on_click - def _(event: viser.GuiEvent) -> None: - assert event.client is not None + cancel_render_button = server.gui.add_button( + "Cancel Render", + icon=viser.Icon.CIRCLE_X, + hint="Cancel current render in progress.", + disabled=True, + ) + + download_render_button = server.gui.add_button( + "Download Render", + icon=viser.Icon.DOWNLOAD, + hint="Download the latest render locally as mp4 file.", + disabled=True, + ) + + generate_command_render_button = server.gui.add_button( + "Generate Command", + icon=viser.Icon.FILE_EXPORT, + hint="Generate the ns-render command for rendering the camera path instead of directly rendering.", + disabled=True, + ) + + def _write_json() -> Path: num_frames = int(framerate_number.value * duration_number.value) json_data = {} + # json data has the properties: # keyframes: list of keyframes with # matrix : flattened 4x4 matrix @@ -1059,13 +1093,15 @@ def _(event: viser.GuiEvent) -> None: # camera_to_world: flattened 4x4 matrix # fov: float in degrees # aspect: float + # first populate the keyframes: keyframes = [] - for keyframe, dummy in camera_path._keyframes.values(): + for keyframe, _ in camera_path._keyframes.values(): pose = tf.SE3.from_rotation_and_translation( tf.SO3(keyframe.wxyz) @ tf.SO3.from_x_radians(np.pi), keyframe.position / VISER_NERFSTUDIO_SCALE_RATIO, ) + keyframe_dict = { "matrix": pose.as_matrix().flatten().tolist(), "fov": np.rad2deg(keyframe.override_fov_rad) if keyframe.override_fov_enabled else fov_degrees.value, @@ -1073,15 +1109,20 @@ def _(event: viser.GuiEvent) -> None: "override_transition_enabled": keyframe.override_transition_enabled, "override_transition_sec": keyframe.override_transition_sec, } + if render_time is not None: keyframe_dict["render_time"] = ( keyframe.override_time_val if keyframe.override_time_enabled else render_time.value ) keyframe_dict["override_time_enabled"] = keyframe.override_time_enabled + keyframes.append(keyframe_dict) + json_data["default_fov"] = fov_degrees.value + if render_time is not None: json_data["default_time"] = render_time.value if render_time is not None else None + json_data["default_transition_sec"] = transition_sec_number.value json_data["keyframes"] = keyframes json_data["camera_type"] = camera_type.value.lower() @@ -1091,31 +1132,36 @@ def _(event: viser.GuiEvent) -> None: json_data["seconds"] = duration_number.value json_data["is_cycle"] = loop.value json_data["smoothness_value"] = tension_slider.value + # now populate the camera path: camera_path_list = [] for i in range(num_frames): maybe_pose_and_fov = camera_path.interpolate_pose_and_fov_rad(i / num_frames) if maybe_pose_and_fov is None: - return + return Path() time = None if len(maybe_pose_and_fov) == 3: # Time is enabled. pose, fov, time = maybe_pose_and_fov else: pose, fov = maybe_pose_and_fov + # rotate the axis of the camera 180 about x axis pose = tf.SE3.from_rotation_and_translation( pose.rotation() @ tf.SO3.from_x_radians(np.pi), pose.translation() / VISER_NERFSTUDIO_SCALE_RATIO, ) + camera_path_list_dict = { "camera_to_world": pose.as_matrix().flatten().tolist(), "fov": np.rad2deg(fov), "aspect": resolution.value[0] / resolution.value[1], } + if time is not None: camera_path_list_dict["render_time"] = time camera_path_list.append(camera_path_list_dict) json_data["camera_path"] = camera_path_list + # finally add crop data if crop is enabled if control_panel is not None: if control_panel.crop_viewport: @@ -1134,18 +1180,100 @@ def _(event: viser.GuiEvent) -> None: json_outfile.parent.mkdir(parents=True, exist_ok=True) with open(json_outfile.absolute(), "w") as outfile: json.dump(json_data, outfile) - # now show the command - with event.client.gui.add_modal("Render Command") as modal: + + return json_outfile.absolute() + + @render_button.on_click + def _(event: viser.GuiEvent) -> None: + client = event.client + assert client is not None + + config = yaml.load(config_path.read_text(), Loader=yaml.Loader) + if config.load_dir is None: + render_button.disabled = True + + notif = client.add_notification( + title="Render unsuccessful", + body="Cannot render video before 2000 training steps.", + loading=False, + with_close_button=True, + color="red", + ) + return + + render_button.disabled = True + cancel_render_button.disabled = False + + render_path = f"renders/{datapath.name}/{render_name_text.value}.mp4" + json_outfile = _write_json() + + notif = client.add_notification( + title="Rendering trajectory", + body="Saving rendered video as " + render_path, + loading=True, + with_close_button=False, + ) + + from nerfstudio.scripts.render import RenderCameraPath + + render = RenderCameraPath( + load_config=config_path, + camera_path_filename=json_outfile, + output_path=Path(render_path), + ) + + @cancel_render_button.on_click + def _(event: viser.GuiEvent) -> None: + render.kill() + render_button.disabled = False + cancel_render_button.disabled = True + + notif.title = "Render canceled" + notif.body = "The render in progress has been canceled." + notif.loading = False + notif.with_close_button = True + + render.main() + + if render._complete: + notif.title = "Render complete!" + notif.body = "Video saved as " + render_path + notif.loading = False + notif.with_close_button = True + + render_button.disabled = False + download_render_button.disabled = False + + @download_render_button.on_click + def _(event: viser.GuiEvent) -> None: + render_path = f"renders/{datapath.name}/{render_name_text.value}.mp4" + + client = event.client + assert client is not None + + with open(render_path, "rb") as file: + video_bytes = file.read() + + client.send_file_download("render.mp4", video_bytes) + + @generate_command_render_button.on_click + def _(event: viser.GuiEvent) -> None: + client = event.client + assert client is not None + + json_outfile = _write_json() + + with client.gui.add_modal("Render Command") as modal: dataname = datapath.name command = " ".join( [ "ns-render camera-path", f"--load-config {config_path}", - f"--camera-path-filename {json_outfile.absolute()}", + f"--camera-path-filename {json_outfile}", f"--output-path renders/{dataname}/{render_name_text.value}.mp4", ] ) - event.client.gui.add_markdown( + client.gui.add_markdown( "\n".join( [ "To render the trajectory, run the following from the command line:", @@ -1156,16 +1284,30 @@ def _(event: viser.GuiEvent) -> None: ] ) ) - close_button = event.client.gui.add_button("Close") + close_button = client.gui.add_button("Close") @close_button.on_click def _(_) -> None: modal.close() + reset_up_button = server.gui.add_button( + "Reset Up Direction", + icon=viser.Icon.ARROW_BIG_UP_LINES, + color="gray", + hint="Set the up direction of the camera orbit controls to the camera's current up direction.", + ) + + @reset_up_button.on_click + def _(event: viser.GuiEvent) -> None: + client = event.client + assert client is not None + client.camera.up_direction = tf.SO3(client.camera.wxyz) @ np.array([0.0, -1.0, 0.0]) + if control_panel is not None: camera_path = CameraPath(server, duration_number, control_panel._time_enabled) else: camera_path = CameraPath(server, duration_number) + camera_path.tension = tension_slider.value camera_path.default_fov = fov_degrees.value / 180.0 * np.pi camera_path.default_transition_sec = transition_sec_number.value