diff --git a/Makefile b/Makefile index dfe4932..83ece77 100644 --- a/Makefile +++ b/Makefile @@ -29,7 +29,7 @@ format: # Run formatters lint: ## Run linters uv run -m ruff format --check uv run -m ruff check - uv run -m mypy app + uv run -m mypy app tests build: ## Build the Docker image docker compose build app diff --git a/alembic/versions/20240909_154014_1c15c81ce8a8_.py b/alembic/versions/20240909_154014_1c15c81ce8a8_.py new file mode 100644 index 0000000..503923a --- /dev/null +++ b/alembic/versions/20240909_154014_1c15c81ce8a8_.py @@ -0,0 +1,50 @@ +"""empty message + +Revision ID: 1c15c81ce8a8 +Revises: b56e853bb310 +Create Date: 2024-09-09 15:40:14.071295 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "1c15c81ce8a8" +down_revision: str | None = "b56e853bb310" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "task_registry", + sa.Column("task_name", sa.String(), nullable=False), + sa.Column("last_run", sa.DateTime(timezone=True), nullable=True), + sa.Column("last_duration", sa.Float(), server_default=sa.text("0"), nullable=False), + sa.Column("last_error", sa.String(), nullable=True), + sa.PrimaryKeyConstraint("task_name", name=op.f("pk_task_registry")), + ) + op.drop_column("job", "updated_at") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "job", + sa.Column( + "updated_at", + postgresql.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + autoincrement=False, + nullable=False, + ), + ) + op.drop_table("task_registry") + # ### end Alembic commands ### diff --git a/app/db/model.py b/app/db/model.py index 56006b4..57a9347 100644 --- a/app/db/model.py +++ b/app/db/model.py @@ -61,7 +61,6 @@ class Job(Base): service_type: Mapped[ServiceType] service_subtype: Mapped[ServiceSubtype] created_at: Mapped[CREATED_AT] - updated_at: Mapped[UPDATED_AT] reserved_at: Mapped[datetime | None] started_at: Mapped[datetime | None] last_alive_at: Mapped[datetime | None] @@ -130,6 +129,17 @@ class Price(Base): updated_at: Mapped[UPDATED_AT] +class TaskRegistry(Base): + """TaskRegistry table.""" + + __tablename__ = "task_registry" + + task_name: Mapped[str] = mapped_column(primary_key=True) + last_run: Mapped[datetime | None] + last_duration: Mapped[float] = mapped_column(server_default=text("0")) + last_error: Mapped[str | None] + + Index( "only_one_system_account", Account.account_type, diff --git a/app/db/session.py b/app/db/session.py index 01ccec2..e9d7453 100644 --- a/app/db/session.py +++ b/app/db/session.py @@ -33,8 +33,12 @@ async def close(self) -> None: L.info("DB engine has been closed") @asynccontextmanager - async def session(self) -> AsyncIterator[AsyncSession]: - """Yield a new database session.""" + async def session(self, *, commit: bool = True) -> AsyncIterator[AsyncSession]: + """Yield a new database session. + + Args: + commit: if True and no errors occurred, commit before closing the session. + """ if not self._engine: err = "DB engine not initialized" raise RuntimeError(err) @@ -50,7 +54,8 @@ async def session(self) -> AsyncIterator[AsyncSession]: await session.rollback() raise else: - await session.commit() + if commit: + await session.commit() database_session_manager = DatabaseSessionManager() diff --git a/app/db/utils.py b/app/db/utils.py index e867eb1..2a3e2fa 100644 --- a/app/db/utils.py +++ b/app/db/utils.py @@ -2,7 +2,10 @@ from collections.abc import AsyncIterator, Callable from contextlib import asynccontextmanager +from datetime import datetime +import sqlalchemy as sa +from sqlalchemy import func from sqlalchemy.ext.asyncio import AsyncSession @@ -22,3 +25,21 @@ async def try_nested( else: if on_success: on_success() + + +async def current_timestamp(db: AsyncSession) -> datetime: + """Return the start datetime of the current transaction. + + The returned value does not change during the transaction. + """ + query = sa.select(func.current_timestamp()) + return (await db.execute(query)).scalar_one() + + +async def clock_timestamp(db: AsyncSession) -> datetime: + """Return the actual current time from the database. + + The returned value changes every time the function is called. + """ + query = sa.select(func.clock_timestamp()) + return (await db.execute(query)).scalar_one() diff --git a/app/repository/group.py b/app/repository/group.py index 31c7673..5cae176 100644 --- a/app/repository/group.py +++ b/app/repository/group.py @@ -10,6 +10,7 @@ from app.repository.ledger import LedgerRepository from app.repository.price import PriceRepository from app.repository.report import ReportRepository +from app.repository.task_registry import TaskRegistryRepository class RepositoryGroup: @@ -24,6 +25,7 @@ def __init__( ledger_repo_class: type[LedgerRepository] = LedgerRepository, price_repo_class: type[PriceRepository] = PriceRepository, report_repo_class: type[ReportRepository] = ReportRepository, + task_registry_class: type[TaskRegistryRepository] = TaskRegistryRepository, ) -> None: """Init the repository group.""" self._db = db @@ -33,6 +35,7 @@ def __init__( self._ledger_repo_class = ledger_repo_class self._price_repo_class = price_repo_class self._report_repo_class = report_repo_class + self._task_registry_class = task_registry_class @property def db(self) -> AsyncSession: @@ -68,3 +71,8 @@ def price(self) -> PriceRepository: def report(self) -> ReportRepository: """Return the report repository.""" return self._report_repo_class(self.db) + + @cached_property + def task_registry(self) -> TaskRegistryRepository: + """Return the task_registry repository.""" + return self._task_registry_class(self.db) diff --git a/app/repository/task_registry.py b/app/repository/task_registry.py new file mode 100644 index 0000000..09fc871 --- /dev/null +++ b/app/repository/task_registry.py @@ -0,0 +1,76 @@ +"""Task registry module.""" + +from datetime import datetime + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql as pg +from sqlalchemy.exc import DBAPIError + +from app.db.model import TaskRegistry +from app.repository.base import BaseRepository + + +class TaskRegistryRepository(BaseRepository): + """TaskRegistryRepository.""" + + async def populate_task(self, task_name: str) -> TaskRegistry | None: + """Insert a task record if it doesn't exist already. + + Args: + task_name: name of the task. + """ + insert_query = ( + pg.insert(TaskRegistry) + .values(task_name=task_name) + .on_conflict_do_nothing() + .returning(TaskRegistry) + ) + return (await self.db.execute(insert_query)).scalar_one_or_none() + + async def get_locked_task(self, task_name: str) -> TaskRegistry | None: + """Lock and return a record from the task registry, or None if already locked. + + Args: + task_name: name of the task. + """ + select_query = ( + sa.select(TaskRegistry) + .where(TaskRegistry.task_name == task_name) + .with_for_update(nowait=True) + ) + try: + # ensure that the record exists and that it can be locked + return (await self.db.execute(select_query)).scalar_one() + except DBAPIError as ex: + if getattr(ex.orig, "pgcode", None) == "55P03": + # Lock Not Available: the record exists, but it cannot be locked + return None + raise + + async def update_task( + self, + task_name: str, + *, + last_run: datetime, + last_duration: float, + last_error: str | None, + ) -> TaskRegistry: + """Update an existing task in the registry. + + Args: + task_name: name of the task. + last_run: last start time. + last_duration: last duration in seconds. + last_error: traceback from the last task execution, or None. + """ + query = ( + sa.update(TaskRegistry) + .values( + last_run=last_run, + last_duration=last_duration, + last_error=last_error, + ) + .where(TaskRegistry.task_name == task_name) + .returning(TaskRegistry) + ) + return (await self.db.execute(query)).scalar_one() diff --git a/app/schema/domain.py b/app/schema/domain.py index 6119085..2edcd17 100644 --- a/app/schema/domain.py +++ b/app/schema/domain.py @@ -77,7 +77,16 @@ class StartedJob(BaseJob): @dataclass(kw_only=True) -class ChargeLongrunResult: +class TaskResult: + """Result of a generic task.""" + + success: int = 0 + failure: int = 0 + state: dict | None = None + + +@dataclass(kw_only=True) +class ChargeLongrunResult(TaskResult): """Result of charge_longrun.""" unfinished_uncharged: int = 0 @@ -87,20 +96,13 @@ class ChargeLongrunResult: finished_overcharged: int = 0 expired_uncharged: int = 0 expired_charged: int = 0 - failure: int = 0 @dataclass(kw_only=True) -class ChargeOneshotResult: +class ChargeOneshotResult(TaskResult): """Result of charge_oneshot.""" - success: int = 0 - failure: int = 0 - @dataclass(kw_only=True) -class ChargeStorageResult: +class ChargeStorageResult(TaskResult): """Result of charge_storage.""" - - success: int = 0 - failure: int = 0 diff --git a/app/service/charge_longrun.py b/app/service/charge_longrun.py index 375854f..407c9a0 100644 --- a/app/service/charge_longrun.py +++ b/app/service/charge_longrun.py @@ -131,7 +131,7 @@ async def _charge_generic( ) -async def charge_longrun( +async def charge_longrun( # noqa: C901 repos: RepositoryGroup, min_charging_interval: float = 0.0, min_charging_amount: Decimal = D0, @@ -155,11 +155,14 @@ def _on_error() -> None: L.exception("Error processing longrun job {}", job.id) result.failure += 1 + def _on_success() -> None: + result.success += 1 + now = transaction_datetime or utcnow() result = ChargeLongrunResult() jobs = await repos.job.get_longrun_to_be_charged() for job in jobs: - async with try_nested(repos.db, on_error=_on_error): + async with try_nested(repos.db, on_error=_on_error, on_success=_on_success): match job: case StartedJob( last_alive_at=datetime() as last_alive_at, diff --git a/app/service/charge_oneshot.py b/app/service/charge_oneshot.py index f7943b3..3c24a3e 100644 --- a/app/service/charge_oneshot.py +++ b/app/service/charge_oneshot.py @@ -84,7 +84,6 @@ async def charge_oneshot(repos: RepositoryGroup) -> ChargeOneshotResult: Args: repos: repository group instance. - jobs: optional sequence of jobs. """ def _on_error() -> None: diff --git a/app/task/job_charger/base.py b/app/task/job_charger/base.py index bcbf223..e547d5e 100644 --- a/app/task/job_charger/base.py +++ b/app/task/job_charger/base.py @@ -1,9 +1,17 @@ """Abstract task.""" import asyncio +import traceback from abc import ABC, abstractmethod +from app.db.model import TaskRegistry +from app.db.session import database_session_manager +from app.db.utils import current_timestamp from app.logger import L +from app.repository.group import RepositoryGroup +from app.repository.task_registry import TaskRegistryRepository +from app.schema.domain import TaskResult +from app.utils import Timer class BaseTask(ABC): @@ -51,6 +59,10 @@ def _counter(self) -> int: """Return the total number of loops executed.""" return self._success + self._failure + @abstractmethod + async def _prepare(self) -> None: + """Prepare the task before entering the loop.""" + @abstractmethod async def _run_once(self) -> None: """Execute one loop.""" @@ -63,6 +75,7 @@ async def run_forever(self, limit: int = 0) -> None: """ self.logger.info("Starting {}", self.name) await asyncio.sleep(self._initial_delay) + await self._prepare() while True: try: await self._run_once() @@ -76,3 +89,49 @@ async def run_forever(self, limit: int = 0) -> None: if 0 < limit <= self._counter: break await asyncio.sleep(sleep) + + +class RegisteredTask(BaseTask, ABC): + """RegisteredTask.""" + + async def _prepare(self) -> None: + async with database_session_manager.session() as db: + task_registry_repo = TaskRegistryRepository(db=db) + if await task_registry_repo.populate_task(self.name): + self.logger.info("Populating the task registry") + + async def _run_once(self) -> None: + async with database_session_manager.session(commit=False) as db: + task_registry_repo = TaskRegistryRepository(db=db) + if not (task := await task_registry_repo.get_locked_task(self.name)): + self.logger.info("Skipping task because the task registry is locked") + return + error = None + start_timestamp = await current_timestamp(db) + self.logger.info("Lock acquired at {}", start_timestamp) + try: + with Timer() as timer: + # create a new session that is rolled back if any error happens inside the task + async with database_session_manager.session() as task_session: + repos = RepositoryGroup(db=task_session) + await self._run_once_logic(repos, task) + except Exception: + error = traceback.format_exc() + raise + finally: + await task_registry_repo.update_task( + self.name, + last_run=start_timestamp, + last_duration=timer.elapsed, + last_error=error, + ) + await db.commit() + + @abstractmethod + async def _run_once_logic(self, repos: RepositoryGroup, task: TaskRegistry) -> TaskResult: + """Execute the actual task logic call. + + Args: + repos: RepositoryGroup instance. + task: TaskRegistry instance. + """ diff --git a/app/task/job_charger/longrun.py b/app/task/job_charger/longrun.py index 2b0a33e..2308e9c 100644 --- a/app/task/job_charger/longrun.py +++ b/app/task/job_charger/longrun.py @@ -1,13 +1,14 @@ """Longrun job charger.""" from app.config import settings -from app.db.session import database_session_manager +from app.db.model import TaskRegistry from app.repository.group import RepositoryGroup +from app.schema.domain import TaskResult from app.service.charge_longrun import charge_longrun -from app.task.job_charger.base import BaseTask +from app.task.job_charger.base import RegisteredTask -class PeriodicLongrunCharger(BaseTask): +class PeriodicLongrunCharger(RegisteredTask): """PeriodicLongrunCharger.""" def __init__(self, name: str, initial_delay: int = 0) -> None: @@ -19,12 +20,14 @@ def __init__(self, name: str, initial_delay: int = 0) -> None: error_sleep=settings.CHARGE_LONGRUN_ERROR_SLEEP, ) - async def _run_once(self) -> None: # noqa: PLR6301 - async with database_session_manager.session() as db: - repos = RepositoryGroup(db=db) - await charge_longrun( - repos=repos, - min_charging_interval=settings.CHARGE_LONGRUN_MIN_CHARGING_INTERVAL, - min_charging_amount=settings.CHARGE_LONGRUN_MIN_CHARGING_AMOUNT, - expiration_interval=settings.CHARGE_LONGRUN_EXPIRATION_INTERVAL, - ) + async def _run_once_logic( # noqa: PLR6301 + self, + repos: RepositoryGroup, + task: TaskRegistry, # noqa: ARG002 + ) -> TaskResult: + return await charge_longrun( + repos=repos, + min_charging_interval=settings.CHARGE_LONGRUN_MIN_CHARGING_INTERVAL, + min_charging_amount=settings.CHARGE_LONGRUN_MIN_CHARGING_AMOUNT, + expiration_interval=settings.CHARGE_LONGRUN_EXPIRATION_INTERVAL, + ) diff --git a/app/task/job_charger/oneshot.py b/app/task/job_charger/oneshot.py index 520eec8..c5b0320 100644 --- a/app/task/job_charger/oneshot.py +++ b/app/task/job_charger/oneshot.py @@ -1,13 +1,14 @@ """Oneshot job charger.""" from app.config import settings -from app.db.session import database_session_manager +from app.db.model import TaskRegistry from app.repository.group import RepositoryGroup +from app.schema.domain import TaskResult from app.service.charge_oneshot import charge_oneshot -from app.task.job_charger.base import BaseTask +from app.task.job_charger.base import RegisteredTask -class PeriodicOneshotCharger(BaseTask): +class PeriodicOneshotCharger(RegisteredTask): """PeriodicOneshotCharger.""" def __init__(self, name: str, initial_delay: int = 0) -> None: @@ -19,7 +20,9 @@ def __init__(self, name: str, initial_delay: int = 0) -> None: error_sleep=settings.CHARGE_ONESHOT_ERROR_SLEEP, ) - async def _run_once(self) -> None: # noqa: PLR6301 - async with database_session_manager.session() as db: - repos = RepositoryGroup(db=db) - await charge_oneshot(repos=repos) + async def _run_once_logic( # noqa: PLR6301 + self, + repos: RepositoryGroup, + task: TaskRegistry, # noqa: ARG002 + ) -> TaskResult: + return await charge_oneshot(repos=repos) diff --git a/app/task/job_charger/storage.py b/app/task/job_charger/storage.py index 19eb423..e9dc03b 100644 --- a/app/task/job_charger/storage.py +++ b/app/task/job_charger/storage.py @@ -1,13 +1,14 @@ """Storage charger.""" from app.config import settings -from app.db.session import database_session_manager +from app.db.model import TaskRegistry from app.repository.group import RepositoryGroup +from app.schema.domain import TaskResult from app.service.charge_storage import charge_storage -from app.task.job_charger.base import BaseTask +from app.task.job_charger.base import RegisteredTask -class PeriodicStorageCharger(BaseTask): +class PeriodicStorageCharger(RegisteredTask): """PeriodicStorageCharger.""" def __init__(self, name: str, initial_delay: int = 0) -> None: @@ -15,21 +16,27 @@ def __init__(self, name: str, initial_delay: int = 0) -> None: super().__init__( name=name, initial_delay=initial_delay, - loop_sleep=settings.CHARGE_LONGRUN_LOOP_SLEEP, - error_sleep=settings.CHARGE_LONGRUN_ERROR_SLEEP, + loop_sleep=settings.CHARGE_STORAGE_LOOP_SLEEP, + error_sleep=settings.CHARGE_STORAGE_ERROR_SLEEP, ) - async def _run_once(self) -> None: # noqa: PLR6301 - async with database_session_manager.session() as db: - repos = RepositoryGroup(db=db) - # get and charge finished jobs not charged or partially charged - jobs = await repos.job.get_storage_finished_to_be_charged() - await charge_storage(repos=repos, jobs=jobs) - # get and charge running jobs - jobs = await repos.job.get_storage_running() - await charge_storage( - repos=repos, - jobs=jobs, - min_charging_interval=settings.CHARGE_STORAGE_MIN_CHARGING_INTERVAL, - min_charging_amount=settings.CHARGE_STORAGE_MIN_CHARGING_AMOUNT, - ) + async def _run_once_logic( # noqa: PLR6301 + self, + repos: RepositoryGroup, + task: TaskRegistry, # noqa: ARG002 + ) -> TaskResult: + # get and charge finished jobs not charged or partially charged + jobs = await repos.job.get_storage_finished_to_be_charged() + result1 = await charge_storage(repos=repos, jobs=jobs) + # get and charge running jobs + jobs = await repos.job.get_storage_running() + result2 = await charge_storage( + repos=repos, + jobs=jobs, + min_charging_interval=settings.CHARGE_STORAGE_MIN_CHARGING_INTERVAL, + min_charging_amount=settings.CHARGE_STORAGE_MIN_CHARGING_AMOUNT, + ) + return TaskResult( + success=result1.success + result2.success, + failure=result1.failure + result2.failure, + ) diff --git a/app/utils.py b/app/utils.py index f0dceb8..6384093 100644 --- a/app/utils.py +++ b/app/utils.py @@ -2,7 +2,10 @@ import time import uuid +from collections.abc import Callable from datetime import UTC, datetime +from types import TracebackType +from typing import Self def since_unix_epoch() -> int: @@ -18,3 +21,39 @@ def utcnow() -> datetime: def create_uuid() -> uuid.UUID: """Return a new random UUID.""" return uuid.uuid4() + + +class Timer: + """Timer context manager. + + Usage example: + + >>> with Timer() as timer: + ... print(timer.elapsed) + ... print(timer.elapsed) + >>> print(timer.elapsed) + """ + + def __init__(self, timer: Callable[[], float] = time.perf_counter) -> None: + """Init the timer.""" + self._timer = timer + self._started_at = self._timer() + self._stopped_at: float | None = None + + def __enter__(self) -> Self: + """Initialize when entering the context manager.""" + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Cleanup when exiting the context manager.""" + self._stopped_at = self._timer() + + @property + def elapsed(self) -> float: + """Return the elapsed time in seconds.""" + return (self._stopped_at or self._timer()) - self._started_at diff --git a/pyproject.toml b/pyproject.toml index 6cd90d0..2181388 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,7 @@ combine-as-imports = true ] "tests/*.py" = [ "ANN", # Missing type annotation + "ARG002", # Unused method argument "D", # pydocstyle "ERA001", # Found commented-out code "INP001", # Missing `__init__.py` diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/api/__init__.py b/tests/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py index df70207..b98aef3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,9 @@ import asyncio from collections.abc import AsyncIterator -from contextlib import asynccontextmanager +from contextlib import AbstractAsyncContextManager, asynccontextmanager from datetime import datetime from decimal import Decimal +from typing import Protocol from unittest.mock import patch from uuid import UUID @@ -26,6 +27,10 @@ from tests.utils import truncate_tables +class SessionFactory(Protocol): + def __call__(self, *, commit: bool = False) -> AbstractAsyncContextManager[AsyncSession]: ... + + @pytest.fixture(scope="session") def event_loop(): loop = asyncio.new_event_loop() @@ -44,16 +49,27 @@ async def _database_context(): @pytest.fixture -async def db() -> AsyncIterator[AsyncSession]: - async with database_session_manager.session() as session: +async def db_session_factory() -> SessionFactory: + return database_session_manager.session + + +@pytest.fixture +async def db(db_session_factory) -> AsyncIterator[AsyncSession]: + async with db_session_factory(commit=False) as session: + yield session + + +@pytest.fixture +async def db2(db_session_factory) -> AsyncIterator[AsyncSession]: + async with db_session_factory(commit=False) as session: yield session @pytest.fixture(autouse=True) -async def _db_cleanup(db): +async def _db_cleanup(db_session_factory): yield - await db.rollback() - await truncate_tables(db) + async with db_session_factory(commit=True) as session: + await truncate_tables(session) @pytest.fixture diff --git a/tests/constants.py b/tests/constants.py index 2fca9e3..265853c 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -5,7 +5,7 @@ @dataclasses.dataclass(frozen=True) -class UUIDS: +class _UUIDS: SYS: UUID VLAB: list[UUID] PROJ: list[UUID] @@ -13,7 +13,7 @@ class UUIDS: JOB: list[UUID] -UUIDS = UUIDS( +UUIDS = _UUIDS( SYS=UUID("00000000-0000-0000-0000-000000000001"), VLAB=[ UUID("1b3bd3f4-3441-41b0-8fae-83a30c133dc2"), diff --git a/tests/db/__init__.py b/tests/db/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/db/test_utils.py b/tests/db/test_utils.py new file mode 100644 index 0000000..61b81eb --- /dev/null +++ b/tests/db/test_utils.py @@ -0,0 +1,25 @@ +from datetime import datetime + +from app.db import utils as test_module + + +async def test_current_timestamp(db): + result1 = await test_module.current_timestamp(db) + + assert isinstance(result1, datetime) + assert result1.tzname() == "UTC" + + result2 = await test_module.current_timestamp(db) + + assert result2 == result1 + + +async def test_clock_timestamp(db): + result1 = await test_module.clock_timestamp(db) + + assert isinstance(result1, datetime) + assert result1.tzname() == "UTC" + + result2 = await test_module.clock_timestamp(db) + + assert result2 > result1 diff --git a/tests/queue/__init__.py b/tests/queue/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/repository/__init__.py b/tests/repository/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/repository/test_group.py b/tests/repository/test_group.py new file mode 100644 index 0000000..cb5253c --- /dev/null +++ b/tests/repository/test_group.py @@ -0,0 +1,20 @@ +from app.repository import group as test_module +from app.repository.account import AccountRepository +from app.repository.event import EventRepository +from app.repository.job import JobRepository +from app.repository.ledger import LedgerRepository +from app.repository.price import PriceRepository +from app.repository.report import ReportRepository +from app.repository.task_registry import TaskRegistryRepository + + +async def test_repository_group(db): + repos = test_module.RepositoryGroup(db) + + assert isinstance(repos.account, AccountRepository) + assert isinstance(repos.event, EventRepository) + assert isinstance(repos.job, JobRepository) + assert isinstance(repos.ledger, LedgerRepository) + assert isinstance(repos.price, PriceRepository) + assert isinstance(repos.report, ReportRepository) + assert isinstance(repos.task_registry, TaskRegistryRepository) diff --git a/tests/repository/test_task_registry.py b/tests/repository/test_task_registry.py new file mode 100644 index 0000000..9cea430 --- /dev/null +++ b/tests/repository/test_task_registry.py @@ -0,0 +1,56 @@ +from datetime import UTC, datetime + +import pytest +from sqlalchemy.exc import NoResultFound + +from app.db.model import TaskRegistry +from app.repository import task_registry as test_module + + +async def test_lifecycle(db, db2): + repo = test_module.TaskRegistryRepository(db) + repo2 = test_module.TaskRegistryRepository(db2) + task_name = "test_task" + + # check that an error is raised if the record is missing + with pytest.raises(NoResultFound): + await repo.get_locked_task(task_name) + + # populate the table + result = await repo.populate_task(task_name) + assert isinstance(result, TaskRegistry) + assert result.task_name == task_name + assert result.last_run is None + assert result.last_duration == 0 + assert result.last_error is None + + # ensure that the record is visible in the 2nd connection. without commit, + # any other insertion would be locked until the transaction is committed + await db.commit() + + # check that a new insertion is ignored + result = await repo2.populate_task(task_name) + assert result is None + + # lock the record from the 1st connection + result = await repo.get_locked_task(task_name) + assert isinstance(result, TaskRegistry) + assert result.task_name == task_name + assert result.last_run is None + assert result.last_duration == 0 + assert result.last_error is None + + # check that's not possible to get the same task from the 2nd connection + result = await repo2.get_locked_task(task_name) + assert result is None + + # update the task and check the result + last_run = datetime.now(tz=UTC) + result = await repo.update_task( + task_name, last_run=last_run, last_duration=100, last_error=None + ) + assert isinstance(result, TaskRegistry) + assert result.task_name == task_name + assert result.last_run == last_run + assert result.last_duration == 100 + assert result.last_error is None diff --git a/tests/service/__init__.py b/tests/service/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/service/test_charge_longrun.py b/tests/service/test_charge_longrun.py index 46092f5..1260719 100644 --- a/tests/service/test_charge_longrun.py +++ b/tests/service/test_charge_longrun.py @@ -34,6 +34,7 @@ async def test_charge_longrun(db): result = await test_module.charge_longrun(repos, transaction_datetime=transaction_datetime) assert result == ChargeLongrunResult( unfinished_uncharged=1, + success=1, ) job = await _select_job(db, job_id) assert job.last_charged_at is not None @@ -66,6 +67,7 @@ async def test_charge_longrun(db): result = await test_module.charge_longrun(repos, transaction_datetime=transaction_datetime) assert result == ChargeLongrunResult( unfinished_charged=1, + success=1, ) job = await _select_job(db, job_id) assert job.last_charged_at is not None @@ -102,6 +104,7 @@ async def test_charge_longrun(db): result = await test_module.charge_longrun(repos, transaction_datetime=transaction_datetime) assert result == ChargeLongrunResult( finished_charged=1, + success=1, ) job = await _select_job(db, job_id) assert job.last_charged_at is not None @@ -146,6 +149,7 @@ async def test_charge_longrun_expired_uncharged(db): ) assert result == ChargeLongrunResult( expired_uncharged=1, + success=1, ) job = await _select_job(db, job_id) assert job.last_charged_at == now @@ -198,6 +202,7 @@ async def test_charge_longrun_expired_charged(db): ) assert result == ChargeLongrunResult( expired_charged=1, + success=1, ) expected_amount = 3600 * Decimal("0.01") job = await _select_job(db, job_id) diff --git a/tests/task/__init__.py b/tests/task/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/task/job_charger/__init__.py b/tests/task/job_charger/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/task/job_charger/test_base.py b/tests/task/job_charger/test_base.py new file mode 100644 index 0000000..48e659a --- /dev/null +++ b/tests/task/job_charger/test_base.py @@ -0,0 +1,66 @@ +import sqlalchemy as sa + +from app.db.model import TaskRegistry +from app.repository.group import RepositoryGroup +from app.schema.domain import TaskResult +from app.task.job_charger import base as test_module + + +async def _select_task_registry(db, task_name): + query = sa.select(TaskRegistry).where(TaskRegistry.task_name == task_name) + return (await db.execute(query)).scalar_one_or_none() + + +class DummyRegisteredTask(test_module.RegisteredTask): + def __init__(self, *args, succeeding: bool = True, **kwargs) -> None: + self.succeeding = succeeding + super().__init__(*args, **kwargs) + + async def _run_once_logic(self, repos: RepositoryGroup, task: TaskRegistry) -> TaskResult: + if not self.succeeding: + err = "Application error" + raise RuntimeError(err) + return TaskResult(success=100, failure=10) + + +async def test_registered_task_success(db): + task_name = "dummy" + task = DummyRegisteredTask(task_name) + + task_registry = await _select_task_registry(db, task_name=task_name) + assert task_registry is None + + await task.run_forever(limit=1) + + assert task.get_stats() == { + "counter": 1, + "success": 1, + "failure": 0, + } + + task_registry = await _select_task_registry(db, task_name=task_name) + assert isinstance(task_registry, TaskRegistry) + assert task_registry.task_name == task_name + assert task_registry.last_error is None + + +async def test_registered_task_failure(db): + task_name = "dummy" + task = DummyRegisteredTask(task_name, succeeding=False) + + task_registry = await _select_task_registry(db, task_name=task_name) + assert task_registry is None + + await task.run_forever(limit=1) + + assert task.get_stats() == { + "counter": 1, + "success": 0, + "failure": 1, + } + + task_registry = await _select_task_registry(db, task_name=task_name) + assert isinstance(task_registry, TaskRegistry) + assert task_registry.task_name == task_name + assert isinstance(task_registry.last_error, str) + assert "Application error" in task_registry.last_error diff --git a/tests/task/job_charger/test_longrun.py b/tests/task/job_charger/test_longrun.py index 510af7e..1c9d1df 100644 --- a/tests/task/job_charger/test_longrun.py +++ b/tests/task/job_charger/test_longrun.py @@ -1,10 +1,12 @@ from unittest.mock import patch +from app.schema.domain import ChargeLongrunResult from app.task.job_charger import longrun as test_module @patch(f"{test_module.__name__}.charge_longrun") async def test_periodic_longrun_charger_run_forever(mock_charge_longrun): + mock_charge_longrun.return_value = ChargeLongrunResult() task = test_module.PeriodicLongrunCharger(name="test-longrun-charger") await task.run_forever(limit=1) assert mock_charge_longrun.call_count == 1 diff --git a/tests/task/job_charger/test_oneshot.py b/tests/task/job_charger/test_oneshot.py index 3885bf4..7bf0fff 100644 --- a/tests/task/job_charger/test_oneshot.py +++ b/tests/task/job_charger/test_oneshot.py @@ -1,10 +1,12 @@ from unittest.mock import patch +from app.schema.domain import ChargeOneshotResult from app.task.job_charger import oneshot as test_module @patch(f"{test_module.__name__}.charge_oneshot") async def test_periodic_oneshot_charger_run_forever(mock_charge_oneshot): + mock_charge_oneshot.return_value = ChargeOneshotResult() task = test_module.PeriodicOneshotCharger(name="test-oneshot-charger") await task.run_forever(limit=1) assert mock_charge_oneshot.call_count == 1 diff --git a/tests/task/job_charger/test_storage.py b/tests/task/job_charger/test_storage.py index dd36438..33b4887 100644 --- a/tests/task/job_charger/test_storage.py +++ b/tests/task/job_charger/test_storage.py @@ -1,10 +1,12 @@ from unittest.mock import patch +from app.schema.domain import ChargeStorageResult from app.task.job_charger import storage as test_module @patch(f"{test_module.__name__}.charge_storage") async def test_periodic_storage_charger_run_forever(mock_charge_storage): + mock_charge_storage.return_value = ChargeStorageResult() task = test_module.PeriodicStorageCharger(name="test-storage-charger") await task.run_forever(limit=1) assert mock_charge_storage.call_count == 2 diff --git a/tests/task/queue_consumer/__init__.py b/tests/task/queue_consumer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..397a202 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,12 @@ +from app import utils as test_module + + +def test_timer(): + with test_module.Timer() as timer: + elapsed_1 = timer.elapsed + elapsed_2 = timer.elapsed + elapsed_3 = timer.elapsed + elapsed_4 = timer.elapsed + + assert 0 < elapsed_1 < elapsed_2 < elapsed_3 + assert elapsed_3 == elapsed_4 diff --git a/tests/utils.py b/tests/utils.py index 18df98b..d92c0cc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -10,7 +10,6 @@ async def truncate_tables(session): query = text(f"""TRUNCATE {",".join(Base.metadata.tables)} RESTART IDENTITY CASCADE""") await session.execute(query) - await session.commit() async def _insert_oneshot_job(db, job_id, reserved_count, reserved_at):