diff --git a/README.rst b/README.rst index 667e0b5..3c45c7a 100644 --- a/README.rst +++ b/README.rst @@ -61,6 +61,9 @@ Latest Release Changelog --------- +- Generates type hints for non-scalar parameters + [Daverball] + 0.20.0 (2024-08-28) ~~~~~~~~~~~~~~~~~~~ diff --git a/pyproject.toml b/pyproject.toml index 0da8538..a2b6612 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ log_level = "INFO" testpaths = ["tests"] filterwarnings = [ "ignore:The _yaml extension module is now located at yaml._yaml", + "ignore:AnsibleCollectionFinder has already been configured" ] [tool.coverage.run] diff --git a/scripts/generate_module_hints.py b/scripts/generate_module_hints.py index 4ba677f..c23b932 100755 --- a/scripts/generate_module_hints.py +++ b/scripts/generate_module_hints.py @@ -38,14 +38,14 @@ 'raw': 'str', 'str': 'str', 'int': 'int', + 'float': 'float', 'bool': 'bool', # NOTE: Technically you can construct a `PathLike` that will not work # but since pretty much all of them convert to str just fine # we accept it anyways 'path': 'StrPath', - # FIXME: We don't support complex parameter types - 'list': 'NotSupported', - 'dict': 'NotSupported', + 'list': 'Sequence', + 'dict': 'Mapping', } @@ -81,7 +81,7 @@ def write_function_parameter_list( # if the type is not set it appears to always be str type_name = type_map.get(meta.get('type', 'string'), 'Incomplete') if print_default := ( - type_name not in ('NotSupported', 'Incomplete') + type_name not in ('Mapping', 'Sequence', 'Incomplete') and 'default' in meta and (default := meta['default']) is not None ): @@ -106,7 +106,22 @@ def write_function_parameter_list( else: type_name = f'{default_type} | {type_name}' - if ( + if type_name == 'Sequence': + element_type = type_map.get(meta.get('elements', ''), 'Incomplete') + if element_type == 'Mapping': + element_type = 'Mapping[str, Incomplete]' + type_name = f'{type_name}[{element_type}]' + # NOTE: Technically Ansible will turn scalar values into + # 1-element lists automatically, so we could accept + # `sequence | scalar`, but that probably defeats the + # purpose of static typing. I'd rather be a little bit + # more strict here. There's already plenty of other + # argument types where Ansible is more forgiviing. + + elif type_name == 'Mapping': + type_name = 'Mapping[str, Incomplete]' + + elif ( type_name == 'str' and 'choices' in meta # this means we need to support arbitrary strings @@ -491,12 +506,12 @@ def write_return_type(returns: dict[str, Any] | None) -> None: if TYPE_CHECKING: from _typeshed import StrPath + from collections.abc import Mapping, Sequence from suitable._module_types import ( ''') modules_py.write('''\ ) from suitable.types import Incomplete - from typing_extensions import Never as NotSupported # HACK: Get Sphinx to display the default values we don't diff --git a/src/suitable/_modules.py b/src/suitable/_modules.py index fbf3260..96ec293 100644 --- a/src/suitable/_modules.py +++ b/src/suitable/_modules.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: from _typeshed import StrPath + from collections.abc import Mapping, Sequence from suitable._module_types import ( AddHostResults, AptResults, @@ -156,7 +157,6 @@ WinWhoamiResults, ) from suitable.types import Incomplete - from typing_extensions import Never as NotSupported # HACK: Get Sphinx to display the default values we don't @@ -174,7 +174,7 @@ def add_host( self, *, name: str, - groups: NotSupported = _Unknown, + groups: Sequence[str] = _Unknown, ) -> AddHostResults: """ Add a host (and alternatively a group) to the ansible-playbook @@ -194,7 +194,7 @@ def add_host( def apt( self, *, - name: NotSupported = _Unknown, + name: Sequence[str] = _Unknown, state: Literal[ 'absent', 'build-dep', @@ -607,7 +607,7 @@ def assemble( def assert_( self, *, - that: NotSupported, + that: Sequence[str], fail_msg: str = _Unknown, success_msg: str = _Unknown, quiet: bool = False, @@ -857,7 +857,7 @@ def command( *, expand_argument_vars: bool = True, cmd: str = _Unknown, - argv: NotSupported = _Unknown, + argv: Sequence[str] = _Unknown, creates: StrPath = _Unknown, removes: StrPath = _Unknown, chdir: StrPath = _Unknown, @@ -1209,23 +1209,23 @@ def deb822_repository( allow_downgrade_to_insecure: bool = _Unknown, allow_insecure: bool = _Unknown, allow_weak: bool = _Unknown, - architectures: NotSupported = _Unknown, + architectures: Sequence[str] = _Unknown, by_hash: bool = _Unknown, check_date: bool = _Unknown, check_valid_until: bool = _Unknown, - components: NotSupported = _Unknown, + components: Sequence[str] = _Unknown, date_max_future: int = _Unknown, enabled: bool = _Unknown, inrelease_path: str = _Unknown, - languages: NotSupported = _Unknown, + languages: Sequence[str] = _Unknown, name: str, pdiffs: bool = _Unknown, signed_by: str = _Unknown, - suites: NotSupported = _Unknown, - targets: NotSupported = _Unknown, + suites: Sequence[str] = _Unknown, + targets: Sequence[str] = _Unknown, trusted: bool = _Unknown, - types: NotSupported = _Unknown, - uris: NotSupported = _Unknown, + types: Sequence[str] = _Unknown, + uris: Sequence[str] = _Unknown, mode: str = '0644', state: Literal['absent', 'present'] = 'present', ) -> Deb822RepositoryResults: @@ -1390,7 +1390,7 @@ def dnf( 'dnf4', 'dnf5', ] = 'auto', - name: NotSupported = _Unknown, + name: Sequence[str] = _Unknown, list: str = _Unknown, state: Literal[ 'absent', @@ -1399,21 +1399,21 @@ def dnf( 'removed', 'latest', ] = _Unknown, - enablerepo: NotSupported = _Unknown, - disablerepo: NotSupported = _Unknown, + enablerepo: Sequence[str] = _Unknown, + disablerepo: Sequence[str] = _Unknown, conf_file: str = _Unknown, disable_gpg_check: bool = False, installroot: str = '/', releasever: str = _Unknown, autoremove: bool = False, - exclude: NotSupported = _Unknown, + exclude: Sequence[str] = _Unknown, skip_broken: bool = False, update_cache: bool = False, update_only: bool = False, security: bool = False, bugfix: bool = False, - enable_plugin: NotSupported = _Unknown, - disable_plugin: NotSupported = _Unknown, + enable_plugin: Sequence[str] = _Unknown, + disable_plugin: Sequence[str] = _Unknown, disable_excludes: str = _Unknown, validate_certs: bool = True, sslverify: bool = True, @@ -1579,7 +1579,7 @@ def dnf( def dnf5( self, *, - name: NotSupported = _Unknown, + name: Sequence[str] = _Unknown, list: str = _Unknown, state: Literal[ 'absent', @@ -1588,21 +1588,21 @@ def dnf5( 'removed', 'latest', ] = _Unknown, - enablerepo: NotSupported = _Unknown, - disablerepo: NotSupported = _Unknown, + enablerepo: Sequence[str] = _Unknown, + disablerepo: Sequence[str] = _Unknown, conf_file: str = _Unknown, disable_gpg_check: bool = False, installroot: str = '/', releasever: str = _Unknown, autoremove: bool = False, - exclude: NotSupported = _Unknown, + exclude: Sequence[str] = _Unknown, skip_broken: bool = False, update_cache: bool = False, update_only: bool = False, security: bool = False, bugfix: bool = False, - enable_plugin: NotSupported = _Unknown, - disable_plugin: NotSupported = _Unknown, + enable_plugin: Sequence[str] = _Unknown, + disable_plugin: Sequence[str] = _Unknown, disable_excludes: str = _Unknown, validate_certs: bool = True, sslverify: bool = True, @@ -1799,7 +1799,7 @@ def expect( creates: StrPath = _Unknown, removes: StrPath = _Unknown, chdir: StrPath = _Unknown, - responses: NotSupported, + responses: Mapping[str, Incomplete], timeout: int | str = 30, echo: bool = False, ) -> ExpectResults: @@ -2092,11 +2092,11 @@ def find( self, *, age: str = _Unknown, - patterns: NotSupported = _Unknown, - excludes: NotSupported = _Unknown, + patterns: Sequence[str] = _Unknown, + excludes: Sequence[str] = _Unknown, contains: str = _Unknown, read_whole_file: bool = False, - paths: NotSupported, + paths: Sequence[str], file_type: Literal['any', 'directory', 'file', 'link'] = 'file', recurse: bool = False, size: str = _Unknown, @@ -2239,7 +2239,7 @@ def gather_facts( def get_url( self, *, - ciphers: NotSupported = _Unknown, + ciphers: Sequence[str] = _Unknown, decompress: bool = True, url: str, dest: StrPath, @@ -2250,14 +2250,14 @@ def get_url( use_proxy: bool = True, validate_certs: bool = True, timeout: int = 10, - headers: NotSupported = _Unknown, + headers: Mapping[str, Incomplete] = _Unknown, url_username: str = _Unknown, url_password: str = _Unknown, force_basic_auth: bool = False, client_cert: StrPath = _Unknown, client_key: StrPath = _Unknown, http_agent: str = 'ansible-httpget', - unredirected_headers: NotSupported = _Unknown, + unredirected_headers: Sequence[str] = _Unknown, use_gssapi: bool = False, use_netrc: bool = True, mode: str = _Unknown, @@ -2544,7 +2544,7 @@ def git( archive: StrPath = _Unknown, archive_prefix: str = _Unknown, separate_git_dir: StrPath = _Unknown, - gpg_allowlist: NotSupported = _Unknown, + gpg_allowlist: Sequence[str] = _Unknown, ) -> GitResults: """ Deploy software (or files) from git checkouts. @@ -2731,7 +2731,7 @@ def group_by( self, *, key: str, - parents: NotSupported = _Unknown, + parents: Sequence[str] = _Unknown, ) -> GroupByResults: """ Create Ansible groups based on facts. @@ -2957,8 +2957,8 @@ def include_vars( name: str = _Unknown, depth: int = 0, files_matching: str = _Unknown, - ignore_files: NotSupported = _Unknown, - extensions: NotSupported = _Unknown, + ignore_files: Sequence[str] = _Unknown, + extensions: Sequence[str] = _Unknown, ignore_unknown_extensions: bool = False, hash_behaviour: Literal['replace', 'merge'] = _Unknown, ) -> IncludeVarsResults: ... @@ -3032,8 +3032,8 @@ def iptables( protocol: str = _Unknown, source: str = _Unknown, destination: str = _Unknown, - tcp_flags: NotSupported = _Unknown, - match: NotSupported = _Unknown, + tcp_flags: Mapping[str, Incomplete] = _Unknown, + match: Sequence[str] = _Unknown, jump: str = _Unknown, gateway: str = _Unknown, log_prefix: str = _Unknown, @@ -3062,7 +3062,7 @@ def iptables( set_counters: str = _Unknown, source_port: str = _Unknown, destination_port: str = _Unknown, - destination_ports: NotSupported = _Unknown, + destination_ports: Sequence[str] = _Unknown, to_ports: str = _Unknown, to_destination: str = _Unknown, to_source: str = _Unknown, @@ -3070,7 +3070,7 @@ def iptables( set_dscp_mark: str = _Unknown, set_dscp_mark_class: str = _Unknown, comment: str = _Unknown, - ctstate: NotSupported = _Unknown, + ctstate: Sequence[str] = _Unknown, src_range: str = _Unknown, dst_range: str = _Unknown, match_set: str = _Unknown, @@ -3647,7 +3647,7 @@ def package( def package_facts( self, *, - manager: NotSupported = _Unknown, + manager: Sequence[str] = _Unknown, strategy: Literal['first', 'all'] = 'first', ) -> PackageFactsResults: """ @@ -3724,7 +3724,7 @@ def ping( def pip( self, *, - name: NotSupported = _Unknown, + name: Sequence[str] = _Unknown, version: str = _Unknown, requirements: str = _Unknown, virtualenv: StrPath = _Unknown, @@ -3855,7 +3855,7 @@ def reboot( connect_timeout: int = _Unknown, test_command: str = 'whoami', msg: str = _Unknown, - search_paths: NotSupported = _Unknown, + search_paths: Sequence[str] = _Unknown, boot_time_command: str = _Unknown, reboot_command: str = _Unknown, ) -> RebootResults: @@ -4272,7 +4272,7 @@ def set_fact( def set_stats( self, *, - data: NotSupported, + data: Mapping[str, Incomplete], per_host: bool = False, aggregate: bool = True, ) -> SetStatsResults: @@ -4296,9 +4296,9 @@ def set_stats( def setup( self, *, - gather_subset: NotSupported = _Unknown, + gather_subset: Sequence[str] = _Unknown, gather_timeout: int = 10, - filter: NotSupported = _Unknown, + filter: Sequence[str] = _Unknown, fact_path: StrPath = '/etc/ansible/facts.d', ) -> SetupResults: """ @@ -4713,7 +4713,7 @@ def sysvinit( enabled: bool = _Unknown, sleep: int = 1, pattern: str = _Unknown, - runlevels: NotSupported = _Unknown, + runlevels: Sequence[str] = _Unknown, arguments: str = _Unknown, daemonize: bool = False, ) -> SysvinitResults: @@ -4971,10 +4971,10 @@ def unarchive( creates: StrPath = _Unknown, io_buffer_size: int = 65536, list_files: bool = False, - exclude: NotSupported = _Unknown, - include: NotSupported = _Unknown, + exclude: Sequence[str] = _Unknown, + include: Sequence[str] = _Unknown, keep_newer: bool = False, - extra_opts: NotSupported = _Unknown, + extra_opts: Sequence[str] = _Unknown, remote_src: bool = False, validate_certs: bool = True, decrypt: bool = True, @@ -5137,7 +5137,7 @@ def unarchive( def uri( self, *, - ciphers: NotSupported = _Unknown, + ciphers: Sequence[str] = _Unknown, decompress: bool = True, url: str, dest: StrPath = _Unknown, @@ -5163,9 +5163,9 @@ def uri( ] = 'safe', creates: StrPath = _Unknown, removes: StrPath = _Unknown, - status_code: NotSupported = _Unknown, + status_code: Sequence[int] = _Unknown, timeout: int = 30, - headers: NotSupported = _Unknown, + headers: Mapping[str, Incomplete] = _Unknown, validate_certs: bool = True, client_cert: StrPath = _Unknown, client_key: StrPath = _Unknown, @@ -5176,7 +5176,7 @@ def uri( use_proxy: bool = True, unix_socket: StrPath = _Unknown, http_agent: str = 'ansible-httpget', - unredirected_headers: NotSupported = _Unknown, + unredirected_headers: Sequence[str] = _Unknown, use_gssapi: bool = False, use_netrc: bool = True, mode: str = _Unknown, @@ -5431,7 +5431,7 @@ def user( non_unique: bool = False, seuser: str = _Unknown, group: str = _Unknown, - groups: NotSupported = _Unknown, + groups: Sequence[str] = _Unknown, append: bool = False, shell: str = _Unknown, home: StrPath = _Unknown, @@ -5451,7 +5451,7 @@ def user( ssh_key_comment: str = _Unknown, ssh_key_passphrase: str = _Unknown, update_password: Literal['always', 'on_create'] = 'always', - expires: Incomplete = _Unknown, + expires: float = _Unknown, password_lock: bool = _Unknown, local: bool = False, profile: str = _Unknown, @@ -5702,7 +5702,7 @@ def wait_for( connect_timeout: int = 5, delay: int = 0, port: int = _Unknown, - active_connection_states: NotSupported = _Unknown, + active_connection_states: Sequence[str] = _Unknown, state: Literal[ 'absent', 'drained', @@ -5712,7 +5712,7 @@ def wait_for( ] = 'started', path: StrPath = _Unknown, search_regex: str = _Unknown, - exclude_hosts: NotSupported = _Unknown, + exclude_hosts: Sequence[str] = _Unknown, sleep: int = 1, msg: str = _Unknown, ) -> WaitForResults: @@ -5797,7 +5797,7 @@ def yum( self, *, use_backend: Literal['auto', 'yum', 'yum4', 'dnf'] = 'auto', - name: NotSupported = _Unknown, + name: Sequence[str] = _Unknown, exclude: str = _Unknown, list: str = _Unknown, state: Literal[ @@ -5972,23 +5972,23 @@ def yum_repository( *, async_: bool = _Unknown, bandwidth: str = _Unknown, - baseurl: NotSupported = _Unknown, + baseurl: Sequence[str] = _Unknown, cost: str = _Unknown, deltarpm_metadata_percentage: str = _Unknown, deltarpm_percentage: str = _Unknown, description: str = _Unknown, enabled: bool = _Unknown, enablegroups: bool = _Unknown, - exclude: NotSupported = _Unknown, + exclude: Sequence[str] = _Unknown, failovermethod: Literal['roundrobin', 'priority'] = _Unknown, file: str = _Unknown, gpgcakey: str = _Unknown, gpgcheck: bool = _Unknown, - gpgkey: NotSupported = _Unknown, + gpgkey: Sequence[str] = _Unknown, module_hotfixes: bool = _Unknown, http_caching: Literal['all', 'packages', 'none'] = _Unknown, include: str = _Unknown, - includepkgs: NotSupported = _Unknown, + includepkgs: Sequence[str] = _Unknown, ip_resolve: Literal[ '4', '6', @@ -6377,8 +6377,8 @@ def cli_command( self, *, command: str, - prompt: NotSupported = _Unknown, - answer: NotSupported = _Unknown, + prompt: Sequence[str] = _Unknown, + answer: Sequence[str] = _Unknown, sendonly: bool = False, newline: bool = True, check_all: bool = False, @@ -6429,8 +6429,8 @@ def cli_config( multiline_delimiter: str = _Unknown, diff_replace: Literal['line', 'block', 'config'] = _Unknown, diff_match: Literal['line', 'strict', 'exact', 'none'] = _Unknown, - diff_ignore_lines: NotSupported = _Unknown, - backup_options: NotSupported = _Unknown, + diff_ignore_lines: Sequence[str] = _Unknown, + backup_options: Mapping[str, Incomplete] = _Unknown, ) -> CliConfigResults: """ Push text based configuration to network devices over network_cli. @@ -6554,7 +6554,7 @@ def grpc_config( config: str = _Unknown, state: str = _Unknown, backup: bool = False, - backup_options: NotSupported = _Unknown, + backup_options: Mapping[str, Incomplete] = _Unknown, ) -> GrpcConfigResults: """ Fetch configuration/state data from gRPC enabled target hosts. @@ -6732,7 +6732,7 @@ def netconf_config( delete: bool = False, commit: bool = True, validate: bool = False, - backup_options: NotSupported = _Unknown, + backup_options: Mapping[str, Incomplete] = _Unknown, get_filter: str = _Unknown, ) -> NetconfConfigResults: """ @@ -7054,13 +7054,13 @@ def restconf_get( def telnet( self, *, - command: NotSupported, + command: Sequence[str], host: str = 'remote_addr', user: str = 'remote_user', password: str = _Unknown, port: int = 23, timeout: int = 120, - prompts: NotSupported = _Unknown, + prompts: Sequence[str] = _Unknown, login_prompt: str = _Unknown, password_prompt: str = _Unknown, pause: int = 1, @@ -7285,7 +7285,7 @@ def firewalld( service: str = _Unknown, protocol: str = _Unknown, port: str = _Unknown, - port_forward: NotSupported = _Unknown, + port_forward: Sequence[Mapping[str, Incomplete]] = _Unknown, rich_rule: str = _Unknown, source: str = _Unknown, interface: str = _Unknown, @@ -7384,7 +7384,7 @@ def firewalld_info( self, *, active_zones: bool = False, - zones: NotSupported = _Unknown, + zones: Sequence[str] = _Unknown, ) -> FirewalldInfoResults: """ Gather information about firewalld. @@ -7582,7 +7582,7 @@ def rhel_facts(self, arg: str, /) -> RhelFactsResults: def rhel_rpm_ostree( self, *, - name: NotSupported = _Unknown, + name: Sequence[str] = _Unknown, state: Literal[ 'absent', 'installed', @@ -7735,11 +7735,11 @@ def synchronize( set_remote_user: bool = True, use_ssh_args: bool = False, ssh_connection_multiplexing: bool = False, - rsync_opts: NotSupported = _Unknown, + rsync_opts: Sequence[str] = _Unknown, partial: bool = False, verify_host: bool = False, private_key: StrPath = _Unknown, - link_dest: NotSupported = _Unknown, + link_dest: Sequence[str] = _Unknown, delay_updates: bool = True, ) -> SynchronizeResults: """ @@ -7908,7 +7908,7 @@ def cli_parse( *, command: str = _Unknown, text: str = _Unknown, - parser: NotSupported, + parser: Mapping[str, Incomplete], set_fact: str = _Unknown, ) -> CliParseResults: """ @@ -7933,7 +7933,7 @@ def fact_diff( *, before: str, after: str, - plugin: NotSupported = _Unknown, + plugin: Mapping[str, Incomplete] = _Unknown, ) -> FactDiffResults: """ Find the difference between currently set facts. @@ -7953,7 +7953,7 @@ def fact_diff( def update_fact( self, *, - updates: NotSupported, + updates: Sequence[Mapping[str, Incomplete]], ) -> UpdateFactResults: """ Update currently set facts. @@ -8198,7 +8198,7 @@ def win_command( *, _raw_params: str = _Unknown, cmd: str = _Unknown, - argv: NotSupported = _Unknown, + argv: Sequence[str] = _Unknown, creates: StrPath = _Unknown, removes: StrPath = _Unknown, chdir: StrPath = _Unknown, @@ -8324,8 +8324,8 @@ def win_copy( def win_dns_client( self, *, - adapter_names: NotSupported, - dns_servers: NotSupported, + adapter_names: Sequence[str], + dns_servers: Sequence[str], ) -> WinDnsClientResults: """ Configures DNS lookup on Windows hosts. @@ -8591,7 +8591,7 @@ def win_environment( state: Literal['absent', 'present'] = _Unknown, name: str = _Unknown, value: str = _Unknown, - variables: NotSupported = _Unknown, + variables: Mapping[str, Incomplete] = _Unknown, level: Literal['machine', 'process', 'user'], ) -> WinEnvironmentResults: """ @@ -8630,7 +8630,7 @@ def win_environment( def win_feature( self, *, - name: NotSupported, + name: Sequence[str], state: Literal['absent', 'present'] = 'present', include_sub_features: bool = False, include_management_tools: bool = False, @@ -8714,8 +8714,8 @@ def win_find( follow: bool = False, get_checksum: bool = True, hidden: bool = False, - paths: NotSupported, - patterns: NotSupported = _Unknown, + paths: Sequence[str], + patterns: Sequence[str] = _Unknown, recurse: bool = False, size: str = _Unknown, use_regex: bool = False, @@ -8803,7 +8803,7 @@ def win_get_url( url_method: str = _Unknown, url_timeout: int = 30, follow_redirects: Literal['all', 'none', 'safe'] = 'safe', - headers: NotSupported = _Unknown, + headers: Mapping[str, Incomplete] = _Unknown, http_agent: str = 'ansible-httpget', maximum_redirection: int = 50, validate_certs: bool = True, @@ -8974,7 +8974,7 @@ def win_group_membership( self, *, name: str, - members: NotSupported, + members: Sequence[str], state: Literal['absent', 'present', 'pure'] = 'present', ) -> WinGroupMembershipResults: """ @@ -9019,7 +9019,7 @@ def win_hostname( def win_optional_feature( self, *, - name: NotSupported, + name: Sequence[str], state: Literal['absent', 'present'] = 'present', include_parent: bool = False, source: str = _Unknown, @@ -9076,7 +9076,7 @@ def win_package( creates_path: StrPath = _Unknown, creates_service: str = _Unknown, creates_version: str = _Unknown, - expected_return_code: NotSupported = _Unknown, + expected_return_code: Sequence[int] = _Unknown, log_path: StrPath = _Unknown, path: str = _Unknown, product_id: str = _Unknown, @@ -9091,7 +9091,7 @@ def win_package( wait_for_children: bool = False, url_method: str = _Unknown, follow_redirects: Literal['all', 'none', 'safe'] = 'safe', - headers: NotSupported = _Unknown, + headers: Mapping[str, Incomplete] = _Unknown, http_agent: str = 'ansible-httpget', maximum_redirection: int = 50, url_timeout: int = 30, @@ -9349,7 +9349,7 @@ def win_path( self, *, name: str = 'PATH', - elements: NotSupported, + elements: Sequence[str], state: Literal['absent', 'present'] = 'present', scope: Literal['machine', 'user'] = 'machine', ) -> WinPathResults: @@ -9408,7 +9408,7 @@ def win_ping( def win_powershell( self, *, - arguments: NotSupported = _Unknown, + arguments: Sequence[str] = _Unknown, chdir: str = _Unknown, creates: str = _Unknown, depth: int = 2, @@ -9418,10 +9418,10 @@ def win_powershell( 'stop', ] = 'continue', executable: str = _Unknown, - parameters: NotSupported = _Unknown, + parameters: Mapping[str, Incomplete] = _Unknown, removes: str = _Unknown, script: str, - sensitive_parameters: NotSupported = _Unknown, + sensitive_parameters: Sequence[Mapping[str, Incomplete]] = _Unknown, ) -> WinPowershellResults: """ Run PowerShell scripts. @@ -9489,10 +9489,10 @@ def win_powershell( def win_reboot( self, *, - pre_reboot_delay: Incomplete = _Unknown, - post_reboot_delay: Incomplete = _Unknown, - reboot_timeout: Incomplete = _Unknown, - connect_timeout: Incomplete = _Unknown, + pre_reboot_delay: int | float = 2, + post_reboot_delay: int | float = 0, + reboot_timeout: int | float = 600, + connect_timeout: int | float = 5, test_command: str = _Unknown, msg: str = _Unknown, boot_time_command: str = _Unknown, @@ -9631,7 +9631,7 @@ def win_regedit( def win_service( self, *, - dependencies: NotSupported = _Unknown, + dependencies: Sequence[str] = _Unknown, dependency_action: Literal['add', 'remove', 'set'] = 'set', desktop_interact: bool = False, description: str = _Unknown, @@ -9642,7 +9642,7 @@ def win_service( 'normal', 'severe', ] = _Unknown, - failure_actions: NotSupported = _Unknown, + failure_actions: Sequence[Mapping[str, Incomplete]] = _Unknown, failure_actions_on_non_crash_failure: bool = _Unknown, failure_command: str = _Unknown, failure_reboot_msg: str = _Unknown, @@ -9653,7 +9653,7 @@ def win_service( path: str = _Unknown, password: str = _Unknown, pre_shutdown_timeout_ms: str = _Unknown, - required_privileges: NotSupported = _Unknown, + required_privileges: Sequence[str] = _Unknown, service_type: Literal[ 'user_own_process', 'user_share_process', @@ -10120,8 +10120,8 @@ def win_template( def win_updates( self, *, - accept_list: NotSupported = _Unknown, - category_names: NotSupported = _Unknown, + accept_list: Sequence[str] = _Unknown, + category_names: Sequence[str] = _Unknown, skip_optional: bool = False, reboot: bool = False, reboot_timeout: int = 1200, @@ -10136,9 +10136,9 @@ def win_updates( 'downloaded', ] = 'installed', log_path: StrPath = _Unknown, - reject_list: NotSupported = _Unknown, + reject_list: Sequence[str] = _Unknown, _operation: Literal['start', 'cancel', 'poll'] = 'start', - _operation_options: NotSupported = _Unknown, + _operation_options: Mapping[str, Incomplete] = _Unknown, ) -> WinUpdatesResults: """ Download and install Windows updates. @@ -10226,11 +10226,11 @@ def win_uri( creates: StrPath = _Unknown, removes: StrPath = _Unknown, return_content: bool = False, - status_code: NotSupported = _Unknown, + status_code: Sequence[int] = _Unknown, url_method: str = 'GET', url_timeout: int = 30, follow_redirects: Literal['all', 'none', 'safe'] = 'safe', - headers: NotSupported = _Unknown, + headers: Mapping[str, Incomplete] = _Unknown, http_agent: str = 'ansible-httpget', maximum_redirection: int = 50, validate_certs: bool = True, @@ -10377,7 +10377,7 @@ def win_user( account_locked: bool = _Unknown, description: str = _Unknown, fullname: str = _Unknown, - groups: NotSupported = _Unknown, + groups: Sequence[str] = _Unknown, groups_action: Literal['add', 'replace', 'remove'] = 'replace', home_directory: str = _Unknown, login_script: str = _Unknown, @@ -10461,7 +10461,7 @@ def win_user_right( self, *, name: str, - users: NotSupported, + users: Sequence[str], action: Literal['add', 'remove', 'set'] = 'set', ) -> WinUserRightResults: """ @@ -10500,7 +10500,7 @@ def win_wait_for( *, connect_timeout: int = 5, delay: int = _Unknown, - exclude_hosts: NotSupported = _Unknown, + exclude_hosts: Sequence[str] = _Unknown, host: str = '127.0.0.1', path: StrPath = _Unknown, port: int = _Unknown, diff --git a/src/suitable/module_runner.py b/src/suitable/module_runner.py index 11b6323..8d204b1 100644 --- a/src/suitable/module_runner.py +++ b/src/suitable/module_runner.py @@ -150,6 +150,9 @@ def f(*args: Any, **kwargs: Any) -> RunnerResults: if self.module_name == 'assert': api.__dict__['assert_'] = f + # TODO: Check whether modules with a free-form argument allow combining + # free-form with named parameters. + # Check whether more than one free-form parameter is possible. def get_module_args( self, args: tuple[Any, ...], diff --git a/tests/test_api.py b/tests/test_api.py index c2641fc..69e1b91 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,10 +1,8 @@ -import gc import os import os.path from secrets import token_hex import pytest -from ansible.utils.display import Display from suitable.api import Api, list_ansible_modules from suitable.errors import ModuleError, UnreachableError @@ -42,6 +40,23 @@ def test_module_args(): assert e.result['invocation']['module_args']['_raw_params'] == upgrade +def test_module_args_non_scalar(): + upgrade = ( + 'apt-get', + 'upgrade', + '-y', + '-o', + 'Dpkg::Options::="--force-confdef"', + '-o', + 'Dpkg::Options::="--force-confold"' + ) + + try: + Api('localhost').command(argv=upgrade) + except ModuleError as e: + assert e.result['invocation']['module_args']['argv'] == list(upgrade) + + def test_results(): result = Api('localhost').command('whoami') assert result.rc('localhost') == 0 @@ -87,6 +102,15 @@ def test_whoami_multiple_servers(server): assert results.rc(server[1]) == 0 +def test_non_scalar_parameter(): + host = Api('localhost') + result = host.command(argv=['echo', 'hello world']) + + assert result.rc() == 0 + assert result.cmd() == ['echo', 'hello world'] + assert result.stdout() == 'hello world' + + def test_valid_return_codes(): host = Api('localhost') assert host._valid_return_codes == (0,) @@ -242,10 +266,6 @@ def test_same_server_multiple_ports(): assert len(result['contacted']) == 2 -def test_single_display_module(): - assert sum(1 for obj in gc.get_objects() if isinstance(obj, Display)) == 1 - - @pytest.mark.skipif(not is_mitogen_supported(), reason="incompatible mitogen") def test_mitogen_integration(): try: