From 7ed261092f27d6fbf4ba371a1b47e1b63c53daff Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Sat, 28 Sep 2024 10:10:48 +0200 Subject: [PATCH 1/7] implement ply saving --- examples/simple_trainer.py | 43 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 3e544201b..67c293f08 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -2,6 +2,7 @@ import math import os import time +import open3d as o3d from dataclasses import dataclass, field from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union @@ -181,6 +182,45 @@ def adjust_steps(self, factor: float): assert_never(strategy) +def save_ply(splats: torch.nn.ParameterDict, dir: str): + # Convert all tensors to numpy arrays in one go + print(f"Saving ply to {dir}") + numpy_data = {k: v.detach().cpu().numpy() for k, v in splats.items()} + + means = numpy_data["means"] + scales = numpy_data["scales"] + quats = numpy_data["quats"] + opacities = numpy_data["opacities"] + sh0 = numpy_data["sh0"].reshape(means.shape[0], -1) + shN = numpy_data["shN"].reshape(means.shape[0], -1) + + ply_data = { + "positions": o3d.core.Tensor(means, dtype=o3d.core.Dtype.Float32), + "normals": o3d.core.Tensor(np.zeros_like(means), dtype=o3d.core.Dtype.Float32), + "opacity": o3d.core.Tensor( + opacities.reshape(-1, 1), dtype=o3d.core.Dtype.Float32 + ), + } + + # Add sh0 and shN data + for i, data in enumerate([sh0, shN]): + prefix = "f_dc" if i == 0 else "f_rest" + for j in range(data.shape[1]): + ply_data[f"{prefix}_{j}"] = o3d.core.Tensor( + data[:, j : j + 1], dtype=o3d.core.Dtype.Float32 + ) + + # Add scales and quats data + for name, data in [("scale", scales), ("rot", quats)]: + for i in range(data.shape[1]): + ply_data[f"{name}_{i}"] = o3d.core.Tensor( + data[:, i : i + 1], dtype=o3d.core.Dtype.Float32 + ) + + pcd = o3d.t.geometry.PointCloud(ply_data) + o3d.t.io.write_point_cloud(str(dir), pcd) + + def create_splats_with_optimizers( parser: Parser, init_type: str = "sfm", @@ -283,6 +323,8 @@ def __init__( os.makedirs(self.stats_dir, exist_ok=True) self.render_dir = f"{cfg.result_dir}/renders" os.makedirs(self.render_dir, exist_ok=True) + self.ply_dir = f"{cfg.result_dir}/ply" + os.makedirs(self.ply_dir, exist_ok=True) # Tensorboard self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") @@ -723,6 +765,7 @@ def train(self): torch.save( data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" ) + save_ply(self.splats, f"{self.ply_dir}/point_cloud.ply") if isinstance(self.cfg.strategy, DefaultStrategy): self.cfg.strategy.step_post_backward( From cba1d3ea7defe067a31a5a3b96a637af5e6ddf53 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Sat, 28 Sep 2024 21:26:53 +0200 Subject: [PATCH 2/7] fix colors --- examples/simple_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 67c293f08..73978234e 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -191,8 +191,8 @@ def save_ply(splats: torch.nn.ParameterDict, dir: str): scales = numpy_data["scales"] quats = numpy_data["quats"] opacities = numpy_data["opacities"] - sh0 = numpy_data["sh0"].reshape(means.shape[0], -1) - shN = numpy_data["shN"].reshape(means.shape[0], -1) + sh0 = numpy_data["sh0"].transpose(0, 2, 1).reshape(means.shape[0], -1).copy() + shN = numpy_data["shN"].transpose(0, 2, 1).reshape(means.shape[0], -1).copy() ply_data = { "positions": o3d.core.Tensor(means, dtype=o3d.core.Dtype.Float32), From d30a1a9bd74796175f92e94664f5da750406b00a Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Sat, 28 Sep 2024 21:36:07 +0200 Subject: [PATCH 3/7] Add own flag for saveing ply files --- examples/simple_trainer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 73978234e..f52b886b6 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -85,6 +85,8 @@ class Config: eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) # Steps to save the model save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Steps to save the model as ply + ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) # Initialization strategy init_type: str = "sfm" @@ -165,6 +167,7 @@ class Config: def adjust_steps(self, factor: float): self.eval_steps = [int(i * factor) for i in self.eval_steps] self.save_steps = [int(i * factor) for i in self.save_steps] + self.ply_steps = [int(i * factor) for i in self.ply_steps] self.max_steps = int(self.max_steps * factor) self.sh_degree_interval = int(self.sh_degree_interval * factor) @@ -765,7 +768,8 @@ def train(self): torch.save( data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" ) - save_ply(self.splats, f"{self.ply_dir}/point_cloud.ply") + if step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1: + save_ply(self.splats, f"{self.ply_dir}/point_cloud_{step}.ply") if isinstance(self.cfg.strategy, DefaultStrategy): self.cfg.strategy.step_post_backward( From 9a5336af0fbc987be58ffbb2d7a1655258a0011f Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Mon, 7 Oct 2024 09:05:13 +0200 Subject: [PATCH 4/7] fix appearance embeddings --- examples/simple_trainer.py | 47 ++++++++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index f52b886b6..3690fb87b 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -185,7 +185,7 @@ def adjust_steps(self, factor: float): assert_never(strategy) -def save_ply(splats: torch.nn.ParameterDict, dir: str): +def save_ply(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = None): # Convert all tensors to numpy arrays in one go print(f"Saving ply to {dir}") numpy_data = {k: v.detach().cpu().numpy() for k, v in splats.items()} @@ -194,9 +194,6 @@ def save_ply(splats: torch.nn.ParameterDict, dir: str): scales = numpy_data["scales"] quats = numpy_data["quats"] opacities = numpy_data["opacities"] - sh0 = numpy_data["sh0"].transpose(0, 2, 1).reshape(means.shape[0], -1).copy() - shN = numpy_data["shN"].transpose(0, 2, 1).reshape(means.shape[0], -1).copy() - ply_data = { "positions": o3d.core.Tensor(means, dtype=o3d.core.Dtype.Float32), "normals": o3d.core.Tensor(np.zeros_like(means), dtype=o3d.core.Dtype.Float32), @@ -205,13 +202,25 @@ def save_ply(splats: torch.nn.ParameterDict, dir: str): ), } - # Add sh0 and shN data - for i, data in enumerate([sh0, shN]): - prefix = "f_dc" if i == 0 else "f_rest" - for j in range(data.shape[1]): - ply_data[f"{prefix}_{j}"] = o3d.core.Tensor( - data[:, j : j + 1], dtype=o3d.core.Dtype.Float32 + if colors is not None: + color = colors.detach().cpu().numpy().copy() # + for j in range(color.shape[1]): + # Needs to be converted to shs as that's what all viewers take. + ply_data[f"f_dc_{j}"] = o3d.core.Tensor( + (color[:, j : j + 1] - 0.5) / 0.2820947917738781, + dtype=o3d.core.Dtype.Float32, ) + else: + sh0 = numpy_data["sh0"].transpose(0, 2, 1).reshape(means.shape[0], -1).copy() + shN = numpy_data["shN"].transpose(0, 2, 1).reshape(means.shape[0], -1).copy() + + # Add sh0 and shN data + for i, data in enumerate([sh0, shN]): + prefix = "f_dc" if i == 0 else "f_rest" + for j in range(data.shape[1]): + ply_data[f"{prefix}_{j}"] = o3d.core.Tensor( + data[:, j : j + 1], dtype=o3d.core.Dtype.Float32 + ) # Add scales and quats data for name, data in [("scale", scales), ("rot", quats)]: @@ -221,7 +230,9 @@ def save_ply(splats: torch.nn.ParameterDict, dir: str): ) pcd = o3d.t.geometry.PointCloud(ply_data) - o3d.t.io.write_point_cloud(str(dir), pcd) + + success = o3d.t.io.write_point_cloud(dir, pcd) + assert success, "Could not save ply file." def create_splats_with_optimizers( @@ -769,7 +780,19 @@ def train(self): data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" ) if step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1: - save_ply(self.splats, f"{self.ply_dir}/point_cloud_{step}.ply") + rgb = None + if self.cfg.app_opt: + # eval at origin to bake the appeareance into the colors + rgb = self.app_module( + features=self.splats["features"], + embed_ids=None, + dirs=torch.zeros_like(self.splats["means"][None, :, :]), + sh_degree=sh_degree_to_use, + ) + rgb = rgb + self.splats["colors"] + rgb = torch.sigmoid(rgb).squeeze(0) + + save_ply(self.splats, f"{self.ply_dir}/point_cloud_{step}.ply", rgb) if isinstance(self.cfg.strategy, DefaultStrategy): self.cfg.strategy.step_post_backward( From 36e3df32a997300f384990722e368010ba08e09a Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Tue, 8 Oct 2024 00:04:53 +0200 Subject: [PATCH 5/7] remove open3d --- examples/simple_trainer.py | 90 ++++++++++++++++++++++---------------- 1 file changed, 52 insertions(+), 38 deletions(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 3690fb87b..354763a13 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -2,7 +2,7 @@ import math import os import time -import open3d as o3d +import struct from dataclasses import dataclass, field from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union @@ -194,45 +194,59 @@ def save_ply(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = No scales = numpy_data["scales"] quats = numpy_data["quats"] opacities = numpy_data["opacities"] - ply_data = { - "positions": o3d.core.Tensor(means, dtype=o3d.core.Dtype.Float32), - "normals": o3d.core.Tensor(np.zeros_like(means), dtype=o3d.core.Dtype.Float32), - "opacity": o3d.core.Tensor( - opacities.reshape(-1, 1), dtype=o3d.core.Dtype.Float32 - ), - } - - if colors is not None: - color = colors.detach().cpu().numpy().copy() # - for j in range(color.shape[1]): - # Needs to be converted to shs as that's what all viewers take. - ply_data[f"f_dc_{j}"] = o3d.core.Tensor( - (color[:, j : j + 1] - 0.5) / 0.2820947917738781, - dtype=o3d.core.Dtype.Float32, - ) - else: - sh0 = numpy_data["sh0"].transpose(0, 2, 1).reshape(means.shape[0], -1).copy() - shN = numpy_data["shN"].transpose(0, 2, 1).reshape(means.shape[0], -1).copy() - - # Add sh0 and shN data - for i, data in enumerate([sh0, shN]): - prefix = "f_dc" if i == 0 else "f_rest" - for j in range(data.shape[1]): - ply_data[f"{prefix}_{j}"] = o3d.core.Tensor( - data[:, j : j + 1], dtype=o3d.core.Dtype.Float32 - ) - - # Add scales and quats data - for name, data in [("scale", scales), ("rot", quats)]: - for i in range(data.shape[1]): - ply_data[f"{name}_{i}"] = o3d.core.Tensor( - data[:, i : i + 1], dtype=o3d.core.Dtype.Float32 - ) - pcd = o3d.t.geometry.PointCloud(ply_data) + num_points = means.shape[0] + + with open(dir, "wb") as f: + # Write PLY header + f.write(b"ply\n") + f.write(b"format binary_little_endian 1.0\n") + f.write(f"element vertex {num_points}\n".encode()) + f.write(b"property float x\n") + f.write(b"property float y\n") + f.write(b"property float z\n") + f.write(b"property float nx\n") + f.write(b"property float ny\n") + f.write(b"property float nz\n") + f.write(b"property float opacity\n") + + if colors is not None: + for j in range(colors.shape[1]): + f.write(f"property float f_dc_{j}\n".encode()) + else: + sh0 = numpy_data["sh0"].transpose(0, 2, 1).reshape(means.shape[0], -1) + shN = numpy_data["shN"].transpose(0, 2, 1).reshape(means.shape[0], -1) + for i, data in enumerate([sh0, shN]): + prefix = "f_dc" if i == 0 else "f_rest" + for j in range(data.shape[1]): + f.write(f"property float {prefix}_{j}\n".encode()) + + for i in range(scales.shape[1]): + f.write(f"property float scale_{i}\n".encode()) + for i in range(quats.shape[1]): + f.write(f"property float rot_{i}\n".encode()) + + f.write(b"end_header\n") + + # Write vertex data + for i in range(num_points): + f.write(struct.pack(" Date: Wed, 16 Oct 2024 09:49:47 +0200 Subject: [PATCH 6/7] align order with INRIA ply --- examples/simple_trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 354763a13..b44f1a72a 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -208,7 +208,6 @@ def save_ply(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = No f.write(b"property float nx\n") f.write(b"property float ny\n") f.write(b"property float nz\n") - f.write(b"property float opacity\n") if colors is not None: for j in range(colors.shape[1]): @@ -221,6 +220,8 @@ def save_ply(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = No for j in range(data.shape[1]): f.write(f"property float {prefix}_{j}\n".encode()) + f.write(b"property float opacity\n") + for i in range(scales.shape[1]): f.write(f"property float scale_{i}\n".encode()) for i in range(quats.shape[1]): @@ -232,7 +233,6 @@ def save_ply(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = No for i in range(num_points): f.write(struct.pack(" Date: Wed, 16 Oct 2024 10:03:46 +0200 Subject: [PATCH 7/7] filter Nan and infs --- examples/simple_trainer.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index b44f1a72a..52e7a233e 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -195,6 +195,33 @@ def save_ply(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = No quats = numpy_data["quats"] opacities = numpy_data["opacities"] + sh0 = numpy_data["sh0"].transpose(0, 2, 1).reshape(means.shape[0], -1) + shN = numpy_data["shN"].transpose(0, 2, 1).reshape(means.shape[0], -1) + + # Create a mask to identify rows with NaN or Inf in any of the numpy_data arrays + invalid_mask = ( + np.isnan(means).any(axis=1) + | np.isinf(means).any(axis=1) + | np.isnan(scales).any(axis=1) + | np.isinf(scales).any(axis=1) + | np.isnan(quats).any(axis=1) + | np.isinf(quats).any(axis=1) + | np.isnan(opacities).any(axis=0) + | np.isinf(opacities).any(axis=0) + | np.isnan(sh0).any(axis=1) + | np.isinf(sh0).any(axis=1) + | np.isnan(shN).any(axis=1) + | np.isinf(shN).any(axis=1) + ) + + # Filter out rows with NaNs or Infs from all data arrays + means = means[~invalid_mask] + scales = scales[~invalid_mask] + quats = quats[~invalid_mask] + opacities = opacities[~invalid_mask] + sh0 = sh0[~invalid_mask] + shN = shN[~invalid_mask] + num_points = means.shape[0] with open(dir, "wb") as f: @@ -213,8 +240,6 @@ def save_ply(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = No for j in range(colors.shape[1]): f.write(f"property float f_dc_{j}\n".encode()) else: - sh0 = numpy_data["sh0"].transpose(0, 2, 1).reshape(means.shape[0], -1) - shN = numpy_data["shN"].transpose(0, 2, 1).reshape(means.shape[0], -1) for i, data in enumerate([sh0, shN]): prefix = "f_dc" if i == 0 else "f_rest" for j in range(data.shape[1]):