Skip to content

Commit

Permalink
Bump version, docs + tests updates
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Oct 30, 2023
1 parent 7a50cae commit 5907cd3
Show file tree
Hide file tree
Showing 10 changed files with 353 additions and 15 deletions.
31 changes: 31 additions & 0 deletions docs/source/examples/02_nesting/05_subcommands_func.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,37 @@ Subcommands from Functions
:func:`tyro.extras.subcommand_cli_from_dict()` provides a shorthand that generates a
subcommand CLI from a dictionary.

For an input like:

.. code-block:: python
tyro.extras.subcommand_cli_from_dict(
{
"checkout": checkout,
"commit": commit,
}
)
This is internally accomplished by generating and calling:

.. code-block:: python
from typing import Annotated, Any, Union
import tyro
tyro.cli(
Union[
Annotated[
Any,
tyro.conf.subcommand(name="checkout", constructor=checkout),
],
Annotated[
Any,
tyro.conf.subcommand(name="commit", constructor=commit),
],
]
)
.. code-block:: python
Expand Down
2 changes: 0 additions & 2 deletions docs/source/examples/04_additional/10_custom_constructors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ which makes it easier to load complex objects.
:linenos:
import dataclasses
import json as json_
from typing import Dict
from typing_extensions import Annotated
Expand Down
31 changes: 31 additions & 0 deletions examples/02_nesting/05_subcommands_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,37 @@
:func:`tyro.extras.subcommand_cli_from_dict()` provides a shorthand that generates a
subcommand CLI from a dictionary.
For an input like:
```python
tyro.extras.subcommand_cli_from_dict(
{
"checkout": checkout,
"commit": commit,
}
)
```
This is internally accomplished by generating and calling:
```python
from typing import Annotated, Any, Union
import tyro
tyro.cli(
Union[
Annotated[
Any,
tyro.conf.subcommand(name="checkout", constructor=checkout),
],
Annotated[
Any,
tyro.conf.subcommand(name="commit", constructor=commit),
],
]
)
```
Usage:
`python ./05_subcommands_func.py --help`
`python ./05_subcommands_func.py commit --help`
Expand Down
2 changes: 0 additions & 2 deletions examples/04_additional/10_custom_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
`python ./10_custom_constructors.py --dict1.json "{\"hello\": \"world\"}"` --dict2.json "{\"hello\": \"world\"}"`
"""

import dataclasses
import json as json_
from typing import Dict

from typing_extensions import Annotated

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "tyro"
authors = [
{name = "brentyi", email = "[email protected]"},
]
version = "0.5.10"
version = "0.5.11"
description = "Strongly typed, zero-effort CLI interfaces"
readme = "README.md"
license = { text="MIT" }
Expand Down
189 changes: 189 additions & 0 deletions tests/test_py311_generated/test_conf_generated.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import argparse
import contextlib
import dataclasses
import io
import json as json_
import shlex
from typing import Annotated, Any, Dict, Generic, List, Tuple, TypeVar, Union

import pytest
Expand Down Expand Up @@ -900,3 +904,188 @@ class TrainConfig:
assert tyro.cli(
tyro.conf.OmitArgPrefixes[TrainConfig], args="--num-slots 3".split(" ")
) == TrainConfig(ModelConfig(num_slots=3))


def test_custom_constructor_0() -> None:
def times_two(n: str) -> int:
return int(n) * 2

@dataclasses.dataclass
class Config:
x: Annotated[int, tyro.conf.arg(constructor=times_two)]

assert tyro.cli(Config, args="--x.n 5".split(" ")) == Config(x=10)


def test_custom_constructor_1() -> None:
def times_two(n: int) -> int:
return int(n) * 2

@dataclasses.dataclass
class Config:
x: Annotated[int, tyro.conf.arg(constructor=times_two)]

assert tyro.cli(Config, args="--x.n 5".split(" ")) == Config(x=10)


def test_custom_constructor_2() -> None:
@dataclasses.dataclass
class Config:
x: Annotated[float, tyro.conf.arg(constructor=int)]

assert tyro.cli(Config, args="--x 5".split(" ")) == Config(x=5)
with pytest.raises(SystemExit):
tyro.cli(Config, args="--x 5.23".split(" "))


def test_custom_constructor_3() -> None:
def dict_from_json(json: str) -> dict:
out = json_.loads(json)
if not isinstance(out, dict):
raise ValueError(f"{json} is not a dict!")
return out

@dataclasses.dataclass
class Config:
x: Annotated[
dict,
tyro.conf.arg(
metavar="JSON",
constructor=dict_from_json,
),
]

assert tyro.cli(
Config, args=shlex.split('--x.json \'{"hello": "world"}\'')
) == Config(x={"hello": "world"})

target = io.StringIO()
with pytest.raises(SystemExit), contextlib.redirect_stdout(target):
tyro.cli(Config, args="--x.json 5".split(" "))

error = target.getvalue()
assert "Error parsing x: 5 is not a dict!" in error


def test_custom_constructor_4() -> None:
@dataclasses.dataclass
class Config:
x: Annotated[float, tyro.conf.arg(constructor=int)] = 3.23

assert tyro.cli(Config, args="--x 5".split(" ")) == Config(x=5)
assert tyro.cli(Config, args=[]) == Config(x=3.23)


def test_custom_constructor_5() -> None:
def make_float(a: float, b: float, c: float = 3) -> float:
return a * b * c

@dataclasses.dataclass
class Config:
x: Annotated[float, tyro.conf.arg(constructor=make_float)] = 3.23

assert tyro.cli(Config, args=[]) == Config(x=3.23)
assert tyro.cli(Config, args="--x.a 5 --x.b 2 --x.c 3".split(" ")) == Config(x=30)
assert tyro.cli(Config, args="--x.a 5 --x.b 2".split(" ")) == Config(x=30)

# --x.b is required!
with pytest.raises(SystemExit):
tyro.cli(Config, args="--x.a 5".split(" "))

# --x.a and --x.b are required!
with pytest.raises(SystemExit):
tyro.cli(Config, args="--x.c 5".split(" "))


def test_custom_constructor_6() -> None:
def make_float(a: tyro.conf.Positional[float], b: float, c: float = 3) -> float:
return a * b * c

@dataclasses.dataclass
class Config:
x: Annotated[float, tyro.conf.arg(constructor=make_float)] = 3.23

assert tyro.cli(Config, args=[]) == Config(x=3.23)
assert tyro.cli(Config, args="--x.b 2 --x.c 3 5".split(" ")) == Config(x=30)
assert tyro.cli(Config, args="--x.b 2 5".split(" ")) == Config(x=30)

# --x.b is required!
with pytest.raises(SystemExit):
tyro.cli(Config, args="5".split(" "))

# --x.a and --x.b are required!
target = io.StringIO()
with pytest.raises(SystemExit), contextlib.redirect_stdout(target):
tyro.cli(Config, args="--x.c 5".split(" "))
error = target.getvalue()
assert "We're missing" in error


def test_custom_constructor_7() -> None:
@dataclasses.dataclass
class Struct:
a: int
b: int
c: int = 3

def make_float(struct: Struct) -> float:
return struct.a * struct.b * struct.c

@dataclasses.dataclass
class Config:
x: Annotated[float, tyro.conf.arg(constructor=make_float)] = 3.23

assert tyro.cli(Config, args=[]) == Config(x=3.23)
assert tyro.cli(
Config, args="--x.struct.a 5 --x.struct.b 2 --x.struct.c 3".split(" ")
) == Config(x=30)
assert tyro.cli(Config, args="--x.struct.a 5 --x.struct.b 2".split(" ")) == Config(
x=30
)

# --x.struct.b is required!
with pytest.raises(SystemExit):
tyro.cli(Config, args="--x.struct.a 5".split(" "))

# --x.struct.a and --x.struct.b are required!
target = io.StringIO()
with pytest.raises(SystemExit), contextlib.redirect_stdout(target):
tyro.cli(Config, args="--x.struct.c 5".split(" "))
error = target.getvalue()
assert "We're missing arguments" in error
assert "'b'" in error
assert "'a'" in error # The 5 is parsed into `a`.


def test_custom_constructor_8() -> None:
@dataclasses.dataclass
class Struct:
a: tyro.conf.Positional[int]
b: int
c: int = 3

def make_float(struct: Struct) -> float:
return struct.a * struct.b * struct.c

@dataclasses.dataclass
class Config:
x: Annotated[float, tyro.conf.arg(constructor=make_float)] = 3.23

assert tyro.cli(Config, args=[]) == Config(x=3.23)
assert tyro.cli(
Config, args="--x.struct.b 2 --x.struct.c 3 5".split(" ")
) == Config(x=30)
assert tyro.cli(Config, args="--x.struct.b 2 5".split(" ")) == Config(x=30)

# --x.struct.b is required!
with pytest.raises(SystemExit):
tyro.cli(Config, args="5".split(" "))

# --x.struct.a and --x.struct.b are required!
target = io.StringIO()
with pytest.raises(SystemExit), contextlib.redirect_stdout(target):
tyro.cli(Config, args="--x.struct.b 5".split(" "))
error = target.getvalue()
assert "We're missing arguments" in error
assert "'a'" in error
assert "'b'" not in error
35 changes: 35 additions & 0 deletions tests/test_py311_generated/test_nested_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,3 +1026,38 @@ class Level3(BaseConfig):
child: Level2 = dataclasses.field(default_factory=lambda: Level2())

tyro.cli(Level3, args=[])


def test_subcommand_dict_helper() -> None:
def checkout(branch: str) -> str:
"""Check out a branch."""
return branch

def commit(message: str, all: bool = False) -> Tuple[str, bool]:
"""Make a commit."""
return (message, all)

assert (
tyro.extras.subcommand_cli_from_dict(
{
"checkout": checkout,
"commit": commit,
},
args="checkout --branch main".split(" "),
)
== "main"
)
assert tyro.extras.subcommand_cli_from_dict(
{
"checkout": checkout,
"commit": commit,
},
args="commit --message hello".split(" "),
) == ("hello", False)
assert tyro.extras.subcommand_cli_from_dict(
{
"checkout": checkout,
"commit": commit,
},
args="commit --message hello --all".split(" "),
) == ("hello", True)
18 changes: 12 additions & 6 deletions tyro/_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,18 +193,24 @@ def unwrap_annotated(
- Annotated[int, 1], int => (int, (1,))
- Annotated[int, "1"], int => (int, ())
"""
targets = tuple(
x
for x in getattr(typ, "__tyro_markers__", tuple())
if search_type is not None and isinstance(x, search_type)
)
assert isinstance(targets, tuple)
if not hasattr(typ, "__metadata__"):
return typ, ()
return typ, targets

args = get_args(typ)
assert len(args) >= 2

# Don't search for a specific metadata type if `None` is passed in.
if search_type is None:
return args[0], ()

# Look through metadata for desired metadata type.
targets = tuple(x for x in args[1:] if isinstance(x, search_type))
targets = tuple(
x
for x in targets + args[1:]
if search_type is not None and isinstance(x, search_type)
)
return args[0], targets


Expand Down
9 changes: 7 additions & 2 deletions tyro/conf/_markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __getitem__(self, key):


def configure(*markers: Marker) -> Callable[[CallableType], CallableType]:
"""Decorator for configuring functions.
"""Decorator for applying configuration options.
Configuration markers are implemented via `typing.Annotated` and straightforward to
apply to types, for example:
Expand All @@ -153,10 +153,15 @@ def configure(*markers: Marker) -> Callable[[CallableType], CallableType]:
def main(field: bool) -> None:
...
```
Args:
markers: Options to apply.
"""

def _inner(callable: CallableType) -> CallableType:
return Annotated.__class_getitem__((callable,) + tuple(markers)) # type: ignore
# We'll read from __tyro_markers__ in `_resolver.unwrap_annotated()`.
callable.__tyro_markers__ = markers # type: ignore
return callable

return _inner

Expand Down
Loading

0 comments on commit 5907cd3

Please sign in to comment.