Skip to content

Commit

Permalink
Add return_unknown_args to tyro.cli (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
JesseFarebro authored Feb 15, 2023
1 parent 1917cc4 commit 1ba5d5c
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 25 deletions.
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,8 @@ exclude_lines = [
ignore = [
"E501", # Ignore line length errors.
]

[tool.pytest.ini_options]
pythonpath = [
"."
]
84 changes: 84 additions & 0 deletions tests/test_dcargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,90 @@ def main() -> argparse.ArgumentParser:
assert isinstance(tyro.cli(main, args=[]), argparse.ArgumentParser)


def test_return_unknown_args() -> None:
@dataclasses.dataclass
class A:
x: int = 0

a, unknown_args = tyro.cli(
A, args=["positional", "--x", "5", "--y", "7"], return_unknown_args=True
)
assert a == A(x=5)
assert unknown_args == ["positional", "--y", "7"]


def test_unknown_args_with_arg_fixing() -> None:
@dataclasses.dataclass
class A:
x: int = 0

a, unknown_args = tyro.cli(
A,
args=["--x", "5", "--a_b", "--a-c"],
return_unknown_args=True,
)
assert a == A(x=5)
# Should return the unfixed arguments
assert unknown_args == ["--a_b", "--a-c"]


def test_allow_ambiguous_args_when_not_returning_unknown_args() -> None:
@dataclasses.dataclass
class A:
a_b: List[int] = dataclasses.field(default_factory=list)

a = tyro.cli(
A,
args=["--a_b", "5", "--a-b", "7"],
)
assert a == A(a_b=[7])


def test_disallow_ambiguous_args_when_returning_unknown_args() -> None:
@dataclasses.dataclass
class A:
x: int = 0

# If there's an argument that's ambiguous then we should raise an error when we're
# returning unknown args.
with pytest.raises(RuntimeError, match="Ambiguous .* --a_b and --a-b"):
tyro.cli(
A,
args=["--x", "5", "--a_b", "--a-b"],
return_unknown_args=True,
)


def test_unknown_args_with_consistent_duplicates() -> None:
@dataclasses.dataclass
class A:
a_b: List[int] = dataclasses.field(default_factory=list)
c_d: List[int] = dataclasses.field(default_factory=list)

# Tests logic for consistent duplicate arguments when performing argument fixing.
# i.e., we can fix arguments if the separator is consistent (all _'s or all -'s).
a, unknown_args = tyro.cli(
A,
args=[
"--a-b",
"5",
"--a-b",
"7",
"--c_d",
"5",
"--c_d",
"7",
"--e-f",
"--e-f",
"--g_h",
"--g_h",
],
return_unknown_args=True,
)
assert a == A(a_b=[7], c_d=[7])
assert unknown_args == ["--e-f", "--e-f", "--g_h", "--g_h"]


def test_pathlike():
def main(x: os.PathLike) -> os.PathLike:
return x
Expand Down
130 changes: 105 additions & 25 deletions tyro/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,21 @@
import pathlib
import sys
import warnings
from typing import Callable, Optional, Sequence, TypeVar, Union, cast, overload
from typing import (
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
cast,
overload,
)

import shtab
from typing_extensions import Literal

from . import (
_argparse_formatter,
Expand Down Expand Up @@ -36,10 +48,24 @@ def cli(
description: Optional[str] = None,
args: Optional[Sequence[str]] = None,
default: Optional[OutT] = None,
return_unknown_args: Literal[False] = False,
) -> OutT:
...


@overload
def cli(
f: TypeForm[OutT],
*,
prog: Optional[str] = None,
description: Optional[str] = None,
args: Optional[Sequence[str]] = None,
default: Optional[OutT] = None,
return_unknown_args: Literal[True],
) -> Tuple[OutT, List[str]]:
...


@overload
def cli(
f: Callable[..., OutT],
Expand All @@ -51,19 +77,37 @@ def cli(
# supported for general callables. These can, however, be specified in the signature
# of the callable itself.
default: None = None,
return_unknown_args: Literal[False] = False,
) -> OutT:
...


@overload
def cli(
f: Callable[..., OutT],
*,
prog: Optional[str] = None,
description: Optional[str] = None,
args: Optional[Sequence[str]] = None,
# Note that passing a default makes sense for things like dataclasses, but are not
# supported for general callables. These can, however, be specified in the signature
# of the callable itself.
default: None = None,
return_unknown_args: Literal[True],
) -> Tuple[OutT, List[str]]:
...


def cli(
f: Union[TypeForm[OutT], Callable[..., OutT]],
*,
prog: Optional[str] = None,
description: Optional[str] = None,
args: Optional[Sequence[str]] = None,
default: Optional[OutT] = None,
return_unknown_args: bool = False,
**deprecated_kwargs,
) -> OutT:
) -> Union[OutT, Tuple[OutT, List[str]]]:
"""Call or instantiate `f`, with inputs populated from an automatically generated
CLI interface.
Expand Down Expand Up @@ -115,23 +159,29 @@ def cli(
type like a dataclass or dictionary, but not if `f` is a general callable like
a function or standard class. Helpful for merging CLI arguments with values
loaded from elsewhere. (for example, a config object loaded from a yaml file)
return_unknown_args: If True, return a tuple of the output of `f` and a list of
unknown arguments. Mirrors the unknown arguments returned from
`argparse.ArgumentParser.parse_known_args()`.
Returns:
The output of `f(...)` or an instance `f`. If `f` is a class, the two are
equivalent.
equivalent. If `return_unknown_args` is True, returns a tuple of the output of
`f(...)` and a list of unknown arguments.
"""
return cast(
OutT,
_cli_impl(
f,
prog=prog,
description=description,
args=args,
default=default,
return_parser=False,
**deprecated_kwargs,
),
output = _cli_impl(
f,
prog=prog,
description=description,
args=args,
default=default,
return_parser=False,
return_unknown_args=return_unknown_args,
**deprecated_kwargs,
)
if return_unknown_args:
return cast(Tuple[OutT, List[str]], output)
else:
return cast(OutT, output)


@overload
Expand Down Expand Up @@ -179,6 +229,7 @@ def get_parser(
args=None,
default=default,
return_parser=True,
return_unknown_args=False,
),
)

Expand All @@ -191,8 +242,9 @@ def _cli_impl(
args: Optional[Sequence[str]],
default: Optional[OutT],
return_parser: bool,
return_unknown_args: bool,
**deprecated_kwargs,
) -> Union[OutT, argparse.ArgumentParser]:
) -> Union[OutT, argparse.ArgumentParser, Tuple[OutT, List[str]],]:
"""Helper for stitching the `tyro` pipeline together.
Converts `f` into a
Expand Down Expand Up @@ -242,18 +294,32 @@ def _cli_impl(

# Read and fix arguments. If the user passes in --field_name instead of
# --field-name, correct for them.
args = sys.argv[1:] if args is None else args

def fix_arg(arg: str) -> str:
args = list(sys.argv[1:]) if args is None else list(args)

# Fix arguments. This will modify all option-style arguments replacing
# underscores with dashes. This is to support the common convention of using
# underscores in variable names, but dashes in command line arguments.
# If two options are ambiguous, e.g., --a_b and --a-b, raise a runtime error.
modified_args: Dict[str, str] = {}
for index, arg in enumerate(args):
if not arg.startswith("--"):
return arg
continue

if "=" in arg:
arg, _, val = arg.partition("=")
return arg.replace("_", "-") + "=" + val
fixed = arg.replace("_", "-") + "=" + val
else:
return arg.replace("_", "-")

args = list(map(fix_arg, args))
fixed = arg.replace("_", "-")
if (
return_unknown_args
and fixed in modified_args
and modified_args[fixed] != arg
):
raise RuntimeError(
f"Ambiguous arguments: " + modified_args[fixed] + " and " + arg
)
modified_args[fixed] = arg
args[index] = fixed

# If we pass in the --tyro-print-completion or --tyro-write-completion flags: turn
# formatting tags, and get the shell we want to generate a completion script for
Expand Down Expand Up @@ -338,7 +404,12 @@ def fix_arg(arg: str) -> str:
)
raise SystemExit()

value_from_prefixed_field_name = vars(parser.parse_args(args=args))
if return_unknown_args:
namespace, unknown_args = parser.parse_known_args(args=args)
else:
unknown_args = None
namespace = parser.parse_args(args=args)
value_from_prefixed_field_name = vars(namespace)

if dummy_wrapped:
value_from_prefixed_field_name = {
Expand Down Expand Up @@ -369,4 +440,13 @@ def fix_arg(arg: str) -> str:

if dummy_wrapped:
out = getattr(out, _strings.dummy_field_name)
return out

if return_unknown_args:
assert unknown_args is not None, "Should have parsed with `parse_known_args()`"
# If we're parsed unknown args, we should return the original args, not
# the fixed ones.
unknown_args = [modified_args.get(arg, arg) for arg in unknown_args]
return out, unknown_args
else:
assert unknown_args is None, "Should have parsed with `parse_args()`"
return out

0 comments on commit 1ba5d5c

Please sign in to comment.