Skip to content

Commit

Permalink
docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
jefequien committed Oct 29, 2024
1 parent 834e4e8 commit cb826c6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
1 change: 1 addition & 0 deletions examples/benchmarks/mcmc_deblur.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ do
--result_dir $RESULT_DIR/$SCENE
done

# Summarize the stats
python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val
6 changes: 5 additions & 1 deletion examples/blur_opt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch import Tensor
Expand All @@ -6,9 +7,12 @@
from gsplat.utils import log_transform


@dataclass
class BlurOptModule(nn.Module):
"""Blur optimization module."""

num_warmup_steps: int = 2000

def __init__(self, n: int, embed_dim: int = 4):
super().__init__()
self.embeds = torch.nn.Embedding(n, embed_dim)
Expand Down Expand Up @@ -74,7 +78,7 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor):
def mask_mean_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2):
"""Mask mean loss."""
x = blur_mask.mean()
if step <= 2000:
if step <= self.num_warmup_steps:
a = 20
else:
a = 10
Expand Down

0 comments on commit cb826c6

Please sign in to comment.