diff --git a/examples/blur_opt.py b/examples/blur_opt.py index a84880df..68f270e0 100644 --- a/examples/blur_opt.py +++ b/examples/blur_opt.py @@ -73,20 +73,19 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor): blur_mask = torch.sigmoid(mlp_out) return blur_mask - def mask_mean_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2): + def mask_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2): """Loss function for regularizing the blur mask by controlling its mean. - The loss function is designed to diverge to +infinity at 0 and 1. This - prevents the mask from collapsing to predicting all 0s or 1s. It is also - bias towards 0 to encourage sparsity. During warmup, we set this bias even - higher to start with a sparse and not collapsed blur mask.""" + The loss function diverges to +infinity at 0 and 1. This prevents the mask + from collapsing all 0s or 1s. It is also biased towards 0 to encourage + sparsity. During warmup, the bias is even higher to start with a sparse mask.""" x = blur_mask.mean() if step <= self.num_warmup_steps: a = 2 else: a = 1 - meanloss = a * (1 / (1 - x + eps) - 1) + 0.1 * (1 / (x + eps) - 1) - return meanloss + maskloss = a * (1 / (1 - x + eps) - 1) + 0.1 * (1 / (x + eps) - 1) + return maskloss def get_encoder(num_freqs: int, input_dims: int): diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 52603c0f..0bed4d6a 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -83,7 +83,7 @@ class Config: # Number of training steps max_steps: int = 30_000 # Steps to evaluate the model - eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + eval_steps: List[int] = field(default_factory=lambda: [7_000, 15_000, 30_000]) # Steps to save the model save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) @@ -152,8 +152,8 @@ class Config: blur_opt: bool = False # Learning rate for blur optimization blur_opt_lr: float = 1e-3 - # Regularization for blur mask mean - blur_mean_reg: float = 0.002 + # Regularization for blur mask + blur_mask_reg: float = 0.002 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False @@ -705,9 +705,7 @@ def train(self): tvloss = 10 * total_variation_loss(self.bil_grids.grids) loss += tvloss if cfg.blur_opt: - loss += cfg.blur_mean_reg * self.blur_module.mask_mean_loss( - blur_mask, step - ) + loss += cfg.blur_mask_reg * self.blur_module.mask_loss(blur_mask, step) # regularizations if cfg.opacity_reg > 0.0: @@ -721,6 +719,7 @@ def train(self): loss + cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() ) + loss.backward() desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| "