diff --git a/nerfstudio/configs/annotated_types.py b/nerfstudio/configs/annotated_types.py new file mode 100644 index 0000000000..8154676792 --- /dev/null +++ b/nerfstudio/configs/annotated_types.py @@ -0,0 +1,159 @@ +# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from collections import OrderedDict +from typing import TYPE_CHECKING + +import tyro + +from nerfstudio.configs.dataparser_configs import dataparser_configs +from nerfstudio.configs.external_methods import get_external_methods +from nerfstudio.configs.method_configs import descriptions as method_descriptions +from nerfstudio.configs.method_configs import method_configs +from nerfstudio.data.dataparsers.base_dataparser import DataParserConfig +from nerfstudio.engine.trainer import TrainerConfig +from nerfstudio.plugins.registry import discover_dataparsers, discover_methods + + +def merge_methods(methods, method_descriptions, new_methods, new_descriptions, overwrite=True): + """Merge new methods and descriptions into existing methods and descriptions. + Args: + methods: Existing methods. + method_descriptions: Existing descriptions. + new_methods: New methods to merge in. + new_descriptions: New descriptions to merge in. + Returns: + Merged methods and descriptions. + """ + methods = OrderedDict(**methods) + method_descriptions = OrderedDict(**method_descriptions) + for k, v in new_methods.items(): + if overwrite or k not in methods: + methods[k] = v + method_descriptions[k] = new_descriptions.get(k, "") + return methods, method_descriptions + + +def sort_methods(methods, method_descriptions): + """Sort methods and descriptions by method name.""" + methods = OrderedDict(sorted(methods.items(), key=lambda x: x[0])) + method_descriptions = OrderedDict(sorted(method_descriptions.items(), key=lambda x: x[0])) + return methods, method_descriptions + + +def to_snake_case(name: str) -> str: + """Convert a name to snake case.""" + return "".join(["_" + c.lower() if c.isupper() else c for c in name]).lstrip("_") + + +# DataParsers +external_dataparsers, _external_dataparsers_help = discover_dataparsers() +all_dataparsers = {**dataparser_configs, **external_dataparsers} + +# Methods +all_methods, all_descriptions = method_configs, method_descriptions +# Add discovered external methods +discovered_methods, discovered_descriptions = discover_methods() +all_methods, all_descriptions = merge_methods( + all_methods, all_descriptions, discovered_methods, discovered_descriptions +) +all_methods, all_descriptions = sort_methods(all_methods, all_descriptions) + +# Register all possible external methods which can be installed with Nerfstudio +all_methods, all_descriptions = merge_methods( + all_methods, all_descriptions, *sort_methods(*get_external_methods()), overwrite=False +) + +# We also register all external dataparsers found in the external methods +_registered_dataparsers = set(map(type, all_dataparsers.values())) +for method_name, method in discovered_methods.items(): + if not hasattr(method.pipeline.datamanager, "dataparser"): + continue + dataparser = method.pipeline.datamanager.dataparser # type: ignore + if type(dataparser) not in _registered_dataparsers: + name = method_name + "-" + to_snake_case(type(dataparser).__name__).replace("_", "-") + all_dataparsers[name] = dataparser + +if TYPE_CHECKING: + # For static analysis (tab completion, type checking, etc), just use the base + # dataparser config. + DataParserUnion = DataParserConfig +else: + # At runtime, populate a Union type dynamically. This is used by `tyro` to generate + # subcommands in the CLI. + DataParserUnion = tyro.extras.subcommand_type_from_defaults( + all_dataparsers, + prefix_names=False, # Omit prefixes in subcommands themselves. + ) + +AnnotatedDataParserUnion = tyro.conf.OmitSubcommandPrefixes[DataParserUnion] # Omit prefixes of flags in subcommands. +"""Union over possible dataparser types, annotated with metadata for tyro. This is +the same as the vanilla union, but results in shorter subcommand names.""" + + +def fix_method_dataparser_type(method_config: TrainerConfig): + def replace_type(instance, field, value=None, new_type=None): + assert dataclasses.is_dataclass(instance) + if value is None: + value = getattr(instance, field) + if new_type is None: + new_type = type(value) + cls = type(instance) + fields = dataclasses.fields(cls) + new_fields = [(field, new_type)] + config = cls.__dataclass_params__ + new_cls = dataclasses.make_dataclass( + cls.__name__, + new_fields, + bases=(cls,), + init=config.init, + repr=config.repr, + eq=config.eq, + order=config.order, + unsafe_hash=config.unsafe_hash, + frozen=config.frozen, + ) + kwargs = {i.name: getattr(instance, i.name) for i in fields} + kwargs[field] = value + return new_cls(**kwargs) + + if hasattr(method_config.pipeline.datamanager, "dataparser"): + return replace_type( + method_config, + "pipeline", + replace_type( + method_config.pipeline, + "datamanager", + replace_type( + method_config.pipeline.datamanager, + "dataparser", + new_type=AnnotatedDataParserUnion, + ), + ), + ) + return method_config + + +for key, method in all_methods.items(): + all_methods[key] = fix_method_dataparser_type(method) + +AnnotatedBaseConfigUnion = tyro.conf.SuppressFixed[ # Don't show unparseable (fixed) arguments in helptext. + tyro.conf.FlagConversionOff[ + tyro.extras.subcommand_type_from_defaults(defaults=all_methods, descriptions=all_descriptions) + ] +] +"""Union[] type over config types, annotated with default instances for use with +tyro.cli(). Allows the user to pick between one of several base configurations, and +then override values in it.""" diff --git a/nerfstudio/configs/dataparser_configs.py b/nerfstudio/configs/dataparser_configs.py index 46a7281411..c617b6d85f 100644 --- a/nerfstudio/configs/dataparser_configs.py +++ b/nerfstudio/configs/dataparser_configs.py @@ -16,12 +16,7 @@ Aggregate all the dataparser configs in one location. """ -from typing import TYPE_CHECKING - -import tyro - from nerfstudio.data.dataparsers.arkitscenes_dataparser import ARKitScenesDataParserConfig -from nerfstudio.data.dataparsers.base_dataparser import DataParserConfig from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig from nerfstudio.data.dataparsers.colmap_dataparser import ColmapDataParserConfig from nerfstudio.data.dataparsers.dnerf_dataparser import DNeRFDataParserConfig @@ -35,9 +30,8 @@ from nerfstudio.data.dataparsers.scannet_dataparser import ScanNetDataParserConfig from nerfstudio.data.dataparsers.sdfstudio_dataparser import SDFStudioDataParserConfig from nerfstudio.data.dataparsers.sitcoms3d_dataparser import Sitcoms3DDataParserConfig -from nerfstudio.plugins.registry_dataparser import discover_dataparsers -dataparsers = { +dataparser_configs = { "nerfstudio-data": NerfstudioDataParserConfig(), "minimal-parser": MinimalDataParserConfig(), "arkit-data": ARKitScenesDataParserConfig(), @@ -53,22 +47,3 @@ "sitcoms3d-data": Sitcoms3DDataParserConfig(), "colmap": ColmapDataParserConfig(), } - -external_dataparsers = discover_dataparsers() -all_dataparsers = {**dataparsers, **external_dataparsers} - -if TYPE_CHECKING: - # For static analysis (tab completion, type checking, etc), just use the base - # dataparser config. - DataParserUnion = DataParserConfig -else: - # At runtime, populate a Union type dynamically. This is used by `tyro` to generate - # subcommands in the CLI. - DataParserUnion = tyro.extras.subcommand_type_from_defaults( - all_dataparsers, - prefix_names=False, # Omit prefixes in subcommands themselves. - ) - -AnnotatedDataParserUnion = tyro.conf.OmitSubcommandPrefixes[DataParserUnion] # Omit prefixes of flags in subcommands. -"""Union over possible dataparser types, annotated with metadata for tyro. This is -the same as the vanilla union, but results in shorter subcommand names.""" diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index b8ec7fedef..dbc73f3136 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -18,18 +18,13 @@ from __future__ import annotations -from collections import OrderedDict from typing import Dict -import tyro from nerfstudio.cameras.camera_optimizers import CameraOptimizerConfig from nerfstudio.configs.base_config import ViewerConfig -from nerfstudio.configs.external_methods import get_external_methods - -from nerfstudio.data.datamanagers.random_cameras_datamanager import RandomCamerasDataManagerConfig from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager, VanillaDataManagerConfig - +from nerfstudio.data.datamanagers.random_cameras_datamanager import RandomCamerasDataManagerConfig from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig from nerfstudio.data.dataparsers.dnerf_dataparser import DNeRFDataParserConfig from nerfstudio.data.dataparsers.instant_ngp_dataparser import InstantNGPDataParserConfig @@ -61,7 +56,6 @@ from nerfstudio.models.vanilla_nerf import NeRFModel, VanillaModelConfig from nerfstudio.pipelines.base_pipeline import VanillaPipelineConfig from nerfstudio.pipelines.dynamic_batch import DynamicBatchPipelineConfig -from nerfstudio.plugins.registry import discover_methods method_configs: Dict[str, TrainerConfig] = {} descriptions = { @@ -301,7 +295,6 @@ vis="viewer", ) - method_configs["instant-ngp-bounded"] = TrainerConfig( method_name="instant-ngp-bounded", steps_per_eval_batch=500, @@ -617,49 +610,3 @@ viewer=ViewerConfig(num_rays_per_chunk=1 << 15), vis="viewer", ) - - -def merge_methods(methods, method_descriptions, new_methods, new_descriptions, overwrite=True): - """Merge new methods and descriptions into existing methods and descriptions. - Args: - methods: Existing methods. - method_descriptions: Existing descriptions. - new_methods: New methods to merge in. - new_descriptions: New descriptions to merge in. - Returns: - Merged methods and descriptions. - """ - methods = OrderedDict(**methods) - method_descriptions = OrderedDict(**method_descriptions) - for k, v in new_methods.items(): - if overwrite or k not in methods: - methods[k] = v - method_descriptions[k] = new_descriptions.get(k, "") - return methods, method_descriptions - - -def sort_methods(methods, method_descriptions): - """Sort methods and descriptions by method name.""" - methods = OrderedDict(sorted(methods.items(), key=lambda x: x[0])) - method_descriptions = OrderedDict(sorted(method_descriptions.items(), key=lambda x: x[0])) - return methods, method_descriptions - - -all_methods, all_descriptions = method_configs, descriptions -# Add discovered external methods -all_methods, all_descriptions = merge_methods(all_methods, all_descriptions, *discover_methods()) -all_methods, all_descriptions = sort_methods(all_methods, all_descriptions) - -# Register all possible external methods which can be installed with Nerfstudio -all_methods, all_descriptions = merge_methods( - all_methods, all_descriptions, *sort_methods(*get_external_methods()), overwrite=False -) - -AnnotatedBaseConfigUnion = tyro.conf.SuppressFixed[ # Don't show unparseable (fixed) arguments in helptext. - tyro.conf.FlagConversionOff[ - tyro.extras.subcommand_type_from_defaults(defaults=all_methods, descriptions=all_descriptions) - ] -] -"""Union[] type over config types, annotated with default instances for use with -tyro.cli(). Allows the user to pick between one of several base configurations, and -then override values in it.""" diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index e48e3621ad..892c77b679 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -21,12 +21,13 @@ from abc import abstractmethod from collections import defaultdict from dataclasses import dataclass, field -from pathlib import Path from functools import cached_property +from pathlib import Path from typing import ( Any, Callable, Dict, + ForwardRef, Generic, List, Literal, @@ -35,9 +36,8 @@ Type, Union, cast, - ForwardRef, - get_origin, get_args, + get_origin, ) import torch @@ -50,26 +50,16 @@ from nerfstudio.cameras.cameras import CameraType from nerfstudio.cameras.rays import RayBundle from nerfstudio.configs.base_config import InstantiateConfig -from nerfstudio.configs.dataparser_configs import AnnotatedDataParserUnion -from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs +from nerfstudio.data.dataparsers.base_dataparser import DataParserConfig, DataparserOutputs from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig from nerfstudio.data.datasets.base_dataset import InputDataset -from nerfstudio.data.pixel_samplers import ( - EquirectangularPixelSampler, - PatchPixelSampler, - PixelSampler, -) -from nerfstudio.data.utils.dataloaders import ( - CacheDataloader, - FixedIndicesEvalDataloader, - RandIndicesEvalDataloader, -) +from nerfstudio.data.pixel_samplers import EquirectangularPixelSampler, PatchPixelSampler, PixelSampler +from nerfstudio.data.utils.dataloaders import CacheDataloader, FixedIndicesEvalDataloader, RandIndicesEvalDataloader from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes from nerfstudio.model_components.ray_generators import RayGenerator -from nerfstudio.utils.misc import IterableWrapper +from nerfstudio.utils.misc import IterableWrapper, get_orig_class from nerfstudio.utils.rich_utils import CONSOLE -from nerfstudio.utils.misc import get_orig_class def variable_res_collate(batch: List[Dict]) -> Dict: @@ -317,7 +307,7 @@ class VanillaDataManagerConfig(DataManagerConfig): _target: Type = field(default_factory=lambda: VanillaDataManager) """Target class to instantiate.""" - dataparser: AnnotatedDataParserUnion = BlenderDataParserConfig() + dataparser: DataParserConfig = BlenderDataParserConfig() """Specifies the dataparser used to unpack the data.""" train_num_rays_per_batch: int = 1024 """Number of rays per batch to use per training iteration.""" diff --git a/nerfstudio/plugins/registry.py b/nerfstudio/plugins/registry.py index 8087726782..357f0ccee2 100644 --- a/nerfstudio/plugins/registry.py +++ b/nerfstudio/plugins/registry.py @@ -21,8 +21,9 @@ import sys import typing as t +from nerfstudio.data.dataparsers.base_dataparser import DataParserConfig from nerfstudio.engine.trainer import TrainerConfig -from nerfstudio.plugins.types import MethodSpecification +from nerfstudio.plugins.types import DataParserSpecification, MethodSpecification from nerfstudio.utils.rich_utils import CONSOLE if sys.version_info < (3, 10): @@ -68,3 +69,44 @@ def discover_methods() -> t.Tuple[t.Dict[str, TrainerConfig], t.Dict[str, str]]: CONSOLE.print("[bold red]Error: Could not load methods from environment variable NERFSTUDIO_METHOD_CONFIGS") return methods, descriptions + + +def discover_dataparsers() -> t.Tuple[t.Dict[str, DataParserConfig], t.Dict[str, str]]: + """ + Discovers all DataParsers registered using the `nerfstudio.dataparser_configs` entrypoint. + And also DataParsers in the NERFSTUDIO_DATAPARSER_CONFIGS environment variable. + """ + configs = {} + descriptions = {} + discovered_entry_points = entry_points(group="nerfstudio.dataparser_configs") + for name in discovered_entry_points.names: + spec = discovered_entry_points[name].load() + if not isinstance(spec, DataParserSpecification): + CONSOLE.print( + f"[bold yellow]Warning: Could not entry point {spec} as it is not an instance of DataParserSpecification" + ) + continue + spec = t.cast(DataParserSpecification, spec) + configs[name] = spec.config + descriptions[name] = spec.description + + if "NERFSTUDIO_DATAPARSER_CONFIGS" in os.environ: + try: + strings = os.environ["NERFSTUDIO_DATAPARSER_CONFIGS"].split(",") + for definition in strings: + if not definition: + continue + name, path = definition.split("=") + CONSOLE.print(f"[bold green]Info: Loading method {name} from environment variable") + module, config_name = path.split(":") + dataparser_config = getattr(importlib.import_module(module), config_name) + assert isinstance(dataparser_config, DataParserSpecification) + configs[name] = dataparser_config.config + descriptions[name] = dataparser_config.description + except Exception: + CONSOLE.print_exception() + CONSOLE.print( + "[bold red]Error: Could not load methods from environment variable NERFSTUDIO_DATAPARSER_CONFIGS" + ) + + return configs, descriptions diff --git a/nerfstudio/plugins/registry_dataparser.py b/nerfstudio/plugins/registry_dataparser.py index f15fdc78cf..f0c72b7a7d 100644 --- a/nerfstudio/plugins/registry_dataparser.py +++ b/nerfstudio/plugins/registry_dataparser.py @@ -13,49 +13,22 @@ # limitations under the License. """ -Module that keeps all registered plugins and allows for plugin discovery. +[LEGACY] Module that keeps all registered plugins and allows for plugin discovery. """ -import sys -import typing as t -from dataclasses import dataclass -from rich.progress import Console +import warnings +from typing import Dict -from nerfstudio.data.dataparsers.base_dataparser import DataParserConfig +from .registry import discover_dataparsers as new_discover_dataparsers +from .types import DataParserConfig # pylint: disable=unused-import -if sys.version_info < (3, 10): - from importlib_metadata import entry_points -else: - from importlib.metadata import entry_points -CONSOLE = Console(width=120) +warnings.warn( + "This module is deprecated and will be removed in future releases. Use nerfstudio.plugins.registry and nerfstudio.plugins.types instead!", + DeprecationWarning, +) -@dataclass -class DataParserSpecification: - """ - DataParser specification class used to register custom dataparsers with Nerfstudio. - The registered dataparsers will be available in commands such as `ns-train` - """ - - config: DataParserConfig - """Dataparser configuration""" - - -def discover_dataparsers() -> t.Dict[str, DataParserConfig]: - """ - Discovers all dataparsers registered using the `nerfstudio.dataparser_configs` entrypoint. - """ - dataparsers = {} - discovered_entry_points = entry_points(group="nerfstudio.dataparser_configs") - for name in discovered_entry_points.names: - spec = discovered_entry_points[name].load() - if not isinstance(spec, DataParserSpecification): - CONSOLE.print( - f"[bold yellow]Warning: Could not entry point {spec} as it is an instance of DataParserSpecification" - ) - continue - spec = t.cast(DataParserSpecification, spec) - dataparsers[name] = spec.config - - return dataparsers +def discover_dataparsers() -> Dict[str, DataParserConfig]: + configs, _docs = new_discover_dataparsers() + return configs diff --git a/nerfstudio/plugins/types.py b/nerfstudio/plugins/types.py index cfdb54b55c..acb9899cdd 100644 --- a/nerfstudio/plugins/types.py +++ b/nerfstudio/plugins/types.py @@ -16,7 +16,9 @@ This package contains specifications used to register plugins. """ from dataclasses import dataclass +from typing import Optional +from nerfstudio.data.dataparsers.base_dataparser import DataParserConfig from nerfstudio.engine.trainer import TrainerConfig @@ -31,3 +33,16 @@ class MethodSpecification: """Trainer configuration""" description: str """Method description shown in `ns-train` help""" + + +@dataclass +class DataParserSpecification: + """ + DataParser specification class used to register custom dataparsers with Nerfstudio. + The registered dataparsers will be available in commands such as `ns-train` + """ + + config: DataParserConfig + """DataParser configuration""" + description: Optional[str] = None + """DataParser description shown in `ns-train` help""" diff --git a/nerfstudio/scripts/train.py b/nerfstudio/scripts/train.py index 9ca18c71a7..4e0b9856a5 100644 --- a/nerfstudio/scripts/train.py +++ b/nerfstudio/scripts/train.py @@ -58,8 +58,8 @@ import tyro import yaml +from nerfstudio.configs.annotated_types import AnnotatedBaseConfigUnion from nerfstudio.configs.config_utils import convert_markup_to_ansi -from nerfstudio.configs.method_configs import AnnotatedBaseConfigUnion from nerfstudio.engine.trainer import TrainerConfig from nerfstudio.utils import comms, profiler from nerfstudio.utils.rich_utils import CONSOLE diff --git a/nerfstudio/utils/eval_utils.py b/nerfstudio/utils/eval_utils.py index 13948678bb..cdea58e494 100644 --- a/nerfstudio/utils/eval_utils.py +++ b/nerfstudio/utils/eval_utils.py @@ -25,7 +25,7 @@ import torch import yaml -from nerfstudio.configs.method_configs import all_methods +from nerfstudio.configs.annotated_types import all_methods from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManagerConfig from nerfstudio.engine.trainer import TrainerConfig from nerfstudio.pipelines.base_pipeline import Pipeline