From 4ea2bddc0bcf16a8533a6f98b607f93fe6ca89f2 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Sun, 12 Feb 2023 01:19:01 -0800 Subject: [PATCH] Add special casing for flax, formatting --- tests/test_flax_ignore_py310.py | 5 +++++ tests/test_helptext.py | 4 +--- tests/test_nested_in_containers.py | 20 ++++++++++++++++---- tyro/_argparse_formatter.py | 14 ++++++++------ tyro/_cli.py | 6 +++++- tyro/_fields.py | 28 ++++++++++++++++++++++------ tyro/_instantiators.py | 8 +++++--- tyro/_parsers.py | 22 +++++++++++++--------- 8 files changed, 75 insertions(+), 32 deletions(-) diff --git a/tests/test_flax_ignore_py310.py b/tests/test_flax_ignore_py310.py index 9642692e..4448c556 100644 --- a/tests/test_flax_ignore_py310.py +++ b/tests/test_flax_ignore_py310.py @@ -2,6 +2,7 @@ import jax import pytest from flax import linen as nn +from helptext_utils import get_helptext from jax import numpy as jnp import tyro @@ -46,6 +47,10 @@ def test_ok(): params = network.init(jax.random.PRNGKey(0), x) assert network.apply(params, x).shape == (10, 3) + helptext = get_helptext(Classifier) + assert "parent" not in helptext + assert "name" not in helptext + def test_missing(): with pytest.raises(SystemExit): diff --git a/tests/test_helptext.py b/tests/test_helptext.py index af8dbaaf..b8aa3e47 100644 --- a/tests/test_helptext.py +++ b/tests/test_helptext.py @@ -510,9 +510,7 @@ def main( def test_metavar_5() -> None: def main( - x: List[ - Union[Tuple[int, int], Tuple[str, str]], - ] = [(1, 1), (2, 2)] + x: List[Union[Tuple[int, int], Tuple[str, str]],] = [(1, 1), (2, 2)] ) -> None: pass diff --git a/tests/test_nested_in_containers.py b/tests/test_nested_in_containers.py index 91e095aa..a3eab0e1 100644 --- a/tests/test_nested_in_containers.py +++ b/tests/test_nested_in_containers.py @@ -321,7 +321,10 @@ def main( ) -> Any: return x - assert tyro.cli(main, args="--x.float.g 0.1".split(" "),)[ + assert tyro.cli( + main, + args="--x.float.g 0.1".split(" "), + )[ "float" ] == GenericColor(0.5, 0.1, 0.3) assert tyro.cli( @@ -349,10 +352,16 @@ def main( ) -> Any: return x - assert tyro.cli(main, args="--x.hello.float.g 0.1".split(" "),)["hello"][ + assert tyro.cli( + main, + args="--x.hello.float.g 0.1".split(" "), + )["hello"][ "float" ] == GenericColor(0.5, 0.1, 0.3) - assert tyro.cli(main, args="--x.hello.int.g 0".split(" "),) == { + assert tyro.cli( + main, + args="--x.hello.int.g 0".split(" "), + ) == { "hello": {"float": GenericColor(0.5, 0.2, 0.3), "int": GenericColor(25, 0, 3)} } @@ -368,6 +377,9 @@ def main( ) -> Any: return x - assert tyro.cli(main, args="--x.hello.a.g 1".split(" "),)["hello"][ + assert tyro.cli( + main, + args="--x.hello.a.g 1".split(" "), + )["hello"][ "a" ] == Color(5, 1, 3) diff --git a/tyro/_argparse_formatter.py b/tyro/_argparse_formatter.py index d5ed0ea3..1e7f35aa 100644 --- a/tyro/_argparse_formatter.py +++ b/tyro/_argparse_formatter.py @@ -154,9 +154,9 @@ def _format_args(self, action, default_metavar): else str_from_rich( Text.from_ansi( out, - style=THEME.metavar_fixed - if out == "{fixed}" - else THEME.metavar, + style=( + THEME.metavar_fixed if out == "{fixed}" else THEME.metavar + ), ), soft_wrap=True, ) @@ -424,9 +424,11 @@ def _tyro_format_nonroot(self): ) .rstrip() .split("\n"), - item_parts + [description_part] - if description_part is not None - else item_parts, + ( + item_parts + [description_part] + if description_part is not None + else item_parts + ), ) ) ) diff --git a/tyro/_cli.py b/tyro/_cli.py index d76c1a9d..ebdfdd27 100644 --- a/tyro/_cli.py +++ b/tyro/_cli.py @@ -301,7 +301,11 @@ def fix_arg(arg: str) -> str: if print_completion: _arguments.USE_RICH = True - assert completion_shell in ("bash", "zsh", "tcsh",), ( + assert completion_shell in ( + "bash", + "zsh", + "tcsh", + ), ( "Shell should be one `bash`, `zsh`, or `tcsh`, but got" f" {completion_shell}" ) diff --git a/tyro/_fields.py b/tyro/_fields.py index 0fe1517c..baf55e34 100644 --- a/tyro/_fields.py +++ b/tyro/_fields.py @@ -369,9 +369,23 @@ def _field_list_from_namedtuple( def _field_list_from_dataclass( cls: TypeForm[Any], default_instance: _DefaultInstance ) -> Union[List[FieldDefinition], UnsupportedNestedTypeMessage]: + # Check if dataclass is a flax module. + is_flax_module = False + try: + import flax + + if issubclass(cls, flax.linen.Module): + is_flax_module = True + except ImportError: + pass + # Handle dataclasses. field_list = [] for dc_field in filter(lambda field: field.init, _resolver.resolved_fields(cls)): + # For flax modules, we ignore the built-in "name" and "parent" fields. + if is_flax_module and dc_field.name in ("name", "parent"): + continue + default = _get_dataclass_field_default(dc_field, default_instance) # Try to get helptext from field metadata. This is also intended to be @@ -423,9 +437,9 @@ def _field_list_from_pydantic( FieldDefinition.make( name=pd_field.name, typ=pd_field.outer_type_, - default=MISSING_NONPROP - if pd_field.required - else pd_field.get_default(), + default=( + MISSING_NONPROP if pd_field.required else pd_field.get_default() + ), helptext=helptext, ) ) @@ -708,9 +722,11 @@ def _field_list_from_params( typ=hints[param.name], default=default, helptext=helptext, - markers=(_markers.Positional, _markers._PositionalCall) - if param.kind is inspect.Parameter.POSITIONAL_ONLY - else (), + markers=( + (_markers.Positional, _markers._PositionalCall) + if param.kind is inspect.Parameter.POSITIONAL_ONLY + else () + ), ) ) diff --git a/tyro/_instantiators.py b/tyro/_instantiators.py index 08720160..08aee6f3 100644 --- a/tyro/_instantiators.py +++ b/tyro/_instantiators.py @@ -205,9 +205,11 @@ def instantiator_base_case(strings: List[str]) -> Any: return instantiator_base_case, InstantiatorMetadata( nargs=1, - metavar=typ.__name__.upper() - if auto_choices is None - else "{" + ",".join(map(str, auto_choices)) + "}", + metavar=( + typ.__name__.upper() + if auto_choices is None + else "{" + ",".join(map(str, auto_choices)) + "}" + ), choices=auto_choices, ) diff --git a/tyro/_parsers.py b/tyro/_parsers.py index f017e994..05351dec 100644 --- a/tyro/_parsers.py +++ b/tyro/_parsers.py @@ -142,11 +142,13 @@ def from_callable_or_type( ), ) nested_parser = ParserSpecification.from_callable_or_type( - # Recursively apply marker types. - field.typ - if len(field.markers) == 0 - else Annotated.__class_getitem__( # type: ignore - (field.typ,) + tuple(field.markers) + ( + # Recursively apply marker types. + field.typ + if len(field.markers) == 0 + else Annotated.__class_getitem__( # type: ignore + (field.typ,) + tuple(field.markers) + ) ), description=None, parent_classes=parent_classes, @@ -439,10 +441,12 @@ def from_field( subcommand_config, default=field.default ) subparser = ParserSpecification.from_callable_or_type( - # Recursively apply markers. - Annotated.__class_getitem__((option,) + tuple(field.markers)) # type: ignore - if len(field.markers) > 0 - else option, + ( + # Recursively apply markers. + Annotated.__class_getitem__((option,) + tuple(field.markers)) # type: ignore + if len(field.markers) > 0 + else option + ), description=subcommand_config.description, parent_classes=parent_classes, default_instance=subcommand_config.default,