Skip to content

Commit

Permalink
image fitting cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruilong Li committed Sep 8, 2024
1 parent 4630230 commit 61ce019
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
10 changes: 4 additions & 6 deletions examples/image_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import time
from pathlib import Path
from typing import Optional, Literal
from typing import Literal, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -104,7 +104,7 @@ def train(
for iter in range(iterations):
start = time.time()

renders, _ = rasterize_fnc(
renders = rasterize_fnc(
self.means,
self.quats / self.quats.norm(dim=-1, keepdim=True),
self.scales,
Expand All @@ -115,8 +115,8 @@ def train(
self.W,
self.H,
packed=False,
)
out_img = renders[0].squeeze(0)
)[0]
out_img = renders[0]
torch.cuda.synchronize()
times[0] += time.time() - start
loss = mse_loss(out_img, self.gt_image)
Expand All @@ -130,8 +130,6 @@ def train(

if save_imgs and iter % 5 == 0:
frames.append((out_img.detach().cpu().numpy() * 255).astype(np.uint8))
# break

if save_imgs:
# save them as a gif with PIL
frames = [Image.fromarray(frame) for frame in frames]
Expand Down
4 changes: 3 additions & 1 deletion gsplat/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,7 +1082,7 @@ def rasterization_2dgs(
**render_normals**: The rendered normals. [C, height, width, 3].
**render_normals_from_depth: surface normal from depth. [C, height, width, 3]
**render_normals_from_depth**: surface normal from depth. [C, height, width, 3]
**render_distort**: The rendered distortions. [C, height, width, 1].
L1 version, different from L2 version in 2DGS paper.
Expand Down Expand Up @@ -1278,6 +1278,8 @@ def rasterization_2dgs(

render_normals = render_normals @ torch.linalg.inv(viewmats)[0, :3, :3].T

# Note(ruilong): To align with the rasterization function, this should return
# (render_colors, ..., meta) instead of (render_colors, render_alphas), meta.
return (
render_colors,
render_alphas,
Expand Down

0 comments on commit 61ce019

Please sign in to comment.