Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
nikmo33 committed Oct 18, 2024
1 parent ec3e715 commit 79da6bf
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 27 deletions.
9 changes: 6 additions & 3 deletions gsplat/strategy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ def check_sanity(
optimizers: Dict[str, torch.optim.Optimizer],
):
"""Sanity check for the parameters and optimizers."""
assert set(params.keys()) == set(optimizers.keys()), (
"params and optimizers must have the same keys, "
f"but got {params.keys()} and {optimizers.keys()}"
trainable_params = set(
[name for name, param in params.items() if param.requires_grad]
)
assert trainable_params == set(optimizers.keys()), (
"trainable parameters and optimizers must have the same keys, "
f"but got {trainable_params} and {optimizers.keys()}"
)

for optimizer in optimizers.values():
Expand Down
51 changes: 27 additions & 24 deletions gsplat/strategy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _multinomial_sample(weights: Tensor, n: int, replacement: bool = True) -> Te

@torch.no_grad()
def _update_param_with_optimizer(
param_fn: Callable[[str, Tensor], Tensor],
param_fn: Callable[[str, Tensor, bool], Tensor],
optimizer_fn: Callable[[str, Tensor], Tensor],
params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
optimizers: Dict[str, torch.optim.Optimizer],
Expand All @@ -68,19 +68,22 @@ def _update_param_with_optimizer(
names = list(params.keys())

for name in names:
param = params[name]
new_param = param_fn(name, param, param.requires_grad)
params[name] = new_param
if name not in optimizers:
assert not param.requires_grad
continue
optimizer = optimizers[name]
for i, param_group in enumerate(optimizer.param_groups):
p = param_group["params"][0]
p_state = optimizer.state[p]
del optimizer.state[p]
for key in p_state.keys():
for i in range(len(optimizer.param_groups)):
param_state = optimizer.state[param]
del optimizer.state[param]
for key in param_state.keys():
if key != "step":
v = p_state[key]
p_state[key] = optimizer_fn(key, v)
p_new = param_fn(name, p)
optimizer.param_groups[i]["params"] = [p_new]
optimizer.state[p_new] = p_state
params[name] = p_new
v = param_state[key]
param_state[key] = optimizer_fn(key, v)
optimizer.param_groups[i]["params"] = [new_param]
optimizer.state[new_param] = param_state


@torch.no_grad()
Expand All @@ -100,8 +103,8 @@ def duplicate(
device = mask.device
sel = torch.where(mask)[0]

def param_fn(name: str, p: Tensor) -> Tensor:
return torch.nn.Parameter(torch.cat([p, p[sel]]))
def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
return torch.nn.Parameter(torch.cat([p, p[sel]]), requires_grad=requires_grad)

def optimizer_fn(key: str, v: Tensor) -> Tensor:
return torch.cat([v, torch.zeros((len(sel), *v.shape[1:]), device=device)])
Expand Down Expand Up @@ -145,7 +148,7 @@ def split(
torch.randn(2, len(scales), 3, device=device),
) # [2, N, 3]

def param_fn(name: str, p: Tensor) -> Tensor:
def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
repeats = [2] + [1] * (p.dim() - 1)
if name == "means":
p_split = (p[sel] + samples).reshape(-1, 3) # [2N, 3]
Expand All @@ -157,7 +160,7 @@ def param_fn(name: str, p: Tensor) -> Tensor:
else:
p_split = p[sel].repeat(repeats)
p_new = torch.cat([p[rest], p_split])
p_new = torch.nn.Parameter(p_new)
p_new = torch.nn.Parameter(p_new, requires_grad=requires_grad)
return p_new

def optimizer_fn(key: str, v: Tensor) -> Tensor:
Expand Down Expand Up @@ -190,8 +193,8 @@ def remove(
"""
sel = torch.where(~mask)[0]

def param_fn(name: str, p: Tensor) -> Tensor:
return torch.nn.Parameter(p[sel])
def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
return torch.nn.Parameter(p[sel], requires_grad=requires_grad)

def optimizer_fn(key: str, v: Tensor) -> Tensor:
return v[sel]
Expand Down Expand Up @@ -219,10 +222,10 @@ def reset_opa(
value: The value to reset the opacities
"""

def param_fn(name: str, p: Tensor) -> Tensor:
def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
if name == "opacities":
opacities = torch.clamp(p, max=torch.logit(torch.tensor(value)).item())
return torch.nn.Parameter(opacities)
return torch.nn.Parameter(opacities, requires_grad=requires_grad)
else:
raise ValueError(f"Unexpected parameter name: {name}")

Expand Down Expand Up @@ -271,13 +274,13 @@ def relocate(
)
new_opacities = torch.clamp(new_opacities, max=1.0 - eps, min=min_opacity)

def param_fn(name: str, p: Tensor) -> Tensor:
def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
if name == "opacities":
p[sampled_idxs] = torch.logit(new_opacities)
elif name == "scales":
p[sampled_idxs] = torch.log(new_scales)
p[dead_indices] = p[sampled_idxs]
return torch.nn.Parameter(p)
return torch.nn.Parameter(p, requires_grad=requires_grad)

def optimizer_fn(key: str, v: Tensor) -> Tensor:
v[sampled_idxs] = 0
Expand Down Expand Up @@ -313,13 +316,13 @@ def sample_add(
)
new_opacities = torch.clamp(new_opacities, max=1.0 - eps, min=min_opacity)

def param_fn(name: str, p: Tensor) -> Tensor:
def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
if name == "opacities":
p[sampled_idxs] = torch.logit(new_opacities)
elif name == "scales":
p[sampled_idxs] = torch.log(new_scales)
p = torch.cat([p, p[sampled_idxs]])
return torch.nn.Parameter(p)
return torch.nn.Parameter(p, requires_grad=requires_grad)

def optimizer_fn(key: str, v: Tensor) -> Tensor:
v_new = torch.zeros((len(sampled_idxs), *v.shape[1:]), device=v.device)
Expand Down
67 changes: 67 additions & 0 deletions tests/test_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,72 @@ def test_strategy():
strategy.step_post_backward(params, optimizers, state, step=600, info=info, lr=1e-3)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device")
def test_strategy_requires_grad():
from gsplat.rendering import rasterization
from gsplat.strategy import DefaultStrategy, MCMCStrategy

def assert_consistent_sizes(params):
sizes = [v.shape[0] for v in params.values()]
assert all([s == sizes[0] for s in sizes])

torch.manual_seed(42)

# Prepare Gaussians
N = 100
params = torch.nn.ParameterDict(
{
"means": torch.randn(N, 3),
"scales": torch.rand(N, 3),
"quats": torch.randn(N, 4),
"opacities": torch.rand(N),
"colors": torch.rand(N, 3),
"non_trainable_features": torch.rand(N, 3),
}
).to(device)
params["non_trainable_features"].requires_grad = False
requires_grad_map = {k: v.requires_grad for k, v in params.items()}
optimizers = {
k: torch.optim.Adam([v], lr=1e-3) for k, v in params.items() if v.requires_grad
}

# A dummy rendering call
render_colors, render_alphas, info = rasterization(
means=params["means"],
quats=params["quats"], # F.normalize is fused into the kernel
scales=torch.exp(params["scales"]),
opacities=torch.sigmoid(params["opacities"]),
colors=params["colors"],
viewmats=torch.eye(4).unsqueeze(0).to(device),
Ks=torch.eye(3).unsqueeze(0).to(device),
width=10,
height=10,
packed=False,
)

# Test DefaultStrategy
strategy = DefaultStrategy(verbose=True)
strategy.check_sanity(params, optimizers)
state = strategy.initialize_state()
strategy.step_pre_backward(params, optimizers, state, step=600, info=info)
render_colors.mean().backward(retain_graph=True)
strategy.step_post_backward(params, optimizers, state, step=600, info=info)
for k, v in params.items():
assert v.requires_grad == requires_grad_map[k]
assert params["non_trainable_features"].grad is None
assert_consistent_sizes(params)
# Test MCMCStrategy
strategy = MCMCStrategy(verbose=True)
strategy.check_sanity(params, optimizers)
state = strategy.initialize_state()
render_colors.mean().backward(retain_graph=True)
strategy.step_post_backward(params, optimizers, state, step=600, info=info, lr=1e-3)
assert params["non_trainable_features"].grad is None
for k, v in params.items():
assert v.requires_grad == requires_grad_map[k]
assert_consistent_sizes(params)


if __name__ == "__main__":
test_strategy()
test_strategy_requires_grad()

0 comments on commit 79da6bf

Please sign in to comment.