-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify external projects, unify plugin API #2146
base: main
Are you sure you want to change the base?
Changes from all commits
0a94a3e
541e385
2ae8c83
68adfb6
91b9fb1
6f690f1
076f2bb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import dataclasses | ||
from collections import OrderedDict | ||
from typing import TYPE_CHECKING | ||
|
||
import tyro | ||
|
||
from nerfstudio.configs.dataparser_configs import dataparser_configs | ||
from nerfstudio.configs.external_methods import get_external_methods | ||
from nerfstudio.configs.method_configs import descriptions as method_descriptions | ||
from nerfstudio.configs.method_configs import method_configs | ||
from nerfstudio.data.dataparsers.base_dataparser import DataParserConfig | ||
from nerfstudio.engine.trainer import TrainerConfig | ||
from nerfstudio.plugins.registry import discover_dataparsers, discover_methods | ||
|
||
|
||
def merge_methods(methods, method_descriptions, new_methods, new_descriptions, overwrite=True): | ||
"""Merge new methods and descriptions into existing methods and descriptions. | ||
Args: | ||
methods: Existing methods. | ||
method_descriptions: Existing descriptions. | ||
new_methods: New methods to merge in. | ||
new_descriptions: New descriptions to merge in. | ||
Returns: | ||
Merged methods and descriptions. | ||
""" | ||
methods = OrderedDict(**methods) | ||
method_descriptions = OrderedDict(**method_descriptions) | ||
brentyi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for k, v in new_methods.items(): | ||
if overwrite or k not in methods: | ||
methods[k] = v | ||
method_descriptions[k] = new_descriptions.get(k, "") | ||
return methods, method_descriptions | ||
|
||
|
||
def sort_methods(methods, method_descriptions): | ||
"""Sort methods and descriptions by method name.""" | ||
methods = OrderedDict(sorted(methods.items(), key=lambda x: x[0])) | ||
method_descriptions = OrderedDict(sorted(method_descriptions.items(), key=lambda x: x[0])) | ||
return methods, method_descriptions | ||
|
||
|
||
def to_snake_case(name: str) -> str: | ||
"""Convert a name to snake case.""" | ||
return "".join(["_" + c.lower() if c.isupper() else c for c in name]).lstrip("_") | ||
|
||
|
||
# DataParsers | ||
external_dataparsers, _external_dataparsers_help = discover_dataparsers() | ||
all_dataparsers = {**dataparser_configs, **external_dataparsers} | ||
|
||
# Methods | ||
all_methods, all_descriptions = method_configs, method_descriptions | ||
# Add discovered external methods | ||
discovered_methods, discovered_descriptions = discover_methods() | ||
all_methods, all_descriptions = merge_methods( | ||
all_methods, all_descriptions, discovered_methods, discovered_descriptions | ||
) | ||
all_methods, all_descriptions = sort_methods(all_methods, all_descriptions) | ||
|
||
# Register all possible external methods which can be installed with Nerfstudio | ||
all_methods, all_descriptions = merge_methods( | ||
all_methods, all_descriptions, *sort_methods(*get_external_methods()), overwrite=False | ||
) | ||
|
||
# We also register all external dataparsers found in the external methods | ||
_registered_dataparsers = set(map(type, all_dataparsers.values())) | ||
for method_name, method in discovered_methods.items(): | ||
if not hasattr(method.pipeline.datamanager, "dataparser"): | ||
continue | ||
dataparser = method.pipeline.datamanager.dataparser # type: ignore | ||
if type(dataparser) not in _registered_dataparsers: | ||
name = method_name + "-" + to_snake_case(type(dataparser).__name__).replace("_", "-") | ||
all_dataparsers[name] = dataparser | ||
|
||
if TYPE_CHECKING: | ||
# For static analysis (tab completion, type checking, etc), just use the base | ||
# dataparser config. | ||
DataParserUnion = DataParserConfig | ||
else: | ||
# At runtime, populate a Union type dynamically. This is used by `tyro` to generate | ||
# subcommands in the CLI. | ||
DataParserUnion = tyro.extras.subcommand_type_from_defaults( | ||
all_dataparsers, | ||
prefix_names=False, # Omit prefixes in subcommands themselves. | ||
) | ||
|
||
AnnotatedDataParserUnion = tyro.conf.OmitSubcommandPrefixes[DataParserUnion] # Omit prefixes of flags in subcommands. | ||
"""Union over possible dataparser types, annotated with metadata for tyro. This is | ||
the same as the vanilla union, but results in shorter subcommand names.""" | ||
|
||
|
||
def fix_method_dataparser_type(method_config: TrainerConfig): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @brentyi , can you please take a closer look at this function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This makes sense but feels scary to me, is there no way to keep the parsers annotated as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think having the AnnotatedDataParserUnion in the DataManagerConfig is a bad design which leads to circular imports when you want to define the DataParser and something else (trainerconfig) in the same file. I agree that redefining the types here is not the safest. I think it would be preferable if there was a way to instruct tyro to do it instead: replace base types with the annotated unions? What do you think? If you like this solution I can try to implement it in tyro. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for missing your reply, I had accidentally left my comment twice and assumed you hadn't replied yet because I was looking at the other one...
Yeah, this makes sense. There's an unfinished PR that would enable this here: brentyi/tyro#30 If you're willing to make a PR to tyro making this kind of thing possible — either building on my PR or starting from scratch — that'd of course be awesome, but I don't have a good sense of how long something like that would take (or how long getting ramped up on the tyro codebase would take). If it's too much effort just splitting things like the DataParser + trainerconfig that you described into separate files seems like an unideal but workable short-term solution. |
||
def replace_type(instance, field, value=None, new_type=None): | ||
assert dataclasses.is_dataclass(instance) | ||
if value is None: | ||
value = getattr(instance, field) | ||
if new_type is None: | ||
new_type = type(value) | ||
cls = type(instance) | ||
fields = dataclasses.fields(cls) | ||
new_fields = [(field, new_type)] | ||
config = cls.__dataclass_params__ | ||
new_cls = dataclasses.make_dataclass( | ||
cls.__name__, | ||
new_fields, | ||
bases=(cls,), | ||
init=config.init, | ||
repr=config.repr, | ||
eq=config.eq, | ||
order=config.order, | ||
unsafe_hash=config.unsafe_hash, | ||
frozen=config.frozen, | ||
) | ||
kwargs = {i.name: getattr(instance, i.name) for i in fields} | ||
kwargs[field] = value | ||
return new_cls(**kwargs) | ||
|
||
if hasattr(method_config.pipeline.datamanager, "dataparser"): | ||
return replace_type( | ||
method_config, | ||
"pipeline", | ||
replace_type( | ||
method_config.pipeline, | ||
"datamanager", | ||
replace_type( | ||
method_config.pipeline.datamanager, | ||
"dataparser", | ||
new_type=AnnotatedDataParserUnion, | ||
), | ||
), | ||
) | ||
return method_config | ||
|
||
|
||
for key, method in all_methods.items(): | ||
all_methods[key] = fix_method_dataparser_type(method) | ||
|
||
AnnotatedBaseConfigUnion = tyro.conf.SuppressFixed[ # Don't show unparseable (fixed) arguments in helptext. | ||
tyro.conf.FlagConversionOff[ | ||
tyro.extras.subcommand_type_from_defaults(defaults=all_methods, descriptions=all_descriptions) | ||
] | ||
] | ||
"""Union[] type over config types, annotated with default instances for use with | ||
tyro.cli(). Allows the user to pick between one of several base configurations, and | ||
then override values in it.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
while we're making changes, can we add annotations here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure