Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify external projects, unify plugin API #2146

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions nerfstudio/configs/annotated_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
from collections import OrderedDict
from typing import TYPE_CHECKING

import tyro

from nerfstudio.configs.dataparser_configs import dataparser_configs
from nerfstudio.configs.external_methods import get_external_methods
from nerfstudio.configs.method_configs import descriptions as method_descriptions
from nerfstudio.configs.method_configs import method_configs
from nerfstudio.data.dataparsers.base_dataparser import DataParserConfig
from nerfstudio.engine.trainer import TrainerConfig
from nerfstudio.plugins.registry import discover_dataparsers, discover_methods


def merge_methods(methods, method_descriptions, new_methods, new_descriptions, overwrite=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while we're making changes, can we add annotations here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

"""Merge new methods and descriptions into existing methods and descriptions.
Args:
methods: Existing methods.
method_descriptions: Existing descriptions.
new_methods: New methods to merge in.
new_descriptions: New descriptions to merge in.
Returns:
Merged methods and descriptions.
"""
methods = OrderedDict(**methods)
method_descriptions = OrderedDict(**method_descriptions)
brentyi marked this conversation as resolved.
Show resolved Hide resolved
for k, v in new_methods.items():
if overwrite or k not in methods:
methods[k] = v
method_descriptions[k] = new_descriptions.get(k, "")
return methods, method_descriptions


def sort_methods(methods, method_descriptions):
"""Sort methods and descriptions by method name."""
methods = OrderedDict(sorted(methods.items(), key=lambda x: x[0]))
method_descriptions = OrderedDict(sorted(method_descriptions.items(), key=lambda x: x[0]))
return methods, method_descriptions


def to_snake_case(name: str) -> str:
"""Convert a name to snake case."""
return "".join(["_" + c.lower() if c.isupper() else c for c in name]).lstrip("_")


# DataParsers
external_dataparsers, _external_dataparsers_help = discover_dataparsers()
all_dataparsers = {**dataparser_configs, **external_dataparsers}

# Methods
all_methods, all_descriptions = method_configs, method_descriptions
# Add discovered external methods
discovered_methods, discovered_descriptions = discover_methods()
all_methods, all_descriptions = merge_methods(
all_methods, all_descriptions, discovered_methods, discovered_descriptions
)
all_methods, all_descriptions = sort_methods(all_methods, all_descriptions)

# Register all possible external methods which can be installed with Nerfstudio
all_methods, all_descriptions = merge_methods(
all_methods, all_descriptions, *sort_methods(*get_external_methods()), overwrite=False
)

# We also register all external dataparsers found in the external methods
_registered_dataparsers = set(map(type, all_dataparsers.values()))
for method_name, method in discovered_methods.items():
if not hasattr(method.pipeline.datamanager, "dataparser"):
continue
dataparser = method.pipeline.datamanager.dataparser # type: ignore
if type(dataparser) not in _registered_dataparsers:
name = method_name + "-" + to_snake_case(type(dataparser).__name__).replace("_", "-")
all_dataparsers[name] = dataparser

if TYPE_CHECKING:
# For static analysis (tab completion, type checking, etc), just use the base
# dataparser config.
DataParserUnion = DataParserConfig
else:
# At runtime, populate a Union type dynamically. This is used by `tyro` to generate
# subcommands in the CLI.
DataParserUnion = tyro.extras.subcommand_type_from_defaults(
all_dataparsers,
prefix_names=False, # Omit prefixes in subcommands themselves.
)

AnnotatedDataParserUnion = tyro.conf.OmitSubcommandPrefixes[DataParserUnion] # Omit prefixes of flags in subcommands.
"""Union over possible dataparser types, annotated with metadata for tyro. This is
the same as the vanilla union, but results in shorter subcommand names."""


def fix_method_dataparser_type(method_config: TrainerConfig):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@brentyi , can you please take a closer look at this function?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense but feels scary to me, is there no way to keep the parsers annotated as AnnotatedDataParserUnion? Is the issue with circular imports? Would breaking this file up into smaller files (like one where the dataparser union is defined and one where the method union is defined) help?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having the AnnotatedDataParserUnion in the DataManagerConfig is a bad design which leads to circular imports when you want to define the DataParser and something else (trainerconfig) in the same file. I agree that redefining the types here is not the safest. I think it would be preferable if there was a way to instruct tyro to do it instead: replace base types with the annotated unions? What do you think? If you like this solution I can try to implement it in tyro.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for missing your reply, I had accidentally left my comment twice and assumed you hadn't replied yet because I was looking at the other one...

I think it would be preferable if there was a way to instruct tyro to do it instead: replace base types with the annotated unions

Yeah, this makes sense. There's an unfinished PR that would enable this here: brentyi/tyro#30

If you're willing to make a PR to tyro making this kind of thing possible — either building on my PR or starting from scratch — that'd of course be awesome, but I don't have a good sense of how long something like that would take (or how long getting ramped up on the tyro codebase would take).

If it's too much effort just splitting things like the DataParser + trainerconfig that you described into separate files seems like an unideal but workable short-term solution.

def replace_type(instance, field, value=None, new_type=None):
assert dataclasses.is_dataclass(instance)
if value is None:
value = getattr(instance, field)
if new_type is None:
new_type = type(value)
cls = type(instance)
fields = dataclasses.fields(cls)
new_fields = [(field, new_type)]
config = cls.__dataclass_params__
new_cls = dataclasses.make_dataclass(
cls.__name__,
new_fields,
bases=(cls,),
init=config.init,
repr=config.repr,
eq=config.eq,
order=config.order,
unsafe_hash=config.unsafe_hash,
frozen=config.frozen,
)
kwargs = {i.name: getattr(instance, i.name) for i in fields}
kwargs[field] = value
return new_cls(**kwargs)

if hasattr(method_config.pipeline.datamanager, "dataparser"):
return replace_type(
method_config,
"pipeline",
replace_type(
method_config.pipeline,
"datamanager",
replace_type(
method_config.pipeline.datamanager,
"dataparser",
new_type=AnnotatedDataParserUnion,
),
),
)
return method_config


for key, method in all_methods.items():
all_methods[key] = fix_method_dataparser_type(method)

AnnotatedBaseConfigUnion = tyro.conf.SuppressFixed[ # Don't show unparseable (fixed) arguments in helptext.
tyro.conf.FlagConversionOff[
tyro.extras.subcommand_type_from_defaults(defaults=all_methods, descriptions=all_descriptions)
]
]
"""Union[] type over config types, annotated with default instances for use with
tyro.cli(). Allows the user to pick between one of several base configurations, and
then override values in it."""
27 changes: 1 addition & 26 deletions nerfstudio/configs/dataparser_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,7 @@
Aggregate all the dataparser configs in one location.
"""

from typing import TYPE_CHECKING

import tyro

from nerfstudio.data.dataparsers.arkitscenes_dataparser import ARKitScenesDataParserConfig
from nerfstudio.data.dataparsers.base_dataparser import DataParserConfig
from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig
from nerfstudio.data.dataparsers.colmap_dataparser import ColmapDataParserConfig
from nerfstudio.data.dataparsers.dnerf_dataparser import DNeRFDataParserConfig
Expand All @@ -35,9 +30,8 @@
from nerfstudio.data.dataparsers.scannet_dataparser import ScanNetDataParserConfig
from nerfstudio.data.dataparsers.sdfstudio_dataparser import SDFStudioDataParserConfig
from nerfstudio.data.dataparsers.sitcoms3d_dataparser import Sitcoms3DDataParserConfig
from nerfstudio.plugins.registry_dataparser import discover_dataparsers

dataparsers = {
dataparser_configs = {
"nerfstudio-data": NerfstudioDataParserConfig(),
"minimal-parser": MinimalDataParserConfig(),
"arkit-data": ARKitScenesDataParserConfig(),
Expand All @@ -53,22 +47,3 @@
"sitcoms3d-data": Sitcoms3DDataParserConfig(),
"colmap": ColmapDataParserConfig(),
}

external_dataparsers = discover_dataparsers()
all_dataparsers = {**dataparsers, **external_dataparsers}

if TYPE_CHECKING:
# For static analysis (tab completion, type checking, etc), just use the base
# dataparser config.
DataParserUnion = DataParserConfig
else:
# At runtime, populate a Union type dynamically. This is used by `tyro` to generate
# subcommands in the CLI.
DataParserUnion = tyro.extras.subcommand_type_from_defaults(
all_dataparsers,
prefix_names=False, # Omit prefixes in subcommands themselves.
)

AnnotatedDataParserUnion = tyro.conf.OmitSubcommandPrefixes[DataParserUnion] # Omit prefixes of flags in subcommands.
"""Union over possible dataparser types, annotated with metadata for tyro. This is
the same as the vanilla union, but results in shorter subcommand names."""
55 changes: 1 addition & 54 deletions nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,13 @@

from __future__ import annotations

from collections import OrderedDict
from typing import Dict

import tyro

from nerfstudio.cameras.camera_optimizers import CameraOptimizerConfig
from nerfstudio.configs.base_config import ViewerConfig
from nerfstudio.configs.external_methods import get_external_methods

from nerfstudio.data.datamanagers.random_cameras_datamanager import RandomCamerasDataManagerConfig
from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager, VanillaDataManagerConfig

from nerfstudio.data.datamanagers.random_cameras_datamanager import RandomCamerasDataManagerConfig
from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig
from nerfstudio.data.dataparsers.dnerf_dataparser import DNeRFDataParserConfig
from nerfstudio.data.dataparsers.instant_ngp_dataparser import InstantNGPDataParserConfig
Expand Down Expand Up @@ -61,7 +56,6 @@
from nerfstudio.models.vanilla_nerf import NeRFModel, VanillaModelConfig
from nerfstudio.pipelines.base_pipeline import VanillaPipelineConfig
from nerfstudio.pipelines.dynamic_batch import DynamicBatchPipelineConfig
from nerfstudio.plugins.registry import discover_methods

method_configs: Dict[str, TrainerConfig] = {}
descriptions = {
Expand Down Expand Up @@ -301,7 +295,6 @@
vis="viewer",
)


method_configs["instant-ngp-bounded"] = TrainerConfig(
method_name="instant-ngp-bounded",
steps_per_eval_batch=500,
Expand Down Expand Up @@ -617,49 +610,3 @@
viewer=ViewerConfig(num_rays_per_chunk=1 << 15),
vis="viewer",
)


def merge_methods(methods, method_descriptions, new_methods, new_descriptions, overwrite=True):
"""Merge new methods and descriptions into existing methods and descriptions.
Args:
methods: Existing methods.
method_descriptions: Existing descriptions.
new_methods: New methods to merge in.
new_descriptions: New descriptions to merge in.
Returns:
Merged methods and descriptions.
"""
methods = OrderedDict(**methods)
method_descriptions = OrderedDict(**method_descriptions)
for k, v in new_methods.items():
if overwrite or k not in methods:
methods[k] = v
method_descriptions[k] = new_descriptions.get(k, "")
return methods, method_descriptions


def sort_methods(methods, method_descriptions):
"""Sort methods and descriptions by method name."""
methods = OrderedDict(sorted(methods.items(), key=lambda x: x[0]))
method_descriptions = OrderedDict(sorted(method_descriptions.items(), key=lambda x: x[0]))
return methods, method_descriptions


all_methods, all_descriptions = method_configs, descriptions
# Add discovered external methods
all_methods, all_descriptions = merge_methods(all_methods, all_descriptions, *discover_methods())
all_methods, all_descriptions = sort_methods(all_methods, all_descriptions)

# Register all possible external methods which can be installed with Nerfstudio
all_methods, all_descriptions = merge_methods(
all_methods, all_descriptions, *sort_methods(*get_external_methods()), overwrite=False
)

AnnotatedBaseConfigUnion = tyro.conf.SuppressFixed[ # Don't show unparseable (fixed) arguments in helptext.
tyro.conf.FlagConversionOff[
tyro.extras.subcommand_type_from_defaults(defaults=all_methods, descriptions=all_descriptions)
]
]
"""Union[] type over config types, annotated with default instances for use with
tyro.cli(). Allows the user to pick between one of several base configurations, and
then override values in it."""
26 changes: 8 additions & 18 deletions nerfstudio/data/datamanagers/base_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
from abc import abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from functools import cached_property
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
ForwardRef,
Generic,
List,
Literal,
Expand All @@ -35,9 +36,8 @@
Type,
Union,
cast,
ForwardRef,
get_origin,
get_args,
get_origin,
)

import torch
Expand All @@ -50,26 +50,16 @@
from nerfstudio.cameras.cameras import CameraType
from nerfstudio.cameras.rays import RayBundle
from nerfstudio.configs.base_config import InstantiateConfig
from nerfstudio.configs.dataparser_configs import AnnotatedDataParserUnion
from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs
from nerfstudio.data.dataparsers.base_dataparser import DataParserConfig, DataparserOutputs
from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.data.pixel_samplers import (
EquirectangularPixelSampler,
PatchPixelSampler,
PixelSampler,
)
from nerfstudio.data.utils.dataloaders import (
CacheDataloader,
FixedIndicesEvalDataloader,
RandIndicesEvalDataloader,
)
from nerfstudio.data.pixel_samplers import EquirectangularPixelSampler, PatchPixelSampler, PixelSampler
from nerfstudio.data.utils.dataloaders import CacheDataloader, FixedIndicesEvalDataloader, RandIndicesEvalDataloader
from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes
from nerfstudio.model_components.ray_generators import RayGenerator
from nerfstudio.utils.misc import IterableWrapper
from nerfstudio.utils.misc import IterableWrapper, get_orig_class
from nerfstudio.utils.rich_utils import CONSOLE
from nerfstudio.utils.misc import get_orig_class


def variable_res_collate(batch: List[Dict]) -> Dict:
Expand Down Expand Up @@ -317,7 +307,7 @@ class VanillaDataManagerConfig(DataManagerConfig):

_target: Type = field(default_factory=lambda: VanillaDataManager)
"""Target class to instantiate."""
dataparser: AnnotatedDataParserUnion = BlenderDataParserConfig()
dataparser: DataParserConfig = BlenderDataParserConfig()
"""Specifies the dataparser used to unpack the data."""
train_num_rays_per_batch: int = 1024
"""Number of rays per batch to use per training iteration."""
Expand Down
Loading