Skip to content

Commit

Permalink
feat: change CompileAPI to ecosystem based compilers
Browse files Browse the repository at this point in the history
  • Loading branch information
bilbeyt committed Aug 30, 2023
1 parent b45bf22 commit 8ca116a
Show file tree
Hide file tree
Showing 13 changed files with 86 additions and 45 deletions.
5 changes: 5 additions & 0 deletions src/ape/api/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ class CompilerAPI(BaseInterfaceModel):
def name(self) -> str:
...

@property
@abstractmethod
def extension(self) -> str:
...

@abstractmethod
def get_versions(self, all_paths: List[Path]) -> Set[str]:
"""
Expand Down
3 changes: 2 additions & 1 deletion src/ape/api/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def contracts(self) -> Dict[str, ContractType]:

@property
def _cache_folder(self) -> Path:
folder = self.contracts_folder.parent / ".build"
current_ecosystem = self.network_manager.network.ecosystem.name
folder = self.contracts_folder.parent / ".build" / current_ecosystem
# NOTE: If we use the cache folder, we expect it to exist
folder.mkdir(exist_ok=True, parents=True)
return folder
Expand Down
51 changes: 30 additions & 21 deletions src/ape/managers/compilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def __getattr__(self, name: str) -> Any:

raise ApeAttributeError(f"No attribute or compiler named '{name}'.")

@property
def supported_extensions(self) -> Set[str]:
return set(compiler.extension for compiler in self.registered_compilers.values())

@property
def registered_compilers(self) -> Dict[str, CompilerAPI]:
"""
Expand All @@ -52,33 +56,38 @@ def registered_compilers(self) -> Dict[str, CompilerAPI]:
Dict[str, :class:`~ape.api.compiler.CompilerAPI`]: The mapping of file-extensions
to compiler API classes.
"""
current_ecosystem = self.network_manager.network.ecosystem.name
ecosystem_config = self.config_manager.get_config(current_ecosystem)
try:
supported_compilers = ecosystem_config.compilers
except AttributeError:
raise CompilerError(f"No compilers defined for ecosystem={current_ecosystem}.")

cache_key = self.config_manager.PROJECT_FOLDER
if cache_key in self._registered_compilers_cache:
return self._registered_compilers_cache[cache_key]

registered_compilers = {}

for plugin_name, (extensions, compiler_class) in self.plugin_manager.register_compiler:
for plugin_name, compiler_class in self.plugin_manager.register_compiler:
# TODO: Investigate side effects of loading compiler plugins.
# See if this needs to be refactored.
self.config_manager.get_config(plugin_name=plugin_name)

compiler = compiler_class()
compiler = compiler_class() # type: ignore[operator]

for extension in extensions:
if extension not in registered_compilers:
registered_compilers[extension] = compiler
if compiler.name in supported_compilers:
registered_compilers[compiler.name] = compiler

self._registered_compilers_cache[cache_key] = registered_compilers
return registered_compilers

def get_compiler(self, name: str) -> Optional[CompilerAPI]:
def get_compiler(self, identifier: str) -> CompilerAPI:
for compiler in self.registered_compilers.values():
if compiler.name == name:
if compiler.name == identifier or compiler.extension == identifier:
return compiler

return None
raise ValueError("No compiler identified with '{identifier}'")

def compile(self, contract_filepaths: List[Path]) -> Dict[str, ContractType]:
"""
Expand Down Expand Up @@ -124,10 +133,8 @@ def compile(self, contract_filepaths: List[Path]) -> Dict[str, ContractType]:
for path in paths_to_compile:
source_id = get_relative_path(path, contracts_folder)
logger.info(f"Compiling '{source_id}'.")

