-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add advanced source transformations to reduce type checking overhead
The new 'munge' module performs transformations on the source code. It uses the AST (abstract syntax tree) representation of Python code to recognize some idioms such as `if STATIC_TYPING:` and transforms them into alternatives that have zero overhead in mpy-compiled files (e.g., `if STATIC_TYPING:` is transformed into `if 0:`, which is eliminated at compile time due to mpy-cross constant-propagation and dead branch elimination) The code assumes the input file is black-formatted. In particular, it would malfunction if an if-statement and its body are on the same line: `if STATIC_TYPING: print("boo")` would be incorrectly munged.
- Loading branch information
Showing
8 changed files
with
228 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,3 +10,4 @@ version.py | |
.env/* | ||
.DS_Store | ||
.idea/* | ||
testcases/*.out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# The MIT License (MIT) | ||
# | ||
# Copyright (c) 2024 Jeff Epler for Adafruit Industries | ||
# | ||
# Permission is hereby granted, free of charge, to any person obtaining a copy | ||
# of this software and associated documentation files (the "Software"), to deal | ||
# in the Software without restriction, including without limitation the rights | ||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
# copies of the Software, and to permit persons to whom the Software is | ||
# furnished to do so, subject to the following conditions: | ||
# | ||
# The above copyright notice and this permission notice shall be included in | ||
# all copies or substantial portions of the Software. | ||
# | ||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
# THE SOFTWARE. | ||
|
||
# Filter program removes some code patterns introduced by type checking, | ||
# to move towards zero overhead static typing in circuitpython libraries | ||
# | ||
# Recognized: | ||
# from __future__ import ... -- eliminated | ||
# try: import typing -- eliminated, but first except: preserved | ||
# try: from typing import ... -- eliminated, but first except: preserved | ||
# if STATIC_TYPING: -- transformed to 'if 0:' | ||
# if sys.implementation_name... -- transformed to unconditional if | ||
# __version__ = ... -- set to library version string | ||
# | ||
# mpy-cross does constant propagation and dead branch elimination of | ||
# 'if 0:' and 'if 1:' | ||
# | ||
# Depends on the file being black-formatted! | ||
|
||
import pathlib | ||
import sys | ||
import ast | ||
|
||
VERBOSE = 0 | ||
|
||
# The canonical spelling of this test... | ||
sys_implementation_is_circuitpython = ast.unparse(ast.parse('sys.implementation.name == "circuitpython"')) | ||
sys_implementation_not_circuitpython = ast.unparse(ast.parse('not sys.implementation.name == "circuitpython"')) | ||
sys_implementation_not_circuitpython2 = ast.unparse(ast.parse('sys.implementation.name != "circuitpython"')) | ||
|
||
def munge(src: pathlib.Path|str, version_str: str) -> str: | ||
path = pathlib.Path(src) | ||
replacements = {} | ||
|
||
def replace(line, new): | ||
if VERBOSE: | ||
replacements[line] = f"{new:<40s} ### {lines[line]}" | ||
else: | ||
replacements[line] = new | ||
|
||
def blank_range(node): | ||
for i in range(node.lineno, node.end_lineno+1): | ||
replace(i, "") | ||
|
||
def unblank_range(node): | ||
for i in range(node.lineno, node.end_lineno+1): | ||
replacements.pop(i, None) | ||
|
||
def imports_from_typing(node): | ||
if isinstance(node, ast.Import) and node.names[0].name == 'typing': | ||
return True | ||
if isinstance(node, ast.ImportFrom) and node.module == 'typing': | ||
return True | ||
return False | ||
|
||
def process_statement(node): | ||
# filter out 'from future import...' | ||
if isinstance(node, ast.ImportFrom): | ||
if node.module == '__future__': | ||
blank_range(node) | ||
# filter out 'try: import typing...' | ||
# but preserve the first 'except:' or 'except ImportError' | ||
elif isinstance(node, ast.Try): | ||
b = node.body[0] | ||
if imports_from_typing(node.body[0]): | ||
blank_range(node) | ||
for h in node.handlers: | ||
if h.type is None or ast.unparse(h.type) == 'ImportError' or ast.unparse(h.type) == 'Exception': | ||
unblank_range(h) | ||
replace(h.lineno, 'if 1:') | ||
break | ||
return | ||
elif isinstance(node, ast.If): | ||
node_test = ast.unparse(node.test) | ||
# return the statements in the 'if' branch of 'if sys.implementation...: ...' | ||
if node_test == sys_implementation_is_circuitpython: | ||
replace(node.lineno, 'if 1:') | ||
# return the statements in the 'else' branch of 'if sys.implementation...: ...' | ||
elif node_test == sys_implementation_not_circuitpython or node_test == sys_implementation_not_circuitpython2: | ||
replace(node.lineno, 'if 0:') | ||
# return the statements in the else branch of 'if TYPE_CHECKING: ...' | ||
elif node_test == 'TYPE_CHECKING': | ||
replace(node.lineno, 'if 0:') | ||
elif isinstance(node, ast.Assign) and isinstance(node.targets[0], ast.Name) and node.targets[0].id == '__version__': | ||
replace(node.lineno, f"__version__ = \"{version_str}\"") | ||
|
||
content = pathlib.Path(path).read_text(encoding="utf-8") | ||
# Insert a blank line 0 because ast line numbers are 1-based | ||
lines = [''] + content.rstrip().split('\n') | ||
a = ast.parse(content, path.name) | ||
|
||
for node in a.body: process_statement(node) | ||
|
||
result = [] | ||
for i in range(1, len(lines)): | ||
result.append(replacements.get(i, lines[i])) | ||
|
||
return "\n".join(result) + "\n" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
Click | ||
pytest | ||
requests | ||
semver | ||
wheel | ||
tomli; python_version < "3.11" | ||
wheel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
|
||
|
||
|
||
|
||
if 1: | ||
pass | ||
|
||
|
||
|
||
if 1: | ||
pass | ||
|
||
|
||
|
||
|
||
if 1: | ||
pass | ||
|
||
|
||
|
||
if 1: | ||
pass | ||
|
||
__version__ = "1.2.3" | ||
|
||
if 1: | ||
print("is circuitpython") | ||
|
||
if 0: | ||
print("not circuitpython (1)") | ||
|
||
if 0: | ||
print("not circuitpython (2)") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from __future__ import annotation | ||
|
||
try: | ||
from typing import TYPE_CHECKING | ||
except ImportError: | ||
pass | ||
|
||
try: | ||
from typing import TYPE_CHECKING as T | ||
except ImportError: | ||
pass | ||
|
||
|
||
try: | ||
import typing | ||
except: | ||
pass | ||
|
||
try: | ||
import typing as T | ||
except: | ||
pass | ||
|
||
__version__ = "0.0.0-auto" | ||
|
||
if sys.implementation.name == "circuitpython": | ||
print("is circuitpython") | ||
|
||
if sys.implementation.name != "circuitpython": | ||
print("not circuitpython (1)") | ||
|
||
if not sys.implementation.name == "circuitpython": | ||
print("not circuitpython (2)") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import sys, pathlib | ||
import pytest | ||
|
||
top = pathlib.Path(__file__).parent.parent | ||
sys.path.insert(0, str(top)) | ||
|
||
from circuitpython_build_tools.munge import munge | ||
|
||
@pytest.mark.parametrize("test_path", top.glob("testcases/*.py")) | ||
def test_munge(test_path): | ||
result_path = test_path.with_suffix(".out") | ||
result_path.unlink(missing_ok = True) | ||
|
||
result_content = munge(test_path, "1.2.3") | ||
result_path.write_text(result_content, encoding="utf-8") | ||
|
||
expected_path = test_path.with_suffix(".exp") | ||
expected_content = expected_path.read_text(encoding="utf-8") | ||
|
||
assert result_content == expected_content | ||
|
||
result_path.unlink() |