Skip to content

Commit

Permalink
cleanup appearance embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
MrNeRF committed Sep 27, 2024
1 parent 1d3f786 commit 60426af
Showing 1 changed file with 2 additions and 54 deletions.
56 changes: 2 additions & 54 deletions examples/simple_trainer_scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class Config:
# Steps to evaluate the model
eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
# Steps to save the model
save_steps: List[int] = field(default_factory=lambda: [2_000, 30_000])
save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])

# voxel size for Scaffold-GS
voxel_size = 0.001
Expand Down Expand Up @@ -134,15 +134,6 @@ class Config:
# Add noise to camera extrinsics. This is only to test the camera pose optimization.
pose_noise: float = 0.0

# Enable appearance optimization. (experimental)
app_opt: bool = False
# Appearance embedding dimension
app_embed_dim: int = 16
# Learning rate for appearance optimization
app_opt_lr: float = 1e-3
# Regularization for appearance optimization as weight decay
app_opt_reg: float = 1e-6

# Enable bilateral grid. (experimental)
use_bilateral_grid: bool = False
# Shape of the bilateral grid (X, Y, W)
Expand Down Expand Up @@ -349,7 +340,6 @@ def __init__(
print("Scene scale:", self.scene_scale)

# Model
feature_dim = 32 if cfg.app_opt else None
self.splats, self.optimizers = create_splats_with_optimizers(
self.parser,
init_extent=cfg.init_extent,
Expand All @@ -358,7 +348,6 @@ def __init__(
scene_scale=self.scene_scale,
sparse_grad=cfg.sparse_grad,
batch_size=cfg.batch_size,
feature_dim=feature_dim,
device=self.device,
world_rank=world_rank,
world_size=world_size,
Expand Down Expand Up @@ -400,29 +389,6 @@ def __init__(
if world_size > 1:
self.pose_perturb = DDP(self.pose_perturb)

self.app_optimizers = []
if cfg.app_opt:
assert feature_dim is not None
self.app_module = AppearanceOptModule(
len(self.trainset), feature_dim, cfg.app_embed_dim, None
).to(self.device)
# initialize the last layer to be zero so that the initial output is zero.
torch.nn.init.zeros_(self.app_module.color_head[-1].weight)
torch.nn.init.zeros_(self.app_module.color_head[-1].bias)
self.app_optimizers = [
torch.optim.Adam(
self.app_module.embeds.parameters(),
lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0,
weight_decay=cfg.app_opt_reg,
),
torch.optim.Adam(
self.app_module.color_head.parameters(),
lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size),
),
]
if world_size > 1:
self.app_module = DDP(self.app_module)

self.bil_grid_optimizers = []
if cfg.use_bilateral_grid:
self.bil_grids = BilateralGrid(
Expand Down Expand Up @@ -622,17 +588,7 @@ def rasterize_splats(
rasterize_mode="antialiased" if self.cfg.antialiased else "classic",
)

image_ids = kwargs.pop("image_ids", None)
if self.cfg.app_opt:
colors = self.app_module(
features=self.splats["gauss_params"]["features"],
embed_ids=image_ids,
dirs=info["means"][None, :, :] - camtoworlds[:, None, :3, 3],
)
colors = colors + info["colors"]
colors = torch.sigmoid(colors)
else:
colors = info["colors"] # [N, K, 3]
colors = info["colors"] # [N, K, 3]

rasterize_mode = "antialiased" if self.cfg.antialiased else "classic"
render_colors, render_alphas, raster_info = rasterization(
Expand Down Expand Up @@ -907,11 +863,6 @@ def train(self):
data["pose_adjust"] = self.pose_adjust.module.state_dict()
else:
data["pose_adjust"] = self.pose_adjust.state_dict()
if cfg.app_opt:
if world_size > 1:
data["app_module"] = self.app_module.module.state_dict()
else:
data["app_module"] = self.app_module.state_dict()
torch.save(
data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt"
)
Expand Down Expand Up @@ -952,9 +903,6 @@ def train(self):
for optimizer in self.pose_optimizers:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
for optimizer in self.app_optimizers:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
for optimizer in self.bil_grid_optimizers:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
Expand Down

0 comments on commit 60426af

Please sign in to comment.