Skip to content

Commit

Permalink
Add option to render camera trajectory as an ellipse. (#380)
Browse files Browse the repository at this point in the history
* ellipse

* cleanup

* uncomment

* canvas list
  • Loading branch information
jefequien authored Sep 3, 2024
1 parent 9c6e591 commit 57c77b9
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 22 deletions.
3 changes: 3 additions & 0 deletions examples/benchmarks/basic.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
SCENE_DIR="data/360_v2"
RESULT_DIR="results/benchmark"
SCENE_LIST="garden bicycle stump bonsai counter kitchen room" # treehill flowers
RENDER_TRAJ_PATH="ellipse"

for SCENE in $SCENE_LIST;
do
Expand All @@ -14,13 +15,15 @@ do

# train without eval
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \
--render_traj_path $RENDER_TRAJ_PATH \
--data_dir data/360_v2/$SCENE/ \
--result_dir $RESULT_DIR/$SCENE/

# run eval and render
for CKPT in $RESULT_DIR/$SCENE/ckpts/*;
do
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default --disable_viewer --data_factor $DATA_FACTOR \
--render_traj_path $RENDER_TRAJ_PATH \
--data_dir data/360_v2/$SCENE/ \
--result_dir $RESULT_DIR/$SCENE/ \
--ckpt $CKPT
Expand Down
3 changes: 3 additions & 0 deletions examples/benchmarks/mcmc.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
SCENE_DIR="data/360_v2"
RESULT_DIR="results/benchmark_mcmc_1M"
SCENE_LIST="garden bicycle stump bonsai counter kitchen room" # treehill flowers
RENDER_TRAJ_PATH="ellipse"

CAP_MAX=1000000

Expand All @@ -17,6 +18,7 @@ do
# train without eval
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \
--strategy.cap-max $CAP_MAX \
--render_traj_path $RENDER_TRAJ_PATH \
--data_dir data/360_v2/$SCENE/ \
--result_dir $RESULT_DIR/$SCENE/

Expand All @@ -25,6 +27,7 @@ do
do
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \
--strategy.cap-max $CAP_MAX \
--render_traj_path $RENDER_TRAJ_PATH \
--data_dir $SCENE_DIR/$SCENE/ \
--result_dir $RESULT_DIR/$SCENE/ \
--ckpt $CKPT
Expand Down
2 changes: 1 addition & 1 deletion examples/datasets/traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_positions(theta):
ind_up = np.argmax(np.abs(avg_up))
up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])

return np.stack([viewmatrix(p - center, up, p) for p in positions])
return np.stack([viewmatrix(center - p, up, p) for p in positions])


def generate_ellipse_path_y(
Expand Down
64 changes: 43 additions & 21 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import viser
import yaml
from datasets.colmap import Dataset, Parser
from datasets.traj import generate_interpolated_path
from datasets.traj import generate_interpolated_path, generate_ellipse_path_z
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
Expand All @@ -38,6 +38,8 @@ class Config:
ckpt: Optional[List[str]] = None
# Name of compression strategy to use
compression: Optional[Literal["png"]] = None
# Render trajectory path
render_traj_path: str = "interp"

# Path to the Mip-NeRF 360 dataset
data_dir: str = "data/360_v2/garden"
Expand Down Expand Up @@ -749,16 +751,19 @@ def eval(self, step: int, stage: str = "val"):
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
) # [1, H, W, 3]
colors = torch.clamp(colors, 0.0, 1.0)
torch.cuda.synchronize()
ellipse_time += time.time() - tic

colors = torch.clamp(colors, 0.0, 1.0)
canvas_list = [pixels, colors]

if world_rank == 0:
# write images
canvas = torch.cat([pixels, colors], dim=2).squeeze(0).cpu().numpy()
canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy()
canvas = (canvas * 255).astype(np.uint8)
imageio.imwrite(
f"{self.render_dir}/{stage}_{i:04d}.png",
(canvas * 255).astype(np.uint8),
f"{self.render_dir}/{stage}_step{step}_{i:04d}.png",
canvas,
)

pixels = pixels.permute(0, 3, 1, 2) # [1, 3, H, W]
Expand Down Expand Up @@ -800,41 +805,58 @@ def render_traj(self, step: int):
cfg = self.cfg
device = self.device

camtoworlds = self.parser.camtoworlds[5:-5]
camtoworlds = generate_interpolated_path(camtoworlds, 1) # [N, 3, 4]
camtoworlds = np.concatenate(
camtoworlds_all = self.parser.camtoworlds[5:-5]
if cfg.render_traj_path == "interp":
camtoworlds_all = generate_interpolated_path(
camtoworlds_all, 1
) # [N, 3, 4]
elif cfg.render_traj_path == "ellipse":
height = camtoworlds_all[:, 2, 3].mean()
camtoworlds_all = generate_ellipse_path_z(
camtoworlds_all, height=height
) # [N, 3, 4]
else:
raise ValueError(
f"Render trajectory type not supported: {cfg.render_traj_path}"
)

camtoworlds_all = np.concatenate(
[
camtoworlds,
np.repeat(np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds), axis=0),
camtoworlds_all,
np.repeat(
np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds_all), axis=0
),
],
axis=1,
) # [N, 4, 4]

camtoworlds = torch.from_numpy(camtoworlds).float().to(device)
camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device)
K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device)
width, height = list(self.parser.imsize_dict.values())[0]

canvas_all = []
for i in tqdm.trange(len(camtoworlds), desc="Rendering trajectory"):
for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"):
camtoworlds = camtoworlds_all[i : i + 1]
Ks = K[None]

renders, _, _ = self.rasterize_splats(
camtoworlds=camtoworlds[i : i + 1],
Ks=K[None],
camtoworlds=camtoworlds,
Ks=Ks,
width=width,
height=height,
sh_degree=cfg.sh_degree,
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
render_mode="RGB+ED",
) # [1, H, W, 4]
colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3]
depths = renders[0, ..., 3:4] # [H, W, 1]
colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3]
depths = renders[..., 3:4] # [1, H, W, 1]
depths = (depths - depths.min()) / (depths.max() - depths.min())
canvas_list = [colors, depths.repeat(1, 1, 1, 3)]

# write images
canvas = torch.cat(
[colors, depths.repeat(1, 1, 3)], dim=0 if width > height else 1
)
canvas = (canvas.cpu().numpy() * 255).astype(np.uint8)
canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy()
canvas = (canvas * 255).astype(np.uint8)
canvas_all.append(canvas)

# save to video
Expand Down Expand Up @@ -903,7 +925,7 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config):
runner.splats[k].data = torch.cat([ckpt["splats"][k] for ckpt in ckpts])
step = ckpts[0]["step"]
runner.eval(step=step)
# runner.render_traj(step=step)
runner.render_traj(step=step)
if cfg.compression is not None:
runner.run_compression(step=step)
else:
Expand Down

0 comments on commit 57c77b9

Please sign in to comment.