Skip to content

Commit

Permalink
Migration complete
Browse files Browse the repository at this point in the history
  • Loading branch information
goFrendiAsgard committed Nov 20, 2024
1 parent 39c37d8 commit 4c48536
Show file tree
Hide file tree
Showing 23 changed files with 714 additions and 130 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__pycache__
.venv
*.db
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# App Name

# Principle

- Developer should be able to override everything with sane amount of code.
- Simple tasks should only require small amount of code.
- A bit of magic is okay as long as transparent and documented.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Callable, Generic, Type, TypeVar

from common.error import NotFoundError
from sqlalchemy import Engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from sqlmodel import Session, SQLModel, select
Expand All @@ -15,6 +16,7 @@ class BaseDBRepository(Generic[DBModel, ResponseModel, CreateModel, UpdateModel]
response_model: Type[ResponseModel]
create_model: Type[CreateModel]
update_model: Type[UpdateModel]
entity_name: str = "entity"
column_preprocessors: dict[str, Callable[[Any], Any]] = {}

def __init__(self, engine: Engine | AsyncEngine):
Expand All @@ -30,7 +32,6 @@ async def create(self, data: CreateModel) -> ResponseModel:
if key in data_dict:
data_dict[key] = preprocessor(data_dict[key])
db_instance = self.db_model(**data_dict)

if self.is_async:
async with AsyncSession(self.engine) as session:
session.add(db_instance)
Expand All @@ -41,44 +42,41 @@ async def create(self, data: CreateModel) -> ResponseModel:
session.add(db_instance)
session.commit()
session.refresh(db_instance)

return self._to_response(db_instance)

async def get_by_id(self, item_id: str) -> ResponseModel | None:
async def get_by_id(self, item_id: str) -> ResponseModel:
if self.is_async:
async with AsyncSession(self.engine) as session:
db_instance = await session.get(self.db_model, item_id)
else:
with Session(self.engine) as session:
db_instance = session.get(self.db_model, item_id)

return self._to_response(db_instance) if db_instance else None
if not db_instance:
raise NotFoundError(f"{self.entity_name} not found")
return self._to_response(db_instance)

async def get_all(self, page: int = 1, page_size: int = 10) -> list[ResponseModel]:
offset = (page - 1) * page_size
statement = select(self.db_model).offset(offset).limit(page_size)

if self.is_async:
async with AsyncSession(self.engine) as session:
result = await session.execute(statement)
results = result.scalars().all()
else:
with Session(self.engine) as session:
results = session.exec(statement).all()

return [self._to_response(instance) for instance in results]

async def update(self, item_id: str, data: UpdateModel) -> ResponseModel | None:
async def update(self, item_id: str, data: UpdateModel) -> ResponseModel:
update_data = data.model_dump(exclude_unset=True)
for key, value in update_data.items():
if key in self.column_preprocessors:
update_data[key] = self.column_preprocessors[key](value)

if self.is_async:
async with AsyncSession(self.engine) as session:
db_instance = await session.get(self.db_model, item_id)
if not db_instance:
return None
raise NotFoundError(f"{self.entity_name} not found")
for key, value in update_data.items():
setattr(db_instance, key, value)
session.add(db_instance)
Expand All @@ -88,32 +86,30 @@ async def update(self, item_id: str, data: UpdateModel) -> ResponseModel | None:
with Session(self.engine) as session:
db_instance = session.get(self.db_model, item_id)
if not db_instance:
return None
raise NotFoundError(f"{self.entity_name} not found")
for key, value in update_data.items():
setattr(db_instance, key, value)
session.add(db_instance)
session.commit()
session.refresh(db_instance)

return self._to_response(db_instance)

async def delete(self, item_id: str) -> bool:
async def delete(self, item_id: str) -> ResponseModel:
if self.is_async:
async with AsyncSession(self.engine) as session:
db_instance = await session.get(self.db_model, item_id)
if not db_instance:
return False
raise NotFoundError(f"{self.entity_name} not found")
await session.delete(db_instance)
await session.commit()
else:
with Session(self.engine) as session:
db_instance = session.get(self.db_model, item_id)
if not db_instance:
return False
raise NotFoundError(f"{self.entity_name} not found")
session.delete(db_instance)
session.commit()

return True
return self._to_response(db_instance)

async def create_bulk(self, data_list: list[CreateModel]) -> list[ResponseModel]:
db_instances = []
Expand All @@ -123,7 +119,6 @@ async def create_bulk(self, data_list: list[CreateModel]) -> list[ResponseModel]
if key in data_dict:
data_dict[key] = preprocessor(data_dict[key])
db_instances.append(self.db_model(**data_dict))

if self.is_async:
async with AsyncSession(self.engine) as session:
session.add_all(db_instances)
Expand All @@ -136,5 +131,4 @@ async def create_bulk(self, data_list: list[CreateModel]) -> list[ResponseModel]
session.commit()
for instance in db_instances:
session.refresh(instance)

return [self._to_response(instance) for instance in db_instances]
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import Dict

from fastapi import HTTPException


class NotFoundError(HTTPException):
def __init__(self, message: str, headers: Dict[str, str] | None = None) -> None:
super().__init__(404, {"message": message}, headers)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from starlette.responses import JSONResponse, Response


class ApiRouteParam:
class RouteParam:
def __init__(
self,
path: str,
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(


class BaseUsecase:
_methods: dict[str, ApiRouteParam] = {}
_route_params: dict[str, RouteParam] = {}

@classmethod
def route(
Expand Down Expand Up @@ -100,9 +100,7 @@ def route(
"""

def decorator(func: Callable):
if not hasattr(cls, "_methods"):
cls._methods = {}
cls._methods[func.__name__] = ApiRouteParam(
cls._route_params[func.__name__] = RouteParam(
path=path,
response_model=response_model,
status_code=status_code,
Expand Down Expand Up @@ -141,7 +139,7 @@ def as_direct_client(self):
"""
Dynamically create a direct client class.
"""
_methods = self._methods
_methods = self._route_params
DirectClient = create_client_class("DirectClient")
for name, details in _methods.items():
func = details.func
Expand All @@ -153,7 +151,7 @@ def as_api_client(self, base_url: str):
"""
Dynamically create an API client class.
"""
_methods = self._methods
_methods = self._route_params
APIClient = create_client_class("APIClient")
# Dynamically generate methods
for name, param in _methods.items():
Expand All @@ -165,7 +163,7 @@ def serve_route(self, app: APIRouter):
"""
Dynamically add routes to FastAPI.
"""
for _, route_param in self._methods.items():
for _, route_param in self._route_params.items():
bound_func = partial(route_param.func, self)
bound_func.__name__ = route_param.func.__name__
bound_func.__doc__ = route_param.func.__doc__
Expand Down Expand Up @@ -212,7 +210,7 @@ async def client_method(self, *args, **kwargs):
return client_method


def create_api_client_method(param: ApiRouteParam, base_url: str):
def create_api_client_method(param: RouteParam, base_url: str):
_url = param.path
_methods = [method.lower() for method in param.methods]

Expand Down
19 changes: 10 additions & 9 deletions src/zrb/builtin/project/add/fastapp/application/app_name/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

APP_PATH = os.path.dirname(__file__)

APP_MODE = os.getenv("APP_NAME_MODE", "monolith")
APP_MODULES = [
module.strip()
Expand All @@ -12,14 +14,13 @@
)
APP_REPOSITORY_TYPE = os.getenv("APP_REPOSITORY_TYPE", "db")

_DEFAULT_DB_URL = "sqlite:///monolith.db"
if APP_MODE != "monolith":
_DEFAULT_DB_URL = "sqlite:///microservices.db"
APP_DB_URL = os.getenv("APP_DB_URL", _DEFAULT_DB_URL)

_DEFAULT_MIGRATION_TABLE = "migration_table"
if APP_MODE != "monolith" and len(APP_MODULES) > 0:
_DEFAULT_MIGRATION_TABLE = f"{APP_MODULES[0]}_{_DEFAULT_MIGRATION_TABLE}"
APP_DB_MIGRATION_TABLE = os.getenv("APP_DB_MIGRATION_TABLE", _DEFAULT_MIGRATION_TABLE)
APP_DB_URL = os.getenv(
"APP_DB_URL",
(
f"sqlite:///{APP_PATH}/monolith.db"
if APP_MODE == "monolith" or len(APP_MODULES) == 0
else f"sqlite:///{APP_PATH}/{APP_MODULES[0]}_microservices.db"
),
)

APP_AUTH_BASE_URL = os.getenv("APP_NAME_AUTH_BASE_URL", "http://localhost:3001")
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# A generic, single database configuration.

[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = migration

# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s

# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = ../..

# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =

# max length of characters to apply to the "slug" field
# truncate_slug_length = 40

# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false

# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false

# version location specification; This defaults
# to migration/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:migration/versions

# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
# version_path_separator = newline
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.

# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false

# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8

sqlalchemy.url = driver://user:pass@localhost/dbname


[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples

# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME

# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME

# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic

[handlers]
keys = console

[formatters]
keys = generic

[logger_root]
level = WARNING
handlers = console
qualname =

[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine

[logger_alembic]
level = INFO
handlers =
qualname = alembic

[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic

[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Generic single-database configuration.
Loading

0 comments on commit 4c48536

Please sign in to comment.