From b414397493337b70fc8f4454287088ae67b7919e Mon Sep 17 00:00:00 2001 From: Zhensheng Yuan <564361+yzslab@users.noreply.github.com> Date: Thu, 21 Nov 2024 19:00:28 +0800 Subject: [PATCH] Fix merging and normalized averaged embedding fusing --- utils/fuse_appearance_embeddings_into_shs_dc.py | 7 ++++++- utils/merge_partitions_v2.py | 8 ++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/utils/fuse_appearance_embeddings_into_shs_dc.py b/utils/fuse_appearance_embeddings_into_shs_dc.py index 3dda632..ce1e0fe 100644 --- a/utils/fuse_appearance_embeddings_into_shs_dc.py +++ b/utils/fuse_appearance_embeddings_into_shs_dc.py @@ -180,10 +180,15 @@ def average_embedding_fusing( # merge `n_average_cameras` embedding to a single embedding final_appearance_embeddings = torch.sum(weighted_appearance_embeddings, dim=1) + gaussian_appearance_features = gaussian_model.get_appearance_features() + if renderer.model.config.normalize: + final_appearance_embeddings = torch.nn.functional.normalize(final_appearance_embeddings, dim=-1) + gaussian_appearance_features = torch.nn.functional.normalize(gaussian_appearance_features, dim=-1) + # embedding network forward, output rgb_offset embedding_network = renderer.model.network input_tensor_list = [ - gaussian_model.get_appearance_features(), + gaussian_appearance_features, final_appearance_embeddings, ] diff --git a/utils/merge_partitions_v2.py b/utils/merge_partitions_v2.py index 4e28a02..34850c4 100644 --- a/utils/merge_partitions_v2.py +++ b/utils/merge_partitions_v2.py @@ -167,10 +167,6 @@ def main(): orientation_transformation, ) - if isinstance(gaussian_model, MipSplattingModelMixin): - t.set_postfix_str("Fusing MipSplatting filters...") - fuse_mip_filters(gaussian_model) - if isinstance(gaussian_model, AppearanceFeatureGaussianModel): with open(os.path.join( os.path.dirname(os.path.dirname(ckpt_file)), @@ -209,6 +205,10 @@ def main(): image_name_to_camera=image_name_to_camera, ) + if isinstance(gaussian_model, MipSplattingModelMixin): + t.set_postfix_str("Fusing MipSplatting filters...") + fuse_mip_filters(gaussian_model) + if args.preprocess: update_ckpt(ckpt, {k: gaussian_model.get_property(k) for k in MERGABLE_PROPERTY_NAMES}, gaussian_model.max_sh_degree) torch.save(ckpt, os.path.join(