Skip to content

Commit

Permalink
add-support-for-non-trainable-params (#456)
Browse files Browse the repository at this point in the history
* cleanup

* comments
  • Loading branch information
nikmo33 authored Nov 7, 2024
1 parent ec3e715 commit 36ef2dd
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 21 deletions.
9 changes: 6 additions & 3 deletions gsplat/strategy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ def check_sanity(
optimizers: Dict[str, torch.optim.Optimizer],
):
"""Sanity check for the parameters and optimizers."""
assert set(params.keys()) == set(optimizers.keys()), (
"params and optimizers must have the same keys, "
f"but got {params.keys()} and {optimizers.keys()}"
trainable_params = set(
[name for name, param in params.items() if param.requires_grad]
)
assert trainable_params == set(optimizers.keys()), (
"trainable parameters and optimizers must have the same keys, "
f"but got {trainable_params} and {optimizers.keys()}"
)

for optimizer in optimizers.values():
Expand Down
42 changes: 24 additions & 18 deletions gsplat/strategy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,25 @@ def _update_param_with_optimizer(
names = list(params.keys())

for name in names:
param = params[name]
new_param = param_fn(name, param)
params[name] = new_param
if name not in optimizers:
assert not param.requires_grad, (
f"Optimizer for {name} is not found, but the parameter is trainable."
f"Got requires_grad={param.requires_grad}"
)
continue
optimizer = optimizers[name]
for i, param_group in enumerate(optimizer.param_groups):
p = param_group["params"][0]
p_state = optimizer.state[p]
del optimizer.state[p]
for key in p_state.keys():
for i in range(len(optimizer.param_groups)):
param_state = optimizer.state[param]
del optimizer.state[param]
for key in param_state.keys():
if key != "step":
v = p_state[key]
p_state[key] = optimizer_fn(key, v)
p_new = param_fn(name, p)
optimizer.param_groups[i]["params"] = [p_new]
optimizer.state[p_new] = p_state
params[name] = p_new
v = param_state[key]
param_state[key] = optimizer_fn(key, v)
optimizer.param_groups[i]["params"] = [new_param]
optimizer.state[new_param] = param_state


@torch.no_grad()
Expand All @@ -101,7 +107,7 @@ def duplicate(
sel = torch.where(mask)[0]

def param_fn(name: str, p: Tensor) -> Tensor:
return torch.nn.Parameter(torch.cat([p, p[sel]]))
return torch.nn.Parameter(torch.cat([p, p[sel]]), requires_grad=p.requires_grad)

def optimizer_fn(key: str, v: Tensor) -> Tensor:
return torch.cat([v, torch.zeros((len(sel), *v.shape[1:]), device=device)])
Expand Down Expand Up @@ -157,7 +163,7 @@ def param_fn(name: str, p: Tensor) -> Tensor:
else:
p_split = p[sel].repeat(repeats)
p_new = torch.cat([p[rest], p_split])
p_new = torch.nn.Parameter(p_new)
p_new = torch.nn.Parameter(p_new, requires_grad=p.requires_grad)
return p_new

def optimizer_fn(key: str, v: Tensor) -> Tensor:
Expand Down Expand Up @@ -191,7 +197,7 @@ def remove(
sel = torch.where(~mask)[0]

def param_fn(name: str, p: Tensor) -> Tensor:
return torch.nn.Parameter(p[sel])
return torch.nn.Parameter(p[sel], requires_grad=p.requires_grad)

def optimizer_fn(key: str, v: Tensor) -> Tensor:
return v[sel]
Expand Down Expand Up @@ -222,7 +228,7 @@ def reset_opa(
def param_fn(name: str, p: Tensor) -> Tensor:
if name == "opacities":
opacities = torch.clamp(p, max=torch.logit(torch.tensor(value)).item())
return torch.nn.Parameter(opacities)
return torch.nn.Parameter(opacities, requires_grad=p.requires_grad)
else:
raise ValueError(f"Unexpected parameter name: {name}")

Expand Down Expand Up @@ -277,7 +283,7 @@ def param_fn(name: str, p: Tensor) -> Tensor:
elif name == "scales":
p[sampled_idxs] = torch.log(new_scales)
p[dead_indices] = p[sampled_idxs]
return torch.nn.Parameter(p)
return torch.nn.Parameter(p, requires_grad=p.requires_grad)

def optimizer_fn(key: str, v: Tensor) -> Tensor:
v[sampled_idxs] = 0
Expand Down Expand Up @@ -318,8 +324,8 @@ def param_fn(name: str, p: Tensor) -> Tensor:
p[sampled_idxs] = torch.logit(new_opacities)
elif name == "scales":
p[sampled_idxs] = torch.log(new_scales)
p = torch.cat([p, p[sampled_idxs]])
return torch.nn.Parameter(p)
p_new = torch.cat([p, p[sampled_idxs]])
return torch.nn.Parameter(p_new, requires_grad=p.requires_grad)

def optimizer_fn(key: str, v: Tensor) -> Tensor:
v_new = torch.zeros((len(sampled_idxs), *v.shape[1:]), device=v.device)
Expand Down
67 changes: 67 additions & 0 deletions tests/test_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,72 @@ def test_strategy():
strategy.step_post_backward(params, optimizers, state, step=600, info=info, lr=1e-3)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device")
def test_strategy_requires_grad():
from gsplat.rendering import rasterization
from gsplat.strategy import DefaultStrategy, MCMCStrategy

def assert_consistent_sizes(params):
sizes = [v.shape[0] for v in params.values()]
assert all([s == sizes[0] for s in sizes])

torch.manual_seed(42)

# Prepare Gaussians
N = 100
params = torch.nn.ParameterDict(
{
"means": torch.randn(N, 3),
"scales": torch.rand(N, 3),
"quats": torch.randn(N, 4),
"opacities": torch.rand(N),
"colors": torch.rand(N, 3),
"non_trainable_features": torch.rand(N, 3),
}
).to(device)
params["non_trainable_features"].requires_grad = False
requires_grad_map = {k: v.requires_grad for k, v in params.items()}
optimizers = {
k: torch.optim.Adam([v], lr=1e-3) for k, v in params.items() if v.requires_grad
}

# A dummy rendering call
render_colors, render_alphas, info = rasterization(
means=params["means"],
quats=params["quats"], # F.normalize is fused into the kernel
scales=torch.exp(params["scales"]),
opacities=torch.sigmoid(params["opacities"]),
colors=params["colors"],
viewmats=torch.eye(4).unsqueeze(0).to(device),
Ks=torch.eye(3).unsqueeze(0).to(device),
width=10,
height=10,
packed=False,
)

# Test DefaultStrategy
strategy = DefaultStrategy(verbose=True)
strategy.check_sanity(params, optimizers)
state = strategy.initialize_state()
strategy.step_pre_backward(params, optimizers, state, step=600, info=info)
render_colors.mean().backward(retain_graph=True)
strategy.step_post_backward(params, optimizers, state, step=600, info=info)
for k, v in params.items():
assert v.requires_grad == requires_grad_map[k]
assert params["non_trainable_features"].grad is None
assert_consistent_sizes(params)
# Test MCMCStrategy
strategy = MCMCStrategy(verbose=True)
strategy.check_sanity(params, optimizers)
state = strategy.initialize_state()
render_colors.mean().backward(retain_graph=True)
strategy.step_post_backward(params, optimizers, state, step=600, info=info, lr=1e-3)
assert params["non_trainable_features"].grad is None
for k, v in params.items():
assert v.requires_grad == requires_grad_map[k]
assert_consistent_sizes(params)


if __name__ == "__main__":
test_strategy()
test_strategy_requires_grad()

0 comments on commit 36ef2dd

Please sign in to comment.