From cd7807f39f30bf1c82cd4dee95e17530131472fd Mon Sep 17 00:00:00 2001 From: Ryan Partridge Date: Fri, 27 Sep 2024 21:18:23 +0100 Subject: [PATCH] feat(routeset): Added `schema` content creation. --- zentra_api/cli/commands/add.py | 74 ++++++++++++++++++++++++------ zentra_api/cli/constants/routes.py | 19 +++++--- 2 files changed, 71 insertions(+), 22 deletions(-) diff --git a/zentra_api/cli/commands/add.py b/zentra_api/cli/commands/add.py index c641f34..868b490 100644 --- a/zentra_api/cli/commands/add.py +++ b/zentra_api/cli/commands/add.py @@ -1,5 +1,6 @@ import json import os +import textwrap from typing import Callable from pathlib import Path @@ -162,6 +163,24 @@ def __init__(self, name: Name, root: Path, option: RouteOptions) -> None: self.schema_content = None self.response_content = None + def _build_base_schema_models(self) -> str: + """Creates the content for the main schema models to put in the 'schema.py' file.""" + name_title = self.name.singular.title() + name_lower = self.name.singular + + return textwrap.dedent(f""" + class {name_title}Base(BaseModel): + pass + + + class {name_title}(BaseModel): + pass + + + class {name_title}ID(BaseModel): + id: int = Field(..., description="The ID of the {name_lower}.") + """).lstrip("\n") + def _get_routes(self) -> list[Route]: """Retrieves the routes from the route map.""" routes = [] @@ -175,30 +194,38 @@ def _get_routes(self) -> list[Route]: def _create_init_content(self, routes: list[Route]) -> None: """Creates the '__init__.py' file content.""" + add_auth = any([route.auth for route in routes]) response_models = [ route.response_model for route in routes if route.response_model not in ROUTE_RESPONSE_MODEL_BLACKLIST ] schema_models = [route.schema_model for route in routes if route.schema_model] - add_auth = any([route.auth for route in routes]) - folder_imports = [ + if "c" in self.option or "u" in self.option: + schema_models.append(f"{self.name.singular.title()}ID") + + local_file_imports = [ Import( root=".", modules=[RouteFile.RESPONSES.value.split(".")[0]], items=response_models, add_dot=False, - ), - Import( - root=".", - modules=[RouteFile.SCHEMA.value.split(".")[0]], - items=schema_models, - add_dot=False, - ), + ) ] + + if schema_models: + local_file_imports.append( + Import( + root=".", + modules=[RouteFile.SCHEMA.value.split(".")[0]], + items=schema_models, + add_dot=False, + ), + ) + file_imports: list[list[Import]] = route_imports(add_auth=add_auth) - file_imports.insert(1, folder_imports) + file_imports.insert(1, local_file_imports) file_imports = Imports(items=file_imports).to_str() self.init_content = "\n".join( @@ -213,7 +240,23 @@ def _create_init_content(self, routes: list[Route]) -> None: def _create_schema_content(self, routes: list[Route]) -> None: """Creates the 'schema.py' file content.""" - pass + file_imports = [ + [Import(root="pydantic", items=["BaseModel", "Field"], add_dot=False)] + ] + route_schema_models = [ + route.schema_model_content() + for route in routes + if route.method not in [RouteMethods.GET, RouteMethods.DELETE] + ] + self.schema_content = "\n".join( + [ + Imports(items=file_imports).to_str(), + "", + self._build_base_schema_models() + + "\n\n" + + "\n\n".join(route_schema_models), + ] + ) def _create_responses_content(self, routes: list[Route]) -> None: """Creates the 'responses.py' file content.""" @@ -228,9 +271,10 @@ def _create_responses_content(self, routes: list[Route]) -> None: file_imports = [ [ Import( - root="app", - modules=["api", self.name.plural, "schema"], + root=".", + modules=["schema"], items=schema_models, + add_dot=False, ) ], [ @@ -256,7 +300,7 @@ def _create_responses_content(self, routes: list[Route]) -> None: def _update_files(self) -> None: """Updates the '__init__.py', 'schema.py', and 'responses.py' files.""" self.asset_paths.init_file.write_text(self.init_content) - # self.asset_paths.schema_file.write_text(self.schema_content) + self.asset_paths.schema_file.write_text(self.schema_content) self.asset_paths.responses_file.write_text(self.response_content) def _create_route_files(self) -> None: @@ -275,7 +319,7 @@ def get_tasks_for_set(self) -> list[Callable]: routes = self._get_routes() self._create_init_content(routes) - # self._create_schema_content(routes) + self._create_schema_content(routes) self._create_responses_content(routes) tasks.extend([self._update_files]) diff --git a/zentra_api/cli/constants/routes.py b/zentra_api/cli/constants/routes.py index 92be753..603ca8e 100644 --- a/zentra_api/cli/constants/routes.py +++ b/zentra_api/cli/constants/routes.py @@ -139,7 +139,7 @@ class Route(BaseModel): def model_post_init(self, __context: Any) -> None: self._func_name = f"{self.func_name_start()}_{self.name.lower()}" self._response_model = self.set_response_model() - self._schema_model = self.set_schema_model() + self._schema_model = self.set_schema_model_name() details = RouteDefaultDetails( method=self.method, @@ -207,7 +207,7 @@ def set_response_model(self) -> str: name = self.name.title() return f"{method}{name}Response" - def set_schema_model(self) -> str | None: + def set_schema_model_name(self) -> str | None: """Creates the schema model (parameter) name.""" if self.method == RouteMethods.GET or self.method == RouteMethods.DELETE: return None @@ -247,9 +247,7 @@ def to_str(self, name: Name) -> str: return "\n".join(text).rstrip() def response_model_class(self, name: Name) -> str: - """ - Creates the route response model class. - """ + """Creates the route response model class.""" def data_type() -> str: """A helper method for creating the response data type T.""" @@ -269,6 +267,13 @@ class {self.response_model}(SuccessResponse[{data_type()}]): pass ''').strip("\n") + def schema_model_content(self) -> str: + """Creates the schema model content.""" + return textwrap.dedent(f""" + class {self.schema_model}(BaseModel): + pass + """).lstrip("\n") + def route_dict_set(name: Name) -> dict[str, Route]: """ @@ -383,7 +388,7 @@ def db_get_method(param: str | None = None) -> str: {out_name} = CONNECT.{name.plural}.create(db, {name.singular}.model_dump()) return {response_model}( code=status.HTTP_201_CREATED, - data={name.singular}.model_dump(), + data={name.singular.title()}ID(id={out_name}.id).model_dump(), ) """) elif method == RouteMethods.PATCH or method == RouteMethods.PUT: @@ -398,7 +403,7 @@ def db_get_method(param: str | None = None) -> str: {out_name} = CONNECT.{name.plural}.{db_get_method()} return {response_model}( code=status.HTTP_202_ACCEPTED, - data={out_name}.model_dump(), + data={name.singular.title()}ID(id=id).model_dump(), ) """) elif method == RouteMethods.DELETE: