Skip to content

Commit

Permalink
Implement get_resource_reader method in rewrite's loader
Browse files Browse the repository at this point in the history
  • Loading branch information
giladmoav authored and gmoav committed May 29, 2024
1 parent a705352 commit 09ed945
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 4 deletions.
21 changes: 17 additions & 4 deletions dessert/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from . import util
from .util import format_explanation as _format_explanation

from .pathlib import PurePath, fnmatch_ex
from .pathlib import Path, PurePath, fnmatch_ex

import logging
from munch import Munch
Expand Down Expand Up @@ -65,8 +65,8 @@ def __init__(self, config=None):
self.session = AssertRewritingSession()
self.state = Munch()
self.fnpats = []
self._rewritten_names = set() # type: Set[str]
self._must_rewrite = set() # type: Set[str]
self._rewritten_names: Dict[str, Path] = {}
self._must_rewrite: Set[str] = set()
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
# which might result in infinite recursion (#3506)
self._writing_pyc = False
Expand Down Expand Up @@ -117,7 +117,7 @@ def exec_module(self, module):
fn = module.__spec__.origin
state = self.state

self._rewritten_names.add(module.__name__)
self._rewritten_names[module.__name__] = fn

# The requested module looks like a test file, so rewrite it. This is
# the most magical part of the process: load the source, rewrite the
Expand Down Expand Up @@ -211,6 +211,19 @@ def get_data(self, pathname):
with open(pathname, 'rb') as f:
return f.read()

if sys.version_info >= (3, 10):
if sys.version_info >= (3, 12):
from importlib.resources.abc import TraversableResources
else:
from importlib.abc import TraversableResources

def get_resource_reader(self, name: str) -> TraversableResources:
if sys.version_info < (3, 11):
from importlib.readers import FileReader
else:
from importlib.resources.readers import FileReader

return FileReader(types.SimpleNamespace(path=self._rewritten_names[name]))

def _write_pyc(state, co, source_stat, pyc):
# Technically, we don't have to have the same pyc format as
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
# put py.test fixtures here

pytest_plugins = ["pytester"]
30 changes: 30 additions & 0 deletions tests/test_dessert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import dessert
import pytest
from _pytest.config import ExitCode
from _pytest.pytester import Pytester


def test_dessert(module):
Expand Down Expand Up @@ -148,3 +150,31 @@ def delete(): # pylint: disable=unused-variable
f.write(source)

return filename


@pytest.mark.skipif(
sys.version_info < (3, 9),
reason="importlib.resources.files was introduced in 3.9",
)
def test_load_resource_via_files_with_rewrite(pytester: Pytester) -> None:
example = pytester.path.joinpath("demo") / "example"
init = pytester.path.joinpath("demo") / "__init__.py"
pytester.makepyfile(
**{
"demo/__init__.py": """
from importlib.resources import files
def load():
return files(__name__)
""",
"test_load": f"""
pytest_plugins = ["demo"]
def test_load():
from demo import load
found = {{str(i) for i in load().iterdir() if i.name != "__pycache__"}}
assert found == {{{str(example)!r}, {str(init)!r}}}
""",
}
)
example.mkdir()

assert pytester.runpytest("-vv").ret == ExitCode.OK

0 comments on commit 09ed945

Please sign in to comment.