Skip to content

Commit

Permalink
Merge pull request #229 from mirumee/enums_custom_name
Browse files Browse the repository at this point in the history
Fix `enums_module_name` option
  • Loading branch information
mat-sop authored Oct 11, 2023
2 parents 878b4ff + 20835bf commit a7b3c19
Show file tree
Hide file tree
Showing 26 changed files with 782 additions and 676 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- Digits in Python names are now preceded by an underscore (breaking change).
- Fixed parsing of unions and interfaces to always add `__typename` to generated result models.
- Added escaping of enum values which are Python keywords by appending `_` to them.
- Fixed `enums_module_name` option not being passed to generators.


## 0.9.0 (2023-09-11)
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Optional settings:
- `client_name` (defaults to `"Client"`) - name of generated client class
- `client_file_name` (defaults to `"client"`) - name of file with generated client class
- `base_client_name` (defaults to `"AsyncBaseClient"`) - name of base client class
- `base_client_file_path` (defaults to `.../graphql_sdk_gen/generators/async_base_client.py`) - path to file where `base_client_name` is defined
- `base_client_file_path` (defaults to `.../ariadne_codegen/client_generators/dependencies/async_base_client.py`) - path to file where `base_client_name` is defined
- `enums_module_name` (defaults to `"enums"`) - name of file with generated enums models
- `input_types_module_name` (defaults to `"input_types"`) - name of file with generated input types models
- `fragments_module_name` (defaults to `"fragments"`) - name of file with generated fragments models
Expand Down
16 changes: 9 additions & 7 deletions ariadne_codegen/client_generators/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,23 @@
OPTIONAL,
TYPING_MODULE,
UNION,
UNSET_IMPORT,
UPLOAD_IMPORT,
)
from .scalars import ScalarData, generate_scalar_imports


class ClientGenerator:
def __init__(
self,
name: str,
base_client: str,
enums_module_name: str,
input_types_module_name: str,
arguments_generator: ArgumentsGenerator,
base_client_import: ast.ImportFrom,
unset_import: ast.ImportFrom,
upload_import: ast.ImportFrom,
arguments_generator: ArgumentsGenerator,
name: str = "Client",
base_client: str = "AsyncBaseClient",
enums_module_name: str = "enums",
input_types_module_name: str = "input_types",
unset_import: ast.ImportFrom = UNSET_IMPORT,
upload_import: ast.ImportFrom = UPLOAD_IMPORT,
custom_scalars: Optional[Dict[str, ScalarData]] = None,
plugin_manager: Optional[PluginManager] = None,
) -> None:
Expand Down
20 changes: 17 additions & 3 deletions ariadne_codegen/client_generators/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
from pathlib import Path

SIMPLE_TYPE_MAP = {
Expand Down Expand Up @@ -26,7 +27,23 @@
SOURCE_COMMENT = "# Source: {}"
COMMENT_DATETIME_FORMAT = "%Y-%m-%d %H:%M"

BASE_MODEL_FILE_PATH = Path(__file__).parent / "dependencies" / "base_model.py"
BASE_MODEL_CLASS_NAME = "BaseModel"
BASE_MODEL_IMPORT = ast.ImportFrom(
module=BASE_MODEL_FILE_PATH.stem, names=[ast.alias(BASE_MODEL_CLASS_NAME)], level=1
)
UPLOAD_IMPORT = ast.ImportFrom(
module=BASE_MODEL_FILE_PATH.stem, names=[ast.alias(UPLOAD_CLASS_NAME)], level=1
)
UNSET_NAME = "UNSET"
UNSET_TYPE_NAME = "UnsetType"
UNSET_IMPORT = ast.ImportFrom(
module=BASE_MODEL_FILE_PATH.stem,
names=[ast.alias(UNSET_NAME), ast.alias(UNSET_TYPE_NAME)],
level=1,
)

EXCEPTIONS_FILE_PATH = Path(__file__).parent / "dependencies" / "exceptions.py"

TYPENAME_FIELD_NAME = "__typename"
TYPENAME_ALIAS = "typename__"
Expand Down Expand Up @@ -66,6 +83,3 @@

SCALARS_PARSE_DICT_NAME = "SCALARS_PARSE_FUNCTIONS"
SCALARS_SERIALIZE_DICT_NAME = "SCALARS_SERIALIZE_FUNCTIONS"

UNSET_NAME = "UNSET"
UNSET_TYPE_NAME = "UnsetType"
16 changes: 7 additions & 9 deletions ariadne_codegen/client_generators/fragments.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ..codegen import generate_expr, generate_method_call, generate_module
from ..plugins.manager import PluginManager
from .constants import MODEL_REBUILD_METHOD
from .constants import BASE_MODEL_IMPORT, MODEL_REBUILD_METHOD
from .result_types import ResultTypesGenerator
from .scalars import ScalarData

Expand All @@ -14,33 +14,31 @@ class FragmentsGenerator:
def __init__(
self,
schema: GraphQLSchema,
enums_module_name: str,
fragments_definitions: Dict[str, FragmentDefinitionNode],
exclude_names: Optional[Set[str]] = None,
base_model_import: Optional[ast.ImportFrom] = None,
enums_module_name: str = "enums",
base_model_import: ast.ImportFrom = BASE_MODEL_IMPORT,
convert_to_snake_case: bool = True,
custom_scalars: Optional[Dict[str, ScalarData]] = None,
plugin_manager: Optional[PluginManager] = None,
) -> None:
self.schema = schema
self.enums_module_name = enums_module_name
self.exclude_names = exclude_names or set()
self.fragments_definitions = fragments_definitions
self.base_model_import = base_model_import
self.convert_to_snake_case = convert_to_snake_case
self.custom_scalars = custom_scalars
self.plugin_manager = plugin_manager

self._fragments_names = (
set(self.fragments_definitions.keys()) - self.exclude_names
)
self._fragments_names = set(self.fragments_definitions.keys())
self._generated_public_names: List[str] = []

def generate(self) -> ast.Module:
def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module:
class_defs_dict: Dict[str, List[ast.ClassDef]] = {}
imports: List[ast.ImportFrom] = []
dependencies_dict: Dict[str, Set[str]] = {}

names_to_exclude = exclude_names or set()
self._fragments_names = self._fragments_names - names_to_exclude
for name in self._fragments_names:
fragmanet_def = self.fragments_definitions[name]
generator = ResultTypesGenerator(
Expand Down
8 changes: 5 additions & 3 deletions ariadne_codegen/client_generators/input_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ANNOTATED,
ANY,
BASE_MODEL_CLASS_NAME,
BASE_MODEL_IMPORT,
FIELD_CLASS,
LIST,
MODEL_REBUILD_METHOD,
Expand All @@ -35,6 +36,7 @@
PYDANTIC_MODULE,
TYPING_MODULE,
UNION,
UPLOAD_IMPORT,
)
from .input_fields import parse_input_field_default_value, parse_input_field_type
from .scalars import ScalarData, generate_scalar_imports
Expand All @@ -44,9 +46,9 @@ class InputTypesGenerator:
def __init__(
self,
schema: GraphQLSchema,
enums_module: str,
base_model_import: ast.ImportFrom,
upload_import: ast.ImportFrom,
enums_module: str = "enums",
base_model_import: ast.ImportFrom = BASE_MODEL_IMPORT,
upload_import: ast.ImportFrom = UPLOAD_IMPORT,
convert_to_snake_case: bool = True,
custom_scalars: Optional[Dict[str, ScalarData]] = None,
plugin_manager: Optional[PluginManager] = None,
Expand Down
Loading

0 comments on commit a7b3c19

Please sign in to comment.