Skip to content

Commit

Permalink
feat(routeset): Added schema content creation.
Browse files Browse the repository at this point in the history
  • Loading branch information
Achronus committed Sep 27, 2024
1 parent 988e09c commit cd7807f
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 22 deletions.
74 changes: 59 additions & 15 deletions zentra_api/cli/commands/add.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
import textwrap
from typing import Callable
from pathlib import Path

Expand Down Expand Up @@ -162,6 +163,24 @@ def __init__(self, name: Name, root: Path, option: RouteOptions) -> None:
self.schema_content = None
self.response_content = None

def _build_base_schema_models(self) -> str:
"""Creates the content for the main schema models to put in the 'schema.py' file."""
name_title = self.name.singular.title()
name_lower = self.name.singular

return textwrap.dedent(f"""
class {name_title}Base(BaseModel):
pass
class {name_title}(BaseModel):
pass
class {name_title}ID(BaseModel):
id: int = Field(..., description="The ID of the {name_lower}.")
""").lstrip("\n")

def _get_routes(self) -> list[Route]:
"""Retrieves the routes from the route map."""
routes = []
Expand All @@ -175,30 +194,38 @@ def _get_routes(self) -> list[Route]:

def _create_init_content(self, routes: list[Route]) -> None:
"""Creates the '__init__.py' file content."""
add_auth = any([route.auth for route in routes])
response_models = [
route.response_model
for route in routes
if route.response_model not in ROUTE_RESPONSE_MODEL_BLACKLIST
]
schema_models = [route.schema_model for route in routes if route.schema_model]
add_auth = any([route.auth for route in routes])

folder_imports = [
if "c" in self.option or "u" in self.option:
schema_models.append(f"{self.name.singular.title()}ID")

local_file_imports = [
Import(
root=".",
modules=[RouteFile.RESPONSES.value.split(".")[0]],
items=response_models,
add_dot=False,
),
Import(
root=".",
modules=[RouteFile.SCHEMA.value.split(".")[0]],
items=schema_models,
add_dot=False,
),
)
]

if schema_models:
local_file_imports.append(
Import(
root=".",
modules=[RouteFile.SCHEMA.value.split(".")[0]],
items=schema_models,
add_dot=False,
),
)

file_imports: list[list[Import]] = route_imports(add_auth=add_auth)
file_imports.insert(1, folder_imports)
file_imports.insert(1, local_file_imports)
file_imports = Imports(items=file_imports).to_str()

self.init_content = "\n".join(
Expand All @@ -213,7 +240,23 @@ def _create_init_content(self, routes: list[Route]) -> None:

def _create_schema_content(self, routes: list[Route]) -> None:
"""Creates the 'schema.py' file content."""
pass
file_imports = [
[Import(root="pydantic", items=["BaseModel", "Field"], add_dot=False)]
]
route_schema_models = [
route.schema_model_content()
for route in routes
if route.method not in [RouteMethods.GET, RouteMethods.DELETE]
]
self.schema_content = "\n".join(
[
Imports(items=file_imports).to_str(),
"",
self._build_base_schema_models()
+ "\n\n"
+ "\n\n".join(route_schema_models),
]
)

def _create_responses_content(self, routes: list[Route]) -> None:
"""Creates the 'responses.py' file content."""
Expand All @@ -228,9 +271,10 @@ def _create_responses_content(self, routes: list[Route]) -> None:
file_imports = [
[
Import(
root="app",
modules=["api", self.name.plural, "schema"],
root=".",
modules=["schema"],
items=schema_models,
add_dot=False,
)
],
[
Expand All @@ -256,7 +300,7 @@ def _create_responses_content(self, routes: list[Route]) -> None:
def _update_files(self) -> None:
"""Updates the '__init__.py', 'schema.py', and 'responses.py' files."""
self.asset_paths.init_file.write_text(self.init_content)
# self.asset_paths.schema_file.write_text(self.schema_content)
self.asset_paths.schema_file.write_text(self.schema_content)
self.asset_paths.responses_file.write_text(self.response_content)

def _create_route_files(self) -> None:
Expand All @@ -275,7 +319,7 @@ def get_tasks_for_set(self) -> list[Callable]:

routes = self._get_routes()
self._create_init_content(routes)
# self._create_schema_content(routes)
self._create_schema_content(routes)
self._create_responses_content(routes)

tasks.extend([self._update_files])
Expand Down
19 changes: 12 additions & 7 deletions zentra_api/cli/constants/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class Route(BaseModel):
def model_post_init(self, __context: Any) -> None:
self._func_name = f"{self.func_name_start()}_{self.name.lower()}"
self._response_model = self.set_response_model()
self._schema_model = self.set_schema_model()
self._schema_model = self.set_schema_model_name()

details = RouteDefaultDetails(
method=self.method,
Expand Down Expand Up @@ -207,7 +207,7 @@ def set_response_model(self) -> str:
name = self.name.title()
return f"{method}{name}Response"

def set_schema_model(self) -> str | None:
def set_schema_model_name(self) -> str | None:
"""Creates the schema model (parameter) name."""
if self.method == RouteMethods.GET or self.method == RouteMethods.DELETE:
return None
Expand Down Expand Up @@ -247,9 +247,7 @@ def to_str(self, name: Name) -> str:
return "\n".join(text).rstrip()

def response_model_class(self, name: Name) -> str:
"""
Creates the route response model class.
"""
"""Creates the route response model class."""

def data_type() -> str:
"""A helper method for creating the response data type T."""
Expand All @@ -269,6 +267,13 @@ class {self.response_model}(SuccessResponse[{data_type()}]):
pass
''').strip("\n")

def schema_model_content(self) -> str:
"""Creates the schema model content."""
return textwrap.dedent(f"""
class {self.schema_model}(BaseModel):
pass
""").lstrip("\n")


def route_dict_set(name: Name) -> dict[str, Route]:
"""
Expand Down Expand Up @@ -383,7 +388,7 @@ def db_get_method(param: str | None = None) -> str:
{out_name} = CONNECT.{name.plural}.create(db, {name.singular}.model_dump())
return {response_model}(
code=status.HTTP_201_CREATED,
data={name.singular}.model_dump(),
data={name.singular.title()}ID(id={out_name}.id).model_dump(),
)
""")
elif method == RouteMethods.PATCH or method == RouteMethods.PUT:
Expand All @@ -398,7 +403,7 @@ def db_get_method(param: str | None = None) -> str:
{out_name} = CONNECT.{name.plural}.{db_get_method()}
return {response_model}(
code=status.HTTP_202_ACCEPTED,
data={out_name}.model_dump(),
data={name.singular.title()}ID(id=id).model_dump(),
)
""")
elif method == RouteMethods.DELETE:
Expand Down

0 comments on commit cd7807f

Please sign in to comment.