diff --git a/nerfstudio/scripts/exporter.py b/nerfstudio/scripts/exporter.py
index b6404af376..73aee72494 100644
--- a/nerfstudio/scripts/exporter.py
+++ b/nerfstudio/scripts/exporter.py
@@ -58,6 +58,8 @@ class Exporter:
"""Path to the config YAML file."""
output_dir: Path
"""Path to the output directory."""
+ _complete: tyro.conf.Suppress[bool] = False
+ """Set to True when export is finished."""
def validate_pipeline(normal_method: str, normal_output_name: str, pipeline: Pipeline) -> None:
@@ -141,7 +143,12 @@ def main(self) -> None:
# Increase the batchsize to speed up the evaluation.
assert isinstance(
pipeline.datamanager,
- (VanillaDataManager, ParallelDataManager, FullImageDatamanager, RandomCamerasDataManager),
+ (
+ VanillaDataManager,
+ ParallelDataManager,
+ FullImageDatamanager,
+ RandomCamerasDataManager,
+ ),
)
assert pipeline.datamanager.train_pixel_sampler is not None
pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = self.num_rays_per_batch
@@ -159,7 +166,7 @@ def main(self) -> None:
estimate_normals=estimate_normals,
rgb_output_name=self.rgb_output_name,
depth_output_name=self.depth_output_name,
- normal_output_name=self.normal_output_name if self.normal_method == "model_output" else None,
+ normal_output_name=(self.normal_output_name if self.normal_method == "model_output" else None),
crop_obb=crop_obb,
std_ratio=self.std_ratio,
)
@@ -186,6 +193,8 @@ def main(self) -> None:
print("\033[A\033[A")
CONSOLE.print("[bold green]:white_check_mark: Saving Point Cloud")
+ self._complete = True
+
@dataclass
class ExportTSDFMesh(Exporter):
@@ -254,18 +263,21 @@ def main(self) -> None:
if self.texture_method == "nerf":
# load the mesh from the tsdf export
mesh = get_mesh_from_filename(
- str(self.output_dir / "tsdf_mesh.ply"), target_num_faces=self.target_num_faces
+ str(self.output_dir / "tsdf_mesh.ply"),
+ target_num_faces=self.target_num_faces,
)
CONSOLE.print("Texturing mesh with NeRF")
texture_utils.export_textured_mesh(
mesh,
pipeline,
self.output_dir,
- px_per_uv_triangle=self.px_per_uv_triangle if self.unwrap_method == "custom" else None,
+ px_per_uv_triangle=(self.px_per_uv_triangle if self.unwrap_method == "custom" else None),
unwrap_method=self.unwrap_method,
num_pixels_per_side=self.num_pixels_per_side,
)
+ self._complete = True
+
@dataclass
class ExportPoissonMesh(Exporter):
@@ -329,7 +341,12 @@ def main(self) -> None:
# Increase the batchsize to speed up the evaluation.
assert isinstance(
pipeline.datamanager,
- (VanillaDataManager, ParallelDataManager, FullImageDatamanager, RandomCamerasDataManager),
+ (
+ VanillaDataManager,
+ ParallelDataManager,
+ FullImageDatamanager,
+ RandomCamerasDataManager,
+ ),
)
assert pipeline.datamanager.train_pixel_sampler is not None
pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = self.num_rays_per_batch
@@ -349,7 +366,7 @@ def main(self) -> None:
estimate_normals=estimate_normals,
rgb_output_name=self.rgb_output_name,
depth_output_name=self.depth_output_name,
- normal_output_name=self.normal_output_name if self.normal_method == "model_output" else None,
+ normal_output_name=(self.normal_output_name if self.normal_method == "model_output" else None),
crop_obb=crop_obb,
std_ratio=self.std_ratio,
)
@@ -379,18 +396,21 @@ def main(self) -> None:
if self.texture_method == "nerf":
# load the mesh from the poisson reconstruction
mesh = get_mesh_from_filename(
- str(self.output_dir / "poisson_mesh.ply"), target_num_faces=self.target_num_faces
+ str(self.output_dir / "poisson_mesh.ply"),
+ target_num_faces=self.target_num_faces,
)
CONSOLE.print("Texturing mesh with NeRF")
texture_utils.export_textured_mesh(
mesh,
pipeline,
self.output_dir,
- px_per_uv_triangle=self.px_per_uv_triangle if self.unwrap_method == "custom" else None,
+ px_per_uv_triangle=(self.px_per_uv_triangle if self.unwrap_method == "custom" else None),
unwrap_method=self.unwrap_method,
num_pixels_per_side=self.num_pixels_per_side,
)
+ self._complete = True
+
@dataclass
class ExportMarchingCubesMesh(Exporter):
@@ -452,11 +472,13 @@ def main(self) -> None:
mesh,
pipeline,
self.output_dir,
- px_per_uv_triangle=self.px_per_uv_triangle if self.unwrap_method == "custom" else None,
+ px_per_uv_triangle=(self.px_per_uv_triangle if self.unwrap_method == "custom" else None),
unwrap_method=self.unwrap_method,
num_pixels_per_side=self.num_pixels_per_side,
)
+ self._complete = True
+
@dataclass
class ExportCameraPoses(Exporter):
@@ -473,7 +495,10 @@ def main(self) -> None:
assert isinstance(pipeline, VanillaPipeline)
train_frames, eval_frames = collect_camera_poses(pipeline)
- for file_name, frames in [("transforms_train.json", train_frames), ("transforms_eval.json", eval_frames)]:
+ for file_name, frames in [
+ ("transforms_train.json", train_frames),
+ ("transforms_eval.json", eval_frames),
+ ]:
if len(frames) == 0:
CONSOLE.print(f"[bold yellow]No frames found for {file_name}. Skipping.")
continue
@@ -654,6 +679,8 @@ def main(self) -> None:
ExportGaussianSplat.write_ply(str(filename), count, map_to_tensors)
+ self._complete = True
+
Commands = tyro.conf.FlagConversionOff[
Union[
diff --git a/nerfstudio/scripts/render.py b/nerfstudio/scripts/render.py
index 4eb4a71840..3c3c57d16a 100644
--- a/nerfstudio/scripts/render.py
+++ b/nerfstudio/scripts/render.py
@@ -75,9 +75,10 @@ def _render_trajectory_video(
depth_near_plane: Optional[float] = None,
depth_far_plane: Optional[float] = None,
colormap_options: colormaps.ColormapOptions = colormaps.ColormapOptions(),
- render_nearest_camera=False,
+ render_nearest_camera: bool = False,
check_occlusions: bool = False,
-) -> None:
+ _kill_flag: List[bool] = [False],
+) -> bool:
"""Helper function to create a video of the spiral trajectory.
Args:
@@ -137,6 +138,9 @@ def _render_trajectory_video(
with progress:
for camera_idx in progress.track(range(cameras.size), description=""):
+ if _kill_flag[0]:
+ return False
+
obb_box = None
if crop_data is not None:
obb_box = crop_data.obb
@@ -205,9 +209,13 @@ def _render_trajectory_video(
for rendered_output_name in rendered_output_names:
if rendered_output_name not in outputs:
CONSOLE.rule("Error", style="red")
- CONSOLE.print(f"Could not find {rendered_output_name} in the model outputs", justify="center")
CONSOLE.print(
- f"Please set --rendered_output_name to one of: {outputs.keys()}", justify="center"
+ f"Could not find {rendered_output_name} in the model outputs",
+ justify="center",
+ )
+ CONSOLE.print(
+ f"Please set --rendered_output_name to one of: {outputs.keys()}",
+ justify="center",
)
sys.exit(1)
output_image = outputs[rendered_output_name]
@@ -261,10 +269,17 @@ def _render_trajectory_video(
render_image = np.concatenate(render_image, axis=1)
if output_format == "images":
if image_format == "png":
- media.write_image(output_image_dir / f"{camera_idx:05d}.png", render_image, fmt="png")
+ media.write_image(
+ output_image_dir / f"{camera_idx:05d}.png",
+ render_image,
+ fmt="png",
+ )
if image_format == "jpeg":
media.write_image(
- output_image_dir / f"{camera_idx:05d}.jpg", render_image, fmt="jpeg", quality=jpeg_quality
+ output_image_dir / f"{camera_idx:05d}.jpg",
+ render_image,
+ fmt="jpeg",
+ quality=jpeg_quality,
)
if output_format == "video":
if writer is None:
@@ -292,7 +307,15 @@ def _render_trajectory_video(
table.add_row("Video", str(output_filename))
else:
table.add_row("Images", str(output_image_dir))
- CONSOLE.print(Panel(table, title="[bold][green]:tada: Render Complete :tada:[/bold]", expand=False))
+ CONSOLE.print(
+ Panel(
+ table,
+ title="[bold][green]:tada: Render Complete :tada:[/bold]",
+ expand=False,
+ )
+ )
+
+ return True
def insert_spherical_metadata_into_file(
@@ -437,6 +460,11 @@ class BaseRender:
"""If true, checks line-of-sight occlusions when computing camera distance and rejects cameras not visible to each other"""
camera_idx: Optional[int] = None
"""Index of the training camera to render."""
+ _kill_flag: tyro.conf.Suppress[List[bool]] = field(default_factory=lambda: [False])
+ """Stop execution of render if set to True."""
+
+ def kill(self) -> None:
+ self._kill_flag[0] = True
@dataclass
@@ -447,6 +475,8 @@ class RenderCameraPath(BaseRender):
"""Filename of the camera path to render."""
output_format: Literal["images", "video"] = "video"
"""How to save output data."""
+ _complete: tyro.conf.Suppress[bool] = False
+ """Set to True when render is finished."""
def main(self) -> None:
"""Main function."""
@@ -490,7 +520,7 @@ def main(self) -> None:
if self.camera_idx is not None:
camera_path.metadata = {"cam_idx": self.camera_idx}
- _render_trajectory_video(
+ self._complete = _render_trajectory_video(
pipeline,
camera_path,
output_filename=self.output_path,
@@ -506,6 +536,7 @@ def main(self) -> None:
colormap_options=self.colormap_options,
render_nearest_camera=self.render_nearest_camera,
check_occlusions=self.check_occlusions,
+ _kill_flag=self._kill_flag,
)
if (
@@ -541,6 +572,7 @@ def main(self) -> None:
colormap_options=self.colormap_options,
render_nearest_camera=self.render_nearest_camera,
check_occlusions=self.check_occlusions,
+ _kill_flag=self._kill_flag,
)
self.output_path = Path(str(left_eye_path.parent)[:-5] + ".mp4")
@@ -644,6 +676,7 @@ def main(self) -> None:
colormap_options=self.colormap_options,
render_nearest_camera=self.render_nearest_camera,
check_occlusions=self.check_occlusions,
+ _kill_flag=self._kill_flag,
)
@@ -699,6 +732,7 @@ def main(self) -> None:
colormap_options=self.colormap_options,
render_nearest_camera=self.render_nearest_camera,
check_occlusions=self.check_occlusions,
+ _kill_flag=self._kill_flag,
)
@@ -736,7 +770,10 @@ def main(self):
def update_config(config: TrainerConfig) -> TrainerConfig:
data_manager_config = config.pipeline.datamanager
- assert isinstance(data_manager_config, (VanillaDataManagerConfig, FullImageDatamanagerConfig))
+ assert isinstance(
+ data_manager_config,
+ (VanillaDataManagerConfig, FullImageDatamanagerConfig),
+ )
data_manager_config.eval_num_images_to_sample_from = -1
data_manager_config.eval_num_times_to_repeat_images = -1
if isinstance(data_manager_config, VanillaDataManagerConfig):
@@ -746,7 +783,11 @@ def update_config(config: TrainerConfig) -> TrainerConfig:
data_manager_config.data = self.data
if self.downscale_factor is not None:
assert hasattr(data_manager_config.dataparser, "downscale_factor")
- setattr(data_manager_config.dataparser, "downscale_factor", self.downscale_factor)
+ setattr(
+ data_manager_config.dataparser,
+ "downscale_factor",
+ self.downscale_factor,
+ )
return config
config, pipeline, _, _ = eval_setup(
@@ -814,10 +855,12 @@ def update_config(config: TrainerConfig) -> TrainerConfig:
if rendered_output_name not in all_outputs:
CONSOLE.rule("Error", style="red")
CONSOLE.print(
- f"Could not find {rendered_output_name} in the model outputs", justify="center"
+ f"Could not find {rendered_output_name} in the model outputs",
+ justify="center",
)
CONSOLE.print(
- f"Please set --rendered-output-name to one of: {all_outputs}", justify="center"
+ f"Please set --rendered-output-name to one of: {all_outputs}",
+ justify="center",
)
sys.exit(1)
@@ -885,7 +928,10 @@ def update_config(config: TrainerConfig) -> TrainerConfig:
media.write_image(output_path.with_suffix(".png"), output_image, fmt="png")
elif self.image_format == "jpeg":
media.write_image(
- output_path.with_suffix(".jpg"), output_image, fmt="jpeg", quality=self.jpeg_quality
+ output_path.with_suffix(".jpg"),
+ output_image,
+ fmt="jpeg",
+ quality=self.jpeg_quality,
)
else:
raise ValueError(f"Unknown image format {self.image_format}")
@@ -898,7 +944,13 @@ def update_config(config: TrainerConfig) -> TrainerConfig:
)
for split in self.split.split("+"):
table.add_row(f"Outputs {split}", str(self.output_path / split))
- CONSOLE.print(Panel(table, title="[bold][green]:tada: Render on split {} Complete :tada:[/bold]", expand=False))
+ CONSOLE.print(
+ Panel(
+ table,
+ title="[bold][green]:tada: Render on split {} Complete :tada:[/bold]",
+ expand=False,
+ )
+ )
Commands = tyro.conf.FlagConversionOff[
diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py
index 2bb3969cd5..e600fa5a31 100644
--- a/nerfstudio/viewer/export_panel.py
+++ b/nerfstudio/viewer/export_panel.py
@@ -15,10 +15,12 @@
from __future__ import annotations
from pathlib import Path
+from typing import cast
import viser
import viser.transforms as vtf
-from typing_extensions import Literal
+import yaml
+from typing_extensions import Literal, Tuple
from nerfstudio.data.scene_box import OrientedBox
from nerfstudio.models.base_model import Model
@@ -40,6 +42,7 @@ def populate_export_tab(
def _(_) -> None:
control_panel.crop_viewport = crop_output.value
+ server.gui.add_markdown("Export available after a checkpoint is saved (default minimum 2000 steps)")
with server.gui.add_folder("Splat"):
populate_splat_tab(server, control_panel, config_path, viewing_gsplat)
with server.gui.add_folder("Point Cloud"):
@@ -48,7 +51,11 @@ def _(_) -> None:
populate_mesh_tab(server, control_panel, config_path, viewing_gsplat)
-def show_command_modal(client: viser.ClientHandle, what: Literal["mesh", "point cloud", "splat"], command: str) -> None:
+def show_command_modal(
+ client: viser.ClientHandle,
+ what: Literal["mesh", "point cloud", "splat"],
+ command: str,
+) -> None:
"""Show a modal to each currently connected client.
In the future, we should only show the modal to the client that pushes the
@@ -73,7 +80,7 @@ def _(_) -> None:
modal.close()
-def get_crop_string(obb: OrientedBox, crop_viewport: bool):
+def get_crop_string(obb: OrientedBox, crop_viewport: bool) -> str:
"""Takes in an oriented bounding box and returns a string of the form "--obb_{center,rotation,scale}
and each arg formatted with spaces around it
"""
@@ -89,6 +96,24 @@ def get_crop_string(obb: OrientedBox, crop_viewport: bool):
return f"--obb_center {posstring} --obb_rotation {rpystring} --obb_scale {scalestring}"
+Vec3f = Tuple[float, float, float]
+
+
+def get_crop_tuple(obb: OrientedBox, crop_viewport: bool):
+ """Takes in an oriented bounding box and returns tuples for obb_{center,rotation,scale}."""
+ if not crop_viewport:
+ return None, None, None
+ rpy = vtf.SO3.from_matrix(obb.R.numpy(force=True)).as_rpy_radians()
+ obb_rotation = [rpy.roll, rpy.pitch, rpy.yaw]
+ obb_center = obb.T.squeeze().tolist()
+ obb_scale = obb.S.squeeze().tolist()
+ return (
+ cast(Vec3f, tuple(obb_center)),
+ cast(Vec3f, tuple(obb_rotation)),
+ cast(Vec3f, tuple(obb_scale)),
+ )
+
+
def populate_point_cloud_tab(
server: viser.ViserServer,
control_panel: ControlPanel,
@@ -114,9 +139,75 @@ def populate_point_cloud_tab(
initial_value="open3d",
hint="Normal map source.",
)
+
output_dir = server.gui.add_text("Output Directory", initial_value="exports/pcd/")
+ export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT)
+ download_button = server.gui.add_button("Download Point Cloud", icon=viser.Icon.DOWNLOAD, disabled=True)
generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2)
+ @export_button.on_click
+ def _(event: viser.GuiEvent) -> None:
+ client = event.client
+ assert client is not None
+
+ config = yaml.load(config_path.read_text(), Loader=yaml.Loader)
+ if config.load_dir is None:
+ notif = client.add_notification(
+ title="Export unsuccessful",
+ body="Cannot export before 2000 training steps.",
+ loading=False,
+ with_close_button=True,
+ color="red",
+ )
+ return
+
+ notif = client.add_notification(
+ title="Exporting point cloud",
+ body="File will be saved under " + str(output_dir.value),
+ loading=True,
+ with_close_button=False,
+ )
+
+ if control_panel.crop_obb is not None and control_panel.crop_viewport:
+ obb_center, obb_rotation, obb_scale = get_crop_tuple(
+ control_panel.crop_obb, control_panel.crop_viewport
+ )
+ else:
+ obb_center, obb_rotation, obb_scale = None, None, None
+
+ from nerfstudio.scripts.exporter import ExportPointCloud
+
+ export = ExportPointCloud(
+ load_config=config_path,
+ output_dir=Path(output_dir.value),
+ num_points=num_points.value,
+ remove_outliers=remove_outliers.value,
+ normal_method=normals.value,
+ save_world_frame=world_frame.value,
+ obb_center=obb_center,
+ obb_rotation=obb_rotation,
+ obb_scale=obb_scale,
+ )
+ export.main()
+
+ if export._complete:
+ notif.title = "Export complete!"
+ notif.body = "File saved under " + str(output_dir.value)
+ notif.loading = False
+ notif.with_close_button = True
+
+ download_button.disabled = False
+
+ @download_button.on_click
+ def _(event: viser.GuiEvent) -> None:
+ client = event.client
+ assert client is not None
+
+ with open(str(output_dir.value) + "point_cloud.ply", "rb") as ply_file:
+ ply_bytes = ply_file.read()
+
+ client.send_file_download("point_cloud.ply", ply_bytes)
+
@generate_command.on_click
def _(event: viser.GuiEvent) -> None:
assert event.client is not None
@@ -157,12 +248,78 @@ def populate_mesh_tab(
)
num_faces = server.gui.add_number("# Faces", initial_value=50_000, min=1)
texture_resolution = server.gui.add_number("Texture Resolution", min=8, initial_value=2048)
- output_directory = server.gui.add_text("Output Directory", initial_value="exports/mesh/")
num_points = server.gui.add_number("# Points", initial_value=1_000_000, min=1, max=None, step=1)
remove_outliers = server.gui.add_checkbox("Remove outliers", True)
+ output_dir = server.gui.add_text("Output Directory", initial_value="exports/mesh/")
+ export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT)
+ download_button = server.gui.add_button("Download Mesh", icon=viser.Icon.DOWNLOAD, disabled=True)
generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2)
+ @export_button.on_click
+ def _(event: viser.GuiEvent) -> None:
+ client = event.client
+ assert client is not None
+
+ config = yaml.load(config_path.read_text(), Loader=yaml.Loader)
+ if config.load_dir is None:
+ notif = client.add_notification(
+ title="Export unsuccessful",
+ body="Cannot export before 2000 training steps.",
+ loading=False,
+ with_close_button=True,
+ color="red",
+ )
+ return
+
+ notif = client.add_notification(
+ title="Exporting poisson mesh",
+ body="File will be saved under " + str(output_dir.value),
+ loading=True,
+ with_close_button=False,
+ )
+
+ if control_panel.crop_obb is not None and control_panel.crop_viewport:
+ obb_center, obb_rotation, obb_scale = get_crop_tuple(
+ control_panel.crop_obb, control_panel.crop_viewport
+ )
+ else:
+ obb_center, obb_rotation, obb_scale = None, None, None
+
+ from nerfstudio.scripts.exporter import ExportPoissonMesh
+
+ export = ExportPoissonMesh(
+ load_config=config_path,
+ output_dir=Path(output_dir.value),
+ target_num_faces=num_faces.value,
+ num_pixels_per_side=texture_resolution.value,
+ num_points=num_points.value,
+ remove_outliers=remove_outliers.value,
+ normal_method=normals.value,
+ obb_center=obb_center,
+ obb_rotation=obb_rotation,
+ obb_scale=obb_scale,
+ )
+ export.main()
+
+ if export._complete:
+ notif.title = "Export complete!"
+ notif.body = "File saved under " + str(output_dir.value)
+ notif.loading = False
+ notif.with_close_button = True
+
+ download_button.disabled = False
+
+ @download_button.on_click
+ def _(event: viser.GuiEvent) -> None:
+ client = event.client
+ assert client is not None
+
+ with open(str(output_dir.value) + "poisson_mesh.ply", "rb") as ply_file:
+ ply_bytes = ply_file.read()
+
+ client.send_file_download("poisson_mesh.ply", ply_bytes)
+
@generate_command.on_click
def _(event: viser.GuiEvent) -> None:
assert event.client is not None
@@ -170,7 +327,7 @@ def _(event: viser.GuiEvent) -> None:
[
"ns-export poisson",
f"--load-config {config_path}",
- f"--output-dir {output_directory.value}",
+ f"--output-dir {output_dir.value}",
f"--target-num-faces {num_faces.value}",
f"--num-pixels-per-side {texture_resolution.value}",
f"--num-points {num_points.value}",
@@ -193,10 +350,70 @@ def populate_splat_tab(
) -> None:
if viewing_gsplat:
server.gui.add_markdown("Generate ply export of Gaussian Splat")
-
- output_directory = server.gui.add_text("Output Directory", initial_value="exports/splat/")
+ output_dir = server.gui.add_text("Output Directory", initial_value="exports/splat/")
+ export_button = server.gui.add_button("Export", icon=viser.Icon.FILE_EXPORT)
+ download_button = server.gui.add_button("Download Splat", icon=viser.Icon.DOWNLOAD, disabled=True)
generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2)
+ @export_button.on_click
+ def _(event: viser.GuiEvent) -> None:
+ client = event.client
+ assert client is not None
+
+ config = yaml.load(config_path.read_text(), Loader=yaml.Loader)
+ if config.load_dir is None:
+ notif = client.add_notification(
+ title="Export unsuccessful",
+ body="Cannot export before 2000 training steps.",
+ loading=False,
+ with_close_button=True,
+ color="red",
+ )
+ return
+
+ notif = client.add_notification(
+ title="Exporting gaussian splat",
+ body="File will be saved under " + str(output_dir.value),
+ loading=True,
+ with_close_button=False,
+ )
+
+ if control_panel.crop_obb is not None and control_panel.crop_viewport:
+ obb_center, obb_rotation, obb_scale = get_crop_tuple(
+ control_panel.crop_obb, control_panel.crop_viewport
+ )
+ else:
+ obb_center, obb_rotation, obb_scale = None, None, None
+
+ from nerfstudio.scripts.exporter import ExportGaussianSplat
+
+ export = ExportGaussianSplat(
+ load_config=config_path,
+ output_dir=Path(output_dir.value),
+ obb_center=obb_center,
+ obb_rotation=obb_rotation,
+ obb_scale=obb_scale,
+ )
+ export.main()
+
+ if export._complete:
+ notif.title = "Export complete!"
+ notif.body = "File saved under " + str(output_dir.value)
+ notif.loading = False
+ notif.with_close_button = True
+
+ download_button.disabled = False
+
+ @download_button.on_click
+ def _(event: viser.GuiEvent) -> None:
+ client = event.client
+ assert client is not None
+
+ with open(str(output_dir.value) + "splat.ply", "rb") as ply_file:
+ ply_bytes = ply_file.read()
+
+ client.send_file_download("splat.ply", ply_bytes)
+
@generate_command.on_click
def _(event: viser.GuiEvent) -> None:
assert event.client is not None
@@ -204,7 +421,7 @@ def _(event: viser.GuiEvent) -> None:
[
"ns-export gaussian-splat",
f"--load-config {config_path}",
- f"--output-dir {output_directory.value}",
+ f"--output-dir {output_dir.value}",
get_crop_string(control_panel.crop_obb, control_panel.crop_viewport),
]
)
diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py
index 10d263f8c2..b375b3ce73 100644
--- a/nerfstudio/viewer/render_panel.py
+++ b/nerfstudio/viewer/render_panel.py
@@ -28,6 +28,7 @@
import splines.quaternion
import viser
import viser.transforms as tf
+import yaml
from scipy import interpolate
from nerfstudio.viewer.control_panel import ControlPanel
@@ -62,7 +63,10 @@ def from_camera(camera: viser.CameraHandle, aspect: float) -> Keyframe:
class CameraPath:
def __init__(
- self, server: viser.ViserServer, duration_element: viser.GuiInputHandle[float], time_enabled: bool = False
+ self,
+ server: viser.ViserServer,
+ duration_element: viser.GuiInputHandle[float],
+ time_enabled: bool = False,
):
self._server = server
self._keyframes: Dict[int, Tuple[Keyframe, viser.CameraFrustumHandle]] = {}
@@ -609,6 +613,10 @@ def _(event: viser.GuiEvent) -> None:
duration_number.value = camera_path.compute_duration()
camera_path.update_spline()
+ # Enable render only if there are two or more keyframes
+ render_button.disabled = len(camera_path._keyframes) < 2
+ generate_command_render_button.disabled = len(camera_path._keyframes) < 2
+
clear_keyframes_button = server.gui.add_button(
"Clear Keyframes",
icon=viser.Icon.TRASH,
@@ -629,6 +637,9 @@ def _(_) -> None:
camera_path.reset()
modal.close()
+ render_button.disabled = True
+ generate_command_render_button.disabled = True
+
duration_number.value = camera_path.compute_duration()
# Clear move handles.
@@ -840,6 +851,7 @@ def _(_) -> None:
position=pose.translation(),
color=(10, 200, 30),
)
+
if render_tab_state.preview_render:
for client in server.get_clients().values():
client.camera.wxyz = pose.rotation().wxyz
@@ -949,6 +961,7 @@ def _(_) -> None:
@load_camera_path_button.on_click
def _(event: viser.GuiEvent) -> None:
assert event.client is not None
+
camera_path_dir = datapath / "camera_paths"
camera_path_dir.mkdir(parents=True, exist_ok=True)
preexisting_camera_paths = list(camera_path_dir.glob("*.json"))
@@ -978,11 +991,13 @@ def _(_) -> None:
for i in range(len(keyframes)):
frame = keyframes[i]
pose = tf.SE3.from_matrix(np.array(frame["matrix"]).reshape(4, 4))
+
# apply the x rotation by 180 deg
pose = tf.SE3.from_rotation_and_translation(
pose.rotation() @ tf.SO3.from_x_radians(np.pi),
pose.translation(),
)
+
camera_path.add_camera(
Keyframe(
position=pose.translation() * VISER_NERFSTUDIO_SCALE_RATIO,
@@ -1004,6 +1019,10 @@ def _(_) -> None:
# update the render name
render_name_text.value = json_path.stem
camera_path.update_spline()
+
+ if len(camera_path._keyframes) > 1:
+ render_button.disabled = False
+
modal.close()
cancel_button = event.client.gui.add_button("Cancel")
@@ -1019,30 +1038,45 @@ def _(_) -> None:
initial_value=now.strftime("%Y-%m-%d-%H-%M-%S"),
hint="Name of the render",
)
- render_button = server.gui.add_button(
- "Generate Command",
- color="green",
- icon=viser.Icon.FILE_EXPORT,
- hint="Generate the ns-render command for rendering the camera path.",
- )
- reset_up_button = server.gui.add_button(
- "Reset Up Direction",
- icon=viser.Icon.ARROW_BIG_UP_LINES,
- color="gray",
- hint="Set the up direction of the camera orbit controls to the camera's current up direction.",
- )
+ render_folder = server.gui.add_folder("Render")
+ with render_folder:
+ server.gui.add_markdown(
+ "Render available after a checkpoint is saved (default minimum 2000 steps)"
+ )
- @reset_up_button.on_click
- def _(event: viser.GuiEvent) -> None:
- assert event.client is not None
- event.client.camera.up_direction = tf.SO3(event.client.camera.wxyz) @ np.array([0.0, -1.0, 0.0])
+ render_button = server.gui.add_button(
+ "Render",
+ icon=viser.Icon.VIDEO,
+ hint="Render the camera path and save video as mp4 file.",
+ disabled=True,
+ )
- @render_button.on_click
- def _(event: viser.GuiEvent) -> None:
- assert event.client is not None
+ cancel_render_button = server.gui.add_button(
+ "Cancel Render",
+ icon=viser.Icon.CIRCLE_X,
+ hint="Cancel current render in progress.",
+ disabled=True,
+ )
+
+ download_render_button = server.gui.add_button(
+ "Download Render",
+ icon=viser.Icon.DOWNLOAD,
+ hint="Download the latest render locally as mp4 file.",
+ disabled=True,
+ )
+
+ generate_command_render_button = server.gui.add_button(
+ "Generate Command",
+ icon=viser.Icon.FILE_EXPORT,
+ hint="Generate the ns-render command for rendering the camera path instead of directly rendering.",
+ disabled=True,
+ )
+
+ def _write_json() -> Path:
num_frames = int(framerate_number.value * duration_number.value)
json_data = {}
+
# json data has the properties:
# keyframes: list of keyframes with
# matrix : flattened 4x4 matrix
@@ -1059,13 +1093,15 @@ def _(event: viser.GuiEvent) -> None:
# camera_to_world: flattened 4x4 matrix
# fov: float in degrees
# aspect: float
+
# first populate the keyframes:
keyframes = []
- for keyframe, dummy in camera_path._keyframes.values():
+ for keyframe, _ in camera_path._keyframes.values():
pose = tf.SE3.from_rotation_and_translation(
tf.SO3(keyframe.wxyz) @ tf.SO3.from_x_radians(np.pi),
keyframe.position / VISER_NERFSTUDIO_SCALE_RATIO,
)
+
keyframe_dict = {
"matrix": pose.as_matrix().flatten().tolist(),
"fov": np.rad2deg(keyframe.override_fov_rad) if keyframe.override_fov_enabled else fov_degrees.value,
@@ -1073,15 +1109,20 @@ def _(event: viser.GuiEvent) -> None:
"override_transition_enabled": keyframe.override_transition_enabled,
"override_transition_sec": keyframe.override_transition_sec,
}
+
if render_time is not None:
keyframe_dict["render_time"] = (
keyframe.override_time_val if keyframe.override_time_enabled else render_time.value
)
keyframe_dict["override_time_enabled"] = keyframe.override_time_enabled
+
keyframes.append(keyframe_dict)
+
json_data["default_fov"] = fov_degrees.value
+
if render_time is not None:
json_data["default_time"] = render_time.value if render_time is not None else None
+
json_data["default_transition_sec"] = transition_sec_number.value
json_data["keyframes"] = keyframes
json_data["camera_type"] = camera_type.value.lower()
@@ -1091,31 +1132,36 @@ def _(event: viser.GuiEvent) -> None:
json_data["seconds"] = duration_number.value
json_data["is_cycle"] = loop.value
json_data["smoothness_value"] = tension_slider.value
+
# now populate the camera path:
camera_path_list = []
for i in range(num_frames):
maybe_pose_and_fov = camera_path.interpolate_pose_and_fov_rad(i / num_frames)
if maybe_pose_and_fov is None:
- return
+ return Path()
time = None
if len(maybe_pose_and_fov) == 3: # Time is enabled.
pose, fov, time = maybe_pose_and_fov
else:
pose, fov = maybe_pose_and_fov
+
# rotate the axis of the camera 180 about x axis
pose = tf.SE3.from_rotation_and_translation(
pose.rotation() @ tf.SO3.from_x_radians(np.pi),
pose.translation() / VISER_NERFSTUDIO_SCALE_RATIO,
)
+
camera_path_list_dict = {
"camera_to_world": pose.as_matrix().flatten().tolist(),
"fov": np.rad2deg(fov),
"aspect": resolution.value[0] / resolution.value[1],
}
+
if time is not None:
camera_path_list_dict["render_time"] = time
camera_path_list.append(camera_path_list_dict)
json_data["camera_path"] = camera_path_list
+
# finally add crop data if crop is enabled
if control_panel is not None:
if control_panel.crop_viewport:
@@ -1134,18 +1180,100 @@ def _(event: viser.GuiEvent) -> None:
json_outfile.parent.mkdir(parents=True, exist_ok=True)
with open(json_outfile.absolute(), "w") as outfile:
json.dump(json_data, outfile)
- # now show the command
- with event.client.gui.add_modal("Render Command") as modal:
+
+ return json_outfile.absolute()
+
+ @render_button.on_click
+ def _(event: viser.GuiEvent) -> None:
+ client = event.client
+ assert client is not None
+
+ config = yaml.load(config_path.read_text(), Loader=yaml.Loader)
+ if config.load_dir is None:
+ render_button.disabled = True
+
+ notif = client.add_notification(
+ title="Render unsuccessful",
+ body="Cannot render video before 2000 training steps.",
+ loading=False,
+ with_close_button=True,
+ color="red",
+ )
+ return
+
+ render_button.disabled = True
+ cancel_render_button.disabled = False
+
+ render_path = f"renders/{datapath.name}/{render_name_text.value}.mp4"
+ json_outfile = _write_json()
+
+ notif = client.add_notification(
+ title="Rendering trajectory",
+ body="Saving rendered video as " + render_path,
+ loading=True,
+ with_close_button=False,
+ )
+
+ from nerfstudio.scripts.render import RenderCameraPath
+
+ render = RenderCameraPath(
+ load_config=config_path,
+ camera_path_filename=json_outfile,
+ output_path=Path(render_path),
+ )
+
+ @cancel_render_button.on_click
+ def _(event: viser.GuiEvent) -> None:
+ render.kill()
+ render_button.disabled = False
+ cancel_render_button.disabled = True
+
+ notif.title = "Render canceled"
+ notif.body = "The render in progress has been canceled."
+ notif.loading = False
+ notif.with_close_button = True
+
+ render.main()
+
+ if render._complete:
+ notif.title = "Render complete!"
+ notif.body = "Video saved as " + render_path
+ notif.loading = False
+ notif.with_close_button = True
+
+ render_button.disabled = False
+ download_render_button.disabled = False
+
+ @download_render_button.on_click
+ def _(event: viser.GuiEvent) -> None:
+ render_path = f"renders/{datapath.name}/{render_name_text.value}.mp4"
+
+ client = event.client
+ assert client is not None
+
+ with open(render_path, "rb") as file:
+ video_bytes = file.read()
+
+ client.send_file_download("render.mp4", video_bytes)
+
+ @generate_command_render_button.on_click
+ def _(event: viser.GuiEvent) -> None:
+ client = event.client
+ assert client is not None
+
+ json_outfile = _write_json()
+
+ with client.gui.add_modal("Render Command") as modal:
dataname = datapath.name
command = " ".join(
[
"ns-render camera-path",
f"--load-config {config_path}",
- f"--camera-path-filename {json_outfile.absolute()}",
+ f"--camera-path-filename {json_outfile}",
f"--output-path renders/{dataname}/{render_name_text.value}.mp4",
]
)
- event.client.gui.add_markdown(
+ client.gui.add_markdown(
"\n".join(
[
"To render the trajectory, run the following from the command line:",
@@ -1156,16 +1284,30 @@ def _(event: viser.GuiEvent) -> None:
]
)
)
- close_button = event.client.gui.add_button("Close")
+ close_button = client.gui.add_button("Close")
@close_button.on_click
def _(_) -> None:
modal.close()
+ reset_up_button = server.gui.add_button(
+ "Reset Up Direction",
+ icon=viser.Icon.ARROW_BIG_UP_LINES,
+ color="gray",
+ hint="Set the up direction of the camera orbit controls to the camera's current up direction.",
+ )
+
+ @reset_up_button.on_click
+ def _(event: viser.GuiEvent) -> None:
+ client = event.client
+ assert client is not None
+ client.camera.up_direction = tf.SO3(client.camera.wxyz) @ np.array([0.0, -1.0, 0.0])
+
if control_panel is not None:
camera_path = CameraPath(server, duration_number, control_panel._time_enabled)
else:
camera_path = CameraPath(server, duration_number)
+
camera_path.tension = tension_slider.value
camera_path.default_fov = fov_degrees.value / 180.0 * np.pi
camera_path.default_transition_sec = transition_sec_number.value