-
Notifications
You must be signed in to change notification settings - Fork 312
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* gaussian sparse adam * update * merge conflict fix * masking for packed * separete directory for optimizers * add docstring * t3dgs * move N inside step
- Loading branch information
1 parent
e19da37
commit d4020bc
Showing
7 changed files
with
232 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
#include "bindings.h" | ||
#include "helpers.cuh" | ||
#include "utils.cuh" | ||
|
||
#include <cooperative_groups.h> | ||
#include <cooperative_groups/reduce.h> | ||
#include <cub/cub.cuh> | ||
#include <cuda.h> | ||
#include <cuda_runtime.h> | ||
|
||
namespace gsplat { | ||
|
||
namespace cg = cooperative_groups; | ||
|
||
template<typename T> | ||
__global__ void selective_adam_update_kernel( | ||
T* __restrict__ param, | ||
const T* __restrict__ param_grad, | ||
T* __restrict__ exp_avg, | ||
T* __restrict__ exp_avg_sq, | ||
const bool* tiles_touched, | ||
const float lr, | ||
const float b1, | ||
const float b2, | ||
const float eps, | ||
const uint32_t N, | ||
const uint32_t M | ||
) { | ||
auto p_idx = cg::this_grid().thread_rank(); | ||
const uint32_t g_idx = p_idx / M; | ||
if (g_idx >= N) return; | ||
if (tiles_touched[g_idx]) { | ||
T Register_param_grad = param_grad[p_idx]; | ||
T Register_exp_avg = exp_avg[p_idx]; | ||
T Register_exp_avg_sq = exp_avg_sq[p_idx]; | ||
Register_exp_avg = b1 * Register_exp_avg + (1.0f - b1) * Register_param_grad; | ||
Register_exp_avg_sq = b2 * Register_exp_avg_sq + (1.0f - b2) * Register_param_grad * Register_param_grad; | ||
T step = -lr * Register_exp_avg / (sqrt(Register_exp_avg_sq) + eps); | ||
|
||
param[p_idx] += step; | ||
exp_avg[p_idx] = Register_exp_avg; | ||
exp_avg_sq[p_idx] = Register_exp_avg_sq; | ||
} | ||
} | ||
|
||
void selective_adam_update( | ||
torch::Tensor ¶m, | ||
torch::Tensor ¶m_grad, | ||
torch::Tensor &exp_avg, | ||
torch::Tensor &exp_avg_sq, | ||
torch::Tensor &tiles_touched, | ||
const float lr, | ||
const float b1, | ||
const float b2, | ||
const float eps, | ||
const uint32_t N, | ||
const uint32_t M | ||
) { | ||
GSPLAT_DEVICE_GUARD(param); | ||
GSPLAT_CHECK_INPUT(param); | ||
GSPLAT_CHECK_INPUT(param_grad); | ||
GSPLAT_CHECK_INPUT(exp_avg); | ||
GSPLAT_CHECK_INPUT(exp_avg_sq); | ||
GSPLAT_CHECK_INPUT(tiles_touched); | ||
|
||
const uint32_t cnt = N * M; | ||
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); | ||
selective_adam_update_kernel<float><<<(cnt + 255) / 256, 256, 0, stream>>>( | ||
param.data_ptr<float>(), | ||
param_grad.data_ptr<float>(), | ||
exp_avg.data_ptr<float>(), | ||
exp_avg_sq.data_ptr<float>(), | ||
tiles_touched.data_ptr<bool>(), | ||
lr, | ||
b1, | ||
b2, | ||
eps, | ||
N, | ||
M | ||
); | ||
} | ||
|
||
} // namespace gsplat |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .selective_adam import SelectiveAdam |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import torch | ||
|
||
from ..cuda._wrapper import selective_adam_update | ||
|
||
|
||
class SelectiveAdam(torch.optim.Adam): | ||
""" | ||
A custom optimizer that extends the standard Adam optimizer by | ||
incorporating selective updates. | ||
This class is useful for situations where only a subset of parameters | ||
should be updated at each step, such as in sparse models or in cases where | ||
parameter visibility is controlled by an external mask. | ||
Additionally, the operations are fused into a single kernel. This optimizer | ||
leverages the `selective_adam_update` function from a CUDA backend for | ||
optimized sparse updates. | ||
This is one of the two optimizers mentioned in the Taming3DGS paper. | ||
Args: | ||
params (iterable): Iterable of parameters to optimize or dicts defining parameter groups. | ||
eps (float): Term added to the denominator to improve numerical stability (default: 1e-8). | ||
betas (Tuple[float, float]): Coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)). | ||
Examples: | ||
>>> N = 100 | ||
>>> param = torch.randn(N, requires_grad=True) | ||
>>> optimizer = SelectiveAdam([param], eps=1e-8, betas=(0.9, 0.999)) | ||
>>> visibility_mask = torch.cat([torch.ones(50), torch.zeros(50)]) # Visible first half, hidden second half | ||
>>> # Forward pass | ||
>>> loss = torch.sum(param ** 2) | ||
>>> # Backward pass | ||
>>> loss.backward() | ||
>>> # Optimization step with selective updates | ||
>>> optimizer.step(visibility=visibility_mask) | ||
""" | ||
|
||
def __init__(self, params, eps, betas): | ||
super().__init__(params=params, eps=eps, betas=betas) | ||
|
||
@torch.no_grad() | ||
def step(self, visibility): | ||
N = visibility.numel() | ||
for group in self.param_groups: | ||
lr = group["lr"] | ||
eps = group["eps"] | ||
beta1, beta2 = group["betas"] | ||
|
||
assert len(group["params"]) == 1, "more than one tensor in group" | ||
param = group["params"][0] | ||
if param.grad is None: | ||
continue | ||
|
||
# Lazy state initialization | ||
state = self.state[param] | ||
if len(state) == 0: | ||
state["step"] = torch.tensor(0.0, dtype=torch.float32) | ||
state["exp_avg"] = torch.zeros_like( | ||
param, memory_format=torch.preserve_format | ||
) | ||
state["exp_avg_sq"] = torch.zeros_like( | ||
param, memory_format=torch.preserve_format | ||
) | ||
|
||
stored_state = self.state.get(param, None) | ||
exp_avg = stored_state["exp_avg"] | ||
exp_avg_sq = stored_state["exp_avg_sq"] | ||
M = param.numel() // N | ||
|
||
selective_adam_update( | ||
param, | ||
param.grad, | ||
exp_avg, | ||
exp_avg_sq, | ||
visibility, | ||
lr, | ||
beta1, | ||
beta2, | ||
eps, | ||
N, | ||
M, | ||
) |