From cf9a8062c3f648008fcb1798d3f8a655691ad8e9 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 4 Oct 2024 23:18:05 -0700 Subject: [PATCH 01/53] gtnet --- examples/benchmarks/mcmc_deblur.sh | 23 +++++++ examples/blur_kernel.py | 106 +++++++++++++++++++++++++++++ examples/datasets/colmap.py | 7 +- examples/simple_trainer.py | 84 +++++++++++++++++++++-- examples/utils.py | 54 +++++++++++++++ 5 files changed, 266 insertions(+), 8 deletions(-) create mode 100644 examples/benchmarks/mcmc_deblur.sh create mode 100644 examples/blur_kernel.py diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh new file mode 100644 index 000000000..1cada6052 --- /dev/null +++ b/examples/benchmarks/mcmc_deblur.sh @@ -0,0 +1,23 @@ +# SCENE_DIR="data/deblur_dataset/synthetic_defocus_blur" +# SCENE_LIST="defocuscozy2room" +SCENE_DIR="data/deblur_dataset/real_defocus_blur" +SCENE_LIST="defocuscake defocuscaps defocuscisco defocustools" + +DATA_FACTOR=4 +RENDER_TRAJ_PATH="spiral" + +RESULT_DIR="results/benchmark_mcmc_deblur" +CAP_MAX=100000 + +for SCENE in $SCENE_LIST; +do + echo "Running $SCENE" + + # train and eval + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ + --strategy.cap-max $CAP_MAX \ + --blur_opt \ + --render_traj_path $RENDER_TRAJ_PATH \ + --data_dir $SCENE_DIR/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE/ +done diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py new file mode 100644 index 000000000..377dcfab7 --- /dev/null +++ b/examples/blur_kernel.py @@ -0,0 +1,106 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import torch.nn as nn +import numpy as np + + +class Embedder: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_embedding_fn() + + def create_embedding_fn(self): + embed_fns = [] + d = self.kwargs['input_dims'] + out_dim = 0 + if self.kwargs['include_input']: + embed_fns.append(lambda x : x) + out_dim += d + + max_freq = self.kwargs['max_freq_log2'] + N_freqs = self.kwargs['num_freqs'] + + if self.kwargs['log_sampling']: + freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) + else: + freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) + + for freq in freq_bands: + for p_fn in self.kwargs['periodic_fns']: + embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) + out_dim += d + + self.embed_fns = embed_fns + self.out_dim = out_dim + + def embed(self, inputs): + return torch.cat([fn(inputs) for fn in self.embed_fns], -1) + +def get_embedder(multires, i=0): + if i == -1: + return nn.Identity(), 3 + + embed_kwargs = { + 'include_input' : True, + 'input_dims' : i, + 'max_freq_log2' : multires-1, + 'num_freqs' : multires, + 'log_sampling' : True, + 'periodic_fns' : [torch.sin, torch.cos], + } + + embedder_obj = Embedder(**embed_kwargs) + embed = lambda x, eo=embedder_obj : eo.embed(x) + return embed, embedder_obj.out_dim + +def init_linear_weights(m): + if isinstance(m, nn.Linear): + if m.weight.shape[0] in [2, 3]: + nn.init.xavier_normal_(m.weight, 0.1) + else: + nn.init.xavier_normal_(m.weight) + nn.init.constant_(m.bias, 0) + +class GTnet(nn.Module): + def __init__(self, res_pos=3, res_view=10, num_hidden=3, width=64): + super().__init__() + self.embed_pos, self.embed_pos_cnl = get_embedder(res_pos, 3) + self.embed_view, self.embed_view_cnl = get_embedder(res_view, 3) + in_cnl = self.embed_pos_cnl + self.embed_view_cnl + 7 # 7 for scales and rotations + + hiddens = [nn.Linear(width, width) if i % 2 == 0 else nn.ReLU() + for i in range((num_hidden - 1) * 2)] + + self.linears = nn.Sequential( + nn.Linear(in_cnl, width), + nn.ReLU(), + *hiddens, + ).to("cuda") + self.s = nn.Linear(width, 3).to("cuda") + self.r = nn.Linear(width, 4).to("cuda") + + self.linears.apply(init_linear_weights) + self.s.apply(init_linear_weights) + self.r.apply(init_linear_weights) + + def forward(self, pos, scales, rotations, viewdirs): + pos = self.embed_pos(pos) + viewdirs = self.embed_view(viewdirs) + + x = torch.cat([pos, viewdirs, scales, rotations], dim=-1) + x1 = self.linears(x) + + scales_delta = self.s(x1) + rotations_delta = self.r(x1) + return scales_delta, rotations_delta + \ No newline at end of file diff --git a/examples/datasets/colmap.py b/examples/datasets/colmap.py index 938bad265..ba12c2258 100644 --- a/examples/datasets/colmap.py +++ b/examples/datasets/colmap.py @@ -40,6 +40,11 @@ def __init__( self.factor = factor self.normalize = normalize self.test_every = test_every + li = os.listdir(data_dir) + for l in li: + if l.startswith("hold"): + self.test_every = int(l.split("=")[-1]) + break colmap_dir = os.path.join(data_dir, "sparse/0/") if not os.path.exists(colmap_dir): @@ -134,7 +139,7 @@ def __init__( # Load extended metadata. Used by Bilarf dataset. self.extconf = { - "spiral_radius_scale": 1.0, + "spiral_radius_scale": 0.1, "no_factor_suffix": False, } extconf_file = os.path.join(data_dir, "ext_metadata.json") diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 93e70002f..b5bbbcb98 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -28,20 +28,21 @@ from fused_ssim import fused_ssim from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from typing_extensions import Literal, assert_never -from utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed +from utils import AppearanceOptModule, BlurOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed from lib_bilagrid import ( BilateralGrid, slice, color_correct, total_variation_loss, ) +from blur_kernel import GTnet from gsplat.compression import PngCompression from gsplat.distributed import cli from gsplat.rendering import rasterization from gsplat.strategy import DefaultStrategy, MCMCStrategy from gsplat.optimizers import SelectiveAdam - +from gsplat.utils import log_transform @dataclass class Config: @@ -146,6 +147,11 @@ class Config: app_opt_lr: float = 1e-3 # Regularization for appearance optimization as weight decay app_opt_reg: float = 1e-6 + + # Enable blur optimization. (experimental) + blur_opt: bool = False + # Learning rate for blur optimization + blur_opt_lr: float = 1e-3 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False @@ -230,6 +236,7 @@ def create_splats_with_optimizers( ("scales", torch.nn.Parameter(scales), 5e-3), ("quats", torch.nn.Parameter(quats), 1e-3), ("opacities", torch.nn.Parameter(opacities), 5e-2), + ("f", torch.nn.Parameter(torch.rand(N, 32)), 7.5e-3), ] if feature_dim is None: @@ -398,6 +405,20 @@ def __init__( ] if world_size > 1: self.app_module = DDP(self.app_module) + + self.blur_optimizers = [] + if cfg.blur_opt: + # self.blur_module = BlurOptModule(cfg.sh_degree).to(self.device) + self.blur_module = GTnet() + self.blur_optimizers = [ + torch.optim.Adam( + # self.blur_module.mlp.parameters(), + self.blur_module.parameters(), + lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size), + ), + ] + if world_size > 1: + self.blur_module = DDP(self.blur_module) self.bil_grid_optimizers = [] if cfg.use_bilateral_grid: @@ -447,13 +468,12 @@ def rasterize_splats( width: int, height: int, masks: Optional[Tensor] = None, + is_eval: bool = False, **kwargs, ) -> Tuple[Tensor, Tensor, Dict]: means = self.splats["means"] # [N, 3] # quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] # rasterization does normalization internally - quats = self.splats["quats"] # [N, 4] - scales = torch.exp(self.splats["scales"]) # [N, 3] opacities = torch.sigmoid(self.splats["opacities"]) # [N,] image_ids = kwargs.pop("image_ids", None) @@ -469,6 +489,37 @@ def rasterize_splats( else: colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] + if self.cfg.blur_opt and not is_eval: + scales = torch.exp(self.splats["scales"]) + quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] + + means_ = means.detach() + scales_ = scales.detach() + quats_ = quats.detach() + viewdir_ = camtoworlds[0, :3, 3].repeat(means.shape[0], 1) + + # means_log_ = log_transform(means_) + # viewdir_ = means_ - camtoworlds[0, :3, 3] + # viewdir_ = F.normalize(viewdir_, dim=-1) + + scales_delta, rotations_delta = self.blur_module( + means_, + scales_, + quats_, + viewdir_, + # means[None, :, :] - camtoworlds[:, None, :3, 3], + # sh_degree=self.cfg.sh_degree, + ) + lambda_s = 0.01 + scales_delta = torch.clamp(lambda_s * scales_delta + (1-lambda_s), min=1.0, max=1.1) + rotations_delta = torch.clamp(lambda_s * rotations_delta + (1-lambda_s), min=1.0, max=1.1) + + scales = scales * scales_delta + quats = quats * rotations_delta + else: + quats = self.splats["quats"] # [N, 4] + scales = torch.exp(self.splats["scales"]) # [N, 3] + rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" render_colors, render_alphas, info = rasterization( means=means, @@ -732,6 +783,11 @@ def train(self): data["app_module"] = self.app_module.module.state_dict() else: data["app_module"] = self.app_module.state_dict() + if cfg.blur_opt: + if world_size > 1: + data["blur_module"] = self.blur_module.module.state_dict() + else: + data["blur_module"] = self.blur_module.state_dict() torch.save( data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" ) @@ -774,6 +830,9 @@ def train(self): for optimizer in self.app_optimizers: optimizer.step() optimizer.zero_grad(set_to_none=True) + for optimizer in self.blur_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) for optimizer in self.bil_grid_optimizers: optimizer.step() optimizer.zero_grad(set_to_none=True) @@ -805,7 +864,9 @@ def train(self): # eval the full set if step in [i - 1 for i in cfg.eval_steps]: self.eval(step) + self.eval(step, stage="debug") self.render_traj(step) + self.render_traj(step, stage="debug") # run compression if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: @@ -834,6 +895,13 @@ def eval(self, step: int, stage: str = "val"): valloader = torch.utils.data.DataLoader( self.valset, batch_size=1, shuffle=False, num_workers=1 ) + is_eval = True + if stage == "debug": + valloader = torch.utils.data.DataLoader( + self.trainset, batch_size=1, shuffle=False, num_workers=1 + ) + is_eval = False + ellipse_time = 0 metrics = defaultdict(list) for i, data in enumerate(valloader): @@ -854,6 +922,7 @@ def eval(self, step: int, stage: str = "val"): near_plane=cfg.near_plane, far_plane=cfg.far_plane, masks=masks, + is_eval=is_eval, ) # [1, H, W, 3] torch.cuda.synchronize() ellipse_time += time.time() - tic @@ -904,7 +973,7 @@ def eval(self, step: int, stage: str = "val"): self.writer.flush() @torch.no_grad() - def render_traj(self, step: int): + def render_traj(self, step: int, stage: str = "val"): """Entry for trajectory rendering.""" print("Running trajectory rendering...") cfg = self.cfg @@ -948,7 +1017,7 @@ def render_traj(self, step: int): # save to video video_dir = f"{cfg.result_dir}/videos" os.makedirs(video_dir, exist_ok=True) - writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30) + writer = imageio.get_writer(f"{video_dir}/traj_{stage}_{step}.mp4", fps=30) for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"): camtoworlds = camtoworlds_all[i : i + 1] Ks = K[None] @@ -962,6 +1031,7 @@ def render_traj(self, step: int): near_plane=cfg.near_plane, far_plane=cfg.far_plane, render_mode="RGB+ED", + is_eval=stage == "val", ) # [1, H, W, 4] colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] depths = renders[..., 3:4] # [1, H, W, 1] @@ -973,7 +1043,7 @@ def render_traj(self, step: int): canvas = (canvas * 255).astype(np.uint8) writer.append_data(canvas) writer.close() - print(f"Video saved to {video_dir}/traj_{step}.mp4") + print(f"Video saved to {video_dir}/traj_{stage}_{step}.mp4") @torch.no_grad() def run_compression(self, step: int): diff --git a/examples/utils.py b/examples/utils.py index b79f4244b..963a01a04 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -114,6 +114,60 @@ def forward( return colors +class BlurOptModule(torch.nn.Module): + """Blur optimization module.""" + + def __init__( + self, + sh_degree: int = 3, + mlp_width: int = 64, + mlp_depth: int = 2, + ): + super().__init__() + self.sh_degree = sh_degree + layers = [] + layers.append( + torch.nn.Linear((sh_degree + 1) ** 2 + 4 + 3, mlp_width) + ) + layers.append(torch.nn.ReLU(inplace=True)) + for _ in range(mlp_depth - 1): + layers.append(torch.nn.Linear(mlp_width, mlp_width)) + layers.append(torch.nn.ReLU(inplace=True)) + layers.append(torch.nn.Linear(mlp_width, 7)) + self.mlp = torch.nn.Sequential(*layers) + + def forward( + self, features, quats, scales, dirs: Tensor, sh_degree: int + ) -> Tensor: + """Adjust blur based on MLP. + + Args: + features: (N, feature_dim) + embed_ids: (C,) + dirs: (C, N, 3) + + Returns: + colors: (C, N, 3) + """ + from gsplat.cuda._torch_impl import _eval_sh_bases_fast + + C, N = dirs.shape[:2] + # View directions + dirs = F.normalize(dirs, dim=-1) # [C, N, 3] + num_bases_to_use = (sh_degree + 1) ** 2 + num_bases = (self.sh_degree + 1) ** 2 + sh_bases = torch.zeros(C, N, num_bases, device=quats.device) # [C, N, K] + sh_bases[:, :, :num_bases_to_use] = _eval_sh_bases_fast(num_bases_to_use, dirs) + + h = torch.cat([sh_bases[0], quats, scales], dim=-1) + x = 0.01 * self.mlp(h) + (1 + 0.01) + print(x.min(), x.mean(), x.max()) + x = torch.clip(x, 1.0, 2.0) + quats_s = x[..., :4] + scales_s = x[..., 4:] + return quats_s, scales_s + + def rotation_6d_to_matrix(d6: Tensor) -> Tensor: """ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix From 5088b3b0043fa863d645e9665d71f12e5d609016 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Sat, 5 Oct 2024 11:52:09 -0700 Subject: [PATCH 02/53] reduce diff --- examples/benchmarks/mcmc_deblur.sh | 4 +- examples/blur_kernel.py | 74 ++++++++++++++++-------------- examples/simple_trainer.py | 39 ++++++++-------- examples/utils.py | 54 ---------------------- 4 files changed, 62 insertions(+), 109 deletions(-) diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index 1cada6052..419267199 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -1,7 +1,5 @@ -# SCENE_DIR="data/deblur_dataset/synthetic_defocus_blur" -# SCENE_LIST="defocuscozy2room" SCENE_DIR="data/deblur_dataset/real_defocus_blur" -SCENE_LIST="defocuscake defocuscaps defocuscisco defocustools" +SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools" DATA_FACTOR=4 RENDER_TRAJ_PATH="spiral" diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index 377dcfab7..56a6693c7 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -3,7 +3,7 @@ # GRAPHDECO research group, https://team.inria.fr/graphdeco # All rights reserved. # -# This software is free for non-commercial, research and evaluation use +# This software is free for non-commercial, research and evaluation use # under the terms of the LICENSE.md file. # # For inquiries contact george.drettakis@inria.fr @@ -18,51 +18,53 @@ class Embedder: def __init__(self, **kwargs): self.kwargs = kwargs self.create_embedding_fn() - + def create_embedding_fn(self): embed_fns = [] - d = self.kwargs['input_dims'] + d = self.kwargs["input_dims"] out_dim = 0 - if self.kwargs['include_input']: - embed_fns.append(lambda x : x) + if self.kwargs["include_input"]: + embed_fns.append(lambda x: x) out_dim += d - - max_freq = self.kwargs['max_freq_log2'] - N_freqs = self.kwargs['num_freqs'] - - if self.kwargs['log_sampling']: - freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) + + max_freq = self.kwargs["max_freq_log2"] + N_freqs = self.kwargs["num_freqs"] + + if self.kwargs["log_sampling"]: + freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs) else: - freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) - + freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs) + for freq in freq_bands: - for p_fn in self.kwargs['periodic_fns']: - embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) + for p_fn in self.kwargs["periodic_fns"]: + embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) out_dim += d - + self.embed_fns = embed_fns self.out_dim = out_dim - + def embed(self, inputs): return torch.cat([fn(inputs) for fn in self.embed_fns], -1) + def get_embedder(multires, i=0): if i == -1: return nn.Identity(), 3 - + embed_kwargs = { - 'include_input' : True, - 'input_dims' : i, - 'max_freq_log2' : multires-1, - 'num_freqs' : multires, - 'log_sampling' : True, - 'periodic_fns' : [torch.sin, torch.cos], + "include_input": True, + "input_dims": i, + "max_freq_log2": multires - 1, + "num_freqs": multires, + "log_sampling": True, + "periodic_fns": [torch.sin, torch.cos], } - + embedder_obj = Embedder(**embed_kwargs) - embed = lambda x, eo=embedder_obj : eo.embed(x) + embed = lambda x, eo=embedder_obj: eo.embed(x) return embed, embedder_obj.out_dim + def init_linear_weights(m): if isinstance(m, nn.Linear): if m.weight.shape[0] in [2, 3]: @@ -71,20 +73,25 @@ def init_linear_weights(m): nn.init.xavier_normal_(m.weight) nn.init.constant_(m.bias, 0) + class GTnet(nn.Module): def __init__(self, res_pos=3, res_view=10, num_hidden=3, width=64): super().__init__() self.embed_pos, self.embed_pos_cnl = get_embedder(res_pos, 3) self.embed_view, self.embed_view_cnl = get_embedder(res_view, 3) - in_cnl = self.embed_pos_cnl + self.embed_view_cnl + 7 # 7 for scales and rotations + in_cnl = ( + self.embed_pos_cnl + self.embed_view_cnl + 7 + ) # 7 for scales and rotations - hiddens = [nn.Linear(width, width) if i % 2 == 0 else nn.ReLU() - for i in range((num_hidden - 1) * 2)] + hiddens = [ + nn.Linear(width, width) if i % 2 == 0 else nn.ReLU() + for i in range((num_hidden - 1) * 2) + ] self.linears = nn.Sequential( - nn.Linear(in_cnl, width), - nn.ReLU(), - *hiddens, + nn.Linear(in_cnl, width), + nn.ReLU(), + *hiddens, ).to("cuda") self.s = nn.Linear(width, 3).to("cuda") self.r = nn.Linear(width, 4).to("cuda") @@ -92,7 +99,7 @@ def __init__(self, res_pos=3, res_view=10, num_hidden=3, width=64): self.linears.apply(init_linear_weights) self.s.apply(init_linear_weights) self.r.apply(init_linear_weights) - + def forward(self, pos, scales, rotations, viewdirs): pos = self.embed_pos(pos) viewdirs = self.embed_view(viewdirs) @@ -103,4 +110,3 @@ def forward(self, pos, scales, rotations, viewdirs): scales_delta = self.s(x1) rotations_delta = self.r(x1) return scales_delta, rotations_delta - \ No newline at end of file diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index b5bbbcb98..c6fe09e4e 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -28,7 +28,13 @@ from fused_ssim import fused_ssim from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from typing_extensions import Literal, assert_never -from utils import AppearanceOptModule, BlurOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed +from utils import ( + AppearanceOptModule, + CameraOptModule, + knn, + rgb_to_sh, + set_random_seed, +) from lib_bilagrid import ( BilateralGrid, slice, @@ -44,6 +50,7 @@ from gsplat.optimizers import SelectiveAdam from gsplat.utils import log_transform + @dataclass class Config: # Disable viewer @@ -147,7 +154,7 @@ class Config: app_opt_lr: float = 1e-3 # Regularization for appearance optimization as weight decay app_opt_reg: float = 1e-6 - + # Enable blur optimization. (experimental) blur_opt: bool = False # Learning rate for blur optimization @@ -236,7 +243,6 @@ def create_splats_with_optimizers( ("scales", torch.nn.Parameter(scales), 5e-3), ("quats", torch.nn.Parameter(quats), 1e-3), ("opacities", torch.nn.Parameter(opacities), 5e-2), - ("f", torch.nn.Parameter(torch.rand(N, 32)), 7.5e-3), ] if feature_dim is None: @@ -405,14 +411,12 @@ def __init__( ] if world_size > 1: self.app_module = DDP(self.app_module) - + self.blur_optimizers = [] if cfg.blur_opt: - # self.blur_module = BlurOptModule(cfg.sh_degree).to(self.device) self.blur_module = GTnet() self.blur_optimizers = [ torch.optim.Adam( - # self.blur_module.mlp.parameters(), self.blur_module.parameters(), lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size), ), @@ -492,28 +496,27 @@ def rasterize_splats( if self.cfg.blur_opt and not is_eval: scales = torch.exp(self.splats["scales"]) quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] - + means_ = means.detach() scales_ = scales.detach() quats_ = quats.detach() viewdir_ = camtoworlds[0, :3, 3].repeat(means.shape[0], 1) - - # means_log_ = log_transform(means_) - # viewdir_ = means_ - camtoworlds[0, :3, 3] - # viewdir_ = F.normalize(viewdir_, dim=-1) - scales_delta, rotations_delta = self.blur_module( means_, scales_, quats_, viewdir_, - # means[None, :, :] - camtoworlds[:, None, :3, 3], - # sh_degree=self.cfg.sh_degree, ) + lambda_s = 0.01 - scales_delta = torch.clamp(lambda_s * scales_delta + (1-lambda_s), min=1.0, max=1.1) - rotations_delta = torch.clamp(lambda_s * rotations_delta + (1-lambda_s), min=1.0, max=1.1) - + max_clamp = 1.1 + scales_delta = torch.clamp( + lambda_s * scales_delta + (1.0 - lambda_s), min=1.0, max=max_clamp + ) + rotations_delta = torch.clamp( + lambda_s * rotations_delta + (1 - lambda_s), min=1.0, max=max_clamp + ) + scales = scales * scales_delta quats = quats * rotations_delta else: @@ -901,7 +904,7 @@ def eval(self, step: int, stage: str = "val"): self.trainset, batch_size=1, shuffle=False, num_workers=1 ) is_eval = False - + ellipse_time = 0 metrics = defaultdict(list) for i, data in enumerate(valloader): diff --git a/examples/utils.py b/examples/utils.py index 963a01a04..b79f4244b 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -114,60 +114,6 @@ def forward( return colors -class BlurOptModule(torch.nn.Module): - """Blur optimization module.""" - - def __init__( - self, - sh_degree: int = 3, - mlp_width: int = 64, - mlp_depth: int = 2, - ): - super().__init__() - self.sh_degree = sh_degree - layers = [] - layers.append( - torch.nn.Linear((sh_degree + 1) ** 2 + 4 + 3, mlp_width) - ) - layers.append(torch.nn.ReLU(inplace=True)) - for _ in range(mlp_depth - 1): - layers.append(torch.nn.Linear(mlp_width, mlp_width)) - layers.append(torch.nn.ReLU(inplace=True)) - layers.append(torch.nn.Linear(mlp_width, 7)) - self.mlp = torch.nn.Sequential(*layers) - - def forward( - self, features, quats, scales, dirs: Tensor, sh_degree: int - ) -> Tensor: - """Adjust blur based on MLP. - - Args: - features: (N, feature_dim) - embed_ids: (C,) - dirs: (C, N, 3) - - Returns: - colors: (C, N, 3) - """ - from gsplat.cuda._torch_impl import _eval_sh_bases_fast - - C, N = dirs.shape[:2] - # View directions - dirs = F.normalize(dirs, dim=-1) # [C, N, 3] - num_bases_to_use = (sh_degree + 1) ** 2 - num_bases = (self.sh_degree + 1) ** 2 - sh_bases = torch.zeros(C, N, num_bases, device=quats.device) # [C, N, K] - sh_bases[:, :, :num_bases_to_use] = _eval_sh_bases_fast(num_bases_to_use, dirs) - - h = torch.cat([sh_bases[0], quats, scales], dim=-1) - x = 0.01 * self.mlp(h) + (1 + 0.01) - print(x.min(), x.mean(), x.max()) - x = torch.clip(x, 1.0, 2.0) - quats_s = x[..., :4] - scales_s = x[..., 4:] - return quats_s, scales_s - - def rotation_6d_to_matrix(d6: Tensor) -> Tensor: """ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix From 214c8f06cd7ea4523bb200772252589930fc854c Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Sat, 5 Oct 2024 12:02:07 -0700 Subject: [PATCH 03/53] refactor --- examples/simple_trainer.py | 47 +++++++++++++++----------------------- 1 file changed, 19 insertions(+), 28 deletions(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index c6fe09e4e..b0a76682e 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -28,13 +28,7 @@ from fused_ssim import fused_ssim from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from typing_extensions import Literal, assert_never -from utils import ( - AppearanceOptModule, - CameraOptModule, - knn, - rgb_to_sh, - set_random_seed, -) +from utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed from lib_bilagrid import ( BilateralGrid, slice, @@ -48,7 +42,6 @@ from gsplat.rendering import rasterization from gsplat.strategy import DefaultStrategy, MCMCStrategy from gsplat.optimizers import SelectiveAdam -from gsplat.utils import log_transform @dataclass @@ -472,12 +465,14 @@ def rasterize_splats( width: int, height: int, masks: Optional[Tensor] = None, - is_eval: bool = False, + is_train: bool = False, **kwargs, ) -> Tuple[Tensor, Tensor, Dict]: means = self.splats["means"] # [N, 3] # quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] # rasterization does normalization internally + quats = self.splats["quats"] # [N, 4] + scales = torch.exp(self.splats["scales"]) # [N, 3] opacities = torch.sigmoid(self.splats["opacities"]) # [N,] image_ids = kwargs.pop("image_ids", None) @@ -493,9 +488,9 @@ def rasterize_splats( else: colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] - if self.cfg.blur_opt and not is_eval: - scales = torch.exp(self.splats["scales"]) + if self.cfg.blur_opt and is_train: quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] + scales = torch.exp(self.splats["scales"]) means_ = means.detach() scales_ = scales.detach() @@ -507,7 +502,6 @@ def rasterize_splats( quats_, viewdir_, ) - lambda_s = 0.01 max_clamp = 1.1 scales_delta = torch.clamp( @@ -519,9 +513,6 @@ def rasterize_splats( scales = scales * scales_delta quats = quats * rotations_delta - else: - quats = self.splats["quats"] # [N, 4] - scales = torch.exp(self.splats["scales"]) # [N, 3] rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" render_colors, render_alphas, info = rasterization( @@ -655,6 +646,7 @@ def train(self): image_ids=image_ids, render_mode="RGB+ED" if cfg.depth_loss else "RGB", masks=masks, + is_train=True, ) if renders.shape[-1] == 4: colors, depths = renders[..., 0:3], renders[..., 3:4] @@ -866,10 +858,10 @@ def train(self): # eval the full set if step in [i - 1 for i in cfg.eval_steps]: + self.eval(step, stage="train") self.eval(step) - self.eval(step, stage="debug") + self.render_traj(step, stage="train") self.render_traj(step) - self.render_traj(step, stage="debug") # run compression if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: @@ -895,19 +887,18 @@ def eval(self, step: int, stage: str = "val"): world_rank = self.world_rank world_size = self.world_size - valloader = torch.utils.data.DataLoader( - self.valset, batch_size=1, shuffle=False, num_workers=1 - ) - is_eval = True - if stage == "debug": - valloader = torch.utils.data.DataLoader( + if stage == "train": + dataloader = torch.utils.data.DataLoader( self.trainset, batch_size=1, shuffle=False, num_workers=1 ) - is_eval = False + else: + dataloader = torch.utils.data.DataLoader( + self.valset, batch_size=1, shuffle=False, num_workers=1 + ) ellipse_time = 0 metrics = defaultdict(list) - for i, data in enumerate(valloader): + for i, data in enumerate(dataloader): camtoworlds = data["camtoworld"].to(device) Ks = data["K"].to(device) pixels = data["image"].to(device) / 255.0 @@ -925,7 +916,7 @@ def eval(self, step: int, stage: str = "val"): near_plane=cfg.near_plane, far_plane=cfg.far_plane, masks=masks, - is_eval=is_eval, + is_train=stage == "train", ) # [1, H, W, 3] torch.cuda.synchronize() ellipse_time += time.time() - tic @@ -953,7 +944,7 @@ def eval(self, step: int, stage: str = "val"): metrics["cc_psnr"].append(self.psnr(cc_colors_p, pixels_p)) if world_rank == 0: - ellipse_time /= len(valloader) + ellipse_time /= len(dataloader) stats = {k: torch.stack(v).mean().item() for k, v in metrics.items()} stats.update( @@ -1034,7 +1025,7 @@ def render_traj(self, step: int, stage: str = "val"): near_plane=cfg.near_plane, far_plane=cfg.far_plane, render_mode="RGB+ED", - is_eval=stage == "val", + is_train=stage == "train", ) # [1, H, W, 4] colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] depths = renders[..., 3:4] # [1, H, W, 1] From d00cb1acc3be849714c593ac3563400ddb270643 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Sat, 5 Oct 2024 12:06:48 -0700 Subject: [PATCH 04/53] exact same gtnet --- examples/blur_kernel.py | 31 +++++++++++++++++++++++++++---- examples/simple_trainer.py | 3 +-- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index 56a6693c7..bc752354a 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -75,8 +75,19 @@ def init_linear_weights(m): class GTnet(nn.Module): - def __init__(self, res_pos=3, res_view=10, num_hidden=3, width=64): + def __init__( + self, + res_pos=3, + res_view=10, + num_hidden=3, + width=64, + pos_delta=False, + num_moments=4, + ): super().__init__() + self.pos_delta = pos_delta + self.num_moments = num_moments + self.embed_pos, self.embed_pos_cnl = get_embedder(res_pos, 3) self.embed_view, self.embed_view_cnl = get_embedder(res_view, 3) in_cnl = ( @@ -93,14 +104,22 @@ def __init__(self, res_pos=3, res_view=10, num_hidden=3, width=64): nn.ReLU(), *hiddens, ).to("cuda") - self.s = nn.Linear(width, 3).to("cuda") - self.r = nn.Linear(width, 4).to("cuda") + if not pos_delta: # Defocus + self.s = nn.Linear(width, 3).to("cuda") + self.r = nn.Linear(width, 4).to("cuda") + else: # Motion + self.s = nn.Linear(width, 3 * (num_moments + 1)).to("cuda") + self.r = nn.Linear(width, 4 * (num_moments + 1)).to("cuda") + self.p = nn.Linear(width, 3 * num_moments).to("cuda") self.linears.apply(init_linear_weights) self.s.apply(init_linear_weights) self.r.apply(init_linear_weights) + if pos_delta: + self.p.apply(init_linear_weights) def forward(self, pos, scales, rotations, viewdirs): + pos_delta = None pos = self.embed_pos(pos) viewdirs = self.embed_view(viewdirs) @@ -109,4 +128,8 @@ def forward(self, pos, scales, rotations, viewdirs): scales_delta = self.s(x1) rotations_delta = self.r(x1) - return scales_delta, rotations_delta + + if self.pos_delta: + pos_delta = self.p(x1) + + return scales_delta, rotations_delta, pos_delta diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index b0a76682e..014ca8f35 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -496,7 +496,7 @@ def rasterize_splats( scales_ = scales.detach() quats_ = quats.detach() viewdir_ = camtoworlds[0, :3, 3].repeat(means.shape[0], 1) - scales_delta, rotations_delta = self.blur_module( + scales_delta, rotations_delta, _ = self.blur_module( means_, scales_, quats_, @@ -510,7 +510,6 @@ def rasterize_splats( rotations_delta = torch.clamp( lambda_s * rotations_delta + (1 - lambda_s), min=1.0, max=max_clamp ) - scales = scales * scales_delta quats = quats * rotations_delta From 5aecb73a056f9cf0431995dab1ee3adca1e2ec58 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 11 Oct 2024 17:49:30 -0700 Subject: [PATCH 05/53] wip --- examples/benchmarks/compression/mcmc.sh | 2 +- .../benchmarks/compression/summarize_stats.py | 7 +- examples/benchmarks/mcmc_deblur.sh | 19 ++- examples/datasets/colmap.py | 1 + examples/simple_trainer.py | 134 +++++++++++++----- gsplat/strategy/mcmc.py | 5 + gsplat/strategy/ops.py | 1 + 7 files changed, 129 insertions(+), 40 deletions(-) diff --git a/examples/benchmarks/compression/mcmc.sh b/examples/benchmarks/compression/mcmc.sh index 4c7165f3d..a28bfb6f5 100644 --- a/examples/benchmarks/compression/mcmc.sh +++ b/examples/benchmarks/compression/mcmc.sh @@ -49,7 +49,7 @@ done if command -v zip &> /dev/null then echo "Zipping results" - python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR + python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST else echo "zip command not found, skipping zipping" fi \ No newline at end of file diff --git a/examples/benchmarks/compression/summarize_stats.py b/examples/benchmarks/compression/summarize_stats.py index d11dbed6f..7fa821015 100644 --- a/examples/benchmarks/compression/summarize_stats.py +++ b/examples/benchmarks/compression/summarize_stats.py @@ -10,7 +10,7 @@ def main(results_dir: str, scenes: List[str]): print("scenes:", scenes) - stage = "compress" + stage = "val" summary = defaultdict(list) for scene in scenes: @@ -33,8 +33,11 @@ def main(results_dir: str, scenes: List[str]): summary[k].append(v) for k, v in summary.items(): - print(k, np.mean(v)) + summary[k] = np.mean(v) + summary["scenes"] = scenes + with open(os.path.join(results_dir, f"{stage}_summary.json"), "w") as f: + json.dump(summary, f, indent=2) if __name__ == "__main__": tyro.cli(main) diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index 419267199..954b1cc29 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -1,11 +1,15 @@ SCENE_DIR="data/deblur_dataset/real_defocus_blur" SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools" +# RETRY_LIST="0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19" +# for RETRY in $RETRY_LIST; + # SCENE="defocuscaps" + # --ckpt $RESULT_DIR/$SCENE/ckpts/ckpt_29999_rank0.pt DATA_FACTOR=4 RENDER_TRAJ_PATH="spiral" -RESULT_DIR="results/benchmark_mcmc_deblur" -CAP_MAX=100000 +RESULT_DIR="results/benchmark_mcmc_deblur_cheating3" +CAP_MAX=250000 for SCENE in $SCENE_LIST; do @@ -17,5 +21,14 @@ do --blur_opt \ --render_traj_path $RENDER_TRAJ_PATH \ --data_dir $SCENE_DIR/$SCENE/ \ - --result_dir $RESULT_DIR/$SCENE/ + --result_dir $RESULT_DIR/$SCENE done + +# Zip the compressed files and summarize the stats +if command -v zip &> /dev/null +then + echo "Zipping results" + python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST +else + echo "zip command not found, skipping zipping" +fi diff --git a/examples/datasets/colmap.py b/examples/datasets/colmap.py index ba12c2258..23266e34c 100644 --- a/examples/datasets/colmap.py +++ b/examples/datasets/colmap.py @@ -321,6 +321,7 @@ def __init__( indices = np.arange(len(self.parser.image_names)) if split == "train": self.indices = indices[indices % self.parser.test_every != 0] + # self.indices = np.concatenate([[0], self.indices]) else: self.indices = indices[indices % self.parser.test_every == 0] diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 014ca8f35..5c61c9375 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -83,7 +83,7 @@ class Config: # Number of training steps max_steps: int = 30_000 # Steps to evaluate the model - eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + eval_steps: List[int] = field(default_factory=lambda: [7_000, 15_000, 30_000]) # Steps to save the model save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) @@ -229,21 +229,25 @@ def create_splats_with_optimizers( N = points.shape[0] quats = torch.rand((N, 4)) # [N, 4] opacities = torch.logit(torch.full((N,), init_opacity)) # [N,] + embeds = 0.1 * torch.randn(N, 50, 1) # [N, 3] params = [ # name, value, lr - ("means", torch.nn.Parameter(points), 1.6e-4 * scene_scale), - ("scales", torch.nn.Parameter(scales), 5e-3), - ("quats", torch.nn.Parameter(quats), 1e-3), - ("opacities", torch.nn.Parameter(opacities), 5e-2), + ("means", torch.nn.Parameter(points), 1.6e-4 * scene_scale, 0.0), + ("scales", torch.nn.Parameter(scales), 5e-3, 0.0), + ("quats", torch.nn.Parameter(quats), 1e-3, 0.0), + ("opacities", torch.nn.Parameter(opacities), 5e-2, 0.0), + # ("embeds", torch.nn.Parameter(0.01 * torch.randn(N, 27, 1)), 1e-3, 0), + # ("embeds", torch.nn.Parameter(torch.zeros(N, 100, 1)), 1e-3, 0), + ("embeds", torch.nn.Parameter(embeds), 1e-3, 0), ] if feature_dim is None: # color is SH coefficients. colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3] colors[:, 0, :] = rgb_to_sh(rgbs) - params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), 2.5e-3)) - params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), 2.5e-3 / 20)) + params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), 2.5e-3, 0.0)) + params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), 2.5e-3 / 20, 0.0)) else: # features will be used for appearance and view-dependent shading features = torch.rand(N, feature_dim) # [N, feature_dim] @@ -251,7 +255,7 @@ def create_splats_with_optimizers( colors = torch.logit(rgbs) # [N, 3] params.append(("colors", torch.nn.Parameter(colors), 2.5e-3)) - splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device) + splats = torch.nn.ParameterDict({n: v for n, v, _, _ in params}).to(device) # Scale learning rate based on batch size, reference: # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ # Note that this would not make the training exactly equivalent, see @@ -270,8 +274,9 @@ def create_splats_with_optimizers( eps=1e-15 / math.sqrt(BS), # TODO: check betas logic when BS is larger than 10 betas[0] will be zero. betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), + weight_decay=weight_decay, ) - for name, _, lr in params + for name, _, lr, weight_decay in params } return splats, optimizers @@ -490,34 +495,80 @@ def rasterize_splats( if self.cfg.blur_opt and is_train: quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] - scales = torch.exp(self.splats["scales"]) - - means_ = means.detach() - scales_ = scales.detach() - quats_ = quats.detach() - viewdir_ = camtoworlds[0, :3, 3].repeat(means.shape[0], 1) - scales_delta, rotations_delta, _ = self.blur_module( - means_, - scales_, - quats_, - viewdir_, - ) - lambda_s = 0.01 - max_clamp = 1.1 - scales_delta = torch.clamp( - lambda_s * scales_delta + (1.0 - lambda_s), min=1.0, max=max_clamp - ) - rotations_delta = torch.clamp( - lambda_s * rotations_delta + (1 - lambda_s), min=1.0, max=max_clamp - ) - scales = scales * scales_delta - quats = quats * rotations_delta + # scales = torch.exp(self.splats["scales"]) + + # t = (10 - 1) * (image_ids[0] / (31 - 1 + 0.0001)) + # i = t.int() + # r = t - i + # scales_delta = (1 - r) * self.splats["embeds"][:, i, :3] + r * self.splats["embeds"][:, i + 1, :3] + + scales_delta = self.splats["embeds"][:, image_ids[0], :1] + # scales_delta = torch.clamp(scales_delta, min=0.0) + # scales_delta = torch.abs(scales_delta) + # rotations_delta = self.splats["embeds"][:, image_ids[0], 3:] + + # d = means[:, :] - camtoworlds[:, :3, 3] + # z = d[:, 2] + # a = 0.2 * torch.sigmoid(self.splats["embeds"][:, image_ids[0], 0]) + # f = self.splats["embeds"][:, image_ids[0], 1] + # scales_delta = a * (f - z) + # scales_delta = scales_delta[:, None] + + # a = self.splats["embeds"][:, 0, 0] + # b = self.splats["embeds"][:, 0, 1] + # r = torch.sigmoid(self.splats["embeds"][:, image_ids[0], 2]) + # scales_delta = a * r + b * (1 - r) + # scales_delta = scales_delta[:, None] + + # if image_ids[0] == 0: + # scales_delta = 0.0 * scales_delta + + print("scales", self.splats["scales"].min().item(), self.splats["scales"].max().item(), self.splats["scales"].mean().item()) + print("deltas", scales_delta.min().item(), scales_delta.max().item(), scales_delta.mean().item()) + + + # defocus_mask = self.splats["embeds"][:, 0, 0] > 0 + # print(self.splats["embeds"][:, 0, 0].max().item(), defocus_mask.sum().item()) + # defocus_mask = self.splats["embeds"][:, 1, 0] > 0 + # print(self.splats["embeds"][:, 1, 0].max().item(), defocus_mask.sum().item()) + + + + # scales_delta = torch.clamp(scales_delta, min=-0.5, max=0.5) + # rotations_delta = torch.clamp(rotations_delta, min=0.0, max=0.2) + + # scales_delta = torch.abs(scales_delta) + # scales_delta = torch.clamp(scales_delta, max=1.0) + + # means_ = means.detach() + # scales_ = scales.detach() + # quats_ = quats.detach() + # viewdir_ = camtoworlds[0, :3, 3].repeat(means.shape[0], 1) + # scales_delta, rotations_delta, _ = self.blur_module( + # means_, + # scales_, + # quats_, + # viewdir_, + # ) + # lambda_s = 0.01 + # max_clamp = 1.05 + # scales_delta = torch.clamp( + # lambda_s * scales_delta + (1.0 - lambda_s), min=1.0, max=max_clamp + # ) + # rotations_delta = torch.clamp( + # lambda_s * rotations_delta + (1.0 - lambda_s), min=1.0, max=max_clamp + # ) + # scales = scales * (scales_delta + 1) + # quats = quats * (rotations_delta + 1) + + scales = torch.exp(self.splats["scales"] + scales_delta) + rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" render_colors, render_alphas, info = rasterization( means=means, quats=quats, - scales=scales, + scales=scales * 1.2, opacities=opacities, colors=colors, viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] @@ -714,7 +765,21 @@ def train(self): loss + cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() ) - + + # delta_loss = torch.abs(scales_delta).mean() + # print(delta_loss) + # loss += 0.1 * delta_loss + # loss += 1000.0 * cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"] + scales_delta) - torch.exp(self.splats["scales"])).mean() + + # mask = self.splats["embeds"] > 0 + # embed_reg = -0.8 * self.splats["embeds"][~mask].mean() + 0.2 * self.splats["embeds"][mask].mean() + + embed_reg = torch.abs(self.splats["embeds"]).mean() + # embed_reg = torch.sqrt(torch.abs(self.splats["embeds"])).mean() + # embed_reg = torch.log(1 + torch.abs(self.splats["embeds"])).mean() + print(embed_reg) + loss += 1.0 * embed_reg + loss.backward() desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " @@ -859,7 +924,6 @@ def train(self): if step in [i - 1 for i in cfg.eval_steps]: self.eval(step, stage="train") self.eval(step) - self.render_traj(step, stage="train") self.render_traj(step) # run compression @@ -901,6 +965,7 @@ def eval(self, step: int, stage: str = "val"): camtoworlds = data["camtoworld"].to(device) Ks = data["K"].to(device) pixels = data["image"].to(device) / 255.0 + image_ids = data["image_id"].to(device) masks = data["mask"].to(device) if "mask" in data else None height, width = pixels.shape[1:3] @@ -914,6 +979,7 @@ def eval(self, step: int, stage: str = "val"): sh_degree=cfg.sh_degree, near_plane=cfg.near_plane, far_plane=cfg.far_plane, + image_ids=image_ids, masks=masks, is_train=stage == "train", ) # [1, H, W, 3] diff --git a/gsplat/strategy/mcmc.py b/gsplat/strategy/mcmc.py index c07e17376..579f517e3 100644 --- a/gsplat/strategy/mcmc.py +++ b/gsplat/strategy/mcmc.py @@ -143,6 +143,11 @@ def step_post_backward( inject_noise_to_position( params=params, optimizers=optimizers, state={}, scaler=lr * self.noise_lr ) + + # with torch.no_grad(): + # deltas_min, _ = params["embeds"].min(dim=1) + # params["scales"] += deltas_min.repeat(1, 3) + # params["embeds"] -= deltas_min[:, :, None].repeat(1, 100, 1) @torch.no_grad() def _relocate_gs( diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 0789dfcbc..486352871 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -361,3 +361,4 @@ def op_sigmoid(x, k=100, x0=0.995): ) noise = torch.einsum("bij,bj->bi", covars, noise) params["means"].add_(noise) + \ No newline at end of file From 15f8898673ff130ecb38d0f345e171c49ec5bd4f Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Sat, 12 Oct 2024 21:51:04 -0700 Subject: [PATCH 06/53] cleanup --- examples/benchmarks/mcmc_deblur.sh | 6 +- examples/blur_kernel.py | 2 + examples/datasets/colmap.py | 3 +- examples/simple_trainer.py | 101 ++++------------------------- 4 files changed, 19 insertions(+), 93 deletions(-) diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index 954b1cc29..7973ddd3a 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -1,14 +1,10 @@ SCENE_DIR="data/deblur_dataset/real_defocus_blur" SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools" -# RETRY_LIST="0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19" -# for RETRY in $RETRY_LIST; - # SCENE="defocuscaps" - # --ckpt $RESULT_DIR/$SCENE/ckpts/ckpt_29999_rank0.pt DATA_FACTOR=4 RENDER_TRAJ_PATH="spiral" -RESULT_DIR="results/benchmark_mcmc_deblur_cheating3" +RESULT_DIR="results/benchmark_mcmc_deblur" CAP_MAX=250000 for SCENE in $SCENE_LIST; diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index bc752354a..409d27baa 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -77,6 +77,7 @@ def init_linear_weights(m): class GTnet(nn.Module): def __init__( self, + n, res_pos=3, res_view=10, num_hidden=3, @@ -85,6 +86,7 @@ def __init__( num_moments=4, ): super().__init__() + self.focals = torch.nn.Parameter(0.1 * torch.ones(n)) self.pos_delta = pos_delta self.num_moments = num_moments diff --git a/examples/datasets/colmap.py b/examples/datasets/colmap.py index 23266e34c..40a649c5e 100644 --- a/examples/datasets/colmap.py +++ b/examples/datasets/colmap.py @@ -321,9 +321,10 @@ def __init__( indices = np.arange(len(self.parser.image_names)) if split == "train": self.indices = indices[indices % self.parser.test_every != 0] - # self.indices = np.concatenate([[0], self.indices]) + self.indices = np.concatenate([[0], self.indices]) else: self.indices = indices[indices % self.parser.test_every == 0] + self.indices = self.indices[1:] def __len__(self): return len(self.indices) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 5c61c9375..f061e163b 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -229,7 +229,7 @@ def create_splats_with_optimizers( N = points.shape[0] quats = torch.rand((N, 4)) # [N, 4] opacities = torch.logit(torch.full((N,), init_opacity)) # [N,] - embeds = 0.1 * torch.randn(N, 50, 1) # [N, 3] + embeds = 0.1 * torch.randn(N, 50, 7) params = [ # name, value, lr @@ -237,8 +237,6 @@ def create_splats_with_optimizers( ("scales", torch.nn.Parameter(scales), 5e-3, 0.0), ("quats", torch.nn.Parameter(quats), 1e-3, 0.0), ("opacities", torch.nn.Parameter(opacities), 5e-2, 0.0), - # ("embeds", torch.nn.Parameter(0.01 * torch.randn(N, 27, 1)), 1e-3, 0), - # ("embeds", torch.nn.Parameter(torch.zeros(N, 100, 1)), 1e-3, 0), ("embeds", torch.nn.Parameter(embeds), 1e-3, 0), ] @@ -412,7 +410,7 @@ def __init__( self.blur_optimizers = [] if cfg.blur_opt: - self.blur_module = GTnet() + self.blur_module = GTnet(len(self.trainset)) self.blur_optimizers = [ torch.optim.Adam( self.blur_module.parameters(), @@ -494,81 +492,21 @@ def rasterize_splats( colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] if self.cfg.blur_opt and is_train: - quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] - # scales = torch.exp(self.splats["scales"]) - - # t = (10 - 1) * (image_ids[0] / (31 - 1 + 0.0001)) - # i = t.int() - # r = t - i - # scales_delta = (1 - r) * self.splats["embeds"][:, i, :3] + r * self.splats["embeds"][:, i + 1, :3] - - scales_delta = self.splats["embeds"][:, image_ids[0], :1] - # scales_delta = torch.clamp(scales_delta, min=0.0) - # scales_delta = torch.abs(scales_delta) - # rotations_delta = self.splats["embeds"][:, image_ids[0], 3:] - - # d = means[:, :] - camtoworlds[:, :3, 3] - # z = d[:, 2] - # a = 0.2 * torch.sigmoid(self.splats["embeds"][:, image_ids[0], 0]) - # f = self.splats["embeds"][:, image_ids[0], 1] - # scales_delta = a * (f - z) - # scales_delta = scales_delta[:, None] - - # a = self.splats["embeds"][:, 0, 0] - # b = self.splats["embeds"][:, 0, 1] - # r = torch.sigmoid(self.splats["embeds"][:, image_ids[0], 2]) - # scales_delta = a * r + b * (1 - r) - # scales_delta = scales_delta[:, None] - - # if image_ids[0] == 0: - # scales_delta = 0.0 * scales_delta - - print("scales", self.splats["scales"].min().item(), self.splats["scales"].max().item(), self.splats["scales"].mean().item()) - print("deltas", scales_delta.min().item(), scales_delta.max().item(), scales_delta.mean().item()) - - - # defocus_mask = self.splats["embeds"][:, 0, 0] > 0 - # print(self.splats["embeds"][:, 0, 0].max().item(), defocus_mask.sum().item()) - # defocus_mask = self.splats["embeds"][:, 1, 0] > 0 - # print(self.splats["embeds"][:, 1, 0].max().item(), defocus_mask.sum().item()) - - - - # scales_delta = torch.clamp(scales_delta, min=-0.5, max=0.5) - # rotations_delta = torch.clamp(rotations_delta, min=0.0, max=0.2) - - # scales_delta = torch.abs(scales_delta) - # scales_delta = torch.clamp(scales_delta, max=1.0) - - # means_ = means.detach() - # scales_ = scales.detach() - # quats_ = quats.detach() - # viewdir_ = camtoworlds[0, :3, 3].repeat(means.shape[0], 1) - # scales_delta, rotations_delta, _ = self.blur_module( - # means_, - # scales_, - # quats_, - # viewdir_, - # ) - # lambda_s = 0.01 - # max_clamp = 1.05 - # scales_delta = torch.clamp( - # lambda_s * scales_delta + (1.0 - lambda_s), min=1.0, max=max_clamp - # ) - # rotations_delta = torch.clamp( - # lambda_s * rotations_delta + (1.0 - lambda_s), min=1.0, max=max_clamp - # ) - # scales = scales * (scales_delta + 1) - # quats = quats * (rotations_delta + 1) - + scales_delta = self.splats["embeds"][:, image_ids[0], :3] + rotations_delta = self.splats["embeds"][:, image_ids[0], 3:] + scales_delta = torch.abs(scales_delta) + if image_ids[0] == 0: + scales_delta = 0.0 * scales_delta + rotations_delta = 0.0 * rotations_delta + scales = torch.exp(self.splats["scales"] + scales_delta) - + quats = F.normalize(self.splats["quats"] + rotations_delta, dim=-1) # [N, 4] rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" render_colors, render_alphas, info = rasterization( means=means, quats=quats, - scales=scales * 1.2, + scales=scales, opacities=opacities, colors=colors, viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] @@ -766,20 +704,9 @@ def train(self): + cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() ) - # delta_loss = torch.abs(scales_delta).mean() - # print(delta_loss) - # loss += 0.1 * delta_loss - # loss += 1000.0 * cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"] + scales_delta) - torch.exp(self.splats["scales"])).mean() - - # mask = self.splats["embeds"] > 0 - # embed_reg = -0.8 * self.splats["embeds"][~mask].mean() + 0.2 * self.splats["embeds"][mask].mean() - - embed_reg = torch.abs(self.splats["embeds"]).mean() - # embed_reg = torch.sqrt(torch.abs(self.splats["embeds"])).mean() - # embed_reg = torch.log(1 + torch.abs(self.splats["embeds"])).mean() - print(embed_reg) - loss += 1.0 * embed_reg - + # embed_reg = torch.abs(self.splats["embeds"]).mean() + embed_reg = torch.log(1 + torch.abs(self.splats["embeds"])).mean() + loss += 10.0 * embed_reg loss.backward() desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " From 9767f142874555d245464bb4fee22078ef29fea1 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Sat, 12 Oct 2024 21:53:51 -0700 Subject: [PATCH 07/53] simplify --- .../benchmarks/compression/summarize_stats.py | 1 + examples/simple_trainer.py | 27 ++++++++++--------- gsplat/strategy/mcmc.py | 5 ---- gsplat/strategy/ops.py | 1 - 4 files changed, 15 insertions(+), 19 deletions(-) diff --git a/examples/benchmarks/compression/summarize_stats.py b/examples/benchmarks/compression/summarize_stats.py index 7fa821015..f2552a049 100644 --- a/examples/benchmarks/compression/summarize_stats.py +++ b/examples/benchmarks/compression/summarize_stats.py @@ -39,5 +39,6 @@ def main(results_dir: str, scenes: List[str]): with open(os.path.join(results_dir, f"{stage}_summary.json"), "w") as f: json.dump(summary, f, indent=2) + if __name__ == "__main__": tyro.cli(main) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index f061e163b..1379a5aca 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -233,19 +233,19 @@ def create_splats_with_optimizers( params = [ # name, value, lr - ("means", torch.nn.Parameter(points), 1.6e-4 * scene_scale, 0.0), - ("scales", torch.nn.Parameter(scales), 5e-3, 0.0), - ("quats", torch.nn.Parameter(quats), 1e-3, 0.0), - ("opacities", torch.nn.Parameter(opacities), 5e-2, 0.0), - ("embeds", torch.nn.Parameter(embeds), 1e-3, 0), + ("means", torch.nn.Parameter(points), 1.6e-4 * scene_scale), + ("scales", torch.nn.Parameter(scales), 5e-3), + ("quats", torch.nn.Parameter(quats), 1e-3), + ("opacities", torch.nn.Parameter(opacities), 5e-2), + ("embeds", torch.nn.Parameter(embeds), 1e-3), ] if feature_dim is None: # color is SH coefficients. colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3] colors[:, 0, :] = rgb_to_sh(rgbs) - params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), 2.5e-3, 0.0)) - params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), 2.5e-3 / 20, 0.0)) + params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), 2.5e-3)) + params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), 2.5e-3 / 20)) else: # features will be used for appearance and view-dependent shading features = torch.rand(N, feature_dim) # [N, feature_dim] @@ -253,7 +253,7 @@ def create_splats_with_optimizers( colors = torch.logit(rgbs) # [N, 3] params.append(("colors", torch.nn.Parameter(colors), 2.5e-3)) - splats = torch.nn.ParameterDict({n: v for n, v, _, _ in params}).to(device) + splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device) # Scale learning rate based on batch size, reference: # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ # Note that this would not make the training exactly equivalent, see @@ -272,9 +272,8 @@ def create_splats_with_optimizers( eps=1e-15 / math.sqrt(BS), # TODO: check betas logic when BS is larger than 10 betas[0] will be zero. betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), - weight_decay=weight_decay, ) - for name, _, lr, weight_decay in params + for name, _, lr in params } return splats, optimizers @@ -498,9 +497,11 @@ def rasterize_splats( if image_ids[0] == 0: scales_delta = 0.0 * scales_delta rotations_delta = 0.0 * rotations_delta - + scales = torch.exp(self.splats["scales"] + scales_delta) - quats = F.normalize(self.splats["quats"] + rotations_delta, dim=-1) # [N, 4] + quats = F.normalize( + self.splats["quats"] + rotations_delta, dim=-1 + ) # [N, 4] rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" render_colors, render_alphas, info = rasterization( @@ -703,7 +704,7 @@ def train(self): loss + cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() ) - + # embed_reg = torch.abs(self.splats["embeds"]).mean() embed_reg = torch.log(1 + torch.abs(self.splats["embeds"])).mean() loss += 10.0 * embed_reg diff --git a/gsplat/strategy/mcmc.py b/gsplat/strategy/mcmc.py index 579f517e3..c07e17376 100644 --- a/gsplat/strategy/mcmc.py +++ b/gsplat/strategy/mcmc.py @@ -143,11 +143,6 @@ def step_post_backward( inject_noise_to_position( params=params, optimizers=optimizers, state={}, scaler=lr * self.noise_lr ) - - # with torch.no_grad(): - # deltas_min, _ = params["embeds"].min(dim=1) - # params["scales"] += deltas_min.repeat(1, 3) - # params["embeds"] -= deltas_min[:, :, None].repeat(1, 100, 1) @torch.no_grad() def _relocate_gs( diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 486352871..0789dfcbc 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -361,4 +361,3 @@ def op_sigmoid(x, k=100, x0=0.995): ) noise = torch.einsum("bij,bj->bi", covars, noise) params["means"].add_(noise) - \ No newline at end of file From e16873815ee1d00bbd871fa789a685c11181890c Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Sun, 13 Oct 2024 19:35:43 -0700 Subject: [PATCH 08/53] masks works --- examples/blur_kernel.py | 4 +- examples/datasets/colmap.py | 4 +- examples/simple_trainer.py | 104 ++++++++++++++++++++++++++++++------ 3 files changed, 93 insertions(+), 19 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index 409d27baa..9b2b5727a 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -12,6 +12,7 @@ import torch import torch.nn as nn import numpy as np +import torch.nn.functional as F class Embedder: @@ -86,7 +87,8 @@ def __init__( num_moments=4, ): super().__init__() - self.focals = torch.nn.Parameter(0.1 * torch.ones(n)) + self.blur_masks = torch.nn.Parameter(-0.5 * torch.ones(n, 1, 40, 60)) + self.pos_delta = pos_delta self.num_moments = num_moments diff --git a/examples/datasets/colmap.py b/examples/datasets/colmap.py index 40a649c5e..83482148a 100644 --- a/examples/datasets/colmap.py +++ b/examples/datasets/colmap.py @@ -321,10 +321,10 @@ def __init__( indices = np.arange(len(self.parser.image_names)) if split == "train": self.indices = indices[indices % self.parser.test_every != 0] - self.indices = np.concatenate([[0], self.indices]) + # self.indices = np.concatenate([[0], self.indices]) else: self.indices = indices[indices % self.parser.test_every == 0] - self.indices = self.indices[1:] + # self.indices = self.indices[1:] def __len__(self): return len(self.indices) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 1379a5aca..3b0a60762 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -11,6 +11,7 @@ import numpy as np import torch import torch.nn.functional as F +import torchvision import tqdm import tyro import viser @@ -42,6 +43,7 @@ from gsplat.rendering import rasterization from gsplat.strategy import DefaultStrategy, MCMCStrategy from gsplat.optimizers import SelectiveAdam +from gsplat.utils import log_transform @dataclass @@ -409,7 +411,7 @@ def __init__( self.blur_optimizers = [] if cfg.blur_opt: - self.blur_module = GTnet(len(self.trainset)) + self.blur_module = GTnet(len(self.trainset)).to(self.device) self.blur_optimizers = [ torch.optim.Adam( self.blur_module.parameters(), @@ -467,7 +469,7 @@ def rasterize_splats( width: int, height: int, masks: Optional[Tensor] = None, - is_train: bool = False, + blur: bool = False, **kwargs, ) -> Tuple[Tensor, Tensor, Dict]: means = self.splats["means"] # [N, 3] @@ -490,18 +492,36 @@ def rasterize_splats( else: colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] - if self.cfg.blur_opt and is_train: + if blur: scales_delta = self.splats["embeds"][:, image_ids[0], :3] rotations_delta = self.splats["embeds"][:, image_ids[0], 3:] - scales_delta = torch.abs(scales_delta) - if image_ids[0] == 0: - scales_delta = 0.0 * scales_delta - rotations_delta = 0.0 * rotations_delta - + scales_delta = torch.clamp(scales_delta, min=0.0, max=0.1) + rotations_delta = torch.clamp(rotations_delta, min=-0.05, max=0.05) + + # means_ = log_transform(means.detach()) + # scales_ = self.splats["scales"].detach() + # quats_ = F.normalize(self.splats["quats"], dim=-1).detach() + # viewdir_ = 0.1 * camtoworlds[0, :3, 3].repeat(means.shape[0], 1) + # scales_delta, rotations_delta, _ = self.blur_module( + # means_, + # scales_, + # quats_, + # viewdir_, + # ) + # scales_delta = 0.0001 * scales_delta + # rotations_delta = 0.0001 * rotations_delta + # scales_delta = torch.clamp(scales_delta, min=0.0, max=0.1) + # rotations_delta = torch.clamp(rotations_delta, min=-0.05, max=0.05) + + # print("scales_delta", scales_delta.min().item(), scales_delta.max().item(), scales_delta.mean().item()) + # print("rotations_delta", rotations_delta.min().item(), rotations_delta.max().item(), rotations_delta.mean().item()) scales = torch.exp(self.splats["scales"] + scales_delta) - quats = F.normalize( - self.splats["quats"] + rotations_delta, dim=-1 - ) # [N, 4] + quats = F.normalize(self.splats["quats"], dim=-1) + rotations_delta + quats = F.normalize(quats, dim=-1) + + # if image_ids[0] == 0: + # scales_delta = 0.0 * scales_delta + # rotations_delta = 0.0 * rotations_delta rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" render_colors, render_alphas, info = rasterization( @@ -528,6 +548,9 @@ def rasterize_splats( ) if masks is not None: render_colors[~masks] = 0 + + # if self.cfg.blur_opt and is_train: + # info["scales_delta"] = scales_delta return render_colors, render_alphas, info def train(self): @@ -635,7 +658,6 @@ def train(self): image_ids=image_ids, render_mode="RGB+ED" if cfg.depth_loss else "RGB", masks=masks, - is_train=True, ) if renders.shape[-1] == 4: colors, depths = renders[..., 0:3], renders[..., 3:4] @@ -654,6 +676,26 @@ def train(self): if cfg.random_bkgd: bkgd = torch.rand(1, 3, device=device) colors = colors + bkgd * (1.0 - alphas) + if cfg.blur_opt: + renders_blur, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=sh_degree_to_use, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + image_ids=image_ids, + render_mode="RGB+ED" if cfg.depth_loss else "RGB", + masks=masks, + blur=True + ) + blur_mask = self.blur_module.blur_masks[image_ids[0]][None, ...] + blur_mask = torchvision.transforms.functional.gaussian_blur(blur_mask, kernel_size=15) + blur_mask = F.interpolate(blur_mask, scale_factor=10, mode='bilinear', align_corners=False) + blur_mask = torch.sigmoid(10 * blur_mask)[0, ...][..., None] + print(blur_mask.min().item(), blur_mask.max().item(), blur_mask.mean().item()) + colors = (1 - blur_mask) * colors + blur_mask * renders_blur[..., 0:3] self.cfg.strategy.step_pre_backward( params=self.splats, @@ -704,10 +746,15 @@ def train(self): loss + cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() ) + # mask_reg = torch.sigmoid(10 * self.blur_module.blur_masks).mean() + # loss += 0.01 * mask_reg + + # delta_reg = torch.abs(info["scales_delta"]).mean() + # loss += 0.01 * delta_reg # embed_reg = torch.abs(self.splats["embeds"]).mean() - embed_reg = torch.log(1 + torch.abs(self.splats["embeds"])).mean() - loss += 10.0 * embed_reg + # # embed_reg = torch.log(1 + torch.abs(self.splats["embeds"])).mean() + # loss += 10.0 * embed_reg loss.backward() desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " @@ -909,13 +956,39 @@ def eval(self, step: int, stage: str = "val"): far_plane=cfg.far_plane, image_ids=image_ids, masks=masks, - is_train=stage == "train", ) # [1, H, W, 3] torch.cuda.synchronize() ellipse_time += time.time() - tic colors = torch.clamp(colors, 0.0, 1.0) canvas_list = [pixels, colors] + if stage == "train": + blur_mask = self.blur_module.blur_masks[image_ids[0]][None, ...] + blur_mask = torchvision.transforms.functional.gaussian_blur(blur_mask, kernel_size=15) + blur_mask = F.interpolate(blur_mask, scale_factor=10, mode='bilinear', align_corners=False) + blur_mask = torch.sigmoid(10 * blur_mask)[0, ...][..., None] + blur_mask_color = blur_mask.repeat(1, 1, 1, 3) + canvas_list.append(blur_mask_color) + + colors_blur, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + image_ids=image_ids, + masks=masks, + blur=True, + ) # [1, H, W, 3] + colors_blur = torch.clamp(colors_blur, 0.0, 1.0) + canvas_list.append(colors_blur) + + colors_mix = (1 - blur_mask) * colors + blur_mask * colors_blur + colors_mix = torch.clamp(colors_mix, 0.0, 1.0) + canvas_list.append(colors_mix) + colors = colors_mix if world_rank == 0: # write images @@ -1018,7 +1091,6 @@ def render_traj(self, step: int, stage: str = "val"): near_plane=cfg.near_plane, far_plane=cfg.far_plane, render_mode="RGB+ED", - is_train=stage == "train", ) # [1, H, W, 4] colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] depths = renders[..., 3:4] # [1, H, W, 1] From 1e338e361a50bc4307a1ee737b901ca9675c8f30 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Sun, 13 Oct 2024 22:51:46 -0700 Subject: [PATCH 09/53] mask is working --- examples/blur_kernel.py | 4 ++-- examples/simple_trainer.py | 42 ++++++++++++++++++++++++-------------- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index 9b2b5727a..aa377b55b 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -87,8 +87,8 @@ def __init__( num_moments=4, ): super().__init__() - self.blur_masks = torch.nn.Parameter(-0.5 * torch.ones(n, 1, 40, 60)) - + self.blur_masks = torch.nn.Parameter(-1.0 * torch.ones(n, 1, 20, 30)) + self.pos_delta = pos_delta self.num_moments = num_moments diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 3b0a60762..e2f9e28b8 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -153,7 +153,7 @@ class Config: # Enable blur optimization. (experimental) blur_opt: bool = False # Learning rate for blur optimization - blur_opt_lr: float = 1e-3 + blur_opt_lr: float = 1e-2 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False @@ -497,7 +497,7 @@ def rasterize_splats( rotations_delta = self.splats["embeds"][:, image_ids[0], 3:] scales_delta = torch.clamp(scales_delta, min=0.0, max=0.1) rotations_delta = torch.clamp(rotations_delta, min=-0.05, max=0.05) - + # means_ = log_transform(means.detach()) # scales_ = self.splats["scales"].detach() # quats_ = F.normalize(self.splats["quats"], dim=-1).detach() @@ -518,7 +518,7 @@ def rasterize_splats( scales = torch.exp(self.splats["scales"] + scales_delta) quats = F.normalize(self.splats["quats"], dim=-1) + rotations_delta quats = F.normalize(quats, dim=-1) - + # if image_ids[0] == 0: # scales_delta = 0.0 * scales_delta # rotations_delta = 0.0 * rotations_delta @@ -548,7 +548,7 @@ def rasterize_splats( ) if masks is not None: render_colors[~masks] = 0 - + # if self.cfg.blur_opt and is_train: # info["scales_delta"] = scales_delta return render_colors, render_alphas, info @@ -688,13 +688,21 @@ def train(self): image_ids=image_ids, render_mode="RGB+ED" if cfg.depth_loss else "RGB", masks=masks, - blur=True + blur=True, ) blur_mask = self.blur_module.blur_masks[image_ids[0]][None, ...] - blur_mask = torchvision.transforms.functional.gaussian_blur(blur_mask, kernel_size=15) - blur_mask = F.interpolate(blur_mask, scale_factor=10, mode='bilinear', align_corners=False) - blur_mask = torch.sigmoid(10 * blur_mask)[0, ...][..., None] - print(blur_mask.min().item(), blur_mask.max().item(), blur_mask.mean().item()) + blur_mask = torchvision.transforms.functional.gaussian_blur( + blur_mask, kernel_size=5 + ) + blur_mask = F.interpolate( + blur_mask, scale_factor=20, mode="bilinear", align_corners=False + ) + blur_mask = torch.sigmoid(blur_mask)[0, ...][..., None] + print( + blur_mask.min().item(), + blur_mask.max().item(), + blur_mask.mean().item(), + ) colors = (1 - blur_mask) * colors + blur_mask * renders_blur[..., 0:3] self.cfg.strategy.step_pre_backward( @@ -748,7 +756,7 @@ def train(self): ) # mask_reg = torch.sigmoid(10 * self.blur_module.blur_masks).mean() # loss += 0.01 * mask_reg - + # delta_reg = torch.abs(info["scales_delta"]).mean() # loss += 0.01 * delta_reg @@ -964,12 +972,16 @@ def eval(self, step: int, stage: str = "val"): canvas_list = [pixels, colors] if stage == "train": blur_mask = self.blur_module.blur_masks[image_ids[0]][None, ...] - blur_mask = torchvision.transforms.functional.gaussian_blur(blur_mask, kernel_size=15) - blur_mask = F.interpolate(blur_mask, scale_factor=10, mode='bilinear', align_corners=False) - blur_mask = torch.sigmoid(10 * blur_mask)[0, ...][..., None] + blur_mask = torchvision.transforms.functional.gaussian_blur( + blur_mask, kernel_size=5 + ) + blur_mask = F.interpolate( + blur_mask, scale_factor=20, mode="bilinear", align_corners=False + ) + blur_mask = torch.sigmoid(blur_mask)[0, ...][..., None] blur_mask_color = blur_mask.repeat(1, 1, 1, 3) canvas_list.append(blur_mask_color) - + colors_blur, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, @@ -984,7 +996,7 @@ def eval(self, step: int, stage: str = "val"): ) # [1, H, W, 3] colors_blur = torch.clamp(colors_blur, 0.0, 1.0) canvas_list.append(colors_blur) - + colors_mix = (1 - blur_mask) * colors + blur_mask * colors_blur colors_mix = torch.clamp(colors_mix, 0.0, 1.0) canvas_list.append(colors_mix) From abad1d179b1e7835d8e731378b46004d370153c7 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 14 Oct 2024 19:41:09 -0700 Subject: [PATCH 10/53] mlp mask --- examples/blur_kernel.py | 19 ++++++++- examples/simple_trainer.py | 85 +++++++++++++++++++++++++++++--------- 2 files changed, 83 insertions(+), 21 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index aa377b55b..c60588c10 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -87,7 +87,24 @@ def __init__( num_moments=4, ): super().__init__() - self.blur_masks = torch.nn.Parameter(-1.0 * torch.ones(n, 1, 20, 30)) + self.blur_masks = torch.nn.Parameter( + -1.0 * torch.ones(n, 1, 400 // 8, 600 // 8) + ) + # self.image_feats = torch.nn.Parameter(torch.rand(n, 4)) + self.depth_mlps = nn.ModuleList() + for _ in range(n): + mlp = nn.Sequential( + nn.Linear(3, 64, bias=False), + nn.ReLU(), + nn.Linear(64, 64, bias=False), + nn.ReLU(), + nn.Linear(64, 64, bias=False), + nn.ReLU(), + nn.Linear(64, 64, bias=False), + nn.ReLU(), + nn.Linear(64, 1, bias=False), + ).to("cuda") + self.depth_mlps.append(mlp) self.pos_delta = pos_delta self.num_moments = num_moments diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index e2f9e28b8..694d30d9b 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -85,7 +85,9 @@ class Config: # Number of training steps max_steps: int = 30_000 # Steps to evaluate the model - eval_steps: List[int] = field(default_factory=lambda: [7_000, 15_000, 30_000]) + eval_steps: List[int] = field( + default_factory=lambda: [2_000, 7_000, 15_000, 30_000] + ) # Steps to save the model save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) @@ -153,7 +155,9 @@ class Config: # Enable blur optimization. (experimental) blur_opt: bool = False # Learning rate for blur optimization - blur_opt_lr: float = 1e-2 + blur_opt_lr: float = 1e-3 + # Regularization for blur optimization as weight decay + blur_opt_reg: float = 1e-4 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False @@ -416,6 +420,7 @@ def __init__( torch.optim.Adam( self.blur_module.parameters(), lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size), + weight_decay=cfg.blur_opt_reg, ), ] if world_size > 1: @@ -656,7 +661,7 @@ def train(self): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, - render_mode="RGB+ED" if cfg.depth_loss else "RGB", + render_mode="RGB+ED", # if cfg.depth_loss else "RGB", masks=masks, ) if renders.shape[-1] == 4: @@ -690,14 +695,31 @@ def train(self): masks=masks, blur=True, ) - blur_mask = self.blur_module.blur_masks[image_ids[0]][None, ...] - blur_mask = torchvision.transforms.functional.gaussian_blur( - blur_mask, kernel_size=5 - ) - blur_mask = F.interpolate( - blur_mask, scale_factor=20, mode="bilinear", align_corners=False + + grid_y, grid_x = torch.meshgrid( + (torch.arange(height, device=self.device) + 0.5) / height, + (torch.arange(width, device=self.device) + 0.5) / width, + indexing="ij", ) - blur_mask = torch.sigmoid(blur_mask)[0, ...][..., None] + grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) + x = grid_xy.reshape(-1, 2) + x = torch.cat([x, depths.reshape(x.shape[0], 1)], dim=-1) + mlp_out = self.blur_module.depth_mlps[image_ids[0]](x) + mlp_out = mlp_out - mlp_out.mean() + print(mlp_out.min(), mlp_out.max(), mlp_out.mean()) + blur_mask = torch.sigmoid(mlp_out) + blur_mask = blur_mask.reshape(depths.shape) + + # blur_mask = blur_mask - blur_mask.mean() + 0.5 + + # blur_mask = self.blur_module.blur_masks[image_ids[0]][None, ...] + # blur_mask = torchvision.transforms.functional.gaussian_blur( + # blur_mask, kernel_size=3 + # ) + # blur_mask = F.interpolate( + # blur_mask, scale_factor=8, mode="bilinear", align_corners=False + # ) + # blur_mask = torch.sigmoid(blur_mask)[0, ...][..., None] print( blur_mask.min().item(), blur_mask.max().item(), @@ -754,8 +776,9 @@ def train(self): loss + cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() ) - # mask_reg = torch.sigmoid(10 * self.blur_module.blur_masks).mean() - # loss += 0.01 * mask_reg + + # mlp_reg = torch.abs().mean() + # loss += 0.0001 * mlp_reg # delta_reg = torch.abs(info["scales_delta"]).mean() # loss += 0.01 * delta_reg @@ -954,7 +977,7 @@ def eval(self, step: int, stage: str = "val"): torch.cuda.synchronize() tic = time.time() - colors, _, _ = self.rasterize_splats( + renders, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, width=width, @@ -963,22 +986,44 @@ def eval(self, step: int, stage: str = "val"): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, + render_mode="RGB+ED", # if cfg.depth_loss else "RGB", masks=masks, ) # [1, H, W, 3] + if renders.shape[-1] == 4: + colors, depths = renders[..., 0:3], renders[..., 3:4] + else: + colors, depths = renders, None + torch.cuda.synchronize() ellipse_time += time.time() - tic colors = torch.clamp(colors, 0.0, 1.0) canvas_list = [pixels, colors] if stage == "train": - blur_mask = self.blur_module.blur_masks[image_ids[0]][None, ...] - blur_mask = torchvision.transforms.functional.gaussian_blur( - blur_mask, kernel_size=5 - ) - blur_mask = F.interpolate( - blur_mask, scale_factor=20, mode="bilinear", align_corners=False + grid_y, grid_x = torch.meshgrid( + (torch.arange(height, device=self.device) + 0.5) / height, + (torch.arange(width, device=self.device) + 0.5) / width, + indexing="ij", ) - blur_mask = torch.sigmoid(blur_mask)[0, ...][..., None] + grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) + x = grid_xy.reshape(-1, 2) + x = torch.cat([x, depths.reshape(x.shape[0], 1)], dim=-1) + mlp_out = self.blur_module.depth_mlps[image_ids[0]](x) + mlp_out = mlp_out - mlp_out.mean() + blur_mask = torch.sigmoid(mlp_out) + blur_mask = blur_mask.reshape(depths.shape) + + # blur_mask = blur_mask - blur_mask.mean() + 0.5 + + # blur_mask = self.blur_module.blur_masks[image_ids[0]][None, ...] + # blur_mask = torchvision.transforms.functional.gaussian_blur( + # blur_mask, kernel_size=3 + # ) + # blur_mask = F.interpolate( + # blur_mask, scale_factor=8, mode="bilinear", align_corners=False + # ) + # blur_mask = torch.sigmoid(blur_mask)[0, ...][..., None] + blur_mask_color = blur_mask.repeat(1, 1, 1, 3) canvas_list.append(blur_mask_color) From 052d2727ed27ff2b16cf9e9eee39491fea170860 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 14 Oct 2024 23:08:58 -0700 Subject: [PATCH 11/53] mlp mask --- examples/benchmarks/mcmc_deblur.sh | 1 + examples/blur_kernel.py | 2 +- examples/simple_trainer.py | 20 +++++++------------- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index 7973ddd3a..6dfb04676 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -1,6 +1,7 @@ SCENE_DIR="data/deblur_dataset/real_defocus_blur" SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools" + DATA_FACTOR=4 RENDER_TRAJ_PATH="spiral" diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index c60588c10..325d5f6e7 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -94,7 +94,7 @@ def __init__( self.depth_mlps = nn.ModuleList() for _ in range(n): mlp = nn.Sequential( - nn.Linear(3, 64, bias=False), + nn.Linear(1, 64, bias=False), nn.ReLU(), nn.Linear(64, 64, bias=False), nn.ReLU(), diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 694d30d9b..bdd63712e 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -702,15 +702,10 @@ def train(self): indexing="ij", ) grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) - x = grid_xy.reshape(-1, 2) - x = torch.cat([x, depths.reshape(x.shape[0], 1)], dim=-1) - mlp_out = self.blur_module.depth_mlps[image_ids[0]](x) - mlp_out = mlp_out - mlp_out.mean() - print(mlp_out.min(), mlp_out.max(), mlp_out.mean()) + # x = torch.cat([colors, depths, grid_xy], dim=-1) + mlp_out = self.blur_module.depth_mlps[image_ids[0]](depths) + mlp_out = mlp_out - torch.quantile(mlp_out, 0.25) blur_mask = torch.sigmoid(mlp_out) - blur_mask = blur_mask.reshape(depths.shape) - - # blur_mask = blur_mask - blur_mask.mean() + 0.5 # blur_mask = self.blur_module.blur_masks[image_ids[0]][None, ...] # blur_mask = torchvision.transforms.functional.gaussian_blur( @@ -1006,13 +1001,12 @@ def eval(self, step: int, stage: str = "val"): indexing="ij", ) grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) - x = grid_xy.reshape(-1, 2) - x = torch.cat([x, depths.reshape(x.shape[0], 1)], dim=-1) - mlp_out = self.blur_module.depth_mlps[image_ids[0]](x) - mlp_out = mlp_out - mlp_out.mean() + # x = torch.cat([colors, depths, grid_xy], dim=-1) + mlp_out = self.blur_module.depth_mlps[image_ids[0]](depths) + mlp_out = mlp_out - torch.quantile(mlp_out, 0.25) blur_mask = torch.sigmoid(mlp_out) - blur_mask = blur_mask.reshape(depths.shape) + # blur_mask = blur_mask.reshape(depths.shape) # blur_mask = blur_mask - blur_mask.mean() + 0.5 # blur_mask = self.blur_module.blur_masks[image_ids[0]][None, ...] From 3d4b5c04003c54260fb433abf558b673716a6cd4 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Tue, 15 Oct 2024 18:10:47 -0700 Subject: [PATCH 12/53] no need for quantile --- examples/blur_kernel.py | 3 ++- examples/simple_trainer.py | 27 ++++++++++++++++++++------- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index 325d5f6e7..ba2ed2e45 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -91,10 +91,11 @@ def __init__( -1.0 * torch.ones(n, 1, 400 // 8, 600 // 8) ) # self.image_feats = torch.nn.Parameter(torch.rand(n, 4)) + self.embed_depth, self.embed_depth_cnl = get_embedder(3, 3) self.depth_mlps = nn.ModuleList() for _ in range(n): mlp = nn.Sequential( - nn.Linear(1, 64, bias=False), + nn.Linear(self.embed_depth_cnl, 64, bias=False), nn.ReLU(), nn.Linear(64, 64, bias=False), nn.ReLU(), diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index bdd63712e..ea98b0a8e 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -157,7 +157,7 @@ class Config: # Learning rate for blur optimization blur_opt_lr: float = 1e-3 # Regularization for blur optimization as weight decay - blur_opt_reg: float = 1e-4 + blur_opt_reg: float = 1e-6 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False @@ -702,9 +702,9 @@ def train(self): indexing="ij", ) grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) - # x = torch.cat([colors, depths, grid_xy], dim=-1) - mlp_out = self.blur_module.depth_mlps[image_ids[0]](depths) - mlp_out = mlp_out - torch.quantile(mlp_out, 0.25) + x = torch.cat([grid_xy, depths], dim=-1) + x_embed = self.blur_module.embed_depth(x) + mlp_out = self.blur_module.depth_mlps[image_ids[0]](x_embed) blur_mask = torch.sigmoid(mlp_out) # blur_mask = self.blur_module.blur_masks[image_ids[0]][None, ...] @@ -775,6 +775,19 @@ def train(self): # mlp_reg = torch.abs().mean() # loss += 0.0001 * mlp_reg + mask_mean = torch.abs(blur_mask).mean() + mask_std = torch.std(blur_mask) + print(mask_mean, mask_std) + lambda_mean = 0.001 + lambda_std = 0.001 + if step >= 2000: + lambda_mean = 0.001 + lambda_std = 0.01 + loss += ( + lambda_mean * (mask_mean - 0.5) ** 2 + + lambda_std * (mask_std - 0.4) ** 2 + ) + # delta_reg = torch.abs(info["scales_delta"]).mean() # loss += 0.01 * delta_reg @@ -1001,9 +1014,9 @@ def eval(self, step: int, stage: str = "val"): indexing="ij", ) grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) - # x = torch.cat([colors, depths, grid_xy], dim=-1) - mlp_out = self.blur_module.depth_mlps[image_ids[0]](depths) - mlp_out = mlp_out - torch.quantile(mlp_out, 0.25) + x = torch.cat([grid_xy, depths], dim=-1) + x_embed = self.blur_module.embed_depth(x) + mlp_out = self.blur_module.depth_mlps[image_ids[0]](x_embed) blur_mask = torch.sigmoid(mlp_out) # blur_mask = blur_mask.reshape(depths.shape) From 649acd036fa40a3c4383a98e5c04d32c960de5a9 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Tue, 15 Oct 2024 19:51:10 -0700 Subject: [PATCH 13/53] single mlp --- examples/blur_kernel.py | 30 ++++++++++++------------------ examples/simple_trainer.py | 22 ++++++++++------------ 2 files changed, 22 insertions(+), 30 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index ba2ed2e45..5551850f6 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -87,25 +87,19 @@ def __init__( num_moments=4, ): super().__init__() - self.blur_masks = torch.nn.Parameter( - -1.0 * torch.ones(n, 1, 400 // 8, 600 // 8) - ) - # self.image_feats = torch.nn.Parameter(torch.rand(n, 4)) + self.image_feats = torch.eye(n).cuda() self.embed_depth, self.embed_depth_cnl = get_embedder(3, 3) - self.depth_mlps = nn.ModuleList() - for _ in range(n): - mlp = nn.Sequential( - nn.Linear(self.embed_depth_cnl, 64, bias=False), - nn.ReLU(), - nn.Linear(64, 64, bias=False), - nn.ReLU(), - nn.Linear(64, 64, bias=False), - nn.ReLU(), - nn.Linear(64, 64, bias=False), - nn.ReLU(), - nn.Linear(64, 1, bias=False), - ).to("cuda") - self.depth_mlps.append(mlp) + self.depth_mlp = nn.Sequential( + nn.Linear(self.embed_depth_cnl + n, 64, bias=False), + nn.ReLU(), + nn.Linear(64, 64, bias=False), + nn.ReLU(), + nn.Linear(64, 64, bias=False), + nn.ReLU(), + nn.Linear(64, 64, bias=False), + nn.ReLU(), + nn.Linear(64, 1, bias=False), + ).to("cuda") self.pos_delta = pos_delta self.num_moments = num_moments diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index ea98b0a8e..b39ae7a1d 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -703,8 +703,10 @@ def train(self): ) grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) x = torch.cat([grid_xy, depths], dim=-1) - x_embed = self.blur_module.embed_depth(x) - mlp_out = self.blur_module.depth_mlps[image_ids[0]](x_embed) + x = self.blur_module.embed_depth(x) + x_img = self.blur_module.image_feats[image_ids[0]][None, None, None, :].repeat(1, height, width, 1) + x = torch.cat([x, x_img], dim=-1) + mlp_out = self.blur_module.depth_mlp(x) blur_mask = torch.sigmoid(mlp_out) # blur_mask = self.blur_module.blur_masks[image_ids[0]][None, ...] @@ -772,20 +774,14 @@ def train(self): + cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() ) - # mlp_reg = torch.abs().mean() - # loss += 0.0001 * mlp_reg - - mask_mean = torch.abs(blur_mask).mean() + mask_mean = torch.mean(blur_mask) mask_std = torch.std(blur_mask) print(mask_mean, mask_std) lambda_mean = 0.001 lambda_std = 0.001 - if step >= 2000: - lambda_mean = 0.001 - lambda_std = 0.01 loss += ( lambda_mean * (mask_mean - 0.5) ** 2 - + lambda_std * (mask_std - 0.4) ** 2 + + lambda_std * (mask_std - 0.5) ** 2 ) # delta_reg = torch.abs(info["scales_delta"]).mean() @@ -1015,8 +1011,10 @@ def eval(self, step: int, stage: str = "val"): ) grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) x = torch.cat([grid_xy, depths], dim=-1) - x_embed = self.blur_module.embed_depth(x) - mlp_out = self.blur_module.depth_mlps[image_ids[0]](x_embed) + x = self.blur_module.embed_depth(x) + x_img = self.blur_module.image_feats[image_ids[0]][None, None, None, :].repeat(1, height, width, 1) + x = torch.cat([x, x_img], dim=-1) + mlp_out = self.blur_module.depth_mlp(x) blur_mask = torch.sigmoid(mlp_out) # blur_mask = blur_mask.reshape(depths.shape) From fad9b21018c15c33439bcb74f3c440f4d7650d32 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Tue, 15 Oct 2024 19:54:50 -0700 Subject: [PATCH 14/53] cleanup --- examples/simple_trainer.py | 40 +++++++------------------------------- 1 file changed, 7 insertions(+), 33 deletions(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index b39ae7a1d..937d8049f 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -11,7 +11,6 @@ import numpy as np import torch import torch.nn.functional as F -import torchvision import tqdm import tyro import viser @@ -695,7 +694,6 @@ def train(self): masks=masks, blur=True, ) - grid_y, grid_x = torch.meshgrid( (torch.arange(height, device=self.device) + 0.5) / height, (torch.arange(width, device=self.device) + 0.5) / width, @@ -704,24 +702,12 @@ def train(self): grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) x = torch.cat([grid_xy, depths], dim=-1) x = self.blur_module.embed_depth(x) - x_img = self.blur_module.image_feats[image_ids[0]][None, None, None, :].repeat(1, height, width, 1) + x_img = self.blur_module.image_feats[image_ids[0]][ + None, None, None, : + ].repeat(1, height, width, 1) x = torch.cat([x, x_img], dim=-1) mlp_out = self.blur_module.depth_mlp(x) blur_mask = torch.sigmoid(mlp_out) - - # blur_mask = self.blur_module.blur_masks[image_ids[0]][None, ...] - # blur_mask = torchvision.transforms.functional.gaussian_blur( - # blur_mask, kernel_size=3 - # ) - # blur_mask = F.interpolate( - # blur_mask, scale_factor=8, mode="bilinear", align_corners=False - # ) - # blur_mask = torch.sigmoid(blur_mask)[0, ...][..., None] - print( - blur_mask.min().item(), - blur_mask.max().item(), - blur_mask.mean().item(), - ) colors = (1 - blur_mask) * colors + blur_mask * renders_blur[..., 0:3] self.cfg.strategy.step_pre_backward( @@ -1012,25 +998,13 @@ def eval(self, step: int, stage: str = "val"): grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) x = torch.cat([grid_xy, depths], dim=-1) x = self.blur_module.embed_depth(x) - x_img = self.blur_module.image_feats[image_ids[0]][None, None, None, :].repeat(1, height, width, 1) + x_img = self.blur_module.image_feats[image_ids[0]][ + None, None, None, : + ].repeat(1, height, width, 1) x = torch.cat([x, x_img], dim=-1) mlp_out = self.blur_module.depth_mlp(x) blur_mask = torch.sigmoid(mlp_out) - - # blur_mask = blur_mask.reshape(depths.shape) - # blur_mask = blur_mask - blur_mask.mean() + 0.5 - - # blur_mask = self.blur_module.blur_masks[image_ids[0]][None, ...] - # blur_mask = torchvision.transforms.functional.gaussian_blur( - # blur_mask, kernel_size=3 - # ) - # blur_mask = F.interpolate( - # blur_mask, scale_factor=8, mode="bilinear", align_corners=False - # ) - # blur_mask = torch.sigmoid(blur_mask)[0, ...][..., None] - - blur_mask_color = blur_mask.repeat(1, 1, 1, 3) - canvas_list.append(blur_mask_color) + canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) colors_blur, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, From 2ae57b9cc0a4faeb603dda8320b244bda429e259 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Tue, 15 Oct 2024 19:58:47 -0700 Subject: [PATCH 15/53] cleanup --- examples/simple_trainer.py | 43 +++----------------------------------- 1 file changed, 3 insertions(+), 40 deletions(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 937d8049f..47755a635 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -501,32 +501,10 @@ def rasterize_splats( rotations_delta = self.splats["embeds"][:, image_ids[0], 3:] scales_delta = torch.clamp(scales_delta, min=0.0, max=0.1) rotations_delta = torch.clamp(rotations_delta, min=-0.05, max=0.05) - - # means_ = log_transform(means.detach()) - # scales_ = self.splats["scales"].detach() - # quats_ = F.normalize(self.splats["quats"], dim=-1).detach() - # viewdir_ = 0.1 * camtoworlds[0, :3, 3].repeat(means.shape[0], 1) - # scales_delta, rotations_delta, _ = self.blur_module( - # means_, - # scales_, - # quats_, - # viewdir_, - # ) - # scales_delta = 0.0001 * scales_delta - # rotations_delta = 0.0001 * rotations_delta - # scales_delta = torch.clamp(scales_delta, min=0.0, max=0.1) - # rotations_delta = torch.clamp(rotations_delta, min=-0.05, max=0.05) - - # print("scales_delta", scales_delta.min().item(), scales_delta.max().item(), scales_delta.mean().item()) - # print("rotations_delta", rotations_delta.min().item(), rotations_delta.max().item(), rotations_delta.mean().item()) scales = torch.exp(self.splats["scales"] + scales_delta) quats = F.normalize(self.splats["quats"], dim=-1) + rotations_delta quats = F.normalize(quats, dim=-1) - # if image_ids[0] == 0: - # scales_delta = 0.0 * scales_delta - # rotations_delta = 0.0 * rotations_delta - rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" render_colors, render_alphas, info = rasterization( means=means, @@ -552,9 +530,6 @@ def rasterize_splats( ) if masks is not None: render_colors[~masks] = 0 - - # if self.cfg.blur_opt and is_train: - # info["scales_delta"] = scales_delta return render_colors, render_alphas, info def train(self): @@ -759,23 +734,11 @@ def train(self): loss + cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() ) - - mask_mean = torch.mean(blur_mask) - mask_std = torch.std(blur_mask) - print(mask_mean, mask_std) - lambda_mean = 0.001 - lambda_std = 0.001 + if cfg.blur_opt: loss += ( - lambda_mean * (mask_mean - 0.5) ** 2 - + lambda_std * (mask_std - 0.5) ** 2 + 0.001 * (torch.mean(blur_mask) - 0.5) ** 2 + + 0.001 * (torch.std(blur_mask) - 0.5) ** 2 ) - - # delta_reg = torch.abs(info["scales_delta"]).mean() - # loss += 0.01 * delta_reg - - # embed_reg = torch.abs(self.splats["embeds"]).mean() - # # embed_reg = torch.log(1 + torch.abs(self.splats["embeds"])).mean() - # loss += 10.0 * embed_reg loss.backward() desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " From 904979e24efa5293f11c6dea1e35cb59be1ee1c6 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Tue, 15 Oct 2024 21:07:16 -0700 Subject: [PATCH 16/53] focal embedding --- examples/blur_kernel.py | 6 +++--- examples/simple_trainer.py | 14 +++++++++----- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index 5551850f6..9ef44e41c 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -87,10 +87,10 @@ def __init__( num_moments=4, ): super().__init__() - self.image_feats = torch.eye(n).cuda() + self.focals = torch.nn.Embedding(n, 1) self.embed_depth, self.embed_depth_cnl = get_embedder(3, 3) self.depth_mlp = nn.Sequential( - nn.Linear(self.embed_depth_cnl + n, 64, bias=False), + nn.Linear(self.embed_depth_cnl + 1, 64, bias=False), nn.ReLU(), nn.Linear(64, 64, bias=False), nn.ReLU(), @@ -99,7 +99,7 @@ def __init__( nn.Linear(64, 64, bias=False), nn.ReLU(), nn.Linear(64, 1, bias=False), - ).to("cuda") + ) self.pos_delta = pos_delta self.num_moments = num_moments diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 47755a635..e25e3df34 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -417,10 +417,14 @@ def __init__( self.blur_module = GTnet(len(self.trainset)).to(self.device) self.blur_optimizers = [ torch.optim.Adam( - self.blur_module.parameters(), - lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size), + self.blur_module.focals.parameters(), + lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size) * 10.0, weight_decay=cfg.blur_opt_reg, ), + torch.optim.Adam( + self.blur_module.depth_mlp.parameters(), + lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size), + ), ] if world_size > 1: self.blur_module = DDP(self.blur_module) @@ -677,7 +681,7 @@ def train(self): grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) x = torch.cat([grid_xy, depths], dim=-1) x = self.blur_module.embed_depth(x) - x_img = self.blur_module.image_feats[image_ids[0]][ + x_img = self.blur_module.focals(image_ids[0])[ None, None, None, : ].repeat(1, height, width, 1) x = torch.cat([x, x_img], dim=-1) @@ -736,7 +740,7 @@ def train(self): ) if cfg.blur_opt: loss += ( - 0.001 * (torch.mean(blur_mask) - 0.5) ** 2 + 0.01 * (torch.mean(blur_mask) - 0.5) ** 2 + 0.001 * (torch.std(blur_mask) - 0.5) ** 2 ) loss.backward() @@ -961,7 +965,7 @@ def eval(self, step: int, stage: str = "val"): grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) x = torch.cat([grid_xy, depths], dim=-1) x = self.blur_module.embed_depth(x) - x_img = self.blur_module.image_feats[image_ids[0]][ + x_img = self.blur_module.focals(image_ids[0])[ None, None, None, : ].repeat(1, height, width, 1) x = torch.cat([x, x_img], dim=-1) From c367c0a8fe4eb815e073977292ec4c6b993cd719 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Tue, 15 Oct 2024 23:25:10 -0700 Subject: [PATCH 17/53] init blur --- examples/blur_kernel.py | 3 ++- examples/simple_trainer.py | 31 ++++++++++++------------------- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index 9ef44e41c..0f340043d 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -88,7 +88,8 @@ def __init__( ): super().__init__() self.focals = torch.nn.Embedding(n, 1) - self.embed_depth, self.embed_depth_cnl = get_embedder(3, 3) + self.focals.weight.data = torch.linspace(-1, 1, n)[:, None] + self.embed_depth, self.embed_depth_cnl = get_embedder(10, 1) self.depth_mlp = nn.Sequential( nn.Linear(self.embed_depth_cnl + 1, 64, bias=False), nn.ReLU(), diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index e25e3df34..2277b0a48 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -418,7 +418,7 @@ def __init__( self.blur_optimizers = [ torch.optim.Adam( self.blur_module.focals.parameters(), - lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size) * 10.0, + lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size), weight_decay=cfg.blur_opt_reg, ), torch.optim.Adam( @@ -673,14 +673,7 @@ def train(self): masks=masks, blur=True, ) - grid_y, grid_x = torch.meshgrid( - (torch.arange(height, device=self.device) + 0.5) / height, - (torch.arange(width, device=self.device) + 0.5) / width, - indexing="ij", - ) - grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) - x = torch.cat([grid_xy, depths], dim=-1) - x = self.blur_module.embed_depth(x) + x = self.blur_module.embed_depth(depths) x_img = self.blur_module.focals(image_ids[0])[ None, None, None, : ].repeat(1, height, width, 1) @@ -739,9 +732,16 @@ def train(self): + cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() ) if cfg.blur_opt: + if step < 2000: + lambda_mean = 0.01 + lambda_std = 0.001 + else: + lambda_mean = 0.0 + lambda_std = 0.0 + print(lambda_mean, lambda_std, blur_mask.mean().item(), blur_mask.std().item()) loss += ( - 0.01 * (torch.mean(blur_mask) - 0.5) ** 2 - + 0.001 * (torch.std(blur_mask) - 0.5) ** 2 + lambda_mean * (torch.mean(blur_mask) - 0.5) ** 2 + + lambda_std * (torch.std(blur_mask) - 0.5) ** 2 ) loss.backward() @@ -957,14 +957,7 @@ def eval(self, step: int, stage: str = "val"): colors = torch.clamp(colors, 0.0, 1.0) canvas_list = [pixels, colors] if stage == "train": - grid_y, grid_x = torch.meshgrid( - (torch.arange(height, device=self.device) + 0.5) / height, - (torch.arange(width, device=self.device) + 0.5) / width, - indexing="ij", - ) - grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) - x = torch.cat([grid_xy, depths], dim=-1) - x = self.blur_module.embed_depth(x) + x = self.blur_module.embed_depth(depths) x_img = self.blur_module.focals(image_ids[0])[ None, None, None, : ].repeat(1, height, width, 1) From ba514edcb3d93ce87492666d867b45c801a99894 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 16 Oct 2024 10:30:51 -0700 Subject: [PATCH 18/53] reg to prevent collapse --- examples/benchmarks/mcmc_deblur.sh | 4 ++-- examples/simple_trainer.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index 6dfb04676..ba844c804 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -1,6 +1,6 @@ SCENE_DIR="data/deblur_dataset/real_defocus_blur" -SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools" - +# SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools" +SCENE_LIST="defocuscupcake" DATA_FACTOR=4 RENDER_TRAJ_PATH="spiral" diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 2277b0a48..532868093 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -418,7 +418,7 @@ def __init__( self.blur_optimizers = [ torch.optim.Adam( self.blur_module.focals.parameters(), - lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size), + lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size) * 10, weight_decay=cfg.blur_opt_reg, ), torch.optim.Adam( @@ -736,9 +736,14 @@ def train(self): lambda_mean = 0.01 lambda_std = 0.001 else: - lambda_mean = 0.0 - lambda_std = 0.0 - print(lambda_mean, lambda_std, blur_mask.mean().item(), blur_mask.std().item()) + lambda_mean = 0.001 + lambda_std = 0.001 + print( + lambda_mean, + lambda_std, + blur_mask.mean().item(), + blur_mask.std().item(), + ) loss += ( lambda_mean * (torch.mean(blur_mask) - 0.5) ** 2 + lambda_std * (torch.std(blur_mask) - 0.5) ** 2 From 6da4119b7d0505ec6f41b9afd94ddfce1fd116f6 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 16 Oct 2024 14:39:17 -0700 Subject: [PATCH 19/53] mlp --- examples/blur_kernel.py | 129 +++++++++++++++++-------------- examples/external.py | 58 ++++++++++++++ examples/mlp.py | 154 +++++++++++++++++++++++++++++++++++++ examples/simple_trainer.py | 27 ++----- 4 files changed, 288 insertions(+), 80 deletions(-) create mode 100644 examples/external.py create mode 100644 examples/mlp.py diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index 0f340043d..fd24a5135 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -13,6 +13,8 @@ import torch.nn as nn import numpy as np import torch.nn.functional as F +from examples.mlp import create_mlp, _create_mlp_torch, _create_mlp_tcnn +from gsplat.utils import log_transform class Embedder: @@ -89,64 +91,73 @@ def __init__( super().__init__() self.focals = torch.nn.Embedding(n, 1) self.focals.weight.data = torch.linspace(-1, 1, n)[:, None] - self.embed_depth, self.embed_depth_cnl = get_embedder(10, 1) - self.depth_mlp = nn.Sequential( - nn.Linear(self.embed_depth_cnl + 1, 64, bias=False), - nn.ReLU(), - nn.Linear(64, 64, bias=False), - nn.ReLU(), - nn.Linear(64, 64, bias=False), - nn.ReLU(), - nn.Linear(64, 64, bias=False), - nn.ReLU(), - nn.Linear(64, 1, bias=False), + self.embed_depth, self.embed_depth_cnl = get_embedder(14, 1) + self.depth_mlp = _create_mlp_torch( + in_dim=self.embed_depth_cnl + 1, + num_layers=5, + layer_width=64, + out_dim=1, ) - self.pos_delta = pos_delta - self.num_moments = num_moments - - self.embed_pos, self.embed_pos_cnl = get_embedder(res_pos, 3) - self.embed_view, self.embed_view_cnl = get_embedder(res_view, 3) - in_cnl = ( - self.embed_pos_cnl + self.embed_view_cnl + 7 - ) # 7 for scales and rotations - - hiddens = [ - nn.Linear(width, width) if i % 2 == 0 else nn.ReLU() - for i in range((num_hidden - 1) * 2) - ] - - self.linears = nn.Sequential( - nn.Linear(in_cnl, width), - nn.ReLU(), - *hiddens, - ).to("cuda") - if not pos_delta: # Defocus - self.s = nn.Linear(width, 3).to("cuda") - self.r = nn.Linear(width, 4).to("cuda") - else: # Motion - self.s = nn.Linear(width, 3 * (num_moments + 1)).to("cuda") - self.r = nn.Linear(width, 4 * (num_moments + 1)).to("cuda") - self.p = nn.Linear(width, 3 * num_moments).to("cuda") - - self.linears.apply(init_linear_weights) - self.s.apply(init_linear_weights) - self.r.apply(init_linear_weights) - if pos_delta: - self.p.apply(init_linear_weights) - - def forward(self, pos, scales, rotations, viewdirs): - pos_delta = None - pos = self.embed_pos(pos) - viewdirs = self.embed_view(viewdirs) - - x = torch.cat([pos, viewdirs, scales, rotations], dim=-1) - x1 = self.linears(x) - - scales_delta = self.s(x1) - rotations_delta = self.r(x1) - - if self.pos_delta: - pos_delta = self.p(x1) - - return scales_delta, rotations_delta, pos_delta + # self.pos_delta = pos_delta + # self.num_moments = num_moments + + # self.embed_pos, self.embed_pos_cnl = get_embedder(res_pos, 3) + # self.embed_view, self.embed_view_cnl = get_embedder(res_view, 3) + # in_cnl = ( + # self.embed_pos_cnl + self.embed_view_cnl + 7 + # ) # 7 for scales and rotations + + # hiddens = [ + # nn.Linear(width, width) if i % 2 == 0 else nn.ReLU() + # for i in range((num_hidden - 1) * 2) + # ] + + # self.linears = nn.Sequential( + # nn.Linear(in_cnl, width), + # nn.ReLU(), + # *hiddens, + # ).to("cuda") + # if not pos_delta: # Defocus + # self.s = nn.Linear(width, 3).to("cuda") + # self.r = nn.Linear(width, 4).to("cuda") + # else: # Motion + # self.s = nn.Linear(width, 3 * (num_moments + 1)).to("cuda") + # self.r = nn.Linear(width, 4 * (num_moments + 1)).to("cuda") + # self.p = nn.Linear(width, 3 * num_moments).to("cuda") + + # self.linears.apply(init_linear_weights) + # self.s.apply(init_linear_weights) + # self.r.apply(init_linear_weights) + # if pos_delta: + # self.p.apply(init_linear_weights) + + def forward(self, depths, image_ids): + height, width = depths.shape[1:3] + + depths_emb = self.embed_depth(depths) + focals_emb = self.focals(image_ids[0])[None, None, None, :].repeat( + 1, height, width, 1 + ) + x = torch.cat([depths_emb, focals_emb], dim=-1) + x = x.reshape(-1, x.shape[-1]).half().float() + blur_mask = self.depth_mlp(x).float() + # blur_mask = torch.sigmoid(mlp_out) + blur_mask = blur_mask.reshape(1, height, width, 1) + return blur_mask + + # def forward(self, pos, scales, rotations, viewdirs): + # pos_delta = None + # pos = self.embed_pos(pos) + # viewdirs = self.embed_view(viewdirs) + + # x = torch.cat([pos, viewdirs, scales, rotations], dim=-1) + # x1 = self.linears(x) + + # scales_delta = self.s(x1) + # rotations_delta = self.r(x1) + + # if self.pos_delta: + # pos_delta = self.p(x1) + + # return scales_delta, rotations_delta, pos_delta diff --git a/examples/external.py b/examples/external.py new file mode 100644 index 000000000..f14219bd5 --- /dev/null +++ b/examples/external.py @@ -0,0 +1,58 @@ +# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + + +class _LazyError: + def __init__(self, data): + self.__data = data # pylint: disable=unused-private-member + + class LazyErrorObj: + def __init__(self, data): + self.__data = data # pylint: disable=unused-private-member + + def __call__(self, *args, **kwds): + name, exc = object.__getattribute__(self, "__data") + raise RuntimeError(f"Could not load package {name}.") from exc + + def __getattr__(self, __name: str): + name, exc = object.__getattribute__(self, "__data") + raise RuntimeError(f"Could not load package {name}") from exc + + def __getattr__(self, __name: str): + return _LazyError.LazyErrorObj(object.__getattribute__(self, "__data")) + + +TCNN_EXISTS = False +tcnn_import_exception = None +tcnn = None +try: + import tinycudann + + tcnn = tinycudann + del tinycudann + TCNN_EXISTS = True +except ModuleNotFoundError as _exp: + tcnn_import_exception = _exp +except ImportError as _exp: + tcnn_import_exception = _exp +except EnvironmentError as _exp: + if "Unknown compute capability" not in _exp.args[0]: + raise _exp + print("Could not load tinycudann: " + str(_exp), file=sys.stderr) + tcnn_import_exception = _exp + +if tcnn_import_exception is not None: + tcnn = _LazyError(tcnn_import_exception) diff --git a/examples/mlp.py b/examples/mlp.py new file mode 100644 index 000000000..4ad80c2d4 --- /dev/null +++ b/examples/mlp.py @@ -0,0 +1,154 @@ +# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Multi Layer Perceptron +""" + +from typing import Union + +import torch +from torch import nn + +from examples.external import TCNN_EXISTS, tcnn + + +def activation_to_tcnn_string(activation: Union[nn.Module, None]) -> str: + """Converts a torch.nn activation function to a string that can be used to + initialize a TCNN activation function. + + Args: + activation: torch.nn activation function + Returns: + str: TCNN activation function string + """ + + if isinstance(activation, nn.ReLU): + return "ReLU" + if isinstance(activation, nn.LeakyReLU): + return "Leaky ReLU" + if isinstance(activation, nn.Sigmoid): + return "Sigmoid" + if isinstance(activation, nn.Softplus): + return "Softplus" + if isinstance(activation, nn.Tanh): + return "Tanh" + if isinstance(activation, type(None)): + return "None" + tcnn_documentation_url = "https://github.com/NVlabs/tiny-cuda-nn/blob/master/DOCUMENTATION.md#activation-functions" + raise ValueError( + f"TCNN activation {activation} not supported for now.\nSee {tcnn_documentation_url} for TCNN documentation." + ) + + +def get_tcnn_network_config( + activation, out_activation, layer_width, num_layers +) -> dict: + """Get the network configuration for tcnn if implemented""" + activation_str = activation_to_tcnn_string(activation) + output_activation_str = activation_to_tcnn_string(out_activation) + assert layer_width in [16, 32, 64, 128] + network_config = { + "otype": "FullyFusedMLP", + "activation": activation_str, + "output_activation": output_activation_str, + "n_neurons": layer_width, + "n_hidden_layers": num_layers - 1, + } + return network_config + + +def create_mlp( + in_dim: int, + num_layers: int, + layer_width: int, + out_dim: int, + initialize_last_layer_zeros: bool = False, +): + if TCNN_EXISTS: + return _create_mlp_tcnn( + in_dim, num_layers, layer_width, out_dim, initialize_last_layer_zeros + ) + else: + return _create_mlp_torch( + in_dim, num_layers, layer_width, out_dim, initialize_last_layer_zeros + ) + + +def _create_mlp_tcnn( + in_dim: int, + num_layers: int, + layer_width: int, + out_dim: int, + initialize_last_layer_zeros: bool = False, +): + """Create a fully-connected neural network with tiny-cuda-nn.""" + network_config = get_tcnn_network_config( + activation=nn.ReLU(), + out_activation=nn.Sigmoid(), + layer_width=layer_width, + num_layers=num_layers, + ) + print(network_config) + tcnn_encoding = tcnn.Network( + n_input_dims=in_dim, + n_output_dims=out_dim, + network_config=network_config, + ) + + mlp_torch = _create_mlp_torch(in_dim, num_layers, layer_width, out_dim) + print(mlp_torch) + output_layer = mlp_torch[8].weight.data + output_layer = nn.functional.pad(output_layer, pad=(0, 0, 0, 16 - (out_dim % 16))) + params = torch.cat( + [ + mlp_torch[0].weight.data.flatten(), + mlp_torch[2].weight.data.flatten(), + mlp_torch[4].weight.data.flatten(), + mlp_torch[6].weight.data.flatten(), + output_layer.flatten(), + ] + ).half() + tcnn_encoding.params.data[...] = params + + # if initialize_last_layer_zeros: + # # tcnn always pads the output layer's width to a multiple of 16 + # params = tcnn_encoding.state_dict()["params"] + # params[-1 * (layer_width * 16 * (out_dim // 16 + 1)) :] = 0 + # tcnn_encoding.load_state_dict({"params": params}) + return tcnn_encoding + + +def _create_mlp_torch( + in_dim: int, + num_layers: int, + layer_width: int, + out_dim: int, + initialize_last_layer_zeros: bool = False, +): + """Create a fully-connected neural network with PyTorch.""" + layers = [] + layer_in = in_dim + for i in range(num_layers): + layer_out = layer_width if i != num_layers - 1 else out_dim + layers.append(nn.Linear(layer_in, layer_out, bias=False)) + if i != num_layers - 1: + layers.append(nn.ReLU()) + else: + layers.append(nn.Sigmoid()) + layer_in = layer_width + + if initialize_last_layer_zeros: + nn.init.zeros_(layers[-1].weight) + return nn.Sequential(*layers) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 532868093..060cfeb42 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -418,12 +418,13 @@ def __init__( self.blur_optimizers = [ torch.optim.Adam( self.blur_module.focals.parameters(), - lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size) * 10, + lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size), weight_decay=cfg.blur_opt_reg, ), torch.optim.Adam( self.blur_module.depth_mlp.parameters(), lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size), + weight_decay=cfg.blur_opt_reg, ), ] if world_size > 1: @@ -673,13 +674,7 @@ def train(self): masks=masks, blur=True, ) - x = self.blur_module.embed_depth(depths) - x_img = self.blur_module.focals(image_ids[0])[ - None, None, None, : - ].repeat(1, height, width, 1) - x = torch.cat([x, x_img], dim=-1) - mlp_out = self.blur_module.depth_mlp(x) - blur_mask = torch.sigmoid(mlp_out) + blur_mask = self.blur_module(depths, image_ids) colors = (1 - blur_mask) * colors + blur_mask * renders_blur[..., 0:3] self.cfg.strategy.step_pre_backward( @@ -732,12 +727,8 @@ def train(self): + cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() ) if cfg.blur_opt: - if step < 2000: - lambda_mean = 0.01 - lambda_std = 0.001 - else: - lambda_mean = 0.001 - lambda_std = 0.001 + lambda_mean = 0.01 + lambda_std = 0.001 print( lambda_mean, lambda_std, @@ -962,13 +953,7 @@ def eval(self, step: int, stage: str = "val"): colors = torch.clamp(colors, 0.0, 1.0) canvas_list = [pixels, colors] if stage == "train": - x = self.blur_module.embed_depth(depths) - x_img = self.blur_module.focals(image_ids[0])[ - None, None, None, : - ].repeat(1, height, width, 1) - x = torch.cat([x, x_img], dim=-1) - mlp_out = self.blur_module.depth_mlp(x) - blur_mask = torch.sigmoid(mlp_out) + blur_mask = self.blur_module(depths, image_ids) canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) colors_blur, _, _ = self.rasterize_splats( From 5714315a83667f7774653280db534f638577fa86 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 16 Oct 2024 15:26:23 -0700 Subject: [PATCH 20/53] deltas mlp --- examples/blur_kernel.py | 89 ++++++++++++++++++-------------------- examples/mlp.py | 30 +++---------- examples/simple_trainer.py | 75 ++++++++++++++++---------------- 3 files changed, 86 insertions(+), 108 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index fd24a5135..f667a3d6b 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -99,38 +99,35 @@ def __init__( out_dim=1, ) - # self.pos_delta = pos_delta - # self.num_moments = num_moments - - # self.embed_pos, self.embed_pos_cnl = get_embedder(res_pos, 3) - # self.embed_view, self.embed_view_cnl = get_embedder(res_view, 3) - # in_cnl = ( - # self.embed_pos_cnl + self.embed_view_cnl + 7 - # ) # 7 for scales and rotations - - # hiddens = [ - # nn.Linear(width, width) if i % 2 == 0 else nn.ReLU() - # for i in range((num_hidden - 1) * 2) - # ] - - # self.linears = nn.Sequential( - # nn.Linear(in_cnl, width), - # nn.ReLU(), - # *hiddens, - # ).to("cuda") - # if not pos_delta: # Defocus - # self.s = nn.Linear(width, 3).to("cuda") - # self.r = nn.Linear(width, 4).to("cuda") - # else: # Motion - # self.s = nn.Linear(width, 3 * (num_moments + 1)).to("cuda") - # self.r = nn.Linear(width, 4 * (num_moments + 1)).to("cuda") - # self.p = nn.Linear(width, 3 * num_moments).to("cuda") - - # self.linears.apply(init_linear_weights) - # self.s.apply(init_linear_weights) - # self.r.apply(init_linear_weights) - # if pos_delta: - # self.p.apply(init_linear_weights) + self.pos_delta = pos_delta + self.num_moments = num_moments + self.embed_pos, self.embed_pos_cnl = get_embedder(res_pos, 3) + self.embed_view, self.embed_view_cnl = get_embedder(res_view, 3) + in_cnl = ( + self.embed_pos_cnl + self.embed_view_cnl + 7 + ) # 7 for scales and rotations + hiddens = [ + nn.Linear(width, width) if i % 2 == 0 else nn.ReLU() + for i in range((num_hidden - 1) * 2) + ] + self.linears = nn.Sequential( + nn.Linear(in_cnl, width), + nn.ReLU(), + *hiddens, + ).to("cuda") + if not pos_delta: # Defocus + self.s = nn.Linear(width, 3).to("cuda") + self.r = nn.Linear(width, 4).to("cuda") + else: # Motion + self.s = nn.Linear(width, 3 * (num_moments + 1)).to("cuda") + self.r = nn.Linear(width, 4 * (num_moments + 1)).to("cuda") + self.p = nn.Linear(width, 3 * num_moments).to("cuda") + + self.linears.apply(init_linear_weights) + self.s.apply(init_linear_weights) + self.r.apply(init_linear_weights) + if pos_delta: + self.p.apply(init_linear_weights) def forward(self, depths, image_ids): height, width = depths.shape[1:3] @@ -140,24 +137,24 @@ def forward(self, depths, image_ids): 1, height, width, 1 ) x = torch.cat([depths_emb, focals_emb], dim=-1) - x = x.reshape(-1, x.shape[-1]).half().float() - blur_mask = self.depth_mlp(x).float() - # blur_mask = torch.sigmoid(mlp_out) + x = x.reshape(-1, x.shape[-1]) + mlp_out = self.depth_mlp(x) + blur_mask = torch.sigmoid(mlp_out) blur_mask = blur_mask.reshape(1, height, width, 1) return blur_mask - # def forward(self, pos, scales, rotations, viewdirs): - # pos_delta = None - # pos = self.embed_pos(pos) - # viewdirs = self.embed_view(viewdirs) + def forward_deltas(self, pos, scales, rotations, viewdirs): + pos_delta = None + pos = self.embed_pos(pos) + viewdirs = self.embed_view(viewdirs) - # x = torch.cat([pos, viewdirs, scales, rotations], dim=-1) - # x1 = self.linears(x) + x = torch.cat([pos, viewdirs, scales, rotations], dim=-1) + x1 = self.linears(x) - # scales_delta = self.s(x1) - # rotations_delta = self.r(x1) + scales_delta = self.s(x1) + rotations_delta = self.r(x1) - # if self.pos_delta: - # pos_delta = self.p(x1) + if self.pos_delta: + pos_delta = self.p(x1) - # return scales_delta, rotations_delta, pos_delta + return scales_delta, rotations_delta, pos_delta diff --git a/examples/mlp.py b/examples/mlp.py index 4ad80c2d4..4d5987cfa 100644 --- a/examples/mlp.py +++ b/examples/mlp.py @@ -96,37 +96,21 @@ def _create_mlp_tcnn( """Create a fully-connected neural network with tiny-cuda-nn.""" network_config = get_tcnn_network_config( activation=nn.ReLU(), - out_activation=nn.Sigmoid(), + out_activation=None, layer_width=layer_width, num_layers=num_layers, ) - print(network_config) tcnn_encoding = tcnn.Network( n_input_dims=in_dim, n_output_dims=out_dim, network_config=network_config, ) - mlp_torch = _create_mlp_torch(in_dim, num_layers, layer_width, out_dim) - print(mlp_torch) - output_layer = mlp_torch[8].weight.data - output_layer = nn.functional.pad(output_layer, pad=(0, 0, 0, 16 - (out_dim % 16))) - params = torch.cat( - [ - mlp_torch[0].weight.data.flatten(), - mlp_torch[2].weight.data.flatten(), - mlp_torch[4].weight.data.flatten(), - mlp_torch[6].weight.data.flatten(), - output_layer.flatten(), - ] - ).half() - tcnn_encoding.params.data[...] = params - - # if initialize_last_layer_zeros: - # # tcnn always pads the output layer's width to a multiple of 16 - # params = tcnn_encoding.state_dict()["params"] - # params[-1 * (layer_width * 16 * (out_dim // 16 + 1)) :] = 0 - # tcnn_encoding.load_state_dict({"params": params}) + if initialize_last_layer_zeros: + # tcnn always pads the output layer's width to a multiple of 16 + params = tcnn_encoding.state_dict()["params"] + params[-1 * (layer_width * 16 * (out_dim // 16 + 1)) :] = 0 + tcnn_encoding.load_state_dict({"params": params}) return tcnn_encoding @@ -145,8 +129,6 @@ def _create_mlp_torch( layers.append(nn.Linear(layer_in, layer_out, bias=False)) if i != num_layers - 1: layers.append(nn.ReLU()) - else: - layers.append(nn.Sigmoid()) layer_in = layer_width if initialize_last_layer_zeros: diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 060cfeb42..78bd86e04 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -296,6 +296,9 @@ def __init__( self.local_rank = local_rank self.world_size = world_size self.device = f"cuda:{local_rank}" + self.render_mode = "RGB" + if cfg.depth_loss or cfg.blur_opt: + self.render_mode = "RGB+ED" # Where to dump results. os.makedirs(cfg.result_dir, exist_ok=True) @@ -417,12 +420,7 @@ def __init__( self.blur_module = GTnet(len(self.trainset)).to(self.device) self.blur_optimizers = [ torch.optim.Adam( - self.blur_module.focals.parameters(), - lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size), - weight_decay=cfg.blur_opt_reg, - ), - torch.optim.Adam( - self.blur_module.depth_mlp.parameters(), + self.blur_module.parameters(), lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size), weight_decay=cfg.blur_opt_reg, ), @@ -482,10 +480,6 @@ def rasterize_splats( **kwargs, ) -> Tuple[Tensor, Tensor, Dict]: means = self.splats["means"] # [N, 3] - # quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] - # rasterization does normalization internally - quats = self.splats["quats"] # [N, 4] - scales = torch.exp(self.splats["scales"]) # [N, 3] opacities = torch.sigmoid(self.splats["opacities"]) # [N,] image_ids = kwargs.pop("image_ids", None) @@ -501,14 +495,26 @@ def rasterize_splats( else: colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] - if blur: - scales_delta = self.splats["embeds"][:, image_ids[0], :3] - rotations_delta = self.splats["embeds"][:, image_ids[0], 3:] + if self.cfg.blur_opt and blur: + means_ = means.detach() + scales_ = self.splats["scales"].detach() + quats_ = self.splats["quats"].detach() + viewdir_ = camtoworlds[0, :3, 3].repeat(means.shape[0], 1) + scales_delta, rotations_delta, _ = self.blur_module.forward_deltas( + means_, + scales_, + quats_, + viewdir_, + ) + # scales_delta = self.splats["embeds"][:, image_ids[0], :3] + # rotations_delta = self.splats["embeds"][:, image_ids[0], 3:] scales_delta = torch.clamp(scales_delta, min=0.0, max=0.1) rotations_delta = torch.clamp(rotations_delta, min=-0.05, max=0.05) scales = torch.exp(self.splats["scales"] + scales_delta) quats = F.normalize(self.splats["quats"], dim=-1) + rotations_delta - quats = F.normalize(quats, dim=-1) + else: + scales = torch.exp(self.splats["scales"]) # [N, 3] + quats = self.splats["quats"] # [N, 4] rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" render_colors, render_alphas, info = rasterization( @@ -640,7 +646,7 @@ def train(self): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, - render_mode="RGB+ED", # if cfg.depth_loss else "RGB", + render_mode=self.render_mode, masks=masks, ) if renders.shape[-1] == 4: @@ -670,7 +676,7 @@ def train(self): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, - render_mode="RGB+ED" if cfg.depth_loss else "RGB", + render_mode="RGB", masks=masks, blur=True, ) @@ -909,14 +915,10 @@ def eval(self, step: int, stage: str = "val"): world_rank = self.world_rank world_size = self.world_size - if stage == "train": - dataloader = torch.utils.data.DataLoader( - self.trainset, batch_size=1, shuffle=False, num_workers=1 - ) - else: - dataloader = torch.utils.data.DataLoader( - self.valset, batch_size=1, shuffle=False, num_workers=1 - ) + dataset = self.trainset if stage == "train" else self.valset + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=1, shuffle=False, num_workers=1 + ) ellipse_time = 0 metrics = defaultdict(list) @@ -939,7 +941,7 @@ def eval(self, step: int, stage: str = "val"): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, - render_mode="RGB+ED", # if cfg.depth_loss else "RGB", + render_mode=self.render_mode, masks=masks, ) # [1, H, W, 3] if renders.shape[-1] == 4: @@ -952,11 +954,8 @@ def eval(self, step: int, stage: str = "val"): colors = torch.clamp(colors, 0.0, 1.0) canvas_list = [pixels, colors] - if stage == "train": - blur_mask = self.blur_module(depths, image_ids) - canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) - - colors_blur, _, _ = self.rasterize_splats( + if self.cfg.blur_opt and stage == "train": + renders_blur, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, width=width, @@ -965,16 +964,16 @@ def eval(self, step: int, stage: str = "val"): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, + render_mode="RGB", masks=masks, blur=True, - ) # [1, H, W, 3] - colors_blur = torch.clamp(colors_blur, 0.0, 1.0) - canvas_list.append(colors_blur) - - colors_mix = (1 - blur_mask) * colors + blur_mask * colors_blur - colors_mix = torch.clamp(colors_mix, 0.0, 1.0) - canvas_list.append(colors_mix) - colors = colors_mix + ) + canvas_list.append(torch.clamp(renders_blur[..., 0:3], 0.0, 1.0)) + blur_mask = self.blur_module(depths, image_ids) + canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) + colors = (1 - blur_mask) * colors + blur_mask * renders_blur[..., 0:3] + colors = torch.clamp(colors, 0.0, 1.0) + canvas_list.append(colors) if world_rank == 0: # write images From 291a74c424b3313d0b29f2eeb8482b3388be7315 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 16 Oct 2024 15:40:20 -0700 Subject: [PATCH 21/53] cleanup --- examples/blur_kernel.py | 69 +++++++++----------------------------- examples/simple_trainer.py | 18 ++++------ 2 files changed, 23 insertions(+), 64 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index f667a3d6b..f14e5d4f0 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -78,16 +78,7 @@ def init_linear_weights(m): class GTnet(nn.Module): - def __init__( - self, - n, - res_pos=3, - res_view=10, - num_hidden=3, - width=64, - pos_delta=False, - num_moments=4, - ): + def __init__(self, n): super().__init__() self.focals = torch.nn.Embedding(n, 1) self.focals.weight.data = torch.linspace(-1, 1, n)[:, None] @@ -99,35 +90,14 @@ def __init__( out_dim=1, ) - self.pos_delta = pos_delta - self.num_moments = num_moments - self.embed_pos, self.embed_pos_cnl = get_embedder(res_pos, 3) - self.embed_view, self.embed_view_cnl = get_embedder(res_view, 3) - in_cnl = ( - self.embed_pos_cnl + self.embed_view_cnl + 7 - ) # 7 for scales and rotations - hiddens = [ - nn.Linear(width, width) if i % 2 == 0 else nn.ReLU() - for i in range((num_hidden - 1) * 2) - ] - self.linears = nn.Sequential( - nn.Linear(in_cnl, width), - nn.ReLU(), - *hiddens, - ).to("cuda") - if not pos_delta: # Defocus - self.s = nn.Linear(width, 3).to("cuda") - self.r = nn.Linear(width, 4).to("cuda") - else: # Motion - self.s = nn.Linear(width, 3 * (num_moments + 1)).to("cuda") - self.r = nn.Linear(width, 4 * (num_moments + 1)).to("cuda") - self.p = nn.Linear(width, 3 * num_moments).to("cuda") - - self.linears.apply(init_linear_weights) - self.s.apply(init_linear_weights) - self.r.apply(init_linear_weights) - if pos_delta: - self.p.apply(init_linear_weights) + self.embed_pos, self.embed_pos_cnl = get_embedder(3, 3) + self.embed_view, self.embed_view_cnl = get_embedder(10, 3) + self.linears = _create_mlp_torch( + in_dim=self.embed_pos_cnl + self.embed_view_cnl + 7, + num_layers=5, + layer_width=64, + out_dim=7, + ) def forward(self, depths, image_ids): height, width = depths.shape[1:3] @@ -144,17 +114,10 @@ def forward(self, depths, image_ids): return blur_mask def forward_deltas(self, pos, scales, rotations, viewdirs): - pos_delta = None - pos = self.embed_pos(pos) - viewdirs = self.embed_view(viewdirs) - - x = torch.cat([pos, viewdirs, scales, rotations], dim=-1) - x1 = self.linears(x) - - scales_delta = self.s(x1) - rotations_delta = self.r(x1) - - if self.pos_delta: - pos_delta = self.p(x1) - - return scales_delta, rotations_delta, pos_delta + pos_embed = self.embed_pos(pos) + viewdirs_embed = self.embed_view(viewdirs) + x = torch.cat([pos_embed, viewdirs_embed, scales, rotations], dim=-1) + x = self.linears(x) + scales_delta = x[:, :3] + rotations_delta = x[:, 3:] + return scales_delta, rotations_delta diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 78bd86e04..73865cf85 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -234,7 +234,7 @@ def create_splats_with_optimizers( N = points.shape[0] quats = torch.rand((N, 4)) # [N, 4] opacities = torch.logit(torch.full((N,), init_opacity)) # [N,] - embeds = 0.1 * torch.randn(N, 50, 7) + # embeds = 0.1 * torch.randn(N, 50, 7) params = [ # name, value, lr @@ -242,7 +242,7 @@ def create_splats_with_optimizers( ("scales", torch.nn.Parameter(scales), 5e-3), ("quats", torch.nn.Parameter(quats), 1e-3), ("opacities", torch.nn.Parameter(opacities), 5e-2), - ("embeds", torch.nn.Parameter(embeds), 1e-3), + # ("embeds", torch.nn.Parameter(embeds), 1e-3), ] if feature_dim is None: @@ -496,15 +496,11 @@ def rasterize_splats( colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] if self.cfg.blur_opt and blur: - means_ = means.detach() - scales_ = self.splats["scales"].detach() - quats_ = self.splats["quats"].detach() - viewdir_ = camtoworlds[0, :3, 3].repeat(means.shape[0], 1) - scales_delta, rotations_delta, _ = self.blur_module.forward_deltas( - means_, - scales_, - quats_, - viewdir_, + scales_delta, rotations_delta = self.blur_module.forward_deltas( + self.splats["means"], + self.splats["scales"], + self.splats["quats"], + camtoworlds[0, :3, 3].repeat(self.splats["means"].shape[0], 1), ) # scales_delta = self.splats["embeds"][:, image_ids[0], :3] # rotations_delta = self.splats["embeds"][:, image_ids[0], 3:] From ad9c5cdd2b245bf8839b4d6f0ce616746c4669fd Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 16 Oct 2024 16:49:27 -0700 Subject: [PATCH 22/53] tcnn works with log_transform --- examples/blur_kernel.py | 9 +++++---- examples/simple_trainer.py | 6 ------ 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index f14e5d4f0..f8626a547 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -92,7 +92,7 @@ def __init__(self, n): self.embed_pos, self.embed_pos_cnl = get_embedder(3, 3) self.embed_view, self.embed_view_cnl = get_embedder(10, 3) - self.linears = _create_mlp_torch( + self.linears = _create_mlp_tcnn( in_dim=self.embed_pos_cnl + self.embed_view_cnl + 7, num_layers=5, layer_width=64, @@ -114,10 +114,11 @@ def forward(self, depths, image_ids): return blur_mask def forward_deltas(self, pos, scales, rotations, viewdirs): + pos = log_transform(pos) pos_embed = self.embed_pos(pos) viewdirs_embed = self.embed_view(viewdirs) x = torch.cat([pos_embed, viewdirs_embed, scales, rotations], dim=-1) - x = self.linears(x) - scales_delta = x[:, :3] - rotations_delta = x[:, 3:] + mlp_out = self.linears(x).float() + scales_delta = mlp_out[:, :3] + rotations_delta = mlp_out[:, 3:] return scales_delta, rotations_delta diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 73865cf85..baf4f3b72 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -731,12 +731,6 @@ def train(self): if cfg.blur_opt: lambda_mean = 0.01 lambda_std = 0.001 - print( - lambda_mean, - lambda_std, - blur_mask.mean().item(), - blur_mask.std().item(), - ) loss += ( lambda_mean * (torch.mean(blur_mask) - 0.5) ** 2 + lambda_std * (torch.std(blur_mask) - 0.5) ** 2 From fa5957ad42966292433caee31367e68721a1c7d0 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 16 Oct 2024 18:02:41 -0700 Subject: [PATCH 23/53] cleanup --- examples/blur_kernel.py | 76 ++++++++++++++------------------------ examples/simple_trainer.py | 18 +++------ 2 files changed, 34 insertions(+), 60 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index f8626a547..c9380a1cf 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -11,8 +11,6 @@ import torch import torch.nn as nn -import numpy as np -import torch.nn.functional as F from examples.mlp import create_mlp, _create_mlp_torch, _create_mlp_tcnn from gsplat.utils import log_transform @@ -46,14 +44,11 @@ def create_embedding_fn(self): self.embed_fns = embed_fns self.out_dim = out_dim - def embed(self, inputs): + def encode(self, inputs): return torch.cat([fn(inputs) for fn in self.embed_fns], -1) -def get_embedder(multires, i=0): - if i == -1: - return nn.Identity(), 3 - +def get_encoder(multires, i=0): embed_kwargs = { "include_input": True, "input_dims": i, @@ -62,63 +57,48 @@ def get_embedder(multires, i=0): "log_sampling": True, "periodic_fns": [torch.sin, torch.cos], } - - embedder_obj = Embedder(**embed_kwargs) - embed = lambda x, eo=embedder_obj: eo.embed(x) - return embed, embedder_obj.out_dim - - -def init_linear_weights(m): - if isinstance(m, nn.Linear): - if m.weight.shape[0] in [2, 3]: - nn.init.xavier_normal_(m.weight, 0.1) - else: - nn.init.xavier_normal_(m.weight) - nn.init.constant_(m.bias, 0) + embedder = Embedder(**embed_kwargs) + return embedder -class GTnet(nn.Module): +class BlurOptModule(nn.Module): def __init__(self, n): super().__init__() - self.focals = torch.nn.Embedding(n, 1) - self.focals.weight.data = torch.linspace(-1, 1, n)[:, None] - self.embed_depth, self.embed_depth_cnl = get_embedder(14, 1) - self.depth_mlp = _create_mlp_torch( - in_dim=self.embed_depth_cnl + 1, + self.embeds = torch.nn.Embedding(n, 1) + self.embeds.weight.data = torch.linspace(-1, 1, n)[:, None] + + self.depth_encoder = get_encoder(14, 1) + self.means_encoder = get_encoder(3, 3) + self.blur_mask_mlp = _create_mlp_torch( + in_dim=self.depth_encoder.out_dim + 1, num_layers=5, layer_width=64, out_dim=1, ) - - self.embed_pos, self.embed_pos_cnl = get_embedder(3, 3) - self.embed_view, self.embed_view_cnl = get_embedder(10, 3) - self.linears = _create_mlp_tcnn( - in_dim=self.embed_pos_cnl + self.embed_view_cnl + 7, + self.blur_deltas_mlp = _create_mlp_tcnn( + in_dim=self.means_encoder.out_dim + 8, num_layers=5, layer_width=64, out_dim=7, ) - def forward(self, depths, image_ids): - height, width = depths.shape[1:3] - - depths_emb = self.embed_depth(depths) - focals_emb = self.focals(image_ids[0])[None, None, None, :].repeat( - 1, height, width, 1 - ) - x = torch.cat([depths_emb, focals_emb], dim=-1) - x = x.reshape(-1, x.shape[-1]) - mlp_out = self.depth_mlp(x) + def predict_mask(self, image_ids, depths): + depths_emb = self.depth_encoder.encode(depths.reshape(-1, 1)) + images_emb = self.embeds(image_ids).repeat(depths_emb.shape[0], 1) + x = torch.cat([images_emb, depths_emb], dim=-1) + mlp_out = self.blur_mask_mlp(x) blur_mask = torch.sigmoid(mlp_out) - blur_mask = blur_mask.reshape(1, height, width, 1) + blur_mask = blur_mask.reshape(depths.shape) return blur_mask - def forward_deltas(self, pos, scales, rotations, viewdirs): - pos = log_transform(pos) - pos_embed = self.embed_pos(pos) - viewdirs_embed = self.embed_view(viewdirs) - x = torch.cat([pos_embed, viewdirs_embed, scales, rotations], dim=-1) - mlp_out = self.linears(x).float() + def predict_deltas(self, image_ids, means, scales, quats): + means_log = log_transform(means) + means_embed = self.means_encoder.encode(means_log) + images_emb = self.embeds(image_ids).repeat(means.shape[0], 1) + x = torch.cat([images_emb, means_embed, scales, quats], dim=-1) + mlp_out = self.blur_deltas_mlp(x).float() scales_delta = mlp_out[:, :3] rotations_delta = mlp_out[:, 3:] + scales_delta = torch.clamp(scales_delta, min=0.0, max=0.1) + rotations_delta = torch.clamp(rotations_delta, min=-0.05, max=0.05) return scales_delta, rotations_delta diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index baf4f3b72..6811c63da 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -35,7 +35,7 @@ color_correct, total_variation_loss, ) -from blur_kernel import GTnet +from blur_kernel import BlurOptModule from gsplat.compression import PngCompression from gsplat.distributed import cli @@ -234,7 +234,6 @@ def create_splats_with_optimizers( N = points.shape[0] quats = torch.rand((N, 4)) # [N, 4] opacities = torch.logit(torch.full((N,), init_opacity)) # [N,] - # embeds = 0.1 * torch.randn(N, 50, 7) params = [ # name, value, lr @@ -242,7 +241,6 @@ def create_splats_with_optimizers( ("scales", torch.nn.Parameter(scales), 5e-3), ("quats", torch.nn.Parameter(quats), 1e-3), ("opacities", torch.nn.Parameter(opacities), 5e-2), - # ("embeds", torch.nn.Parameter(embeds), 1e-3), ] if feature_dim is None: @@ -417,7 +415,7 @@ def __init__( self.blur_optimizers = [] if cfg.blur_opt: - self.blur_module = GTnet(len(self.trainset)).to(self.device) + self.blur_module = BlurOptModule(len(self.trainset)).to(self.device) self.blur_optimizers = [ torch.optim.Adam( self.blur_module.parameters(), @@ -496,16 +494,12 @@ def rasterize_splats( colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] if self.cfg.blur_opt and blur: - scales_delta, rotations_delta = self.blur_module.forward_deltas( + scales_delta, rotations_delta = self.blur_module.predict_deltas( + image_ids, self.splats["means"], self.splats["scales"], self.splats["quats"], - camtoworlds[0, :3, 3].repeat(self.splats["means"].shape[0], 1), ) - # scales_delta = self.splats["embeds"][:, image_ids[0], :3] - # rotations_delta = self.splats["embeds"][:, image_ids[0], 3:] - scales_delta = torch.clamp(scales_delta, min=0.0, max=0.1) - rotations_delta = torch.clamp(rotations_delta, min=-0.05, max=0.05) scales = torch.exp(self.splats["scales"] + scales_delta) quats = F.normalize(self.splats["quats"], dim=-1) + rotations_delta else: @@ -676,7 +670,7 @@ def train(self): masks=masks, blur=True, ) - blur_mask = self.blur_module(depths, image_ids) + blur_mask = self.blur_module.predict_mask(image_ids, depths) colors = (1 - blur_mask) * colors + blur_mask * renders_blur[..., 0:3] self.cfg.strategy.step_pre_backward( @@ -959,7 +953,7 @@ def eval(self, step: int, stage: str = "val"): blur=True, ) canvas_list.append(torch.clamp(renders_blur[..., 0:3], 0.0, 1.0)) - blur_mask = self.blur_module(depths, image_ids) + blur_mask = self.blur_module.predict_mask(image_ids, depths) canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) colors = (1 - blur_mask) * colors + blur_mask * renders_blur[..., 0:3] colors = torch.clamp(colors, 0.0, 1.0) From 2659a9dd2ef44a1fdf4e27761af55a0564986e94 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 16 Oct 2024 18:19:05 -0700 Subject: [PATCH 24/53] send --- examples/benchmarks/mcmc_deblur.sh | 3 +-- examples/blur_kernel.py | 33 ++++++++++++++++++++---------- examples/simple_trainer.py | 17 +++++++-------- 3 files changed, 30 insertions(+), 23 deletions(-) diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index ba844c804..7973ddd3a 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -1,6 +1,5 @@ SCENE_DIR="data/deblur_dataset/real_defocus_blur" -# SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools" -SCENE_LIST="defocuscupcake" +SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools" DATA_FACTOR=4 RENDER_TRAJ_PATH="spiral" diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index c9380a1cf..ecfafdf89 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -11,6 +11,7 @@ import torch import torch.nn as nn +from torch import Tensor from examples.mlp import create_mlp, _create_mlp_torch, _create_mlp_tcnn from gsplat.utils import log_transform @@ -62,43 +63,53 @@ def get_encoder(multires, i=0): class BlurOptModule(nn.Module): - def __init__(self, n): + """Blur optimization module.""" + + def __init__(self, n: int, embed_dim: int = 1): super().__init__() - self.embeds = torch.nn.Embedding(n, 1) + self.embeds = torch.nn.Embedding(n, embed_dim) self.embeds.weight.data = torch.linspace(-1, 1, n)[:, None] self.depth_encoder = get_encoder(14, 1) self.means_encoder = get_encoder(3, 3) self.blur_mask_mlp = _create_mlp_torch( - in_dim=self.depth_encoder.out_dim + 1, + in_dim=embed_dim + self.depth_encoder.out_dim, num_layers=5, layer_width=64, out_dim=1, ) self.blur_deltas_mlp = _create_mlp_tcnn( - in_dim=self.means_encoder.out_dim + 8, + in_dim=embed_dim + self.means_encoder.out_dim + 7, num_layers=5, layer_width=64, out_dim=7, ) - def predict_mask(self, image_ids, depths): + def predict_mask(self, image_ids: Tensor, depths: Tensor): depths_emb = self.depth_encoder.encode(depths.reshape(-1, 1)) images_emb = self.embeds(image_ids).repeat(depths_emb.shape[0], 1) - x = torch.cat([images_emb, depths_emb], dim=-1) - mlp_out = self.blur_mask_mlp(x) + mlp_out = self.blur_mask_mlp(torch.cat([images_emb, depths_emb], dim=-1)) blur_mask = torch.sigmoid(mlp_out) blur_mask = blur_mask.reshape(depths.shape) return blur_mask - def predict_deltas(self, image_ids, means, scales, quats): + def predict_deltas( + self, image_ids: Tensor, means: Tensor, scales: Tensor, quats: Tensor + ): means_log = log_transform(means) - means_embed = self.means_encoder.encode(means_log) + means_emb = self.means_encoder.encode(means_log) images_emb = self.embeds(image_ids).repeat(means.shape[0], 1) - x = torch.cat([images_emb, means_embed, scales, quats], dim=-1) - mlp_out = self.blur_deltas_mlp(x).float() + mlp_out = self.blur_deltas_mlp( + torch.cat([images_emb, means_emb, scales, quats], dim=-1) + ).float() scales_delta = mlp_out[:, :3] rotations_delta = mlp_out[:, 3:] scales_delta = torch.clamp(scales_delta, min=0.0, max=0.1) rotations_delta = torch.clamp(rotations_delta, min=-0.05, max=0.05) return scales_delta, rotations_delta + + def mask_reg_loss(self, blur_mask: Tensor): + """Mask regularization loss.""" + meanloss = (torch.mean(blur_mask) - 0.5) ** 2 + stdloss = (torch.std(blur_mask) - 0.5) ** 2 + return meanloss + 0.1 * stdloss diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 6811c63da..931b8123e 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -157,6 +157,8 @@ class Config: blur_opt_lr: float = 1e-3 # Regularization for blur optimization as weight decay blur_opt_reg: float = 1e-6 + # Regularization for blur mask + blur_mask_reg: float = 0.01 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False @@ -657,6 +659,7 @@ def train(self): bkgd = torch.rand(1, 3, device=device) colors = colors + bkgd * (1.0 - alphas) if cfg.blur_opt: + blur_mask = self.blur_module.predict_mask(image_ids, depths) renders_blur, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, @@ -670,7 +673,6 @@ def train(self): masks=masks, blur=True, ) - blur_mask = self.blur_module.predict_mask(image_ids, depths) colors = (1 - blur_mask) * colors + blur_mask * renders_blur[..., 0:3] self.cfg.strategy.step_pre_backward( @@ -709,6 +711,8 @@ def train(self): if cfg.use_bilateral_grid: tvloss = 10 * total_variation_loss(self.bil_grids.grids) loss += tvloss + if cfg.blur_opt: + loss += cfg.blur_mask_reg * self.blur_module.mask_reg_loss(blur_mask) # regularizations if cfg.opacity_reg > 0.0: @@ -722,13 +726,6 @@ def train(self): loss + cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() ) - if cfg.blur_opt: - lambda_mean = 0.01 - lambda_std = 0.001 - loss += ( - lambda_mean * (torch.mean(blur_mask) - 0.5) ** 2 - + lambda_std * (torch.std(blur_mask) - 0.5) ** 2 - ) loss.backward() desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " @@ -939,6 +936,8 @@ def eval(self, step: int, stage: str = "val"): colors = torch.clamp(colors, 0.0, 1.0) canvas_list = [pixels, colors] if self.cfg.blur_opt and stage == "train": + blur_mask = self.blur_module.predict_mask(image_ids, depths) + canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) renders_blur, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, @@ -953,8 +952,6 @@ def eval(self, step: int, stage: str = "val"): blur=True, ) canvas_list.append(torch.clamp(renders_blur[..., 0:3], 0.0, 1.0)) - blur_mask = self.blur_module.predict_mask(image_ids, depths) - canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) colors = (1 - blur_mask) * colors + blur_mask * renders_blur[..., 0:3] colors = torch.clamp(colors, 0.0, 1.0) canvas_list.append(colors) From 464249dc447d6b35b5e06856645803d6b5e44621 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 16 Oct 2024 19:24:50 -0700 Subject: [PATCH 25/53] log depth --- examples/benchmarks/mcmc_deblur.sh | 3 ++- examples/blur_kernel.py | 5 +++-- examples/datasets/colmap.py | 2 -- examples/simple_trainer.py | 1 - 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index 7973ddd3a..ba844c804 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -1,5 +1,6 @@ SCENE_DIR="data/deblur_dataset/real_defocus_blur" -SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools" +# SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools" +SCENE_LIST="defocuscupcake" DATA_FACTOR=4 RENDER_TRAJ_PATH="spiral" diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index ecfafdf89..f081b3fcb 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -70,7 +70,7 @@ def __init__(self, n: int, embed_dim: int = 1): self.embeds = torch.nn.Embedding(n, embed_dim) self.embeds.weight.data = torch.linspace(-1, 1, n)[:, None] - self.depth_encoder = get_encoder(14, 1) + self.depth_encoder = get_encoder(7, 1) self.means_encoder = get_encoder(3, 3) self.blur_mask_mlp = _create_mlp_torch( in_dim=embed_dim + self.depth_encoder.out_dim, @@ -86,7 +86,8 @@ def __init__(self, n: int, embed_dim: int = 1): ) def predict_mask(self, image_ids: Tensor, depths: Tensor): - depths_emb = self.depth_encoder.encode(depths.reshape(-1, 1)) + depths_log = log_transform(depths) + depths_emb = self.depth_encoder.encode(depths_log.reshape(-1, 1)) images_emb = self.embeds(image_ids).repeat(depths_emb.shape[0], 1) mlp_out = self.blur_mask_mlp(torch.cat([images_emb, depths_emb], dim=-1)) blur_mask = torch.sigmoid(mlp_out) diff --git a/examples/datasets/colmap.py b/examples/datasets/colmap.py index 83482148a..ba12c2258 100644 --- a/examples/datasets/colmap.py +++ b/examples/datasets/colmap.py @@ -321,10 +321,8 @@ def __init__( indices = np.arange(len(self.parser.image_names)) if split == "train": self.indices = indices[indices % self.parser.test_every != 0] - # self.indices = np.concatenate([[0], self.indices]) else: self.indices = indices[indices % self.parser.test_every == 0] - # self.indices = self.indices[1:] def __len__(self): return len(self.indices) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 931b8123e..a5bbd3827 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -42,7 +42,6 @@ from gsplat.rendering import rasterization from gsplat.strategy import DefaultStrategy, MCMCStrategy from gsplat.optimizers import SelectiveAdam -from gsplat.utils import log_transform @dataclass From af99af2c9869e1299ac1e4835d7a2d7933e72e20 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 16 Oct 2024 21:27:34 -0700 Subject: [PATCH 26/53] send --- examples/benchmarks/mcmc_deblur.sh | 3 +-- examples/blur_kernel.py | 12 ++++++------ examples/simple_trainer.py | 18 +++++++++++++----- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index ba844c804..7973ddd3a 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -1,6 +1,5 @@ SCENE_DIR="data/deblur_dataset/real_defocus_blur" -# SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools" -SCENE_LIST="defocuscupcake" +SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools" DATA_FACTOR=4 RENDER_TRAJ_PATH="spiral" diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index f081b3fcb..f15f0ec2c 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -68,7 +68,6 @@ class BlurOptModule(nn.Module): def __init__(self, n: int, embed_dim: int = 1): super().__init__() self.embeds = torch.nn.Embedding(n, embed_dim) - self.embeds.weight.data = torch.linspace(-1, 1, n)[:, None] self.depth_encoder = get_encoder(7, 1) self.means_encoder = get_encoder(3, 3) @@ -104,13 +103,14 @@ def predict_deltas( torch.cat([images_emb, means_emb, scales, quats], dim=-1) ).float() scales_delta = mlp_out[:, :3] - rotations_delta = mlp_out[:, 3:] + quats_delta = mlp_out[:, 3:] scales_delta = torch.clamp(scales_delta, min=0.0, max=0.1) - rotations_delta = torch.clamp(rotations_delta, min=-0.05, max=0.05) - return scales_delta, rotations_delta + quats_delta = torch.clamp(quats_delta, min=-0.05, max=0.05) + return scales_delta, quats_delta - def mask_reg_loss(self, blur_mask: Tensor): + def mask_reg_loss(self, blur_mask: Tensor, step: int): """Mask regularization loss.""" meanloss = (torch.mean(blur_mask) - 0.5) ** 2 stdloss = (torch.std(blur_mask) - 0.5) ** 2 - return meanloss + 0.1 * stdloss + lambda_mean = 10.0 if step < 2000 else 1.0 + return lambda_mean * meanloss + stdloss diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index a5bbd3827..105d05b46 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -157,7 +157,7 @@ class Config: # Regularization for blur optimization as weight decay blur_opt_reg: float = 1e-6 # Regularization for blur mask - blur_mask_reg: float = 0.01 + blur_mask_reg: float = 0.001 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False @@ -418,6 +418,11 @@ def __init__( if cfg.blur_opt: self.blur_module = BlurOptModule(len(self.trainset)).to(self.device) self.blur_optimizers = [ + torch.optim.Adam( + self.blur_module.embeds.parameters(), + lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size) * 10.0, + weight_decay=cfg.blur_opt_reg, + ), torch.optim.Adam( self.blur_module.parameters(), lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size), @@ -495,14 +500,15 @@ def rasterize_splats( colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] if self.cfg.blur_opt and blur: - scales_delta, rotations_delta = self.blur_module.predict_deltas( + quats = F.normalize(self.splats["quats"], dim=-1) + scales_delta, quats_delta = self.blur_module.predict_deltas( image_ids, self.splats["means"], self.splats["scales"], - self.splats["quats"], + quats, ) scales = torch.exp(self.splats["scales"] + scales_delta) - quats = F.normalize(self.splats["quats"], dim=-1) + rotations_delta + quats += quats_delta else: scales = torch.exp(self.splats["scales"]) # [N, 3] quats = self.splats["quats"] # [N, 4] @@ -711,7 +717,9 @@ def train(self): tvloss = 10 * total_variation_loss(self.bil_grids.grids) loss += tvloss if cfg.blur_opt: - loss += cfg.blur_mask_reg * self.blur_module.mask_reg_loss(blur_mask) + loss += cfg.blur_mask_reg * self.blur_module.mask_reg_loss( + blur_mask, step + ) # regularizations if cfg.opacity_reg > 0.0: From e58b2da84f269a59fde76c56d1a124c15e2a5f15 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 16 Oct 2024 23:08:57 -0700 Subject: [PATCH 27/53] init focal --- examples/blur_kernel.py | 2 ++ examples/simple_trainer.py | 5 ----- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index f15f0ec2c..fd1b478e5 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -68,6 +68,8 @@ class BlurOptModule(nn.Module): def __init__(self, n: int, embed_dim: int = 1): super().__init__() self.embeds = torch.nn.Embedding(n, embed_dim) + # Initialize focal lengths in first channel + self.embeds.weight.data[:, 0] = torch.linspace(-1, 1, n) self.depth_encoder = get_encoder(7, 1) self.means_encoder = get_encoder(3, 3) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 105d05b46..a85141bf1 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -418,11 +418,6 @@ def __init__( if cfg.blur_opt: self.blur_module = BlurOptModule(len(self.trainset)).to(self.device) self.blur_optimizers = [ - torch.optim.Adam( - self.blur_module.embeds.parameters(), - lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size) * 10.0, - weight_decay=cfg.blur_opt_reg, - ), torch.optim.Adam( self.blur_module.parameters(), lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size), From e70f474b1959cd6be8877641955832f27653b496 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Thu, 17 Oct 2024 19:25:44 -0700 Subject: [PATCH 28/53] init strat --- examples/blur_kernel.py | 31 ++++++++++++++++++++++--------- examples/simple_trainer.py | 19 ++++++++++++++++++- 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index fd1b478e5..8e2d3fcd4 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -65,11 +65,10 @@ def get_encoder(multires, i=0): class BlurOptModule(nn.Module): """Blur optimization module.""" - def __init__(self, n: int, embed_dim: int = 1): + def __init__(self, n: int, embed_dim: int = 4): super().__init__() + self.num_warmup_steps = 2000 self.embeds = torch.nn.Embedding(n, embed_dim) - # Initialize focal lengths in first channel - self.embeds.weight.data[:, 0] = torch.linspace(-1, 1, n) self.depth_encoder = get_encoder(7, 1) self.means_encoder = get_encoder(3, 3) @@ -96,7 +95,12 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor): return blur_mask def predict_deltas( - self, image_ids: Tensor, means: Tensor, scales: Tensor, quats: Tensor + self, + image_ids: Tensor, + means: Tensor, + scales: Tensor, + quats: Tensor, + step: int, ): means_log = log_transform(means) means_emb = self.means_encoder.encode(means_log) @@ -106,13 +110,22 @@ def predict_deltas( ).float() scales_delta = mlp_out[:, :3] quats_delta = mlp_out[:, 3:] - scales_delta = torch.clamp(scales_delta, min=0.0, max=0.1) - quats_delta = torch.clamp(quats_delta, min=-0.05, max=0.05) + if step < self.num_warmup_steps: + scales_delta = torch.clamp(scales_delta, min=0.0, max=0.1) + quats_delta = torch.clamp(quats_delta, min=0.0, max=0.0) + else: + scales_delta = torch.clamp(scales_delta, min=0.0, max=0.1) + quats_delta = torch.clamp(quats_delta, min=-0.05, max=0.05) return scales_delta, quats_delta def mask_reg_loss(self, blur_mask: Tensor, step: int): """Mask regularization loss.""" - meanloss = (torch.mean(blur_mask) - 0.5) ** 2 + meanloss = (torch.mean(blur_mask) - 0.0) ** 2 stdloss = (torch.std(blur_mask) - 0.5) ** 2 - lambda_mean = 10.0 if step < 2000 else 1.0 - return lambda_mean * meanloss + stdloss + if step < self.num_warmup_steps: + lambda_mean = 1.0 + lambda_std = 1.0 + else: + lambda_mean = 1.0 + lambda_std = 1.0 + return lambda_mean * meanloss + lambda_std * stdloss diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index a85141bf1..f1e172409 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -419,7 +419,17 @@ def __init__( self.blur_module = BlurOptModule(len(self.trainset)).to(self.device) self.blur_optimizers = [ torch.optim.Adam( - self.blur_module.parameters(), + self.blur_module.embeds.parameters(), + lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size) * 10.0, + weight_decay=cfg.blur_opt_reg, + ), + torch.optim.Adam( + self.blur_module.blur_mask_mlp.parameters(), + lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size), + weight_decay=cfg.blur_opt_reg, + ), + torch.optim.Adam( + self.blur_module.blur_deltas_mlp.parameters(), lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size), weight_decay=cfg.blur_opt_reg, ), @@ -476,6 +486,7 @@ def rasterize_splats( height: int, masks: Optional[Tensor] = None, blur: bool = False, + step: int = 0, **kwargs, ) -> Tuple[Tensor, Tensor, Dict]: means = self.splats["means"] # [N, 3] @@ -501,6 +512,7 @@ def rasterize_splats( self.splats["means"], self.splats["scales"], quats, + step, ) scales = torch.exp(self.splats["scales"] + scales_delta) quats += quats_delta @@ -672,6 +684,7 @@ def train(self): render_mode="RGB", masks=masks, blur=True, + step=step, ) colors = (1 - blur_mask) * colors + blur_mask * renders_blur[..., 0:3] @@ -874,6 +887,9 @@ def train(self): self.eval(step) self.render_traj(step) + if step == 7001: + break + # run compression if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: self.run_compression(step=step) @@ -952,6 +968,7 @@ def eval(self, step: int, stage: str = "val"): render_mode="RGB", masks=masks, blur=True, + step=step, ) canvas_list.append(torch.clamp(renders_blur[..., 0:3], 0.0, 1.0)) colors = (1 - blur_mask) * colors + blur_mask * renders_blur[..., 0:3] From beac915690230f713b965bda0e42bdf56305cc84 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Thu, 17 Oct 2024 20:51:06 -0700 Subject: [PATCH 29/53] log 10 center init around 0.2 --- examples/blur_kernel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index 8e2d3fcd4..8cef43f99 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -70,7 +70,7 @@ def __init__(self, n: int, embed_dim: int = 4): self.num_warmup_steps = 2000 self.embeds = torch.nn.Embedding(n, embed_dim) - self.depth_encoder = get_encoder(7, 1) + self.depth_encoder = get_encoder(10, 1) self.means_encoder = get_encoder(3, 3) self.blur_mask_mlp = _create_mlp_torch( in_dim=embed_dim + self.depth_encoder.out_dim, @@ -120,10 +120,10 @@ def predict_deltas( def mask_reg_loss(self, blur_mask: Tensor, step: int): """Mask regularization loss.""" - meanloss = (torch.mean(blur_mask) - 0.0) ** 2 + meanloss = (torch.mean(blur_mask) - 0.2) ** 2 stdloss = (torch.std(blur_mask) - 0.5) ** 2 if step < self.num_warmup_steps: - lambda_mean = 1.0 + lambda_mean = 10.0 lambda_std = 1.0 else: lambda_mean = 1.0 From 56f7937f6bad0bf0c8e0dd1961f6e2a2091a1206 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 18 Oct 2024 01:54:16 -0700 Subject: [PATCH 30/53] slower lr --- examples/blur_kernel.py | 7 ++++--- examples/simple_trainer.py | 5 +---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index 8cef43f99..6f5e3d757 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -120,12 +120,13 @@ def predict_deltas( def mask_reg_loss(self, blur_mask: Tensor, step: int): """Mask regularization loss.""" - meanloss = (torch.mean(blur_mask) - 0.2) ** 2 + meanloss = (torch.mean(blur_mask) - 0.0) ** 2 + # meanloss = torch.mean(blur_mask) stdloss = (torch.std(blur_mask) - 0.5) ** 2 if step < self.num_warmup_steps: - lambda_mean = 10.0 + lambda_mean = 1.0 lambda_std = 1.0 else: - lambda_mean = 1.0 + lambda_mean = 10.0 lambda_std = 1.0 return lambda_mean * meanloss + lambda_std * stdloss diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index f1e172409..7fe55dceb 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -153,7 +153,7 @@ class Config: # Enable blur optimization. (experimental) blur_opt: bool = False # Learning rate for blur optimization - blur_opt_lr: float = 1e-3 + blur_opt_lr: float = 1e-4 # Regularization for blur optimization as weight decay blur_opt_reg: float = 1e-6 # Regularization for blur mask @@ -887,9 +887,6 @@ def train(self): self.eval(step) self.render_traj(step) - if step == 7001: - break - # run compression if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: self.run_compression(step=step) From a8a8fe84361b0e3ad778cb0fd68cef60e2227de0 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 18 Oct 2024 14:38:36 -0700 Subject: [PATCH 31/53] successfully init tcnn --- examples/blur_kernel.py | 21 ++++++--------------- examples/simple_trainer.py | 4 +--- 2 files changed, 7 insertions(+), 18 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index 6f5e3d757..945ad0a64 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -12,7 +12,7 @@ import torch import torch.nn as nn from torch import Tensor -from examples.mlp import create_mlp, _create_mlp_torch, _create_mlp_tcnn +from examples.mlp import create_mlp, _create_mlp_tcnn from gsplat.utils import log_transform @@ -72,7 +72,7 @@ def __init__(self, n: int, embed_dim: int = 4): self.depth_encoder = get_encoder(10, 1) self.means_encoder = get_encoder(3, 3) - self.blur_mask_mlp = _create_mlp_torch( + self.blur_mask_mlp = _create_mlp_tcnn( in_dim=embed_dim + self.depth_encoder.out_dim, num_layers=5, layer_width=64, @@ -110,23 +110,14 @@ def predict_deltas( ).float() scales_delta = mlp_out[:, :3] quats_delta = mlp_out[:, 3:] - if step < self.num_warmup_steps: - scales_delta = torch.clamp(scales_delta, min=0.0, max=0.1) - quats_delta = torch.clamp(quats_delta, min=0.0, max=0.0) - else: - scales_delta = torch.clamp(scales_delta, min=0.0, max=0.1) - quats_delta = torch.clamp(quats_delta, min=-0.05, max=0.05) + scales_delta = torch.clamp(scales_delta, min=0.0, max=0.1) + quats_delta = torch.clamp(quats_delta, min=0.0, max=0.1) return scales_delta, quats_delta def mask_reg_loss(self, blur_mask: Tensor, step: int): """Mask regularization loss.""" + lambda_mean = 10.0 + lambda_std = 0.0 meanloss = (torch.mean(blur_mask) - 0.0) ** 2 - # meanloss = torch.mean(blur_mask) stdloss = (torch.std(blur_mask) - 0.5) ** 2 - if step < self.num_warmup_steps: - lambda_mean = 1.0 - lambda_std = 1.0 - else: - lambda_mean = 10.0 - lambda_std = 1.0 return lambda_mean * meanloss + lambda_std * stdloss diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 7fe55dceb..8ed4a49fa 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -83,9 +83,7 @@ class Config: # Number of training steps max_steps: int = 30_000 # Steps to evaluate the model - eval_steps: List[int] = field( - default_factory=lambda: [2_000, 7_000, 15_000, 30_000] - ) + eval_steps: List[int] = field(default_factory=lambda: [7_000, 15_000, 30_000]) # Steps to save the model save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) From 9eb1560d0125230f58739c171736a028fe843fff Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 18 Oct 2024 17:20:38 -0700 Subject: [PATCH 32/53] remove warmup and std --- examples/blur_kernel.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index 945ad0a64..fe015d1c8 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -67,7 +67,6 @@ class BlurOptModule(nn.Module): def __init__(self, n: int, embed_dim: int = 4): super().__init__() - self.num_warmup_steps = 2000 self.embeds = torch.nn.Embedding(n, embed_dim) self.depth_encoder = get_encoder(10, 1) @@ -117,7 +116,5 @@ def predict_deltas( def mask_reg_loss(self, blur_mask: Tensor, step: int): """Mask regularization loss.""" lambda_mean = 10.0 - lambda_std = 0.0 meanloss = (torch.mean(blur_mask) - 0.0) ** 2 - stdloss = (torch.std(blur_mask) - 0.5) ** 2 - return lambda_mean * meanloss + lambda_std * stdloss + return lambda_mean * meanloss From 2a73703848d2d837799a038f239e48fb03849eb4 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 18 Oct 2024 23:17:54 -0700 Subject: [PATCH 33/53] new loss function --- examples/blur_kernel.py | 11 +++++------ examples/simple_trainer.py | 7 +++---- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index fe015d1c8..95b744ab8 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -99,7 +99,6 @@ def predict_deltas( means: Tensor, scales: Tensor, quats: Tensor, - step: int, ): means_log = log_transform(means) means_emb = self.means_encoder.encode(means_log) @@ -113,8 +112,8 @@ def predict_deltas( quats_delta = torch.clamp(quats_delta, min=0.0, max=0.1) return scales_delta, quats_delta - def mask_reg_loss(self, blur_mask: Tensor, step: int): - """Mask regularization loss.""" - lambda_mean = 10.0 - meanloss = (torch.mean(blur_mask) - 0.0) ** 2 - return lambda_mean * meanloss + def mask_variation_loss(self, blur_mask: Tensor): + """Mask variation loss.""" + x = blur_mask.mean() + meanloss = (1 / (1 - x)) - x - 1 + return meanloss diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 8ed4a49fa..a457a76d8 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -151,7 +151,7 @@ class Config: # Enable blur optimization. (experimental) blur_opt: bool = False # Learning rate for blur optimization - blur_opt_lr: float = 1e-4 + blur_opt_lr: float = 1e-3 # Regularization for blur optimization as weight decay blur_opt_reg: float = 1e-6 # Regularization for blur mask @@ -510,7 +510,6 @@ def rasterize_splats( self.splats["means"], self.splats["scales"], quats, - step, ) scales = torch.exp(self.splats["scales"] + scales_delta) quats += quats_delta @@ -723,8 +722,8 @@ def train(self): tvloss = 10 * total_variation_loss(self.bil_grids.grids) loss += tvloss if cfg.blur_opt: - loss += cfg.blur_mask_reg * self.blur_module.mask_reg_loss( - blur_mask, step + loss += cfg.blur_mask_reg * self.blur_module.mask_variation_loss( + blur_mask ) # regularizations From 4f4c1013a6a364eb6c50887d69546593f35c18a9 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Sun, 20 Oct 2024 00:12:57 -0700 Subject: [PATCH 34/53] dialed --- examples/blur_kernel.py | 2 +- examples/simple_trainer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index 95b744ab8..46115682e 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -115,5 +115,5 @@ def predict_deltas( def mask_variation_loss(self, blur_mask: Tensor): """Mask variation loss.""" x = blur_mask.mean() - meanloss = (1 / (1 - x)) - x - 1 + meanloss = 1 / (1 - x) + (0.01 / x) - 1.21 return meanloss diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index a457a76d8..3f7300fb2 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -151,11 +151,11 @@ class Config: # Enable blur optimization. (experimental) blur_opt: bool = False # Learning rate for blur optimization - blur_opt_lr: float = 1e-3 + blur_opt_lr: float = 1e-4 # Regularization for blur optimization as weight decay blur_opt_reg: float = 1e-6 # Regularization for blur mask - blur_mask_reg: float = 0.001 + blur_mask_reg: float = 0.002 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False From 7676b7927ad54ff09860b35a45b2164fa6e27a40 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Sun, 20 Oct 2024 19:37:21 -0700 Subject: [PATCH 35/53] embed init 0 clean lossfn high lr --- examples/blur_kernel.py | 3 ++- examples/simple_trainer.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index 46115682e..7604509fb 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -68,6 +68,7 @@ class BlurOptModule(nn.Module): def __init__(self, n: int, embed_dim: int = 4): super().__init__() self.embeds = torch.nn.Embedding(n, embed_dim) + self.embeds.weight.data.fill_(0) self.depth_encoder = get_encoder(10, 1) self.means_encoder = get_encoder(3, 3) @@ -115,5 +116,5 @@ def predict_deltas( def mask_variation_loss(self, blur_mask: Tensor): """Mask variation loss.""" x = blur_mask.mean() - meanloss = 1 / (1 - x) + (0.01 / x) - 1.21 + meanloss = (1 / (1 - x) - 1) + (0.01 / x) return meanloss diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 3f7300fb2..a5b9cf0ba 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -151,7 +151,7 @@ class Config: # Enable blur optimization. (experimental) blur_opt: bool = False # Learning rate for blur optimization - blur_opt_lr: float = 1e-4 + blur_opt_lr: float = 1e-3 # Regularization for blur optimization as weight decay blur_opt_reg: float = 1e-6 # Regularization for blur mask From 6e00996f8a237f2f66b2377ad20dd68fa9b58183 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 23 Oct 2024 21:08:21 -0700 Subject: [PATCH 36/53] test features --- examples/blur_kernel.py | 33 +++++++++++++++++++++++++-------- examples/simple_trainer.py | 38 ++++++++++++++++++++------------------ 2 files changed, 45 insertions(+), 26 deletions(-) diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index 7604509fb..605bd4605 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -70,10 +70,12 @@ def __init__(self, n: int, embed_dim: int = 4): self.embeds = torch.nn.Embedding(n, embed_dim) self.embeds.weight.data.fill_(0) - self.depth_encoder = get_encoder(10, 1) + self.depth_encoder = get_encoder(7, 1) self.means_encoder = get_encoder(3, 3) + self.grid_encoder = get_encoder(3, 2) self.blur_mask_mlp = _create_mlp_tcnn( - in_dim=embed_dim + self.depth_encoder.out_dim, + in_dim=embed_dim + + self.depth_encoder.out_dim + self.grid_encoder.out_dim, num_layers=5, layer_width=64, out_dim=1, @@ -85,11 +87,22 @@ def __init__(self, n: int, embed_dim: int = 4): out_dim=7, ) - def predict_mask(self, image_ids: Tensor, depths: Tensor): - depths_log = log_transform(depths) - depths_emb = self.depth_encoder.encode(depths_log.reshape(-1, 1)) - images_emb = self.embeds(image_ids).repeat(depths_emb.shape[0], 1) - mlp_out = self.blur_mask_mlp(torch.cat([images_emb, depths_emb], dim=-1)) + def predict_mask(self, image_ids: Tensor, colors: Tensor, depths: Tensor, step: int = -1): + height, width = depths.shape[1:3] + grid_y, grid_x = torch.meshgrid( + (torch.arange(height, device="cuda") + 0.5) / height, + (torch.arange(width, device="cuda") + 0.5) / width, + indexing="ij", + ) + grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) + grid_emb = self.grid_encoder.encode(grid_xy) + depths_emb = self.depth_encoder.encode(log_transform(depths)) + images_emb = self.embeds(image_ids).repeat(*depths_emb.shape[:-1], 1) + + mlp_in = torch.cat([images_emb, grid_emb, depths_emb], dim=-1) + # mlp_in = torch.cat([images_emb, grid_emb, colors, depths_emb], dim=-1) + mlp_out = self.blur_mask_mlp(mlp_in.reshape(-1, mlp_in.shape[-1])) + print(mlp_out.min().item(), mlp_out.max().item()) blur_mask = torch.sigmoid(mlp_out) blur_mask = blur_mask.reshape(depths.shape) return blur_mask @@ -100,6 +113,7 @@ def predict_deltas( means: Tensor, scales: Tensor, quats: Tensor, + step: int = -1, ): means_log = log_transform(means) means_emb = self.means_encoder.encode(means_log) @@ -116,5 +130,8 @@ def predict_deltas( def mask_variation_loss(self, blur_mask: Tensor): """Mask variation loss.""" x = blur_mask.mean() - meanloss = (1 / (1 - x) - 1) + (0.01 / x) + print(x.item()) + eps = 1e-6 + meanloss = (1 / (1 - x + eps) - 1) + (0.1 / (x + eps)) + # meanloss = 1 / (1 - x + eps) + 1 / (x + eps) - 4 return meanloss diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index a5b9cf0ba..188bccab1 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -153,9 +153,9 @@ class Config: # Learning rate for blur optimization blur_opt_lr: float = 1e-3 # Regularization for blur optimization as weight decay - blur_opt_reg: float = 1e-6 + blur_opt_reg: float = 0.0 # Regularization for blur mask - blur_mask_reg: float = 0.002 + blur_mask_reg: float = 0.001 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False @@ -417,17 +417,7 @@ def __init__( self.blur_module = BlurOptModule(len(self.trainset)).to(self.device) self.blur_optimizers = [ torch.optim.Adam( - self.blur_module.embeds.parameters(), - lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size) * 10.0, - weight_decay=cfg.blur_opt_reg, - ), - torch.optim.Adam( - self.blur_module.blur_mask_mlp.parameters(), - lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size), - weight_decay=cfg.blur_opt_reg, - ), - torch.optim.Adam( - self.blur_module.blur_deltas_mlp.parameters(), + self.blur_module.parameters(), lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size), weight_decay=cfg.blur_opt_reg, ), @@ -510,6 +500,7 @@ def rasterize_splats( self.splats["means"], self.splats["scales"], quats, + step, ) scales = torch.exp(self.splats["scales"] + scales_delta) quats += quats_delta @@ -649,6 +640,7 @@ def train(self): image_ids=image_ids, render_mode=self.render_mode, masks=masks, + step=step, ) if renders.shape[-1] == 4: colors, depths = renders[..., 0:3], renders[..., 3:4] @@ -668,7 +660,6 @@ def train(self): bkgd = torch.rand(1, 3, device=device) colors = colors + bkgd * (1.0 - alphas) if cfg.blur_opt: - blur_mask = self.blur_module.predict_mask(image_ids, depths) renders_blur, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, @@ -678,11 +669,17 @@ def train(self): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, - render_mode="RGB", + render_mode="RGB+ED", masks=masks, blur=True, step=step, ) + blur_mask = self.blur_module.predict_mask( + image_ids, + renders[..., 0:3], + renders[..., 3:4], + step=step, + ) colors = (1 - blur_mask) * colors + blur_mask * renders_blur[..., 0:3] self.cfg.strategy.step_pre_backward( @@ -948,8 +945,6 @@ def eval(self, step: int, stage: str = "val"): colors = torch.clamp(colors, 0.0, 1.0) canvas_list = [pixels, colors] if self.cfg.blur_opt and stage == "train": - blur_mask = self.blur_module.predict_mask(image_ids, depths) - canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) renders_blur, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, @@ -959,11 +954,18 @@ def eval(self, step: int, stage: str = "val"): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, - render_mode="RGB", + render_mode="RGB+ED", masks=masks, blur=True, step=step, ) + blur_mask = self.blur_module.predict_mask( + image_ids, + renders[..., 0:3], + renders[..., 3:4], + step=step, + ) + canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) canvas_list.append(torch.clamp(renders_blur[..., 0:3], 0.0, 1.0)) colors = (1 - blur_mask) * colors + blur_mask * renders_blur[..., 0:3] colors = torch.clamp(colors, 0.0, 1.0) From 6c611dc745ff7a0352495b67168fe06ffec0682d Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 23 Oct 2024 21:16:22 -0700 Subject: [PATCH 37/53] cleanup --- examples/benchmarks/mcmc_deblur.sh | 1 + examples/blur_kernel.py | 35 ++++-------------------------- examples/simple_trainer.py | 32 +++++++-------------------- 3 files changed, 13 insertions(+), 55 deletions(-) diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index 7973ddd3a..1a2375925 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -1,5 +1,6 @@ SCENE_DIR="data/deblur_dataset/real_defocus_blur" SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools" +SCENE_LIST="defocustools" DATA_FACTOR=4 RENDER_TRAJ_PATH="spiral" diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index 605bd4605..303c8b460 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -1,14 +1,3 @@ -# -# Copyright (C) 2023, Inria -# GRAPHDECO research group, https://team.inria.fr/graphdeco -# All rights reserved. -# -# This software is free for non-commercial, research and evaluation use -# under the terms of the LICENSE.md file. -# -# For inquiries contact george.drettakis@inria.fr -# - import torch import torch.nn as nn from torch import Tensor @@ -72,10 +61,8 @@ def __init__(self, n: int, embed_dim: int = 4): self.depth_encoder = get_encoder(7, 1) self.means_encoder = get_encoder(3, 3) - self.grid_encoder = get_encoder(3, 2) self.blur_mask_mlp = _create_mlp_tcnn( - in_dim=embed_dim - + self.depth_encoder.out_dim + self.grid_encoder.out_dim, + in_dim=embed_dim + self.depth_encoder.out_dim, num_layers=5, layer_width=64, out_dim=1, @@ -87,22 +74,11 @@ def __init__(self, n: int, embed_dim: int = 4): out_dim=7, ) - def predict_mask(self, image_ids: Tensor, colors: Tensor, depths: Tensor, step: int = -1): - height, width = depths.shape[1:3] - grid_y, grid_x = torch.meshgrid( - (torch.arange(height, device="cuda") + 0.5) / height, - (torch.arange(width, device="cuda") + 0.5) / width, - indexing="ij", - ) - grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) - grid_emb = self.grid_encoder.encode(grid_xy) + def predict_mask(self, image_ids: Tensor, depths: Tensor): depths_emb = self.depth_encoder.encode(log_transform(depths)) images_emb = self.embeds(image_ids).repeat(*depths_emb.shape[:-1], 1) - - mlp_in = torch.cat([images_emb, grid_emb, depths_emb], dim=-1) - # mlp_in = torch.cat([images_emb, grid_emb, colors, depths_emb], dim=-1) + mlp_in = torch.cat([images_emb, depths_emb], dim=-1) mlp_out = self.blur_mask_mlp(mlp_in.reshape(-1, mlp_in.shape[-1])) - print(mlp_out.min().item(), mlp_out.max().item()) blur_mask = torch.sigmoid(mlp_out) blur_mask = blur_mask.reshape(depths.shape) return blur_mask @@ -113,7 +89,6 @@ def predict_deltas( means: Tensor, scales: Tensor, quats: Tensor, - step: int = -1, ): means_log = log_transform(means) means_emb = self.means_encoder.encode(means_log) @@ -129,9 +104,7 @@ def predict_deltas( def mask_variation_loss(self, blur_mask: Tensor): """Mask variation loss.""" - x = blur_mask.mean() - print(x.item()) eps = 1e-6 + x = blur_mask.mean() meanloss = (1 / (1 - x + eps) - 1) + (0.1 / (x + eps)) - # meanloss = 1 / (1 - x + eps) + 1 / (x + eps) - 4 return meanloss diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 188bccab1..ae89d5d6d 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -152,8 +152,6 @@ class Config: blur_opt: bool = False # Learning rate for blur optimization blur_opt_lr: float = 1e-3 - # Regularization for blur optimization as weight decay - blur_opt_reg: float = 0.0 # Regularization for blur mask blur_mask_reg: float = 0.001 @@ -419,7 +417,6 @@ def __init__( torch.optim.Adam( self.blur_module.parameters(), lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size), - weight_decay=cfg.blur_opt_reg, ), ] if world_size > 1: @@ -474,7 +471,6 @@ def rasterize_splats( height: int, masks: Optional[Tensor] = None, blur: bool = False, - step: int = 0, **kwargs, ) -> Tuple[Tensor, Tensor, Dict]: means = self.splats["means"] # [N, 3] @@ -500,7 +496,6 @@ def rasterize_splats( self.splats["means"], self.splats["scales"], quats, - step, ) scales = torch.exp(self.splats["scales"] + scales_delta) quats += quats_delta @@ -640,7 +635,6 @@ def train(self): image_ids=image_ids, render_mode=self.render_mode, masks=masks, - step=step, ) if renders.shape[-1] == 4: colors, depths = renders[..., 0:3], renders[..., 3:4] @@ -669,18 +663,13 @@ def train(self): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, - render_mode="RGB+ED", + render_mode="RGB", masks=masks, blur=True, - step=step, - ) - blur_mask = self.blur_module.predict_mask( - image_ids, - renders[..., 0:3], - renders[..., 3:4], - step=step, ) - colors = (1 - blur_mask) * colors + blur_mask * renders_blur[..., 0:3] + colors_blur = renders_blur[..., 0:3] + blur_mask = self.blur_module.predict_mask(image_ids, depths) + colors = (1 - blur_mask) * colors + blur_mask * colors_blur self.cfg.strategy.step_pre_backward( params=self.splats, @@ -957,17 +946,12 @@ def eval(self, step: int, stage: str = "val"): render_mode="RGB+ED", masks=masks, blur=True, - step=step, - ) - blur_mask = self.blur_module.predict_mask( - image_ids, - renders[..., 0:3], - renders[..., 3:4], - step=step, ) + colors_blur = renders_blur[..., 0:3] + blur_mask = self.blur_module.predict_mask(image_ids, depths) canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) - canvas_list.append(torch.clamp(renders_blur[..., 0:3], 0.0, 1.0)) - colors = (1 - blur_mask) * colors + blur_mask * renders_blur[..., 0:3] + canvas_list.append(torch.clamp(colors_blur, 0.0, 1.0)) + colors = (1 - blur_mask) * colors + blur_mask * colors_blur colors = torch.clamp(colors, 0.0, 1.0) canvas_list.append(colors) From 9e3bc195748d7355c40db61c2acc685c00a69d04 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 23 Oct 2024 21:40:54 -0700 Subject: [PATCH 38/53] minor --- examples/benchmarks/mcmc_deblur.sh | 2 +- examples/blur_kernel.py | 134 ++++++++++++++--------------- examples/simple_trainer.py | 18 ++-- 3 files changed, 76 insertions(+), 78 deletions(-) diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index 1a2375925..d39d52e30 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -1,6 +1,6 @@ SCENE_DIR="data/deblur_dataset/real_defocus_blur" SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools" -SCENE_LIST="defocustools" +SCENE_LIST="defocussausage" DATA_FACTOR=4 RENDER_TRAJ_PATH="spiral" diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index 303c8b460..ca80804c0 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -1,89 +1,35 @@ import torch import torch.nn as nn from torch import Tensor -from examples.mlp import create_mlp, _create_mlp_tcnn +from examples.mlp import create_mlp from gsplat.utils import log_transform -class Embedder: - def __init__(self, **kwargs): - self.kwargs = kwargs - self.create_embedding_fn() - - def create_embedding_fn(self): - embed_fns = [] - d = self.kwargs["input_dims"] - out_dim = 0 - if self.kwargs["include_input"]: - embed_fns.append(lambda x: x) - out_dim += d - - max_freq = self.kwargs["max_freq_log2"] - N_freqs = self.kwargs["num_freqs"] - - if self.kwargs["log_sampling"]: - freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs) - else: - freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs) - - for freq in freq_bands: - for p_fn in self.kwargs["periodic_fns"]: - embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) - out_dim += d - - self.embed_fns = embed_fns - self.out_dim = out_dim - - def encode(self, inputs): - return torch.cat([fn(inputs) for fn in self.embed_fns], -1) - - -def get_encoder(multires, i=0): - embed_kwargs = { - "include_input": True, - "input_dims": i, - "max_freq_log2": multires - 1, - "num_freqs": multires, - "log_sampling": True, - "periodic_fns": [torch.sin, torch.cos], - } - embedder = Embedder(**embed_kwargs) - return embedder - - class BlurOptModule(nn.Module): """Blur optimization module.""" def __init__(self, n: int, embed_dim: int = 4): super().__init__() self.embeds = torch.nn.Embedding(n, embed_dim) - self.embeds.weight.data.fill_(0) - self.depth_encoder = get_encoder(7, 1) self.means_encoder = get_encoder(3, 3) - self.blur_mask_mlp = _create_mlp_tcnn( + self.blur_mask_mlp = create_mlp( in_dim=embed_dim + self.depth_encoder.out_dim, num_layers=5, layer_width=64, out_dim=1, ) - self.blur_deltas_mlp = _create_mlp_tcnn( + self.blur_deltas_mlp = create_mlp( in_dim=embed_dim + self.means_encoder.out_dim + 7, num_layers=5, layer_width=64, out_dim=7, ) - def predict_mask(self, image_ids: Tensor, depths: Tensor): - depths_emb = self.depth_encoder.encode(log_transform(depths)) - images_emb = self.embeds(image_ids).repeat(*depths_emb.shape[:-1], 1) - mlp_in = torch.cat([images_emb, depths_emb], dim=-1) - mlp_out = self.blur_mask_mlp(mlp_in.reshape(-1, mlp_in.shape[-1])) - blur_mask = torch.sigmoid(mlp_out) - blur_mask = blur_mask.reshape(depths.shape) - return blur_mask + def zero_init(self): + torch.nn.init.zeros_(self.embeds.weight) - def predict_deltas( + def forward( self, image_ids: Tensor, means: Tensor, @@ -96,15 +42,69 @@ def predict_deltas( mlp_out = self.blur_deltas_mlp( torch.cat([images_emb, means_emb, scales, quats], dim=-1) ).float() - scales_delta = mlp_out[:, :3] - quats_delta = mlp_out[:, 3:] - scales_delta = torch.clamp(scales_delta, min=0.0, max=0.1) - quats_delta = torch.clamp(quats_delta, min=0.0, max=0.1) - return scales_delta, quats_delta + scales_delta = torch.clamp(mlp_out[:, :3], min=0.0, max=0.1) + quats_delta = torch.clamp(mlp_out[:, 3:], min=0.0, max=0.1) + scales = torch.exp(scales + scales_delta) + quats = quats + quats_delta + return scales, quats - def mask_variation_loss(self, blur_mask: Tensor): + def predict_mask(self, image_ids: Tensor, depths: Tensor): + depths_emb = self.depth_encoder.encode(log_transform(depths)) + images_emb = self.embeds(image_ids).repeat(*depths_emb.shape[:-1], 1) + mlp_in = torch.cat([images_emb, depths_emb], dim=-1) + mlp_out = self.blur_mask_mlp(mlp_in.reshape(-1, mlp_in.shape[-1])) + blur_mask = torch.sigmoid(mlp_out) + blur_mask = blur_mask.reshape(depths.shape) + return blur_mask + + def mask_variation_loss(self, blur_mask: Tensor, eps: float = 1e-6): """Mask variation loss.""" - eps = 1e-6 x = blur_mask.mean() meanloss = (1 / (1 - x + eps) - 1) + (0.1 / (x + eps)) return meanloss + + +def get_encoder(num_freqs: int, input_dims: int): + kwargs = { + "include_input": True, + "input_dims": input_dims, + "max_freq_log2": num_freqs - 1, + "num_freqs": num_freqs, + "log_sampling": True, + "periodic_fns": [torch.sin, torch.cos], + } + encoder = Encoder(**kwargs) + return encoder + + +class Encoder: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_embedding_fn() + + def create_embedding_fn(self): + embed_fns = [] + d = self.kwargs["input_dims"] + out_dim = 0 + if self.kwargs["include_input"]: + embed_fns.append(lambda x: x) + out_dim += d + + max_freq = self.kwargs["max_freq_log2"] + N_freqs = self.kwargs["num_freqs"] + + if self.kwargs["log_sampling"]: + freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs) + else: + freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs) + + for freq in freq_bands: + for p_fn in self.kwargs["periodic_fns"]: + embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) + out_dim += d + + self.embed_fns = embed_fns + self.out_dim = out_dim + + def encode(self, inputs): + return torch.cat([fn(inputs) for fn in self.embed_fns], -1) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index ae89d5d6d..e28c1ba03 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -413,6 +413,7 @@ def __init__( self.blur_optimizers = [] if cfg.blur_opt: self.blur_module = BlurOptModule(len(self.trainset)).to(self.device) + self.blur_module.zero_init() self.blur_optimizers = [ torch.optim.Adam( self.blur_module.parameters(), @@ -490,15 +491,12 @@ def rasterize_splats( colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] if self.cfg.blur_opt and blur: - quats = F.normalize(self.splats["quats"], dim=-1) - scales_delta, quats_delta = self.blur_module.predict_deltas( - image_ids, - self.splats["means"], - self.splats["scales"], - quats, + scales, quats = self.blur_module( + image_ids=image_ids, + means=self.splats["means"], + scales=self.splats["scales"], + quats=F.normalize(self.splats["quats"], dim=-1), ) - scales = torch.exp(self.splats["scales"] + scales_delta) - quats += quats_delta else: scales = torch.exp(self.splats["scales"]) # [N, 3] quats = self.splats["quats"] # [N, 4] @@ -934,6 +932,8 @@ def eval(self, step: int, stage: str = "val"): colors = torch.clamp(colors, 0.0, 1.0) canvas_list = [pixels, colors] if self.cfg.blur_opt and stage == "train": + blur_mask = self.blur_module.predict_mask(image_ids, depths) + canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) renders_blur, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, @@ -948,8 +948,6 @@ def eval(self, step: int, stage: str = "val"): blur=True, ) colors_blur = renders_blur[..., 0:3] - blur_mask = self.blur_module.predict_mask(image_ids, depths) - canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) canvas_list.append(torch.clamp(colors_blur, 0.0, 1.0)) colors = (1 - blur_mask) * colors + blur_mask * colors_blur colors = torch.clamp(colors, 0.0, 1.0) From 63a4bb632a2bbeb12eee1552e9beb3cbd963afe1 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 23 Oct 2024 21:47:25 -0700 Subject: [PATCH 39/53] summarize stats --- examples/benchmarks/compression/summarize_stats.py | 3 +-- examples/benchmarks/mcmc_deblur.sh | 10 ++-------- examples/simple_trainer.py | 2 +- 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/examples/benchmarks/compression/summarize_stats.py b/examples/benchmarks/compression/summarize_stats.py index f2552a049..5efa729fe 100644 --- a/examples/benchmarks/compression/summarize_stats.py +++ b/examples/benchmarks/compression/summarize_stats.py @@ -8,9 +8,8 @@ import tyro -def main(results_dir: str, scenes: List[str]): +def main(results_dir: str, scenes: List[str], stage: str = "compress"): print("scenes:", scenes) - stage = "val" summary = defaultdict(list) for scene in scenes: diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index d39d52e30..d7b4e6fd3 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -21,11 +21,5 @@ do --result_dir $RESULT_DIR/$SCENE done -# Zip the compressed files and summarize the stats -if command -v zip &> /dev/null -then - echo "Zipping results" - python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST -else - echo "zip command not found, skipping zipping" -fi +# Summarize the stats +python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index e28c1ba03..0daa54482 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -943,7 +943,7 @@ def eval(self, step: int, stage: str = "val"): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, - render_mode="RGB+ED", + render_mode="RGB", masks=masks, blur=True, ) From ed830d7141d7758823e5d29dec0716e6eac562cc Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 23 Oct 2024 22:54:52 -0700 Subject: [PATCH 40/53] less freqs --- examples/benchmarks/mcmc_deblur.sh | 1 - examples/blur_kernel.py | 23 ++++++++++++++++++----- examples/simple_trainer.py | 2 +- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index d7b4e6fd3..5af1f1b17 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -1,6 +1,5 @@ SCENE_DIR="data/deblur_dataset/real_defocus_blur" SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools" -SCENE_LIST="defocussausage" DATA_FACTOR=4 RENDER_TRAJ_PATH="spiral" diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index ca80804c0..eaf479cab 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn from torch import Tensor +import torch.nn.functional as F from examples.mlp import create_mlp from gsplat.utils import log_transform @@ -11,10 +12,11 @@ class BlurOptModule(nn.Module): def __init__(self, n: int, embed_dim: int = 4): super().__init__() self.embeds = torch.nn.Embedding(n, embed_dim) - self.depth_encoder = get_encoder(7, 1) self.means_encoder = get_encoder(3, 3) + self.depths_encoder = get_encoder(3, 1) + self.grid_encoder = get_encoder(1, 2) self.blur_mask_mlp = create_mlp( - in_dim=embed_dim + self.depth_encoder.out_dim, + in_dim=embed_dim + self.depths_encoder.out_dim + self.grid_encoder.out_dim, num_layers=5, layer_width=64, out_dim=1, @@ -36,7 +38,9 @@ def forward( scales: Tensor, quats: Tensor, ): + quats = F.normalize(quats, dim=-1) means_log = log_transform(means) + means_emb = self.means_encoder.encode(means_log) images_emb = self.embeds(image_ids).repeat(means.shape[0], 1) mlp_out = self.blur_deltas_mlp( @@ -49,15 +53,24 @@ def forward( return scales, quats def predict_mask(self, image_ids: Tensor, depths: Tensor): - depths_emb = self.depth_encoder.encode(log_transform(depths)) + height, width = depths.shape[1:3] + grid_y, grid_x = torch.meshgrid( + (torch.arange(height, device=depths.device) + 0.5) / height, + (torch.arange(width, device=depths.device) + 0.5) / width, + indexing="ij", + ) + grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) + grid_emb = self.grid_encoder.encode(grid_xy) + + depths_emb = self.depths_encoder.encode(log_transform(depths)) images_emb = self.embeds(image_ids).repeat(*depths_emb.shape[:-1], 1) - mlp_in = torch.cat([images_emb, depths_emb], dim=-1) + mlp_in = torch.cat([images_emb, grid_emb, depths_emb], dim=-1) mlp_out = self.blur_mask_mlp(mlp_in.reshape(-1, mlp_in.shape[-1])) blur_mask = torch.sigmoid(mlp_out) blur_mask = blur_mask.reshape(depths.shape) return blur_mask - def mask_variation_loss(self, blur_mask: Tensor, eps: float = 1e-6): + def mask_variation_loss(self, blur_mask: Tensor, eps: float = 1e-2): """Mask variation loss.""" x = blur_mask.mean() meanloss = (1 / (1 - x + eps) - 1) + (0.1 / (x + eps)) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 0daa54482..d864814b5 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -495,7 +495,7 @@ def rasterize_splats( image_ids=image_ids, means=self.splats["means"], scales=self.splats["scales"], - quats=F.normalize(self.splats["quats"], dim=-1), + quats=self.splats["quats"], ) else: scales = torch.exp(self.splats["scales"]) # [N, 3] From d874ef1e4e0035e907959289197c5f2703175646 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Sat, 26 Oct 2024 13:02:59 -0700 Subject: [PATCH 41/53] latest --- examples/benchmarks/mcmc_deblur.sh | 2 ++ examples/blur_kernel.py | 21 +++++++++++++++----- examples/simple_trainer.py | 31 ++++++++++++++++++++---------- 3 files changed, 39 insertions(+), 15 deletions(-) diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index 5af1f1b17..0ef02ad69 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -15,6 +15,8 @@ do CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ --strategy.cap-max $CAP_MAX \ --blur_opt \ + --blur_opt_lr 1e-3 \ + --blur_mean_reg 0.005 \ --render_traj_path $RENDER_TRAJ_PATH \ --data_dir $SCENE_DIR/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index eaf479cab..b6daf971d 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -2,6 +2,7 @@ import torch.nn as nn from torch import Tensor import torch.nn.functional as F +from kornia.filters import median_blur from examples.mlp import create_mlp from gsplat.utils import log_transform @@ -65,17 +66,27 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor): depths_emb = self.depths_encoder.encode(log_transform(depths)) images_emb = self.embeds(image_ids).repeat(*depths_emb.shape[:-1], 1) mlp_in = torch.cat([images_emb, grid_emb, depths_emb], dim=-1) - mlp_out = self.blur_mask_mlp(mlp_in.reshape(-1, mlp_in.shape[-1])) + mlp_out = self.blur_mask_mlp(mlp_in.reshape(-1, mlp_in.shape[-1])).reshape( + depths.shape + ) blur_mask = torch.sigmoid(mlp_out) - blur_mask = blur_mask.reshape(depths.shape) return blur_mask - def mask_variation_loss(self, blur_mask: Tensor, eps: float = 1e-2): - """Mask variation loss.""" + def mask_mean_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2): + """Mask mean loss.""" x = blur_mask.mean() - meanloss = (1 / (1 - x + eps) - 1) + (0.1 / (x + eps)) + a = 0.9 + meanloss = a * (1 / (1 - x + eps) - 1) + (1 - a) * (1 / (x + eps) - 1) return meanloss + def mask_smoothness_loss(self, blur_mask: Tensor): + """Mask smoothness loss.""" + blurred_xy = median_blur(blur_mask.permute(0, 3, 1, 2), (5, 5)).permute( + 0, 2, 3, 1 + ) + smoothloss = F.huber_loss(blur_mask, blurred_xy) + return smoothloss + def get_encoder(num_freqs: int, input_dims: int): kwargs = { diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index d864814b5..4e73510da 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -152,8 +152,10 @@ class Config: blur_opt: bool = False # Learning rate for blur optimization blur_opt_lr: float = 1e-3 - # Regularization for blur mask - blur_mask_reg: float = 0.001 + # Regularization for blur mask mean + blur_mean_reg: float = 0.001 + # Regularization for blur mask smoothness + blur_smoothness_reg: float = 0.0 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False @@ -661,12 +663,15 @@ def train(self): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, - render_mode="RGB", + render_mode="RGB+ED", masks=masks, blur=True, ) - colors_blur = renders_blur[..., 0:3] - blur_mask = self.blur_module.predict_mask(image_ids, depths) + colors_blur, depths_blur = ( + renders_blur[..., 0:3], + renders_blur[..., 3:4], + ) + blur_mask = self.blur_module.predict_mask(image_ids, depths_blur) colors = (1 - blur_mask) * colors + blur_mask * colors_blur self.cfg.strategy.step_pre_backward( @@ -706,7 +711,10 @@ def train(self): tvloss = 10 * total_variation_loss(self.bil_grids.grids) loss += tvloss if cfg.blur_opt: - loss += cfg.blur_mask_reg * self.blur_module.mask_variation_loss( + loss += cfg.blur_mean_reg * self.blur_module.mask_mean_loss( + blur_mask, step + ) + loss += cfg.blur_smoothness_reg * self.blur_module.mask_smoothness_loss( blur_mask ) @@ -932,8 +940,6 @@ def eval(self, step: int, stage: str = "val"): colors = torch.clamp(colors, 0.0, 1.0) canvas_list = [pixels, colors] if self.cfg.blur_opt and stage == "train": - blur_mask = self.blur_module.predict_mask(image_ids, depths) - canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) renders_blur, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, @@ -943,11 +949,16 @@ def eval(self, step: int, stage: str = "val"): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, - render_mode="RGB", + render_mode="RGB+ED", masks=masks, blur=True, ) - colors_blur = renders_blur[..., 0:3] + colors_blur, depths_blur = ( + renders_blur[..., 0:3], + renders_blur[..., 3:4], + ) + blur_mask = self.blur_module.predict_mask(image_ids, depths_blur) + canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) canvas_list.append(torch.clamp(colors_blur, 0.0, 1.0)) colors = (1 - blur_mask) * colors + blur_mask * colors_blur colors = torch.clamp(colors, 0.0, 1.0) From 8a02e745a552b117d2c6d720c1a9bab8def57c78 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 28 Oct 2024 16:53:54 -0700 Subject: [PATCH 42/53] latest run avg psnr 23.50 --- examples/benchmarks/mcmc_deblur.sh | 9 ++++----- examples/blur_kernel.py | 27 +++++++++++++++------------ examples/simple_trainer.py | 19 +++++++++++-------- 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index 0ef02ad69..d497515f9 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -1,12 +1,12 @@ SCENE_DIR="data/deblur_dataset/real_defocus_blur" SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools" +SCENE_LIST="defocuscake defocustools defocussausage defocuscupcake defocuscups defocuscoral defocusdaisy defocusseal defocuscaps defocuscisco" DATA_FACTOR=4 RENDER_TRAJ_PATH="spiral" - -RESULT_DIR="results/benchmark_mcmc_deblur" CAP_MAX=250000 +RESULT_DIR="results/benchmark_mcmc_deblur/c0.2_a10" for SCENE in $SCENE_LIST; do echo "Running $SCENE" @@ -15,12 +15,11 @@ do CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ --strategy.cap-max $CAP_MAX \ --blur_opt \ - --blur_opt_lr 1e-3 \ - --blur_mean_reg 0.005 \ + --blur_a 10 \ + --blur_c 0.2 \ --render_traj_path $RENDER_TRAJ_PATH \ --data_dir $SCENE_DIR/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE done -# Summarize the stats python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index b6daf971d..a28f3244c 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -10,8 +10,11 @@ class BlurOptModule(nn.Module): """Blur optimization module.""" - def __init__(self, n: int, embed_dim: int = 4): + def __init__(self, cfg, n: int, embed_dim: int = 4): super().__init__() + self.a = cfg.blur_a + self.c = cfg.blur_c + self.embeds = torch.nn.Embedding(n, embed_dim) self.means_encoder = get_encoder(3, 3) self.depths_encoder = get_encoder(3, 1) @@ -75,17 +78,17 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor): def mask_mean_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2): """Mask mean loss.""" x = blur_mask.mean() - a = 0.9 - meanloss = a * (1 / (1 - x + eps) - 1) + (1 - a) * (1 / (x + eps) - 1) - return meanloss - - def mask_smoothness_loss(self, blur_mask: Tensor): - """Mask smoothness loss.""" - blurred_xy = median_blur(blur_mask.permute(0, 3, 1, 2), (5, 5)).permute( - 0, 2, 3, 1 - ) - smoothloss = F.huber_loss(blur_mask, blurred_xy) - return smoothloss + if step <= 2000: + a = 20 + b = 1 + c = 0.2 + else: + a = self.a + b = 1 + c = self.c + print(x.item(), a, b, c) + meanloss = a * (1 / (1 - x + eps) - 1) + b * (1 / (x + eps) - 1) + return c * meanloss def get_encoder(num_freqs: int, input_dims: int): diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 4e73510da..1cd997697 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -155,7 +155,8 @@ class Config: # Regularization for blur mask mean blur_mean_reg: float = 0.001 # Regularization for blur mask smoothness - blur_smoothness_reg: float = 0.0 + blur_a: float = 4 + blur_c: float = 0.5 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False @@ -414,7 +415,7 @@ def __init__( self.blur_optimizers = [] if cfg.blur_opt: - self.blur_module = BlurOptModule(len(self.trainset)).to(self.device) + self.blur_module = BlurOptModule(cfg, len(self.trainset)).to(self.device) self.blur_module.zero_init() self.blur_optimizers = [ torch.optim.Adam( @@ -671,7 +672,7 @@ def train(self): renders_blur[..., 0:3], renders_blur[..., 3:4], ) - blur_mask = self.blur_module.predict_mask(image_ids, depths_blur) + blur_mask = self.blur_module.predict_mask(image_ids, depths) colors = (1 - blur_mask) * colors + blur_mask * colors_blur self.cfg.strategy.step_pre_backward( @@ -714,9 +715,6 @@ def train(self): loss += cfg.blur_mean_reg * self.blur_module.mask_mean_loss( blur_mask, step ) - loss += cfg.blur_smoothness_reg * self.blur_module.mask_smoothness_loss( - blur_mask - ) # regularizations if cfg.opacity_reg > 0.0: @@ -875,6 +873,8 @@ def train(self): self.eval(step, stage="train") self.eval(step) self.render_traj(step) + if (step + 1) % 1000 == 0 or step == 0: + self.eval(step, stage="train", vis_skip=True) # run compression if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: @@ -892,7 +892,7 @@ def train(self): self.viewer.update(step, num_train_rays_per_step) @torch.no_grad() - def eval(self, step: int, stage: str = "val"): + def eval(self, step: int, stage: str = "val", vis_skip: bool = False): """Entry for evaluation.""" print("Running evaluation...") cfg = self.cfg @@ -904,10 +904,13 @@ def eval(self, step: int, stage: str = "val"): dataloader = torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, num_workers=1 ) + train_vis_image_ids = np.linspace(0, len(dataloader) - 1, 7).astype(int) ellipse_time = 0 metrics = defaultdict(list) for i, data in enumerate(dataloader): + if vis_skip and stage == "train" and i not in train_vis_image_ids: + continue camtoworlds = data["camtoworld"].to(device) Ks = data["K"].to(device) pixels = data["image"].to(device) / 255.0 @@ -957,7 +960,7 @@ def eval(self, step: int, stage: str = "val"): renders_blur[..., 0:3], renders_blur[..., 3:4], ) - blur_mask = self.blur_module.predict_mask(image_ids, depths_blur) + blur_mask = self.blur_module.predict_mask(image_ids, depths) canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) canvas_list.append(torch.clamp(colors_blur, 0.0, 1.0)) colors = (1 - blur_mask) * colors + blur_mask * colors_blur From c8fd74dacea9afca82fb06b3b5bdc7ad3435b23f Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 28 Oct 2024 17:01:44 -0700 Subject: [PATCH 43/53] cleanup --- examples/benchmarks/mcmc_deblur.sh | 5 +--- examples/blur_kernel.py | 17 ++++---------- examples/simple_trainer.py | 37 +++++++++--------------------- 3 files changed, 16 insertions(+), 43 deletions(-) diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index d497515f9..c3d5a2822 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -1,12 +1,11 @@ SCENE_DIR="data/deblur_dataset/real_defocus_blur" SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools" -SCENE_LIST="defocuscake defocustools defocussausage defocuscupcake defocuscups defocuscoral defocusdaisy defocusseal defocuscaps defocuscisco" DATA_FACTOR=4 RENDER_TRAJ_PATH="spiral" CAP_MAX=250000 -RESULT_DIR="results/benchmark_mcmc_deblur/c0.2_a10" +RESULT_DIR="results/benchmark_mcmc_deblur" for SCENE in $SCENE_LIST; do echo "Running $SCENE" @@ -15,8 +14,6 @@ do CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ --strategy.cap-max $CAP_MAX \ --blur_opt \ - --blur_a 10 \ - --blur_c 0.2 \ --render_traj_path $RENDER_TRAJ_PATH \ --data_dir $SCENE_DIR/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index a28f3244c..90b5206b7 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -2,7 +2,6 @@ import torch.nn as nn from torch import Tensor import torch.nn.functional as F -from kornia.filters import median_blur from examples.mlp import create_mlp from gsplat.utils import log_transform @@ -10,11 +9,8 @@ class BlurOptModule(nn.Module): """Blur optimization module.""" - def __init__(self, cfg, n: int, embed_dim: int = 4): + def __init__(self, n: int, embed_dim: int = 4): super().__init__() - self.a = cfg.blur_a - self.c = cfg.blur_c - self.embeds = torch.nn.Embedding(n, embed_dim) self.means_encoder = get_encoder(3, 3) self.depths_encoder = get_encoder(3, 1) @@ -80,15 +76,10 @@ def mask_mean_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2): x = blur_mask.mean() if step <= 2000: a = 20 - b = 1 - c = 0.2 else: - a = self.a - b = 1 - c = self.c - print(x.item(), a, b, c) - meanloss = a * (1 / (1 - x + eps) - 1) + b * (1 / (x + eps) - 1) - return c * meanloss + a = 10 + meanloss = a * (1 / (1 - x + eps) - 1) + (1 / (x + eps) - 1) + return meanloss def get_encoder(num_freqs: int, input_dims: int): diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 1cd997697..8e3696790 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -83,7 +83,7 @@ class Config: # Number of training steps max_steps: int = 30_000 # Steps to evaluate the model - eval_steps: List[int] = field(default_factory=lambda: [7_000, 15_000, 30_000]) + eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) # Steps to save the model save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) @@ -153,10 +153,7 @@ class Config: # Learning rate for blur optimization blur_opt_lr: float = 1e-3 # Regularization for blur mask mean - blur_mean_reg: float = 0.001 - # Regularization for blur mask smoothness - blur_a: float = 4 - blur_c: float = 0.5 + blur_mean_reg: float = 0.0002 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False @@ -655,6 +652,7 @@ def train(self): bkgd = torch.rand(1, 3, device=device) colors = colors + bkgd * (1.0 - alphas) if cfg.blur_opt: + blur_mask = self.blur_module.predict_mask(image_ids, depths) renders_blur, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, @@ -664,16 +662,11 @@ def train(self): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, - render_mode="RGB+ED", + render_mode="RGB", masks=masks, blur=True, ) - colors_blur, depths_blur = ( - renders_blur[..., 0:3], - renders_blur[..., 3:4], - ) - blur_mask = self.blur_module.predict_mask(image_ids, depths) - colors = (1 - blur_mask) * colors + blur_mask * colors_blur + colors = (1 - blur_mask) * colors + blur_mask * renders_blur[..., 0:3] self.cfg.strategy.step_pre_backward( params=self.splats, @@ -871,10 +864,8 @@ def train(self): # eval the full set if step in [i - 1 for i in cfg.eval_steps]: self.eval(step, stage="train") - self.eval(step) + self.eval(step, stage="val") self.render_traj(step) - if (step + 1) % 1000 == 0 or step == 0: - self.eval(step, stage="train", vis_skip=True) # run compression if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: @@ -892,7 +883,7 @@ def train(self): self.viewer.update(step, num_train_rays_per_step) @torch.no_grad() - def eval(self, step: int, stage: str = "val", vis_skip: bool = False): + def eval(self, step: int, stage: str = "val"): """Entry for evaluation.""" print("Running evaluation...") cfg = self.cfg @@ -904,13 +895,10 @@ def eval(self, step: int, stage: str = "val", vis_skip: bool = False): dataloader = torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, num_workers=1 ) - train_vis_image_ids = np.linspace(0, len(dataloader) - 1, 7).astype(int) ellipse_time = 0 metrics = defaultdict(list) for i, data in enumerate(dataloader): - if vis_skip and stage == "train" and i not in train_vis_image_ids: - continue camtoworlds = data["camtoworld"].to(device) Ks = data["K"].to(device) pixels = data["image"].to(device) / 255.0 @@ -943,6 +931,8 @@ def eval(self, step: int, stage: str = "val", vis_skip: bool = False): colors = torch.clamp(colors, 0.0, 1.0) canvas_list = [pixels, colors] if self.cfg.blur_opt and stage == "train": + blur_mask = self.blur_module.predict_mask(image_ids, depths) + canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) renders_blur, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, @@ -952,16 +942,11 @@ def eval(self, step: int, stage: str = "val", vis_skip: bool = False): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, - render_mode="RGB+ED", + render_mode="RGB", masks=masks, blur=True, ) - colors_blur, depths_blur = ( - renders_blur[..., 0:3], - renders_blur[..., 3:4], - ) - blur_mask = self.blur_module.predict_mask(image_ids, depths) - canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) + colors_blur = renders_blur[..., 0:3] canvas_list.append(torch.clamp(colors_blur, 0.0, 1.0)) colors = (1 - blur_mask) * colors + blur_mask * colors_blur colors = torch.clamp(colors, 0.0, 1.0) From 834e4e8a4473d5102f6442274ea945dccfcf8b05 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 28 Oct 2024 17:03:46 -0700 Subject: [PATCH 44/53] cleanup --- examples/{blur_kernel.py => blur_opt.py} | 0 examples/simple_trainer.py | 4 ++-- 2 files changed, 2 insertions(+), 2 deletions(-) rename examples/{blur_kernel.py => blur_opt.py} (100%) diff --git a/examples/blur_kernel.py b/examples/blur_opt.py similarity index 100% rename from examples/blur_kernel.py rename to examples/blur_opt.py diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 8e3696790..2e1eaa777 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -28,6 +28,7 @@ from fused_ssim import fused_ssim from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from typing_extensions import Literal, assert_never +from blur_opt import BlurOptModule from utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed from lib_bilagrid import ( BilateralGrid, @@ -35,7 +36,6 @@ color_correct, total_variation_loss, ) -from blur_kernel import BlurOptModule from gsplat.compression import PngCompression from gsplat.distributed import cli @@ -412,7 +412,7 @@ def __init__( self.blur_optimizers = [] if cfg.blur_opt: - self.blur_module = BlurOptModule(cfg, len(self.trainset)).to(self.device) + self.blur_module = BlurOptModule(len(self.trainset)).to(self.device) self.blur_module.zero_init() self.blur_optimizers = [ torch.optim.Adam( From cb826c6bc36db968a07c5c0fd204d8f4918e6acd Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 28 Oct 2024 17:16:19 -0700 Subject: [PATCH 45/53] docstring --- examples/benchmarks/mcmc_deblur.sh | 1 + examples/blur_opt.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index c3d5a2822..fd2ea9b61 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -19,4 +19,5 @@ do --result_dir $RESULT_DIR/$SCENE done +# Summarize the stats python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val diff --git a/examples/blur_opt.py b/examples/blur_opt.py index 90b5206b7..8deb135f6 100644 --- a/examples/blur_opt.py +++ b/examples/blur_opt.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import torch import torch.nn as nn from torch import Tensor @@ -6,9 +7,12 @@ from gsplat.utils import log_transform +@dataclass class BlurOptModule(nn.Module): """Blur optimization module.""" + num_warmup_steps: int = 2000 + def __init__(self, n: int, embed_dim: int = 4): super().__init__() self.embeds = torch.nn.Embedding(n, embed_dim) @@ -74,7 +78,7 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor): def mask_mean_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2): """Mask mean loss.""" x = blur_mask.mean() - if step <= 2000: + if step <= self.num_warmup_steps: a = 20 else: a = 10 From ac0cef589a8ef09b9fbd9539b14933ac9febe01b Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 28 Oct 2024 17:18:07 -0700 Subject: [PATCH 46/53] docstring --- examples/blur_opt.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/blur_opt.py b/examples/blur_opt.py index 8deb135f6..1f0e97855 100644 --- a/examples/blur_opt.py +++ b/examples/blur_opt.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass import torch import torch.nn as nn from torch import Tensor @@ -7,14 +6,13 @@ from gsplat.utils import log_transform -@dataclass class BlurOptModule(nn.Module): """Blur optimization module.""" - num_warmup_steps: int = 2000 - def __init__(self, n: int, embed_dim: int = 4): super().__init__() + self.num_warmup_steps = 2000 + self.embeds = torch.nn.Embedding(n, embed_dim) self.means_encoder = get_encoder(3, 3) self.depths_encoder = get_encoder(3, 1) @@ -76,7 +74,12 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor): return blur_mask def mask_mean_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2): - """Mask mean loss.""" + """Loss function for regularizing the blur mask by controlling its mean. + + The loss function is designed to diverge to +infinity at 0 and 1. This + prevents the mask from collapsing to predicting all 0s or 1s. It is also + bias towards 0 to encourage sparsity. During warmup, we set this bias very + high to start with a sparse and not collapsed blur mask.""" x = blur_mask.mean() if step <= self.num_warmup_steps: a = 20 From a3d7d15ee8d262f0aefd0b8a96ad554a5b04d47d Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 28 Oct 2024 17:24:57 -0700 Subject: [PATCH 47/53] rescale --- examples/blur_opt.py | 10 +++++----- examples/simple_trainer.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/blur_opt.py b/examples/blur_opt.py index 1f0e97855..a84880df6 100644 --- a/examples/blur_opt.py +++ b/examples/blur_opt.py @@ -78,14 +78,14 @@ def mask_mean_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2): The loss function is designed to diverge to +infinity at 0 and 1. This prevents the mask from collapsing to predicting all 0s or 1s. It is also - bias towards 0 to encourage sparsity. During warmup, we set this bias very - high to start with a sparse and not collapsed blur mask.""" + bias towards 0 to encourage sparsity. During warmup, we set this bias even + higher to start with a sparse and not collapsed blur mask.""" x = blur_mask.mean() if step <= self.num_warmup_steps: - a = 20 + a = 2 else: - a = 10 - meanloss = a * (1 / (1 - x + eps) - 1) + (1 / (x + eps) - 1) + a = 1 + meanloss = a * (1 / (1 - x + eps) - 1) + 0.1 * (1 / (x + eps) - 1) return meanloss diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 2e1eaa777..52603c0f3 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -153,7 +153,7 @@ class Config: # Learning rate for blur optimization blur_opt_lr: float = 1e-3 # Regularization for blur mask mean - blur_mean_reg: float = 0.0002 + blur_mean_reg: float = 0.002 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False From 10bac1d931dccdbf1cd61a668c55e1c4e605516e Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 28 Oct 2024 17:35:50 -0700 Subject: [PATCH 48/53] minor --- examples/blur_opt.py | 13 ++++++------- examples/simple_trainer.py | 11 +++++------ 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/examples/blur_opt.py b/examples/blur_opt.py index a84880df6..68f270e02 100644 --- a/examples/blur_opt.py +++ b/examples/blur_opt.py @@ -73,20 +73,19 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor): blur_mask = torch.sigmoid(mlp_out) return blur_mask - def mask_mean_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2): + def mask_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2): """Loss function for regularizing the blur mask by controlling its mean. - The loss function is designed to diverge to +infinity at 0 and 1. This - prevents the mask from collapsing to predicting all 0s or 1s. It is also - bias towards 0 to encourage sparsity. During warmup, we set this bias even - higher to start with a sparse and not collapsed blur mask.""" + The loss function diverges to +infinity at 0 and 1. This prevents the mask + from collapsing all 0s or 1s. It is also biased towards 0 to encourage + sparsity. During warmup, the bias is even higher to start with a sparse mask.""" x = blur_mask.mean() if step <= self.num_warmup_steps: a = 2 else: a = 1 - meanloss = a * (1 / (1 - x + eps) - 1) + 0.1 * (1 / (x + eps) - 1) - return meanloss + maskloss = a * (1 / (1 - x + eps) - 1) + 0.1 * (1 / (x + eps) - 1) + return maskloss def get_encoder(num_freqs: int, input_dims: int): diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 52603c0f3..0bed4d6a2 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -83,7 +83,7 @@ class Config: # Number of training steps max_steps: int = 30_000 # Steps to evaluate the model - eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + eval_steps: List[int] = field(default_factory=lambda: [7_000, 15_000, 30_000]) # Steps to save the model save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) @@ -152,8 +152,8 @@ class Config: blur_opt: bool = False # Learning rate for blur optimization blur_opt_lr: float = 1e-3 - # Regularization for blur mask mean - blur_mean_reg: float = 0.002 + # Regularization for blur mask + blur_mask_reg: float = 0.002 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False @@ -705,9 +705,7 @@ def train(self): tvloss = 10 * total_variation_loss(self.bil_grids.grids) loss += tvloss if cfg.blur_opt: - loss += cfg.blur_mean_reg * self.blur_module.mask_mean_loss( - blur_mask, step - ) + loss += cfg.blur_mask_reg * self.blur_module.mask_loss(blur_mask, step) # regularizations if cfg.opacity_reg > 0.0: @@ -721,6 +719,7 @@ def train(self): loss + cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() ) + loss.backward() desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " From 68289aeee5f1d3ab240bc86a61c3e7c33a8a9028 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Tue, 29 Oct 2024 11:39:36 -0700 Subject: [PATCH 49/53] warmup 3 --- examples/blur_opt.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/blur_opt.py b/examples/blur_opt.py index 68f270e02..d01a552af 100644 --- a/examples/blur_opt.py +++ b/examples/blur_opt.py @@ -81,10 +81,12 @@ def mask_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2): sparsity. During warmup, the bias is even higher to start with a sparse mask.""" x = blur_mask.mean() if step <= self.num_warmup_steps: - a = 2 + a = 3 + b = 0.1 else: a = 1 - maskloss = a * (1 / (1 - x + eps) - 1) + 0.1 * (1 / (x + eps) - 1) + b = 0.1 + maskloss = a * (1 / (1 - x + eps) - 1) + b * (1 / (x + eps) - 1) return maskloss From d95b9293fcbf75f283632041b77ccb64025e87f1 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Tue, 29 Oct 2024 13:35:42 -0700 Subject: [PATCH 50/53] delayed start instead of warmup --- examples/blur_opt.py | 16 +++++----------- examples/simple_trainer.py | 19 +++++++++++++------ 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/examples/blur_opt.py b/examples/blur_opt.py index d01a552af..fed26000c 100644 --- a/examples/blur_opt.py +++ b/examples/blur_opt.py @@ -11,8 +11,6 @@ class BlurOptModule(nn.Module): def __init__(self, n: int, embed_dim: int = 4): super().__init__() - self.num_warmup_steps = 2000 - self.embeds = torch.nn.Embedding(n, embed_dim) self.means_encoder = get_encoder(3, 3) self.depths_encoder = get_encoder(3, 1) @@ -73,19 +71,15 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor): blur_mask = torch.sigmoid(mlp_out) return blur_mask - def mask_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2): + def mask_loss(self, blur_mask: Tensor, eps: float = 1e-2): """Loss function for regularizing the blur mask by controlling its mean. The loss function diverges to +infinity at 0 and 1. This prevents the mask - from collapsing all 0s or 1s. It is also biased towards 0 to encourage - sparsity. During warmup, the bias is even higher to start with a sparse mask.""" + from collapsing all 0s or 1s. It is biased towards 0 to encourage sparsity. + """ x = blur_mask.mean() - if step <= self.num_warmup_steps: - a = 3 - b = 0.1 - else: - a = 1 - b = 0.1 + a = 2.0 + b = 0.1 maskloss = a * (1 / (1 - x + eps) - 1) + b * (1 / (x + eps) - 1) return maskloss diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 0bed4d6a2..e6b054ce9 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -153,7 +153,9 @@ class Config: # Learning rate for blur optimization blur_opt_lr: float = 1e-3 # Regularization for blur mask - blur_mask_reg: float = 0.002 + blur_mask_reg: float = 0.001 + # Blur start iteration + blur_start_iter: int = 2_000 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False @@ -651,7 +653,7 @@ def train(self): if cfg.random_bkgd: bkgd = torch.rand(1, 3, device=device) colors = colors + bkgd * (1.0 - alphas) - if cfg.blur_opt: + if cfg.blur_opt and step >= cfg.blur_start_iter: blur_mask = self.blur_module.predict_mask(image_ids, depths) renders_blur, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, @@ -704,8 +706,8 @@ def train(self): if cfg.use_bilateral_grid: tvloss = 10 * total_variation_loss(self.bil_grids.grids) loss += tvloss - if cfg.blur_opt: - loss += cfg.blur_mask_reg * self.blur_module.mask_loss(blur_mask, step) + if cfg.blur_opt and step >= cfg.blur_start_iter: + loss += cfg.blur_mask_reg * self.blur_module.mask_loss(blur_mask) # regularizations if cfg.opacity_reg > 0.0: @@ -865,6 +867,8 @@ def train(self): self.eval(step, stage="train") self.eval(step, stage="val") self.render_traj(step) + if step % 1000 == 0: + self.eval(step, stage="vis") # run compression if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: @@ -890,7 +894,7 @@ def eval(self, step: int, stage: str = "val"): world_rank = self.world_rank world_size = self.world_size - dataset = self.trainset if stage == "train" else self.valset + dataset = self.valset if stage == "val" else self.trainset dataloader = torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, num_workers=1 ) @@ -898,6 +902,9 @@ def eval(self, step: int, stage: str = "val"): ellipse_time = 0 metrics = defaultdict(list) for i, data in enumerate(dataloader): + if stage == "vis": + if i % 5 != 0: + continue camtoworlds = data["camtoworld"].to(device) Ks = data["K"].to(device) pixels = data["image"].to(device) / 255.0 @@ -929,7 +936,7 @@ def eval(self, step: int, stage: str = "val"): colors = torch.clamp(colors, 0.0, 1.0) canvas_list = [pixels, colors] - if self.cfg.blur_opt and stage == "train": + if self.cfg.blur_opt and stage != "val": blur_mask = self.blur_module.predict_mask(image_ids, depths) canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) renders_blur, _, _ = self.rasterize_splats( From 8f7e92cef63ee1783a40b60fa922cf245a36c3ef Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Thu, 31 Oct 2024 07:55:38 -0700 Subject: [PATCH 51/53] wip --- examples/benchmarks/mcmc_deblur.sh | 6 ++++-- examples/blur_opt.py | 12 ++++-------- examples/simple_trainer.py | 12 +++++++----- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index fd2ea9b61..b9f87c12c 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -5,7 +5,7 @@ DATA_FACTOR=4 RENDER_TRAJ_PATH="spiral" CAP_MAX=250000 -RESULT_DIR="results/benchmark_mcmc_deblur" +RESULT_DIR="results/benchmark_mcmc_deblur_wd/1e-3_0.8" for SCENE in $SCENE_LIST; do echo "Running $SCENE" @@ -14,10 +14,12 @@ do CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ --strategy.cap-max $CAP_MAX \ --blur_opt \ + --blur_opt_lr 1e-3 \ + --blur_a 0.8 \ + --blur_mask_reg 0.002 \ --render_traj_path $RENDER_TRAJ_PATH \ --data_dir $SCENE_DIR/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE done - # Summarize the stats python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val diff --git a/examples/blur_opt.py b/examples/blur_opt.py index fed26000c..e02a5f5f9 100644 --- a/examples/blur_opt.py +++ b/examples/blur_opt.py @@ -9,8 +9,9 @@ class BlurOptModule(nn.Module): """Blur optimization module.""" - def __init__(self, n: int, embed_dim: int = 4): + def __init__(self, cfg, n: int, embed_dim: int = 4): super().__init__() + self.blur_a = cfg.blur_a self.embeds = torch.nn.Embedding(n, embed_dim) self.means_encoder = get_encoder(3, 3) self.depths_encoder = get_encoder(3, 1) @@ -39,9 +40,7 @@ def forward( quats: Tensor, ): quats = F.normalize(quats, dim=-1) - means_log = log_transform(means) - - means_emb = self.means_encoder.encode(means_log) + means_emb = self.means_encoder.encode(log_transform(means)) images_emb = self.embeds(image_ids).repeat(means.shape[0], 1) mlp_out = self.blur_deltas_mlp( torch.cat([images_emb, means_emb, scales, quats], dim=-1) @@ -61,7 +60,6 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor): ) grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) grid_emb = self.grid_encoder.encode(grid_xy) - depths_emb = self.depths_encoder.encode(log_transform(depths)) images_emb = self.embeds(image_ids).repeat(*depths_emb.shape[:-1], 1) mlp_in = torch.cat([images_emb, grid_emb, depths_emb], dim=-1) @@ -78,9 +76,7 @@ def mask_loss(self, blur_mask: Tensor, eps: float = 1e-2): from collapsing all 0s or 1s. It is biased towards 0 to encourage sparsity. """ x = blur_mask.mean() - a = 2.0 - b = 0.1 - maskloss = a * (1 / (1 - x + eps) - 1) + b * (1 / (x + eps) - 1) + maskloss = self.blur_a * (1 / (1 - x + eps) - 1) + 0.2 * (1 / (x + eps) - 1) return maskloss diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index e6b054ce9..3f0c93ce0 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -154,8 +154,9 @@ class Config: blur_opt_lr: float = 1e-3 # Regularization for blur mask blur_mask_reg: float = 0.001 - # Blur start iteration - blur_start_iter: int = 2_000 + # Regularization for blur optimization as weight decay + blur_opt_reg: float = 1e-6 + blur_a: float = 0.8 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False @@ -414,12 +415,13 @@ def __init__( self.blur_optimizers = [] if cfg.blur_opt: - self.blur_module = BlurOptModule(len(self.trainset)).to(self.device) + self.blur_module = BlurOptModule(cfg, len(self.trainset)).to(self.device) self.blur_module.zero_init() self.blur_optimizers = [ torch.optim.Adam( self.blur_module.parameters(), lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size), + weight_decay=cfg.blur_opt_reg, ), ] if world_size > 1: @@ -653,7 +655,7 @@ def train(self): if cfg.random_bkgd: bkgd = torch.rand(1, 3, device=device) colors = colors + bkgd * (1.0 - alphas) - if cfg.blur_opt and step >= cfg.blur_start_iter: + if cfg.blur_opt: blur_mask = self.blur_module.predict_mask(image_ids, depths) renders_blur, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, @@ -706,7 +708,7 @@ def train(self): if cfg.use_bilateral_grid: tvloss = 10 * total_variation_loss(self.bil_grids.grids) loss += tvloss - if cfg.blur_opt and step >= cfg.blur_start_iter: + if cfg.blur_opt: loss += cfg.blur_mask_reg * self.blur_module.mask_loss(blur_mask) # regularizations From f3558773d21fb8b57bddc70f5be5ed005e85aa05 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 13 Nov 2024 11:57:48 -0800 Subject: [PATCH 52/53] bounded l1 losos --- examples/benchmarks/mcmc_deblur.sh | 6 ++---- examples/blur_opt.py | 32 +++++++++++++++++++++++------- examples/mlp.py | 27 +++++-------------------- examples/simple_trainer.py | 7 +++---- 4 files changed, 35 insertions(+), 37 deletions(-) diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index b9f87c12c..da89187ef 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -4,8 +4,8 @@ SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake def DATA_FACTOR=4 RENDER_TRAJ_PATH="spiral" CAP_MAX=250000 +RESULT_DIR="results/benchmark_mcmc_deblur" -RESULT_DIR="results/benchmark_mcmc_deblur_wd/1e-3_0.8" for SCENE in $SCENE_LIST; do echo "Running $SCENE" @@ -14,12 +14,10 @@ do CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ --strategy.cap-max $CAP_MAX \ --blur_opt \ - --blur_opt_lr 1e-3 \ - --blur_a 0.8 \ - --blur_mask_reg 0.002 \ --render_traj_path $RENDER_TRAJ_PATH \ --data_dir $SCENE_DIR/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE done + # Summarize the stats python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val diff --git a/examples/blur_opt.py b/examples/blur_opt.py index e02a5f5f9..e2fcc506d 100644 --- a/examples/blur_opt.py +++ b/examples/blur_opt.py @@ -9,9 +9,8 @@ class BlurOptModule(nn.Module): """Blur optimization module.""" - def __init__(self, cfg, n: int, embed_dim: int = 4): + def __init__(self, n: int, embed_dim: int = 4): super().__init__() - self.blur_a = cfg.blur_a self.embeds = torch.nn.Embedding(n, embed_dim) self.means_encoder = get_encoder(3, 3) self.depths_encoder = get_encoder(3, 1) @@ -28,6 +27,7 @@ def __init__(self, cfg, n: int, embed_dim: int = 4): layer_width=64, out_dim=7, ) + self.bounded_l1_loss = bounded_l1_loss(10.0, 0.5) def zero_init(self): torch.nn.init.zeros_(self.embeds.weight) @@ -69,15 +69,33 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor): blur_mask = torch.sigmoid(mlp_out) return blur_mask - def mask_loss(self, blur_mask: Tensor, eps: float = 1e-2): + def mask_loss(self, blur_mask: Tensor): """Loss function for regularizing the blur mask by controlling its mean. - The loss function diverges to +infinity at 0 and 1. This prevents the mask - from collapsing all 0s or 1s. It is biased towards 0 to encourage sparsity. + Uses bounded l1 loss which diverges to +infinity at 0 and 1 to prevents the mask + from collapsing all 0s or 1s. """ x = blur_mask.mean() - maskloss = self.blur_a * (1 / (1 - x + eps) - 1) + 0.2 * (1 / (x + eps) - 1) - return maskloss + return self.bounded_l1_loss(x) + + +def bounded_l1_loss(lambda_a: float, lambda_b: float, eps: float = 1e-2): + """L1 loss function with discontinuities at 0 and 1. + + Args: + lambda_a (float): Coefficient of L1 loss. + lambda_b (float): Coefficient of bounded loss. + eps (float, optional): Epsilon to prevent divide by zero. Defaults to 1e-2. + """ + + def loss_fn(x: Tensor): + return lambda_a * x + lambda_b * (1 / (1 - x + eps) + 1 / (x + eps)) + + # Compute constant that sets min to zero + xs = torch.linspace(0, 1, 1000) + ys = loss_fn(xs) + c = ys.min() + return lambda x: loss_fn(x) - c def get_encoder(num_freqs: int, input_dims: int): diff --git a/examples/mlp.py b/examples/mlp.py index 4d5987cfa..f5bc48acd 100644 --- a/examples/mlp.py +++ b/examples/mlp.py @@ -18,7 +18,6 @@ from typing import Union -import torch from torch import nn from examples.external import TCNN_EXISTS, tcnn @@ -37,7 +36,7 @@ def activation_to_tcnn_string(activation: Union[nn.Module, None]) -> str: if isinstance(activation, nn.ReLU): return "ReLU" if isinstance(activation, nn.LeakyReLU): - return "Leaky ReLU" + return "LeakyReLU" if isinstance(activation, nn.Sigmoid): return "Sigmoid" if isinstance(activation, nn.Softplus): @@ -74,16 +73,11 @@ def create_mlp( num_layers: int, layer_width: int, out_dim: int, - initialize_last_layer_zeros: bool = False, ): if TCNN_EXISTS: - return _create_mlp_tcnn( - in_dim, num_layers, layer_width, out_dim, initialize_last_layer_zeros - ) + return _create_mlp_tcnn(in_dim, num_layers, layer_width, out_dim) else: - return _create_mlp_torch( - in_dim, num_layers, layer_width, out_dim, initialize_last_layer_zeros - ) + return _create_mlp_torch(in_dim, num_layers, layer_width, out_dim) def _create_mlp_tcnn( @@ -91,11 +85,10 @@ def _create_mlp_tcnn( num_layers: int, layer_width: int, out_dim: int, - initialize_last_layer_zeros: bool = False, ): """Create a fully-connected neural network with tiny-cuda-nn.""" network_config = get_tcnn_network_config( - activation=nn.ReLU(), + activation=nn.LeakyReLU(), out_activation=None, layer_width=layer_width, num_layers=num_layers, @@ -105,12 +98,6 @@ def _create_mlp_tcnn( n_output_dims=out_dim, network_config=network_config, ) - - if initialize_last_layer_zeros: - # tcnn always pads the output layer's width to a multiple of 16 - params = tcnn_encoding.state_dict()["params"] - params[-1 * (layer_width * 16 * (out_dim // 16 + 1)) :] = 0 - tcnn_encoding.load_state_dict({"params": params}) return tcnn_encoding @@ -119,7 +106,6 @@ def _create_mlp_torch( num_layers: int, layer_width: int, out_dim: int, - initialize_last_layer_zeros: bool = False, ): """Create a fully-connected neural network with PyTorch.""" layers = [] @@ -128,9 +114,6 @@ def _create_mlp_torch( layer_out = layer_width if i != num_layers - 1 else out_dim layers.append(nn.Linear(layer_in, layer_out, bias=False)) if i != num_layers - 1: - layers.append(nn.ReLU()) + layers.append(nn.LeakyReLU()) layer_in = layer_width - - if initialize_last_layer_zeros: - nn.init.zeros_(layers[-1].weight) return nn.Sequential(*layers) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 3f0c93ce0..49aee9860 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -156,7 +156,6 @@ class Config: blur_mask_reg: float = 0.001 # Regularization for blur optimization as weight decay blur_opt_reg: float = 1e-6 - blur_a: float = 0.8 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False @@ -415,7 +414,7 @@ def __init__( self.blur_optimizers = [] if cfg.blur_opt: - self.blur_module = BlurOptModule(cfg, len(self.trainset)).to(self.device) + self.blur_module = BlurOptModule(len(self.trainset)).to(self.device) self.blur_module.zero_init() self.blur_optimizers = [ torch.optim.Adam( @@ -869,8 +868,8 @@ def train(self): self.eval(step, stage="train") self.eval(step, stage="val") self.render_traj(step) - if step % 1000 == 0: - self.eval(step, stage="vis") + # if step % 1000 == 0: + # self.eval(step, stage="vis") # run compression if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: From a14bcb7d26796cbb8cac0dc31fc863c94fbfc391 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 13 Nov 2024 12:47:34 -0800 Subject: [PATCH 53/53] mlp folder --- examples/blur_opt.py | 54 +++------------------------------- examples/mlp/__init__.py | 2 ++ examples/mlp/encoder.py | 47 +++++++++++++++++++++++++++++ examples/{ => mlp}/external.py | 0 examples/{ => mlp}/mlp.py | 2 +- 5 files changed, 54 insertions(+), 51 deletions(-) create mode 100644 examples/mlp/__init__.py create mode 100644 examples/mlp/encoder.py rename examples/{ => mlp}/external.py (100%) rename examples/{ => mlp}/mlp.py (98%) diff --git a/examples/blur_opt.py b/examples/blur_opt.py index e2fcc506d..ee529d6c3 100644 --- a/examples/blur_opt.py +++ b/examples/blur_opt.py @@ -2,7 +2,7 @@ import torch.nn as nn from torch import Tensor import torch.nn.functional as F -from examples.mlp import create_mlp +from examples.mlp import create_mlp, get_encoder from gsplat.utils import log_transform @@ -12,9 +12,9 @@ class BlurOptModule(nn.Module): def __init__(self, n: int, embed_dim: int = 4): super().__init__() self.embeds = torch.nn.Embedding(n, embed_dim) - self.means_encoder = get_encoder(3, 3) - self.depths_encoder = get_encoder(3, 1) - self.grid_encoder = get_encoder(1, 2) + self.means_encoder = get_encoder(num_freqs=3, input_dims=3) + self.depths_encoder = get_encoder(num_freqs=3, input_dims=1) + self.grid_encoder = get_encoder(num_freqs=1, input_dims=2) self.blur_mask_mlp = create_mlp( in_dim=embed_dim + self.depths_encoder.out_dim + self.grid_encoder.out_dim, num_layers=5, @@ -96,49 +96,3 @@ def loss_fn(x: Tensor): ys = loss_fn(xs) c = ys.min() return lambda x: loss_fn(x) - c - - -def get_encoder(num_freqs: int, input_dims: int): - kwargs = { - "include_input": True, - "input_dims": input_dims, - "max_freq_log2": num_freqs - 1, - "num_freqs": num_freqs, - "log_sampling": True, - "periodic_fns": [torch.sin, torch.cos], - } - encoder = Encoder(**kwargs) - return encoder - - -class Encoder: - def __init__(self, **kwargs): - self.kwargs = kwargs - self.create_embedding_fn() - - def create_embedding_fn(self): - embed_fns = [] - d = self.kwargs["input_dims"] - out_dim = 0 - if self.kwargs["include_input"]: - embed_fns.append(lambda x: x) - out_dim += d - - max_freq = self.kwargs["max_freq_log2"] - N_freqs = self.kwargs["num_freqs"] - - if self.kwargs["log_sampling"]: - freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs) - else: - freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs) - - for freq in freq_bands: - for p_fn in self.kwargs["periodic_fns"]: - embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) - out_dim += d - - self.embed_fns = embed_fns - self.out_dim = out_dim - - def encode(self, inputs): - return torch.cat([fn(inputs) for fn in self.embed_fns], -1) diff --git a/examples/mlp/__init__.py b/examples/mlp/__init__.py new file mode 100644 index 000000000..58f4382b4 --- /dev/null +++ b/examples/mlp/__init__.py @@ -0,0 +1,2 @@ +from .encoder import get_encoder +from .mlp import create_mlp diff --git a/examples/mlp/encoder.py b/examples/mlp/encoder.py new file mode 100644 index 000000000..188a25bf8 --- /dev/null +++ b/examples/mlp/encoder.py @@ -0,0 +1,47 @@ +import torch + + +def get_encoder(num_freqs: int, input_dims: int): + kwargs = { + "include_input": True, + "input_dims": input_dims, + "max_freq_log2": num_freqs - 1, + "num_freqs": num_freqs, + "log_sampling": True, + "periodic_fns": [torch.sin, torch.cos], + } + encoder = Encoder(**kwargs) + return encoder + + +class Encoder: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_embedding_fn() + + def create_embedding_fn(self): + embed_fns = [] + d = self.kwargs["input_dims"] + out_dim = 0 + if self.kwargs["include_input"]: + embed_fns.append(lambda x: x) + out_dim += d + + max_freq = self.kwargs["max_freq_log2"] + N_freqs = self.kwargs["num_freqs"] + + if self.kwargs["log_sampling"]: + freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs) + else: + freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs) + + for freq in freq_bands: + for p_fn in self.kwargs["periodic_fns"]: + embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) + out_dim += d + + self.embed_fns = embed_fns + self.out_dim = out_dim + + def encode(self, inputs): + return torch.cat([fn(inputs) for fn in self.embed_fns], -1) diff --git a/examples/external.py b/examples/mlp/external.py similarity index 100% rename from examples/external.py rename to examples/mlp/external.py diff --git a/examples/mlp.py b/examples/mlp/mlp.py similarity index 98% rename from examples/mlp.py rename to examples/mlp/mlp.py index f5bc48acd..f68330bee 100644 --- a/examples/mlp.py +++ b/examples/mlp/mlp.py @@ -20,7 +20,7 @@ from torch import nn -from examples.external import TCNN_EXISTS, tcnn +from examples.mlp.external import TCNN_EXISTS, tcnn def activation_to_tcnn_string(activation: Union[nn.Module, None]) -> str: