Skip to content

Commit

Permalink
latest
Browse files Browse the repository at this point in the history
  • Loading branch information
jefequien committed Oct 26, 2024
1 parent ed830d7 commit d874ef1
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 15 deletions.
2 changes: 2 additions & 0 deletions examples/benchmarks/mcmc_deblur.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ do
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \
--strategy.cap-max $CAP_MAX \
--blur_opt \
--blur_opt_lr 1e-3 \
--blur_mean_reg 0.005 \
--render_traj_path $RENDER_TRAJ_PATH \
--data_dir $SCENE_DIR/$SCENE/ \
--result_dir $RESULT_DIR/$SCENE
Expand Down
21 changes: 16 additions & 5 deletions examples/blur_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from kornia.filters import median_blur
from examples.mlp import create_mlp
from gsplat.utils import log_transform

Expand Down Expand Up @@ -65,17 +66,27 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor):
depths_emb = self.depths_encoder.encode(log_transform(depths))
images_emb = self.embeds(image_ids).repeat(*depths_emb.shape[:-1], 1)
mlp_in = torch.cat([images_emb, grid_emb, depths_emb], dim=-1)
mlp_out = self.blur_mask_mlp(mlp_in.reshape(-1, mlp_in.shape[-1]))
mlp_out = self.blur_mask_mlp(mlp_in.reshape(-1, mlp_in.shape[-1])).reshape(
depths.shape
)
blur_mask = torch.sigmoid(mlp_out)
blur_mask = blur_mask.reshape(depths.shape)
return blur_mask

def mask_variation_loss(self, blur_mask: Tensor, eps: float = 1e-2):
"""Mask variation loss."""
def mask_mean_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2):
"""Mask mean loss."""
x = blur_mask.mean()
meanloss = (1 / (1 - x + eps) - 1) + (0.1 / (x + eps))
a = 0.9
meanloss = a * (1 / (1 - x + eps) - 1) + (1 - a) * (1 / (x + eps) - 1)
return meanloss

def mask_smoothness_loss(self, blur_mask: Tensor):
"""Mask smoothness loss."""
blurred_xy = median_blur(blur_mask.permute(0, 3, 1, 2), (5, 5)).permute(
0, 2, 3, 1
)
smoothloss = F.huber_loss(blur_mask, blurred_xy)
return smoothloss


def get_encoder(num_freqs: int, input_dims: int):
kwargs = {
Expand Down
31 changes: 21 additions & 10 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,10 @@ class Config:
blur_opt: bool = False
# Learning rate for blur optimization
blur_opt_lr: float = 1e-3
# Regularization for blur mask
blur_mask_reg: float = 0.001
# Regularization for blur mask mean
blur_mean_reg: float = 0.001
# Regularization for blur mask smoothness
blur_smoothness_reg: float = 0.0

# Enable bilateral grid. (experimental)
use_bilateral_grid: bool = False
Expand Down Expand Up @@ -661,12 +663,15 @@ def train(self):
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
image_ids=image_ids,
render_mode="RGB",
render_mode="RGB+ED",
masks=masks,
blur=True,
)
colors_blur = renders_blur[..., 0:3]
blur_mask = self.blur_module.predict_mask(image_ids, depths)
colors_blur, depths_blur = (
renders_blur[..., 0:3],
renders_blur[..., 3:4],
)
blur_mask = self.blur_module.predict_mask(image_ids, depths_blur)
colors = (1 - blur_mask) * colors + blur_mask * colors_blur

self.cfg.strategy.step_pre_backward(
Expand Down Expand Up @@ -706,7 +711,10 @@ def train(self):
tvloss = 10 * total_variation_loss(self.bil_grids.grids)
loss += tvloss
if cfg.blur_opt:
loss += cfg.blur_mask_reg * self.blur_module.mask_variation_loss(
loss += cfg.blur_mean_reg * self.blur_module.mask_mean_loss(
blur_mask, step
)
loss += cfg.blur_smoothness_reg * self.blur_module.mask_smoothness_loss(
blur_mask
)

Expand Down Expand Up @@ -932,8 +940,6 @@ def eval(self, step: int, stage: str = "val"):
colors = torch.clamp(colors, 0.0, 1.0)
canvas_list = [pixels, colors]
if self.cfg.blur_opt and stage == "train":
blur_mask = self.blur_module.predict_mask(image_ids, depths)
canvas_list.append(blur_mask.repeat(1, 1, 1, 3))
renders_blur, _, _ = self.rasterize_splats(
camtoworlds=camtoworlds,
Ks=Ks,
Expand All @@ -943,11 +949,16 @@ def eval(self, step: int, stage: str = "val"):
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
image_ids=image_ids,
render_mode="RGB",
render_mode="RGB+ED",
masks=masks,
blur=True,
)
colors_blur = renders_blur[..., 0:3]
colors_blur, depths_blur = (
renders_blur[..., 0:3],
renders_blur[..., 3:4],
)
blur_mask = self.blur_module.predict_mask(image_ids, depths_blur)
canvas_list.append(blur_mask.repeat(1, 1, 1, 3))
canvas_list.append(torch.clamp(colors_blur, 0.0, 1.0))
colors = (1 - blur_mask) * colors + blur_mask * colors_blur
colors = torch.clamp(colors, 0.0, 1.0)
Expand Down

0 comments on commit d874ef1

Please sign in to comment.