diff --git a/examples/image_fitting.py b/examples/image_fitting.py index 155a6dee5..51a07beb6 100644 --- a/examples/image_fitting.py +++ b/examples/image_fitting.py @@ -2,7 +2,7 @@ import os import time from pathlib import Path -from typing import Optional, Literal +from typing import Literal, Optional import numpy as np import torch @@ -104,7 +104,7 @@ def train( for iter in range(iterations): start = time.time() - renders, _ = rasterize_fnc( + renders = rasterize_fnc( self.means, self.quats / self.quats.norm(dim=-1, keepdim=True), self.scales, @@ -115,8 +115,8 @@ def train( self.W, self.H, packed=False, - ) - out_img = renders[0].squeeze(0) + )[0] + out_img = renders[0] torch.cuda.synchronize() times[0] += time.time() - start loss = mse_loss(out_img, self.gt_image) @@ -130,8 +130,6 @@ def train( if save_imgs and iter % 5 == 0: frames.append((out_img.detach().cpu().numpy() * 255).astype(np.uint8)) - # break - if save_imgs: # save them as a gif with PIL frames = [Image.fromarray(frame) for frame in frames] diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 0393f7c19..472573f21 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -1082,7 +1082,7 @@ def rasterization_2dgs( **render_normals**: The rendered normals. [C, height, width, 3]. - **render_normals_from_depth: surface normal from depth. [C, height, width, 3] + **render_normals_from_depth**: surface normal from depth. [C, height, width, 3] **render_distort**: The rendered distortions. [C, height, width, 1]. L1 version, different from L2 version in 2DGS paper. @@ -1278,6 +1278,8 @@ def rasterization_2dgs( render_normals = render_normals @ torch.linalg.inv(viewmats)[0, :3, :3].T + # Note(ruilong): To align with the rasterization function, this should return + # (render_colors, ..., meta) instead of (render_colors, render_alphas), meta. return ( render_colors, render_alphas,