diff --git a/docs/source/tab_completion.md b/docs/source/tab_completion.md index 14153cf3..7984d9a0 100644 --- a/docs/source/tab_completion.md +++ b/docs/source/tab_completion.md @@ -4,10 +4,10 @@ Interfaces built with :func:`tyro.cli()` can be tab completed in interactive shells without any source code modification. Completion scripts can be generated by passing the -`--tyro-print-completion {bash/zsh/tcsh}` flag to a tyro CLI. This generates -a completion script via [shtab](https://docs.iterative.ai/shtab/) and prints it -to stdout. To set up tab completion, the printed script simply needs to be -written somewhere where your shell will find it. +`--tyro-write-completion {bash/zsh/tcsh} PATH` flag to a tyro CLI. This +generates a completion script via [shtab](https://docs.iterative.ai/shtab/) and +prints it to stdout. To set up tab completion, the printed script simply needs +to be written somewhere where your shell will find it. ## Zsh @@ -23,7 +23,7 @@ mkdir -p ~/.zfunc # (2) Write completion script. The name here (_01_functions_py) doesn't matter, # as long as it's prefixed with an underscore. -python 01_functions.py --tyro-print-completion zsh > ~/.zfunc/_01_functions_py +python 01_functions.py --tyro-write-completion zsh ~/.zfunc/_01_functions_py ``` And if it's not in your `.zshrc` already: @@ -66,7 +66,7 @@ mkdir -p $completion_dir # (2) Write completion scripts. Note that the name of the completion script must # match the name of the file. -python 01_functions.py --tyro-print-completion bash > ${completion_dir}/01_functions.py +python 01_functions.py --tyro-write-completion bash ${completion_dir}/01_functions.py ``` In contrast to zsh, tab completion in bash requires that scripts are either set diff --git a/tests/helptext_utils.py b/tests/helptext_utils.py index 44b94e14..fa8e8e84 100644 --- a/tests/helptext_utils.py +++ b/tests/helptext_utils.py @@ -34,6 +34,10 @@ def get_helptext(f: Callable, args: List[str] = ["--help"]) -> str: tyro.cli(f, args=["--tyro-print-completion", "bash"]) with pytest.raises(SystemExit), contextlib.redirect_stdout(open(os.devnull, "w")): tyro.cli(f, args=["--tyro-print-completion", "zsh"]) + with pytest.raises(SystemExit), contextlib.redirect_stdout(open(os.devnull, "w")): + tyro.cli(f, args=["--tyro-write-completion", "bash", os.devnull]) + with pytest.raises(SystemExit), contextlib.redirect_stdout(open(os.devnull, "w")): + tyro.cli(f, args=["--tyro-write-completion", "zsh", os.devnull]) # Check helptext with vs without formatting. This can help catch text wrapping bugs # caused by ANSI sequences. diff --git a/tests/test_print_completion.py b/tests/test_completion.py similarity index 100% rename from tests/test_print_completion.py rename to tests/test_completion.py diff --git a/tyro/_cli.py b/tyro/_cli.py index ebdfdd27..6534fe05 100644 --- a/tyro/_cli.py +++ b/tyro/_cli.py @@ -1,6 +1,7 @@ """Core public API.""" import argparse import dataclasses +import pathlib import sys import warnings from typing import Callable, Optional, Sequence, TypeVar, Union, cast, overload @@ -97,9 +98,9 @@ def cli( - Optional unions over nested structures (optional subparsers). - Generics (including nested generics). - Completion script generation for interactive shells is also provided. To print a - script that can be used for tab completion, pass in `--tyro-print-completion - {bash/zsh/tcsh}`. + Completion script generation for interactive shells is also provided. To write a + script that can be used for tab completion, pass in: + `--tyro-write-completion {bash/zsh/tcsh} {path to script to write}`. Args: f: Function or type. @@ -167,7 +168,7 @@ def get_parser( """Get the `argparse.ArgumentParser` object generated under-the-hood by `tyro.cli()`. Useful for tools like `sphinx-argparse`, `argcomplete`, etc. - For tab completion, we recommend using `tyro.cli()`'s built-in `--tyro-print-completion` + For tab completion, we recommend using `tyro.cli()`'s built-in `--tyro-write-completion` flag.""" return cast( argparse.ArgumentParser, @@ -254,20 +255,28 @@ def fix_arg(arg: str) -> str: args = list(map(fix_arg, args)) - # If we pass in the --tyro-print-completion flag: turn formatting tags, and get - # the shell we want to generate a completion script for (bash/zsh/tcsh). + # If we pass in the --tyro-print-completion or --tyro-write-completion flags: turn + # formatting tags, and get the shell we want to generate a completion script for + # (bash/zsh/tcsh). # - # Note that shtab also offers an add_argument_to() functions that fulfills a similar - # goal, but manual parsing of argv is convenient for turning off formatting. + # shtab also offers an add_argument_to() functions that fulfills a similar goal, but + # manual parsing of argv is convenient for turning off formatting. + # + # Note: --tyro-print-completion is deprecated! --tyro-write-completion is less prone + # to errors from accidental logging, print statements, etc. print_completion = len(args) >= 2 and args[0] == "--tyro-print-completion" + write_completion = len(args) >= 3 and args[0] == "--tyro-write-completion" # Note: setting USE_RICH must happen before the parser specification is generated. # TODO: revisit this. Ideally we should be able to eliminate the global state # changes. completion_shell = None - if print_completion: + completion_target_path = None + if print_completion or write_completion: completion_shell = args[1] - if print_completion or return_parser: + if write_completion: + completion_target_path = pathlib.Path(args[2]) + if print_completion or write_completion or return_parser: _arguments.USE_RICH = False else: _arguments.USE_RICH = True @@ -299,7 +308,7 @@ def fix_arg(arg: str) -> str: _arguments.USE_RICH = True return parser - if print_completion: + if print_completion or write_completion: _arguments.USE_RICH = True assert completion_shell in ( "bash", @@ -309,13 +318,24 @@ def fix_arg(arg: str) -> str: "Shell should be one `bash`, `zsh`, or `tcsh`, but got" f" {completion_shell}" ) - print( - shtab.complete( - parser=parser, - shell=completion_shell, - root_prefix=f"tyro_{parser.prog}", + + if write_completion: + assert completion_target_path is not None + completion_target_path.write_text( + shtab.complete( + parser=parser, + shell=completion_shell, + root_prefix=f"tyro_{parser.prog}", + ) + ) + else: + print( + shtab.complete( + parser=parser, + shell=completion_shell, + root_prefix=f"tyro_{parser.prog}", + ) ) - ) raise SystemExit() value_from_prefixed_field_name = vars(parser.parse_args(args=args))