diff --git a/gsplat/strategy/base.py b/gsplat/strategy/base.py index 5089b00b9..0d6e61f30 100644 --- a/gsplat/strategy/base.py +++ b/gsplat/strategy/base.py @@ -18,9 +18,12 @@ def check_sanity( optimizers: Dict[str, torch.optim.Optimizer], ): """Sanity check for the parameters and optimizers.""" - assert set(params.keys()) == set(optimizers.keys()), ( - "params and optimizers must have the same keys, " - f"but got {params.keys()} and {optimizers.keys()}" + trainable_params = set( + [name for name, param in params.items() if param.requires_grad] + ) + assert trainable_params == set(optimizers.keys()), ( + "trainable parameters and optimizers must have the same keys, " + f"but got {trainable_params} and {optimizers.keys()}" ) for optimizer in optimizers.values(): diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 0789dfcbc..83c90a25e 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -68,19 +68,25 @@ def _update_param_with_optimizer( names = list(params.keys()) for name in names: + param = params[name] + new_param = param_fn(name, param) + params[name] = new_param + if name not in optimizers: + assert not param.requires_grad, ( + f"Optimizer for {name} is not found, but the parameter is trainable." + f"Got requires_grad={param.requires_grad}" + ) + continue optimizer = optimizers[name] - for i, param_group in enumerate(optimizer.param_groups): - p = param_group["params"][0] - p_state = optimizer.state[p] - del optimizer.state[p] - for key in p_state.keys(): + for i in range(len(optimizer.param_groups)): + param_state = optimizer.state[param] + del optimizer.state[param] + for key in param_state.keys(): if key != "step": - v = p_state[key] - p_state[key] = optimizer_fn(key, v) - p_new = param_fn(name, p) - optimizer.param_groups[i]["params"] = [p_new] - optimizer.state[p_new] = p_state - params[name] = p_new + v = param_state[key] + param_state[key] = optimizer_fn(key, v) + optimizer.param_groups[i]["params"] = [new_param] + optimizer.state[new_param] = param_state @torch.no_grad() @@ -101,7 +107,7 @@ def duplicate( sel = torch.where(mask)[0] def param_fn(name: str, p: Tensor) -> Tensor: - return torch.nn.Parameter(torch.cat([p, p[sel]])) + return torch.nn.Parameter(torch.cat([p, p[sel]]), requires_grad=p.requires_grad) def optimizer_fn(key: str, v: Tensor) -> Tensor: return torch.cat([v, torch.zeros((len(sel), *v.shape[1:]), device=device)]) @@ -157,7 +163,7 @@ def param_fn(name: str, p: Tensor) -> Tensor: else: p_split = p[sel].repeat(repeats) p_new = torch.cat([p[rest], p_split]) - p_new = torch.nn.Parameter(p_new) + p_new = torch.nn.Parameter(p_new, requires_grad=p.requires_grad) return p_new def optimizer_fn(key: str, v: Tensor) -> Tensor: @@ -191,7 +197,7 @@ def remove( sel = torch.where(~mask)[0] def param_fn(name: str, p: Tensor) -> Tensor: - return torch.nn.Parameter(p[sel]) + return torch.nn.Parameter(p[sel], requires_grad=p.requires_grad) def optimizer_fn(key: str, v: Tensor) -> Tensor: return v[sel] @@ -222,7 +228,7 @@ def reset_opa( def param_fn(name: str, p: Tensor) -> Tensor: if name == "opacities": opacities = torch.clamp(p, max=torch.logit(torch.tensor(value)).item()) - return torch.nn.Parameter(opacities) + return torch.nn.Parameter(opacities, requires_grad=p.requires_grad) else: raise ValueError(f"Unexpected parameter name: {name}") @@ -277,7 +283,7 @@ def param_fn(name: str, p: Tensor) -> Tensor: elif name == "scales": p[sampled_idxs] = torch.log(new_scales) p[dead_indices] = p[sampled_idxs] - return torch.nn.Parameter(p) + return torch.nn.Parameter(p, requires_grad=p.requires_grad) def optimizer_fn(key: str, v: Tensor) -> Tensor: v[sampled_idxs] = 0 @@ -318,8 +324,8 @@ def param_fn(name: str, p: Tensor) -> Tensor: p[sampled_idxs] = torch.logit(new_opacities) elif name == "scales": p[sampled_idxs] = torch.log(new_scales) - p = torch.cat([p, p[sampled_idxs]]) - return torch.nn.Parameter(p) + p_new = torch.cat([p, p[sampled_idxs]]) + return torch.nn.Parameter(p_new, requires_grad=p.requires_grad) def optimizer_fn(key: str, v: Tensor) -> Tensor: v_new = torch.zeros((len(sampled_idxs), *v.shape[1:]), device=v.device) diff --git a/tests/test_strategy.py b/tests/test_strategy.py index 5f432a9ba..031154152 100644 --- a/tests/test_strategy.py +++ b/tests/test_strategy.py @@ -62,5 +62,72 @@ def test_strategy(): strategy.step_post_backward(params, optimizers, state, step=600, info=info, lr=1e-3) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") +def test_strategy_requires_grad(): + from gsplat.rendering import rasterization + from gsplat.strategy import DefaultStrategy, MCMCStrategy + + def assert_consistent_sizes(params): + sizes = [v.shape[0] for v in params.values()] + assert all([s == sizes[0] for s in sizes]) + + torch.manual_seed(42) + + # Prepare Gaussians + N = 100 + params = torch.nn.ParameterDict( + { + "means": torch.randn(N, 3), + "scales": torch.rand(N, 3), + "quats": torch.randn(N, 4), + "opacities": torch.rand(N), + "colors": torch.rand(N, 3), + "non_trainable_features": torch.rand(N, 3), + } + ).to(device) + params["non_trainable_features"].requires_grad = False + requires_grad_map = {k: v.requires_grad for k, v in params.items()} + optimizers = { + k: torch.optim.Adam([v], lr=1e-3) for k, v in params.items() if v.requires_grad + } + + # A dummy rendering call + render_colors, render_alphas, info = rasterization( + means=params["means"], + quats=params["quats"], # F.normalize is fused into the kernel + scales=torch.exp(params["scales"]), + opacities=torch.sigmoid(params["opacities"]), + colors=params["colors"], + viewmats=torch.eye(4).unsqueeze(0).to(device), + Ks=torch.eye(3).unsqueeze(0).to(device), + width=10, + height=10, + packed=False, + ) + + # Test DefaultStrategy + strategy = DefaultStrategy(verbose=True) + strategy.check_sanity(params, optimizers) + state = strategy.initialize_state() + strategy.step_pre_backward(params, optimizers, state, step=600, info=info) + render_colors.mean().backward(retain_graph=True) + strategy.step_post_backward(params, optimizers, state, step=600, info=info) + for k, v in params.items(): + assert v.requires_grad == requires_grad_map[k] + assert params["non_trainable_features"].grad is None + assert_consistent_sizes(params) + # Test MCMCStrategy + strategy = MCMCStrategy(verbose=True) + strategy.check_sanity(params, optimizers) + state = strategy.initialize_state() + render_colors.mean().backward(retain_graph=True) + strategy.step_post_backward(params, optimizers, state, step=600, info=info, lr=1e-3) + assert params["non_trainable_features"].grad is None + for k, v in params.items(): + assert v.requires_grad == requires_grad_map[k] + assert_consistent_sizes(params) + + if __name__ == "__main__": test_strategy() + test_strategy_requires_grad()