Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement normal consistency loss. #273

Open
wants to merge 77 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
d31aa2b
use inria cuda to train
jefequien Jul 2, 2024
6a58dc3
canvas list
jefequien Jul 2, 2024
1b6e094
2dgs
jefequien Jul 3, 2024
55b39bb
lambda_dist
jefequien Jul 3, 2024
f4ea0e0
clean up
jefequien Jul 3, 2024
87748f4
format
jefequien Jul 3, 2024
7ee6c20
3m garden first
jefequien Jul 3, 2024
143af51
2dgs_mcmc_sfm
jefequien Jul 4, 2024
6670314
train 3dgs without normal loss
jefequien Jul 4, 2024
286867b
Merge branch 'main' into jeff/inria
jefequien Jul 4, 2024
a8fd7a9
3dgs normal working
jefequien Jul 4, 2024
5870b25
baseline no gradient
jefequien Jul 5, 2024
a8adfe5
normal backprob
jefequien Jul 5, 2024
08637cb
cleanup
jefequien Jul 5, 2024
ea6ce1a
cleanup
jefequien Jul 5, 2024
f7f0b9c
cleanup
jefequien Jul 5, 2024
468c451
cmocean dense
jefequien Jul 5, 2024
8cfa0fc
cmo ice
jefequien Jul 5, 2024
20cd888
voltage
jefequien Jul 5, 2024
503ac93
edit bash
jefequien Jul 5, 2024
23f9984
clean up benchmark script
jefequien Jul 5, 2024
02e00fb
remove dist_loss
jefequien Jul 5, 2024
96e8d23
depth must be last
jefequien Jul 5, 2024
f336b95
colors normal depth
jefequien Jul 5, 2024
54f137e
reduce diff
jefequien Jul 8, 2024
50749de
cleanup
jefequien Jul 8, 2024
bcac804
remove distloss
jefequien Jul 8, 2024
5752261
benchmark script
jefequien Jul 8, 2024
dadaf75
refactor
jefequien Jul 8, 2024
6800602
compile bug
jefequien Jul 8, 2024
98573e6
Merge branch 'main' into jeff/normal_consistency
jefequien Jul 9, 2024
6d156d2
remove benchmark script
jefequien Jul 9, 2024
9c06a24
support packed and sparse
jefequien Jul 15, 2024
7294be4
refactor
jefequien Jul 15, 2024
0a9ce04
bugfix
jefequien Jul 15, 2024
1bd9e91
utils/
jefequien Jul 15, 2024
d7a2916
rasterization backend
jefequien Jul 15, 2024
0274d1b
v_rotmat
jefequien Jul 15, 2024
6840727
render traj
jefequien Jul 15, 2024
5fdc934
point inward during ellipse
jefequien Jul 15, 2024
e3dc5e5
add tests for fwd and bwd pass
jefequien Jul 16, 2024
467ffe6
all but one test passing
jefequien Jul 16, 2024
8e158a2
weird bug
jefequien Jul 16, 2024
e5c9b64
merge
jefequien Jul 16, 2024
08b4631
test passes but test suite does not pass
jefequien Jul 16, 2024
c7e2fd4
count radii
jefequien Jul 16, 2024
ee40bb2
change sel to pass tests
jefequien Jul 16, 2024
6f21cb7
__sel
jefequien Jul 16, 2024
d2f89ab
cleanup
jefequien Jul 16, 2024
12723b1
simplify utils
jefequien Jul 17, 2024
433b313
merge
jefequien Jul 19, 2024
fadb7b4
util
jefequien Jul 19, 2024
f40b879
test rasterization
jefequien Jul 19, 2024
d766117
benchmark script
jefequien Jul 19, 2024
8e8d681
__init__
jefequien Jul 19, 2024
90c62af
merge
jefequien Aug 29, 2024
24eb1b5
fix normal consistency
jefequien Aug 31, 2024
a86ef37
ellipse
jefequien Aug 31, 2024
f0b93f0
cleanup
jefequien Aug 31, 2024
a014c89
uncomment
jefequien Aug 31, 2024
35b21c4
canvas list
jefequien Aug 31, 2024
9aaf257
Merge branch 'main' into jeff/normal_consistency
jefequien Aug 31, 2024
1afd41a
merge with traj
jefequien Aug 31, 2024
f072f17
cleanup
jefequien Aug 31, 2024
1fa189c
fix tests
jefequien Aug 31, 2024
b9ef876
remove 2dgs inria
jefequien Aug 31, 2024
c08a23e
script
jefequien Sep 1, 2024
1535db1
merge
jefequien Sep 3, 2024
6346132
merge
jefequien Sep 13, 2024
bfd78bc
fix merge
jefequien Sep 13, 2024
0a1dc09
fix utils
jefequien Sep 13, 2024
5b5a7c3
reduce diff test_basic
jefequien Sep 13, 2024
9c4186a
tests not passing
jefequien Sep 13, 2024
0c5e3ed
all tests passed
jefequien Sep 13, 2024
541990a
merge
jefequien Sep 22, 2024
38f2532
summarize stats
jefequien Sep 24, 2024
72d38d6
merge
jefequien Sep 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 125 additions & 47 deletions examples/simple_trainer_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,20 @@
set_random_seed,
)
from gsplat import quat_scale_to_covar_preci
from gsplat.rendering import rasterization
from gsplat.rendering import (
rasterization,
rasterization_2dgs_inria_wrapper,
rasterization_inria_wrapper,
)
from gsplat.relocation import compute_relocation
from gsplat.cuda_legacy._torch_impl import scale_rot_to_cov3d
from gsplat.normal_utils import depth_to_normal
from simple_trainer import create_splats_with_optimizers