compiled_contracts = self.registered_compilers[extension].compile(
paths_to_compile, base_path=contracts_folder
)
compiler = self.get_compiler(extension)
compiled_contracts = compiler.compile(paths_to_compile, base_path=contracts_folder)
for contract_type in compiled_contracts:
contract_name = contract_type.name
if not contract_name:
Expand Down Expand Up @@ -176,9 +183,11 @@ def get_imports(
imports_dict: Dict[str, List[str]] = {}
base_path = base_path or self.project_manager.contracts_folder

for ext, compiler in self.registered_compilers.items():
for compiler in self.registered_compilers.values():
try:
sources = [p for p in contract_filepaths if p.suffix == ext and p.is_file()]
sources = [
p for p in contract_filepaths if p.suffix == compiler.extension and p.is_file()
]
imports = compiler.get_imports(contract_filepaths=sources, base_path=base_path)
except NotImplementedError:
imports = None
Expand Down Expand Up @@ -214,7 +223,7 @@ def get_references(self, imports_dict: Dict[str, List[str]]) -> Dict[str, List[s

def _get_contract_extensions(self, contract_filepaths: List[Path]) -> Set[str]:
extensions = {path.suffix for path in contract_filepaths}
unhandled_extensions = {s for s in extensions - set(self.registered_compilers) if s}
unhandled_extensions = {s for s in extensions - self.supported_extensions if s}
if len(unhandled_extensions) > 0:
unhandled_extensions_str = ", ".join(unhandled_extensions)
raise CompilerError(f"No compiler found for extensions [{unhandled_extensions_str}].")
Expand Down Expand Up @@ -249,11 +258,11 @@ def enrich_error(self, err: ContractLogicError) -> ContractLogicError:
return err

ext = Path(contract.source_id).suffix
if ext not in self.registered_compilers:
if ext not in self.supported_extensions:
# Compiler not found.
return err

compiler = self.registered_compilers[ext]
compiler = self.get_compiler(ext)
return compiler.enrich_error(err)

def flatten_contract(self, path: Path) -> Content:
Expand All @@ -268,12 +277,12 @@ def flatten_contract(self, path: Path) -> Content:
``ethpm_types.source.Content``: The flattened contract content.
"""

if path.suffix not in self.registered_compilers:
if path.suffix not in self.supported_extensions:
raise CompilerError(
f"Unable to flatten contract. Missing compiler for '{path.suffix}'."
)

compiler = self.registered_compilers[path.suffix]
compiler = self.get_compiler(path.suffix)
return compiler.flatten_contract(path)

def can_trace_source(self, filename: str) -> bool:
Expand All @@ -293,8 +302,8 @@ def can_trace_source(self, filename: str) -> bool:
return False

extension = path.suffix
if extension in self.registered_compilers:
compiler = self.registered_compilers[extension]
if extension in self.supported_extensions:
compiler = self.get_compiler(extension)
if compiler.supports_source_tracing:
return True

Expand Down
16 changes: 10 additions & 6 deletions src/ape/managers/project/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def source_paths(self) -> List[Path]:
if not self.contracts_folder.is_dir():
return files

for extension in self.compiler_manager.registered_compilers:
for extension in self.compiler_manager.supported_extensions:
files.extend((x for x in self.contracts_folder.rglob(f"*{extension}") if x.is_file()))

return files
Expand Down Expand Up @@ -169,8 +169,10 @@ def _get_compiler_data(self, compile_if_needed: bool = True):
)
compiler_list: List[Compiler] = []
contracts_folder = self.config_manager.contracts_folder
for ext, compiler in self.compiler_manager.registered_compilers.items():
sources = [x for x in self.source_paths if x.is_file() and x.suffix == ext]
for compiler in self.compiler_manager.registered_compilers.values():
sources = [
x for x in self.source_paths if x.is_file() and x.suffix == compiler.extension
]
if not sources:
continue

Expand All @@ -183,7 +185,9 @@ def _get_compiler_data(self, compile_if_needed: bool = True):
# These are unlikely to be part of the published manifest
continue
elif len(versions) > 1:
raise (ProjectError(f"Unable to create version map for '{ext}'."))
raise (
ProjectError(f"Unable to create version map for '{compiler.extension}'.")
)

version = versions[0]
version_map = {version: sources}
Expand Down Expand Up @@ -336,7 +340,7 @@ def get_project(
else path / "contracts"
)
if not contracts_folder.is_dir():
extensions = list(self.compiler_manager.registered_compilers.keys())
extensions = list(self.compiler_manager.supported_extensions)
path_patterns_to_ignore = self.config_manager.compiler.ignore_files

def find_contracts_folder(sub_dir: Path) -> Optional[Path]:
Expand Down Expand Up @@ -586,7 +590,7 @@ def _append_extensions_in_dir(directory: Path):
elif (
file.suffix
and file.suffix not in extensions_found
and file.suffix not in self.compiler_manager.registered_compilers
and file.suffix not in self.compiler_manager.supported_extensions
):
extensions_found.append(file.suffix)

Expand Down
4 changes: 2 additions & 2 deletions src/ape/managers/project/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def source_paths(self) -> List[Path]:
return files

compilers = self.compiler_manager.registered_compilers
for extension in compilers:
ext = extension.replace(".", "\\.")
for compiler in compilers.values():
ext = compiler.extension.replace(".", "\\.")
pattern = rf"[\w|-]+{ext}"
ext_files = get_all_files_in_directory(self.contracts_folder, pattern=pattern)
files.extend(ext_files)
Expand Down
8 changes: 4 additions & 4 deletions src/ape/plugins/compiler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, Type
from typing import Type

from ape.api import CompilerAPI

Expand All @@ -13,7 +13,7 @@ class CompilerPlugin(PluginType):
"""

@hookspec
def register_compiler(self) -> Tuple[Tuple[str], Type[CompilerAPI]]: # type: ignore[empty-body]
def register_compiler(self) -> Type[CompilerAPI]: # type: ignore[empty-body]
"""
A hook for returning the set of file extensions the plugin handles
and the compiler class that can be used to compile them.
Expand All @@ -22,8 +22,8 @@ def register_compiler(self) -> Tuple[Tuple[str], Type[CompilerAPI]]: # type: ig
@plugins.register(plugins.CompilerPlugin)
def register_compiler():
return (".json",), InterfaceCompiler
return InterfaceCompiler
Returns:
Tuple[Tuple[str], Type[:class:`~ape.api.CompilerAPI`]]
Type[:class:`~ape.api.CompilerAPI`]
"""
5 changes: 3 additions & 2 deletions src/ape/pytest/coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@ def _init_coverage_profile(
for src in self.sources:
source_cov = project_coverage.include(src)
ext = Path(src.source_id).suffix
if ext not in self.compiler_manager.registered_compilers:
if ext not in self.compiler_manager.supported_extensions:
continue

compiler = self.compiler_manager.registered_compilers[ext]
compiler = self.compiler_manager.get_compiler(ext)
assert compiler is not None
try:
compiler.init_coverage_profile(source_cov, src)
except NotImplementedError:
Expand Down
5 changes: 3 additions & 2 deletions src/ape/types/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,10 +503,11 @@ def create(
return cls.parse_obj([])

ext = f".{source_id.split('.')[-1]}"
if ext not in accessor.compiler_manager.registered_compilers:
if ext not in accessor.compiler_manager.supported_extensions:
return cls.parse_obj([])

compiler = accessor.compiler_manager.registered_compilers[ext]
compiler = accessor.compiler_manager.get_compiler(ext)
assert compiler is not None
try:
return compiler.trace_source(contract_type, trace, HexBytes(data))
except NotImplementedError:
Expand Down
19 changes: 16 additions & 3 deletions src/ape_compile/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@
import click
from ethpm_types import ContractType

from ape.cli import ape_cli_context, contract_file_paths_argument
from ape.cli import (
NetworkBoundCommand,
ape_cli_context,
contract_file_paths_argument,
network_option,
)


def _include_dependencies_callback(ctx, param, value):
return value or ctx.obj.config_manager.get_config("compile").include_dependencies


@click.command(short_help="Compile select contract source files")
@click.command(short_help="Compile select contract source files", cls=NetworkBoundCommand)
@contract_file_paths_argument()
@click.option(
"-f",
Expand All @@ -37,7 +42,15 @@ def _include_dependencies_callback(ctx, param, value):
callback=_include_dependencies_callback,
)
@ape_cli_context()
def cli(cli_ctx, file_paths: Set[Path], use_cache: bool, display_size: bool, include_dependencies):
@network_option()
def cli(
cli_ctx,
file_paths: Set[Path],
use_cache: bool,
display_size: bool,
include_dependencies: bool,
network: str,
):
"""
Compiles the manifest for this project and saves the results
back to the manifest.
Expand Down
1 change: 1 addition & 0 deletions src/ape_ethereum/ecosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ class EthereumConfig(PluginConfig):
sepolia_fork: NetworkConfig = _create_local_config()
local: NetworkConfig = _create_local_config(default_provider="test")
default_network: str = LOCAL_NETWORK_NAME
compilers: Dict[str, Dict[str, Any]] = {"ethpm": {}}


class Block(BlockAPI):
Expand Down
2 changes: 1 addition & 1 deletion src/ape_pm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@

@plugins.register(plugins.CompilerPlugin)
def register_compiler():
return (".json",), InterfaceCompiler
return InterfaceCompiler
4 changes: 4 additions & 0 deletions src/ape_pm/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class InterfaceCompiler(CompilerAPI):
def name(self) -> str:
return "ethpm"

@property
def extension(self) -> str:
return ".json"

def get_versions(self, all_paths: List[Path]) -> Set[str]:
# NOTE: This bypasses the serialization of this compiler into the package manifest's
# ``compilers`` field. You should not do this with a real compiler plugin.
Expand Down
8 changes: 5 additions & 3 deletions src/ape_test/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from watchdog import events # type: ignore
from watchdog.observers import Observer # type: ignore

from ape.cli import ape_cli_context
from ape.cli import NetworkBoundCommand, ape_cli_context, network_option
from ape.utils import ManagerAccessMixin, cached_property

# Copied from https://github.com/olzhasar/pytest-watcher/blob/master/pytest_watcher/watcher.py
Expand Down Expand Up @@ -44,7 +44,7 @@ def dispatch(self, event: events.FileSystemEvent) -> None:

@cached_property
def _extensions_to_watch(self) -> List[str]:
return [".py", *self.compiler_manager.registered_compilers.keys()]
return [".py", *self.compiler_manager.supported_extensions]

def _is_path_watched(self, filepath: str) -> bool:
"""
Expand Down Expand Up @@ -78,8 +78,10 @@ def _run_main_loop(delay: float, pytest_args: Sequence[str]) -> None:
add_help_option=False, # NOTE: This allows pass-through to pytest's help
short_help="Launches pytest and runs the tests for a project",
context_settings=dict(ignore_unknown_options=True),
cls=NetworkBoundCommand,
)
@ape_cli_context()
@network_option()
@click.option(
"-w",
"--watch",
Expand All @@ -104,7 +106,7 @@ def _run_main_loop(delay: float, pytest_args: Sequence[str]) -> None:
help="Delay between polling cycles for `ape test --watch`. Defaults to 0.5 seconds.",
)
@click.argument("pytest_args", nargs=-1, type=click.UNPROCESSED)
def cli(cli_ctx, watch, watch_folders, watch_delay, pytest_args):
def cli(cli_ctx, watch, watch_folders, watch_delay, pytest_args, network):
if watch:
event_handler = EventHandler()

Expand Down

0 comments on commit 8ca116a

Please sign in to comment.