Skip to content

Commit

Permalink
bugfix: handle multiple camera types in a single batch
Browse files Browse the repository at this point in the history
  • Loading branch information
decrispell committed Sep 5, 2024
1 parent 69ae4ab commit c2ad294
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions nerfstudio/cameras/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,15 +778,15 @@ def _compute_rays_for_vr180(

return vr180_origins, directions_stack

for cam in cam_types:
if CameraType.PERSPECTIVE.value in cam_types:
for cam_type in cam_types:
if CameraType.PERSPECTIVE.value == cam_type:
mask = (self.camera_type[true_indices] == CameraType.PERSPECTIVE.value).squeeze(-1) # (num_rays)
mask = torch.stack([mask, mask, mask], dim=0)
directions_stack[..., 0][mask] = torch.masked_select(coord_stack[..., 0], mask).float()
directions_stack[..., 1][mask] = torch.masked_select(coord_stack[..., 1], mask).float()
directions_stack[..., 2][mask] = -1.0

elif CameraType.FISHEYE.value in cam_types:
elif CameraType.FISHEYE.value == cam_type:
mask = (self.camera_type[true_indices] == CameraType.FISHEYE.value).squeeze(-1) # (num_rays)
mask = torch.stack([mask, mask, mask], dim=0)

Expand All @@ -803,7 +803,7 @@ def _compute_rays_for_vr180(
).float()
directions_stack[..., 2][mask] = -torch.masked_select(torch.cos(theta), mask).float()

elif CameraType.EQUIRECTANGULAR.value in cam_types:
elif CameraType.EQUIRECTANGULAR.value == cam_type:
mask = (self.camera_type[true_indices] == CameraType.EQUIRECTANGULAR.value).squeeze(-1) # (num_rays)
mask = torch.stack([mask, mask, mask], dim=0)

Expand All @@ -816,22 +816,22 @@ def _compute_rays_for_vr180(
directions_stack[..., 1][mask] = torch.masked_select(torch.cos(phi), mask).float()
directions_stack[..., 2][mask] = torch.masked_select(-torch.cos(theta) * torch.sin(phi), mask).float()

elif CameraType.OMNIDIRECTIONALSTEREO_L.value in cam_types:
elif CameraType.OMNIDIRECTIONALSTEREO_L.value == cam_type:
ods_origins_circle, directions_stack = _compute_rays_for_omnidirectional_stereo("left")
# assign final camera origins
c2w[..., :3, 3] = ods_origins_circle

elif CameraType.OMNIDIRECTIONALSTEREO_R.value in cam_types:
elif CameraType.OMNIDIRECTIONALSTEREO_R.value == cam_type:
ods_origins_circle, directions_stack = _compute_rays_for_omnidirectional_stereo("right")
# assign final camera origins
c2w[..., :3, 3] = ods_origins_circle

elif CameraType.VR180_L.value in cam_types:
elif CameraType.VR180_L.value == cam_type:
vr180_origins, directions_stack = _compute_rays_for_vr180("left")
# assign final camera origins
c2w[..., :3, 3] = vr180_origins

elif CameraType.VR180_R.value in cam_types:
elif CameraType.VR180_R.value == cam_type:
vr180_origins, directions_stack = _compute_rays_for_vr180("right")
# assign final camera origins
c2w[..., :3, 3] = vr180_origins
Expand Down Expand Up @@ -880,7 +880,7 @@ def _compute_rays_for_vr180(
directions_stack[coord_mask] = camera_utils.fisheye624_unproject(masked_coords, camera_params)

else:
raise ValueError(f"Camera type {cam} not supported.")
raise ValueError(f"Camera type {cam_type} not supported.")

assert directions_stack.shape == (3,) + num_rays_shape + (3,)

Expand Down

0 comments on commit c2ad294

Please sign in to comment.