Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support recursive argument suppression #217

Merged
merged 2 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/tyro/_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ class ArgumentDefinition:
subcommand_prefix: str # Prefix for nesting.
field: _fields.FieldDefinition

def __post_init__(self) -> None:
if (
_markers.Fixed in self.field.markers
or _markers.Suppress in self.field.markers
) and self.field.default in _singleton.MISSING_AND_MISSING_NONPROP:
raise UnsupportedTypeAnnotationError(
f"Field {self.field.intern_name} is missing a default value!"
)

def add_argument(
self, parser: Union[argparse.ArgumentParser, argparse._ArgumentGroup]
) -> None:
Expand Down
12 changes: 1 addition & 11 deletions src/tyro/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
from ._singleton import MISSING_AND_MISSING_NONPROP, MISSING_NONPROP
from ._typing import TypeForm
from .conf import _confstruct, _markers
from .constructors._primitive_spec import (
UnsupportedTypeAnnotationError,
)
from .constructors._primitive_spec import UnsupportedTypeAnnotationError
from .constructors._registry import ConstructorRegistry
from .constructors._struct_spec import (
StructFieldSpec,
Expand Down Expand Up @@ -51,14 +49,6 @@ class FieldDefinition:
# doesn't match the keyword expected by our callable.
call_argname: Any

def __post_init__(self):
if (
_markers.Fixed in self.markers or _markers.Suppress in self.markers
) and self.default in MISSING_AND_MISSING_NONPROP:
raise UnsupportedTypeAnnotationError(
f"Field {self.intern_name} is missing a default value!"
)

@staticmethod
@contextlib.contextmanager
def marker_context(markers: Tuple[_markers.Marker, ...]):
Expand Down
12 changes: 4 additions & 8 deletions src/tyro/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,7 @@ def handle_field(
field.type, field.markers, nondefault_only=True
)

if (
not force_primitive
and _markers.Fixed not in field.markers
and _markers.Suppress not in field.markers
):
if not force_primitive:
# (1) Handle Unions over callables; these result in subparsers.
subparsers_attempt = SubparsersSpecification.from_field(
field,
Expand All @@ -367,9 +363,9 @@ def handle_field(
extern_prefix=_strings.make_field_name([extern_prefix, field.extern_name]),
)
if subparsers_attempt is not None:
if (
not subparsers_attempt.required
and _markers.AvoidSubcommands in field.markers
if not subparsers_attempt.required and (
_markers.AvoidSubcommands in field.markers
or _markers.Suppress in field.markers
):
# Don't make a subparser.
field = field.with_new_type_stripped(type(field.default))
Expand Down
13 changes: 13 additions & 0 deletions tests/test_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1545,3 +1545,16 @@ def main(
# Doesn't work in Python 3.7 because of argparse limitations.
assert tyro.cli(main, args="--verbosity --verbosity -vv".split(" ")) == (2, 2)
assert tyro.cli(main, args="--verbosity --verbosity -vvv".split(" ")) == (2, 3)


def test_nested_suppress() -> None:
@dataclasses.dataclass
class Bconfig:
b: int = 1

@dataclasses.dataclass
class Aconfig:
a: str = "hello"
b_conf: Bconfig = dataclasses.field(default_factory=Bconfig)

assert tyro.cli(Aconfig, config=(tyro.conf.Suppress,), args=[]) == Aconfig()
13 changes: 13 additions & 0 deletions tests/test_py311_generated/test_conf_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,3 +1540,16 @@ def main(
# Doesn't work in Python 3.7 because of argparse limitations.
assert tyro.cli(main, args="--verbosity --verbosity -vv".split(" ")) == (2, 2)
assert tyro.cli(main, args="--verbosity --verbosity -vvv".split(" ")) == (2, 3)


def test_nested_suppress() -> None:
@dataclasses.dataclass
class Bconfig:
b: int = 1

@dataclasses.dataclass
class Aconfig:
a: str = "hello"
b_conf: Bconfig = dataclasses.field(default_factory=Bconfig)

assert tyro.cli(Aconfig, config=(tyro.conf.Suppress,), args=[]) == Aconfig()
Loading