Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
… into main
  • Loading branch information
Anthony-Tafoya committed Sep 11, 2024
2 parents be44674 + e7b7dc9 commit e16135e
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 22 deletions.
18 changes: 9 additions & 9 deletions nerfstudio/cameras/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,15 +778,15 @@ def _compute_rays_for_vr180(

return vr180_origins, directions_stack

for cam in cam_types:
if CameraType.PERSPECTIVE.value in cam_types:
for cam_type in cam_types:
if CameraType.PERSPECTIVE.value == cam_type:
mask = (self.camera_type[true_indices] == CameraType.PERSPECTIVE.value).squeeze(-1) # (num_rays)
mask = torch.stack([mask, mask, mask], dim=0)
directions_stack[..., 0][mask] = torch.masked_select(coord_stack[..., 0], mask).float()
directions_stack[..., 1][mask] = torch.masked_select(coord_stack[..., 1], mask).float()
directions_stack[..., 2][mask] = -1.0

elif CameraType.FISHEYE.value in cam_types:
elif CameraType.FISHEYE.value == cam_type:
mask = (self.camera_type[true_indices] == CameraType.FISHEYE.value).squeeze(-1) # (num_rays)
mask = torch.stack([mask, mask, mask], dim=0)

Expand All @@ -803,7 +803,7 @@ def _compute_rays_for_vr180(
).float()
directions_stack[..., 2][mask] = -torch.masked_select(torch.cos(theta), mask).float()

elif CameraType.EQUIRECTANGULAR.value in cam_types:
elif CameraType.EQUIRECTANGULAR.value == cam_type:
mask = (self.camera_type[true_indices] == CameraType.EQUIRECTANGULAR.value).squeeze(-1) # (num_rays)
mask = torch.stack([mask, mask, mask], dim=0)

Expand All @@ -816,22 +816,22 @@ def _compute_rays_for_vr180(
directions_stack[..., 1][mask] = torch.masked_select(torch.cos(phi), mask).float()
directions_stack[..., 2][mask] = torch.masked_select(-torch.cos(theta) * torch.sin(phi), mask).float()

elif CameraType.OMNIDIRECTIONALSTEREO_L.value in cam_types:
elif CameraType.OMNIDIRECTIONALSTEREO_L.value == cam_type:
ods_origins_circle, directions_stack = _compute_rays_for_omnidirectional_stereo("left")
# assign final camera origins
c2w[..., :3, 3] = ods_origins_circle

elif CameraType.OMNIDIRECTIONALSTEREO_R.value in cam_types:
elif CameraType.OMNIDIRECTIONALSTEREO_R.value == cam_type:
ods_origins_circle, directions_stack = _compute_rays_for_omnidirectional_stereo("right")
# assign final camera origins
c2w[..., :3, 3] = ods_origins_circle

elif CameraType.VR180_L.value in cam_types:
elif CameraType.VR180_L.value == cam_type:
vr180_origins, directions_stack = _compute_rays_for_vr180("left")
# assign final camera origins
c2w[..., :3, 3] = vr180_origins

elif CameraType.VR180_R.value in cam_types:
elif CameraType.VR180_R.value == cam_type:
vr180_origins, directions_stack = _compute_rays_for_vr180("right")
# assign final camera origins
c2w[..., :3, 3] = vr180_origins
Expand Down Expand Up @@ -880,7 +880,7 @@ def _compute_rays_for_vr180(
directions_stack[coord_mask] = camera_utils.fisheye624_unproject(masked_coords, camera_params)

else:
raise ValueError(f"Camera type {cam} not supported.")
raise ValueError(f"Camera type {cam_type} not supported.")

assert directions_stack.shape == (3,) + num_rays_shape + (3,)

Expand Down
8 changes: 6 additions & 2 deletions nerfstudio/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class TrainerConfig(ExperimentConfig):
"""Optionally log gradients during training"""
gradient_accumulation_steps: Dict[str, int] = field(default_factory=lambda: {})
"""Number of steps to accumulate gradients over. Contains a mapping of {param_group:num}"""
start_paused: bool = False
"""Whether to start the training in a paused state."""


class Trainer:
Expand Down Expand Up @@ -121,7 +123,9 @@ def __init__(self, config: TrainerConfig, local_rank: int = 0, world_size: int =
self.device += f":{local_rank}"
self.mixed_precision: bool = self.config.mixed_precision
self.use_grad_scaler: bool = self.mixed_precision or self.config.use_grad_scaler
self.training_state: Literal["training", "paused", "completed"] = "training"
self.training_state: Literal["training", "paused", "completed"] = (
"paused" if self.config.start_paused else "training"
)
self.gradient_accumulation_steps: DefaultDict = defaultdict(lambda: 1)
self.gradient_accumulation_steps.update(self.config.gradient_accumulation_steps)

Expand Down Expand Up @@ -361,7 +365,7 @@ def _init_viewer_state(self) -> None:
assert self.viewer_state and self.pipeline.datamanager.train_dataset
self.viewer_state.init_scene(
train_dataset=self.pipeline.datamanager.train_dataset,
train_state="training",
train_state=self.training_state,
eval_dataset=self.pipeline.datamanager.eval_dataset,
)

Expand Down
6 changes: 1 addition & 5 deletions nerfstudio/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,7 @@ def get_rgba_image(self, outputs: Dict[str, torch.Tensor], output_name: str = "r
RGBA image.
"""
accumulation_name = output_name.replace("rgb", "accumulation")
if (
not hasattr(self, "renderer_rgb")
or not hasattr(self.renderer_rgb, "background_color")
or accumulation_name not in outputs
):
if accumulation_name not in outputs:
raise NotImplementedError(f"get_rgba_image is not implemented for model {self.__class__.__name__}")
rgb = outputs[output_name]
if self.renderer_rgb.background_color == "random": # type: ignore
Expand Down
11 changes: 10 additions & 1 deletion nerfstudio/scripts/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ def _render_trajectory_video(
outputs = pipeline.model.get_outputs_for_camera(
cameras[camera_idx : camera_idx + 1], obb_box=obb_box
)
if rendered_output_names is not None and "rgba" in rendered_output_names:
rgba = pipeline.model.get_rgba_image(outputs=outputs, output_name="rgb")
outputs["rgba"] = rgba

render_image = []
for rendered_output_name in rendered_output_names:
Expand All @@ -221,6 +224,8 @@ def _render_trajectory_video(
.cpu()
.numpy()
)
elif rendered_output_name == "rgba":
output_image = output_image.detach().cpu().numpy()
else:
output_image = (
colormaps.apply_colormap(
Expand Down Expand Up @@ -790,6 +795,9 @@ def update_config(config: TrainerConfig) -> TrainerConfig:
for camera_idx, (camera, batch) in enumerate(progress.track(dataloader, total=len(dataset))):
with torch.no_grad():
outputs = pipeline.model.get_outputs_for_camera(camera)
if self.rendered_output_names is not None and "rgba" in self.rendered_output_names:
rgba = pipeline.model.get_rgba_image(outputs=outputs, output_name="rgb")
outputs["rgba"] = rgba

gt_batch = batch.copy()
gt_batch["rgb"] = gt_batch.pop("image")
Expand Down Expand Up @@ -841,11 +849,12 @@ def update_config(config: TrainerConfig) -> TrainerConfig:
output_image = gt_batch[output_name]
else:
output_image = outputs[output_name]
del output_name

# Map to color spaces / numpy
if is_raw:
output_image = output_image.cpu().numpy()
elif output_name == "rgba":
output_image = output_image.detach().cpu().numpy()
elif is_depth:
output_image = (
colormaps.apply_depth_colormap(
Expand Down
12 changes: 9 additions & 3 deletions nerfstudio/viewer/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,10 @@ def __init__(
self.output_type_changed = True
self.output_split_type_changed = True
self.step = 0
self.train_btn_state: Literal["training", "paused", "completed"] = "training"
self._prev_train_state: Literal["training", "paused", "completed"] = "training"
self.train_btn_state: Literal["training", "paused", "completed"] = (
"training" if self.trainer is None else self.trainer.training_state
)
self._prev_train_state: Literal["training", "paused", "completed"] = self.train_btn_state
self.last_move_time = 0
# track the camera index that last being clicked
self.current_camera_idx = 0
Expand Down Expand Up @@ -174,7 +176,11 @@ def __init__(
)
self.resume_train.on_click(lambda _: self.toggle_pause_button())
self.resume_train.on_click(lambda han: self._toggle_training_state(han))
self.resume_train.visible = False
if self.train_btn_state == "training":
self.resume_train.visible = False
else:
self.pause_train.visible = False

# Add buttons to toggle training image visibility
self.hide_images = self.viser_server.gui.add_button(
label="Hide Train Cams", disabled=False, icon=viser.Icon.EYE_OFF, color=None
Expand Down
6 changes: 4 additions & 2 deletions nerfstudio/viewer_legacy/server/viewer_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ def __init__(
self.output_type_changed = True
self.output_split_type_changed = True
self.step = 0
self.train_btn_state: Literal["training", "paused", "completed"] = "training"
self._prev_train_state: Literal["training", "paused", "completed"] = "training"
self.train_btn_state: Literal["training", "paused", "completed"] = (
"training" if self.trainer is None else self.trainer.training_state
)
self._prev_train_state: Literal["training", "paused", "completed"] = self.train_btn_state

self.camera_message = None

Expand Down

0 comments on commit e16135e

Please sign in to comment.