diff --git a/examples/benchmarks/basic.sh b/examples/benchmarks/basic.sh index 6a0986aa3..6b043c567 100644 --- a/examples/benchmarks/basic.sh +++ b/examples/benchmarks/basic.sh @@ -1,6 +1,7 @@ SCENE_DIR="data/360_v2" RESULT_DIR="results/benchmark" SCENE_LIST="garden bicycle stump bonsai counter kitchen room" # treehill flowers +RENDER_TRAJ_PATH="ellipse" for SCENE in $SCENE_LIST; do @@ -14,6 +15,7 @@ do # train without eval CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ + --render_traj_path $RENDER_TRAJ_PATH \ --data_dir data/360_v2/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE/ @@ -21,6 +23,7 @@ do for CKPT in $RESULT_DIR/$SCENE/ckpts/*; do CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default --disable_viewer --data_factor $DATA_FACTOR \ + --render_traj_path $RENDER_TRAJ_PATH \ --data_dir data/360_v2/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE/ \ --ckpt $CKPT diff --git a/examples/benchmarks/mcmc.sh b/examples/benchmarks/mcmc.sh index 4b0add753..23e40838d 100644 --- a/examples/benchmarks/mcmc.sh +++ b/examples/benchmarks/mcmc.sh @@ -1,6 +1,7 @@ SCENE_DIR="data/360_v2" RESULT_DIR="results/benchmark_mcmc_1M" SCENE_LIST="garden bicycle stump bonsai counter kitchen room" # treehill flowers +RENDER_TRAJ_PATH="ellipse" CAP_MAX=1000000 @@ -17,6 +18,7 @@ do # train without eval CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ --strategy.cap-max $CAP_MAX \ + --render_traj_path $RENDER_TRAJ_PATH \ --data_dir data/360_v2/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE/ @@ -25,6 +27,7 @@ do do CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ --strategy.cap-max $CAP_MAX \ + --render_traj_path $RENDER_TRAJ_PATH \ --data_dir $SCENE_DIR/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE/ \ --ckpt $CKPT diff --git a/examples/datasets/traj.py b/examples/datasets/traj.py index 8d49aa711..8fcc981b2 100644 --- a/examples/datasets/traj.py +++ b/examples/datasets/traj.py @@ -90,7 +90,7 @@ def get_positions(theta): ind_up = np.argmax(np.abs(avg_up)) up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up]) - return np.stack([viewmatrix(p - center, up, p) for p in positions]) + return np.stack([viewmatrix(center - p, up, p) for p in positions]) def generate_ellipse_path_y( diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index a6c8f65dc..ad4b3d353 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -15,7 +15,7 @@ import viser import yaml from datasets.colmap import Dataset, Parser -from datasets.traj import generate_interpolated_path +from datasets.traj import generate_interpolated_path, generate_ellipse_path_z from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -38,6 +38,8 @@ class Config: ckpt: Optional[List[str]] = None # Name of compression strategy to use compression: Optional[Literal["png"]] = None + # Render trajectory path + render_traj_path: str = "interp" # Path to the Mip-NeRF 360 dataset data_dir: str = "data/360_v2/garden" @@ -749,16 +751,19 @@ def eval(self, step: int, stage: str = "val"): near_plane=cfg.near_plane, far_plane=cfg.far_plane, ) # [1, H, W, 3] - colors = torch.clamp(colors, 0.0, 1.0) torch.cuda.synchronize() ellipse_time += time.time() - tic + colors = torch.clamp(colors, 0.0, 1.0) + canvas_list = [pixels, colors] + if world_rank == 0: # write images - canvas = torch.cat([pixels, colors], dim=2).squeeze(0).cpu().numpy() + canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() + canvas = (canvas * 255).astype(np.uint8) imageio.imwrite( - f"{self.render_dir}/{stage}_{i:04d}.png", - (canvas * 255).astype(np.uint8), + f"{self.render_dir}/{stage}_step{step}_{i:04d}.png", + canvas, ) pixels = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] @@ -800,25 +805,43 @@ def render_traj(self, step: int): cfg = self.cfg device = self.device - camtoworlds = self.parser.camtoworlds[5:-5] - camtoworlds = generate_interpolated_path(camtoworlds, 1) # [N, 3, 4] - camtoworlds = np.concatenate( + camtoworlds_all = self.parser.camtoworlds[5:-5] + if cfg.render_traj_path == "interp": + camtoworlds_all = generate_interpolated_path( + camtoworlds_all, 1 + ) # [N, 3, 4] + elif cfg.render_traj_path == "ellipse": + height = camtoworlds_all[:, 2, 3].mean() + camtoworlds_all = generate_ellipse_path_z( + camtoworlds_all, height=height + ) # [N, 3, 4] + else: + raise ValueError( + f"Render trajectory type not supported: {cfg.render_traj_path}" + ) + + camtoworlds_all = np.concatenate( [ - camtoworlds, - np.repeat(np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds), axis=0), + camtoworlds_all, + np.repeat( + np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds_all), axis=0 + ), ], axis=1, ) # [N, 4, 4] - camtoworlds = torch.from_numpy(camtoworlds).float().to(device) + camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device) K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device) width, height = list(self.parser.imsize_dict.values())[0] canvas_all = [] - for i in tqdm.trange(len(camtoworlds), desc="Rendering trajectory"): + for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"): + camtoworlds = camtoworlds_all[i : i + 1] + Ks = K[None] + renders, _, _ = self.rasterize_splats( - camtoworlds=camtoworlds[i : i + 1], - Ks=K[None], + camtoworlds=camtoworlds, + Ks=Ks, width=width, height=height, sh_degree=cfg.sh_degree, @@ -826,15 +849,14 @@ def render_traj(self, step: int): far_plane=cfg.far_plane, render_mode="RGB+ED", ) # [1, H, W, 4] - colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3] - depths = renders[0, ..., 3:4] # [H, W, 1] + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] + depths = renders[..., 3:4] # [1, H, W, 1] depths = (depths - depths.min()) / (depths.max() - depths.min()) + canvas_list = [colors, depths.repeat(1, 1, 1, 3)] # write images - canvas = torch.cat( - [colors, depths.repeat(1, 1, 3)], dim=0 if width > height else 1 - ) - canvas = (canvas.cpu().numpy() * 255).astype(np.uint8) + canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() + canvas = (canvas * 255).astype(np.uint8) canvas_all.append(canvas) # save to video @@ -903,7 +925,7 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config): runner.splats[k].data = torch.cat([ckpt["splats"][k] for ckpt in ckpts]) step = ckpts[0]["step"] runner.eval(step=step) - # runner.render_traj(step=step) + runner.render_traj(step=step) if cfg.compression is not None: runner.run_compression(step=step) else: