From 795161945b37747709d4da965b226a19fdf87d3f Mon Sep 17 00:00:00 2001 From: Jenia Golbstein Date: Sat, 23 Nov 2024 22:27:10 +0300 Subject: [PATCH] use mask in 2dgs (#497) --- examples/simple_trainer_2dgs.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/simple_trainer_2dgs.py b/examples/simple_trainer_2dgs.py index be2e4e93..10900858 100644 --- a/examples/simple_trainer_2dgs.py +++ b/examples/simple_trainer_2dgs.py @@ -577,6 +577,10 @@ def train(self): step=step, info=info, ) + masks = data["mask"].to(device) if "mask" in data else None + if masks is not None: + pixels = pixels * masks[..., None] + colors = colors * masks[..., None] # loss l1loss = F.l1_loss(colors, pixels)