@dataclass
class Config:
# Model type can be 3dgs, 3dgs_inria, or 2dgs_inria
model_type: str = "3dgs"
# Disable viewer
disable_viewer: bool = False
# Path to the .pt file. If provide, it will skip training and render a video
Expand Down Expand Up @@ -139,6 +145,13 @@ class Config:
# Weight for depth loss
depth_lambda: float = 1e-2

# Enable normal consistency loss. (experimental)
normal_consistency_loss: bool = False
# Weight for normal consistency loss
normal_consistency_lambda: float = 0.05
# Start applying normal consistency loss after this iteration
normal_consistency_start_iter: int = 7000

# Dump information to tensorboard every this steps
tb_every: int = 100
# Save training images to tensorboard
Expand All @@ -162,6 +175,18 @@ def __init__(self, cfg: Config) -> None:

self.cfg = cfg
self.device = "cuda"
if cfg.model_type == "3dgs":
self.rasterization_fn = rasterization
elif cfg.model_type == "3dgs_inria":
self.rasterization_fn = rasterization_inria_wrapper
elif cfg.model_type == "2dgs_inria":
self.rasterization_fn = rasterization_2dgs_inria_wrapper
else:
raise ValueError(f"Unsupported model type: {cfg.model_type}")

self.render_mode = "RGB"
if cfg.depth_loss or cfg.normal_consistency_loss:
self.render_mode = "RGB+ED"

# Where to dump results.
os.makedirs(cfg.result_dir, exist_ok=True)
Expand Down Expand Up @@ -273,8 +298,6 @@ def rasterize_splats(
**kwargs,
) -> Tuple[Tensor, Tensor, Dict]:
means = self.splats["means3d"] # [N, 3]
# quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4]
# rasterization does normalization internally
quats = self.splats["quats"] # [N, 4]
scales = torch.exp(self.splats["scales"]) # [N, 3]
opacities = torch.sigmoid(self.splats["opacities"]) # [N,]
Expand All @@ -293,7 +316,8 @@ def rasterize_splats(
colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3]

