Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nikmo33 committed Nov 6, 2024
1 parent 79da6bf commit fede42c
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions gsplat/strategy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _multinomial_sample(weights: Tensor, n: int, replacement: bool = True) -> Te

@torch.no_grad()
def _update_param_with_optimizer(
param_fn: Callable[[str, Tensor, bool], Tensor],
param_fn: Callable[[str, Tensor], Tensor],
optimizer_fn: Callable[[str, Tensor], Tensor],
params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
optimizers: Dict[str, torch.optim.Optimizer],
Expand All @@ -69,10 +69,13 @@ def _update_param_with_optimizer(

for name in names:
param = params[name]
new_param = param_fn(name, param, param.requires_grad)
new_param = param_fn(name, param)
params[name] = new_param
if name not in optimizers:
assert not param.requires_grad
assert not param.requires_grad, (
f"Optimizer for {name} is not found, but the parameter is trainable."
f"Got requires_grad={param.requires_grad}"
)
continue
optimizer = optimizers[name]
for i in range(len(optimizer.param_groups)):
Expand Down Expand Up @@ -103,8 +106,8 @@ def duplicate(
device = mask.device
sel = torch.where(mask)[0]

def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
return torch.nn.Parameter(torch.cat([p, p[sel]]), requires_grad=requires_grad)
def param_fn(name: str, p: Tensor) -> Tensor:
return torch.nn.Parameter(torch.cat([p, p[sel]]), requires_grad=p.requires_grad)

def optimizer_fn(key: str, v: Tensor) -> Tensor:
return torch.cat([v, torch.zeros((len(sel), *v.shape[1:]), device=device)])
Expand Down Expand Up @@ -148,7 +151,7 @@ def split(
torch.randn(2, len(scales), 3, device=device),
) # [2, N, 3]

def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
def param_fn(name: str, p: Tensor) -> Tensor:
repeats = [2] + [1] * (p.dim() - 1)
if name == "means":
p_split = (p[sel] + samples).reshape(-1, 3) # [2N, 3]
Expand All @@ -160,7 +163,7 @@ def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
else:
p_split = p[sel].repeat(repeats)
p_new = torch.cat([p[rest], p_split])
p_new = torch.nn.Parameter(p_new, requires_grad=requires_grad)
p_new = torch.nn.Parameter(p_new, requires_grad=p.requires_grad)
return p_new

def optimizer_fn(key: str, v: Tensor) -> Tensor:
Expand Down Expand Up @@ -193,8 +196,8 @@ def remove(
"""
sel = torch.where(~mask)[0]

def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
return torch.nn.Parameter(p[sel], requires_grad=requires_grad)
def param_fn(name: str, p: Tensor) -> Tensor:
return torch.nn.Parameter(p[sel], requires_grad=p.requires_grad)

def optimizer_fn(key: str, v: Tensor) -> Tensor:
return v[sel]
Expand Down Expand Up @@ -222,10 +225,10 @@ def reset_opa(
value: The value to reset the opacities
"""

def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
def param_fn(name: str, p: Tensor) -> Tensor:
if name == "opacities":
opacities = torch.clamp(p, max=torch.logit(torch.tensor(value)).item())
return torch.nn.Parameter(opacities, requires_grad=requires_grad)
return torch.nn.Parameter(opacities, requires_grad=p.requires_grad)
else:
raise ValueError(f"Unexpected parameter name: {name}")

Expand Down Expand Up @@ -274,13 +277,13 @@ def relocate(
)
new_opacities = torch.clamp(new_opacities, max=1.0 - eps, min=min_opacity)

def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
def param_fn(name: str, p: Tensor) -> Tensor:
if name == "opacities":
p[sampled_idxs] = torch.logit(new_opacities)
elif name == "scales":
p[sampled_idxs] = torch.log(new_scales)
p[dead_indices] = p[sampled_idxs]
return torch.nn.Parameter(p, requires_grad=requires_grad)
return torch.nn.Parameter(p, requires_grad=p.requires_grad)

def optimizer_fn(key: str, v: Tensor) -> Tensor:
v[sampled_idxs] = 0
Expand Down Expand Up @@ -316,13 +319,13 @@ def sample_add(
)
new_opacities = torch.clamp(new_opacities, max=1.0 - eps, min=min_opacity)

def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
def param_fn(name: str, p: Tensor) -> Tensor:
if name == "opacities":
p[sampled_idxs] = torch.logit(new_opacities)
elif name == "scales":
p[sampled_idxs] = torch.log(new_scales)
p = torch.cat([p, p[sampled_idxs]])
return torch.nn.Parameter(p, requires_grad=requires_grad)
p_new = torch.cat([p, p[sampled_idxs]])
return torch.nn.Parameter(p_new, requires_grad=p.requires_grad)

def optimizer_fn(key: str, v: Tensor) -> Tensor:
v_new = torch.zeros((len(sampled_idxs), *v.shape[1:]), device=v.device)
Expand Down

0 comments on commit fede42c

Please sign in to comment.