From 47b1e3e3a571fe668bbf018e3a9a1d2ebf11558c Mon Sep 17 00:00:00 2001 From: brentyi Date: Wed, 18 Dec 2024 18:23:46 -0800 Subject: [PATCH 1/2] Support recursive argument suppression --- src/tyro/_arguments.py | 9 +++++++++ src/tyro/_fields.py | 12 +----------- src/tyro/_parsers.py | 12 ++++-------- tests/test_conf.py | 17 +++++++++++++++-- .../test_py311_generated/test_conf_generated.py | 13 +++++++++++++ 5 files changed, 42 insertions(+), 21 deletions(-) diff --git a/src/tyro/_arguments.py b/src/tyro/_arguments.py index d21ad850..7121ce01 100644 --- a/src/tyro/_arguments.py +++ b/src/tyro/_arguments.py @@ -110,6 +110,15 @@ class ArgumentDefinition: subcommand_prefix: str # Prefix for nesting. field: _fields.FieldDefinition + def __post_init__(self) -> None: + if ( + _markers.Fixed in self.field.markers + or _markers.Suppress in self.field.markers + ) and self.field.default in _singleton.MISSING_AND_MISSING_NONPROP: + raise UnsupportedTypeAnnotationError( + f"Field {self.field.intern_name} is missing a default value!" + ) + def add_argument( self, parser: Union[argparse.ArgumentParser, argparse._ArgumentGroup] ) -> None: diff --git a/src/tyro/_fields.py b/src/tyro/_fields.py index 1d1ce112..94db68ac 100644 --- a/src/tyro/_fields.py +++ b/src/tyro/_fields.py @@ -16,9 +16,7 @@ from ._singleton import MISSING_AND_MISSING_NONPROP, MISSING_NONPROP from ._typing import TypeForm from .conf import _confstruct, _markers -from .constructors._primitive_spec import ( - UnsupportedTypeAnnotationError, -) +from .constructors._primitive_spec import UnsupportedTypeAnnotationError from .constructors._registry import ConstructorRegistry from .constructors._struct_spec import ( StructFieldSpec, @@ -51,14 +49,6 @@ class FieldDefinition: # doesn't match the keyword expected by our callable. call_argname: Any - def __post_init__(self): - if ( - _markers.Fixed in self.markers or _markers.Suppress in self.markers - ) and self.default in MISSING_AND_MISSING_NONPROP: - raise UnsupportedTypeAnnotationError( - f"Field {self.intern_name} is missing a default value!" - ) - @staticmethod @contextlib.contextmanager def marker_context(markers: Tuple[_markers.Marker, ...]): diff --git a/src/tyro/_parsers.py b/src/tyro/_parsers.py index 0b237726..5a90817d 100644 --- a/src/tyro/_parsers.py +++ b/src/tyro/_parsers.py @@ -354,11 +354,7 @@ def handle_field( field.type, field.markers, nondefault_only=True ) - if ( - not force_primitive - and _markers.Fixed not in field.markers - and _markers.Suppress not in field.markers - ): + if not force_primitive: # (1) Handle Unions over callables; these result in subparsers. subparsers_attempt = SubparsersSpecification.from_field( field, @@ -367,9 +363,9 @@ def handle_field( extern_prefix=_strings.make_field_name([extern_prefix, field.extern_name]), ) if subparsers_attempt is not None: - if ( - not subparsers_attempt.required - and _markers.AvoidSubcommands in field.markers + if not subparsers_attempt.required and ( + _markers.AvoidSubcommands in field.markers + or _markers.Suppress in field.markers ): # Don't make a subparser. field = field.with_new_type_stripped(type(field.default)) diff --git a/tests/test_conf.py b/tests/test_conf.py index 9d506898..5782643a 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -8,10 +8,10 @@ from typing import Any, Dict, Generic, List, Tuple, Type, TypeVar, Union import pytest -from helptext_utils import get_helptext_with_checks +import tyro from typing_extensions import Annotated -import tyro +from helptext_utils import get_helptext_with_checks def test_suppress_subcommand() -> None: @@ -1545,3 +1545,16 @@ def main( # Doesn't work in Python 3.7 because of argparse limitations. assert tyro.cli(main, args="--verbosity --verbosity -vv".split(" ")) == (2, 2) assert tyro.cli(main, args="--verbosity --verbosity -vvv".split(" ")) == (2, 3) + + +def test_nested_suppress() -> None: + @dataclasses.dataclass + class Bconfig: + b: int = 1 + + @dataclasses.dataclass + class Aconfig: + a: str = "hello" + b_conf: Bconfig = dataclasses.field(default_factory=Bconfig) + + assert tyro.cli(Aconfig, config=(tyro.conf.Suppress,), args=[]) == Aconfig() diff --git a/tests/test_py311_generated/test_conf_generated.py b/tests/test_py311_generated/test_conf_generated.py index 2572a7c6..0785bd68 100644 --- a/tests/test_py311_generated/test_conf_generated.py +++ b/tests/test_py311_generated/test_conf_generated.py @@ -1540,3 +1540,16 @@ def main( # Doesn't work in Python 3.7 because of argparse limitations. assert tyro.cli(main, args="--verbosity --verbosity -vv".split(" ")) == (2, 2) assert tyro.cli(main, args="--verbosity --verbosity -vvv".split(" ")) == (2, 3) + + +def test_nested_suppress() -> None: + @dataclasses.dataclass + class Bconfig: + b: int = 1 + + @dataclasses.dataclass + class Aconfig: + a: str = "hello" + b_conf: Bconfig = dataclasses.field(default_factory=Bconfig) + + assert tyro.cli(Aconfig, config=(tyro.conf.Suppress,), args=[]) == Aconfig() From c958629b07abe7e0ae47f6dae50ce24aea110cd7 Mon Sep 17 00:00:00 2001 From: brentyi Date: Wed, 18 Dec 2024 18:26:19 -0800 Subject: [PATCH 2/2] ruff --- tests/test_conf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_conf.py b/tests/test_conf.py index 5782643a..77abee13 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -8,10 +8,10 @@ from typing import Any, Dict, Generic, List, Tuple, Type, TypeVar, Union import pytest -import tyro +from helptext_utils import get_helptext_with_checks from typing_extensions import Annotated -from helptext_utils import get_helptext_with_checks +import tyro def test_suppress_subcommand() -> None: