diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 73978234..f52b886b 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -85,6 +85,8 @@ class Config: eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) # Steps to save the model save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Steps to save the model as ply + ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) # Initialization strategy init_type: str = "sfm" @@ -165,6 +167,7 @@ class Config: def adjust_steps(self, factor: float): self.eval_steps = [int(i * factor) for i in self.eval_steps] self.save_steps = [int(i * factor) for i in self.save_steps] + self.ply_steps = [int(i * factor) for i in self.ply_steps] self.max_steps = int(self.max_steps * factor) self.sh_degree_interval = int(self.sh_degree_interval * factor) @@ -765,7 +768,8 @@ def train(self): torch.save( data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" ) - save_ply(self.splats, f"{self.ply_dir}/point_cloud.ply") + if step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1: + save_ply(self.splats, f"{self.ply_dir}/point_cloud_{step}.ply") if isinstance(self.cfg.strategy, DefaultStrategy): self.cfg.strategy.step_post_backward(