diff --git a/nerfstudio/viewer/metrics_panel.py b/nerfstudio/viewer/metrics_panel.py new file mode 100644 index 0000000000..387e1c0873 --- /dev/null +++ b/nerfstudio/viewer/metrics_panel.py @@ -0,0 +1,62 @@ +# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import dataclasses +from pathlib import Path +from typing import Literal + +import viser + +from nerfstudio.models.base_model import Model +from nerfstudio.models.splatfacto import SplatfactoModel +from nerfstudio.viewer.control_panel import ControlPanel +from nerfstudio.viewer.viewer_elements import ViewerPlot + +@dataclasses.dataclass +class MetricsPanel: + server: viser.ViserServer + control_panel: ControlPanel + config_path: Path + viewer_model: Model + +def populate_metrics_tab( + server: viser.ViserServer, + control_panel: ControlPanel, + config_path: Path, + viewer_model: Model, +) -> None: + viewing_gsplat = isinstance(viewer_model, SplatfactoModel) + + with server.add_gui_folder("Training Metrics"): + populate_train_metrics_tab(server, control_panel, config_path, viewing_gsplat) + # with server.add_gui_folder("Training Loss"): + # populate_train_loss_tab(server, control_panel, config_path, viewing_gsplat) + #training rays? + # with server.add_gui_folder("Eval Metrics"): + # populate_eval_metrics_tab(server, control_panel, config_path, viewing_gsplat) + # with server.add_gui_folder("Eval Metrics (All Images)"): + # populate_eval_metrics_all_images_tab(server, control_panel, config_path, viewing_gsplat) + # with server.add_gui_folder("Eval Loss"): + # populate_eval_loss_tab(server, control_panel, config_path, viewing_gsplat) + + +def populate_train_metrics_tab( + server: viser.ViserServer, + control_panel: ControlPanel, + config_path: Path, + viewing_gsplat: bool, +) -> None: + ViewerPlot() \ No newline at end of file diff --git a/nerfstudio/viewer/viewer.py b/nerfstudio/viewer/viewer.py index a5093f0dad..101d72fcef 100644 --- a/nerfstudio/viewer/viewer.py +++ b/nerfstudio/viewer/viewer.py @@ -40,6 +40,7 @@ from nerfstudio.utils.writer import GLOBAL_BUFFER, EventName from nerfstudio.viewer.control_panel import ControlPanel from nerfstudio.viewer.export_panel import populate_export_tab +from nerfstudio.viewer.metrics_panel import populate_metrics_tab from nerfstudio.viewer.render_panel import populate_render_tab from nerfstudio.viewer.render_state_machine import RenderAction, RenderStateMachine from nerfstudio.viewer.utils import CameraState, parse_object @@ -199,6 +200,7 @@ def __init__( self._output_split_type_change, default_composite_depth=self.config.default_composite_depth, ) + config_path = self.log_filename.parents[0] / "config.yml" with tabs.add_tab("Render", viser.Icon.CAMERA): self.render_tab_state = populate_render_tab( @@ -208,6 +210,9 @@ def __init__( with tabs.add_tab("Export", viser.Icon.PACKAGE_EXPORT): populate_export_tab(self.viser_server, self.control_panel, config_path, self.pipeline.model) + with tabs.add_tab("Metrics", viser.Icon.GRAPH): + populate_metrics_tab(self.viser_server, self.control_panel, config_path, self.pipeline.model) + # Keep track of the pointers to generated GUI folders, because each generated folder holds a unique ID. viewer_gui_folders = dict() diff --git a/nerfstudio/viewer/viewer_elements.py b/nerfstudio/viewer/viewer_elements.py index 654503a3c4..d71f6649ee 100644 --- a/nerfstudio/viewer/viewer_elements.py +++ b/nerfstudio/viewer/viewer_elements.py @@ -32,6 +32,7 @@ GuiButtonHandle, GuiDropdownHandle, GuiInputHandle, + GuiPlotlyHandle, ScenePointerEvent, ViserServer, ) @@ -43,6 +44,10 @@ if TYPE_CHECKING: from nerfstudio.viewer.viewer import Viewer +import plotly +import plotly.basedatatypes +import plotly.graph_objects as go + TValue = TypeVar("TValue") TString = TypeVar("TString", default=str, bound=str) @@ -284,7 +289,9 @@ def __init__( cb_hook: Callable = lambda element: None, ) -> None: self.name = name - self.gui_handle: Optional[Union[GuiInputHandle[TValue], GuiButtonHandle, GuiButtonGroupHandle]] = None + self.gui_handle: Optional[ + Union[GuiInputHandle[TValue], GuiButtonHandle, GuiButtonGroupHandle, GuiPlotlyHandle] + ] = None self.disabled = disabled self.visible = visible self.cb_hook = cb_hook @@ -710,3 +717,127 @@ def _create_gui_handle(self, viser_server: ViserServer) -> None: self.gui_handle = viser_server.add_gui_vector3( self.name, self.default_value, step=self.step, disabled=self.disabled, visible=self.visible, hint=self.hint ) + + +class ViewerPlot(ViewerElement[go.Figure]): + """Base class for viewer figures, using plotly backend. + Includes misc wrapper methods for setting plotly figure properties. + """ + + gui_handle: GuiPlotlyHandle + + _figure: go.Figure + """Figure to be displayed. Do not access this directly, exists only for initial statekeeping.""" + _aspect: float + """Aspect ratio of the plot (h/w). Default is 1.0.""" + _dark_mode: bool + """If the plot is in dark mode (i.e., `plotly_dark` template). Default is True. Uses `plotly` template for light mode.""" + _margin: int + """Margin of the plot. Default is 0.""" + + def __init__( + self, + figure: Optional[go.Figure] = None, + aspect: float = 1.0, + margin: int = 0, + dark_mode: bool = True, + visible: bool = True, + ): + """ + Args: + figure: The plotly figure to display -- if None, an empty figure is created. + aspect: Aspect ratio of the plot (h/w). Default is 1.0. + visible: If the plot is visible. + margin: Margin of the plot. Default is 0. + """ + self._figure = go.Figure() if figure is None else figure + self._aspect = aspect + self._margin = margin + self._dark_mode = dark_mode + super().__init__(name="", visible=visible) # plots have no name. + + def _create_gui_handle(self, viser_server: ViserServer) -> None: + self.gui_handle = viser_server.add_gui_plotly(figure=self._figure, visible=self.visible, aspect=self._aspect) + + def install(self, viser_server: ViserServer) -> None: + self._create_gui_handle(viser_server) + assert self.gui_handle is not None + + @property + def figure(self): + assert self.gui_handle is not None + return self.gui_handle.figure + + @figure.setter + def figure(self, figure: Union[go.Figure, plotly.basedatatypes.BaseTraceType]): + if isinstance(figure, plotly.basedatatypes.BaseTraceType): + figure = go.Figure(data=[figure]) + assert self.gui_handle is not None + self._figure = figure + self._update_plot() + + @property + def aspect(self): + return self._aspect + + @aspect.setter + def aspect(self, aspect: float): + self._aspect = aspect + self._update_plot() + + @property + def dark(self): + return self._dark_mode + + @dark.setter + def dark(self, dark_mode: bool): + self._dark_mode = dark_mode + self._update_plot() + + @property + def margin(self): + return self._margin + + @margin.setter + def margin(self, margin: int): + self._margin = margin + self._update_plot() + + def _update_plot(self) -> None: + """Refresh the plot with: + - the current figure + - aspect ratio + - dark mode + """ + template = "plotly_dark" if self._dark_mode else "plotly" + self._figure.update_layout(template=template) + + # Set margins. Also, set automargin for title, so that title doesn't get cut off. + self._figure.update_layout( + margin=dict(l=self._margin, r=self._margin, t=self._margin, b=self._margin), + ) + if self._margin == 0 and self._figure.layout.title.text is not None: # type: ignore + self._figure.layout.title.automargin = True # type: ignore + + if self.gui_handle is not None: + self.gui_handle.aspect = self._aspect + self.gui_handle.figure = self._figure + + @staticmethod + def plot_line(x: np.ndarray, y: np.ndarray, name: str = "", color: str = "blue") -> go.Scatter: + """Wrapper for plotting a line in a plotly figure.""" + return go.Scatter(x=x, y=y, mode="lines", name=name, line=dict(color=color)) + + @staticmethod + def plot_scatter(x: np.ndarray, y: np.ndarray, name: str = "", color: str = "blue") -> go.Scatter: + """Wrapper for plotting a scatter in a plotly figure.""" + return go.Scatter(x=x, y=y, mode="markers", name=name, marker=dict(color=color)) + + @staticmethod + def plot_image(image: np.ndarray, name: str = "") -> go.Image: + """Wrapper for plotting an image in a plotly figure. + `plotly.graph_object.Image` expects [0...255], so images [0...1] is automatically scaled here. + """ + if image.dtype != np.uint8: + image = (image * 255).astype(np.uint8) + return go.Image(z=image, name=name) diff --git a/pyproject.toml b/pyproject.toml index a5b3d39fc6..394423afb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ dependencies = [ "torchvision>=0.14.1", "torchmetrics[image]>=1.0.1", "typing_extensions>=4.4.0", - "viser==0.1.27", + "viser==0.1.29", "nuscenes-devkit>=1.1.1", "wandb>=0.13.3", "xatlas",