From 5907cd37fb732be43881f1ed929d49ff603cb33d Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Sun, 29 Oct 2023 23:13:06 -0700 Subject: [PATCH] Bump version, docs + tests updates --- .../02_nesting/05_subcommands_func.rst | 31 +++ .../04_additional/10_custom_constructors.rst | 2 - examples/02_nesting/05_subcommands_func.py | 31 +++ .../04_additional/10_custom_constructors.py | 2 - pyproject.toml | 2 +- .../test_conf_generated.py | 189 ++++++++++++++++++ .../test_nested_generated.py | 35 ++++ tyro/_resolver.py | 18 +- tyro/conf/_markers.py | 9 +- tyro/extras/_subcommand_cli_from_dict.py | 49 ++++- 10 files changed, 353 insertions(+), 15 deletions(-) diff --git a/docs/source/examples/02_nesting/05_subcommands_func.rst b/docs/source/examples/02_nesting/05_subcommands_func.rst index 361f6af8..3eddf238 100644 --- a/docs/source/examples/02_nesting/05_subcommands_func.rst +++ b/docs/source/examples/02_nesting/05_subcommands_func.rst @@ -8,6 +8,37 @@ Subcommands from Functions :func:`tyro.extras.subcommand_cli_from_dict()` provides a shorthand that generates a subcommand CLI from a dictionary. +For an input like: + +.. code-block:: python + + tyro.extras.subcommand_cli_from_dict( + { + "checkout": checkout, + "commit": commit, + } + ) + +This is internally accomplished by generating and calling: + +.. code-block:: python + + from typing import Annotated, Any, Union + import tyro + + tyro.cli( + Union[ + Annotated[ + Any, + tyro.conf.subcommand(name="checkout", constructor=checkout), + ], + Annotated[ + Any, + tyro.conf.subcommand(name="commit", constructor=commit), + ], + ] + ) + .. code-block:: python diff --git a/docs/source/examples/04_additional/10_custom_constructors.rst b/docs/source/examples/04_additional/10_custom_constructors.rst index ed2b08b8..fb3f8880 100644 --- a/docs/source/examples/04_additional/10_custom_constructors.rst +++ b/docs/source/examples/04_additional/10_custom_constructors.rst @@ -14,9 +14,7 @@ which makes it easier to load complex objects. :linenos: - import dataclasses import json as json_ - from typing import Dict from typing_extensions import Annotated diff --git a/examples/02_nesting/05_subcommands_func.py b/examples/02_nesting/05_subcommands_func.py index 29e1e14f..0cf9309b 100644 --- a/examples/02_nesting/05_subcommands_func.py +++ b/examples/02_nesting/05_subcommands_func.py @@ -3,6 +3,37 @@ :func:`tyro.extras.subcommand_cli_from_dict()` provides a shorthand that generates a subcommand CLI from a dictionary. +For an input like: + +```python +tyro.extras.subcommand_cli_from_dict( + { + "checkout": checkout, + "commit": commit, + } +) +``` + +This is internally accomplished by generating and calling: + +```python +from typing import Annotated, Any, Union +import tyro + +tyro.cli( + Union[ + Annotated[ + Any, + tyro.conf.subcommand(name="checkout", constructor=checkout), + ], + Annotated[ + Any, + tyro.conf.subcommand(name="commit", constructor=commit), + ], + ] +) +``` + Usage: `python ./05_subcommands_func.py --help` `python ./05_subcommands_func.py commit --help` diff --git a/examples/04_additional/10_custom_constructors.py b/examples/04_additional/10_custom_constructors.py index 2d6b9d15..993133b0 100644 --- a/examples/04_additional/10_custom_constructors.py +++ b/examples/04_additional/10_custom_constructors.py @@ -9,9 +9,7 @@ `python ./10_custom_constructors.py --dict1.json "{\"hello\": \"world\"}"` --dict2.json "{\"hello\": \"world\"}"` """ -import dataclasses import json as json_ -from typing import Dict from typing_extensions import Annotated diff --git a/pyproject.toml b/pyproject.toml index 35fb6ee3..d08495e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "tyro" authors = [ {name = "brentyi", email = "brentyi@berkeley.edu"}, ] -version = "0.5.10" +version = "0.5.11" description = "Strongly typed, zero-effort CLI interfaces" readme = "README.md" license = { text="MIT" } diff --git a/tests/test_py311_generated/test_conf_generated.py b/tests/test_py311_generated/test_conf_generated.py index d9370e7d..15dcdac1 100644 --- a/tests/test_py311_generated/test_conf_generated.py +++ b/tests/test_py311_generated/test_conf_generated.py @@ -1,5 +1,9 @@ import argparse +import contextlib import dataclasses +import io +import json as json_ +import shlex from typing import Annotated, Any, Dict, Generic, List, Tuple, TypeVar, Union import pytest @@ -900,3 +904,188 @@ class TrainConfig: assert tyro.cli( tyro.conf.OmitArgPrefixes[TrainConfig], args="--num-slots 3".split(" ") ) == TrainConfig(ModelConfig(num_slots=3)) + + +def test_custom_constructor_0() -> None: + def times_two(n: str) -> int: + return int(n) * 2 + + @dataclasses.dataclass + class Config: + x: Annotated[int, tyro.conf.arg(constructor=times_two)] + + assert tyro.cli(Config, args="--x.n 5".split(" ")) == Config(x=10) + + +def test_custom_constructor_1() -> None: + def times_two(n: int) -> int: + return int(n) * 2 + + @dataclasses.dataclass + class Config: + x: Annotated[int, tyro.conf.arg(constructor=times_two)] + + assert tyro.cli(Config, args="--x.n 5".split(" ")) == Config(x=10) + + +def test_custom_constructor_2() -> None: + @dataclasses.dataclass + class Config: + x: Annotated[float, tyro.conf.arg(constructor=int)] + + assert tyro.cli(Config, args="--x 5".split(" ")) == Config(x=5) + with pytest.raises(SystemExit): + tyro.cli(Config, args="--x 5.23".split(" ")) + + +def test_custom_constructor_3() -> None: + def dict_from_json(json: str) -> dict: + out = json_.loads(json) + if not isinstance(out, dict): + raise ValueError(f"{json} is not a dict!") + return out + + @dataclasses.dataclass + class Config: + x: Annotated[ + dict, + tyro.conf.arg( + metavar="JSON", + constructor=dict_from_json, + ), + ] + + assert tyro.cli( + Config, args=shlex.split('--x.json \'{"hello": "world"}\'') + ) == Config(x={"hello": "world"}) + + target = io.StringIO() + with pytest.raises(SystemExit), contextlib.redirect_stdout(target): + tyro.cli(Config, args="--x.json 5".split(" ")) + + error = target.getvalue() + assert "Error parsing x: 5 is not a dict!" in error + + +def test_custom_constructor_4() -> None: + @dataclasses.dataclass + class Config: + x: Annotated[float, tyro.conf.arg(constructor=int)] = 3.23 + + assert tyro.cli(Config, args="--x 5".split(" ")) == Config(x=5) + assert tyro.cli(Config, args=[]) == Config(x=3.23) + + +def test_custom_constructor_5() -> None: + def make_float(a: float, b: float, c: float = 3) -> float: + return a * b * c + + @dataclasses.dataclass + class Config: + x: Annotated[float, tyro.conf.arg(constructor=make_float)] = 3.23 + + assert tyro.cli(Config, args=[]) == Config(x=3.23) + assert tyro.cli(Config, args="--x.a 5 --x.b 2 --x.c 3".split(" ")) == Config(x=30) + assert tyro.cli(Config, args="--x.a 5 --x.b 2".split(" ")) == Config(x=30) + + # --x.b is required! + with pytest.raises(SystemExit): + tyro.cli(Config, args="--x.a 5".split(" ")) + + # --x.a and --x.b are required! + with pytest.raises(SystemExit): + tyro.cli(Config, args="--x.c 5".split(" ")) + + +def test_custom_constructor_6() -> None: + def make_float(a: tyro.conf.Positional[float], b: float, c: float = 3) -> float: + return a * b * c + + @dataclasses.dataclass + class Config: + x: Annotated[float, tyro.conf.arg(constructor=make_float)] = 3.23 + + assert tyro.cli(Config, args=[]) == Config(x=3.23) + assert tyro.cli(Config, args="--x.b 2 --x.c 3 5".split(" ")) == Config(x=30) + assert tyro.cli(Config, args="--x.b 2 5".split(" ")) == Config(x=30) + + # --x.b is required! + with pytest.raises(SystemExit): + tyro.cli(Config, args="5".split(" ")) + + # --x.a and --x.b are required! + target = io.StringIO() + with pytest.raises(SystemExit), contextlib.redirect_stdout(target): + tyro.cli(Config, args="--x.c 5".split(" ")) + error = target.getvalue() + assert "We're missing" in error + + +def test_custom_constructor_7() -> None: + @dataclasses.dataclass + class Struct: + a: int + b: int + c: int = 3 + + def make_float(struct: Struct) -> float: + return struct.a * struct.b * struct.c + + @dataclasses.dataclass + class Config: + x: Annotated[float, tyro.conf.arg(constructor=make_float)] = 3.23 + + assert tyro.cli(Config, args=[]) == Config(x=3.23) + assert tyro.cli( + Config, args="--x.struct.a 5 --x.struct.b 2 --x.struct.c 3".split(" ") + ) == Config(x=30) + assert tyro.cli(Config, args="--x.struct.a 5 --x.struct.b 2".split(" ")) == Config( + x=30 + ) + + # --x.struct.b is required! + with pytest.raises(SystemExit): + tyro.cli(Config, args="--x.struct.a 5".split(" ")) + + # --x.struct.a and --x.struct.b are required! + target = io.StringIO() + with pytest.raises(SystemExit), contextlib.redirect_stdout(target): + tyro.cli(Config, args="--x.struct.c 5".split(" ")) + error = target.getvalue() + assert "We're missing arguments" in error + assert "'b'" in error + assert "'a'" in error # The 5 is parsed into `a`. + + +def test_custom_constructor_8() -> None: + @dataclasses.dataclass + class Struct: + a: tyro.conf.Positional[int] + b: int + c: int = 3 + + def make_float(struct: Struct) -> float: + return struct.a * struct.b * struct.c + + @dataclasses.dataclass + class Config: + x: Annotated[float, tyro.conf.arg(constructor=make_float)] = 3.23 + + assert tyro.cli(Config, args=[]) == Config(x=3.23) + assert tyro.cli( + Config, args="--x.struct.b 2 --x.struct.c 3 5".split(" ") + ) == Config(x=30) + assert tyro.cli(Config, args="--x.struct.b 2 5".split(" ")) == Config(x=30) + + # --x.struct.b is required! + with pytest.raises(SystemExit): + tyro.cli(Config, args="5".split(" ")) + + # --x.struct.a and --x.struct.b are required! + target = io.StringIO() + with pytest.raises(SystemExit), contextlib.redirect_stdout(target): + tyro.cli(Config, args="--x.struct.b 5".split(" ")) + error = target.getvalue() + assert "We're missing arguments" in error + assert "'a'" in error + assert "'b'" not in error diff --git a/tests/test_py311_generated/test_nested_generated.py b/tests/test_py311_generated/test_nested_generated.py index 79e94a6b..c60ca3e9 100644 --- a/tests/test_py311_generated/test_nested_generated.py +++ b/tests/test_py311_generated/test_nested_generated.py @@ -1026,3 +1026,38 @@ class Level3(BaseConfig): child: Level2 = dataclasses.field(default_factory=lambda: Level2()) tyro.cli(Level3, args=[]) + + +def test_subcommand_dict_helper() -> None: + def checkout(branch: str) -> str: + """Check out a branch.""" + return branch + + def commit(message: str, all: bool = False) -> Tuple[str, bool]: + """Make a commit.""" + return (message, all) + + assert ( + tyro.extras.subcommand_cli_from_dict( + { + "checkout": checkout, + "commit": commit, + }, + args="checkout --branch main".split(" "), + ) + == "main" + ) + assert tyro.extras.subcommand_cli_from_dict( + { + "checkout": checkout, + "commit": commit, + }, + args="commit --message hello".split(" "), + ) == ("hello", False) + assert tyro.extras.subcommand_cli_from_dict( + { + "checkout": checkout, + "commit": commit, + }, + args="commit --message hello --all".split(" "), + ) == ("hello", True) diff --git a/tyro/_resolver.py b/tyro/_resolver.py index cc468898..3c685238 100644 --- a/tyro/_resolver.py +++ b/tyro/_resolver.py @@ -193,18 +193,24 @@ def unwrap_annotated( - Annotated[int, 1], int => (int, (1,)) - Annotated[int, "1"], int => (int, ()) """ + targets = tuple( + x + for x in getattr(typ, "__tyro_markers__", tuple()) + if search_type is not None and isinstance(x, search_type) + ) + assert isinstance(targets, tuple) if not hasattr(typ, "__metadata__"): - return typ, () + return typ, targets args = get_args(typ) assert len(args) >= 2 - # Don't search for a specific metadata type if `None` is passed in. - if search_type is None: - return args[0], () - # Look through metadata for desired metadata type. - targets = tuple(x for x in args[1:] if isinstance(x, search_type)) + targets = tuple( + x + for x in targets + args[1:] + if search_type is not None and isinstance(x, search_type) + ) return args[0], targets diff --git a/tyro/conf/_markers.py b/tyro/conf/_markers.py index d80e30c4..3fbb717c 100644 --- a/tyro/conf/_markers.py +++ b/tyro/conf/_markers.py @@ -136,7 +136,7 @@ def __getitem__(self, key): def configure(*markers: Marker) -> Callable[[CallableType], CallableType]: - """Decorator for configuring functions. + """Decorator for applying configuration options. Configuration markers are implemented via `typing.Annotated` and straightforward to apply to types, for example: @@ -153,10 +153,15 @@ def configure(*markers: Marker) -> Callable[[CallableType], CallableType]: def main(field: bool) -> None: ... ``` + + Args: + markers: Options to apply. """ def _inner(callable: CallableType) -> CallableType: - return Annotated.__class_getitem__((callable,) + tuple(markers)) # type: ignore + # We'll read from __tyro_markers__ in `_resolver.unwrap_annotated()`. + callable.__tyro_markers__ = markers # type: ignore + return callable return _inner diff --git a/tyro/extras/_subcommand_cli_from_dict.py b/tyro/extras/_subcommand_cli_from_dict.py index f270dbc8..0e1cd56c 100644 --- a/tyro/extras/_subcommand_cli_from_dict.py +++ b/tyro/extras/_subcommand_cli_from_dict.py @@ -42,8 +42,53 @@ def subcommand_cli_from_dict( args: Optional[Sequence[str]] = None, use_underscores: bool = False, ) -> Any: - """Generate a subcommand CLI from a dictionary that maps subcommand name to the - corresponding function to call (or object to instantiate).""" + """Generate a subcommand CLI from a dictionary of functions. + + For an input like: + + ```python + tyro.extras.subcommand_cli_from_dict( + { + "checkout": checkout, + "commit": commit, + } + ) + ``` + + This is internally accomplished by generating and calling: + + ```python + from typing import Annotated, Any, Union + import tyro + + tyro.cli( + Union[ + Annotated[ + Any, + tyro.conf.subcommand(name="checkout", constructor=checkout), + ], + Annotated[ + Any, + tyro.conf.subcommand(name="commit", constructor=commit), + ], + ] + ) + ``` + + Args: + subcommands: Dictionary that maps the subcommand name to function to call. + prog: The name of the program printed in helptext. Mirrors argument from + `argparse.ArgumentParser()`. + description: Description text for the parser, displayed when the --help flag is + passed in. If not specified, `f`'s docstring is used. Mirrors argument from + `argparse.ArgumentParser()`. + args: If set, parse arguments from a sequence of strings instead of the + commandline. Mirrors argument from `argparse.ArgumentParser.parse_args()`. + use_underscores: If True, use underscores as a word delimeter instead of hyphens. + This primarily impacts helptext; underscores and hyphens are treated equivalently + when parsing happens. We default helptext to hyphens to follow the GNU style guide. + https://www.gnu.org/software/libc/manual/html_node/Argument-Syntax.html + """ return cli( Union.__getitem__( # type: ignore tuple(