Skip to content

Commit

Permalink
Merge pull request #26 from giladmoav/fix-loader-issue
Browse files Browse the repository at this point in the history
Implement get_resource_reader method in rewrite's loader
  • Loading branch information
ayalash authored Oct 5, 2024
2 parents a705352 + 5ae1c5d commit db24c60
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
21 changes: 17 additions & 4 deletions dessert/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from . import util
from .util import format_explanation as _format_explanation

from .pathlib import PurePath, fnmatch_ex
from .pathlib import Path, PurePath, fnmatch_ex

import logging
from munch import Munch
Expand Down Expand Up @@ -65,8 +65,8 @@ def __init__(self, config=None):
self.session = AssertRewritingSession()
self.state = Munch()
self.fnpats = []
self._rewritten_names = set() # type: Set[str]
self._must_rewrite = set() # type: Set[str]
self._rewritten_names: Dict[str, Path] = {}
self._must_rewrite: Set[str] = set()
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
# which might result in infinite recursion (#3506)
self._writing_pyc = False
Expand Down Expand Up @@ -117,7 +117,7 @@ def exec_module(self, module):
fn = module.__spec__.origin
state = self.state

self._rewritten_names.add(module.__name__)
self._rewritten_names[module.__name__] = fn

# The requested module looks like a test file, so rewrite it. This is
# the most magical part of the process: load the source, rewrite the
Expand Down Expand Up @@ -211,6 +211,19 @@ def get_data(self, pathname):
with open(pathname, 'rb') as f:
return f.read()

if sys.version_info >= (3, 10):
if sys.version_info >= (3, 12):
from importlib.resources.abc import TraversableResources
else:
from importlib.abc import TraversableResources

def get_resource_reader(self, name: str) -> TraversableResources:
if sys.version_info < (3, 11):
from importlib.readers import FileReader
else:
from importlib.resources.readers import FileReader

return FileReader(types.SimpleNamespace(path=self._rewritten_names[name]))

def _write_pyc(state, co, source_stat, pyc):
# Technically, we don't have to have the same pyc format as
Expand Down
15 changes: 15 additions & 0 deletions tests/test_dessert.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,18 @@ def delete(): # pylint: disable=unused-variable
f.write(source)

return filename


@pytest.mark.skipif(
sys.version_info < (3, 9),
reason="importlib.resources.files was introduced in 3.9",
)
def test_load_resource_via_files_with_rewrite() -> None:
package_dir = mkdtemp()

with open(os.path.join(package_dir, '__init__.py'), 'w') as init_file:
init_file.write("""from importlib.resources import files
assert files(__package__).exists()""")

with dessert.rewrite_assertions_context():
emport.import_file(os.path.join(package_dir, '__init__.py'))

0 comments on commit db24c60

Please sign in to comment.