Skip to content

Commit

Permalink
Update forward & backward for rendered alpha image (#70)
Browse files Browse the repository at this point in the history
* finish alpha forward & backward

* black format

* fix some merging issues

* remove unnecessary .cuda()

* add return_alpha keyword

* add some notes

* black reformat
  • Loading branch information
Zhuoyang-Pan authored Jan 16, 2024
1 parent e8696bd commit 0c305ab
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 9 deletions.
4 changes: 2 additions & 2 deletions examples/test_rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def train(self, iterations: int = 1000, lr: float = 0.01, save_imgs: bool = True
frames = []
for i in range(iterations):
optimizer.zero_grad()
slow_out = self.forward_slow()
slow_out, _ = self.forward_slow()

loss = mse_loss(slow_out, self.gt_image)
loss.backward()
Expand All @@ -168,7 +168,7 @@ def train(self, iterations: int = 1000, lr: float = 0.01, save_imgs: bool = True
]

optimizer.zero_grad()
new_out = self.forward_new()
new_out, _ = self.forward_new()
loss = mse_loss(new_out, self.gt_image)
loss.backward()

Expand Down
8 changes: 7 additions & 1 deletion gsplat/cuda/csrc/backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ __global__ void nd_rasterize_backward_kernel(
const float* __restrict__ final_Ts,
const int* __restrict__ final_index,
const float* __restrict__ v_output,
const float* __restrict__ v_output_alpha,
float2* __restrict__ v_xy,
float3* __restrict__ v_conic,
float* __restrict__ v_rgb,
Expand Down Expand Up @@ -45,6 +46,7 @@ __global__ void nd_rasterize_backward_kernel(
int2 range = tile_bins[tile_id];
// df/d_out for this pixel
const float *v_out = &(v_output[channels * pix_id]);
const float v_out_alpha = v_output_alpha[pix_id];
// this is the T AFTER the last gaussian in this pixel
float T_final = final_Ts[pix_id];
float T = T_final;
Expand Down Expand Up @@ -97,7 +99,7 @@ __global__ void nd_rasterize_backward_kernel(
// update the running sum
S[c] += rgbs[channels * g + c] * fac;
}

v_alpha += T_final * ra * v_out_alpha;
// update v_opacity for this gaussian
atomicAdd(&(v_opacity[g]), vis * v_alpha);

Expand Down Expand Up @@ -146,6 +148,7 @@ __global__ void rasterize_backward_kernel(
const float* __restrict__ final_Ts,
const int* __restrict__ final_index,
const float3* __restrict__ v_output,
const float* __restrict__ v_output_alpha,
float2* __restrict__ v_xy,
float3* __restrict__ v_conic,
float3* __restrict__ v_rgb,
Expand Down Expand Up @@ -188,6 +191,7 @@ __global__ void rasterize_backward_kernel(

// df/d_out for this pixel
const float3 v_out = v_output[pix_id];
const float v_out_alpha = v_output_alpha[pix_id];

// collect and process batches of gaussians
// each thread loads one gaussian at a time before rasterizing
Expand Down Expand Up @@ -265,6 +269,8 @@ __global__ void rasterize_backward_kernel(
v_alpha += (rgb.x * T - buffer.x * ra) * v_out.x;
v_alpha += (rgb.y * T - buffer.y * ra) * v_out.y;
v_alpha += (rgb.z * T - buffer.z * ra) * v_out.z;

v_alpha += T_final * ra * v_out_alpha;
// contribution from background pixel
v_alpha += -T_final * ra * background.x * v_out.x;
v_alpha += -T_final * ra * background.y * v_out.y;
Expand Down
2 changes: 2 additions & 0 deletions gsplat/cuda/csrc/backward.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ __global__ void nd_rasterize_backward_kernel(
const float* __restrict__ final_Ts,
const int* __restrict__ final_index,
const float* __restrict__ v_output,
const float* __restrict__ v_output_alpha,
float2* __restrict__ v_xy,
float3* __restrict__ v_conic,
float* __restrict__ v_rgb,
Expand All @@ -63,6 +64,7 @@ __global__ void rasterize_backward_kernel(
const float* __restrict__ final_Ts,
const int* __restrict__ final_index,
const float3* __restrict__ v_output,
const float* __restrict__ v_output_alpha,
float2* __restrict__ v_xy,
float3* __restrict__ v_conic,
float3* __restrict__ v_rgb,
Expand Down
8 changes: 6 additions & 2 deletions gsplat/cuda/csrc/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,8 @@ std::
const torch::Tensor &background,
const torch::Tensor &final_Ts,
const torch::Tensor &final_idx,
const torch::Tensor &v_output // dL_dout_color
const torch::Tensor &v_output, // dL_dout_color
const torch::Tensor &v_output_alpha // dL_dout_alpha
) {

CHECK_INPUT(xys);
Expand Down Expand Up @@ -540,6 +541,7 @@ std::
final_Ts.contiguous().data_ptr<float>(),
final_idx.contiguous().data_ptr<int>(),
v_output.contiguous().data_ptr<float>(),
v_output_alpha.contiguous().data_ptr<float>(),
(float2 *)v_xy.contiguous().data_ptr<float>(),
(float3 *)v_conic.contiguous().data_ptr<float>(),
v_colors.contiguous().data_ptr<float>(),
Expand Down Expand Up @@ -569,7 +571,8 @@ std::
const torch::Tensor &background,
const torch::Tensor &final_Ts,
const torch::Tensor &final_idx,
const torch::Tensor &v_output // dL_dout_color
const torch::Tensor &v_output, // dL_dout_color
const torch::Tensor &v_output_alpha // dL_dout_alpha
) {

CHECK_INPUT(xys);
Expand Down Expand Up @@ -612,6 +615,7 @@ std::
final_Ts.contiguous().data_ptr<float>(),
final_idx.contiguous().data_ptr<int>(),
(float3 *)v_output.contiguous().data_ptr<float>(),
v_output_alpha.contiguous().data_ptr<float>(),
(float2 *)v_xy.contiguous().data_ptr<float>(),
(float3 *)v_conic.contiguous().data_ptr<float>(),
(float3 *)v_colors.contiguous().data_ptr<float>(),
Expand Down
6 changes: 4 additions & 2 deletions gsplat/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ std::
const torch::Tensor &background,
const torch::Tensor &final_Ts,
const torch::Tensor &final_idx,
const torch::Tensor &v_output // dL_dout_color
const torch::Tensor &v_output, // dL_dout_color
const torch::Tensor &v_output_alpha
);

std::
Expand All @@ -179,5 +180,6 @@ std::
const torch::Tensor &background,
const torch::Tensor &final_Ts,
const torch::Tensor &final_idx,
const torch::Tensor &v_output // dL_dout_color
const torch::Tensor &v_output, // dL_dout_color
const torch::Tensor &v_output_alpha
);
18 changes: 16 additions & 2 deletions gsplat/rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def rasterize_gaussians(
img_height: int,
img_width: int,
background: Optional[Float[Tensor, "channels"]] = None,
return_alpha: Optional[bool] = False,
) -> Tensor:
"""Rasterizes 2D gaussians by sorting and binning gaussian intersections for each tile and returns an N-dimensional output using alpha-compositing.
Expand All @@ -39,11 +40,13 @@ def rasterize_gaussians(
img_height (int): height of the rendered image.
img_width (int): width of the rendered image.
background (Tensor): background color
return_alpha (bool): whether to return alpha channel
Returns:
A Tensor:
- **out_img** (Tensor): N-dimensional rendered output image.
- **out_alpha** (Optional[Tensor]): Alpha channel of the rendered output image.
"""
if colors.dtype == torch.uint8:
# make sure colors are float [0,1]
Expand Down Expand Up @@ -75,6 +78,7 @@ def rasterize_gaussians(
img_height,
img_width,
background.contiguous(),
return_alpha,
)


Expand All @@ -94,6 +98,7 @@ def forward(
img_height: int,
img_width: int,
background: Optional[Float[Tensor, "channels"]] = None,
return_alpha: Optional[bool] = False,
) -> Tensor:
num_points = xys.size(0)
BLOCK_X, BLOCK_Y = 16, 16
Expand Down Expand Up @@ -148,13 +153,20 @@ def forward(
final_idx,
)

return out_img
if return_alpha:
out_alpha = 1 - final_Ts
return out_img, out_alpha
else:
return out_img

@staticmethod
def backward(ctx, v_out_img):
def backward(ctx, v_out_img, v_out_alpha=None):
img_height = ctx.img_height
img_width = ctx.img_width

if v_out_alpha is None:
v_out_alpha = torch.zeros_like(v_out_img[..., 0])

(
gaussian_ids_sorted,
tile_bins,
Expand Down Expand Up @@ -184,6 +196,7 @@ def backward(ctx, v_out_img):
final_Ts,
final_idx,
v_out_img,
v_out_alpha,
)

return (
Expand All @@ -197,4 +210,5 @@ def backward(ctx, v_out_img):
None, # img_height
None, # img_width
None, # background
None, # return_alpha
)

0 comments on commit 0c305ab

Please sign in to comment.