From 1158e08b0da9dee60cbf4f21cc886f3c94050b5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Mon, 9 Oct 2023 18:31:17 +0200 Subject: [PATCH 1/5] Fix main to pass enums_module_name option --- CHANGELOG.md | 1 + ariadne_codegen/main.py | 1 + .../main/clients/custom_files_names/expected_client/__init__.py | 2 +- .../expected_client/{enums.py => custom_enums.py} | 0 .../custom_files_names/expected_client/custom_input_types.py | 2 +- 5 files changed, 4 insertions(+), 2 deletions(-) rename tests/main/clients/custom_files_names/expected_client/{enums.py => custom_enums.py} (100%) diff --git a/CHANGELOG.md b/CHANGELOG.md index d0e63a80..10090d88 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ - Digits in Python names are now preceded by an underscore (breaking change). - Fixed parsing of unions and interfaces to always add `__typename` to generated result models. - Added escaping of enum values which are Python keywords by appending `_` to them. +- Fixed `enums_module_name` option. ## 0.9.0 (2023-09-11) diff --git a/ariadne_codegen/main.py b/ariadne_codegen/main.py index 805b8852..0139d58b 100644 --- a/ariadne_codegen/main.py +++ b/ariadne_codegen/main.py @@ -74,6 +74,7 @@ def client(config_dict): client_file_name=settings.client_file_name, base_client_name=settings.base_client_name, base_client_file_path=settings.base_client_file_path, + enums_module_name=settings.enums_module_name, input_types_module_name=settings.input_types_module_name, fragments_module_name=settings.fragments_module_name, queries_source=settings.queries_path, diff --git a/tests/main/clients/custom_files_names/expected_client/__init__.py b/tests/main/clients/custom_files_names/expected_client/__init__.py index 08792a74..8889a2f2 100644 --- a/tests/main/clients/custom_files_names/expected_client/__init__.py +++ b/tests/main/clients/custom_files_names/expected_client/__init__.py @@ -1,8 +1,8 @@ from .async_base_client import AsyncBaseClient from .base_model import BaseModel, Upload from .custom_client import Client +from .custom_enums import enumA from .custom_input_types import inputA -from .enums import enumA from .exceptions import ( GraphQLClientError, GraphQLClientGraphQLError, diff --git a/tests/main/clients/custom_files_names/expected_client/enums.py b/tests/main/clients/custom_files_names/expected_client/custom_enums.py similarity index 100% rename from tests/main/clients/custom_files_names/expected_client/enums.py rename to tests/main/clients/custom_files_names/expected_client/custom_enums.py diff --git a/tests/main/clients/custom_files_names/expected_client/custom_input_types.py b/tests/main/clients/custom_files_names/expected_client/custom_input_types.py index a20b4a93..29cd071b 100644 --- a/tests/main/clients/custom_files_names/expected_client/custom_input_types.py +++ b/tests/main/clients/custom_files_names/expected_client/custom_input_types.py @@ -1,5 +1,5 @@ from .base_model import BaseModel -from .enums import enumA +from .custom_enums import enumA class inputA(BaseModel): From fbb70fd0da6077d8eddecadfd28324630c349794 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Tue, 10 Oct 2023 17:31:16 +0200 Subject: [PATCH 2/5] Refactor PackageGenerator by moving inicialization of other generators to separate factory method --- README.md | 2 +- ariadne_codegen/client_generators/client.py | 16 +- .../client_generators/constants.py | 20 +- .../client_generators/input_types.py | 8 +- ariadne_codegen/client_generators/package.py | 256 +++++----- ariadne_codegen/main.py | 23 +- ariadne_codegen/settings.py | 16 +- tests/client_generators/conftest.py | 37 -- .../test_default_values.py | 40 +- .../input_types_generator/test_imports.py | 25 +- .../test_method_calls.py | 11 +- .../input_types_generator/test_names.py | 20 +- .../test_parsing_inputs.py | 20 +- .../test_plugin_hooks.py | 33 +- .../test_client_generator.py | 163 ++----- .../test_package_generator.py | 450 ++++++++++++++---- tests/test_settings.py | 6 +- 17 files changed, 608 insertions(+), 538 deletions(-) diff --git a/README.md b/README.md index 297e2363..f530182b 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ Optional settings: - `client_name` (defaults to `"Client"`) - name of generated client class - `client_file_name` (defaults to `"client"`) - name of file with generated client class - `base_client_name` (defaults to `"AsyncBaseClient"`) - name of base client class -- `base_client_file_path` (defaults to `.../graphql_sdk_gen/generators/async_base_client.py`) - path to file where `base_client_name` is defined +- `base_client_file_path` (defaults to `.../ariadne_codegen/client_generators/dependencies/async_base_client.py`) - path to file where `base_client_name` is defined - `enums_module_name` (defaults to `"enums"`) - name of file with generated enums models - `input_types_module_name` (defaults to `"input_types"`) - name of file with generated input types models - `fragments_module_name` (defaults to `"fragments"`) - name of file with generated fragments models diff --git a/ariadne_codegen/client_generators/client.py b/ariadne_codegen/client_generators/client.py index 63542d5a..23b6798b 100644 --- a/ariadne_codegen/client_generators/client.py +++ b/ariadne_codegen/client_generators/client.py @@ -38,6 +38,8 @@ OPTIONAL, TYPING_MODULE, UNION, + UNSET_IMPORT, + UPLOAD_IMPORT, ) from .scalars import ScalarData, generate_scalar_imports @@ -45,14 +47,14 @@ class ClientGenerator: def __init__( self, - name: str, - base_client: str, - enums_module_name: str, - input_types_module_name: str, - arguments_generator: ArgumentsGenerator, base_client_import: ast.ImportFrom, - unset_import: ast.ImportFrom, - upload_import: ast.ImportFrom, + arguments_generator: ArgumentsGenerator, + name: str = "Client", + base_client: str = "AsyncBaseClient", + enums_module_name: str = "enums", + input_types_module_name: str = "input_types", + unset_import: ast.ImportFrom = UNSET_IMPORT, + upload_import: ast.ImportFrom = UPLOAD_IMPORT, custom_scalars: Optional[Dict[str, ScalarData]] = None, plugin_manager: Optional[PluginManager] = None, ) -> None: diff --git a/ariadne_codegen/client_generators/constants.py b/ariadne_codegen/client_generators/constants.py index fbccbbe1..fd66599e 100644 --- a/ariadne_codegen/client_generators/constants.py +++ b/ariadne_codegen/client_generators/constants.py @@ -1,3 +1,4 @@ +import ast from pathlib import Path SIMPLE_TYPE_MAP = { @@ -26,7 +27,23 @@ SOURCE_COMMENT = "# Source: {}" COMMENT_DATETIME_FORMAT = "%Y-%m-%d %H:%M" +BASE_MODEL_FILE_PATH = Path(__file__).parent / "dependencies" / "base_model.py" BASE_MODEL_CLASS_NAME = "BaseModel" +BASE_MODEL_IMPORT = ast.ImportFrom( + module=BASE_MODEL_FILE_PATH.stem, names=[ast.alias(BASE_MODEL_CLASS_NAME)], level=1 +) +UPLOAD_IMPORT = ast.ImportFrom( + module=BASE_MODEL_FILE_PATH.stem, names=[ast.alias(UPLOAD_CLASS_NAME)], level=1 +) +UNSET_NAME = "UNSET" +UNSET_TYPE_NAME = "UnsetType" +UNSET_IMPORT = ast.ImportFrom( + module=BASE_MODEL_FILE_PATH.stem, + names=[ast.alias(UNSET_NAME), ast.alias(UNSET_TYPE_NAME)], + level=1, +) + +EXCEPTIONS_FILE_PATH = Path(__file__).parent / "dependencies" / "exceptions.py" TYPENAME_FIELD_NAME = "__typename" TYPENAME_ALIAS = "typename__" @@ -66,6 +83,3 @@ SCALARS_PARSE_DICT_NAME = "SCALARS_PARSE_FUNCTIONS" SCALARS_SERIALIZE_DICT_NAME = "SCALARS_SERIALIZE_FUNCTIONS" - -UNSET_NAME = "UNSET" -UNSET_TYPE_NAME = "UnsetType" diff --git a/ariadne_codegen/client_generators/input_types.py b/ariadne_codegen/client_generators/input_types.py index 174c630f..06d73826 100644 --- a/ariadne_codegen/client_generators/input_types.py +++ b/ariadne_codegen/client_generators/input_types.py @@ -27,6 +27,7 @@ ANNOTATED, ANY, BASE_MODEL_CLASS_NAME, + BASE_MODEL_IMPORT, FIELD_CLASS, LIST, MODEL_REBUILD_METHOD, @@ -35,6 +36,7 @@ PYDANTIC_MODULE, TYPING_MODULE, UNION, + UPLOAD_IMPORT, ) from .input_fields import parse_input_field_default_value, parse_input_field_type from .scalars import ScalarData, generate_scalar_imports @@ -44,9 +46,9 @@ class InputTypesGenerator: def __init__( self, schema: GraphQLSchema, - enums_module: str, - base_model_import: ast.ImportFrom, - upload_import: ast.ImportFrom, + enums_module: str = "enums", + base_model_import: ast.ImportFrom = BASE_MODEL_IMPORT, + upload_import: ast.ImportFrom = UPLOAD_IMPORT, convert_to_snake_case: bool = True, custom_scalars: Optional[Dict[str, ScalarData]] = None, plugin_manager: Optional[PluginManager] = None, diff --git a/ariadne_codegen/client_generators/package.py b/ariadne_codegen/client_generators/package.py index d7e66d04..af0aa5e6 100644 --- a/ariadne_codegen/client_generators/package.py +++ b/ariadne_codegen/client_generators/package.py @@ -7,19 +7,22 @@ from ..codegen import generate_import_from from ..exceptions import ParsingError from ..plugins.manager import PluginManager -from ..settings import CommentsStrategy +from ..settings import ClientSettings, CommentsStrategy from ..utils import ast_to_str, process_name, str_to_pascal_case from .arguments import ArgumentsGenerator from .client import ClientGenerator from .comments import get_comment from .constants import ( BASE_MODEL_CLASS_NAME, + BASE_MODEL_FILE_PATH, + BASE_MODEL_IMPORT, DEFAULT_ASYNC_BASE_CLIENT_PATH, DEFAULT_BASE_CLIENT_PATH, + EXCEPTIONS_FILE_PATH, GRAPHQL_CLIENT_EXCEPTIONS_NAMES, - UNSET_NAME, - UNSET_TYPE_NAME, + UNSET_IMPORT, UPLOAD_CLASS_NAME, + UPLOAD_IMPORT, ) from .enums import EnumsGenerator from .fragments import FragmentsGenerator @@ -35,10 +38,16 @@ def __init__( package_name: str, target_path: str, schema: GraphQLSchema, + init_generator: InitFileGenerator, + client_generator: ClientGenerator, + enums_generator: EnumsGenerator, + input_types_generator: InputTypesGenerator, + fragments_definitions: Optional[Dict[str, FragmentDefinitionNode]] = None, client_name: str = "Client", - client_file_name: str = "client", + async_client: bool = True, base_client_name: str = "AsyncBaseClient", - base_client_file_path: Optional[str] = None, + base_client_file_path: str = DEFAULT_ASYNC_BASE_CLIENT_PATH.as_posix(), + client_file_name: str = "client", enums_module_name: str = "enums", input_types_module_name: str = "input_types", fragments_module_name: str = "fragments", @@ -46,124 +55,60 @@ def __init__( queries_source: str = "", schema_source: str = "", convert_to_snake_case: bool = True, - async_client: bool = True, - fragments: Optional[List[FragmentDefinitionNode]] = None, - init_generator: Optional[InitFileGenerator] = None, - client_generator: Optional[ClientGenerator] = None, - enums_generator: Optional[EnumsGenerator] = None, - input_types_generator: Optional[InputTypesGenerator] = None, + base_model_file_path: str = BASE_MODEL_FILE_PATH.as_posix(), + base_model_import: ast.ImportFrom = BASE_MODEL_IMPORT, + upload_import: ast.ImportFrom = UPLOAD_IMPORT, + unset_import: ast.ImportFrom = UNSET_IMPORT, files_to_include: Optional[List[str]] = None, custom_scalars: Optional[Dict[str, ScalarData]] = None, plugin_manager: Optional[PluginManager] = None, ) -> None: - self.package_name = package_name - self.target_path = target_path - self.schema = schema self.package_path = Path(target_path) / package_name - self.client_name = client_name - self.base_client_name = base_client_name - self.custom_scalars = custom_scalars if custom_scalars else {} - - self.plugin_manager = plugin_manager - self.base_model_file_path = ( - Path(__file__).parent / "dependencies" / "base_model.py" - ) - self.base_model_import = generate_import_from( - [BASE_MODEL_CLASS_NAME], self.base_model_file_path.stem, 1 - ) - self.upload_import = generate_import_from( - [UPLOAD_CLASS_NAME], self.base_model_file_path.stem, 1 - ) - self.unset_import = generate_import_from( - [UNSET_NAME, UNSET_TYPE_NAME], self.base_model_file_path.stem, 1 - ) - self.exceptions_file_path = ( - Path(__file__).parent / "dependencies" / "exceptions.py" + self.schema = schema + self.fragments_definitions = ( + fragments_definitions if fragments_definitions else {} ) - self.files_to_include = ( - [Path(f) for f in files_to_include] if files_to_include else [] - ) + self.init_generator = init_generator + self.client_generator = client_generator + self.enums_generator = enums_generator + self.input_types_generator = input_types_generator + + self.client_name = client_name + self.async_client = async_client + self.base_client_name = base_client_name + self.base_client_file_path = Path(base_client_file_path) + self.client_file_name = client_file_name self.enums_module_name = enums_module_name self.input_types_module_name = input_types_module_name self.fragments_module_name = fragments_module_name - self.client_file_name = client_file_name self.comments_strategy = comments_strategy self.queries_source = queries_source self.schema_source = schema_source - self.convert_to_snake_case = convert_to_snake_case - self.async_client = async_client - if base_client_file_path: - self.base_client_file_path = Path(base_client_file_path) - else: - if self.async_client: - self.base_client_file_path = DEFAULT_ASYNC_BASE_CLIENT_PATH - else: - self.base_client_file_path = DEFAULT_BASE_CLIENT_PATH - - self.init_generator = ( - init_generator - if init_generator - else InitFileGenerator(plugin_manager=self.plugin_manager) - ) - self.client_generator = ( - client_generator - if client_generator - else ClientGenerator( - name=self.client_name, - base_client=self.base_client_name, - enums_module_name=self.enums_module_name, - input_types_module_name=self.input_types_module_name, - arguments_generator=ArgumentsGenerator( - schema=self.schema, - convert_to_snake_case=self.convert_to_snake_case, - custom_scalars=self.custom_scalars, - plugin_manager=self.plugin_manager, - ), - base_client_import=generate_import_from( - names=[self.base_client_name], - from_=self.base_client_file_path.stem, - level=1, - ), - unset_import=self.unset_import, - upload_import=self.upload_import, - custom_scalars=self.custom_scalars, - plugin_manager=self.plugin_manager, - ) - ) - self.input_types_generator = ( - input_types_generator - if input_types_generator - else InputTypesGenerator( - schema=self.schema, - enums_module=self.enums_module_name, - base_model_import=self.base_model_import, - upload_import=self.upload_import, - convert_to_snake_case=self.convert_to_snake_case, - custom_scalars=self.custom_scalars, - plugin_manager=self.plugin_manager, - ) - ) - self.enums_generator = ( - enums_generator - if enums_generator - else EnumsGenerator(schema=self.schema, plugin_manager=self.plugin_manager) - ) + self.convert_to_snake_case = convert_to_snake_case - self.fragments_definitions = {f.name.value: f for f in fragments or []} + self.base_model_file_path = Path(base_model_file_path) + self.base_model_import = base_model_import + self.upload_import = upload_import + self.unset_import = unset_import - self.result_types_files: Dict[str, ast.Module] = {} - self.generated_files: List[str] = [] - self.include_exceptions_file = self._include_exceptions() + self.files_to_include = ( + [Path(f) for f in files_to_include] if files_to_include else [] + ) + self.custom_scalars = custom_scalars if custom_scalars else {} + self.plugin_manager = plugin_manager + self._result_types_files: Dict[str, ast.Module] = {} + self._generated_files: List[str] = [] self._unpacked_fragments: Set[str] = set() def generate(self) -> List[str]: """Generate package with graphql client.""" + self._include_exceptions() self._validate_unique_file_names() if not self.package_path.exists(): self.package_path.mkdir() @@ -175,7 +120,7 @@ def generate(self) -> List[str]: self._generate_client() self._generate_init() - return sorted(self.generated_files) + return sorted(self._generated_files) def add_operation(self, definition: OperationDefinitionNode): name = definition.name @@ -206,7 +151,7 @@ def add_operation(self, definition: OperationDefinitionNode): self._unpacked_fragments = self._unpacked_fragments.union( query_types_generator.get_unpacked_fragments() ) - self.result_types_files[file_name] = query_types_generator.generate() + self._result_types_files[file_name] = query_types_generator.generate() operation_str = query_types_generator.get_operation_as_str() self.init_generator.add_import( query_types_generator.get_generated_public_names(), module_name, 1 @@ -222,10 +167,16 @@ def add_operation(self, definition: OperationDefinitionNode): ) def _include_exceptions(self): - return self.base_client_file_path in ( + if self.base_client_file_path in ( DEFAULT_ASYNC_BASE_CLIENT_PATH, DEFAULT_BASE_CLIENT_PATH, - ) + ): + self.files_to_include.append(EXCEPTIONS_FILE_PATH) + self.init_generator.add_import( + names=GRAPHQL_CLIENT_EXCEPTIONS_NAMES, + from_=EXCEPTIONS_FILE_PATH.stem, + level=1, + ) def _validate_unique_file_names(self): file_names = ( @@ -237,11 +188,9 @@ def _validate_unique_file_names(self): f"{self.input_types_module_name}.py", f"{self.fragments_module_name}.py", ] - + list(self.result_types_files.keys()) + + list(self._result_types_files.keys()) + [f.name for f in self.files_to_include] ) - if self.include_exceptions_file: - file_names.append(self.exceptions_file_path.name) if len(file_names) != len(set(file_names)): seen = set() @@ -257,7 +206,7 @@ def _generate_client(self): if self.plugin_manager: code = self.plugin_manager.generate_client_code(code) client_file_path.write_text(code) - self.generated_files.append(client_file_path.name) + self._generated_files.append(client_file_path.name) self.init_generator.add_import( names=[self.client_generator.name], from_=self.client_file_name, level=1 @@ -281,7 +230,7 @@ def _generate_enums(self): code = self.plugin_manager.generate_enums_code(code) enums_file_path = self.package_path / f"{self.enums_module_name}.py" enums_file_path.write_text(code) - self.generated_files.append(enums_file_path.name) + self._generated_files.append(enums_file_path.name) self.init_generator.add_import( self.enums_generator.get_generated_public_names(), self.enums_module_name, 1 ) @@ -293,7 +242,7 @@ def _generate_input_types(self): if self.plugin_manager: code = self.plugin_manager.generate_inputs_code(code) input_types_file_path.write_text(code) - self.generated_files.append(input_types_file_path.name) + self._generated_files.append(input_types_file_path.name) self.init_generator.add_import( self.input_types_generator.get_generated_public_names(), self.input_types_module_name, @@ -301,13 +250,13 @@ def _generate_input_types(self): ) def _generate_result_types(self): - for file_name, module in self.result_types_files.items(): + for file_name, module in self._result_types_files.items(): file_path = self.package_path / file_name code = self._add_comments_to_code(ast_to_str(module), self.queries_source) if self.plugin_manager: code = self.plugin_manager.generate_result_types_code(code) file_path.write_text(code) - self.generated_files.append(file_path.name) + self._generated_files.append(file_path.name) def _generate_fragments(self): if not set(self.fragments_definitions.keys()).difference( @@ -329,7 +278,7 @@ def _generate_fragments(self): file_path = self.package_path / f"{self.fragments_module_name}.py" code = self._add_comments_to_code(ast_to_str(module), self.queries_source) file_path.write_text(code) - self.generated_files.append(file_path.name) + self._generated_files.append(file_path.name) self.init_generator.add_import( generator.get_generated_public_names(), self.fragments_module_name, 1 ) @@ -339,20 +288,13 @@ def _copy_files(self): self.base_client_file_path, self.base_model_file_path, ] - if self.include_exceptions_file: - files_to_copy.append(self.exceptions_file_path) - self.init_generator.add_import( - names=GRAPHQL_CLIENT_EXCEPTIONS_NAMES, - from_=self.exceptions_file_path.stem, - level=1, - ) for source_path in files_to_copy: - code = self._add_comments_to_code(source_path.read_text()) + code = self._add_comments_to_code(source_path.read_text(encoding="utf-8")) if self.plugin_manager: code = self.plugin_manager.copy_code(code) target_path = self.package_path / source_path.name target_path.write_text(code) - self.generated_files.append(target_path.name) + self._generated_files.append(target_path.name) self.init_generator.add_import( names=[self.base_client_name], @@ -372,4 +314,74 @@ def _generate_init(self): if self.plugin_manager: code = self.plugin_manager.generate_init_code(code) init_file_path.write_text(code) - self.generated_files.append(init_file_path.name) + self._generated_files.append(init_file_path.name) + + +def get_package_generator( + schema: GraphQLSchema, + fragments: List[FragmentDefinitionNode], + settings: ClientSettings, + plugin_manager: PluginManager, +) -> PackageGenerator: + init_generator = InitFileGenerator(plugin_manager=plugin_manager) + client_generator = ClientGenerator( + base_client_import=generate_import_from( + names=[settings.base_client_name], + from_=Path(settings.base_client_file_path).stem, + level=1, + ), + arguments_generator=ArgumentsGenerator( + schema=schema, + convert_to_snake_case=settings.convert_to_snake_case, + custom_scalars=settings.scalars, + plugin_manager=plugin_manager, + ), + name=settings.client_name, + base_client=settings.base_client_name, + enums_module_name=settings.enums_module_name, + input_types_module_name=settings.input_types_module_name, + unset_import=UNSET_IMPORT, + upload_import=UPLOAD_IMPORT, + custom_scalars=settings.scalars, + plugin_manager=plugin_manager, + ) + enums_generator = EnumsGenerator(schema=schema, plugin_manager=plugin_manager) + input_types_generator = InputTypesGenerator( + schema=schema, + enums_module=settings.enums_module_name, + base_model_import=BASE_MODEL_IMPORT, + upload_import=UPLOAD_IMPORT, + convert_to_snake_case=settings.convert_to_snake_case, + custom_scalars=settings.scalars, + plugin_manager=plugin_manager, + ) + + return PackageGenerator( + package_name=settings.target_package_name, + target_path=settings.target_package_path, + schema=schema, + init_generator=init_generator, + client_generator=client_generator, + enums_generator=enums_generator, + input_types_generator=input_types_generator, + fragments_definitions={f.name.value: f for f in fragments or []}, + client_name=settings.client_name, + async_client=settings.async_client, + base_client_name=settings.base_client_name, + base_client_file_path=settings.base_client_file_path, + client_file_name=settings.client_file_name, + enums_module_name=settings.enums_module_name, + input_types_module_name=settings.input_types_module_name, + fragments_module_name=settings.fragments_module_name, + comments_strategy=settings.include_comments, + queries_source=settings.queries_path, + schema_source=settings.schema_source, + convert_to_snake_case=settings.convert_to_snake_case, + base_model_file_path=BASE_MODEL_FILE_PATH.as_posix(), + base_model_import=BASE_MODEL_IMPORT, + upload_import=UPLOAD_IMPORT, + unset_import=UNSET_IMPORT, + files_to_include=settings.files_to_include, + custom_scalars=settings.scalars, + plugin_manager=plugin_manager, + ) diff --git a/ariadne_codegen/main.py b/ariadne_codegen/main.py index 0139d58b..8f7531fa 100644 --- a/ariadne_codegen/main.py +++ b/ariadne_codegen/main.py @@ -3,7 +3,7 @@ import click from graphql import assert_valid_schema -from .client_generators.package import PackageGenerator +from .client_generators.package import get_package_generator from .config import get_client_settings, get_config_dict, get_graphql_schema_settings from .graphql_schema_generators.schema import generate_graphql_schema_file from .plugins.explorer import get_plugins_types @@ -42,14 +42,12 @@ def client(config_dict): if settings.schema_path: schema = get_graphql_schema_from_path(settings.schema_path) - schema_source = settings.schema_path else: schema = get_graphql_schema_from_url( url=settings.remote_schema_url, headers=settings.remote_schema_headers, verify_ssl=settings.remote_schema_verify_ssl, ) - schema_source = settings.remote_schema_url plugin_manager = PluginManager( schema=schema, @@ -66,25 +64,10 @@ def client(config_dict): sys.stdout.write(settings.used_settings_message) - package_generator = PackageGenerator( - package_name=settings.target_package_name, - target_path=settings.target_package_path, + package_generator = get_package_generator( schema=schema, - client_name=settings.client_name, - client_file_name=settings.client_file_name, - base_client_name=settings.base_client_name, - base_client_file_path=settings.base_client_file_path, - enums_module_name=settings.enums_module_name, - input_types_module_name=settings.input_types_module_name, - fragments_module_name=settings.fragments_module_name, - queries_source=settings.queries_path, - schema_source=schema_source, - comments_strategy=settings.include_comments, fragments=fragments, - convert_to_snake_case=settings.convert_to_snake_case, - async_client=settings.async_client, - files_to_include=settings.files_to_include, - custom_scalars=settings.scalars, + settings=settings, plugin_manager=plugin_manager, ) for query in queries: diff --git a/ariadne_codegen/settings.py b/ariadne_codegen/settings.py index 61cccacb..024a6149 100644 --- a/ariadne_codegen/settings.py +++ b/ariadne_codegen/settings.py @@ -4,7 +4,7 @@ from keyword import iskeyword from pathlib import Path from textwrap import dedent -from typing import Dict, List, Optional +from typing import Dict, List from .client_generators.constants import ( DEFAULT_ASYNC_BASE_CLIENT_PATH, @@ -27,8 +27,8 @@ class Strategy(str, enum.Enum): @dataclass class BaseSettings: - schema_path: Optional[str] = None - remote_schema_url: Optional[str] = None + schema_path: str = "" + remote_schema_url: str = "" remote_schema_headers: dict = field(default_factory=dict) remote_schema_verify_ssl: bool = True plugins: List[str] = field(default_factory=list) @@ -52,8 +52,8 @@ class ClientSettings(BaseSettings): target_package_path: str = field(default_factory=lambda: Path.cwd().as_posix()) client_name: str = "Client" client_file_name: str = "client" - base_client_name: Optional[str] = None - base_client_file_path: Optional[str] = None + base_client_name: str = "" + base_client_file_path: str = "" enums_module_name: str = "enums" input_types_module_name: str = "input_types" fragments_module_name: str = "fragments" @@ -112,7 +112,11 @@ def _set_default_base_client_data(self): self.base_client_name = "BaseClient" @property - def used_settings_message(self): + def schema_source(self) -> str: + return self.schema_path if self.schema_path else self.remote_schema_url + + @property + def used_settings_message(self) -> str: snake_case_msg = ( "Converting fields and arguments name to snake case." if self.convert_to_snake_case diff --git a/tests/client_generators/conftest.py b/tests/client_generators/conftest.py index d466ddc8..c89211ee 100644 --- a/tests/client_generators/conftest.py +++ b/tests/client_generators/conftest.py @@ -3,49 +3,12 @@ import pytest -from ariadne_codegen.client_generators.constants import ( - BASE_MODEL_CLASS_NAME, - UNSET_NAME, - UNSET_TYPE_NAME, - UPLOAD_CLASS_NAME, -) from ariadne_codegen.client_generators.dependencies import ( async_base_client, base_client, - base_model, ) -@pytest.fixture -def base_model_import(): - return ast.ImportFrom( - module=Path(base_model.__file__).stem, - names=[ast.alias(BASE_MODEL_CLASS_NAME)], - level=1, - ) - - -@pytest.fixture -def upload_import(): - return ast.ImportFrom( - module=Path(base_model.__file__).stem, - names=[ast.alias(UPLOAD_CLASS_NAME)], - level=1, - ) - - -@pytest.fixture -def unset_import(): - return ast.ImportFrom( - module=Path(base_model.__file__).stem, - names=[ - ast.alias(UNSET_NAME), - ast.alias(UNSET_TYPE_NAME), - ], - level=1, - ) - - @pytest.fixture def base_client_import(): return ast.ImportFrom( diff --git a/tests/client_generators/input_types_generator/test_default_values.py b/tests/client_generators/input_types_generator/test_default_values.py index 239fce35..68322917 100644 --- a/tests/client_generators/input_types_generator/test_default_values.py +++ b/tests/client_generators/input_types_generator/test_default_values.py @@ -29,15 +29,10 @@ ], ) def test_generate_returns_module_with_parsed_inputs_scalar_field_with_default_value( - field_str, expected_annotation, expected_value, base_model_import, upload_import + field_str, expected_annotation, expected_value ): schema_str = f"input TestInput {{{field_str}}}" - generator = InputTypesGenerator( - schema=build_ast_schema(parse(schema_str)), - enums_module="enums", - base_model_import=base_model_import, - upload_import=upload_import, - ) + generator = InputTypesGenerator(schema=build_ast_schema(parse(schema_str))) expected_class_def = ast.ClassDef( name="TestInput", bases=[ast.Name(id=BASE_MODEL_CLASS_NAME)], @@ -103,15 +98,10 @@ def test_generate_returns_module_with_parsed_inputs_scalar_field_with_default_va ], ) def test_generate_returns_module_with_parsed_inputs_list_field_with_default_value( - field_str, expected_list, base_model_import, upload_import + field_str, expected_list ): schema_str = f"input TestInput {{{field_str}}}" - generator = InputTypesGenerator( - schema=build_ast_schema(parse(schema_str)), - enums_module="enums", - base_model_import=base_model_import, - upload_import=upload_import, - ) + generator = InputTypesGenerator(schema=build_ast_schema(parse(schema_str))) expected_field_value = ast.Call( func=ast.Name(id=FIELD_CLASS), args=[], @@ -144,9 +134,7 @@ def test_generate_returns_module_with_parsed_inputs_list_field_with_default_valu assert compare_ast(field_def.value, expected_field_value) -def test_generate_returns_module_with_parsed_inputs_object_field_with_default_value( - base_model_import, upload_import -): +def test_generate_returns_module_with_parsed_inputs_object_field_with_default_value(): schema_str = """ input TestInput { field: SecondInput = {val: 5} @@ -193,12 +181,7 @@ def test_generate_returns_module_with_parsed_inputs_object_field_with_default_va ], ) - generator = InputTypesGenerator( - schema=build_ast_schema(parse(schema_str)), - enums_module="enums", - base_model_import=base_model_import, - upload_import=upload_import, - ) + generator = InputTypesGenerator(schema=build_ast_schema(parse(schema_str))) module = generator.generate() @@ -212,9 +195,7 @@ def test_generate_returns_module_with_parsed_inputs_object_field_with_default_va assert compare_ast(field_def.value, expected_field_value) -def test_generate_returns_module_with_parsed_nested_object_as_default_value( - base_model_import, upload_import -): +def test_generate_returns_module_with_parsed_nested_object_as_default_value(): schema_str = """ input TestInput { field: SecondInput = { nested: { val: 1.5 } } @@ -269,12 +250,7 @@ def test_generate_returns_module_with_parsed_nested_object_as_default_value( ) ], ) - generator = InputTypesGenerator( - schema=build_ast_schema(parse(schema_str)), - enums_module="enums", - base_model_import=base_model_import, - upload_import=upload_import, - ) + generator = InputTypesGenerator(schema=build_ast_schema(parse(schema_str))) module = generator.generate() diff --git a/tests/client_generators/input_types_generator/test_imports.py b/tests/client_generators/input_types_generator/test_imports.py index ba8e0b8e..58ba492c 100644 --- a/tests/client_generators/input_types_generator/test_imports.py +++ b/tests/client_generators/input_types_generator/test_imports.py @@ -8,7 +8,7 @@ from ...utils import compare_ast, filter_imports, sorted_imports -def test_generate_returns_module_with_enum_imports(base_model_import, upload_import): +def test_generate_returns_module_with_enum_imports(): schema_str = """ input TestInput { field: TestEnum! @@ -19,12 +19,7 @@ def test_generate_returns_module_with_enum_imports(base_model_import, upload_imp VAL2 } """ - generator = InputTypesGenerator( - schema=build_ast_schema(parse(schema_str)), - enums_module="enums", - base_model_import=base_model_import, - upload_import=upload_import, - ) + generator = InputTypesGenerator(schema=build_ast_schema(parse(schema_str))) expected_import = ast.ImportFrom( module="enums", names=[ast.alias("TestEnum")], level=1 ) @@ -35,9 +30,7 @@ def test_generate_returns_module_with_enum_imports(base_model_import, upload_imp assert compare_ast(import_, expected_import) -def test_generate_returns_module_with_used_custom_scalars_imports( - base_model_import, upload_import -): +def test_generate_returns_module_with_used_custom_scalars_imports(): schema_str = """ input TestInput { field: SCALARA! @@ -47,9 +40,6 @@ def test_generate_returns_module_with_used_custom_scalars_imports( """ generator = InputTypesGenerator( schema=build_ast_schema(parse(schema_str)), - enums_module="enums", - base_model_import=base_model_import, - upload_import=upload_import, custom_scalars={ "SCALARA": ScalarData( type_=".custom_scalars.ScalarA", @@ -71,7 +61,7 @@ def test_generate_returns_module_with_used_custom_scalars_imports( assert compare_ast(sorted_imports(imports), sorted_imports(expected_imports)) -def test_generate_returns_module_with_upload_import(base_model_import, upload_import): +def test_generate_returns_module_with_upload_import(): schema_str = """ input TestInput { field: Upload! @@ -79,12 +69,7 @@ def test_generate_returns_module_with_upload_import(base_model_import, upload_im scalar Upload """ - generator = InputTypesGenerator( - schema=build_ast_schema(parse(schema_str)), - enums_module="enums", - base_model_import=base_model_import, - upload_import=upload_import, - ) + generator = InputTypesGenerator(schema=build_ast_schema(parse(schema_str))) expected_import = ast.ImportFrom( module="base_model", names=[ast.alias("Upload")], level=1 ) diff --git a/tests/client_generators/input_types_generator/test_method_calls.py b/tests/client_generators/input_types_generator/test_method_calls.py index be4d2a21..055ac46c 100644 --- a/tests/client_generators/input_types_generator/test_method_calls.py +++ b/tests/client_generators/input_types_generator/test_method_calls.py @@ -8,9 +8,7 @@ from ...utils import compare_ast, filter_ast_objects -def test_generate_returns_modules_with_model_rebuild_calls( - base_model_import, upload_import -): +def test_generate_returns_modules_with_model_rebuild_calls(): schema_str = """ input CustomInput { field: Int! @@ -53,12 +51,7 @@ def test_generate_returns_modules_with_model_rebuild_calls( ) ), ] - generator = InputTypesGenerator( - schema=build_ast_schema(parse(schema_str)), - enums_module="enums", - base_model_import=base_model_import, - upload_import=upload_import, - ) + generator = InputTypesGenerator(schema=build_ast_schema(parse(schema_str))) module = generator.generate() diff --git a/tests/client_generators/input_types_generator/test_names.py b/tests/client_generators/input_types_generator/test_names.py index 0283a893..c8dbb2ff 100644 --- a/tests/client_generators/input_types_generator/test_names.py +++ b/tests/client_generators/input_types_generator/test_names.py @@ -152,14 +152,9 @@ ], ) def test_generate_returns_module_with_fields_names_converted_to_snake_case( - schema_str, expected_class_def, base_model_import, upload_import + schema_str, expected_class_def ): - generator = InputTypesGenerator( - schema=build_ast_schema(parse(schema_str)), - enums_module="enums", - base_model_import=base_model_import, - upload_import=upload_import, - ) + generator = InputTypesGenerator(schema=build_ast_schema(parse(schema_str))) module = generator.generate() @@ -167,9 +162,7 @@ def test_generate_returns_module_with_fields_names_converted_to_snake_case( assert compare_ast(class_def, expected_class_def) -def test_generate_returns_module_with_valid_field_names( - base_model_import, upload_import -): +def test_generate_returns_module_with_valid_field_names(): schema = """ input KeywordInput { in: String! @@ -183,12 +176,7 @@ def test_generate_returns_module_with_valid_field_names( } """ - generator = InputTypesGenerator( - schema=build_ast_schema(parse(schema)), - enums_module="enums", - base_model_import=base_model_import, - upload_import=upload_import, - ) + generator = InputTypesGenerator(schema=build_ast_schema(parse(schema))) module = generator.generate() diff --git a/tests/client_generators/input_types_generator/test_parsing_inputs.py b/tests/client_generators/input_types_generator/test_parsing_inputs.py index 3a8e2b61..eee46667 100644 --- a/tests/client_generators/input_types_generator/test_parsing_inputs.py +++ b/tests/client_generators/input_types_generator/test_parsing_inputs.py @@ -84,14 +84,9 @@ ], ) def test_generate_returns_module_with_parsed_input_types( - schema_str, expected_class_defs, base_model_import, upload_import + schema_str, expected_class_defs ): - generator = InputTypesGenerator( - schema=build_ast_schema(parse(schema_str)), - enums_module="enums", - base_model_import=base_model_import, - upload_import=upload_import, - ) + generator = InputTypesGenerator(schema=build_ast_schema(parse(schema_str))) module = generator.generate() @@ -99,9 +94,7 @@ def test_generate_returns_module_with_parsed_input_types( assert compare_ast(class_defs, expected_class_defs) -def test_generate_returns_module_with_classes_in_the_same_order_as_declared( - base_model_import, upload_import -): +def test_generate_returns_module_with_classes_in_the_same_order_as_declared(): schema_str = """ input BeforeInput { field: Boolean! @@ -130,12 +123,7 @@ def test_generate_returns_module_with_classes_in_the_same_order_as_declared( "NestedInput", "AfterInput", ] - generator = InputTypesGenerator( - schema=build_ast_schema(parse(schema_str)), - enums_module="enums", - base_model_import=base_model_import, - upload_import=upload_import, - ) + generator = InputTypesGenerator(schema=build_ast_schema(parse(schema_str))) module = generator.generate() diff --git a/tests/client_generators/input_types_generator/test_plugin_hooks.py b/tests/client_generators/input_types_generator/test_plugin_hooks.py index 602b0116..6397ea81 100644 --- a/tests/client_generators/input_types_generator/test_plugin_hooks.py +++ b/tests/client_generators/input_types_generator/test_plugin_hooks.py @@ -10,7 +10,7 @@ def test_generator_triggers_generate_input_class_hook_for_every_input_type( - mocked_plugin_manager, base_model_import, upload_import + mocked_plugin_manager, ): schema_str = """ input TestInputA { @@ -23,11 +23,7 @@ def test_generator_triggers_generate_input_class_hook_for_every_input_type( """ InputTypesGenerator( - schema=build_ast_schema(parse(schema_str)), - enums_module="enums", - base_model_import=base_model_import, - upload_import=upload_import, - plugin_manager=mocked_plugin_manager, + schema=build_ast_schema(parse(schema_str)), plugin_manager=mocked_plugin_manager ) assert mocked_plugin_manager.generate_input_class.call_count == 2 @@ -41,7 +37,7 @@ def test_generator_triggers_generate_input_class_hook_for_every_input_type( def test_generator_triggers_generate_input_field_hook_for_every_input_field( - mocked_plugin_manager, base_model_import, upload_import + mocked_plugin_manager, ): schema_str = """ input TestInputAB { @@ -55,11 +51,7 @@ def test_generator_triggers_generate_input_field_hook_for_every_input_field( """ InputTypesGenerator( - schema=build_ast_schema(parse(schema_str)), - enums_module="enums", - base_model_import=base_model_import, - upload_import=upload_import, - plugin_manager=mocked_plugin_manager, + schema=build_ast_schema(parse(schema_str)), plugin_manager=mocked_plugin_manager ) assert mocked_plugin_manager.generate_input_field.call_count == 3 @@ -72,15 +64,9 @@ def test_generator_triggers_generate_input_field_hook_for_every_input_field( assert mock_calls[2].kwargs["field_name"] == "fieldC" -def test_generate_triggers_generate_inputs_module_hook( - mocked_plugin_manager, base_model_import, upload_import -): +def test_generate_triggers_generate_inputs_module_hook(mocked_plugin_manager): generator = InputTypesGenerator( - schema=GraphQLSchema(), - enums_module="enums", - base_model_import=base_model_import, - upload_import=upload_import, - plugin_manager=mocked_plugin_manager, + schema=GraphQLSchema(), plugin_manager=mocked_plugin_manager ) generator.generate() @@ -88,9 +74,7 @@ def test_generate_triggers_generate_inputs_module_hook( assert mocked_plugin_manager.generate_inputs_module.called -def test_generate_triggers_process_name_hook_for_every_field( - mocked_plugin_manager, base_model_import, upload_import -): +def test_generate_triggers_process_name_hook_for_every_field(mocked_plugin_manager): schema_str = """ input TestInputAB { fieldA: String! @@ -104,9 +88,6 @@ def test_generate_triggers_process_name_hook_for_every_field( InputTypesGenerator( schema=build_ast_schema(parse(schema_str)), - enums_module="enums", - base_model_import=base_model_import, - upload_import=upload_import, convert_to_snake_case=False, plugin_manager=mocked_plugin_manager, ) diff --git a/tests/client_generators/test_client_generator.py b/tests/client_generators/test_client_generator.py index 222ff56e..f8a812e4 100644 --- a/tests/client_generators/test_client_generator.py +++ b/tests/client_generators/test_client_generator.py @@ -15,6 +15,8 @@ OPTIONAL, TYPING_MODULE, UNION, + UNSET_IMPORT, + UPLOAD_IMPORT, ) from ariadne_codegen.client_generators.scalars import ScalarData from ariadne_codegen.exceptions import NotSupported @@ -22,19 +24,12 @@ from ..utils import compare_ast, filter_imports, get_class_def, sorted_imports -def test_generate_returns_module_with_correct_class_name( - async_base_client_import, unset_import, upload_import -): +def test_generate_returns_module_with_correct_class_name(async_base_client_import): name = "ClientXyz" generator = ClientGenerator( - name, - base_client="BaseClient", - enums_module_name="enums", - input_types_module_name="inputs", - arguments_generator=ArgumentsGenerator(schema=GraphQLSchema()), base_client_import=async_base_client_import, - unset_import=unset_import, - upload_import=upload_import, + arguments_generator=ArgumentsGenerator(schema=GraphQLSchema()), + name=name, ) module = generator.generate() @@ -44,18 +39,10 @@ def test_generate_returns_module_with_correct_class_name( assert class_def.name == name -def test_generate_returns_module_with_gql_lambda_definition( - async_base_client_import, unset_import, upload_import -): +def test_generate_returns_module_with_gql_lambda_definition(async_base_client_import): generator = ClientGenerator( - "ClientXYZ", - base_client="BaseClient", - enums_module_name="enums", - input_types_module_name="inputs", - arguments_generator=ArgumentsGenerator(schema=GraphQLSchema()), base_client_import=async_base_client_import, - unset_import=unset_import, - upload_import=upload_import, + arguments_generator=ArgumentsGenerator(schema=GraphQLSchema()), ) expected_def = ast.FunctionDef( name="gql", @@ -81,17 +68,11 @@ def test_generate_returns_module_with_gql_lambda_definition( def test_generate_triggers_generate_gql_function_hook( - mocked_plugin_manager, async_base_client_import, unset_import, upload_import + mocked_plugin_manager, async_base_client_import ): generator = ClientGenerator( - "ClientXYZ", - base_client="BaseClient", - enums_module_name="enums", - input_types_module_name="inputs", - arguments_generator=ArgumentsGenerator(schema=GraphQLSchema()), base_client_import=async_base_client_import, - unset_import=unset_import, - upload_import=upload_import, + arguments_generator=ArgumentsGenerator(schema=GraphQLSchema()), plugin_manager=mocked_plugin_manager, ) @@ -101,36 +82,25 @@ def test_generate_triggers_generate_gql_function_hook( def test_generate_triggers_generate_client_class_hook( - mocked_plugin_manager, async_base_client_import, unset_import, upload_import + mocked_plugin_manager, async_base_client_import ): generator = ClientGenerator( - "ClientXYZ", - base_client="BaseClient", - enums_module_name="enums", - input_types_module_name="inputs", - arguments_generator=ArgumentsGenerator(schema=GraphQLSchema()), base_client_import=async_base_client_import, - unset_import=unset_import, - upload_import=upload_import, + arguments_generator=ArgumentsGenerator(schema=GraphQLSchema()), plugin_manager=mocked_plugin_manager, ) + generator.generate() assert mocked_plugin_manager.generate_client_class.called def test_generate_triggers_generate_client_module_hook( - mocked_plugin_manager, async_base_client_import, unset_import, upload_import + mocked_plugin_manager, async_base_client_import ): generator = ClientGenerator( - "ClientXYZ", - base_client="BaseClient", - enums_module_name="enums", - input_types_module_name="inputs", - arguments_generator=ArgumentsGenerator(schema=GraphQLSchema()), base_client_import=async_base_client_import, - unset_import=unset_import, - upload_import=upload_import, + arguments_generator=ArgumentsGenerator(schema=GraphQLSchema()), plugin_manager=mocked_plugin_manager, ) @@ -139,9 +109,7 @@ def test_generate_triggers_generate_client_module_hook( assert mocked_plugin_manager.generate_client_module.called -def test_generate_returns_module_with_correct_imports( - async_base_client_import, unset_import, upload_import -): +def test_generate_returns_module_with_correct_imports(async_base_client_import): schema_str = """ schema { query: Query } type Query { xyz(arg1: TestScalar!, arg2: TestEnum!, arg3: TestInput): TestType } @@ -174,24 +142,20 @@ def test_generate_returns_module_with_correct_imports( ) } generator = ClientGenerator( - "Client", - base_client="BaseClient", - enums_module_name="enums", - input_types_module_name="inputs", + base_client_import=async_base_client_import, arguments_generator=ArgumentsGenerator( schema=build_schema(schema_str), custom_scalars=scalars ), - base_client_import=async_base_client_import, - unset_import=unset_import, - upload_import=upload_import, custom_scalars=scalars, ) expected_imports = [ async_base_client_import, - unset_import, - upload_import, + UNSET_IMPORT, + UPLOAD_IMPORT, ast.ImportFrom(module="enums", names=[ast.alias(name="TestEnum")], level=1), - ast.ImportFrom(module="inputs", names=[ast.alias(name="TestInput")], level=1), + ast.ImportFrom( + module="input_types", names=[ast.alias(name="TestInput")], level=1 + ), ast.ImportFrom( module=".custom_scalars", names=[ast.alias(name="TestScalarType")], level=0 ), @@ -230,19 +194,12 @@ def test_generate_returns_module_with_correct_imports( ) -def test_add_method_adds_async_method_definition( - async_base_client_import, unset_import, upload_import -): +def test_add_method_adds_async_method_definition(async_base_client_import): generator = ClientGenerator( - "ClientXYZ", - base_client="BaseClient", - enums_module_name="enums", - input_types_module_name="inputs", - arguments_generator=ArgumentsGenerator(schema=GraphQLSchema()), base_client_import=async_base_client_import, - unset_import=unset_import, - upload_import=upload_import, + arguments_generator=ArgumentsGenerator(schema=GraphQLSchema()), ) + method_name = "list_xyz" return_type = "ListXyz" return_type_module_name = method_name @@ -267,9 +224,7 @@ def test_add_method_adds_async_method_definition( assert method_def.returns.id == return_type -def test_add_method_generates_correct_async_method_body( - async_base_client_import, unset_import, upload_import -): +def test_add_method_generates_correct_async_method_body(async_base_client_import): schema_str = """ schema { query: Query } type Query { xyz(arg1: Int!): TestType } @@ -287,14 +242,8 @@ def test_add_method_generates_correct_async_method_body( } """ generator = ClientGenerator( - "Client", - base_client="BaseClient", - enums_module_name="enums", - input_types_module_name="inputs", - arguments_generator=ArgumentsGenerator(schema=build_schema(schema_str)), base_client_import=async_base_client_import, - unset_import=unset_import, - upload_import=upload_import, + arguments_generator=ArgumentsGenerator(schema=build_schema(schema_str)), ) method_name = "list_xyz" return_type = "ListXyz" @@ -368,9 +317,7 @@ def test_add_method_generates_correct_async_method_body( assert compare_ast(method_def.body, expected_method_body) -def test_add_method_adds_method_definition( - unset_import, upload_import, base_client_import -): +def test_add_method_adds_method_definition(base_client_import): schema_str = """ schema { query: Query } type Query { xyz(arg1: Int!): TestType } @@ -388,14 +335,8 @@ def test_add_method_adds_method_definition( } """ generator = ClientGenerator( - "Client", - base_client="BaseClient", - enums_module_name="enums", - input_types_module_name="inputs", - arguments_generator=ArgumentsGenerator(schema=build_schema(schema_str)), base_client_import=base_client_import, - unset_import=unset_import, - upload_import=upload_import, + arguments_generator=ArgumentsGenerator(schema=build_schema(schema_str)), ) method_name = "list_xyz" return_type = "ListXyz" @@ -421,9 +362,7 @@ def test_add_method_adds_method_definition( assert method_def.returns.id == return_type -def test_add_method_generates_correct_method_body( - unset_import, upload_import, base_client_import -): +def test_add_method_generates_correct_method_body(base_client_import): schema_str = """ schema { query: Query } type Query { xyz(arg1: Int!): TestType } @@ -441,14 +380,8 @@ def test_add_method_generates_correct_method_body( } """ generator = ClientGenerator( - "Client", - base_client="BaseClient", - enums_module_name="enums", - input_types_module_name="inputs", - arguments_generator=ArgumentsGenerator(schema=build_schema(schema_str)), base_client_import=base_client_import, - unset_import=unset_import, - upload_import=upload_import, + arguments_generator=ArgumentsGenerator(schema=build_schema(schema_str)), ) method_name = "list_xyz" return_type = "ListXyz" @@ -521,7 +454,7 @@ def test_add_method_generates_correct_method_body( def test_add_method_generates_async_generator_for_subscription_definition( - async_base_client_import, unset_import, upload_import + async_base_client_import, ): schema_str = """ schema { subscription: Subscription } @@ -529,14 +462,8 @@ def test_add_method_generates_async_generator_for_subscription_definition( """ subscription_str = "subscription GetCounter { counter }" generator = ClientGenerator( - "ClientXYZ", - base_client="AsyncBaseClient", - enums_module_name="enums", - input_types_module_name="inputs", - arguments_generator=ArgumentsGenerator(schema=build_schema(schema_str)), base_client_import=async_base_client_import, - unset_import=unset_import, - upload_import=upload_import, + arguments_generator=ArgumentsGenerator(schema=build_schema(schema_str)), ) expected_method_def = ast.AsyncFunctionDef( name="get_counter", @@ -617,19 +544,11 @@ def test_add_method_generates_async_generator_for_subscription_definition( assert compare_ast(class_def.body[0], expected_method_def) -def test_add_method_raises_not_supported_for_not_async_subscription( - unset_import, upload_import, base_client_import -): +def test_add_method_raises_not_supported_for_not_async_subscription(base_client_import): subscription_str = "subscription GetCounter { counter }" - generator = ClientGenerator( - "Client", - base_client="BaseClient", - enums_module_name="enums", - input_types_module_name="inputs", - arguments_generator=ArgumentsGenerator(GraphQLSchema()), + generator = generator = ClientGenerator( base_client_import=base_client_import, - unset_import=unset_import, - upload_import=upload_import, + arguments_generator=ArgumentsGenerator(schema=GraphQLSchema()), ) with pytest.raises(NotSupported): @@ -646,7 +565,7 @@ def test_add_method_raises_not_supported_for_not_async_subscription( def test_add_method_triggers_generate_client_method_hook( - mocked_plugin_manager, unset_import, upload_import, base_client_import + mocked_plugin_manager, base_client_import ): schema_str = """ schema { query: Query } @@ -664,15 +583,9 @@ def test_add_method_triggers_generate_client_method_hook( } } """ - generator = ClientGenerator( - "Client", - base_client="BaseClient", - enums_module_name="enums", - input_types_module_name="inputs", - arguments_generator=ArgumentsGenerator(schema=build_schema(schema_str)), + generator = generator = ClientGenerator( base_client_import=base_client_import, - unset_import=unset_import, - upload_import=upload_import, + arguments_generator=ArgumentsGenerator(schema=build_schema(schema_str)), plugin_manager=mocked_plugin_manager, ) method_name = "list_xyz" diff --git a/tests/client_generators/test_package_generator.py b/tests/client_generators/test_package_generator.py index 4c0f3315..5a44aac9 100644 --- a/tests/client_generators/test_package_generator.py +++ b/tests/client_generators/test_package_generator.py @@ -5,11 +5,17 @@ from freezegun import freeze_time from graphql import GraphQLSchema, build_ast_schema, parse +from ariadne_codegen.client_generators.arguments import ArgumentsGenerator +from ariadne_codegen.client_generators.client import ClientGenerator from ariadne_codegen.client_generators.constants import ( + EXCEPTIONS_FILE_PATH, SOURCE_COMMENT, STABLE_COMMENT, TIMESTAMP_COMMENT, ) +from ariadne_codegen.client_generators.enums import EnumsGenerator +from ariadne_codegen.client_generators.init_file import InitFileGenerator +from ariadne_codegen.client_generators.input_types import InputTypesGenerator from ariadne_codegen.client_generators.package import PackageGenerator from ariadne_codegen.client_generators.scalars import ScalarData from ariadne_codegen.exceptions import ParsingError @@ -52,9 +58,21 @@ """ -def test_generate_creates_directory_and_files(tmp_path): +def test_generate_creates_directory_and_files(tmp_path, async_base_client_import): package_name = "test_graphql_client" - generator = PackageGenerator(package_name, tmp_path.as_posix(), GraphQLSchema()) + schema = GraphQLSchema() + generator = PackageGenerator( + package_name=package_name, + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), + ) generator.generate() @@ -81,9 +99,23 @@ def test_generate_creates_directory_and_files(tmp_path): assert base_model_path.is_file() -def test_generate_creates_files_with_correct_imports(tmp_path): +def test_generate_creates_files_with_correct_imports( + tmp_path, async_base_client_import +): package_name = "test_graphql_client" - generator = PackageGenerator(package_name, tmp_path.as_posix(), GraphQLSchema()) + schema = GraphQLSchema() + generator = PackageGenerator( + package_name=package_name, + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), + ) generator.generate() @@ -117,10 +149,20 @@ def test_generate_creates_files_with_correct_imports(tmp_path): assert "from .async_base_client import AsyncBaseClient" in client_content -def test_generate_creates_files_with_types(tmp_path): +def test_generate_creates_files_with_types(tmp_path, async_base_client_import): package_name = "test_graphql_client" + schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( - package_name, tmp_path.as_posix(), build_ast_schema(parse(SCHEMA_STR)) + package_name=package_name, + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), ) expected_input_types = """ class CustomInput(BaseModel): @@ -147,10 +189,20 @@ class CustomEnum(str, Enum): assert dedent(expected_enums) in enums_content -def test_generate_creates_file_with_query_types(tmp_path): +def test_generate_creates_file_with_query_types(tmp_path, async_base_client_import): package_name = "test_graphql_client" + schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( - package_name, tmp_path.as_posix(), build_ast_schema(parse(SCHEMA_STR)) + package_name=package_name, + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), ) query_str = """ query CustomQuery($id: ID!, $param: String) { @@ -191,10 +243,22 @@ class CustomQueryQuery1Field2(BaseModel): ) -def test_generate_creates_multiple_query_types_files(tmp_path): +def test_generate_creates_multiple_query_types_files( + tmp_path, async_base_client_import +): package_name = "test_graphql_client" + schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( - package_name, tmp_path.as_posix(), build_ast_schema(parse(SCHEMA_STR)) + package_name=package_name, + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), ) query_str = """ query CustomQuery1 { @@ -223,7 +287,7 @@ def test_generate_creates_multiple_query_types_files(tmp_path): assert query2_file_path.is_file() -def test_generate_copies_base_client_file(tmp_path): +def test_generate_copies_base_client_file(tmp_path, async_base_client_import): base_client_file_content = """ class TestBaseClient: pass @@ -231,10 +295,18 @@ class TestBaseClient: package_name = "test_graphql_client" base_client_file_path = tmp_path / "test_base_client.py" base_client_file_path.write_text(dedent(base_client_file_content)) + schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( - package_name, - tmp_path.as_posix(), - build_ast_schema(parse(SCHEMA_STR)), + package_name=package_name, + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), base_client_name="TestBaseClient", base_client_file_path=base_client_file_path.as_posix(), ) @@ -249,12 +321,22 @@ class TestBaseClient: assert dedent(base_client_file_content) in dedent(copied_content) -def test_generate_creates_client_with_valid_method_names(tmp_path): +def test_generate_creates_client_with_valid_method_names( + tmp_path, async_base_client_import +): package_name = "test_graphql_client" + schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( - package_name, - tmp_path.as_posix(), - build_ast_schema(parse(SCHEMA_STR)), + package_name=package_name, + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), async_client=False, ) query_str = """ @@ -281,12 +363,23 @@ def test_generate_creates_client_with_valid_method_names(tmp_path): assert function.name == "from_" -def test_generate_with_conflicting_query_name_raises_parsing_error(tmp_path): +def test_generate_with_conflicting_query_name_raises_parsing_error( + tmp_path, async_base_client_import +): + schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( - "test_graphql_client", - tmp_path.as_posix(), - build_ast_schema(parse(SCHEMA_STR)), + package_name="test_graphql_client", + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), input_types_module_name="input_types", + convert_to_snake_case=True, ) query_str = """ query InputTypes { @@ -302,13 +395,21 @@ def test_generate_with_conflicting_query_name_raises_parsing_error(tmp_path): def test_generate_with_enum_as_query_argument_generates_client_with_correct_method( - tmp_path, + tmp_path, async_base_client_import ): package_name = "test_graphql_client" + schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( - package_name, - tmp_path.as_posix(), - build_ast_schema(parse(SCHEMA_STR)), + package_name=package_name, + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), async_client=True, ) query_str = """ @@ -332,10 +433,22 @@ def test_generate_with_enum_as_query_argument_generates_client_with_correct_meth assert expected_enum_import in client_content -def test_generate_creates_client_file_with_gql_lambda_definition(tmp_path): +def test_generate_creates_client_file_with_gql_lambda_definition( + tmp_path, async_base_client_import +): package_name = "test_graphql_client" + schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( - package_name, tmp_path.as_posix(), build_ast_schema(parse(SCHEMA_STR)) + package_name=package_name, + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), ) generator.generate() @@ -358,12 +471,22 @@ def test_generate_creates_client_file_with_gql_lambda_definition(tmp_path): ], ) @freeze_time("01.01.2022 12:00") -def test_generate_adds_comment_to_generated_files(tmp_path, strategy, expected_comment): +def test_generate_adds_comment_to_generated_files( + tmp_path, strategy, expected_comment, async_base_client_import +): package_name = "test_graphql_client" + schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( - package_name, - tmp_path.as_posix(), - build_ast_schema(parse(SCHEMA_STR)), + package_name=package_name, + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), comments_strategy=strategy, ) query_str = """ @@ -396,15 +519,23 @@ def test_generate_adds_comment_to_generated_files(tmp_path, strategy, expected_c "strategy", [CommentsStrategy.STABLE, CommentsStrategy.TIMESTAMP] ) def test_generate_adds_comment_with_correct_source_to_generated_files( - tmp_path, strategy + tmp_path, async_base_client_import, strategy ): package_name = "test_graphql_client" schema_source = "schema_source.graphql" queries_source = "queries_source.graphql" + schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( - package_name, - tmp_path.as_posix(), - build_ast_schema(parse(SCHEMA_STR)), + package_name=package_name, + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), comments_strategy=strategy, schema_source=schema_source, queries_source=queries_source, @@ -443,13 +574,21 @@ def test_generate_adds_comment_with_correct_source_to_generated_files( [CommentsStrategy.NONE, CommentsStrategy.STABLE, CommentsStrategy.TIMESTAMP], ) def test_generate_calls_get_file_comment_hook_for_every_file( - tmp_path, strategy, mocked_plugin_manager + tmp_path, async_base_client_import, strategy, mocked_plugin_manager ): package_name = "test_graphql_client" + schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( - package_name, - tmp_path.as_posix(), - build_ast_schema(parse(SCHEMA_STR)), + package_name=package_name, + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), comments_strategy=strategy, plugin_manager=mocked_plugin_manager, ) @@ -469,7 +608,9 @@ def test_generate_calls_get_file_comment_hook_for_every_file( ) -def test_generate_creates_result_types_from_operation_that_uses_fragment(tmp_path): +def test_generate_creates_result_types_from_operation_that_uses_fragment( + tmp_path, async_base_client_import +): package_name = "test_graphql_client" query_str = """ query CustomQuery($id: ID!) { @@ -495,11 +636,19 @@ class CustomQueryQuery1(TestFragment): field_3: CustomEnum = Field(alias="field3") """ query_def, fragment_def = parse(query_str).definitions + schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( - package_name, - tmp_path.as_posix(), - build_ast_schema(parse(SCHEMA_STR)), - fragments=[fragment_def], + package_name=package_name, + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), + fragments_definitions={"TestFragment": fragment_def}, ) generator.add_operation(query_def) @@ -511,12 +660,24 @@ class CustomQueryQuery1(TestFragment): assert dedent(expected_types) in result_types_content -def test_generate_returns_list_of_generated_files(tmp_path): +def test_generate_returns_list_of_generated_files(tmp_path, async_base_client_import): + schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( - "test_graphql_client", - tmp_path.as_posix(), - build_ast_schema(parse(SCHEMA_STR)), - fragments=[parse("fragment TestFragment on CustomType { id }").definitions[0]], + package_name="test_graphql_client", + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), + fragments_definitions={ + "TestFragment": parse( + "fragment TestFragment on CustomType { id }" + ).definitions[0] + }, custom_scalars={"SCALARABC": ScalarData(type_="str", graphql_name="SCALARABC")}, ) query_str = """ @@ -536,7 +697,7 @@ def test_generate_returns_list_of_generated_files(tmp_path): generator.base_client_file_path.name, "base_model.py", f"{generator.client_file_name}.py", - generator.exceptions_file_path.name, + EXCEPTIONS_FILE_PATH.name, f"{generator.input_types_module_name}.py", f"{generator.enums_module_name}.py", "custom_query.py", @@ -545,7 +706,7 @@ def test_generate_returns_list_of_generated_files(tmp_path): ) -def test_generate_copies_files_to_include(tmp_path): +def test_generate_copies_files_to_include(tmp_path, async_base_client_import): file1 = tmp_path / "file1.py" file1_content = "class TestBaseClass:\n pass" file1.write_text(file1_content) @@ -556,10 +717,18 @@ def test_generate_copies_files_to_include(tmp_path): file2_content = "class TestBaseClass2:\n pass" file2.write_text(file2_content) + schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( package_name="test_graphql_client", target_path=tmp_path.as_posix(), - schema=build_ast_schema(parse(SCHEMA_STR)), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), files_to_include=[file1.as_posix(), file2.as_posix()], ) generated_files = generator.generate() @@ -572,15 +741,31 @@ def test_generate_copies_files_to_include(tmp_path): assert file2_content in copied_file2.read() -def test_generate_creates_client_with_custom_scalars_imports(tmp_path): +def test_generate_creates_client_with_custom_scalars_imports( + tmp_path, async_base_client_import +): package_name = "test_graphql_client" + custom_scalars = { + "SCALARABC": ScalarData(type_=".abc.ScalarABC", graphql_name="SCALARABC") + } + schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( - package_name, - tmp_path.as_posix(), - build_ast_schema(parse(SCHEMA_STR)), - custom_scalars={ - "SCALARABC": ScalarData(type_=".abc.ScalarABC", graphql_name="SCALARABC") - }, + package_name=package_name, + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator( + schema=schema, custom_scalars=custom_scalars + ), + custom_scalars=custom_scalars, + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator( + schema=schema, custom_scalars=custom_scalars + ), + custom_scalars=custom_scalars, ) query_str = """ query CustomQuery($id: ID!, $param: SCALARABC) { @@ -600,33 +785,66 @@ def test_generate_creates_client_with_custom_scalars_imports(tmp_path): assert "from .abc import ScalarABC" in client_file.read() -def test_generate_triggers_generate_client_code_hook(mocked_plugin_manager, tmp_path): +def test_generate_triggers_generate_client_code_hook( + mocked_plugin_manager, tmp_path, async_base_client_import +): + schema = build_ast_schema(parse(SCHEMA_STR)) + PackageGenerator( - "package_name", - tmp_path.as_posix(), - build_ast_schema(parse(SCHEMA_STR)), + package_name="test_graphql_client", + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), plugin_manager=mocked_plugin_manager, ).generate() assert mocked_plugin_manager.generate_client_code.called -def test_generate_triggers_generate_enums_code_hook(mocked_plugin_manager, tmp_path): +def test_generate_triggers_generate_enums_code_hook( + mocked_plugin_manager, tmp_path, async_base_client_import +): + schema = build_ast_schema(parse(SCHEMA_STR)) + PackageGenerator( - "package_name", - tmp_path.as_posix(), - build_ast_schema(parse(SCHEMA_STR)), + package_name="test_graphql_client", + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), plugin_manager=mocked_plugin_manager, ).generate() assert mocked_plugin_manager.generate_enums_code.called -def test_generate_triggers_generate_inputs_code_hook(mocked_plugin_manager, tmp_path): +def test_generate_triggers_generate_inputs_code_hook( + mocked_plugin_manager, tmp_path, async_base_client_import +): + schema = build_ast_schema(parse(SCHEMA_STR)) + PackageGenerator( - "package_name", - tmp_path.as_posix(), - build_ast_schema(parse(SCHEMA_STR)), + package_name="test_graphql_client", + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), plugin_manager=mocked_plugin_manager, ).generate() @@ -634,12 +852,20 @@ def test_generate_triggers_generate_inputs_code_hook(mocked_plugin_manager, tmp_ def test_generate_triggers_generate_result_types_code_hook_for_every_added_operation( - mocked_plugin_manager, tmp_path + mocked_plugin_manager, tmp_path, async_base_client_import ): + schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( - "package_name", - tmp_path.as_posix(), - build_ast_schema(parse(SCHEMA_STR)), + package_name="test_graphql_client", + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), plugin_manager=mocked_plugin_manager, ) generator.add_operation(parse("query A { query2 { id } }").definitions[0]) @@ -651,12 +877,21 @@ def test_generate_triggers_generate_result_types_code_hook_for_every_added_opera def test_generate_triggers_copy_code_hook_for_every_attached_dependency_file( - mocked_plugin_manager, tmp_path + mocked_plugin_manager, tmp_path, async_base_client_import ): + schema = build_ast_schema(parse(SCHEMA_STR)) + PackageGenerator( - "package_name", - tmp_path.as_posix(), - build_ast_schema(parse(SCHEMA_STR)), + package_name="test_graphql_client", + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), plugin_manager=mocked_plugin_manager, ).generate() @@ -664,14 +899,23 @@ def test_generate_triggers_copy_code_hook_for_every_attached_dependency_file( def test_generate_triggers_copy_code_hook_for_every_file_to_include( - mocked_plugin_manager, tmp_path + mocked_plugin_manager, tmp_path, async_base_client_import ): test_file_path = tmp_path / "xyz.py" test_file_path.touch() + schema = build_ast_schema(parse(SCHEMA_STR)) + PackageGenerator( - "package_name", - tmp_path.as_posix(), - build_ast_schema(parse(SCHEMA_STR)), + package_name="test_graphql_client", + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), plugin_manager=mocked_plugin_manager, files_to_include=[test_file_path.as_posix()], ).generate() @@ -679,18 +923,31 @@ def test_generate_triggers_copy_code_hook_for_every_file_to_include( assert mocked_plugin_manager.copy_code.call_count == 4 -def test_generate_triggers_generate_init_code_hook(mocked_plugin_manager, tmp_path): +def test_generate_triggers_generate_init_code_hook( + mocked_plugin_manager, tmp_path, async_base_client_import +): + schema = build_ast_schema(parse(SCHEMA_STR)) + PackageGenerator( - "package_name", - tmp_path.as_posix(), - build_ast_schema(parse(SCHEMA_STR)), + package_name="test_graphql_client", + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), plugin_manager=mocked_plugin_manager, ).generate() assert mocked_plugin_manager.generate_init_code.called -def test_add_operation_triggers_process_name_hook(mocked_plugin_manager, tmp_path): +def test_add_operation_triggers_process_name_hook( + mocked_plugin_manager, tmp_path, async_base_client_import +): query_str = """ query custom_query_name { query2 { @@ -698,10 +955,19 @@ def test_add_operation_triggers_process_name_hook(mocked_plugin_manager, tmp_pat } } """ + schema = build_ast_schema(parse(SCHEMA_STR)) + PackageGenerator( - "package_name", - tmp_path.as_posix(), - build_ast_schema(parse(SCHEMA_STR)), + package_name="test_graphql_client", + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), plugin_manager=mocked_plugin_manager, ).add_operation(parse(query_str).definitions[0]) diff --git a/tests/test_settings.py b/tests/test_settings.py index ac0a0719..6b024109 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -114,7 +114,7 @@ def test_client_settings_without_schema_path_with_remote_schema_url_is_valid(tmp remote_schema_url="http://testserver/graphq/", queries_path=queries_path ) - assert settings.schema_path is None + assert not settings.schema_path def test_client_settings_without_schema_path_or_remote_schema_url_raises_exception( @@ -211,13 +211,13 @@ def test_graphq_schema_settings_without_remote_schema_url_with_schema_path_is_va settings = GraphQLSchemaSettings(schema_path=schema_path.as_posix()) - assert settings.remote_schema_url is None + assert not settings.remote_schema_url def test_graphq_schema_settings_without_schema_path_with_remote_schema_url_is_valid(): settings = GraphQLSchemaSettings(remote_schema_url="http://testserver/graphq/") - assert settings.schema_path is None + assert not settings.schema_path def test_graphq_schema_settings_without_schema_path_or_remote_schema_url_is_not_valid(): From c634950faefca773629ddfd95a62cac3b98d2b70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Tue, 10 Oct 2023 17:46:19 +0200 Subject: [PATCH 3/5] Refactor PackageGenerator to get fragment generator as argument --- .../client_generators/fragments.py | 16 ++++---- ariadne_codegen/client_generators/package.py | 31 ++++++++------ .../test_fragments_generator.py | 3 +- .../test_package_generator.py | 41 ++++++++++++++++--- 4 files changed, 63 insertions(+), 28 deletions(-) diff --git a/ariadne_codegen/client_generators/fragments.py b/ariadne_codegen/client_generators/fragments.py index 0cc26b77..3729cc5b 100644 --- a/ariadne_codegen/client_generators/fragments.py +++ b/ariadne_codegen/client_generators/fragments.py @@ -5,7 +5,7 @@ from ..codegen import generate_expr, generate_method_call, generate_module from ..plugins.manager import PluginManager -from .constants import MODEL_REBUILD_METHOD +from .constants import BASE_MODEL_IMPORT, MODEL_REBUILD_METHOD from .result_types import ResultTypesGenerator from .scalars import ScalarData @@ -14,33 +14,31 @@ class FragmentsGenerator: def __init__( self, schema: GraphQLSchema, - enums_module_name: str, fragments_definitions: Dict[str, FragmentDefinitionNode], - exclude_names: Optional[Set[str]] = None, - base_model_import: Optional[ast.ImportFrom] = None, + enums_module_name: str = "enums", + base_model_import: ast.ImportFrom = BASE_MODEL_IMPORT, convert_to_snake_case: bool = True, custom_scalars: Optional[Dict[str, ScalarData]] = None, plugin_manager: Optional[PluginManager] = None, ) -> None: self.schema = schema self.enums_module_name = enums_module_name - self.exclude_names = exclude_names or set() self.fragments_definitions = fragments_definitions self.base_model_import = base_model_import self.convert_to_snake_case = convert_to_snake_case self.custom_scalars = custom_scalars self.plugin_manager = plugin_manager - self._fragments_names = ( - set(self.fragments_definitions.keys()) - self.exclude_names - ) + self._fragments_names = set(self.fragments_definitions.keys()) self._generated_public_names: List[str] = [] - def generate(self) -> ast.Module: + def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module: class_defs_dict: Dict[str, List[ast.ClassDef]] = {} imports: List[ast.ImportFrom] = [] dependencies_dict: Dict[str, Set[str]] = {} + names_to_exclude = exclude_names or set() + self._fragments_names = self._fragments_names - names_to_exclude for name in self._fragments_names: fragmanet_def = self.fragments_definitions[name] generator = ResultTypesGenerator( diff --git a/ariadne_codegen/client_generators/package.py b/ariadne_codegen/client_generators/package.py index af0aa5e6..4740d3a8 100644 --- a/ariadne_codegen/client_generators/package.py +++ b/ariadne_codegen/client_generators/package.py @@ -42,6 +42,7 @@ def __init__( client_generator: ClientGenerator, enums_generator: EnumsGenerator, input_types_generator: InputTypesGenerator, + fragments_generator: FragmentsGenerator, fragments_definitions: Optional[Dict[str, FragmentDefinitionNode]] = None, client_name: str = "Client", async_client: bool = True, @@ -74,6 +75,7 @@ def __init__( self.client_generator = client_generator self.enums_generator = enums_generator self.input_types_generator = input_types_generator + self.fragments_generator = fragments_generator self.client_name = client_name self.async_client = async_client @@ -264,23 +266,17 @@ def _generate_fragments(self): ): return - generator = FragmentsGenerator( - schema=self.schema, - enums_module_name=self.enums_module_name, - fragments_definitions=self.fragments_definitions, - exclude_names=self._unpacked_fragments, - base_model_import=self.base_model_import, - convert_to_snake_case=self.convert_to_snake_case, - custom_scalars=self.custom_scalars, - plugin_manager=self.plugin_manager, + module = self.fragments_generator.generate( + exclude_names=self._unpacked_fragments ) - module = generator.generate() file_path = self.package_path / f"{self.fragments_module_name}.py" code = self._add_comments_to_code(ast_to_str(module), self.queries_source) file_path.write_text(code) self._generated_files.append(file_path.name) self.init_generator.add_import( - generator.get_generated_public_names(), self.fragments_module_name, 1 + self.fragments_generator.get_generated_public_names(), + self.fragments_module_name, + 1, ) def _copy_files(self): @@ -355,6 +351,16 @@ def get_package_generator( custom_scalars=settings.scalars, plugin_manager=plugin_manager, ) + fragments_definitions = {f.name.value: f for f in fragments or []} + fragments_generator = FragmentsGenerator( + schema=schema, + fragments_definitions=fragments_definitions, + enums_module_name=settings.enums_module_name, + base_model_import=BASE_MODEL_IMPORT, + convert_to_snake_case=settings.convert_to_snake_case, + custom_scalars=settings.scalars, + plugin_manager=plugin_manager, + ) return PackageGenerator( package_name=settings.target_package_name, @@ -364,7 +370,8 @@ def get_package_generator( client_generator=client_generator, enums_generator=enums_generator, input_types_generator=input_types_generator, - fragments_definitions={f.name.value: f for f in fragments or []}, + fragments_generator=fragments_generator, + fragments_definitions=fragments_definitions, client_name=settings.client_name, async_client=settings.async_client, base_client_name=settings.base_client_name, diff --git a/tests/client_generators/test_fragments_generator.py b/tests/client_generators/test_fragments_generator.py index 18d87b53..1deb3025 100644 --- a/tests/client_generators/test_fragments_generator.py +++ b/tests/client_generators/test_fragments_generator.py @@ -145,11 +145,10 @@ def test_generate_returns_module_without_models_for_excluded_fragments( "FragmentA": fragment_a, "FragmentB": fragment_b, }, - exclude_names={"TestFragment", "FragmentB"}, convert_to_snake_case=True, ) - module = generator.generate() + module = generator.generate(exclude_names={"TestFragment", "FragmentB"}) generated_class_defs = filter_class_defs(module) assert [c.name for c in generated_class_defs] == ["FragmentA"] diff --git a/tests/client_generators/test_package_generator.py b/tests/client_generators/test_package_generator.py index 5a44aac9..63310665 100644 --- a/tests/client_generators/test_package_generator.py +++ b/tests/client_generators/test_package_generator.py @@ -14,6 +14,7 @@ TIMESTAMP_COMMENT, ) from ariadne_codegen.client_generators.enums import EnumsGenerator +from ariadne_codegen.client_generators.fragments import FragmentsGenerator from ariadne_codegen.client_generators.init_file import InitFileGenerator from ariadne_codegen.client_generators.input_types import InputTypesGenerator from ariadne_codegen.client_generators.package import PackageGenerator @@ -72,6 +73,7 @@ def test_generate_creates_directory_and_files(tmp_path, async_base_client_import ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), ) generator.generate() @@ -115,6 +117,7 @@ def test_generate_creates_files_with_correct_imports( ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), ) generator.generate() @@ -163,6 +166,7 @@ def test_generate_creates_files_with_types(tmp_path, async_base_client_import): ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), ) expected_input_types = """ class CustomInput(BaseModel): @@ -203,6 +207,7 @@ def test_generate_creates_file_with_query_types(tmp_path, async_base_client_impo ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), ) query_str = """ query CustomQuery($id: ID!, $param: String) { @@ -259,6 +264,7 @@ def test_generate_creates_multiple_query_types_files( ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), ) query_str = """ query CustomQuery1 { @@ -307,6 +313,7 @@ class TestBaseClient: ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), base_client_name="TestBaseClient", base_client_file_path=base_client_file_path.as_posix(), ) @@ -337,6 +344,7 @@ def test_generate_creates_client_with_valid_method_names( ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), async_client=False, ) query_str = """ @@ -378,6 +386,7 @@ def test_generate_with_conflicting_query_name_raises_parsing_error( ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), input_types_module_name="input_types", convert_to_snake_case=True, ) @@ -410,6 +419,7 @@ def test_generate_with_enum_as_query_argument_generates_client_with_correct_meth ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), async_client=True, ) query_str = """ @@ -449,6 +459,7 @@ def test_generate_creates_client_file_with_gql_lambda_definition( ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), ) generator.generate() @@ -487,6 +498,7 @@ def test_generate_adds_comment_to_generated_files( ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), comments_strategy=strategy, ) query_str = """ @@ -536,6 +548,7 @@ def test_generate_adds_comment_with_correct_source_to_generated_files( ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), comments_strategy=strategy, schema_source=schema_source, queries_source=queries_source, @@ -589,6 +602,7 @@ def test_generate_calls_get_file_comment_hook_for_every_file( ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), comments_strategy=strategy, plugin_manager=mocked_plugin_manager, ) @@ -648,6 +662,9 @@ class CustomQueryQuery1(TestFragment): ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator( + schema=schema, fragments_definitions={"TestFragment": fragment_def} + ), fragments_definitions={"TestFragment": fragment_def}, ) @@ -662,6 +679,11 @@ class CustomQueryQuery1(TestFragment): def test_generate_returns_list_of_generated_files(tmp_path, async_base_client_import): schema = build_ast_schema(parse(SCHEMA_STR)) + fragments_definitions = { + "TestFragment": parse("fragment TestFragment on CustomType { id }").definitions[ + 0 + ] + } generator = PackageGenerator( package_name="test_graphql_client", target_path=tmp_path.as_posix(), @@ -673,11 +695,10 @@ def test_generate_returns_list_of_generated_files(tmp_path, async_base_client_im ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), - fragments_definitions={ - "TestFragment": parse( - "fragment TestFragment on CustomType { id }" - ).definitions[0] - }, + fragments_generator=FragmentsGenerator( + schema=schema, fragments_definitions=fragments_definitions + ), + fragments_definitions=fragments_definitions, custom_scalars={"SCALARABC": ScalarData(type_="str", graphql_name="SCALARABC")}, ) query_str = """ @@ -729,6 +750,7 @@ def test_generate_copies_files_to_include(tmp_path, async_base_client_import): ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), files_to_include=[file1.as_posix(), file2.as_posix()], ) generated_files = generator.generate() @@ -765,6 +787,7 @@ def test_generate_creates_client_with_custom_scalars_imports( input_types_generator=InputTypesGenerator( schema=schema, custom_scalars=custom_scalars ), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), custom_scalars=custom_scalars, ) query_str = """ @@ -801,6 +824,7 @@ def test_generate_triggers_generate_client_code_hook( ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), plugin_manager=mocked_plugin_manager, ).generate() @@ -823,6 +847,7 @@ def test_generate_triggers_generate_enums_code_hook( ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), plugin_manager=mocked_plugin_manager, ).generate() @@ -845,6 +870,7 @@ def test_generate_triggers_generate_inputs_code_hook( ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), plugin_manager=mocked_plugin_manager, ).generate() @@ -866,6 +892,7 @@ def test_generate_triggers_generate_result_types_code_hook_for_every_added_opera ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), plugin_manager=mocked_plugin_manager, ) generator.add_operation(parse("query A { query2 { id } }").definitions[0]) @@ -892,6 +919,7 @@ def test_generate_triggers_copy_code_hook_for_every_attached_dependency_file( ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), plugin_manager=mocked_plugin_manager, ).generate() @@ -916,6 +944,7 @@ def test_generate_triggers_copy_code_hook_for_every_file_to_include( ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), plugin_manager=mocked_plugin_manager, files_to_include=[test_file_path.as_posix()], ).generate() @@ -939,6 +968,7 @@ def test_generate_triggers_generate_init_code_hook( ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), plugin_manager=mocked_plugin_manager, ).generate() @@ -968,6 +998,7 @@ def test_add_operation_triggers_process_name_hook( ), enums_generator=EnumsGenerator(schema=schema), input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), plugin_manager=mocked_plugin_manager, ).add_operation(parse(query_str).definitions[0]) From bd115dfdb8d12f129e0376ce1714eb81b36f9183 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Tue, 10 Oct 2023 17:57:45 +0200 Subject: [PATCH 4/5] Divide package generator's tests into more files --- .../package_generator/__init__.py | 0 .../package_generator/conftest.py | 44 +++ .../test_generated_files.py} | 292 ++---------------- .../package_generator/test_plugin_hooks.py | 199 ++++++++++++ 4 files changed, 267 insertions(+), 268 deletions(-) create mode 100644 tests/client_generators/package_generator/__init__.py create mode 100644 tests/client_generators/package_generator/conftest.py rename tests/client_generators/{test_package_generator.py => package_generator/test_generated_files.py} (72%) create mode 100644 tests/client_generators/package_generator/test_plugin_hooks.py diff --git a/tests/client_generators/package_generator/__init__.py b/tests/client_generators/package_generator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/client_generators/package_generator/conftest.py b/tests/client_generators/package_generator/conftest.py new file mode 100644 index 00000000..36f74191 --- /dev/null +++ b/tests/client_generators/package_generator/conftest.py @@ -0,0 +1,44 @@ +import pytest +from graphql import build_ast_schema, parse + + +@pytest.fixture +def schema_str(): + return """ + schema { + query: Query + } + + type Query { + query1(id: ID!, param: SCALARABC): CustomType + query2: [CustomType!] + query3(val: CustomEnum!): [CustomType] + } + + type CustomType { + id: ID! + field1: [String] + field2: CustomType2 + field3: CustomEnum! + } + + type CustomType2 { + fieldb: Int + } + + enum CustomEnum { + VAL1 + VAL2 + } + + input CustomInput { + value: Int! + } + + scalar SCALARABC + """ + + +@pytest.fixture +def schema(schema_str): + return build_ast_schema(parse(schema_str)) diff --git a/tests/client_generators/test_package_generator.py b/tests/client_generators/package_generator/test_generated_files.py similarity index 72% rename from tests/client_generators/test_package_generator.py rename to tests/client_generators/package_generator/test_generated_files.py index 63310665..a2ad4417 100644 --- a/tests/client_generators/test_package_generator.py +++ b/tests/client_generators/package_generator/test_generated_files.py @@ -3,7 +3,7 @@ import pytest from freezegun import freeze_time -from graphql import GraphQLSchema, build_ast_schema, parse +from graphql import GraphQLSchema, parse from ariadne_codegen.client_generators.arguments import ArgumentsGenerator from ariadne_codegen.client_generators.client import ClientGenerator @@ -22,46 +22,13 @@ from ariadne_codegen.exceptions import ParsingError from ariadne_codegen.settings import CommentsStrategy -from ..utils import get_class_def +from ...utils import get_class_def -SCHEMA_STR = """ -schema { - query: Query -} -type Query { - query1(id: ID!, param: SCALARABC): CustomType - query2: [CustomType!] - query3(val: CustomEnum!): [CustomType] -} - -type CustomType { - id: ID! - field1: [String] - field2: CustomType2 - field3: CustomEnum! -} - -type CustomType2 { - fieldb: Int -} - -enum CustomEnum { - VAL1 - VAL2 -} - -input CustomInput { - value: Int! -} - -scalar SCALARABC -""" - - -def test_generate_creates_directory_and_files(tmp_path, async_base_client_import): +def test_generate_creates_directory_and_files( + tmp_path, schema, async_base_client_import +): package_name = "test_graphql_client" - schema = GraphQLSchema() generator = PackageGenerator( package_name=package_name, target_path=tmp_path.as_posix(), @@ -152,9 +119,8 @@ def test_generate_creates_files_with_correct_imports( assert "from .async_base_client import AsyncBaseClient" in client_content -def test_generate_creates_files_with_types(tmp_path, async_base_client_import): +def test_generate_creates_files_with_types(tmp_path, schema, async_base_client_import): package_name = "test_graphql_client" - schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( package_name=package_name, target_path=tmp_path.as_posix(), @@ -193,9 +159,10 @@ class CustomEnum(str, Enum): assert dedent(expected_enums) in enums_content -def test_generate_creates_file_with_query_types(tmp_path, async_base_client_import): +def test_generate_creates_file_with_query_types( + tmp_path, schema, async_base_client_import +): package_name = "test_graphql_client" - schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( package_name=package_name, target_path=tmp_path.as_posix(), @@ -249,10 +216,9 @@ class CustomQueryQuery1Field2(BaseModel): def test_generate_creates_multiple_query_types_files( - tmp_path, async_base_client_import + tmp_path, schema, async_base_client_import ): package_name = "test_graphql_client" - schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( package_name=package_name, target_path=tmp_path.as_posix(), @@ -293,7 +259,7 @@ def test_generate_creates_multiple_query_types_files( assert query2_file_path.is_file() -def test_generate_copies_base_client_file(tmp_path, async_base_client_import): +def test_generate_copies_base_client_file(tmp_path, schema, async_base_client_import): base_client_file_content = """ class TestBaseClient: pass @@ -301,7 +267,6 @@ class TestBaseClient: package_name = "test_graphql_client" base_client_file_path = tmp_path / "test_base_client.py" base_client_file_path.write_text(dedent(base_client_file_content)) - schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( package_name=package_name, target_path=tmp_path.as_posix(), @@ -329,10 +294,9 @@ class TestBaseClient: def test_generate_creates_client_with_valid_method_names( - tmp_path, async_base_client_import + tmp_path, schema, async_base_client_import ): package_name = "test_graphql_client" - schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( package_name=package_name, target_path=tmp_path.as_posix(), @@ -372,9 +336,8 @@ def test_generate_creates_client_with_valid_method_names( def test_generate_with_conflicting_query_name_raises_parsing_error( - tmp_path, async_base_client_import + tmp_path, schema, async_base_client_import ): - schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( package_name="test_graphql_client", target_path=tmp_path.as_posix(), @@ -404,10 +367,9 @@ def test_generate_with_conflicting_query_name_raises_parsing_error( def test_generate_with_enum_as_query_argument_generates_client_with_correct_method( - tmp_path, async_base_client_import + tmp_path, schema, async_base_client_import ): package_name = "test_graphql_client" - schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( package_name=package_name, target_path=tmp_path.as_posix(), @@ -444,10 +406,9 @@ def test_generate_with_enum_as_query_argument_generates_client_with_correct_meth def test_generate_creates_client_file_with_gql_lambda_definition( - tmp_path, async_base_client_import + tmp_path, schema, async_base_client_import ): package_name = "test_graphql_client" - schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( package_name=package_name, target_path=tmp_path.as_posix(), @@ -483,10 +444,9 @@ def test_generate_creates_client_file_with_gql_lambda_definition( ) @freeze_time("01.01.2022 12:00") def test_generate_adds_comment_to_generated_files( - tmp_path, strategy, expected_comment, async_base_client_import + tmp_path, schema, strategy, expected_comment, async_base_client_import ): package_name = "test_graphql_client" - schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( package_name=package_name, target_path=tmp_path.as_posix(), @@ -531,12 +491,11 @@ def test_generate_adds_comment_to_generated_files( "strategy", [CommentsStrategy.STABLE, CommentsStrategy.TIMESTAMP] ) def test_generate_adds_comment_with_correct_source_to_generated_files( - tmp_path, async_base_client_import, strategy + tmp_path, schema, async_base_client_import, strategy ): package_name = "test_graphql_client" schema_source = "schema_source.graphql" queries_source = "queries_source.graphql" - schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( package_name=package_name, target_path=tmp_path.as_posix(), @@ -587,10 +546,9 @@ def test_generate_adds_comment_with_correct_source_to_generated_files( [CommentsStrategy.NONE, CommentsStrategy.STABLE, CommentsStrategy.TIMESTAMP], ) def test_generate_calls_get_file_comment_hook_for_every_file( - tmp_path, async_base_client_import, strategy, mocked_plugin_manager + tmp_path, schema, async_base_client_import, strategy, mocked_plugin_manager ): package_name = "test_graphql_client" - schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( package_name=package_name, target_path=tmp_path.as_posix(), @@ -623,7 +581,7 @@ def test_generate_calls_get_file_comment_hook_for_every_file( def test_generate_creates_result_types_from_operation_that_uses_fragment( - tmp_path, async_base_client_import + tmp_path, schema, async_base_client_import ): package_name = "test_graphql_client" query_str = """ @@ -650,7 +608,6 @@ class CustomQueryQuery1(TestFragment): field_3: CustomEnum = Field(alias="field3") """ query_def, fragment_def = parse(query_str).definitions - schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( package_name=package_name, target_path=tmp_path.as_posix(), @@ -677,8 +634,9 @@ class CustomQueryQuery1(TestFragment): assert dedent(expected_types) in result_types_content -def test_generate_returns_list_of_generated_files(tmp_path, async_base_client_import): - schema = build_ast_schema(parse(SCHEMA_STR)) +def test_generate_returns_list_of_generated_files( + tmp_path, schema, async_base_client_import +): fragments_definitions = { "TestFragment": parse("fragment TestFragment on CustomType { id }").definitions[ 0 @@ -727,7 +685,7 @@ def test_generate_returns_list_of_generated_files(tmp_path, async_base_client_im ) -def test_generate_copies_files_to_include(tmp_path, async_base_client_import): +def test_generate_copies_files_to_include(tmp_path, schema, async_base_client_import): file1 = tmp_path / "file1.py" file1_content = "class TestBaseClass:\n pass" file1.write_text(file1_content) @@ -738,7 +696,6 @@ def test_generate_copies_files_to_include(tmp_path, async_base_client_import): file2_content = "class TestBaseClass2:\n pass" file2.write_text(file2_content) - schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( package_name="test_graphql_client", target_path=tmp_path.as_posix(), @@ -764,13 +721,12 @@ def test_generate_copies_files_to_include(tmp_path, async_base_client_import): def test_generate_creates_client_with_custom_scalars_imports( - tmp_path, async_base_client_import + tmp_path, schema, async_base_client_import ): package_name = "test_graphql_client" custom_scalars = { "SCALARABC": ScalarData(type_=".abc.ScalarABC", graphql_name="SCALARABC") } - schema = build_ast_schema(parse(SCHEMA_STR)) generator = PackageGenerator( package_name=package_name, target_path=tmp_path.as_posix(), @@ -806,203 +762,3 @@ def test_generate_creates_client_with_custom_scalars_imports( f"{generator.client_file_name}.py" ).open() as client_file: assert "from .abc import ScalarABC" in client_file.read() - - -def test_generate_triggers_generate_client_code_hook( - mocked_plugin_manager, tmp_path, async_base_client_import -): - schema = build_ast_schema(parse(SCHEMA_STR)) - - PackageGenerator( - package_name="test_graphql_client", - target_path=tmp_path.as_posix(), - schema=schema, - init_generator=InitFileGenerator(), - client_generator=ClientGenerator( - base_client_import=async_base_client_import, - arguments_generator=ArgumentsGenerator(schema=schema), - ), - enums_generator=EnumsGenerator(schema=schema), - input_types_generator=InputTypesGenerator(schema=schema), - fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), - plugin_manager=mocked_plugin_manager, - ).generate() - - assert mocked_plugin_manager.generate_client_code.called - - -def test_generate_triggers_generate_enums_code_hook( - mocked_plugin_manager, tmp_path, async_base_client_import -): - schema = build_ast_schema(parse(SCHEMA_STR)) - - PackageGenerator( - package_name="test_graphql_client", - target_path=tmp_path.as_posix(), - schema=schema, - init_generator=InitFileGenerator(), - client_generator=ClientGenerator( - base_client_import=async_base_client_import, - arguments_generator=ArgumentsGenerator(schema=schema), - ), - enums_generator=EnumsGenerator(schema=schema), - input_types_generator=InputTypesGenerator(schema=schema), - fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), - plugin_manager=mocked_plugin_manager, - ).generate() - - assert mocked_plugin_manager.generate_enums_code.called - - -def test_generate_triggers_generate_inputs_code_hook( - mocked_plugin_manager, tmp_path, async_base_client_import -): - schema = build_ast_schema(parse(SCHEMA_STR)) - - PackageGenerator( - package_name="test_graphql_client", - target_path=tmp_path.as_posix(), - schema=schema, - init_generator=InitFileGenerator(), - client_generator=ClientGenerator( - base_client_import=async_base_client_import, - arguments_generator=ArgumentsGenerator(schema=schema), - ), - enums_generator=EnumsGenerator(schema=schema), - input_types_generator=InputTypesGenerator(schema=schema), - fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), - plugin_manager=mocked_plugin_manager, - ).generate() - - assert mocked_plugin_manager.generate_inputs_code.called - - -def test_generate_triggers_generate_result_types_code_hook_for_every_added_operation( - mocked_plugin_manager, tmp_path, async_base_client_import -): - schema = build_ast_schema(parse(SCHEMA_STR)) - generator = PackageGenerator( - package_name="test_graphql_client", - target_path=tmp_path.as_posix(), - schema=schema, - init_generator=InitFileGenerator(), - client_generator=ClientGenerator( - base_client_import=async_base_client_import, - arguments_generator=ArgumentsGenerator(schema=schema), - ), - enums_generator=EnumsGenerator(schema=schema), - input_types_generator=InputTypesGenerator(schema=schema), - fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), - plugin_manager=mocked_plugin_manager, - ) - generator.add_operation(parse("query A { query2 { id } }").definitions[0]) - generator.add_operation(parse("query B { query2 { id } }").definitions[0]) - - generator.generate() - - assert mocked_plugin_manager.generate_result_types_code.call_count == 2 - - -def test_generate_triggers_copy_code_hook_for_every_attached_dependency_file( - mocked_plugin_manager, tmp_path, async_base_client_import -): - schema = build_ast_schema(parse(SCHEMA_STR)) - - PackageGenerator( - package_name="test_graphql_client", - target_path=tmp_path.as_posix(), - schema=schema, - init_generator=InitFileGenerator(), - client_generator=ClientGenerator( - base_client_import=async_base_client_import, - arguments_generator=ArgumentsGenerator(schema=schema), - ), - enums_generator=EnumsGenerator(schema=schema), - input_types_generator=InputTypesGenerator(schema=schema), - fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), - plugin_manager=mocked_plugin_manager, - ).generate() - - assert mocked_plugin_manager.copy_code.call_count == 3 - - -def test_generate_triggers_copy_code_hook_for_every_file_to_include( - mocked_plugin_manager, tmp_path, async_base_client_import -): - test_file_path = tmp_path / "xyz.py" - test_file_path.touch() - schema = build_ast_schema(parse(SCHEMA_STR)) - - PackageGenerator( - package_name="test_graphql_client", - target_path=tmp_path.as_posix(), - schema=schema, - init_generator=InitFileGenerator(), - client_generator=ClientGenerator( - base_client_import=async_base_client_import, - arguments_generator=ArgumentsGenerator(schema=schema), - ), - enums_generator=EnumsGenerator(schema=schema), - input_types_generator=InputTypesGenerator(schema=schema), - fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), - plugin_manager=mocked_plugin_manager, - files_to_include=[test_file_path.as_posix()], - ).generate() - - assert mocked_plugin_manager.copy_code.call_count == 4 - - -def test_generate_triggers_generate_init_code_hook( - mocked_plugin_manager, tmp_path, async_base_client_import -): - schema = build_ast_schema(parse(SCHEMA_STR)) - - PackageGenerator( - package_name="test_graphql_client", - target_path=tmp_path.as_posix(), - schema=schema, - init_generator=InitFileGenerator(), - client_generator=ClientGenerator( - base_client_import=async_base_client_import, - arguments_generator=ArgumentsGenerator(schema=schema), - ), - enums_generator=EnumsGenerator(schema=schema), - input_types_generator=InputTypesGenerator(schema=schema), - fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), - plugin_manager=mocked_plugin_manager, - ).generate() - - assert mocked_plugin_manager.generate_init_code.called - - -def test_add_operation_triggers_process_name_hook( - mocked_plugin_manager, tmp_path, async_base_client_import -): - query_str = """ - query custom_query_name { - query2 { - id - } - } - """ - schema = build_ast_schema(parse(SCHEMA_STR)) - - PackageGenerator( - package_name="test_graphql_client", - target_path=tmp_path.as_posix(), - schema=schema, - init_generator=InitFileGenerator(), - client_generator=ClientGenerator( - base_client_import=async_base_client_import, - arguments_generator=ArgumentsGenerator(schema=schema), - ), - enums_generator=EnumsGenerator(schema=schema), - input_types_generator=InputTypesGenerator(schema=schema), - fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), - plugin_manager=mocked_plugin_manager, - ).add_operation(parse(query_str).definitions[0]) - - assert mocked_plugin_manager.process_name.called - assert "custom_query_name" in { - c.args[0] for c in mocked_plugin_manager.process_name.mock_calls - } diff --git a/tests/client_generators/package_generator/test_plugin_hooks.py b/tests/client_generators/package_generator/test_plugin_hooks.py new file mode 100644 index 00000000..022f10dd --- /dev/null +++ b/tests/client_generators/package_generator/test_plugin_hooks.py @@ -0,0 +1,199 @@ +from graphql import parse + +from ariadne_codegen.client_generators.arguments import ArgumentsGenerator +from ariadne_codegen.client_generators.client import ClientGenerator +from ariadne_codegen.client_generators.enums import EnumsGenerator +from ariadne_codegen.client_generators.fragments import FragmentsGenerator +from ariadne_codegen.client_generators.init_file import InitFileGenerator +from ariadne_codegen.client_generators.input_types import InputTypesGenerator +from ariadne_codegen.client_generators.package import PackageGenerator + + +def test_generate_triggers_generate_client_code_hook( + tmp_path, + schema, + async_base_client_import, + mocked_plugin_manager, +): + PackageGenerator( + package_name="test_graphql_client", + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), + plugin_manager=mocked_plugin_manager, + ).generate() + + assert mocked_plugin_manager.generate_client_code.called + + +def test_generate_triggers_generate_enums_code_hook( + tmp_path, schema, async_base_client_import, mocked_plugin_manager +): + PackageGenerator( + package_name="test_graphql_client", + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), + plugin_manager=mocked_plugin_manager, + ).generate() + + assert mocked_plugin_manager.generate_enums_code.called + + +def test_generate_triggers_generate_inputs_code_hook( + tmp_path, schema, async_base_client_import, mocked_plugin_manager +): + PackageGenerator( + package_name="test_graphql_client", + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), + plugin_manager=mocked_plugin_manager, + ).generate() + + assert mocked_plugin_manager.generate_inputs_code.called + + +def test_generate_triggers_generate_result_types_code_hook_for_every_added_operation( + tmp_path, schema, async_base_client_import, mocked_plugin_manager +): + generator = PackageGenerator( + package_name="test_graphql_client", + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), + plugin_manager=mocked_plugin_manager, + ) + generator.add_operation(parse("query A { query2 { id } }").definitions[0]) + generator.add_operation(parse("query B { query2 { id } }").definitions[0]) + + generator.generate() + + assert mocked_plugin_manager.generate_result_types_code.call_count == 2 + + +def test_generate_triggers_copy_code_hook_for_every_attached_dependency_file( + tmp_path, schema, async_base_client_import, mocked_plugin_manager +): + PackageGenerator( + package_name="test_graphql_client", + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), + plugin_manager=mocked_plugin_manager, + ).generate() + + assert mocked_plugin_manager.copy_code.call_count == 3 + + +def test_generate_triggers_copy_code_hook_for_every_file_to_include( + tmp_path, schema, async_base_client_import, mocked_plugin_manager +): + test_file_path = tmp_path / "xyz.py" + test_file_path.touch() + + PackageGenerator( + package_name="test_graphql_client", + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), + plugin_manager=mocked_plugin_manager, + files_to_include=[test_file_path.as_posix()], + ).generate() + + assert mocked_plugin_manager.copy_code.call_count == 4 + + +def test_generate_triggers_generate_init_code_hook( + tmp_path, schema, async_base_client_import, mocked_plugin_manager +): + PackageGenerator( + package_name="test_graphql_client", + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), + plugin_manager=mocked_plugin_manager, + ).generate() + + assert mocked_plugin_manager.generate_init_code.called + + +def test_add_operation_triggers_process_name_hook( + tmp_path, schema, async_base_client_import, mocked_plugin_manager +): + query_str = """ + query custom_query_name { + query2 { + id + } + } + """ + + PackageGenerator( + package_name="test_graphql_client", + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), + plugin_manager=mocked_plugin_manager, + ).add_operation(parse(query_str).definitions[0]) + + assert mocked_plugin_manager.process_name.called + assert "custom_query_name" in { + c.args[0] for c in mocked_plugin_manager.process_name.mock_calls + } From 20835bfe7a4f37511094d1ab02610714a8579c16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Wed, 11 Oct 2023 12:10:00 +0200 Subject: [PATCH 5/5] Update CHANGELOG.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Rafał Pitoń --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 10090d88..6c09c6bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ - Digits in Python names are now preceded by an underscore (breaking change). - Fixed parsing of unions and interfaces to always add `__typename` to generated result models. - Added escaping of enum values which are Python keywords by appending `_` to them. -- Fixed `enums_module_name` option. +- Fixed `enums_module_name` option not being passed to generators. ## 0.9.0 (2023-09-11)