Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruilong Li committed Jan 3, 2024
1 parent 88a6aec commit 3eef920
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 20 deletions.
16 changes: 13 additions & 3 deletions examples/datasets/nerf_360_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,9 @@ def preprocess(self, data):

if self.training:
if self.color_bkgd_aug == "random":
color_bkgd = torch.rand(3, device=self.images.device, generator=self.g)
color_bkgd = torch.rand(
3, device=self.images.device, generator=self.g
)
elif self.color_bkgd_aug == "white":
color_bkgd = torch.ones(3, device=self.images.device)
elif self.color_bkgd_aug == "black":
Expand Down Expand Up @@ -311,10 +313,18 @@ def fetch_data(self, index):
else:
image_id = [index] * num_rays
x = torch.randint(
0, self.width, size=(num_rays,), device=self.images.device, generator=self.g
0,
self.width,
size=(num_rays,),
device=self.images.device,
generator=self.g,
)
y = torch.randint(
0, self.height, size=(num_rays,), device=self.images.device, generator=self.g
0,
self.height,
size=(num_rays,),
device=self.images.device,
generator=self.g,
)
else:
image_id = [index]
Expand Down
16 changes: 13 additions & 3 deletions examples/datasets/nerf_synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def preprocess(self, data):

if self.training:
if self.color_bkgd_aug == "random":
color_bkgd = torch.rand(3, device=self.images.device, generator=self.g)
color_bkgd = torch.rand(
3, device=self.images.device, generator=self.g
)
elif self.color_bkgd_aug == "white":
color_bkgd = torch.ones(3, device=self.images.device)
elif self.color_bkgd_aug == "black":
Expand Down Expand Up @@ -179,10 +181,18 @@ def fetch_data(self, index):
else:
image_id = [index] * num_rays
x = torch.randint(
0, self.WIDTH, size=(num_rays,), device=self.images.device, generator=self.g
0,
self.WIDTH,
size=(num_rays,),
device=self.images.device,
generator=self.g,
)
y = torch.randint(
0, self.HEIGHT, size=(num_rays,), device=self.images.device, generator=self.g
0,
self.HEIGHT,
size=(num_rays,),
device=self.images.device,
generator=self.g,
)
else:
image_id = [index]
Expand Down
12 changes: 9 additions & 3 deletions examples/train_ngp_nerf_occ.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from nerfacc.estimators.occ_grid import OccGridEstimator


def run(args):
device = "cuda:0"
set_random_seed(42)
Expand Down Expand Up @@ -102,7 +103,10 @@ def run(args):
grad_scaler = torch.cuda.amp.GradScaler(2**10)
radiance_field = NGPRadianceField(aabb=estimator.aabbs[-1]).to(device)
optimizer = torch.optim.Adam(
radiance_field.parameters(), lr=1e-2, eps=1e-15, weight_decay=weight_decay
radiance_field.parameters(),
lr=1e-2,
eps=1e-15,
weight_decay=weight_decay,
)
scheduler = torch.optim.lr_scheduler.ChainedScheduler(
[
Expand Down Expand Up @@ -167,7 +171,8 @@ def occ_eval_fn(x):
# dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels)
num_rays = int(
num_rays * (target_sample_batch_size / float(n_rendering_samples))
num_rays
* (target_sample_batch_size / float(n_rendering_samples))
)
train_dataset.update_num_rays(num_rays)

Expand Down Expand Up @@ -249,6 +254,7 @@ def occ_eval_fn(x):
lpips_avg = sum(lpips) / len(lpips)
print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
Expand All @@ -274,4 +280,4 @@ def occ_eval_fn(x):
)
args = parser.parse_args()

run(args)
run(args)
35 changes: 24 additions & 11 deletions nerfacc/estimators/n3tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from typing import Callable, List, Optional, Tuple, Union

import torch
from torch import Tensor

from ..grid import _enlarge_aabb
from ..volrend import (
render_visibility_from_alpha,
render_visibility_from_density,
)
from .base import AbstractEstimator
from torch import Tensor

try:
import svox
Expand All @@ -21,7 +22,7 @@

class N3TreeEstimator(AbstractEstimator):
"""Use N3Tree to implement Occupancy Grid.
This allows more flexible topologies than the cascaded grid. However, it is
slower to create samples from the tree than the cascaded grid. By default,
it has the same topology as the cascaded grid but `self.tree` can be
Expand All @@ -43,7 +44,9 @@ def __init__(
)

# check the resolution is legal
assert isinstance(resolution, int), "N3Tree only supports uniform resolution!"
assert isinstance(
resolution, int
), "N3Tree only supports uniform resolution!"

# check the roi_aabb is legal
if isinstance(roi_aabb, (list, tuple)):
Expand Down Expand Up @@ -148,16 +151,18 @@ def sampling(
"""

assert t_min is None and t_max is None, (
"Do not supported per-ray min max. Please use near_plane and far_plane instead."
)
assert (
t_min is None and t_max is None
), "Do not supported per-ray min max. Please use near_plane and far_plane instead."
if stratified:
near_plane += torch.rand(()).item() * render_step_size

t_starts, t_ends, packed_info, ray_indices = svox.volume_sample(
self.tree,
thresh=self.thresh,
rays=svox.Rays(rays_o.contiguous(), rays_d.contiguous(), rays_d.contiguous()),
rays=svox.Rays(
rays_o.contiguous(), rays_d.contiguous(), rays_d.contiguous()
),
step_size=render_step_size,
cone_angle=cone_angle,
near_plane=near_plane,
Expand Down Expand Up @@ -253,10 +258,16 @@ def update_every_n_steps(
@torch.no_grad()
def _sample_uniform_and_occupied_cells(self, n: int) -> List[Tensor]:
"""Samples both n uniform and occupied cells."""
uniform_indices = torch.randint(len(self.tree), (n,), device=self.device)
occupied_indices = torch.nonzero(self.tree[:].values >= self.thresh)[:, 0]
uniform_indices = torch.randint(
len(self.tree), (n,), device=self.device
)
occupied_indices = torch.nonzero(self.tree[:].values >= self.thresh)[
:, 0
]
if n < len(occupied_indices):
selector = torch.randint(len(occupied_indices), (n,), device=self.device)
selector = torch.randint(
len(occupied_indices), (n,), device=self.device
)
occupied_indices = occupied_indices[selector]
indices = torch.cat([uniform_indices, occupied_indices], dim=0)
return indices
Expand All @@ -275,7 +286,9 @@ def _update(
x = self.tree.sample(1).squeeze(1)
occ = occ_eval_fn(x).squeeze(-1)
sel = (*self.tree._all_leaves().T,)
self.tree.data.data[sel] = torch.maximum(self.tree.data.data[sel] * ema_decay, occ[:, None])
self.tree.data.data[sel] = torch.maximum(
self.tree.data.data[sel] * ema_decay, occ[:, None]
)
else:
N = len(self.tree) // 4
indices = self._sample_uniform_and_occupied_cells(N)
Expand Down

0 comments on commit 3eef920

Please sign in to comment.