From 07d9188590ae1ee1e9ff7794a4217cc4591b44cf Mon Sep 17 00:00:00 2001 From: Hang Date: Fri, 7 Jun 2024 17:07:51 -0700 Subject: [PATCH] Update 3DGS examples to use the latest nerfview from pypi (#206) --- examples/requirements.txt | 4 ++-- examples/simple_trainer.py | 37 ++++++++++++++++++------------------- examples/simple_viewer.py | 26 +++++++++++--------------- 3 files changed, 31 insertions(+), 36 deletions(-) diff --git a/examples/requirements.txt b/examples/requirements.txt index 534414028..695e8b5ae 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -2,11 +2,11 @@ # pycolmap for data parsing git+https://github.com/rmbrualla/pycolmap@cc7ea4b7301720ac29287dbe450952511b32125e -# nerfview for viewer -git+https://github.com/hangg7/nerfview@4dde5291debd21ba33d768d9a8193aca87fc38fd # (optional) nerfacc for torch version rasterization # git+https://github.com/nerfstudio-project/nerfacc +viser +nerfview==0.0.2 imageio[ffmpeg] numpy scikit-learn diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index cfd2b6138..e2e6cf51b 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -11,9 +11,10 @@ import torch.nn.functional as F import tqdm import tyro +import viser +import nerfview from datasets.colmap import Dataset, Parser from datasets.traj import generate_interpolated_path -from nerfview import VIEWER_LOCK, CameraState, ViewerServer from torch import Tensor from torch.utils.tensorboard import SummaryWriter from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure @@ -305,7 +306,12 @@ def __init__(self, cfg: Config) -> None: # Viewer if not self.cfg.disable_viewer: - self.server = ViewerServer(port=cfg.port, render_fn=self._viewer_render_fn) + self.server = viser.ViserServer(port=cfg.port, verbose=False) + self.viewer = nerfview.Viewer( + server=self.server, + render_fn=self._viewer_render_fn, + mode="training", + ) # Running stats for prunning & growing. n_gauss = len(self.splats["means3d"]) @@ -401,9 +407,9 @@ def train(self): pbar = tqdm.tqdm(range(init_step, max_steps)) for step in pbar: if not cfg.disable_viewer: - while self.server.state.status == "paused": + while self.viewer.state.status == "paused": time.sleep(0.01) - VIEWER_LOCK.acquire() + self.viewer.lock.acquire() tic = time.time() try: @@ -624,15 +630,15 @@ def train(self): self.render_traj(step) if not cfg.disable_viewer: - VIEWER_LOCK.release() + self.viewer.lock.release() num_train_steps_per_sec = 1.0 / (time.time() - tic) num_train_rays_per_sec = ( num_train_rays_per_step * num_train_steps_per_sec ) # Update the viewer state. - self.server.state.num_train_rays_per_sec = num_train_rays_per_sec + self.viewer.state.num_train_rays_per_sec = num_train_rays_per_sec # Update the scene. - self.server.update(step, num_train_rays_per_step) + self.viewer.update(step, num_train_rays_per_step) @torch.no_grad() def update_running_stats(self, info: Dict): @@ -909,20 +915,13 @@ def render_traj(self, step: int): print(f"Video saved to {video_dir}/traj_{step}.mp4") @torch.no_grad() - def _viewer_render_fn(self, camera_state: CameraState, img_wh: Tuple[int, int]): + def _viewer_render_fn( + self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int] + ): """Callable function for the viewer.""" - fov = camera_state.fov - c2w = camera_state.c2w W, H = img_wh - - focal_length = H / 2.0 / np.tan(fov / 2.0) - K = np.array( - [ - [focal_length, 0.0, W / 2.0], - [0.0, focal_length, H / 2.0], - [0.0, 0.0, 1.0], - ] - ) + c2w = camera_state.c2w + K = camera_state.get_K(img_wh) c2w = torch.from_numpy(c2w).float().to(self.device) K = torch.from_numpy(K).float().to(self.device) diff --git a/examples/simple_viewer.py b/examples/simple_viewer.py index 5a7ac9c71..84dd6ba56 100644 --- a/examples/simple_viewer.py +++ b/examples/simple_viewer.py @@ -12,11 +12,11 @@ from typing import Tuple import imageio +import nerfview import numpy as np import torch import torch.nn.functional as F -from nerfview import CameraState, ViewerServer - +import viser from gsplat._helper import load_test_data from gsplat.rendering import rasterization @@ -135,19 +135,10 @@ # register and open viewer @torch.no_grad() -def viewer_render_fn(camera_state: CameraState, img_wh: Tuple[int, int]): - fov = camera_state.fov - c2w = camera_state.c2w +def viewer_render_fn(camera_state: nerfview.CameraState, img_wh: Tuple[int, int]): width, height = img_wh - - focal_length = height / 2.0 / np.tan(fov / 2.0) - K = np.array( - [ - [focal_length, 0.0, width / 2.0], - [0.0, focal_length, height / 2.0], - [0.0, 0.0, 1.0], - ] - ) + c2w = camera_state.c2w + K = camera_state.get_K(img_wh) c2w = torch.from_numpy(c2w).float().to(device) K = torch.from_numpy(K).float().to(device) viewmat = c2w.inverse() @@ -184,6 +175,11 @@ def viewer_render_fn(camera_state: CameraState, img_wh: Tuple[int, int]): return render_rgbs -server = ViewerServer(port=args.port, render_fn=viewer_render_fn, mode="rendering") +server = viser.ViserServer(port=args.port, verbose=False) +_ = nerfview.Viewer( + server=server, + render_fn=viewer_render_fn, + mode="rendering", +) print("Viewer running... Ctrl+C to exit.") time.sleep(100000)