Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add-support-for-non-trainable-params #456

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions gsplat/strategy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ def check_sanity(
optimizers: Dict[str, torch.optim.Optimizer],
):
"""Sanity check for the parameters and optimizers."""
assert set(params.keys()) == set(optimizers.keys()), (
"params and optimizers must have the same keys, "
f"but got {params.keys()} and {optimizers.keys()}"
trainable_params = set(
[name for name, param in params.items() if param.requires_grad]
)
assert trainable_params == set(optimizers.keys()), (
"trainable parameters and optimizers must have the same keys, "
f"but got {trainable_params} and {optimizers.keys()}"
)

for optimizer in optimizers.values():
Expand Down
51 changes: 27 additions & 24 deletions gsplat/strategy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _multinomial_sample(weights: Tensor, n: int, replacement: bool = True) -> Te

@torch.no_grad()
def _update_param_with_optimizer(
param_fn: Callable[[str, Tensor], Tensor],
param_fn: Callable[[str, Tensor, bool], Tensor],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks the API of this _update_param_with_optimizer function. Though it is kinda ok because it's an internal function, there seem to be a simple way that wouldn't break it:

Instead of

def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
        return torch.nn.Parameter(p[sel], requires_grad=requires_grad)

We could do

def param_fn(name: str, p: Tensor) -> Tensor:
        return torch.nn.Parameter(p[sel], requires_grad=p.requires_grad)

optimizer_fn: Callable[[str, Tensor], Tensor],
params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
optimizers: Dict[str, torch.optim.Optimizer],
Expand All @@ -68,19 +68,22 @@ def _update_param_with_optimizer(
names = list(params.keys())

for name in names:
param = params[name]
new_param = param_fn(name, param, param.requires_grad)
params[name] = new_param
if name not in optimizers:
assert not param.requires_grad
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a msg to explain. sth like assert not param.requires_grad, f"param {name} does not present in the optimizer, it's requires_grad should be False, but found True"

continue
optimizer = optimizers[name]
for i, param_group in enumerate(optimizer.param_groups):
p = param_group["params"][0]
p_state = optimizer.state[p]
del optimizer.state[p]
for key in p_state.keys():
for i in range(len(optimizer.param_groups)):
param_state = optimizer.state[param]
del optimizer.state[param]
for key in param_state.keys():
if key != "step":
v = p_state[key]
p_state[key] = optimizer_fn(key, v)
p_new = param_fn(name, p)
optimizer.param_groups[i]["params"] = [p_new]
optimizer.state[p_new] = p_state
params[name] = p_new
v = param_state[key]
param_state[key] = optimizer_fn(key, v)
optimizer.param_groups[i]["params"] = [new_param]
optimizer.state[new_param] = param_state


@torch.no_grad()
Expand All @@ -100,8 +103,8 @@ def duplicate(
device = mask.device
sel = torch.where(mask)[0]

def param_fn(name: str, p: Tensor) -> Tensor:
return torch.nn.Parameter(torch.cat([p, p[sel]]))
def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
return torch.nn.Parameter(torch.cat([p, p[sel]]), requires_grad=requires_grad)

def optimizer_fn(key: str, v: Tensor) -> Tensor:
return torch.cat([v, torch.zeros((len(sel), *v.shape[1:]), device=device)])
Expand Down Expand Up @@ -145,7 +148,7 @@ def split(
torch.randn(2, len(scales), 3, device=device),
) # [2, N, 3]

def param_fn(name: str, p: Tensor) -> Tensor:
def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
repeats = [2] + [1] * (p.dim() - 1)
if name == "means":
p_split = (p[sel] + samples).reshape(-1, 3) # [2N, 3]
Expand All @@ -157,7 +160,7 @@ def param_fn(name: str, p: Tensor) -> Tensor:
else:
p_split = p[sel].repeat(repeats)
p_new = torch.cat([p[rest], p_split])
p_new = torch.nn.Parameter(p_new)
p_new = torch.nn.Parameter(p_new, requires_grad=requires_grad)
return p_new

def optimizer_fn(key: str, v: Tensor) -> Tensor:
Expand Down Expand Up @@ -190,8 +193,8 @@ def remove(
"""
sel = torch.where(~mask)[0]

def param_fn(name: str, p: Tensor) -> Tensor:
return torch.nn.Parameter(p[sel])
def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
return torch.nn.Parameter(p[sel], requires_grad=requires_grad)

def optimizer_fn(key: str, v: Tensor) -> Tensor:
return v[sel]
Expand Down Expand Up @@ -219,10 +222,10 @@ def reset_opa(
value: The value to reset the opacities
"""

def param_fn(name: str, p: Tensor) -> Tensor:
def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
if name == "opacities":
opacities = torch.clamp(p, max=torch.logit(torch.tensor(value)).item())
return torch.nn.Parameter(opacities)
return torch.nn.Parameter(opacities, requires_grad=requires_grad)
else:
raise ValueError(f"Unexpected parameter name: {name}")

Expand Down Expand Up @@ -271,13 +274,13 @@ def relocate(
)
new_opacities = torch.clamp(new_opacities, max=1.0 - eps, min=min_opacity)

def param_fn(name: str, p: Tensor) -> Tensor:
def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
if name == "opacities":
p[sampled_idxs] = torch.logit(new_opacities)
elif name == "scales":
p[sampled_idxs] = torch.log(new_scales)
p[dead_indices] = p[sampled_idxs]
return torch.nn.Parameter(p)
return torch.nn.Parameter(p, requires_grad=requires_grad)

def optimizer_fn(key: str, v: Tensor) -> Tensor:
v[sampled_idxs] = 0
Expand Down Expand Up @@ -313,13 +316,13 @@ def sample_add(
)
new_opacities = torch.clamp(new_opacities, max=1.0 - eps, min=min_opacity)

def param_fn(name: str, p: Tensor) -> Tensor:
def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
if name == "opacities":
p[sampled_idxs] = torch.logit(new_opacities)
elif name == "scales":
p[sampled_idxs] = torch.log(new_scales)
p = torch.cat([p, p[sampled_idxs]])
return torch.nn.Parameter(p)
return torch.nn.Parameter(p, requires_grad=requires_grad)

def optimizer_fn(key: str, v: Tensor) -> Tensor:
v_new = torch.zeros((len(sampled_idxs), *v.shape[1:]), device=v.device)
Expand Down
67 changes: 67 additions & 0 deletions tests/test_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,72 @@ def test_strategy():
strategy.step_post_backward(params, optimizers, state, step=600, info=info, lr=1e-3)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device")
def test_strategy_requires_grad():
from gsplat.rendering import rasterization
from gsplat.strategy import DefaultStrategy, MCMCStrategy

def assert_consistent_sizes(params):
sizes = [v.shape[0] for v in params.values()]
assert all([s == sizes[0] for s in sizes])

torch.manual_seed(42)

# Prepare Gaussians
N = 100
params = torch.nn.ParameterDict(
{
"means": torch.randn(N, 3),
"scales": torch.rand(N, 3),
"quats": torch.randn(N, 4),
"opacities": torch.rand(N),
"colors": torch.rand(N, 3),
"non_trainable_features": torch.rand(N, 3),
}
).to(device)
params["non_trainable_features"].requires_grad = False
requires_grad_map = {k: v.requires_grad for k, v in params.items()}
optimizers = {
k: torch.optim.Adam([v], lr=1e-3) for k, v in params.items() if v.requires_grad
}

# A dummy rendering call
render_colors, render_alphas, info = rasterization(
means=params["means"],
quats=params["quats"], # F.normalize is fused into the kernel
scales=torch.exp(params["scales"]),
opacities=torch.sigmoid(params["opacities"]),
colors=params["colors"],
viewmats=torch.eye(4).unsqueeze(0).to(device),
Ks=torch.eye(3).unsqueeze(0).to(device),
width=10,
height=10,
packed=False,
)

# Test DefaultStrategy
strategy = DefaultStrategy(verbose=True)
strategy.check_sanity(params, optimizers)
state = strategy.initialize_state()
strategy.step_pre_backward(params, optimizers, state, step=600, info=info)
render_colors.mean().backward(retain_graph=True)
strategy.step_post_backward(params, optimizers, state, step=600, info=info)
for k, v in params.items():
assert v.requires_grad == requires_grad_map[k]
assert params["non_trainable_features"].grad is None
assert_consistent_sizes(params)
# Test MCMCStrategy
strategy = MCMCStrategy(verbose=True)
strategy.check_sanity(params, optimizers)
state = strategy.initialize_state()
render_colors.mean().backward(retain_graph=True)
strategy.step_post_backward(params, optimizers, state, step=600, info=info, lr=1e-3)
assert params["non_trainable_features"].grad is None
for k, v in params.items():
assert v.requires_grad == requires_grad_map[k]
assert_consistent_sizes(params)


if __name__ == "__main__":
test_strategy()
test_strategy_requires_grad()
Loading