From 3eef9208b0567c8531b69a55a1812aaf6478536f Mon Sep 17 00:00:00 2001 From: Ruilong Li <397653553@qq.com> Date: Tue, 2 Jan 2024 17:40:11 -0800 Subject: [PATCH] format --- examples/datasets/nerf_360_v2.py | 16 ++++++++++--- examples/datasets/nerf_synthetic.py | 16 ++++++++++--- examples/train_ngp_nerf_occ.py | 12 +++++++--- nerfacc/estimators/n3tree.py | 35 ++++++++++++++++++++--------- 4 files changed, 59 insertions(+), 20 deletions(-) diff --git a/examples/datasets/nerf_360_v2.py b/examples/datasets/nerf_360_v2.py index c9f3035..d11ff67 100644 --- a/examples/datasets/nerf_360_v2.py +++ b/examples/datasets/nerf_360_v2.py @@ -276,7 +276,9 @@ def preprocess(self, data): if self.training: if self.color_bkgd_aug == "random": - color_bkgd = torch.rand(3, device=self.images.device, generator=self.g) + color_bkgd = torch.rand( + 3, device=self.images.device, generator=self.g + ) elif self.color_bkgd_aug == "white": color_bkgd = torch.ones(3, device=self.images.device) elif self.color_bkgd_aug == "black": @@ -311,10 +313,18 @@ def fetch_data(self, index): else: image_id = [index] * num_rays x = torch.randint( - 0, self.width, size=(num_rays,), device=self.images.device, generator=self.g + 0, + self.width, + size=(num_rays,), + device=self.images.device, + generator=self.g, ) y = torch.randint( - 0, self.height, size=(num_rays,), device=self.images.device, generator=self.g + 0, + self.height, + size=(num_rays,), + device=self.images.device, + generator=self.g, ) else: image_id = [index] diff --git a/examples/datasets/nerf_synthetic.py b/examples/datasets/nerf_synthetic.py index 5d2c4dc..bf0769b 100644 --- a/examples/datasets/nerf_synthetic.py +++ b/examples/datasets/nerf_synthetic.py @@ -143,7 +143,9 @@ def preprocess(self, data): if self.training: if self.color_bkgd_aug == "random": - color_bkgd = torch.rand(3, device=self.images.device, generator=self.g) + color_bkgd = torch.rand( + 3, device=self.images.device, generator=self.g + ) elif self.color_bkgd_aug == "white": color_bkgd = torch.ones(3, device=self.images.device) elif self.color_bkgd_aug == "black": @@ -179,10 +181,18 @@ def fetch_data(self, index): else: image_id = [index] * num_rays x = torch.randint( - 0, self.WIDTH, size=(num_rays,), device=self.images.device, generator=self.g + 0, + self.WIDTH, + size=(num_rays,), + device=self.images.device, + generator=self.g, ) y = torch.randint( - 0, self.HEIGHT, size=(num_rays,), device=self.images.device, generator=self.g + 0, + self.HEIGHT, + size=(num_rays,), + device=self.images.device, + generator=self.g, ) else: image_id = [index] diff --git a/examples/train_ngp_nerf_occ.py b/examples/train_ngp_nerf_occ.py index 5dfaa0d..026c9e3 100644 --- a/examples/train_ngp_nerf_occ.py +++ b/examples/train_ngp_nerf_occ.py @@ -24,6 +24,7 @@ ) from nerfacc.estimators.occ_grid import OccGridEstimator + def run(args): device = "cuda:0" set_random_seed(42) @@ -102,7 +103,10 @@ def run(args): grad_scaler = torch.cuda.amp.GradScaler(2**10) radiance_field = NGPRadianceField(aabb=estimator.aabbs[-1]).to(device) optimizer = torch.optim.Adam( - radiance_field.parameters(), lr=1e-2, eps=1e-15, weight_decay=weight_decay + radiance_field.parameters(), + lr=1e-2, + eps=1e-15, + weight_decay=weight_decay, ) scheduler = torch.optim.lr_scheduler.ChainedScheduler( [ @@ -167,7 +171,8 @@ def occ_eval_fn(x): # dynamic batch size for rays to keep sample batch size constant. num_rays = len(pixels) num_rays = int( - num_rays * (target_sample_batch_size / float(n_rendering_samples)) + num_rays + * (target_sample_batch_size / float(n_rendering_samples)) ) train_dataset.update_num_rays(num_rays) @@ -249,6 +254,7 @@ def occ_eval_fn(x): lpips_avg = sum(lpips) / len(lpips) print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}") + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( @@ -274,4 +280,4 @@ def occ_eval_fn(x): ) args = parser.parse_args() - run(args) \ No newline at end of file + run(args) diff --git a/nerfacc/estimators/n3tree.py b/nerfacc/estimators/n3tree.py index df796e6..5f1537d 100644 --- a/nerfacc/estimators/n3tree.py +++ b/nerfacc/estimators/n3tree.py @@ -2,13 +2,14 @@ from typing import Callable, List, Optional, Tuple, Union import torch +from torch import Tensor + from ..grid import _enlarge_aabb from ..volrend import ( render_visibility_from_alpha, render_visibility_from_density, ) from .base import AbstractEstimator -from torch import Tensor try: import svox @@ -21,7 +22,7 @@ class N3TreeEstimator(AbstractEstimator): """Use N3Tree to implement Occupancy Grid. - + This allows more flexible topologies than the cascaded grid. However, it is slower to create samples from the tree than the cascaded grid. By default, it has the same topology as the cascaded grid but `self.tree` can be @@ -43,7 +44,9 @@ def __init__( ) # check the resolution is legal - assert isinstance(resolution, int), "N3Tree only supports uniform resolution!" + assert isinstance( + resolution, int + ), "N3Tree only supports uniform resolution!" # check the roi_aabb is legal if isinstance(roi_aabb, (list, tuple)): @@ -148,16 +151,18 @@ def sampling( """ - assert t_min is None and t_max is None, ( - "Do not supported per-ray min max. Please use near_plane and far_plane instead." - ) + assert ( + t_min is None and t_max is None + ), "Do not supported per-ray min max. Please use near_plane and far_plane instead." if stratified: near_plane += torch.rand(()).item() * render_step_size t_starts, t_ends, packed_info, ray_indices = svox.volume_sample( self.tree, thresh=self.thresh, - rays=svox.Rays(rays_o.contiguous(), rays_d.contiguous(), rays_d.contiguous()), + rays=svox.Rays( + rays_o.contiguous(), rays_d.contiguous(), rays_d.contiguous() + ), step_size=render_step_size, cone_angle=cone_angle, near_plane=near_plane, @@ -253,10 +258,16 @@ def update_every_n_steps( @torch.no_grad() def _sample_uniform_and_occupied_cells(self, n: int) -> List[Tensor]: """Samples both n uniform and occupied cells.""" - uniform_indices = torch.randint(len(self.tree), (n,), device=self.device) - occupied_indices = torch.nonzero(self.tree[:].values >= self.thresh)[:, 0] + uniform_indices = torch.randint( + len(self.tree), (n,), device=self.device + ) + occupied_indices = torch.nonzero(self.tree[:].values >= self.thresh)[ + :, 0 + ] if n < len(occupied_indices): - selector = torch.randint(len(occupied_indices), (n,), device=self.device) + selector = torch.randint( + len(occupied_indices), (n,), device=self.device + ) occupied_indices = occupied_indices[selector] indices = torch.cat([uniform_indices, occupied_indices], dim=0) return indices @@ -275,7 +286,9 @@ def _update( x = self.tree.sample(1).squeeze(1) occ = occ_eval_fn(x).squeeze(-1) sel = (*self.tree._all_leaves().T,) - self.tree.data.data[sel] = torch.maximum(self.tree.data.data[sel] * ema_decay, occ[:, None]) + self.tree.data.data[sel] = torch.maximum( + self.tree.data.data[sel] * ema_decay, occ[:, None] + ) else: N = len(self.tree) // 4 indices = self._sample_uniform_and_occupied_cells(N)