Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
jefequien committed Oct 29, 2024
1 parent a3d7d15 commit 10bac1d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
13 changes: 6 additions & 7 deletions examples/blur_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,19 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor):
blur_mask = torch.sigmoid(mlp_out)
return blur_mask

def mask_mean_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2):
def mask_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2):
"""Loss function for regularizing the blur mask by controlling its mean.
The loss function is designed to diverge to +infinity at 0 and 1. This
prevents the mask from collapsing to predicting all 0s or 1s. It is also
bias towards 0 to encourage sparsity. During warmup, we set this bias even
higher to start with a sparse and not collapsed blur mask."""
The loss function diverges to +infinity at 0 and 1. This prevents the mask
from collapsing all 0s or 1s. It is also biased towards 0 to encourage
sparsity. During warmup, the bias is even higher to start with a sparse mask."""
x = blur_mask.mean()
if step <= self.num_warmup_steps:
a = 2
else:
a = 1
meanloss = a * (1 / (1 - x + eps) - 1) + 0.1 * (1 / (x + eps) - 1)
return meanloss
maskloss = a * (1 / (1 - x + eps) - 1) + 0.1 * (1 / (x + eps) - 1)
return maskloss


def get_encoder(num_freqs: int, input_dims: int):
Expand Down
11 changes: 5 additions & 6 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class Config:
# Number of training steps
max_steps: int = 30_000
# Steps to evaluate the model
eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
eval_steps: List[int] = field(default_factory=lambda: [7_000, 15_000, 30_000])
# Steps to save the model
save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])

Expand Down Expand Up @@ -152,8 +152,8 @@ class Config:
blur_opt: bool = False
# Learning rate for blur optimization
blur_opt_lr: float = 1e-3
# Regularization for blur mask mean
blur_mean_reg: float = 0.002
# Regularization for blur mask
blur_mask_reg: float = 0.002

# Enable bilateral grid. (experimental)
use_bilateral_grid: bool = False
Expand Down Expand Up @@ -705,9 +705,7 @@ def train(self):
tvloss = 10 * total_variation_loss(self.bil_grids.grids)
loss += tvloss
if cfg.blur_opt:
loss += cfg.blur_mean_reg * self.blur_module.mask_mean_loss(
blur_mask, step
)
loss += cfg.blur_mask_reg * self.blur_module.mask_loss(blur_mask, step)

# regularizations
if cfg.opacity_reg > 0.0:
Expand All @@ -721,6 +719,7 @@ def train(self):
loss
+ cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean()
)

loss.backward()

desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| "
Expand Down

0 comments on commit 10bac1d

Please sign in to comment.