diff --git a/src/tyro/_arguments.py b/src/tyro/_arguments.py index d21ad850..7121ce01 100644 --- a/src/tyro/_arguments.py +++ b/src/tyro/_arguments.py @@ -110,6 +110,15 @@ class ArgumentDefinition: subcommand_prefix: str # Prefix for nesting. field: _fields.FieldDefinition + def __post_init__(self) -> None: + if ( + _markers.Fixed in self.field.markers + or _markers.Suppress in self.field.markers + ) and self.field.default in _singleton.MISSING_AND_MISSING_NONPROP: + raise UnsupportedTypeAnnotationError( + f"Field {self.field.intern_name} is missing a default value!" + ) + def add_argument( self, parser: Union[argparse.ArgumentParser, argparse._ArgumentGroup] ) -> None: diff --git a/src/tyro/_fields.py b/src/tyro/_fields.py index 1d1ce112..94db68ac 100644 --- a/src/tyro/_fields.py +++ b/src/tyro/_fields.py @@ -16,9 +16,7 @@ from ._singleton import MISSING_AND_MISSING_NONPROP, MISSING_NONPROP from ._typing import TypeForm from .conf import _confstruct, _markers -from .constructors._primitive_spec import ( - UnsupportedTypeAnnotationError, -) +from .constructors._primitive_spec import UnsupportedTypeAnnotationError from .constructors._registry import ConstructorRegistry from .constructors._struct_spec import ( StructFieldSpec, @@ -51,14 +49,6 @@ class FieldDefinition: # doesn't match the keyword expected by our callable. call_argname: Any - def __post_init__(self): - if ( - _markers.Fixed in self.markers or _markers.Suppress in self.markers - ) and self.default in MISSING_AND_MISSING_NONPROP: - raise UnsupportedTypeAnnotationError( - f"Field {self.intern_name} is missing a default value!" - ) - @staticmethod @contextlib.contextmanager def marker_context(markers: Tuple[_markers.Marker, ...]): diff --git a/src/tyro/_parsers.py b/src/tyro/_parsers.py index 0b237726..5a90817d 100644 --- a/src/tyro/_parsers.py +++ b/src/tyro/_parsers.py @@ -354,11 +354,7 @@ def handle_field( field.type, field.markers, nondefault_only=True ) - if ( - not force_primitive - and _markers.Fixed not in field.markers - and _markers.Suppress not in field.markers - ): + if not force_primitive: # (1) Handle Unions over callables; these result in subparsers. subparsers_attempt = SubparsersSpecification.from_field( field, @@ -367,9 +363,9 @@ def handle_field( extern_prefix=_strings.make_field_name([extern_prefix, field.extern_name]), ) if subparsers_attempt is not None: - if ( - not subparsers_attempt.required - and _markers.AvoidSubcommands in field.markers + if not subparsers_attempt.required and ( + _markers.AvoidSubcommands in field.markers + or _markers.Suppress in field.markers ): # Don't make a subparser. field = field.with_new_type_stripped(type(field.default)) diff --git a/tests/test_conf.py b/tests/test_conf.py index 9d506898..77abee13 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -1545,3 +1545,16 @@ def main( # Doesn't work in Python 3.7 because of argparse limitations. assert tyro.cli(main, args="--verbosity --verbosity -vv".split(" ")) == (2, 2) assert tyro.cli(main, args="--verbosity --verbosity -vvv".split(" ")) == (2, 3) + + +def test_nested_suppress() -> None: + @dataclasses.dataclass + class Bconfig: + b: int = 1 + + @dataclasses.dataclass + class Aconfig: + a: str = "hello" + b_conf: Bconfig = dataclasses.field(default_factory=Bconfig) + + assert tyro.cli(Aconfig, config=(tyro.conf.Suppress,), args=[]) == Aconfig() diff --git a/tests/test_py311_generated/test_conf_generated.py b/tests/test_py311_generated/test_conf_generated.py index 2572a7c6..0785bd68 100644 --- a/tests/test_py311_generated/test_conf_generated.py +++ b/tests/test_py311_generated/test_conf_generated.py @@ -1540,3 +1540,16 @@ def main( # Doesn't work in Python 3.7 because of argparse limitations. assert tyro.cli(main, args="--verbosity --verbosity -vv".split(" ")) == (2, 2) assert tyro.cli(main, args="--verbosity --verbosity -vvv".split(" ")) == (2, 3) + + +def test_nested_suppress() -> None: + @dataclasses.dataclass + class Bconfig: + b: int = 1 + + @dataclasses.dataclass + class Aconfig: + a: str = "hello" + b_conf: Bconfig = dataclasses.field(default_factory=Bconfig) + + assert tyro.cli(Aconfig, config=(tyro.conf.Suppress,), args=[]) == Aconfig()