Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Selective Adam #432

Merged
merged 8 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from gsplat.distributed import cli
from gsplat.rendering import rasterization
from gsplat.strategy import DefaultStrategy, MCMCStrategy
from gsplat.optimizers import SelectiveAdam


@dataclass
Expand Down Expand Up @@ -115,6 +116,8 @@ class Config:
packed: bool = False
# Use sparse gradients for optimization. (experimental)
sparse_grad: bool = False
# Use visible adam from Taming 3DGS. (experimental)
visible_adam: bool = False
# Anti-aliasing in rasterization. Might slightly hurt quantitative metrics.
antialiased: bool = False

Expand Down Expand Up @@ -191,6 +194,7 @@ def create_splats_with_optimizers(
scene_scale: float = 1.0,
sh_degree: int = 3,
sparse_grad: bool = False,
visible_adam: bool = False,
batch_size: int = 1,
feature_dim: Optional[int] = None,
device: str = "cuda",
Expand Down Expand Up @@ -247,8 +251,15 @@ def create_splats_with_optimizers(
# Note that this would not make the training exactly equivalent, see
# https://arxiv.org/pdf/2402.18824v1
BS = batch_size * world_size
optimizer_class = None
if sparse_grad:
optimizer_class = torch.optim.SparseAdam
elif visible_adam:
optimizer_class = SelectiveAdam
else:
optimizer_class = torch.optim.Adam
optimizers = {
name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)(
name: optimizer_class(
[{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}],
eps=1e-15 / math.sqrt(BS),
# TODO: check betas logic when BS is larger than 10 betas[0] will be zero.
Expand Down Expand Up @@ -316,6 +327,7 @@ def __init__(
scene_scale=self.scene_scale,
sh_degree=cfg.sh_degree,
sparse_grad=cfg.sparse_grad,
visible_adam=cfg.visible_adam,
batch_size=cfg.batch_size,
feature_dim=feature_dim,
device=self.device,
Expand Down Expand Up @@ -739,9 +751,22 @@ def train(self):
is_coalesced=len(Ks) == 1,
)

if cfg.visible_adam:
gaussian_cnt = self.splats.means.shape[0]
if cfg.packed:
visibility_mask = torch.zeros_like(
self.splats["opacities"], dtype=bool
)
visibility_mask.scatter_(0, info["gaussian_ids"], 1)
else:
visibility_mask = (info["radii"] > 0).any(0)

# optimize
for optimizer in self.optimizers.values():
optimizer.step()
if cfg.visible_adam:
optimizer.step(visibility_mask)
else:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
for optimizer in self.pose_optimizers:
optimizer.step()
Expand Down
18 changes: 18 additions & 0 deletions gsplat/cuda/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,24 @@ def call_cuda(*args, **kwargs):
return call_cuda


def selective_adam_update(
param: Tensor,
param_grad: Tensor,
exp_avg: Tensor,
exp_avg_sq: Tensor,
tiles_touched: Tensor,
lr: float,
b1: float,
b2: float,
eps: float,
N: int,
M: int,
) -> None:
_make_lazy_cuda_func("selective_adam_update")(
param, param_grad, exp_avg, exp_avg_sq, tiles_touched, lr, b1, b2, eps, N, M
)


def _make_lazy_cuda_obj(name: str) -> Any:
# pylint: disable=import-outside-toplevel
from ._backend import _C
Expand Down
83 changes: 83 additions & 0 deletions gsplat/cuda/csrc/adam.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#include "bindings.h"
#include "helpers.cuh"
#include "utils.cuh"

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/cub.cuh>
#include <cuda.h>
#include <cuda_runtime.h>

namespace gsplat {

namespace cg = cooperative_groups;

template<typename T>
__global__ void selective_adam_update_kernel(
T* __restrict__ param,
const T* __restrict__ param_grad,
T* __restrict__ exp_avg,
T* __restrict__ exp_avg_sq,
const bool* tiles_touched,
const float lr,
const float b1,
const float b2,
const float eps,
const uint32_t N,
const uint32_t M
) {
auto p_idx = cg::this_grid().thread_rank();
const uint32_t g_idx = p_idx / M;
if (g_idx >= N) return;
if (tiles_touched[g_idx]) {
T Register_param_grad = param_grad[p_idx];
T Register_exp_avg = exp_avg[p_idx];
T Register_exp_avg_sq = exp_avg_sq[p_idx];
Register_exp_avg = b1 * Register_exp_avg + (1.0f - b1) * Register_param_grad;
Register_exp_avg_sq = b2 * Register_exp_avg_sq + (1.0f - b2) * Register_param_grad * Register_param_grad;
T step = -lr * Register_exp_avg / (sqrt(Register_exp_avg_sq) + eps);

param[p_idx] += step;
exp_avg[p_idx] = Register_exp_avg;
exp_avg_sq[p_idx] = Register_exp_avg_sq;
}
}

void selective_adam_update(
torch::Tensor &param,
torch::Tensor &param_grad,
torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq,
torch::Tensor &tiles_touched,
const float lr,
const float b1,
const float b2,
const float eps,
const uint32_t N,
const uint32_t M
) {
GSPLAT_DEVICE_GUARD(param);
GSPLAT_CHECK_INPUT(param);
GSPLAT_CHECK_INPUT(param_grad);
GSPLAT_CHECK_INPUT(exp_avg);
GSPLAT_CHECK_INPUT(exp_avg_sq);
GSPLAT_CHECK_INPUT(tiles_touched);

const uint32_t cnt = N * M;
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
selective_adam_update_kernel<float><<<(cnt + 255) / 256, 256, 0, stream>>>(
param.data_ptr<float>(),
param_grad.data_ptr<float>(),
exp_avg.data_ptr<float>(),
exp_avg_sq.data_ptr<float>(),
tiles_touched.data_ptr<bool>(),
lr,
b1,
b2,
eps,
N,
M
);
}

} // namespace gsplat
13 changes: 13 additions & 0 deletions gsplat/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,19 @@ fully_fused_projection_packed_bwd_2dgs_tensor(
const bool sparse_grad
);

void selective_adam_update(
torch::Tensor &param,
torch::Tensor &param_grad,
torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq,
torch::Tensor &tiles_touched,
const float lr,
const float b1,
const float b2,
const float eps,
const uint32_t N,
const uint32_t M);

} // namespace gsplat

#endif // GSPLAT_CUDA_BINDINGS_H
2 changes: 2 additions & 0 deletions gsplat/cuda/csrc/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"rasterize_to_indices_in_range_2dgs",
&gsplat::rasterize_to_indices_in_range_2dgs_tensor
);

m.def("selective_adam_update", &gsplat::selective_adam_update);
}
1 change: 1 addition & 0 deletions gsplat/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .selective_adam import SelectiveAdam
88 changes: 88 additions & 0 deletions gsplat/optimizers/selective_adam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import torch

from ..cuda._wrapper import selective_adam_update


class SelectiveAdam(torch.optim.Adam):
"""
A custom optimizer that extends the standard Adam optimizer by
incorporating selective updates.

This class is useful for situations where only a subset of parameters
should be updated at each step, such as in sparse models or in cases where
parameter visibility is controlled by an external mask.

Additionally, the operations are fused into a single kernel. This optimizer
leverages the `selective_adam_update` function from a CUDA backend for
optimized sparse updates.

This is one of the two optimizers mentioned in the Taming3DGS paper.

Args:
params (iterable): Iterable of parameters to optimize or dicts defining parameter groups.
eps (float): Term added to the denominator to improve numerical stability (default: 1e-8).
betas (Tuple[float, float]): Coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)).

Examples:

>>> N = 100
>>> param = torch.randn(N, requires_grad=True)
>>> optimizer = SelectiveAdam([param], eps=1e-8, betas=(0.9, 0.999))
>>> visibility_mask = torch.cat([torch.ones(50), torch.zeros(50)]) # Visible first half, hidden second half

>>> # Forward pass
>>> loss = torch.sum(param ** 2)

>>> # Backward pass
>>> loss.backward()

>>> # Optimization step with selective updates
>>> optimizer.step(visibility=visibility_mask)

"""

def __init__(self, params, eps, betas):
super().__init__(params=params, eps=eps, betas=betas)

@torch.no_grad()
def step(self, visibility):
N = visibility.numel()
for group in self.param_groups:
lr = group["lr"]
eps = group["eps"]
beta1, beta2 = group["betas"]

assert len(group["params"]) == 1, "more than one tensor in group"
param = group["params"][0]
if param.grad is None:
continue

# Lazy state initialization
state = self.state[param]
if len(state) == 0:
state["step"] = torch.tensor(0.0, dtype=torch.float32)
state["exp_avg"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
state["exp_avg_sq"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)

stored_state = self.state.get(param, None)
exp_avg = stored_state["exp_avg"]
exp_avg_sq = stored_state["exp_avg_sq"]
M = param.numel() // N

selective_adam_update(
param,
param.grad,
exp_avg,
exp_avg_sq,
visibility,
lr,
beta1,
beta2,
eps,
N,
M,
)
Loading