From 2ed1716980191abd818d048b686475f897267c96 Mon Sep 17 00:00:00 2001 From: Congrong Xu <50019703+KevinXu02@users.noreply.github.com> Date: Wed, 14 Aug 2024 03:18:43 +0800 Subject: [PATCH] Bilagird for splatfacto (#3316) add bilateral grid as option to splatfacto training --- nerfstudio/configs/method_configs.py | 12 + nerfstudio/model_components/lib_bilagrid.py | 547 ++++++++++++++++++++ nerfstudio/models/splatfacto.py | 58 +++ pyproject.toml | 3 +- 4 files changed, 619 insertions(+), 1 deletion(-) create mode 100644 nerfstudio/model_components/lib_bilagrid.py diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index ce54a35a35..e77ab130c4 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -637,6 +637,12 @@ lr_final=5e-7, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0 ), }, + "bilateral_grid": { + "optimizer": AdamOptimizerConfig(lr=2e-3, eps=1e-15), + "scheduler": ExponentialDecaySchedulerConfig( + lr_final=1e-4, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0 + ), + }, }, viewer=ViewerConfig(num_rays_per_chunk=1 << 15), vis="viewer", @@ -692,6 +698,12 @@ lr_final=5e-7, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0 ), }, + "bilateral_grid": { + "optimizer": AdamOptimizerConfig(lr=5e-3, eps=1e-15), + "scheduler": ExponentialDecaySchedulerConfig( + lr_final=1e-4, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0 + ), + }, }, viewer=ViewerConfig(num_rays_per_chunk=1 << 15), vis="viewer", diff --git a/nerfstudio/model_components/lib_bilagrid.py b/nerfstudio/model_components/lib_bilagrid.py new file mode 100644 index 0000000000..93869cc385 --- /dev/null +++ b/nerfstudio/model_components/lib_bilagrid.py @@ -0,0 +1,547 @@ +# # Copyright 2024 Yuehao Wang (https://github.com/yuehaowang). This part of code is borrowed form ["Bilateral Guided Radiance Field Processing"](https://bilarfpro.github.io/). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This is a standalone PyTorch implementation of 3D bilateral grid and CP-decomposed 4D bilateral grid. +To use this module, you can download the "lib_bilagrid.py" file and simply put it in your project directory. + +For the details, please check our research project: ["Bilateral Guided Radiance Field Processing"](https://bilarfpro.github.io/). + +#### Dependencies + +In addition to PyTorch and Numpy, please install [tensorly](https://github.com/tensorly/tensorly). +We have tested this module on Python 3.9.18, PyTorch 2.0.1 (CUDA 11), tensorly 0.8.1, and Numpy 1.25.2. + +#### Overview + +- For bilateral guided training, you need to construct a `BilateralGrid` instance, which can hold multiple bilateral grids + for input views. Then, use `slice` function to obtain transformed RGB output and the corresponding affine transformations. + +- For bilateral guided finishing, you need to instantiate a `BilateralGridCP4D` object and use `slice4d`. + +#### Examples + +- Bilateral grid for approximating ISP: + + Open In Colab + +- Low-rank 4D bilateral grid for MR enhancement: + + Open In Colab + + +Below is the API reference. + +""" + +import tensorly as tl +import torch +import torch.nn.functional as F +from torch import nn + +tl.set_backend("pytorch") + + +def color_correct(img: torch.Tensor, ref: torch.Tensor, num_iters: int = 5, eps: float = 0.5 / 255) -> torch.Tensor: + """ + Warp `img` to match the colors in `ref_img` using iterative color matching. + + This function performs color correction by warping the colors of the input image + to match those of a reference image. It uses a least squares method to find a + transformation that maps the input image's colors to the reference image's colors. + + The algorithm iteratively solves a system of linear equations, updating the set of + unsaturated pixels in each iteration. This approach helps handle non-linear color + transformations and reduces the impact of clipping. + + Args: + img (torch.Tensor): Input image to be color corrected. Shape: [..., num_channels] + ref (torch.Tensor): Reference image to match colors. Shape: [..., num_channels] + num_iters (int, optional): Number of iterations for the color matching process. + Default is 5. + eps (float, optional): Small value to determine the range of unclipped pixels. + Default is 0.5 / 255. + + Returns: + torch.Tensor: Color corrected image with the same shape as the input image. + + Note: + - Both input and reference images should be in the range [0, 1]. + - The function works with any number of channels, but typically used with 3 (RGB). + """ + if img.shape[-1] != ref.shape[-1]: + raise ValueError(f"img's {img.shape[-1]} and ref's {ref.shape[-1]} channels must match") + num_channels = img.shape[-1] + img_mat = img.reshape([-1, num_channels]) + ref_mat = ref.reshape([-1, num_channels]) + + def is_unclipped(z): + return (z >= eps) & (z <= 1 - eps) # z \in [eps, 1-eps]. + + mask0 = is_unclipped(img_mat) + # Because the set of saturated pixels may change after solving for a + # transformation, we repeatedly solve a system `num_iters` times and update + # our estimate of which pixels are saturated. + for _ in range(num_iters): + # Construct the left hand side of a linear system that contains a quadratic + # expansion of each pixel of `img`. + a_mat = [] + for c in range(num_channels): + a_mat.append(img_mat[:, c : (c + 1)] * img_mat[:, c:]) # Quadratic term. + a_mat.append(img_mat) # Linear term. + a_mat.append(torch.ones_like(img_mat[:, :1])) # Bias term. + a_mat = torch.cat(a_mat, dim=-1) + warp = [] + for c in range(num_channels): + # Construct the right hand side of a linear system containing each color + # of `ref`. + b = ref_mat[:, c] + # Ignore rows of the linear system that were saturated in the input or are + # saturated in the current corrected color estimate. + mask = mask0[:, c] & is_unclipped(img_mat[:, c]) & is_unclipped(b) + ma_mat = torch.where(mask[:, None], a_mat, torch.zeros_like(a_mat)) + mb = torch.where(mask, b, torch.zeros_like(b)) + w = torch.linalg.lstsq(ma_mat, mb, rcond=-1)[0] + assert torch.all(torch.isfinite(w)) + warp.append(w) + warp = torch.stack(warp, dim=-1) + # Apply the warp to update img_mat. + img_mat = torch.clip(torch.matmul(a_mat, warp), 0, 1) + corrected_img = torch.reshape(img_mat, img.shape) + return corrected_img + + +def bilateral_grid_tv_loss(model, config): + """Computes total variations of bilateral grids.""" + total_loss = 0.0 + + for bil_grids in model.bil_grids: + total_loss += config.bilgrid_tv_loss_mult * total_variation_loss(bil_grids.grids) + + return total_loss + + +def color_affine_transform(affine_mats, rgb): + """Applies color affine transformations. + + Args: + affine_mats (torch.Tensor): Affine transformation matrices. Supported shape: $(..., 3, 4)$. + rgb (torch.Tensor): Input RGB values. Supported shape: $(..., 3)$. + + Returns: + Output transformed colors of shape $(..., 3)$. + """ + return torch.matmul(affine_mats[..., :3], rgb.unsqueeze(-1)).squeeze(-1) + affine_mats[..., 3] + + +def _num_tensor_elems(t): + return max(torch.prod(torch.tensor(t.size()[1:]).float()).item(), 1.0) + + +def total_variation_loss(x): # noqa: F811 + """Returns total variation on multi-dimensional tensors. + + Args: + x (torch.Tensor): The input tensor with shape $(B, C, ...)$, where $B$ is the batch size and $C$ is the channel size. + """ + batch_size = x.shape[0] + tv = 0 + for i in range(2, len(x.shape)): + n_res = x.shape[i] + idx1 = torch.arange(1, n_res, device=x.device) + idx2 = torch.arange(0, n_res - 1, device=x.device) + x1 = x.index_select(i, idx1) + x2 = x.index_select(i, idx2) + count = _num_tensor_elems(x1) + tv += torch.pow((x1 - x2), 2).sum() / count + return tv / batch_size + + +def slice(bil_grids, xy, rgb, grid_idx): + """Slices a batch of 3D bilateral grids by pixel coordinates `xy` and gray-scale guidances of pixel colors `rgb`. + + Supports 2-D, 3-D, and 4-D input shapes. The first dimension of the input is the batch size + and the last dimension is 2 for `xy`, 3 for `rgb`, and 1 for `grid_idx`. + + The return value is a dictionary containing the affine transformations `affine_mats` sliced from bilateral grids and + the output color `rgb_out` after applying the afffine transformations. + + In the 2-D input case, `xy` is a $(N, 2)$ tensor, `rgb` is a $(N, 3)$ tensor, and `grid_idx` is a $(N, 1)$ tensor. + Then `affine_mats[i]` can be obtained via slicing the bilateral grid indexed at `grid_idx[i]` by `xy[i, :]` and `rgb2gray(rgb[i, :])`. + For 3-D and 4-D input cases, the behavior of indexing bilateral grids and coordinates is the same with the 2-D case. + + .. note:: + This function can be regarded as a wrapper of `color_affine_transform` and `BilateralGrid` with a slight performance improvement. + When `grid_idx` contains a unique index, only a single bilateral grid will used during the slicing. In this case, this function will not + perform tensor indexing to avoid data copy and extra memory + (see [this](https://discuss.pytorch.org/t/does-indexing-a-tensor-return-a-copy-of-it/164905)). + + Args: + bil_grids (`BilateralGrid`): An instance of $N$ bilateral grids. + xy (torch.Tensor): The x-y coordinates of shape $(..., 2)$ in the range of $[0,1]$. + rgb (torch.Tensor): The RGB values of shape $(..., 3)$ for computing the guidance coordinates, ranging in $[0,1]$. + grid_idx (torch.Tensor): The indices of bilateral grids for each slicing. Shape: $(..., 1)$. + + Returns: + A dictionary with keys and values as follows: + ``` + { + "rgb": Transformed RGB colors. Shape: (..., 3), + "rgb_affine_mats": The sliced affine transformation matrices from bilateral grids. Shape: (..., 3, 4) + } + ``` + """ + + sh_ = rgb.shape + + grid_idx_unique = torch.unique(grid_idx) + if len(grid_idx_unique) == 1: + # All pixels are from a single view. + grid_idx = grid_idx_unique # (1,) + xy = xy.unsqueeze(0) # (1, ..., 2) + rgb = rgb.unsqueeze(0) # (1, ..., 3) + else: + # Pixels are randomly sampled from different views. + if len(grid_idx.shape) == 4: + grid_idx = grid_idx[:, 0, 0, 0] # (chunk_size,) + elif len(grid_idx.shape) == 3: + grid_idx = grid_idx[:, 0, 0] # (chunk_size,) + elif len(grid_idx.shape) == 2: + grid_idx = grid_idx[:, 0] # (chunk_size,) + else: + raise ValueError("The input to bilateral grid slicing is not supported yet.") + + affine_mats = bil_grids(xy, rgb, grid_idx) + rgb = color_affine_transform(affine_mats, rgb) + + return { + "rgb": rgb.reshape(*sh_), + "rgb_affine_mats": affine_mats.reshape(*sh_[:-1], affine_mats.shape[-2], affine_mats.shape[-1]), + } + + +class BilateralGrid(nn.Module): + """Class for 3D bilateral grids. + + Holds one or more than one bilateral grids. + """ + + def __init__(self, num, grid_X=16, grid_Y=16, grid_W=8): + """ + Args: + num (int): The number of bilateral grids (i.e., the number of views). + grid_X (int): Defines grid width $W$. + grid_Y (int): Defines grid height $H$. + grid_W (int): Defines grid guidance dimension $L$. + """ + super(BilateralGrid, self).__init__() + + self.grid_width = grid_X + """Grid width. Type: int.""" + self.grid_height = grid_Y + """Grid height. Type: int.""" + self.grid_guidance = grid_W + """Grid guidance dimension. Type: int.""" + + # Initialize grids. + grid = self._init_identity_grid() + self.grids = nn.Parameter(grid.tile(num, 1, 1, 1, 1)) # (N, 12, L, H, W) + """ A 5-D tensor of shape $(N, 12, L, H, W)$.""" + + # Weights of BT601 RGB-to-gray. + self.register_buffer("rgb2gray_weight", torch.Tensor([[0.299, 0.587, 0.114]])) + self.rgb2gray = lambda rgb: (rgb @ self.rgb2gray_weight.T) * 2.0 - 1.0 + """ A function that converts RGB to gray-scale guidance in $[-1, 1]$.""" + + def _init_identity_grid(self): + grid = torch.tensor( + [ + 1.0, + 0, + 0, + 0, + 0, + 1.0, + 0, + 0, + 0, + 0, + 1.0, + 0, + ] + ).float() + grid = grid.repeat([self.grid_guidance * self.grid_height * self.grid_width, 1]) # (L * H * W, 12) + grid = grid.reshape(1, self.grid_guidance, self.grid_height, self.grid_width, -1) # (1, L, H, W, 12) + grid = grid.permute(0, 4, 1, 2, 3) # (1, 12, L, H, W) + return grid + + def tv_loss(self): + """Computes and returns total variation loss on the bilateral grids.""" + return total_variation_loss(self.grids) + + def forward(self, grid_xy, rgb, idx=None): + """Bilateral grid slicing. Supports 2-D, 3-D, 4-D, and 5-D input. + For the 2-D, 3-D, and 4-D cases, please refer to `slice`. + For the 5-D cases, `idx` will be unused and the first dimension of `xy` should be + equal to the number of bilateral grids. Then this function becomes PyTorch's + [`F.grid_sample`](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html). + + Args: + grid_xy (torch.Tensor): The x-y coordinates in the range of $[0,1]$. + rgb (torch.Tensor): The RGB values in the range of $[0,1]$. + idx (torch.Tensor): The bilateral grid indices. + + Returns: + Sliced affine matrices of shape $(..., 3, 4)$. + """ + + grids = self.grids + input_ndims = len(grid_xy.shape) + assert len(rgb.shape) == input_ndims + + if input_ndims > 1 and input_ndims < 5: + # Convert input into 5D + for i in range(5 - input_ndims): + grid_xy = grid_xy.unsqueeze(1) + rgb = rgb.unsqueeze(1) + assert idx is not None + elif input_ndims != 5: + raise ValueError("Bilateral grid slicing only takes either 2D, 3D, 4D and 5D inputs") + + grids = self.grids + if idx is not None: + grids = grids[idx] + assert grids.shape[0] == grid_xy.shape[0] + + # Generate slicing coordinates. + grid_xy = (grid_xy - 0.5) * 2 # Rescale to [-1, 1]. + grid_z = self.rgb2gray(rgb) + + # print(grid_xy.shape, grid_z.shape) + # exit() + grid_xyz = torch.cat([grid_xy, grid_z], dim=-1) # (N, m, h, w, 3) + + affine_mats = F.grid_sample( + grids, grid_xyz, mode="bilinear", align_corners=True, padding_mode="border" + ) # (N, 12, m, h, w) + affine_mats = affine_mats.permute(0, 2, 3, 4, 1) # (N, m, h, w, 12) + affine_mats = affine_mats.reshape(*affine_mats.shape[:-1], 3, 4) # (N, m, h, w, 3, 4) + + for _ in range(5 - input_ndims): + affine_mats = affine_mats.squeeze(1) + + return affine_mats + + +def slice4d(bil_grid4d, xyz, rgb): + """Slices a 4D bilateral grid by point coordinates `xyz` and gray-scale guidances of radiance colors `rgb`. + + Args: + bil_grid4d (`BilateralGridCP4D`): The input 4D bilateral grid. + xyz (torch.Tensor): The xyz coordinates with shape $(..., 3)$. + rgb (torch.Tensor): The RGB values with shape $(..., 3)$. + + Returns: + A dictionary with keys and values as follows: + ``` + { + "rgb": Transformed radiance RGB colors. Shape: (..., 3), + "rgb_affine_mats": The sliced affine transformation matrices from the 4D bilateral grid. Shape: (..., 3, 4) + } + ``` + """ + + affine_mats = bil_grid4d(xyz, rgb) + rgb = color_affine_transform(affine_mats, rgb) + + return {"rgb": rgb, "rgb_affine_mats": affine_mats} + + +class _ScaledTanh(nn.Module): + def __init__(self, s=2.0): + super().__init__() + self.scaler = s + + def forward(self, x): + return torch.tanh(self.scaler * x) + + +class BilateralGridCP4D(nn.Module): + """Class for low-rank 4D bilateral grids.""" + + def __init__( + self, + grid_X=16, + grid_Y=16, + grid_Z=16, + grid_W=8, + rank=5, + learn_gray=True, + gray_mlp_width=8, + gray_mlp_depth=2, + init_noise_scale=1e-6, + bound=2.0, + ): + """ + Args: + grid_X (int): Defines grid width. + grid_Y (int): Defines grid height. + grid_Z (int): Defines grid depth. + grid_W (int): Defines grid guidance dimension. + rank (int): Rank of the 4D bilateral grid. + learn_gray (bool): If True, an MLP will be learned to convert RGB colors to gray-scale guidances. + gray_mlp_width (int): The MLP width for learnable guidance. + gray_mlp_depth (int): The number of MLP layers for learnable guidance. + init_noise_scale (float): The noise scale of the initialized factors. + bound (float): The bound of the xyz coordinates. + """ + super(BilateralGridCP4D, self).__init__() + + self.grid_X = grid_X + """Grid width. Type: int.""" + self.grid_Y = grid_Y + """Grid height. Type: int.""" + self.grid_Z = grid_Z + """Grid depth. Type: int.""" + self.grid_W = grid_W + """Grid guidance dimension. Type: int.""" + self.rank = rank + """Rank of the 4D bilateral grid. Type: int.""" + self.learn_gray = learn_gray + """Flags of learnable guidance is used. Type: bool.""" + self.gray_mlp_width = gray_mlp_width + """The MLP width for learnable guidance. Type: int.""" + self.gray_mlp_depth = gray_mlp_depth + """The MLP depth for learnable guidance. Type: int.""" + self.init_noise_scale = init_noise_scale + """The noise scale of the initialized factors. Type: float.""" + self.bound = bound + """The bound of the xyz coordinates. Type: float.""" + + self._init_cp_factors_parafac() + + self.rgb2gray = None + """ A function that converts RGB to gray-scale guidances in $[-1, 1]$. + If `learn_gray` is True, this will be an MLP network.""" + + if self.learn_gray: + + def rgb2gray_mlp_linear(layer): + return nn.Linear(self.gray_mlp_width, self.gray_mlp_width if layer < self.gray_mlp_depth - 1 else 1) + + def rgb2gray_mlp_actfn(_): + return nn.ReLU(inplace=True) + + self.rgb2gray = nn.Sequential( + *( + [nn.Linear(3, self.gray_mlp_width)] + + [ + nn_module(layer) + for layer in range(1, self.gray_mlp_depth) + for nn_module in [rgb2gray_mlp_actfn, rgb2gray_mlp_linear] + ] + + [_ScaledTanh(2.0)] + ) + ) + else: + # Weights of BT601/BT470 RGB-to-gray. + self.register_buffer("rgb2gray_weight", torch.Tensor([[0.299, 0.587, 0.114]])) + self.rgb2gray = lambda rgb: (rgb @ self.rgb2gray_weight.T) * 2.0 - 1.0 + + def _init_identity_grid(self): + grid = torch.tensor( + [ + 1.0, + 0, + 0, + 0, + 0, + 1.0, + 0, + 0, + 0, + 0, + 1.0, + 0, + ] + ).float() + grid = grid.repeat([self.grid_W * self.grid_Z * self.grid_Y * self.grid_X, 1]) + grid = grid.reshape(self.grid_W, self.grid_Z, self.grid_Y, self.grid_X, -1) + grid = grid.permute(4, 0, 1, 2, 3) # (12, grid_W, grid_Z, grid_Y, grid_X) + return grid + + def _init_cp_factors_parafac(self): + # Initialize identity grids. + init_grids = self._init_identity_grid() + # Random noises are added to avoid singularity. + init_grids = torch.randn_like(init_grids) * self.init_noise_scale + init_grids + from tensorly.decomposition import parafac + + # Initialize grid CP factors + _, facs = parafac(init_grids.clone().detach(), rank=self.rank) + + self.num_facs = len(facs) + + self.fac_0 = nn.Linear(facs[0].shape[0], facs[0].shape[1], bias=False) + self.fac_0.weight = nn.Parameter(facs[0]) # (12, rank) + + for i in range(1, self.num_facs): + fac = facs[i].T # (rank, grid_size) + fac = fac.view(1, fac.shape[0], fac.shape[1], 1) # (1, rank, grid_size, 1) + self.register_buffer(f"fac_{i}_init", fac) + + fac_resid = torch.zeros_like(fac) + self.register_parameter(f"fac_{i}", nn.Parameter(fac_resid)) + + def tv_loss(self): + """Computes and returns total variation loss on the factors of the low-rank 4D bilateral grids.""" + + total_loss = 0 + for i in range(1, self.num_facs): + fac = self.get_parameter(f"fac_{i}") + total_loss += total_variation_loss(fac) + + return total_loss + + def forward(self, xyz, rgb): + """Low-rank 4D bilateral grid slicing. + + Args: + xyz (torch.Tensor): The xyz coordinates with shape $(..., 3)$. + rgb (torch.Tensor): The corresponding RGB values with shape $(..., 3)$. + + Returns: + Sliced affine matrices with shape $(..., 3, 4)$. + """ + sh_ = xyz.shape + xyz = xyz.reshape(-1, 3) # flatten (N, 3) + rgb = rgb.reshape(-1, 3) # flatten (N, 3) + + xyz = xyz / self.bound + assert self.rgb2gray is not None + gray = self.rgb2gray(rgb) + xyzw = torch.cat([xyz, gray], dim=-1) # (N, 4) + xyzw = xyzw.transpose(0, 1) # (4, N) + coords = torch.stack([torch.zeros_like(xyzw), xyzw], dim=-1) # (4, N, 2) + coords = coords.unsqueeze(1) # (4, 1, N, 2) + + coef = 1.0 + for i in range(1, self.num_facs): + fac = self.get_parameter(f"fac_{i}") + self.get_buffer(f"fac_{i}_init") + coef = coef * F.grid_sample( + fac, coords[[i - 1]], align_corners=True, padding_mode="border" + ) # [1, rank, 1, N] + coef = coef.squeeze([0, 2]).transpose(0, 1) # (N, rank) #type: ignore + mat = self.fac_0(coef) + return mat.reshape(*sh_[:-1], 3, 4) diff --git a/nerfstudio/models/splatfacto.py b/nerfstudio/models/splatfacto.py index 37856e3891..05587c642b 100644 --- a/nerfstudio/models/splatfacto.py +++ b/nerfstudio/models/splatfacto.py @@ -40,6 +40,7 @@ from nerfstudio.data.scene_box import OrientedBox from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation from nerfstudio.engine.optimizers import Optimizers +from nerfstudio.model_components.lib_bilagrid import BilateralGrid, color_correct, slice, total_variation_loss from nerfstudio.models.base_model import Model, ModelConfig from nerfstudio.utils.colors import get_color from nerfstudio.utils.misc import torch_compile @@ -184,6 +185,12 @@ class SplatfactoModelConfig(ModelConfig): """ camera_optimizer: CameraOptimizerConfig = field(default_factory=lambda: CameraOptimizerConfig(mode="off")) """Config of the camera optimizer to use""" + use_bilateral_grid: bool = False + """If True, use bilateral grid to handle the ISP changes in the image space. This technique was introduced in the paper 'Bilateral Guided Radiance Field Processing' (https://bilarfpro.github.io/).""" + grid_shape: Tuple[int, int, int] = (16, 16, 8) + """Shape of the bilateral grid (X, Y, W)""" + color_corrected_metrics: bool = False + """If True, apply color correction to the rendered images before computing the metrics.""" class SplatfactoModel(Model): @@ -271,6 +278,13 @@ def populate_modules(self): ) # This color is the same as the default background color in Viser. This would only affect the background color when rendering. else: self.background_color = get_color(self.config.background_color) + if self.config.use_bilateral_grid: + self.bil_grids = BilateralGrid( + num=self.num_train_data, + grid_X=self.config.grid_shape[0], + grid_Y=self.config.grid_shape[1], + grid_W=self.config.grid_shape[2], + ) @property def colors(self): @@ -647,6 +661,8 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]: Mapping of different parameter groups """ gps = self.get_gaussian_param_groups() + if self.config.use_bilateral_grid: + gps["bilateral_grid"] = list(self.bil_grids.parameters()) self.camera_optimizer.get_param_groups(param_groups=gps) return gps @@ -686,6 +702,23 @@ def _get_background_color(self): raise ValueError(f"Unknown background color {self.config.background_color}") return background + def _apply_bilateral_grid(self, rgb: torch.Tensor, cam_idx: int, H: int, W: int) -> torch.Tensor: + # make xy grid + grid_y, grid_x = torch.meshgrid( + torch.linspace(0, 1.0, H, device=self.device), + torch.linspace(0, 1.0, W, device=self.device), + indexing="ij", + ) + grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) + + out = slice( + bil_grids=self.bil_grids, + rgb=rgb, + xy=grid_xy, + grid_idx=torch.tensor(cam_idx, device=self.device, dtype=torch.long), + ) + return out["rgb"] + def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: """Takes in a camera and returns a dictionary of outputs. @@ -789,6 +822,11 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: rgb = render[:, ..., :3] + (1 - alpha) * background rgb = torch.clamp(rgb, 0.0, 1.0) + # apply bilateral grid + if self.config.use_bilateral_grid and self.training: + if camera.metadata is not None and "cam_idx" in camera.metadata: + rgb = self._apply_bilateral_grid(rgb, camera.metadata["cam_idx"], H, W) + if render_mode == "RGB+ED": depth_im = render[:, ..., 3:4] depth_im = torch.where(alpha > 0, depth_im, depth_im.detach().max()).squeeze(0) @@ -839,7 +877,11 @@ def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]: gt_rgb = self.composite_with_background(self.get_gt_img(batch["image"]), outputs["background"]) metrics_dict = {} predicted_rgb = outputs["rgb"] + metrics_dict["psnr"] = self.psnr(predicted_rgb, gt_rgb) + if self.config.color_corrected_metrics: + cc_rgb = color_correct(predicted_rgb, gt_rgb) + metrics_dict["cc_psnr"] = self.psnr(cc_rgb, gt_rgb) metrics_dict["gaussian_count"] = self.num_points @@ -890,6 +932,8 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te if self.training: # Add loss from camera optimizer self.camera_optimizer.get_loss_dict(loss_dict) + if self.config.use_bilateral_grid: + loss_dict["tv_loss"] = 10 * total_variation_loss(self.bil_grids.grids) return loss_dict @@ -922,9 +966,14 @@ def get_image_metrics_and_images( """ gt_rgb = self.composite_with_background(self.get_gt_img(batch["image"]), outputs["background"]) predicted_rgb = outputs["rgb"] + cc_rgb = None combined_rgb = torch.cat([gt_rgb, predicted_rgb], dim=1) + if self.config.color_corrected_metrics: + cc_rgb = color_correct(predicted_rgb, gt_rgb) + cc_rgb = torch.moveaxis(cc_rgb, -1, 0)[None, ...] + # Switch images from [H, W, C] to [1, C, H, W] for metrics computations gt_rgb = torch.moveaxis(gt_rgb, -1, 0)[None, ...] predicted_rgb = torch.moveaxis(predicted_rgb, -1, 0)[None, ...] @@ -937,6 +986,15 @@ def get_image_metrics_and_images( metrics_dict = {"psnr": float(psnr.item()), "ssim": float(ssim)} # type: ignore metrics_dict["lpips"] = float(lpips) + if self.config.color_corrected_metrics: + assert cc_rgb is not None + cc_psnr = self.psnr(gt_rgb, cc_rgb) + cc_ssim = self.ssim(gt_rgb, cc_rgb) + cc_lpips = self.lpips(gt_rgb, cc_rgb) + metrics_dict["cc_psnr"] = float(cc_psnr.item()) + metrics_dict["cc_ssim"] = float(cc_ssim) + metrics_dict["cc_lpips"] = float(cc_lpips) + images_dict = {"img": combined_rgb} return metrics_dict, images_dict diff --git a/pyproject.toml b/pyproject.toml index 397433cba1..fa7d2c319c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,8 @@ dependencies = [ "pytorch-msssim", "pathos", "packaging", - "fpsample" + "fpsample", + "tensorly" ] [project.urls]