Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

style: 💄 typing improvements #186

Merged
merged 2 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions pybind11_stubgen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import importlib
import logging
import re
from argparse import ArgumentParser
from argparse import ArgumentParser, Namespace
from pathlib import Path

from pybind11_stubgen.parser.interface import IParser
Expand Down Expand Up @@ -57,6 +57,24 @@
from pybind11_stubgen.writer import Writer


class CLIArgs(Namespace):
output_dir: str
root_suffix: str
ignore_invalid_expressions: re.Pattern | None
ignore_invalid_identifiers: re.Pattern | None
ignore_unresolved_names: re.Pattern | None
ignore_all_errors: bool
enum_class_locations: list[tuple[re.Pattern, str]]
numpy_array_wrap_with_annotated: bool
numpy_array_remove_parameters: bool
print_invalid_expressions_as_is: bool
print_safe_value_reprs: re.Pattern | None
exit_code: bool
dry_run: bool
stub_extension: str
module_name: str


def arg_parser() -> ArgumentParser:
def regex(pattern_str: str) -> re.Pattern:
try:
Expand Down Expand Up @@ -196,7 +214,7 @@ def regex_colon_path(regex_path: str) -> tuple[re.Pattern, str]:
return parser


def stub_parser_from_args(args) -> IParser:
def stub_parser_from_args(args: CLIArgs) -> IParser:
error_handlers_top: list[type] = [
LoggerData,
*([IgnoreAllErrors] if args.ignore_all_errors else []),
Expand Down Expand Up @@ -273,7 +291,7 @@ def main():
level=logging.INFO,
format="%(name)s - [%(levelname)7s] %(message)s",
)
args = arg_parser().parse_args()
args = arg_parser().parse_args(namespace=CLIArgs())

parser = stub_parser_from_args(args)
printer = Printer(invalid_expr_as_ellipses=not args.print_invalid_expressions_as_is)
Expand All @@ -296,7 +314,7 @@ def main():


def to_output_and_subdir(
output_dir: Path, module_name: str, root_suffix: str | None
output_dir: str, module_name: str, root_suffix: str | None
) -> tuple[Path, Path | None]:
out_dir = Path(output_dir)

Expand Down
100 changes: 51 additions & 49 deletions pybind11_stubgen/parser/mixins/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import sys
import types
from logging import getLogger
from typing import Any
from typing import Any, Sequence, TypeVar

from typing_extensions import TypeGuard
sizmailov marked this conversation as resolved.
Show resolved Hide resolved

from pybind11_stubgen.parser.errors import (
InvalidExpressionError,
Expand Down Expand Up @@ -37,6 +39,12 @@

logger = getLogger("pybind11_stubgen")

T = TypeVar("T")


def all_isinstance(seq: Sequence[Any], type_: type[T]) -> TypeGuard[Sequence[T]]:
return all(isinstance(item, type_) for item in seq)


class RemoveSelfAnnotation(IParser):
def handle_method(self, path: QualifiedName, method: Any) -> list[Method]:
Expand Down Expand Up @@ -112,7 +120,7 @@ def handle_module(

def handle_type(self, type_: type) -> QualifiedName:
result = super().handle_type(type_)
if not inspect.ismodule(type):
if not inspect.ismodule(type_):
self._add_import(result)
return result

Expand Down Expand Up @@ -292,16 +300,14 @@ def handle_field(self, path: QualifiedName, field: Any) -> Field | None:

class FixPEP585CollectionNames(IParser):
__typing_collection_names: set[Identifier] = set(
map(
Identifier,
[
"Dict",
"List",
"Set",
"Tuple",
"FrozenSet",
"Type",
],
Identifier(name)
for name in (
"Dict",
"List",
"Set",
"Tuple",
"FrozenSet",
"Type",
)
)

Expand All @@ -325,47 +331,42 @@ def parse_annotation_str(

class FixTypingTypeNames(IParser):
__typing_names: set[Identifier] = set(
map(
Identifier,
[
"Annotated",
"Any",
"Buffer",
"Callable",
"Dict",
"ItemsView",
"Iterable",
"Iterator",
"KeysView",
"List",
"Optional",
"Sequence",
"Set",
"Tuple",
"Union",
"ValuesView",
# Old pybind11 annotations were not capitalized
"buffer",
"iterable",
"iterator",
"sequence",
],
Identifier(name)
for name in (
"Annotated",
"Any",
"Buffer",
"Callable",
"Dict",
"ItemsView",
"Iterable",
"Iterator",
"KeysView",
"List",
"Optional",
"Sequence",
"Set",
"Tuple",
"Union",
"ValuesView",
# Old pybind11 annotations were not capitalized
"buffer",
"iterable",
"iterator",
"sequence",
)
)
__typing_extensions_names: set[Identifier] = set(
map(
Identifier,
[
"buffer",
"Buffer",
],
Identifier(name)
for name in (
"buffer",
"Buffer",
)
)

def __init__(self):
super().__init__()
py_version = sys.version_info[:2]
if py_version < (3, 9):
if sys.version_info < (3, 9):
self.__typing_extensions_names.add(Identifier("Annotated"))

def parse_annotation_str(
Expand Down Expand Up @@ -548,18 +549,20 @@ def parse_annotation_str(
return result

def __wrap_with_size_helper(self, dims: list[int | str]) -> FixedSize | DynamicSize:
if all(isinstance(d, int) for d in dims):
if all_isinstance(dims, int):
return_t = FixedSize
result = return_t(*dims)
else:
return_t = DynamicSize
result = return_t(*dims)

# TRICK: Use `self.handle_type` to make `FixedSize`/`DynamicSize`
# properly added to the list of imports
self.handle_type(return_t)
return return_t(*dims)
return result

def __to_dims(
self, dimensions: list[ResolvedType | Value | InvalidExpression]
self, dimensions: Sequence[ResolvedType | Value | InvalidExpression]
) -> list[int | str] | None:
result = []
for dim_param in dimensions:
Expand All @@ -578,7 +581,6 @@ def __to_dims(
return result

def report_error(self, error: ParserError) -> None:

if (
isinstance(error, NameResolutionError)
and len(error.name) == 1
Expand Down
2 changes: 1 addition & 1 deletion pybind11_stubgen/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def print_function(self, func: Function) -> list[str]:
kw_only = True
if not pos_only and not arg.pos_only:
pos_only = True
if sys.version_info[:2] >= (3, 8):
if sys.version_info >= (3, 8):
args.append("/")
if not kw_only and arg.kw_only:
kw_only = True
Expand Down
2 changes: 1 addition & 1 deletion pybind11_stubgen/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import field as field_
from typing import Tuple, Union

if sys.version_info[:2] >= (3, 8):
if sys.version_info >= (3, 8):
from typing import Literal

Modifier = Literal["static", "class", None]
Expand Down
Loading