rasterize_mode = "antialiased" if self.cfg.antialiased else "classic"
render_colors, render_alphas, info = rasterization(

render_colors, render_alphas, info = self.rasterization_fn(
means=means,
quats=quats,
scales=scales,
Expand Down Expand Up @@ -394,26 +418,28 @@ def train(self):
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
image_ids=image_ids,
render_mode="RGB+ED" if cfg.depth_loss else "RGB",
render_mode=self.render_mode,
)
if renders.shape[-1] == 4:
colors, depths = renders[..., 0:3], renders[..., 3:4]
else:
colors, depths = renders, None
colors = renders[..., :3]

if cfg.random_bkgd:
bkgd = torch.rand(1, 3, device=device)
colors = colors + bkgd * (1.0 - alphas)

info["means2d"].retain_grad() # used for running stats

# loss
l1loss = F.l1_loss(colors, pixels)
ssimloss = 1.0 - self.ssim(
pixels.permute(0, 3, 1, 2), colors.permute(0, 3, 1, 2)
)
loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda
loss += (
cfg.opacity_reg
* torch.abs(torch.sigmoid(self.splats["opacities"])).mean()
)
loss += cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean()

if cfg.depth_loss:
depths = renders[..., -1:]
# query depths from depth map
points = torch.stack(
[
Expand All @@ -433,15 +459,22 @@ def train(self):
depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale
loss += depthloss * cfg.depth_lambda

loss = (
loss
+ cfg.opacity_reg
* torch.abs(torch.sigmoid(self.splats["opacities"])).mean()
)
loss = (
loss
+ cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean()
)
if cfg.normal_consistency_loss:
depths = renders[..., -1:]
normals = renders[..., -4:-1]
normals_surf = depth_to_normal(
depths,
camtoworlds,
Ks,
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
)
normals_surf = normals_surf * (alphas).detach()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to multiply with alphas? Isn’t normals_surf already normalized here: https://github.com/jefequien/gsplat-mcmc/blob/6d156d2d447d9557e432b7f4b94fa6b3b9527187/gsplat/normal_utils.py#L41?

Copy link
Contributor Author

@jefequien jefequien Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can either normalize both the rendered normals and surface normals or use the alpha-composited rendered normals and multiply the surface normals by alpha. The latter results in normals that look better visually.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification. I was using the first one in GOF.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latter one should be more robust when dealing with semi-transparent pixels (accumulate_alpha < 1).

normalconsistencyloss = (
1 - (normals * normals_surf).sum(dim=-1)
).mean()
if step > cfg.normal_consistency_start_iter:
loss += normalconsistencyloss * cfg.normal_consistency_lambda

loss.backward()

Expand All @@ -465,6 +498,12 @@ def train(self):
self.writer.add_scalar("train/mem", mem, step)
if cfg.depth_loss:
self.writer.add_scalar("train/depthloss", depthloss.item(), step)
if cfg.normal_consistency_loss:
self.writer.add_scalar(
"train/normalconsistencyloss",
normalconsistencyloss.item(),
step,
)
if cfg.tb_save_image:
canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy()
canvas = canvas.reshape(-1, *canvas.shape[2:])
Expand Down Expand Up @@ -681,24 +720,43 @@ def eval(self, step: int):

torch.cuda.synchronize()
tic = time.time()
colors, _, _ = self.rasterize_splats(
renders, alphas, info = self.rasterize_splats(
camtoworlds=camtoworlds,
Ks=Ks,
width=width,
height=height,
sh_degree=cfg.sh_degree,
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
) # [1, H, W, 3]
colors = torch.clamp(colors, 0.0, 1.0)
render_mode=self.render_mode,
) # [1, H, W, C]
torch.cuda.synchronize()
ellipse_time += time.time() - tic

colors = torch.clamp(renders[..., 0:3], 0.0, 1.0)
canvas_list = [pixels, colors]
if cfg.depth_loss:
depths = renders[..., -1:]
depths = (depths - depths.min()) / (depths.max() - depths.min())
canvas_list.append(depths)
if cfg.normal_consistency_loss:
depths = renders[..., -1:]
normals = renders[..., -4:-1]
normals_surf = depth_to_normal(
depths,
camtoworlds,
Ks,
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
)
normals_surf = normals_surf * (alphas).detach()
canvas_list.extend([normals * 0.5 + 0.5])
canvas_list.extend([normals_surf * 0.5 + 0.5])

# write images
canvas = torch.cat([pixels, colors], dim=2).squeeze(0).cpu().numpy()
imageio.imwrite(
f"{self.render_dir}/val_{i:04d}.png", (canvas * 255).astype(np.uint8)
)
canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy()
canvas = (canvas * 255).astype(np.uint8)
imageio.imwrite(f"{self.render_dir}/val_step{step:04d}_{i:04d}.png", canvas)

pixels = pixels.permute(0, 3, 1, 2) # [1, 3, H, W]
colors = colors.permute(0, 3, 1, 2) # [1, 3, H, W]
Expand Down Expand Up @@ -738,41 +796,61 @@ def render_traj(self, step: int):
cfg = self.cfg
device = self.device

camtoworlds = self.parser.camtoworlds[5:-5]
camtoworlds = generate_interpolated_path(camtoworlds, 1) # [N, 3, 4]
camtoworlds = np.concatenate(
camtoworlds_all = self.parser.camtoworlds[5:-5]
camtoworlds_all = generate_interpolated_path(camtoworlds_all, 1) # [N, 3, 4]
camtoworlds_all = np.concatenate(
[
camtoworlds,
np.repeat(np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds), axis=0),
camtoworlds_all,
np.repeat(
np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds_all), axis=0
),
],
axis=1,
) # [N, 4, 4]

camtoworlds = torch.from_numpy(camtoworlds).float().to(device)
camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device)
K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device)
width, height = list(self.parser.imsize_dict.values())[0]

