From 3002d232de3c9c4519f62db7c8a097106e656027 Mon Sep 17 00:00:00 2001 From: Jeff Epler Date: Tue, 2 Jan 2024 10:24:50 -0600 Subject: [PATCH] Add advanced source transformations to reduce type checking overhead The new 'munge' module performs transformations on the source code. It uses the AST (abstract syntax tree) representation of Python code to recognize some idioms such as `if STATIC_TYPING:` and transforms them into alternatives that have zero overhead in mpy-compiled files (e.g., `if STATIC_TYPING:` is transformed into `if 0:`, which is eliminated at compile time due to mpy-cross constant-propagation and dead branch elimination) The code assumes the input file is black-formatted. In particular, it would malfunction if an if-statement and its body are on the same line: `if STATIC_TYPING: print("boo")` would be incorrectly munged. --- .github/workflows/build.yml | 2 + .gitignore | 1 + circuitpython_build_tools/build.py | 47 +++++------- circuitpython_build_tools/munge.py | 117 +++++++++++++++++++++++++++++ requirements.txt | 3 +- testcases/test1.exp | 33 ++++++++ testcases/test1.py | 33 ++++++++ tests/test_munge.py | 22 ++++++ 8 files changed, 228 insertions(+), 30 deletions(-) create mode 100644 circuitpython_build_tools/munge.py create mode 100644 testcases/test1.exp create mode 100644 testcases/test1.py create mode 100644 tests/test_munge.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a301583..f7d5c63 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -48,6 +48,8 @@ jobs: git clone --recurse-submodules https://github.com/adafruit/CircuitPython_Community_Bundle.git cd CircuitPython_Community_Bundle circuitpython-build-bundles --filename_prefix test-bundle --library_location libraries --library_depth 2 + - name: Munge tests + run: pytest - name: Build Python package run: | pip install --upgrade setuptools wheel twine readme_renderer testresources diff --git a/.gitignore b/.gitignore index 1c50e6e..08aafad 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ version.py .env/* .DS_Store .idea/* +testcases/*.out diff --git a/circuitpython_build_tools/build.py b/circuitpython_build_tools/build.py index dc35ade..44aa72e 100644 --- a/circuitpython_build_tools/build.py +++ b/circuitpython_build_tools/build.py @@ -36,6 +36,8 @@ import subprocess import tempfile +from .munge import munge + # pyproject.toml `py_modules` values that are incorrect. These should all have PRs filed! # and should be removed when the fixed version is incorporated in its respective bundle. @@ -170,16 +172,6 @@ def mpy_cross(mpy_cross_filename, circuitpython_tag, quiet=False): shutil.copy("build_deps/circuitpython/mpy-cross/mpy-cross", mpy_cross_filename) -def _munge_to_temp(original_path, temp_file, library_version): - with open(original_path, "r", encoding="utf-8") as original_file: - for line in original_file: - line = line.strip("\n") - if line.startswith("__version__"): - line = line.replace("0.0.0-auto.0", library_version) - line = line.replace("0.0.0+auto.0", library_version) - print(line, file=temp_file) - temp_file.flush() - def get_package_info(library_path, package_folder_prefix): lib_path = pathlib.Path(library_path) parent_idx = len(lib_path.parts) @@ -289,25 +281,22 @@ def library(library_path, output_directory, package_folder_prefix, full_path = os.path.join(library_path, filename) output_file = output_directory / filename.relative_to(library_path) if filename.suffix == ".py": - with tempfile.NamedTemporaryFile(delete=False, mode="w+") as temp_file: - temp_file_name = temp_file.name - try: - _munge_to_temp(full_path, temp_file, library_version) - temp_file.close() - if mpy_cross and os.stat(temp_file.name).st_size != 0: - output_file = output_file.with_suffix(".mpy") - mpy_success = subprocess.call([ - mpy_cross, - "-o", output_file, - "-s", str(filename.relative_to(library_path)), - temp_file.name - ]) - if mpy_success != 0: - raise RuntimeError("mpy-cross failed on", full_path) - else: - shutil.copyfile(temp_file_name, output_file) - finally: - os.remove(temp_file_name) + content = munge(full_path, library_version) + if mpy_cross and content: + # TODO: Once 8.x bundles are no longer built, switch to + # sending mpy-cross the code on stdin instead of via + # temporary file (supports the "-" input argument) + with tempfile.NamedTemporaryFile(delete=False, mode="w+") as temp_file: + temp_file.write(content) + temp_file.flush() + subprocess.check_output([ + mpy_cross, + "-o", output_file.with_suffix(".mpy"), + "-s", str(filename.relative_to(library_path)), + temp_file.name + ], input=content.encode('utf-8')) + else: + output_file.write_text(content, encoding="utf-8") else: shutil.copyfile(full_path, output_file) diff --git a/circuitpython_build_tools/munge.py b/circuitpython_build_tools/munge.py new file mode 100644 index 0000000..4026efc --- /dev/null +++ b/circuitpython_build_tools/munge.py @@ -0,0 +1,117 @@ +# The MIT License (MIT) +# +# Copyright (c) 2024 Jeff Epler for Adafruit Industries +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +# Filter program removes some code patterns introduced by type checking, +# to move towards zero overhead static typing in circuitpython libraries +# +# Recognized: +# from __future__ import ... -- eliminated +# try: import typing -- eliminated, but first except: preserved +# try: from typing import ... -- eliminated, but first except: preserved +# if STATIC_TYPING: -- transformed to 'if 0:' +# if sys.implementation_name... -- transformed to unconditional if +# __version__ = ... -- set to library version string +# +# mpy-cross does constant propagation and dead branch elimination of +# 'if 0:' and 'if 1:' +# +# Depends on the file being black-formatted! + +import pathlib +import sys +import ast + +VERBOSE = 0 + +# The canonical spelling of this test... +sys_implementation_is_circuitpython = ast.unparse(ast.parse('sys.implementation.name == "circuitpython"')) +sys_implementation_not_circuitpython = ast.unparse(ast.parse('not sys.implementation.name == "circuitpython"')) +sys_implementation_not_circuitpython2 = ast.unparse(ast.parse('sys.implementation.name != "circuitpython"')) + +def munge(src: pathlib.Path|str, version_str: str) -> str: + path = pathlib.Path(src) + replacements = {} + + def replace(line, new): + if VERBOSE: + replacements[line] = f"{new:<40s} ### {lines[line]}" + else: + replacements[line] = new + + def blank_range(node): + for i in range(node.lineno, node.end_lineno+1): + replace(i, "") + + def unblank_range(node): + for i in range(node.lineno, node.end_lineno+1): + replacements.pop(i, None) + + def imports_from_typing(node): + if isinstance(node, ast.Import) and node.names[0].name == 'typing': + return True + if isinstance(node, ast.ImportFrom) and node.module == 'typing': + return True + return False + + def process_statement(node): + # filter out 'from future import...' + if isinstance(node, ast.ImportFrom): + if node.module == '__future__': + blank_range(node) + # filter out 'try: import typing...' + # but preserve the first 'except:' or 'except ImportError' + elif isinstance(node, ast.Try): + b = node.body[0] + if imports_from_typing(node.body[0]): + blank_range(node) + for h in node.handlers: + if h.type is None or ast.unparse(h.type) == 'ImportError' or ast.unparse(h.type) == 'Exception': + unblank_range(h) + replace(h.lineno, 'if 1:') + break + return + elif isinstance(node, ast.If): + node_test = ast.unparse(node.test) + # return the statements in the 'if' branch of 'if sys.implementation...: ...' + if node_test == sys_implementation_is_circuitpython: + replace(node.lineno, 'if 1:') + # return the statements in the 'else' branch of 'if sys.implementation...: ...' + elif node_test == sys_implementation_not_circuitpython or node_test == sys_implementation_not_circuitpython2: + replace(node.lineno, 'if 0:') + # return the statements in the else branch of 'if TYPE_CHECKING: ...' + elif node_test == 'TYPE_CHECKING': + replace(node.lineno, 'if 0:') + elif isinstance(node, ast.Assign) and isinstance(node.targets[0], ast.Name) and node.targets[0].id == '__version__': + replace(node.lineno, f"__version__ = \"{version_str}\"") + + content = pathlib.Path(path).read_text(encoding="utf-8") + # Insert a blank line 0 because ast line numbers are 1-based + lines = [''] + content.rstrip().split('\n') + a = ast.parse(content, path.name) + + for node in a.body: process_statement(node) + + result = [] + for i in range(1, len(lines)): + result.append(replacements.get(i, lines[i])) + + return "\n".join(result) + "\n" diff --git a/requirements.txt b/requirements.txt index 8a3514c..b9c2e4d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ Click +pytest requests semver -wheel tomli; python_version < "3.11" +wheel diff --git a/testcases/test1.exp b/testcases/test1.exp new file mode 100644 index 0000000..02d7b82 --- /dev/null +++ b/testcases/test1.exp @@ -0,0 +1,33 @@ + + + + +if 1: + pass + + + +if 1: + pass + + + + +if 1: + pass + + + +if 1: + pass + +__version__ = "1.2.3" + +if 1: + print("is circuitpython") + +if 0: + print("not circuitpython (1)") + +if 0: + print("not circuitpython (2)") diff --git a/testcases/test1.py b/testcases/test1.py new file mode 100644 index 0000000..60f4e0f --- /dev/null +++ b/testcases/test1.py @@ -0,0 +1,33 @@ +from __future__ import annotation + +try: + from typing import TYPE_CHECKING +except ImportError: + pass + +try: + from typing import TYPE_CHECKING as T +except ImportError: + pass + + +try: + import typing +except: + pass + +try: + import typing as T +except: + pass + +__version__ = "0.0.0-auto" + +if sys.implementation.name == "circuitpython": + print("is circuitpython") + +if sys.implementation.name != "circuitpython": + print("not circuitpython (1)") + +if not sys.implementation.name == "circuitpython": + print("not circuitpython (2)") diff --git a/tests/test_munge.py b/tests/test_munge.py new file mode 100644 index 0000000..48e95f2 --- /dev/null +++ b/tests/test_munge.py @@ -0,0 +1,22 @@ +import sys, pathlib +import pytest + +top = pathlib.Path(__file__).parent.parent +sys.path.insert(0, str(top)) + +from circuitpython_build_tools.munge import munge + +@pytest.mark.parametrize("test_path", top.glob("testcases/*.py")) +def test_munge(test_path): + result_path = test_path.with_suffix(".out") + result_path.unlink(missing_ok = True) + + result_content = munge(test_path, "1.2.3") + result_path.write_text(result_content, encoding="utf-8") + + expected_path = test_path.with_suffix(".exp") + expected_content = expected_path.read_text(encoding="utf-8") + + assert result_content == expected_content + + result_path.unlink()