From 2ed1716980191abd818d048b686475f897267c96 Mon Sep 17 00:00:00 2001
From: Congrong Xu <50019703+KevinXu02@users.noreply.github.com>
Date: Wed, 14 Aug 2024 03:18:43 +0800
Subject: [PATCH] Bilagird for splatfacto (#3316)
add bilateral grid as option to splatfacto training
---
nerfstudio/configs/method_configs.py | 12 +
nerfstudio/model_components/lib_bilagrid.py | 547 ++++++++++++++++++++
nerfstudio/models/splatfacto.py | 58 +++
pyproject.toml | 3 +-
4 files changed, 619 insertions(+), 1 deletion(-)
create mode 100644 nerfstudio/model_components/lib_bilagrid.py
diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py
index ce54a35a35..e77ab130c4 100644
--- a/nerfstudio/configs/method_configs.py
+++ b/nerfstudio/configs/method_configs.py
@@ -637,6 +637,12 @@
lr_final=5e-7, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0
),
},
+ "bilateral_grid": {
+ "optimizer": AdamOptimizerConfig(lr=2e-3, eps=1e-15),
+ "scheduler": ExponentialDecaySchedulerConfig(
+ lr_final=1e-4, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0
+ ),
+ },
},
viewer=ViewerConfig(num_rays_per_chunk=1 << 15),
vis="viewer",
@@ -692,6 +698,12 @@
lr_final=5e-7, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0
),
},
+ "bilateral_grid": {
+ "optimizer": AdamOptimizerConfig(lr=5e-3, eps=1e-15),
+ "scheduler": ExponentialDecaySchedulerConfig(
+ lr_final=1e-4, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0
+ ),
+ },
},
viewer=ViewerConfig(num_rays_per_chunk=1 << 15),
vis="viewer",
diff --git a/nerfstudio/model_components/lib_bilagrid.py b/nerfstudio/model_components/lib_bilagrid.py
new file mode 100644
index 0000000000..93869cc385
--- /dev/null
+++ b/nerfstudio/model_components/lib_bilagrid.py
@@ -0,0 +1,547 @@
+# # Copyright 2024 Yuehao Wang (https://github.com/yuehaowang). This part of code is borrowed form ["Bilateral Guided Radiance Field Processing"](https://bilarfpro.github.io/).
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This is a standalone PyTorch implementation of 3D bilateral grid and CP-decomposed 4D bilateral grid.
+To use this module, you can download the "lib_bilagrid.py" file and simply put it in your project directory.
+
+For the details, please check our research project: ["Bilateral Guided Radiance Field Processing"](https://bilarfpro.github.io/).
+
+#### Dependencies
+
+In addition to PyTorch and Numpy, please install [tensorly](https://github.com/tensorly/tensorly).
+We have tested this module on Python 3.9.18, PyTorch 2.0.1 (CUDA 11), tensorly 0.8.1, and Numpy 1.25.2.
+
+#### Overview
+
+- For bilateral guided training, you need to construct a `BilateralGrid` instance, which can hold multiple bilateral grids
+ for input views. Then, use `slice` function to obtain transformed RGB output and the corresponding affine transformations.
+
+- For bilateral guided finishing, you need to instantiate a `BilateralGridCP4D` object and use `slice4d`.
+
+#### Examples
+
+- Bilateral grid for approximating ISP:
+
+
+
+- Low-rank 4D bilateral grid for MR enhancement:
+
+
+
+
+Below is the API reference.
+
+"""
+
+import tensorly as tl
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+tl.set_backend("pytorch")
+
+
+def color_correct(img: torch.Tensor, ref: torch.Tensor, num_iters: int = 5, eps: float = 0.5 / 255) -> torch.Tensor:
+ """
+ Warp `img` to match the colors in `ref_img` using iterative color matching.
+
+ This function performs color correction by warping the colors of the input image
+ to match those of a reference image. It uses a least squares method to find a
+ transformation that maps the input image's colors to the reference image's colors.
+
+ The algorithm iteratively solves a system of linear equations, updating the set of
+ unsaturated pixels in each iteration. This approach helps handle non-linear color
+ transformations and reduces the impact of clipping.
+
+ Args:
+ img (torch.Tensor): Input image to be color corrected. Shape: [..., num_channels]
+ ref (torch.Tensor): Reference image to match colors. Shape: [..., num_channels]
+ num_iters (int, optional): Number of iterations for the color matching process.
+ Default is 5.
+ eps (float, optional): Small value to determine the range of unclipped pixels.
+ Default is 0.5 / 255.
+
+ Returns:
+ torch.Tensor: Color corrected image with the same shape as the input image.
+
+ Note:
+ - Both input and reference images should be in the range [0, 1].
+ - The function works with any number of channels, but typically used with 3 (RGB).
+ """
+ if img.shape[-1] != ref.shape[-1]:
+ raise ValueError(f"img's {img.shape[-1]} and ref's {ref.shape[-1]} channels must match")
+ num_channels = img.shape[-1]
+ img_mat = img.reshape([-1, num_channels])
+ ref_mat = ref.reshape([-1, num_channels])
+
+ def is_unclipped(z):
+ return (z >= eps) & (z <= 1 - eps) # z \in [eps, 1-eps].
+
+ mask0 = is_unclipped(img_mat)
+ # Because the set of saturated pixels may change after solving for a
+ # transformation, we repeatedly solve a system `num_iters` times and update
+ # our estimate of which pixels are saturated.
+ for _ in range(num_iters):
+ # Construct the left hand side of a linear system that contains a quadratic
+ # expansion of each pixel of `img`.
+ a_mat = []
+ for c in range(num_channels):
+ a_mat.append(img_mat[:, c : (c + 1)] * img_mat[:, c:]) # Quadratic term.
+ a_mat.append(img_mat) # Linear term.
+ a_mat.append(torch.ones_like(img_mat[:, :1])) # Bias term.
+ a_mat = torch.cat(a_mat, dim=-1)
+ warp = []
+ for c in range(num_channels):
+ # Construct the right hand side of a linear system containing each color
+ # of `ref`.
+ b = ref_mat[:, c]
+ # Ignore rows of the linear system that were saturated in the input or are
+ # saturated in the current corrected color estimate.
+ mask = mask0[:, c] & is_unclipped(img_mat[:, c]) & is_unclipped(b)
+ ma_mat = torch.where(mask[:, None], a_mat, torch.zeros_like(a_mat))
+ mb = torch.where(mask, b, torch.zeros_like(b))
+ w = torch.linalg.lstsq(ma_mat, mb, rcond=-1)[0]
+ assert torch.all(torch.isfinite(w))
+ warp.append(w)
+ warp = torch.stack(warp, dim=-1)
+ # Apply the warp to update img_mat.
+ img_mat = torch.clip(torch.matmul(a_mat, warp), 0, 1)
+ corrected_img = torch.reshape(img_mat, img.shape)
+ return corrected_img
+
+
+def bilateral_grid_tv_loss(model, config):
+ """Computes total variations of bilateral grids."""
+ total_loss = 0.0
+
+ for bil_grids in model.bil_grids:
+ total_loss += config.bilgrid_tv_loss_mult * total_variation_loss(bil_grids.grids)
+
+ return total_loss
+
+
+def color_affine_transform(affine_mats, rgb):
+ """Applies color affine transformations.
+
+ Args:
+ affine_mats (torch.Tensor): Affine transformation matrices. Supported shape: $(..., 3, 4)$.
+ rgb (torch.Tensor): Input RGB values. Supported shape: $(..., 3)$.
+
+ Returns:
+ Output transformed colors of shape $(..., 3)$.
+ """
+ return torch.matmul(affine_mats[..., :3], rgb.unsqueeze(-1)).squeeze(-1) + affine_mats[..., 3]
+
+
+def _num_tensor_elems(t):
+ return max(torch.prod(torch.tensor(t.size()[1:]).float()).item(), 1.0)
+
+
+def total_variation_loss(x): # noqa: F811
+ """Returns total variation on multi-dimensional tensors.
+
+ Args:
+ x (torch.Tensor): The input tensor with shape $(B, C, ...)$, where $B$ is the batch size and $C$ is the channel size.
+ """
+ batch_size = x.shape[0]
+ tv = 0
+ for i in range(2, len(x.shape)):
+ n_res = x.shape[i]
+ idx1 = torch.arange(1, n_res, device=x.device)
+ idx2 = torch.arange(0, n_res - 1, device=x.device)
+ x1 = x.index_select(i, idx1)
+ x2 = x.index_select(i, idx2)
+ count = _num_tensor_elems(x1)
+ tv += torch.pow((x1 - x2), 2).sum() / count
+ return tv / batch_size
+
+
+def slice(bil_grids, xy, rgb, grid_idx):
+ """Slices a batch of 3D bilateral grids by pixel coordinates `xy` and gray-scale guidances of pixel colors `rgb`.
+
+ Supports 2-D, 3-D, and 4-D input shapes. The first dimension of the input is the batch size
+ and the last dimension is 2 for `xy`, 3 for `rgb`, and 1 for `grid_idx`.
+
+ The return value is a dictionary containing the affine transformations `affine_mats` sliced from bilateral grids and
+ the output color `rgb_out` after applying the afffine transformations.
+
+ In the 2-D input case, `xy` is a $(N, 2)$ tensor, `rgb` is a $(N, 3)$ tensor, and `grid_idx` is a $(N, 1)$ tensor.
+ Then `affine_mats[i]` can be obtained via slicing the bilateral grid indexed at `grid_idx[i]` by `xy[i, :]` and `rgb2gray(rgb[i, :])`.
+ For 3-D and 4-D input cases, the behavior of indexing bilateral grids and coordinates is the same with the 2-D case.
+
+ .. note::
+ This function can be regarded as a wrapper of `color_affine_transform` and `BilateralGrid` with a slight performance improvement.
+ When `grid_idx` contains a unique index, only a single bilateral grid will used during the slicing. In this case, this function will not
+ perform tensor indexing to avoid data copy and extra memory
+ (see [this](https://discuss.pytorch.org/t/does-indexing-a-tensor-return-a-copy-of-it/164905)).
+
+ Args:
+ bil_grids (`BilateralGrid`): An instance of $N$ bilateral grids.
+ xy (torch.Tensor): The x-y coordinates of shape $(..., 2)$ in the range of $[0,1]$.
+ rgb (torch.Tensor): The RGB values of shape $(..., 3)$ for computing the guidance coordinates, ranging in $[0,1]$.
+ grid_idx (torch.Tensor): The indices of bilateral grids for each slicing. Shape: $(..., 1)$.
+
+ Returns:
+ A dictionary with keys and values as follows:
+ ```
+ {
+ "rgb": Transformed RGB colors. Shape: (..., 3),
+ "rgb_affine_mats": The sliced affine transformation matrices from bilateral grids. Shape: (..., 3, 4)
+ }
+ ```
+ """
+
+ sh_ = rgb.shape
+
+ grid_idx_unique = torch.unique(grid_idx)
+ if len(grid_idx_unique) == 1:
+ # All pixels are from a single view.
+ grid_idx = grid_idx_unique # (1,)
+ xy = xy.unsqueeze(0) # (1, ..., 2)
+ rgb = rgb.unsqueeze(0) # (1, ..., 3)
+ else:
+ # Pixels are randomly sampled from different views.
+ if len(grid_idx.shape) == 4:
+ grid_idx = grid_idx[:, 0, 0, 0] # (chunk_size,)
+ elif len(grid_idx.shape) == 3:
+ grid_idx = grid_idx[:, 0, 0] # (chunk_size,)
+ elif len(grid_idx.shape) == 2:
+ grid_idx = grid_idx[:, 0] # (chunk_size,)
+ else:
+ raise ValueError("The input to bilateral grid slicing is not supported yet.")
+
+ affine_mats = bil_grids(xy, rgb, grid_idx)
+ rgb = color_affine_transform(affine_mats, rgb)
+
+ return {
+ "rgb": rgb.reshape(*sh_),
+ "rgb_affine_mats": affine_mats.reshape(*sh_[:-1], affine_mats.shape[-2], affine_mats.shape[-1]),
+ }
+
+
+class BilateralGrid(nn.Module):
+ """Class for 3D bilateral grids.
+
+ Holds one or more than one bilateral grids.
+ """
+
+ def __init__(self, num, grid_X=16, grid_Y=16, grid_W=8):
+ """
+ Args:
+ num (int): The number of bilateral grids (i.e., the number of views).
+ grid_X (int): Defines grid width $W$.
+ grid_Y (int): Defines grid height $H$.
+ grid_W (int): Defines grid guidance dimension $L$.
+ """
+ super(BilateralGrid, self).__init__()
+
+ self.grid_width = grid_X
+ """Grid width. Type: int."""
+ self.grid_height = grid_Y
+ """Grid height. Type: int."""
+ self.grid_guidance = grid_W
+ """Grid guidance dimension. Type: int."""
+
+ # Initialize grids.
+ grid = self._init_identity_grid()
+ self.grids = nn.Parameter(grid.tile(num, 1, 1, 1, 1)) # (N, 12, L, H, W)
+ """ A 5-D tensor of shape $(N, 12, L, H, W)$."""
+
+ # Weights of BT601 RGB-to-gray.
+ self.register_buffer("rgb2gray_weight", torch.Tensor([[0.299, 0.587, 0.114]]))
+ self.rgb2gray = lambda rgb: (rgb @ self.rgb2gray_weight.T) * 2.0 - 1.0
+ """ A function that converts RGB to gray-scale guidance in $[-1, 1]$."""
+
+ def _init_identity_grid(self):
+ grid = torch.tensor(
+ [
+ 1.0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.0,
+ 0,
+ ]
+ ).float()
+ grid = grid.repeat([self.grid_guidance * self.grid_height * self.grid_width, 1]) # (L * H * W, 12)
+ grid = grid.reshape(1, self.grid_guidance, self.grid_height, self.grid_width, -1) # (1, L, H, W, 12)
+ grid = grid.permute(0, 4, 1, 2, 3) # (1, 12, L, H, W)
+ return grid
+
+ def tv_loss(self):
+ """Computes and returns total variation loss on the bilateral grids."""
+ return total_variation_loss(self.grids)
+
+ def forward(self, grid_xy, rgb, idx=None):
+ """Bilateral grid slicing. Supports 2-D, 3-D, 4-D, and 5-D input.
+ For the 2-D, 3-D, and 4-D cases, please refer to `slice`.
+ For the 5-D cases, `idx` will be unused and the first dimension of `xy` should be
+ equal to the number of bilateral grids. Then this function becomes PyTorch's
+ [`F.grid_sample`](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html).
+
+ Args:
+ grid_xy (torch.Tensor): The x-y coordinates in the range of $[0,1]$.
+ rgb (torch.Tensor): The RGB values in the range of $[0,1]$.
+ idx (torch.Tensor): The bilateral grid indices.
+
+ Returns:
+ Sliced affine matrices of shape $(..., 3, 4)$.
+ """
+
+ grids = self.grids
+ input_ndims = len(grid_xy.shape)
+ assert len(rgb.shape) == input_ndims
+
+ if input_ndims > 1 and input_ndims < 5:
+ # Convert input into 5D
+ for i in range(5 - input_ndims):
+ grid_xy = grid_xy.unsqueeze(1)
+ rgb = rgb.unsqueeze(1)
+ assert idx is not None
+ elif input_ndims != 5:
+ raise ValueError("Bilateral grid slicing only takes either 2D, 3D, 4D and 5D inputs")
+
+ grids = self.grids
+ if idx is not None:
+ grids = grids[idx]
+ assert grids.shape[0] == grid_xy.shape[0]
+
+ # Generate slicing coordinates.
+ grid_xy = (grid_xy - 0.5) * 2 # Rescale to [-1, 1].
+ grid_z = self.rgb2gray(rgb)
+
+ # print(grid_xy.shape, grid_z.shape)
+ # exit()
+ grid_xyz = torch.cat([grid_xy, grid_z], dim=-1) # (N, m, h, w, 3)
+
+ affine_mats = F.grid_sample(
+ grids, grid_xyz, mode="bilinear", align_corners=True, padding_mode="border"
+ ) # (N, 12, m, h, w)
+ affine_mats = affine_mats.permute(0, 2, 3, 4, 1) # (N, m, h, w, 12)
+ affine_mats = affine_mats.reshape(*affine_mats.shape[:-1], 3, 4) # (N, m, h, w, 3, 4)
+
+ for _ in range(5 - input_ndims):
+ affine_mats = affine_mats.squeeze(1)
+
+ return affine_mats
+
+
+def slice4d(bil_grid4d, xyz, rgb):
+ """Slices a 4D bilateral grid by point coordinates `xyz` and gray-scale guidances of radiance colors `rgb`.
+
+ Args:
+ bil_grid4d (`BilateralGridCP4D`): The input 4D bilateral grid.
+ xyz (torch.Tensor): The xyz coordinates with shape $(..., 3)$.
+ rgb (torch.Tensor): The RGB values with shape $(..., 3)$.
+
+ Returns:
+ A dictionary with keys and values as follows:
+ ```
+ {
+ "rgb": Transformed radiance RGB colors. Shape: (..., 3),
+ "rgb_affine_mats": The sliced affine transformation matrices from the 4D bilateral grid. Shape: (..., 3, 4)
+ }
+ ```
+ """
+
+ affine_mats = bil_grid4d(xyz, rgb)
+ rgb = color_affine_transform(affine_mats, rgb)
+
+ return {"rgb": rgb, "rgb_affine_mats": affine_mats}
+
+
+class _ScaledTanh(nn.Module):
+ def __init__(self, s=2.0):
+ super().__init__()
+ self.scaler = s
+
+ def forward(self, x):
+ return torch.tanh(self.scaler * x)
+
+
+class BilateralGridCP4D(nn.Module):
+ """Class for low-rank 4D bilateral grids."""
+
+ def __init__(
+ self,
+ grid_X=16,
+ grid_Y=16,
+ grid_Z=16,
+ grid_W=8,
+ rank=5,
+ learn_gray=True,
+ gray_mlp_width=8,
+ gray_mlp_depth=2,
+ init_noise_scale=1e-6,
+ bound=2.0,
+ ):
+ """
+ Args:
+ grid_X (int): Defines grid width.
+ grid_Y (int): Defines grid height.
+ grid_Z (int): Defines grid depth.
+ grid_W (int): Defines grid guidance dimension.
+ rank (int): Rank of the 4D bilateral grid.
+ learn_gray (bool): If True, an MLP will be learned to convert RGB colors to gray-scale guidances.
+ gray_mlp_width (int): The MLP width for learnable guidance.
+ gray_mlp_depth (int): The number of MLP layers for learnable guidance.
+ init_noise_scale (float): The noise scale of the initialized factors.
+ bound (float): The bound of the xyz coordinates.
+ """
+ super(BilateralGridCP4D, self).__init__()
+
+ self.grid_X = grid_X
+ """Grid width. Type: int."""
+ self.grid_Y = grid_Y
+ """Grid height. Type: int."""
+ self.grid_Z = grid_Z
+ """Grid depth. Type: int."""
+ self.grid_W = grid_W
+ """Grid guidance dimension. Type: int."""
+ self.rank = rank
+ """Rank of the 4D bilateral grid. Type: int."""
+ self.learn_gray = learn_gray
+ """Flags of learnable guidance is used. Type: bool."""
+ self.gray_mlp_width = gray_mlp_width
+ """The MLP width for learnable guidance. Type: int."""
+ self.gray_mlp_depth = gray_mlp_depth
+ """The MLP depth for learnable guidance. Type: int."""
+ self.init_noise_scale = init_noise_scale
+ """The noise scale of the initialized factors. Type: float."""
+ self.bound = bound
+ """The bound of the xyz coordinates. Type: float."""
+
+ self._init_cp_factors_parafac()
+
+ self.rgb2gray = None
+ """ A function that converts RGB to gray-scale guidances in $[-1, 1]$.
+ If `learn_gray` is True, this will be an MLP network."""
+
+ if self.learn_gray:
+
+ def rgb2gray_mlp_linear(layer):
+ return nn.Linear(self.gray_mlp_width, self.gray_mlp_width if layer < self.gray_mlp_depth - 1 else 1)
+
+ def rgb2gray_mlp_actfn(_):
+ return nn.ReLU(inplace=True)
+
+ self.rgb2gray = nn.Sequential(
+ *(
+ [nn.Linear(3, self.gray_mlp_width)]
+ + [
+ nn_module(layer)
+ for layer in range(1, self.gray_mlp_depth)
+ for nn_module in [rgb2gray_mlp_actfn, rgb2gray_mlp_linear]
+ ]
+ + [_ScaledTanh(2.0)]
+ )
+ )
+ else:
+ # Weights of BT601/BT470 RGB-to-gray.
+ self.register_buffer("rgb2gray_weight", torch.Tensor([[0.299, 0.587, 0.114]]))
+ self.rgb2gray = lambda rgb: (rgb @ self.rgb2gray_weight.T) * 2.0 - 1.0
+
+ def _init_identity_grid(self):
+ grid = torch.tensor(
+ [
+ 1.0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.0,
+ 0,
+ ]
+ ).float()
+ grid = grid.repeat([self.grid_W * self.grid_Z * self.grid_Y * self.grid_X, 1])
+ grid = grid.reshape(self.grid_W, self.grid_Z, self.grid_Y, self.grid_X, -1)
+ grid = grid.permute(4, 0, 1, 2, 3) # (12, grid_W, grid_Z, grid_Y, grid_X)
+ return grid
+
+ def _init_cp_factors_parafac(self):
+ # Initialize identity grids.
+ init_grids = self._init_identity_grid()
+ # Random noises are added to avoid singularity.
+ init_grids = torch.randn_like(init_grids) * self.init_noise_scale + init_grids
+ from tensorly.decomposition import parafac
+
+ # Initialize grid CP factors
+ _, facs = parafac(init_grids.clone().detach(), rank=self.rank)
+
+ self.num_facs = len(facs)
+
+ self.fac_0 = nn.Linear(facs[0].shape[0], facs[0].shape[1], bias=False)
+ self.fac_0.weight = nn.Parameter(facs[0]) # (12, rank)
+
+ for i in range(1, self.num_facs):
+ fac = facs[i].T # (rank, grid_size)
+ fac = fac.view(1, fac.shape[0], fac.shape[1], 1) # (1, rank, grid_size, 1)
+ self.register_buffer(f"fac_{i}_init", fac)
+
+ fac_resid = torch.zeros_like(fac)
+ self.register_parameter(f"fac_{i}", nn.Parameter(fac_resid))
+
+ def tv_loss(self):
+ """Computes and returns total variation loss on the factors of the low-rank 4D bilateral grids."""
+
+ total_loss = 0
+ for i in range(1, self.num_facs):
+ fac = self.get_parameter(f"fac_{i}")
+ total_loss += total_variation_loss(fac)
+
+ return total_loss
+
+ def forward(self, xyz, rgb):
+ """Low-rank 4D bilateral grid slicing.
+
+ Args:
+ xyz (torch.Tensor): The xyz coordinates with shape $(..., 3)$.
+ rgb (torch.Tensor): The corresponding RGB values with shape $(..., 3)$.
+
+ Returns:
+ Sliced affine matrices with shape $(..., 3, 4)$.
+ """
+ sh_ = xyz.shape
+ xyz = xyz.reshape(-1, 3) # flatten (N, 3)
+ rgb = rgb.reshape(-1, 3) # flatten (N, 3)
+
+ xyz = xyz / self.bound
+ assert self.rgb2gray is not None
+ gray = self.rgb2gray(rgb)
+ xyzw = torch.cat([xyz, gray], dim=-1) # (N, 4)
+ xyzw = xyzw.transpose(0, 1) # (4, N)
+ coords = torch.stack([torch.zeros_like(xyzw), xyzw], dim=-1) # (4, N, 2)
+ coords = coords.unsqueeze(1) # (4, 1, N, 2)
+
+ coef = 1.0
+ for i in range(1, self.num_facs):
+ fac = self.get_parameter(f"fac_{i}") + self.get_buffer(f"fac_{i}_init")
+ coef = coef * F.grid_sample(
+ fac, coords[[i - 1]], align_corners=True, padding_mode="border"
+ ) # [1, rank, 1, N]
+ coef = coef.squeeze([0, 2]).transpose(0, 1) # (N, rank) #type: ignore
+ mat = self.fac_0(coef)
+ return mat.reshape(*sh_[:-1], 3, 4)
diff --git a/nerfstudio/models/splatfacto.py b/nerfstudio/models/splatfacto.py
index 37856e3891..05587c642b 100644
--- a/nerfstudio/models/splatfacto.py
+++ b/nerfstudio/models/splatfacto.py
@@ -40,6 +40,7 @@
from nerfstudio.data.scene_box import OrientedBox
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation
from nerfstudio.engine.optimizers import Optimizers
+from nerfstudio.model_components.lib_bilagrid import BilateralGrid, color_correct, slice, total_variation_loss
from nerfstudio.models.base_model import Model, ModelConfig
from nerfstudio.utils.colors import get_color
from nerfstudio.utils.misc import torch_compile
@@ -184,6 +185,12 @@ class SplatfactoModelConfig(ModelConfig):
"""
camera_optimizer: CameraOptimizerConfig = field(default_factory=lambda: CameraOptimizerConfig(mode="off"))
"""Config of the camera optimizer to use"""
+ use_bilateral_grid: bool = False
+ """If True, use bilateral grid to handle the ISP changes in the image space. This technique was introduced in the paper 'Bilateral Guided Radiance Field Processing' (https://bilarfpro.github.io/)."""
+ grid_shape: Tuple[int, int, int] = (16, 16, 8)
+ """Shape of the bilateral grid (X, Y, W)"""
+ color_corrected_metrics: bool = False
+ """If True, apply color correction to the rendered images before computing the metrics."""
class SplatfactoModel(Model):
@@ -271,6 +278,13 @@ def populate_modules(self):
) # This color is the same as the default background color in Viser. This would only affect the background color when rendering.
else:
self.background_color = get_color(self.config.background_color)
+ if self.config.use_bilateral_grid:
+ self.bil_grids = BilateralGrid(
+ num=self.num_train_data,
+ grid_X=self.config.grid_shape[0],
+ grid_Y=self.config.grid_shape[1],
+ grid_W=self.config.grid_shape[2],
+ )
@property
def colors(self):
@@ -647,6 +661,8 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]:
Mapping of different parameter groups
"""
gps = self.get_gaussian_param_groups()
+ if self.config.use_bilateral_grid:
+ gps["bilateral_grid"] = list(self.bil_grids.parameters())
self.camera_optimizer.get_param_groups(param_groups=gps)
return gps
@@ -686,6 +702,23 @@ def _get_background_color(self):
raise ValueError(f"Unknown background color {self.config.background_color}")
return background
+ def _apply_bilateral_grid(self, rgb: torch.Tensor, cam_idx: int, H: int, W: int) -> torch.Tensor:
+ # make xy grid
+ grid_y, grid_x = torch.meshgrid(
+ torch.linspace(0, 1.0, H, device=self.device),
+ torch.linspace(0, 1.0, W, device=self.device),
+ indexing="ij",
+ )
+ grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0)
+
+ out = slice(
+ bil_grids=self.bil_grids,
+ rgb=rgb,
+ xy=grid_xy,
+ grid_idx=torch.tensor(cam_idx, device=self.device, dtype=torch.long),
+ )
+ return out["rgb"]
+
def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
"""Takes in a camera and returns a dictionary of outputs.
@@ -789,6 +822,11 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
rgb = render[:, ..., :3] + (1 - alpha) * background
rgb = torch.clamp(rgb, 0.0, 1.0)
+ # apply bilateral grid
+ if self.config.use_bilateral_grid and self.training:
+ if camera.metadata is not None and "cam_idx" in camera.metadata:
+ rgb = self._apply_bilateral_grid(rgb, camera.metadata["cam_idx"], H, W)
+
if render_mode == "RGB+ED":
depth_im = render[:, ..., 3:4]
depth_im = torch.where(alpha > 0, depth_im, depth_im.detach().max()).squeeze(0)
@@ -839,7 +877,11 @@ def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]:
gt_rgb = self.composite_with_background(self.get_gt_img(batch["image"]), outputs["background"])
metrics_dict = {}
predicted_rgb = outputs["rgb"]
+
metrics_dict["psnr"] = self.psnr(predicted_rgb, gt_rgb)
+ if self.config.color_corrected_metrics:
+ cc_rgb = color_correct(predicted_rgb, gt_rgb)
+ metrics_dict["cc_psnr"] = self.psnr(cc_rgb, gt_rgb)
metrics_dict["gaussian_count"] = self.num_points
@@ -890,6 +932,8 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
if self.training:
# Add loss from camera optimizer
self.camera_optimizer.get_loss_dict(loss_dict)
+ if self.config.use_bilateral_grid:
+ loss_dict["tv_loss"] = 10 * total_variation_loss(self.bil_grids.grids)
return loss_dict
@@ -922,9 +966,14 @@ def get_image_metrics_and_images(
"""
gt_rgb = self.composite_with_background(self.get_gt_img(batch["image"]), outputs["background"])
predicted_rgb = outputs["rgb"]
+ cc_rgb = None
combined_rgb = torch.cat([gt_rgb, predicted_rgb], dim=1)
+ if self.config.color_corrected_metrics:
+ cc_rgb = color_correct(predicted_rgb, gt_rgb)
+ cc_rgb = torch.moveaxis(cc_rgb, -1, 0)[None, ...]
+
# Switch images from [H, W, C] to [1, C, H, W] for metrics computations
gt_rgb = torch.moveaxis(gt_rgb, -1, 0)[None, ...]
predicted_rgb = torch.moveaxis(predicted_rgb, -1, 0)[None, ...]
@@ -937,6 +986,15 @@ def get_image_metrics_and_images(
metrics_dict = {"psnr": float(psnr.item()), "ssim": float(ssim)} # type: ignore
metrics_dict["lpips"] = float(lpips)
+ if self.config.color_corrected_metrics:
+ assert cc_rgb is not None
+ cc_psnr = self.psnr(gt_rgb, cc_rgb)
+ cc_ssim = self.ssim(gt_rgb, cc_rgb)
+ cc_lpips = self.lpips(gt_rgb, cc_rgb)
+ metrics_dict["cc_psnr"] = float(cc_psnr.item())
+ metrics_dict["cc_ssim"] = float(cc_ssim)
+ metrics_dict["cc_lpips"] = float(cc_lpips)
+
images_dict = {"img": combined_rgb}
return metrics_dict, images_dict
diff --git a/pyproject.toml b/pyproject.toml
index 397433cba1..fa7d2c319c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -66,7 +66,8 @@ dependencies = [
"pytorch-msssim",
"pathos",
"packaging",
- "fpsample"
+ "fpsample",
+ "tensorly"
]
[project.urls]