Skip to content

Commit

Permalink
Add special casing for flax, formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Feb 12, 2023
1 parent 4ee8be0 commit 4ea2bdd
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 32 deletions.
5 changes: 5 additions & 0 deletions tests/test_flax_ignore_py310.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import jax
import pytest
from flax import linen as nn
from helptext_utils import get_helptext
from jax import numpy as jnp

import tyro
Expand Down Expand Up @@ -46,6 +47,10 @@ def test_ok():
params = network.init(jax.random.PRNGKey(0), x)
assert network.apply(params, x).shape == (10, 3)

helptext = get_helptext(Classifier)
assert "parent" not in helptext
assert "name" not in helptext


def test_missing():
with pytest.raises(SystemExit):
Expand Down
4 changes: 1 addition & 3 deletions tests/test_helptext.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,9 +510,7 @@ def main(

def test_metavar_5() -> None:
def main(
x: List[
Union[Tuple[int, int], Tuple[str, str]],
] = [(1, 1), (2, 2)]
x: List[Union[Tuple[int, int], Tuple[str, str]],] = [(1, 1), (2, 2)]
) -> None:
pass

Expand Down
20 changes: 16 additions & 4 deletions tests/test_nested_in_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,10 @@ def main(
) -> Any:
return x

assert tyro.cli(main, args="--x.float.g 0.1".split(" "),)[
assert tyro.cli(
main,
args="--x.float.g 0.1".split(" "),
)[
"float"
] == GenericColor(0.5, 0.1, 0.3)
assert tyro.cli(
Expand Down Expand Up @@ -349,10 +352,16 @@ def main(
) -> Any:
return x

assert tyro.cli(main, args="--x.hello.float.g 0.1".split(" "),)["hello"][
assert tyro.cli(
main,
args="--x.hello.float.g 0.1".split(" "),
)["hello"][
"float"
] == GenericColor(0.5, 0.1, 0.3)
assert tyro.cli(main, args="--x.hello.int.g 0".split(" "),) == {
assert tyro.cli(
main,
args="--x.hello.int.g 0".split(" "),
) == {
"hello": {"float": GenericColor(0.5, 0.2, 0.3), "int": GenericColor(25, 0, 3)}
}

Expand All @@ -368,6 +377,9 @@ def main(
) -> Any:
return x

assert tyro.cli(main, args="--x.hello.a.g 1".split(" "),)["hello"][
assert tyro.cli(
main,
args="--x.hello.a.g 1".split(" "),
)["hello"][
"a"
] == Color(5, 1, 3)
14 changes: 8 additions & 6 deletions tyro/_argparse_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ def _format_args(self, action, default_metavar):
else str_from_rich(
Text.from_ansi(
out,
style=THEME.metavar_fixed
if out == "{fixed}"
else THEME.metavar,
style=(
THEME.metavar_fixed if out == "{fixed}" else THEME.metavar
),
),
soft_wrap=True,
)
Expand Down Expand Up @@ -424,9 +424,11 @@ def _tyro_format_nonroot(self):
)
.rstrip()
.split("\n"),
item_parts + [description_part]
if description_part is not None
else item_parts,
(
item_parts + [description_part]
if description_part is not None
else item_parts
),
)
)
)
Expand Down
6 changes: 5 additions & 1 deletion tyro/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,11 @@ def fix_arg(arg: str) -> str:

if print_completion:
_arguments.USE_RICH = True
assert completion_shell in ("bash", "zsh", "tcsh",), (
assert completion_shell in (
"bash",
"zsh",
"tcsh",
), (
"Shell should be one `bash`, `zsh`, or `tcsh`, but got"
f" {completion_shell}"
)
Expand Down
28 changes: 22 additions & 6 deletions tyro/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,9 +369,23 @@ def _field_list_from_namedtuple(
def _field_list_from_dataclass(
cls: TypeForm[Any], default_instance: _DefaultInstance
) -> Union[List[FieldDefinition], UnsupportedNestedTypeMessage]:
# Check if dataclass is a flax module.
is_flax_module = False
try:
import flax

if issubclass(cls, flax.linen.Module):
is_flax_module = True
except ImportError:
pass

# Handle dataclasses.
field_list = []
for dc_field in filter(lambda field: field.init, _resolver.resolved_fields(cls)):
# For flax modules, we ignore the built-in "name" and "parent" fields.
if is_flax_module and dc_field.name in ("name", "parent"):
continue

default = _get_dataclass_field_default(dc_field, default_instance)

# Try to get helptext from field metadata. This is also intended to be
Expand Down Expand Up @@ -423,9 +437,9 @@ def _field_list_from_pydantic(
FieldDefinition.make(
name=pd_field.name,
typ=pd_field.outer_type_,
default=MISSING_NONPROP
if pd_field.required
else pd_field.get_default(),
default=(
MISSING_NONPROP if pd_field.required else pd_field.get_default()
),
helptext=helptext,
)
)
Expand Down Expand Up @@ -708,9 +722,11 @@ def _field_list_from_params(
typ=hints[param.name],
default=default,
helptext=helptext,
markers=(_markers.Positional, _markers._PositionalCall)
if param.kind is inspect.Parameter.POSITIONAL_ONLY
else (),
markers=(
(_markers.Positional, _markers._PositionalCall)
if param.kind is inspect.Parameter.POSITIONAL_ONLY
else ()
),
)
)

Expand Down
8 changes: 5 additions & 3 deletions tyro/_instantiators.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,11 @@ def instantiator_base_case(strings: List[str]) -> Any:

return instantiator_base_case, InstantiatorMetadata(
nargs=1,
metavar=typ.__name__.upper()
if auto_choices is None
else "{" + ",".join(map(str, auto_choices)) + "}",
metavar=(
typ.__name__.upper()
if auto_choices is None
else "{" + ",".join(map(str, auto_choices)) + "}"
),
choices=auto_choices,
)

Expand Down
22 changes: 13 additions & 9 deletions tyro/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,13 @@ def from_callable_or_type(
),
)
nested_parser = ParserSpecification.from_callable_or_type(
# Recursively apply marker types.
field.typ
if len(field.markers) == 0
else Annotated.__class_getitem__( # type: ignore
(field.typ,) + tuple(field.markers)
(
# Recursively apply marker types.
field.typ
if len(field.markers) == 0
else Annotated.__class_getitem__( # type: ignore
(field.typ,) + tuple(field.markers)
)
),
description=None,
parent_classes=parent_classes,
Expand Down Expand Up @@ -439,10 +441,12 @@ def from_field(
subcommand_config, default=field.default
)
subparser = ParserSpecification.from_callable_or_type(
# Recursively apply markers.
Annotated.__class_getitem__((option,) + tuple(field.markers)) # type: ignore
if len(field.markers) > 0
else option,
(
# Recursively apply markers.
Annotated.__class_getitem__((option,) + tuple(field.markers)) # type: ignore
if len(field.markers) > 0
else option
),
description=subcommand_config.description,
parent_classes=parent_classes,
default_instance=subcommand_config.default,
Expand Down

0 comments on commit 4ea2bdd

Please sign in to comment.