From 2f1eb16baa57916abfa66fe1b7c77a41f72812f1 Mon Sep 17 00:00:00 2001 From: Rahul Goel Date: Tue, 27 Aug 2024 16:20:35 +0200 Subject: [PATCH 1/8] gaussian sparse adam --- examples/simple_trainer.py | 15 ++++++- gsplat/cuda/_wrapper.py | 55 ++++++++++++++++++++++++ gsplat/cuda/csrc/adam.cu | 83 +++++++++++++++++++++++++++++++++++++ gsplat/cuda/csrc/bindings.h | 16 +++++++ gsplat/cuda/csrc/ext.cpp | 2 + 5 files changed, 169 insertions(+), 2 deletions(-) create mode 100644 gsplat/cuda/csrc/adam.cu diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 096682c38..0a67c743e 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -40,6 +40,7 @@ from gsplat.distributed import cli from gsplat.rendering import rasterization from gsplat.strategy import DefaultStrategy, MCMCStrategy +from gsplat.cuda._wrapper import SparseGaussianAdam @dataclass @@ -247,8 +248,17 @@ def create_splats_with_optimizers( # Note that this would not make the training exactly equivalent, see # https://arxiv.org/pdf/2402.18824v1 BS = batch_size * world_size + # optimizers = { + # name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)( + # [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}], + # eps=1e-15 / math.sqrt(BS), + # # TODO: check betas logic when BS is larger than 10 betas[0] will be zero. + # betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), + # ) + # for name, _, lr in params + # } optimizers = { - name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)( + name: SparseGaussianAdam( [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}], eps=1e-15 / math.sqrt(BS), # TODO: check betas logic when BS is larger than 10 betas[0] will be zero. @@ -741,7 +751,8 @@ def train(self): # optimize for optimizer in self.optimizers.values(): - optimizer.step() + N = self.splats.means.shape[0] + optimizer.step(info["radii"] > 0, N) optimizer.zero_grad(set_to_none=True) for optimizer in self.pose_optimizers: optimizer.step() diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index 1c3826110..32a076cd6 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -15,6 +15,32 @@ def call_cuda(*args, **kwargs): return call_cuda +def adam_update( + param: Tensor, + param_grad: Tensor, + exp_avg: Tensor, + exp_avg_sq: Tensor, + tiles_touched: Tensor, + lr: float, + b1: float, + b2: float, + eps: float, + N: int, + M: int +) -> None: + _make_lazy_cuda_func("adam_update")( + param, + param_grad, + exp_avg, + exp_avg_sq, + tiles_touched, + lr, + b1, + b2, + eps, + N, + M + ) def _make_lazy_cuda_obj(name: str) -> Any: # pylint: disable=import-outside-toplevel @@ -1976,3 +2002,32 @@ def backward( None, None, ) +class SparseGaussianAdam(torch.optim.Adam): + def __init__(self, params, eps, betas): + super().__init__(params=params, eps=eps, betas=betas) + + @torch.no_grad() + def step(self, visibility, N): + for group in self.param_groups: + lr = group["lr"] + eps = group["eps"] + beta1, beta2 = group["betas"] + + assert len(group["params"]) == 1, "more than one tensor in group" + param = group["params"][0] + if param.grad is None: + continue + + # Lazy state initialization + state = self.state[param] + if len(state) == 0: + state['step'] = torch.tensor(0.0, dtype=torch.float32) + state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format) + state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format) + + + stored_state = self.state.get(param, None) + exp_avg = stored_state["exp_avg"] + exp_avg_sq = stored_state["exp_avg_sq"] + M = param.numel() // N + adam_update(param, param.grad, exp_avg, exp_avg_sq, visibility, lr, beta1, beta2, eps, N, M) diff --git a/gsplat/cuda/csrc/adam.cu b/gsplat/cuda/csrc/adam.cu new file mode 100644 index 000000000..1e5a7a1b3 --- /dev/null +++ b/gsplat/cuda/csrc/adam.cu @@ -0,0 +1,83 @@ +#include "bindings.h" +#include "helpers.cuh" +#include "utils.cuh" + +#include +#include +#include +#include +#include + +namespace gsplat { + +namespace cg = cooperative_groups; + +template +__global__ void adam_update_kernel( + T* __restrict__ param, + const T* __restrict__ param_grad, + T* __restrict__ exp_avg, + T* __restrict__ exp_avg_sq, + const bool* tiles_touched, + const float lr, + const float b1, + const float b2, + const float eps, + const uint32_t N, + const uint32_t M +) { + auto p_idx = cg::this_grid().thread_rank(); + const uint32_t g_idx = p_idx / M; + if (g_idx >= N) return; + if (tiles_touched[g_idx]) { + T Register_param_grad = param_grad[p_idx]; + T Register_exp_avg = exp_avg[p_idx]; + T Register_exp_avg_sq = exp_avg_sq[p_idx]; + Register_exp_avg = b1 * Register_exp_avg + (1.0f - b1) * Register_param_grad; + Register_exp_avg_sq = b2 * Register_exp_avg_sq + (1.0f - b2) * Register_param_grad * Register_param_grad; + T step = -lr * Register_exp_avg / (sqrt(Register_exp_avg_sq) + eps); + + param[p_idx] += step; + exp_avg[p_idx] = Register_exp_avg; + exp_avg_sq[p_idx] = Register_exp_avg_sq; + } +} + +void adam_update( + torch::Tensor ¶m, + torch::Tensor ¶m_grad, + torch::Tensor &exp_avg, + torch::Tensor &exp_avg_sq, + torch::Tensor &tiles_touched, + const float lr, + const float b1, + const float b2, + const float eps, + const uint32_t N, + const uint32_t M +) { + GSPLAT_DEVICE_GUARD(param); + GSPLAT_CHECK_INPUT(param); + GSPLAT_CHECK_INPUT(param_grad); + GSPLAT_CHECK_INPUT(exp_avg); + GSPLAT_CHECK_INPUT(exp_avg_sq); + GSPLAT_CHECK_INPUT(tiles_touched); + + const uint32_t cnt = N * M; + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + adam_update_kernel<<<(cnt + 255) / 256, 256, 0, stream>>>( + param.data_ptr(), + param_grad.data_ptr(), + exp_avg.data_ptr(), + exp_avg_sq.data_ptr(), + tiles_touched.data_ptr(), + lr, + b1, + b2, + eps, + N, + M + ); +} + +} // namespace gsplat \ No newline at end of file diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index cf0dc8751..703941524 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -316,6 +316,7 @@ std::tuple compute_relocation_tensor( const int n_max ); +<<<<<<< HEAD //====== 2DGS ======// std::tuple< torch::Tensor, @@ -487,6 +488,21 @@ fully_fused_projection_packed_bwd_2dgs_tensor( const bool viewmats_requires_grad, const bool sparse_grad ); +======= + +void adam_update( + torch::Tensor ¶m, + torch::Tensor ¶m_grad, + torch::Tensor &exp_avg, + torch::Tensor &exp_avg_sq, + torch::Tensor &tiles_touched, + const float lr, + const float b1, + const float b2, + const float eps, + const uint32_t N, + const uint32_t M); +>>>>>>> 7fa96fb (gaussian sparse adam) } // namespace gsplat diff --git a/gsplat/cuda/csrc/ext.cpp b/gsplat/cuda/csrc/ext.cpp index a85129959..9febf5859 100644 --- a/gsplat/cuda/csrc/ext.cpp +++ b/gsplat/cuda/csrc/ext.cpp @@ -87,4 +87,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "rasterize_to_indices_in_range_2dgs", &gsplat::rasterize_to_indices_in_range_2dgs_tensor ); + + m.def("adam_update", &gsplat::adam_update); } From c7b4fa15c47b6a94b6e2b1b63250263fbfa3e58a Mon Sep 17 00:00:00 2001 From: Rahul Goel Date: Sun, 22 Sep 2024 16:06:08 +0200 Subject: [PATCH 2/8] update --- examples/simple_trainer.py | 34 +++++++++++++++++++++------------- gsplat/cuda/_wrapper.py | 11 ++++++++--- gsplat/cuda/csrc/adam.cu | 6 +++--- gsplat/cuda/csrc/bindings.h | 2 +- gsplat/cuda/csrc/ext.cpp | 2 +- 5 files changed, 34 insertions(+), 21 deletions(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 0a67c743e..47c0dee54 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -40,7 +40,7 @@ from gsplat.distributed import cli from gsplat.rendering import rasterization from gsplat.strategy import DefaultStrategy, MCMCStrategy -from gsplat.cuda._wrapper import SparseGaussianAdam +from gsplat.cuda._wrapper import SelectiveAdam @dataclass @@ -116,6 +116,8 @@ class Config: packed: bool = False # Use sparse gradients for optimization. (experimental) sparse_grad: bool = False + # Use visible adam from Taming 3DGS. (experimental) + visible_adam: bool = False # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. antialiased: bool = False @@ -192,6 +194,7 @@ def create_splats_with_optimizers( scene_scale: float = 1.0, sh_degree: int = 3, sparse_grad: bool = False, + visible_adam: bool = False, batch_size: int = 1, feature_dim: Optional[int] = None, device: str = "cuda", @@ -248,17 +251,15 @@ def create_splats_with_optimizers( # Note that this would not make the training exactly equivalent, see # https://arxiv.org/pdf/2402.18824v1 BS = batch_size * world_size - # optimizers = { - # name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)( - # [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}], - # eps=1e-15 / math.sqrt(BS), - # # TODO: check betas logic when BS is larger than 10 betas[0] will be zero. - # betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), - # ) - # for name, _, lr in params - # } + optimizer_class = None + if sparse_grad: + optimizer_class = torch.optim.SparseAdam + elif visible_adam: + optimizer_class = SelectiveAdam + else: + optimizer_class = torch.optim.Adam optimizers = { - name: SparseGaussianAdam( + name: optimizer_class( [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}], eps=1e-15 / math.sqrt(BS), # TODO: check betas logic when BS is larger than 10 betas[0] will be zero. @@ -326,6 +327,7 @@ def __init__( scene_scale=self.scene_scale, sh_degree=cfg.sh_degree, sparse_grad=cfg.sparse_grad, + visible_adam=cfg.visible_adam, batch_size=cfg.batch_size, feature_dim=feature_dim, device=self.device, @@ -749,10 +751,16 @@ def train(self): is_coalesced=len(Ks) == 1, ) + if cfg.visible_adam: + visibility_mask = info["radii"] > 0 + gaussian_cnt = self.splats.means.shape[0] + # optimize for optimizer in self.optimizers.values(): - N = self.splats.means.shape[0] - optimizer.step(info["radii"] > 0, N) + if cfg.visible_adam: + optimizer.step(visibility_mask, gaussian_cnt) + else: + optimizer.step() optimizer.zero_grad(set_to_none=True) for optimizer in self.pose_optimizers: optimizer.step() diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index 32a076cd6..efc5b45a4 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -15,7 +15,7 @@ def call_cuda(*args, **kwargs): return call_cuda -def adam_update( +def selective_adam_update( param: Tensor, param_grad: Tensor, exp_avg: Tensor, @@ -28,7 +28,7 @@ def adam_update( N: int, M: int ) -> None: - _make_lazy_cuda_func("adam_update")( + _make_lazy_cuda_func("selective_adam_update")( param, param_grad, exp_avg, @@ -1263,6 +1263,7 @@ def backward(ctx, v_colors: Tensor): v_dirs = None return None, v_dirs, v_coeffs, None +<<<<<<< HEAD ###### 2DGS ###### def fully_fused_projection_2dgs( @@ -2003,6 +2004,9 @@ def backward( None, ) class SparseGaussianAdam(torch.optim.Adam): +======= +class SelectiveAdam(torch.optim.Adam): +>>>>>>> 4aef74b (update) def __init__(self, params, eps, betas): super().__init__(params=params, eps=eps, betas=betas) @@ -2030,4 +2034,5 @@ def step(self, visibility, N): exp_avg = stored_state["exp_avg"] exp_avg_sq = stored_state["exp_avg_sq"] M = param.numel() // N - adam_update(param, param.grad, exp_avg, exp_avg_sq, visibility, lr, beta1, beta2, eps, N, M) + + selective_adam_update(param, param.grad, exp_avg, exp_avg_sq, visibility, lr, beta1, beta2, eps, N, M) \ No newline at end of file diff --git a/gsplat/cuda/csrc/adam.cu b/gsplat/cuda/csrc/adam.cu index 1e5a7a1b3..3018cfd95 100644 --- a/gsplat/cuda/csrc/adam.cu +++ b/gsplat/cuda/csrc/adam.cu @@ -13,7 +13,7 @@ namespace gsplat { namespace cg = cooperative_groups; template -__global__ void adam_update_kernel( +__global__ void selective_adam_update_kernel( T* __restrict__ param, const T* __restrict__ param_grad, T* __restrict__ exp_avg, @@ -43,7 +43,7 @@ __global__ void adam_update_kernel( } } -void adam_update( +void selective_adam_update( torch::Tensor ¶m, torch::Tensor ¶m_grad, torch::Tensor &exp_avg, @@ -65,7 +65,7 @@ void adam_update( const uint32_t cnt = N * M; at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); - adam_update_kernel<<<(cnt + 255) / 256, 256, 0, stream>>>( + selective_adam_update_kernel<<<(cnt + 255) / 256, 256, 0, stream>>>( param.data_ptr(), param_grad.data_ptr(), exp_avg.data_ptr(), diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index 703941524..0f99a4532 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -490,7 +490,7 @@ fully_fused_projection_packed_bwd_2dgs_tensor( ); ======= -void adam_update( +void selective_adam_update( torch::Tensor ¶m, torch::Tensor ¶m_grad, torch::Tensor &exp_avg, diff --git a/gsplat/cuda/csrc/ext.cpp b/gsplat/cuda/csrc/ext.cpp index 9febf5859..c97ee7fda 100644 --- a/gsplat/cuda/csrc/ext.cpp +++ b/gsplat/cuda/csrc/ext.cpp @@ -88,5 +88,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &gsplat::rasterize_to_indices_in_range_2dgs_tensor ); - m.def("adam_update", &gsplat::adam_update); + m.def("selective_adam_update", &gsplat::selective_adam_update); } From 7235a7f6bf654061124e7b9e50ce2fb02d6cccf0 Mon Sep 17 00:00:00 2001 From: Rahul Goel Date: Sun, 22 Sep 2024 17:41:51 +0200 Subject: [PATCH 3/8] merge conflict fix --- gsplat/cuda/_wrapper.py | 6 +----- gsplat/cuda/csrc/bindings.h | 3 --- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index efc5b45a4..403ce6b99 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -1263,8 +1263,6 @@ def backward(ctx, v_colors: Tensor): v_dirs = None return None, v_dirs, v_coeffs, None -<<<<<<< HEAD - ###### 2DGS ###### def fully_fused_projection_2dgs( means: Tensor, # [N, 3] @@ -2003,10 +2001,8 @@ def backward( None, None, ) -class SparseGaussianAdam(torch.optim.Adam): -======= + class SelectiveAdam(torch.optim.Adam): ->>>>>>> 4aef74b (update) def __init__(self, params, eps, betas): super().__init__(params=params, eps=eps, betas=betas) diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index 0f99a4532..66f7a2567 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -316,7 +316,6 @@ std::tuple compute_relocation_tensor( const int n_max ); -<<<<<<< HEAD //====== 2DGS ======// std::tuple< torch::Tensor, @@ -488,7 +487,6 @@ fully_fused_projection_packed_bwd_2dgs_tensor( const bool viewmats_requires_grad, const bool sparse_grad ); -======= void selective_adam_update( torch::Tensor ¶m, @@ -502,7 +500,6 @@ void selective_adam_update( const float eps, const uint32_t N, const uint32_t M); ->>>>>>> 7fa96fb (gaussian sparse adam) } // namespace gsplat From febe97eb9933cb9aa79158c15a1bcf342404fbb7 Mon Sep 17 00:00:00 2001 From: Rahul Goel Date: Mon, 30 Sep 2024 12:33:51 +0200 Subject: [PATCH 4/8] masking for packed --- examples/simple_trainer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 47c0dee54..05ad2e823 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -752,8 +752,12 @@ def train(self): ) if cfg.visible_adam: - visibility_mask = info["radii"] > 0 gaussian_cnt = self.splats.means.shape[0] + if cfg.packed: + visibility_mask = torch.zeros_like(self.splats["opacities"], dtype=bool) + visibility_mask.scatter_(0, info["gaussian_ids"], 1) + else: + visibility_mask = (info["radii"] > 0).any(0) # optimize for optimizer in self.optimizers.values(): From 35412ec9d40bee279409ae659e45972a8b41482f Mon Sep 17 00:00:00 2001 From: Rahul Goel Date: Mon, 30 Sep 2024 12:57:20 +0200 Subject: [PATCH 5/8] separete directory for optimizers --- examples/simple_trainer.py | 6 ++-- gsplat/cuda/_wrapper.py | 48 +++------------------------ gsplat/optimizers/__init__.py | 1 + gsplat/optimizers/selective_adam.py | 50 +++++++++++++++++++++++++++++ 4 files changed, 60 insertions(+), 45 deletions(-) create mode 100644 gsplat/optimizers/__init__.py create mode 100644 gsplat/optimizers/selective_adam.py diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 05ad2e823..f3c748236 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -40,7 +40,7 @@ from gsplat.distributed import cli from gsplat.rendering import rasterization from gsplat.strategy import DefaultStrategy, MCMCStrategy -from gsplat.cuda._wrapper import SelectiveAdam +from gsplat.optimizers import SelectiveAdam @dataclass @@ -754,7 +754,9 @@ def train(self): if cfg.visible_adam: gaussian_cnt = self.splats.means.shape[0] if cfg.packed: - visibility_mask = torch.zeros_like(self.splats["opacities"], dtype=bool) + visibility_mask = torch.zeros_like( + self.splats["opacities"], dtype=bool + ) visibility_mask.scatter_(0, info["gaussian_ids"], 1) else: visibility_mask = (info["radii"] > 0).any(0) diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index 403ce6b99..5a1065af0 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -15,6 +15,7 @@ def call_cuda(*args, **kwargs): return call_cuda + def selective_adam_update( param: Tensor, param_grad: Tensor, @@ -26,22 +27,13 @@ def selective_adam_update( b2: float, eps: float, N: int, - M: int + M: int, ) -> None: _make_lazy_cuda_func("selective_adam_update")( - param, - param_grad, - exp_avg, - exp_avg_sq, - tiles_touched, - lr, - b1, - b2, - eps, - N, - M + param, param_grad, exp_avg, exp_avg_sq, tiles_touched, lr, b1, b2, eps, N, M ) + def _make_lazy_cuda_obj(name: str) -> Any: # pylint: disable=import-outside-toplevel from ._backend import _C @@ -1263,6 +1255,7 @@ def backward(ctx, v_colors: Tensor): v_dirs = None return None, v_dirs, v_coeffs, None + ###### 2DGS ###### def fully_fused_projection_2dgs( means: Tensor, # [N, 3] @@ -2001,34 +1994,3 @@ def backward( None, None, ) - -class SelectiveAdam(torch.optim.Adam): - def __init__(self, params, eps, betas): - super().__init__(params=params, eps=eps, betas=betas) - - @torch.no_grad() - def step(self, visibility, N): - for group in self.param_groups: - lr = group["lr"] - eps = group["eps"] - beta1, beta2 = group["betas"] - - assert len(group["params"]) == 1, "more than one tensor in group" - param = group["params"][0] - if param.grad is None: - continue - - # Lazy state initialization - state = self.state[param] - if len(state) == 0: - state['step'] = torch.tensor(0.0, dtype=torch.float32) - state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format) - state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format) - - - stored_state = self.state.get(param, None) - exp_avg = stored_state["exp_avg"] - exp_avg_sq = stored_state["exp_avg_sq"] - M = param.numel() // N - - selective_adam_update(param, param.grad, exp_avg, exp_avg_sq, visibility, lr, beta1, beta2, eps, N, M) \ No newline at end of file diff --git a/gsplat/optimizers/__init__.py b/gsplat/optimizers/__init__.py new file mode 100644 index 000000000..6be9896c0 --- /dev/null +++ b/gsplat/optimizers/__init__.py @@ -0,0 +1 @@ +from .selective_adam import SelectiveAdam diff --git a/gsplat/optimizers/selective_adam.py b/gsplat/optimizers/selective_adam.py new file mode 100644 index 000000000..9a535d420 --- /dev/null +++ b/gsplat/optimizers/selective_adam.py @@ -0,0 +1,50 @@ +import torch + +from ..cuda._wrapper import selective_adam_update + + +class SelectiveAdam(torch.optim.Adam): + def __init__(self, params, eps, betas): + super().__init__(params=params, eps=eps, betas=betas) + + @torch.no_grad() + def step(self, visibility, N): + for group in self.param_groups: + lr = group["lr"] + eps = group["eps"] + beta1, beta2 = group["betas"] + + assert len(group["params"]) == 1, "more than one tensor in group" + param = group["params"][0] + if param.grad is None: + continue + + # Lazy state initialization + state = self.state[param] + if len(state) == 0: + state["step"] = torch.tensor(0.0, dtype=torch.float32) + state["exp_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + state["exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + stored_state = self.state.get(param, None) + exp_avg = stored_state["exp_avg"] + exp_avg_sq = stored_state["exp_avg_sq"] + M = param.numel() // N + + selective_adam_update( + param, + param.grad, + exp_avg, + exp_avg_sq, + visibility, + lr, + beta1, + beta2, + eps, + N, + M, + ) From 464d0647ada9863719e1fdab87dbad2b7204a113 Mon Sep 17 00:00:00 2001 From: Rahul Goel Date: Tue, 1 Oct 2024 07:39:21 +0200 Subject: [PATCH 6/8] add docstring --- gsplat/optimizers/selective_adam.py | 35 +++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/gsplat/optimizers/selective_adam.py b/gsplat/optimizers/selective_adam.py index 9a535d420..4f13d4302 100644 --- a/gsplat/optimizers/selective_adam.py +++ b/gsplat/optimizers/selective_adam.py @@ -4,6 +4,41 @@ class SelectiveAdam(torch.optim.Adam): + """ + A custom optimizer that extends the standard Adam optimizer by + incorporating selective updates. + + This class is useful for situations where only a subset of parameters + should be updated at each step, such as in sparse models or in cases where + parameter visibility is controlled by an external mask. + + Additionally, the operations are fused into a single kernel. This optimizer + leverages the `selective_adam_update` function from a CUDA backend for + optimized sparse updates. + + Args: + params (iterable): Iterable of parameters to optimize or dicts defining parameter groups. + eps (float): Term added to the denominator to improve numerical stability (default: 1e-8). + betas (Tuple[float, float]): Coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)). + + Examples: + + >>> N = 100 + >>> param = torch.randn(N, requires_grad=True) + >>> optimizer = SelectiveAdam([param], eps=1e-8, betas=(0.9, 0.999)) + >>> visibility_mask = torch.cat([torch.ones(50), torch.zeros(50)]) # Visible first half, hidden second half + + >>> # Forward pass + >>> loss = torch.sum(param ** 2) + + >>> # Backward pass + >>> loss.backward() + + >>> # Optimization step with selective updates + >>> optimizer.step(visibility=visibility_mask, N=N) + + """ + def __init__(self, params, eps, betas): super().__init__(params=params, eps=eps, betas=betas) From dc19061f7cdb09a3b75f30c638fde1c04f6f1b21 Mon Sep 17 00:00:00 2001 From: Rahul Goel Date: Tue, 1 Oct 2024 07:41:52 +0200 Subject: [PATCH 7/8] t3dgs --- gsplat/optimizers/selective_adam.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gsplat/optimizers/selective_adam.py b/gsplat/optimizers/selective_adam.py index 4f13d4302..9bf37e1fe 100644 --- a/gsplat/optimizers/selective_adam.py +++ b/gsplat/optimizers/selective_adam.py @@ -16,6 +16,8 @@ class SelectiveAdam(torch.optim.Adam): leverages the `selective_adam_update` function from a CUDA backend for optimized sparse updates. + This is one of the two optimizers mentioned in the Taming3DGS paper. + Args: params (iterable): Iterable of parameters to optimize or dicts defining parameter groups. eps (float): Term added to the denominator to improve numerical stability (default: 1e-8). From 5186c475a0f1db744f56791edd39aa876f6dc781 Mon Sep 17 00:00:00 2001 From: Rahul Goel Date: Wed, 2 Oct 2024 09:59:15 +0200 Subject: [PATCH 8/8] move N inside step --- examples/simple_trainer.py | 2 +- gsplat/optimizers/selective_adam.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index f3c748236..93e70002f 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -764,7 +764,7 @@ def train(self): # optimize for optimizer in self.optimizers.values(): if cfg.visible_adam: - optimizer.step(visibility_mask, gaussian_cnt) + optimizer.step(visibility_mask) else: optimizer.step() optimizer.zero_grad(set_to_none=True) diff --git a/gsplat/optimizers/selective_adam.py b/gsplat/optimizers/selective_adam.py index 9bf37e1fe..e7decf7a9 100644 --- a/gsplat/optimizers/selective_adam.py +++ b/gsplat/optimizers/selective_adam.py @@ -37,7 +37,7 @@ class SelectiveAdam(torch.optim.Adam): >>> loss.backward() >>> # Optimization step with selective updates - >>> optimizer.step(visibility=visibility_mask, N=N) + >>> optimizer.step(visibility=visibility_mask) """ @@ -45,7 +45,8 @@ def __init__(self, params, eps, betas): super().__init__(params=params, eps=eps, betas=betas) @torch.no_grad() - def step(self, visibility, N): + def step(self, visibility): + N = visibility.numel() for group in self.param_groups: lr = group["lr"] eps = group["eps"]