canvas_all = []
for i in tqdm.trange(len(camtoworlds), desc="Rendering trajectory"):
renders, _, _ = self.rasterize_splats(
camtoworlds=camtoworlds[i : i + 1],
Ks=K[None],
for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"):
camtoworlds = camtoworlds_all[i : i + 1]
Ks = K[None]

renders, alphas, info = self.rasterize_splats(
camtoworlds=camtoworlds,
Ks=Ks,
width=width,
height=height,
sh_degree=cfg.sh_degree,
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
render_mode="RGB+ED",
) # [1, H, W, 4]
colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3]
depths = renders[0, ..., 3:4] # [H, W, 1]
depths = (depths - depths.min()) / (depths.max() - depths.min())
render_mode=self.render_mode,
) # [1, H, W, C]

colors = torch.clamp(renders[..., 0:3], 0.0, 1.0)
canvas_list = [colors]
if cfg.depth_loss:
depths = renders[..., -1:]
depths = (depths - depths.min()) / (depths.max() - depths.min())
canvas_list.append(depths)
if cfg.normal_consistency_loss:
depths = renders[..., -1:]
normals = renders[..., -4:-1]
normals_surf = depth_to_normal(
depths,
camtoworlds,
Ks,
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
)
normals_surf = normals_surf * (alphas).detach()
canvas_list.extend([normals * 0.5 + 0.5])
canvas_list.extend([normals_surf * 0.5 + 0.5])

# write images
canvas = torch.cat(
[colors, depths.repeat(1, 1, 3)], dim=0 if width > height else 1
)
canvas = (canvas.cpu().numpy() * 255).astype(np.uint8)
canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy()
canvas = (canvas * 255).astype(np.uint8)
canvas_all.append(canvas)

