Skip to content

Commit

Permalink
style: typing improvements
Browse files Browse the repository at this point in the history
CLIArgs argparse.Namespace dataclass for improved data modeling in pybind11_stubgen/__init__.py; not necessary to index sys.version_info; custom all_isinstance TypeGuard in fix.py to resolve a type error in arguments to FixedSize vs DynamicSize; use Sequence to accept broader types; typo in FixMissingImport.handle_type?; use comprehension instead of map
  • Loading branch information
ringohoffman committed Nov 20, 2023
1 parent 223fa12 commit 42858ba
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 55 deletions.
26 changes: 22 additions & 4 deletions pybind11_stubgen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import importlib
import logging
import re
from argparse import ArgumentParser
from argparse import ArgumentParser, Namespace
from pathlib import Path

from pybind11_stubgen.parser.interface import IParser
Expand Down Expand Up @@ -57,6 +57,24 @@
from pybind11_stubgen.writer import Writer


class CLIArgs(Namespace):
output_dir: str
root_suffix: str
ignore_invalid_expressions: re.Pattern | None
ignore_invalid_identifiers: re.Pattern | None
ignore_unresolved_names: re.Pattern | None
ignore_all_errors: bool
enum_class_locations: list[tuple[re.Pattern, str]]
numpy_array_wrap_with_annotated: bool
numpy_array_remove_parameters: bool
print_invalid_expressions_as_is: bool
print_safe_value_reprs: re.Pattern | None
exit_code: bool
dry_run: bool
stub_extension: str
module_name: str


def arg_parser() -> ArgumentParser:
def regex(pattern_str: str) -> re.Pattern:
try:
Expand Down Expand Up @@ -196,7 +214,7 @@ def regex_colon_path(regex_path: str) -> tuple[re.Pattern, str]:
return parser


def stub_parser_from_args(args) -> IParser:
def stub_parser_from_args(args: CLIArgs) -> IParser:
error_handlers_top: list[type] = [
LoggerData,
*([IgnoreAllErrors] if args.ignore_all_errors else []),
Expand Down Expand Up @@ -273,7 +291,7 @@ def main():
level=logging.INFO,
format="%(name)s - [%(levelname)7s] %(message)s",
)
args = arg_parser().parse_args()
args = arg_parser().parse_args(namespace=CLIArgs())

parser = stub_parser_from_args(args)
printer = Printer(invalid_expr_as_ellipses=not args.print_invalid_expressions_as_is)
Expand All @@ -296,7 +314,7 @@ def main():


def to_output_and_subdir(
output_dir: Path, module_name: str, root_suffix: str | None
output_dir: str, module_name: str, root_suffix: str | None
) -> tuple[Path, Path | None]:
out_dir = Path(output_dir)

Expand Down
100 changes: 51 additions & 49 deletions pybind11_stubgen/parser/mixins/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import sys
import types
from logging import getLogger
from typing import Any
from typing import Any, Sequence, TypeVar

from typing_extensions import TypeGuard

from pybind11_stubgen.parser.errors import (
InvalidExpressionError,
Expand Down Expand Up @@ -37,6 +39,12 @@

logger = getLogger("pybind11_stubgen")

T = TypeVar("T")


def all_isinstance(seq: Sequence[Any], type_: type[T]) -> TypeGuard[Sequence[T]]:
return all(isinstance(item, type_) for item in seq)


class RemoveSelfAnnotation(IParser):
def handle_method(self, path: QualifiedName, method: Any) -> list[Method]:
Expand Down Expand Up @@ -112,7 +120,7 @@ def handle_module(

def handle_type(self, type_: type) -> QualifiedName:
result = super().handle_type(type_)
if not inspect.ismodule(type):
if not inspect.ismodule(type_):
self._add_import(result)
return result

Expand Down Expand Up @@ -292,16 +300,14 @@ def handle_field(self, path: QualifiedName, field: Any) -> Field | None:

class FixPEP585CollectionNames(IParser):
__typing_collection_names: set[Identifier] = set(
map(
Identifier,
[
"Dict",
"List",
"Set",
"Tuple",
"FrozenSet",
"Type",
],
Identifier(name)
for name in (
"Dict",
"List",
"Set",
"Tuple",
"FrozenSet",
"Type",
)
)

Expand All @@ -325,47 +331,42 @@ def parse_annotation_str(

class FixTypingTypeNames(IParser):
__typing_names: set[Identifier] = set(
map(
Identifier,
[
"Annotated",
"Any",
"Buffer",
"Callable",
"Dict",
"ItemsView",
"Iterable",
"Iterator",
"KeysView",
"List",
"Optional",
"Sequence",
"Set",
"Tuple",
"Union",
"ValuesView",
# Old pybind11 annotations were not capitalized
"buffer",
"iterable",
"iterator",
"sequence",
],
Identifier(name)
for name in (
"Annotated",
"Any",
"Buffer",
"Callable",
"Dict",
"ItemsView",
"Iterable",
"Iterator",
"KeysView",
"List",
"Optional",
"Sequence",
"Set",
"Tuple",
"Union",
"ValuesView",
# Old pybind11 annotations were not capitalized
"buffer",
"iterable",
"iterator",
"sequence",
)
)
__typing_extensions_names: set[Identifier] = set(
map(
Identifier,
[
"buffer",
"Buffer",
],
Identifier(name)
for name in (
"buffer",
"Buffer",
)
)

def __init__(self):
super().__init__()
py_version = sys.version_info[:2]
if py_version < (3, 9):
if sys.version_info < (3, 9):
self.__typing_extensions_names.add(Identifier("Annotated"))

def parse_annotation_str(
Expand Down Expand Up @@ -548,18 +549,20 @@ def parse_annotation_str(
return result

def __wrap_with_size_helper(self, dims: list[int | str]) -> FixedSize | DynamicSize:
if all(isinstance(d, int) for d in dims):
if all_isinstance(dims, int):
return_t = FixedSize
result = return_t(*dims)
else:
return_t = DynamicSize
result = return_t(*dims)

# TRICK: Use `self.handle_type` to make `FixedSize`/`DynamicSize`
# properly added to the list of imports
self.handle_type(return_t)
return return_t(*dims)
return result

def __to_dims(
self, dimensions: list[ResolvedType | Value | InvalidExpression]
self, dimensions: Sequence[ResolvedType | Value | InvalidExpression]
) -> list[int | str] | None:
result = []
for dim_param in dimensions:
Expand All @@ -578,7 +581,6 @@ def __to_dims(
return result

def report_error(self, error: ParserError) -> None:

if (
isinstance(error, NameResolutionError)
and len(error.name) == 1
Expand Down
2 changes: 1 addition & 1 deletion pybind11_stubgen/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def print_function(self, func: Function) -> list[str]:
kw_only = True
if not pos_only and not arg.pos_only:
pos_only = True
if sys.version_info[:2] >= (3, 8):
if sys.version_info >= (3, 8):
args.append("/")
if not kw_only and arg.kw_only:
kw_only = True
Expand Down
2 changes: 1 addition & 1 deletion pybind11_stubgen/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import field as field_
from typing import Tuple, Union

if sys.version_info[:2] >= (3, 8):
if sys.version_info >= (3, 8):
from typing import Literal

Modifier = Literal["static", "class", None]
Expand Down

0 comments on commit 42858ba

Please sign in to comment.