# save to video
Expand Down
9 changes: 6 additions & 3 deletions gsplat/cuda/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def forward(
calc_compensations: bool,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
# "covars" and {"quats", "scales"} are mutually exclusive
radii, means2d, depths, conics, compensations = _make_lazy_cuda_func(
radii, means2d, depths, normals, conics, compensations = _make_lazy_cuda_func(
"fully_fused_projection_fwd"
)(
means,
Expand All @@ -735,10 +735,12 @@ def forward(
ctx.height = height
ctx.eps2d = eps2d

return radii, means2d, depths, conics, compensations
return radii, means2d, depths, normals, conics, compensations

@staticmethod
def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations):
def backward(
ctx, v_radii, v_means2d, v_depths, v_normals, v_conics, v_compensations
):
(
means,
covars,
Expand Down Expand Up @@ -772,6 +774,7 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations):
compensations,
v_means2d.contiguous(),
v_depths.contiguous(),
v_normals.contiguous(),
v_conics.contiguous(),
v_compensations,
ctx.needs_input_grad[4], # viewmats_requires_grad
Expand Down
3 changes: 2 additions & 1 deletion gsplat/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ world_to_cam_bwd_tensor(const torch::Tensor &means, // [N, 3]
const bool means_requires_grad, const bool covars_requires_grad,
const bool viewmats_requires_grad);

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
fully_fused_projection_fwd_tensor(
const torch::Tensor &means, // [N, 3]
const at::optional<torch::Tensor> &covars, // [N, 6] optional
Expand Down Expand Up @@ -94,6 +94,7 @@ fully_fused_projection_bwd_tensor(
// grad outputs
const torch::Tensor &v_means2d, // [C, N, 2]
const torch::Tensor &v_depths, // [C, N]
const torch::Tensor &v_normals, // [C, N, 3]
const torch::Tensor &v_conics, // [C, N, 3]
const at::optional<torch::Tensor> &v_compensations, // [C, N] optional
const bool viewmats_requires_grad);
Expand Down
12 changes: 11 additions & 1 deletion gsplat/cuda/csrc/fully_fused_projection_bwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ __global__ void fully_fused_projection_bwd_kernel(
// grad outputs
const T *__restrict__ v_means2d, // [C, N, 2]
const T *__restrict__ v_depths, // [C, N]
const T *__restrict__ v_normals, // [C, N, 3]
const T *__restrict__ v_conics, // [C, N, 3]
const T *__restrict__ v_compensations, // [C, N] optional
// grad inputs
Expand All @@ -59,6 +60,7 @@ __global__ void fully_fused_projection_bwd_kernel(

v_means2d += idx * 2;
v_depths += idx;
v_normals += idx * 3;
v_conics += idx * 3;

// vjp: compute the inverse of the 2d covariance
Expand Down Expand Up @@ -152,6 +154,12 @@ __global__ void fully_fused_projection_bwd_kernel(
vec4<T> v_quat(0.f);
vec3<T> v_scale(0.f);
quat_scale_to_covar_vjp<T>(quat, scale, rotmat, v_covar, v_quat, v_scale);

// add contribution from v_normals. Please check if this is correct.
mat3<T> v_R = quat_to_rotmat<T>(quat);
v_R[2] += glm::make_vec3(v_normals);
quat_to_rotmat_vjp<T>(quat, v_R, v_quat);

warpSum(v_quat, warp_group_g);
warpSum(v_scale, warp_group_g);
if (warp_group_g.thread_rank() == 0) {
Expand Down Expand Up @@ -201,6 +209,7 @@ fully_fused_projection_bwd_tensor(
// grad outputs
const torch::Tensor &v_means2d, // [C, N, 2]
const torch::Tensor &v_depths, // [C, N]
const torch::Tensor &v_normals, // [C, N, 3]
const torch::Tensor &v_conics, // [C, N, 3]
const at::optional<torch::Tensor> &v_compensations, // [C, N] optional
const bool viewmats_requires_grad) {
Expand All @@ -219,6 +228,7 @@ fully_fused_projection_bwd_tensor(
CHECK_INPUT(conics);
CHECK_INPUT(v_means2d);
CHECK_INPUT(v_depths);
CHECK_INPUT(v_normals);
CHECK_INPUT(v_conics);
if (compensations.has_value()) {
CHECK_INPUT(compensations.value());
Expand Down Expand Up @@ -255,7 +265,7 @@ fully_fused_projection_bwd_tensor(
compensations.has_value() ? compensations.value().data_ptr<float>()
: nullptr,
v_means2d.data_ptr<float>(), v_depths.data_ptr<float>(),
v_conics.data_ptr<float>(),
v_normals.data_ptr<float>(), v_conics.data_ptr<float>(),
v_compensations.has_value() ? v_compensations.value().data_ptr<float>()
: nullptr,
v_means.data_ptr<float>(),
Expand Down
Loading
Loading