Skip to content

Commit

Permalink
Prevent charger tasks from running concurrently
Browse files Browse the repository at this point in the history
  • Loading branch information
GianlucaFicarelli committed Oct 8, 2024
1 parent bec330e commit 66b01c3
Show file tree
Hide file tree
Showing 38 changed files with 561 additions and 69 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ format: # Run formatters
lint: ## Run linters
pdm run python -m ruff format --check
pdm run python -m ruff check
pdm run python -m mypy app
pdm run python -m mypy app tests

build: ## Build the Docker image
docker compose build app
Expand Down
50 changes: 50 additions & 0 deletions alembic/versions/20240909_154014_1c15c81ce8a8_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""empty message
Revision ID: 1c15c81ce8a8
Revises: b56e853bb310
Create Date: 2024-09-09 15:40:14.071295
"""

from collections.abc import Sequence

import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "1c15c81ce8a8"
down_revision: str | None = "b56e853bb310"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"task_registry",
sa.Column("task_name", sa.String(), nullable=False),
sa.Column("last_run", sa.DateTime(timezone=True), nullable=True),
sa.Column("last_duration", sa.Float(), server_default=sa.text("0"), nullable=False),
sa.Column("last_error", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("task_name", name=op.f("pk_task_registry")),
)
op.drop_column("job", "updated_at")
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"job",
sa.Column(
"updated_at",
postgresql.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
autoincrement=False,
nullable=False,
),
)
op.drop_table("task_registry")
# ### end Alembic commands ###
12 changes: 11 additions & 1 deletion app/db/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ class Job(Base):
service_type: Mapped[ServiceType]
service_subtype: Mapped[ServiceSubtype]
created_at: Mapped[CREATED_AT]
updated_at: Mapped[UPDATED_AT]
reserved_at: Mapped[datetime | None]
started_at: Mapped[datetime | None]
last_alive_at: Mapped[datetime | None]
Expand Down Expand Up @@ -130,6 +129,17 @@ class Price(Base):
updated_at: Mapped[UPDATED_AT]


class TaskRegistry(Base):
"""TaskRegistry table."""

__tablename__ = "task_registry"

task_name: Mapped[str] = mapped_column(primary_key=True)
last_run: Mapped[datetime | None]
last_duration: Mapped[float] = mapped_column(server_default=text("0"))
last_error: Mapped[str | None]


Index(
"only_one_system_account",
Account.account_type,
Expand Down
11 changes: 8 additions & 3 deletions app/db/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@ async def close(self) -> None:
L.info("DB engine has been closed")

@asynccontextmanager
async def session(self) -> AsyncIterator[AsyncSession]:
"""Yield a new database session."""
async def session(self, *, commit: bool = True) -> AsyncIterator[AsyncSession]:
"""Yield a new database session.
Args:
commit: if True and no errors occurred, commit before closing the session.
"""
if not self._engine:
err = "DB engine not initialized"
raise RuntimeError(err)
Expand All @@ -50,7 +54,8 @@ async def session(self) -> AsyncIterator[AsyncSession]:
await session.rollback()
raise
else:
await session.commit()
if commit:
await session.commit()


database_session_manager = DatabaseSessionManager()
21 changes: 21 additions & 0 deletions app/db/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

from collections.abc import AsyncIterator, Callable
from contextlib import asynccontextmanager
from datetime import datetime

import sqlalchemy as sa
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession


Expand All @@ -22,3 +25,21 @@ async def try_nested(
else:
if on_success:
on_success()


async def current_timestamp(db: AsyncSession) -> datetime:
"""Return the start datetime of the current transaction.
The returned value does not change during the transaction.
"""
query = sa.select(func.current_timestamp())
return (await db.execute(query)).scalar_one()


async def clock_timestamp(db: AsyncSession) -> datetime:
"""Return the actual current time from the database.
The returned value changes every time the function is called.
"""
query = sa.select(func.clock_timestamp())
return (await db.execute(query)).scalar_one()
8 changes: 8 additions & 0 deletions app/repository/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from app.repository.ledger import LedgerRepository
from app.repository.price import PriceRepository
from app.repository.report import ReportRepository
from app.repository.task_registry import TaskRegistryRepository


class RepositoryGroup:
Expand All @@ -24,6 +25,7 @@ def __init__(
ledger_repo_class: type[LedgerRepository] = LedgerRepository,
price_repo_class: type[PriceRepository] = PriceRepository,
report_repo_class: type[ReportRepository] = ReportRepository,
task_registry_class: type[TaskRegistryRepository] = TaskRegistryRepository,
) -> None:
"""Init the repository group."""
self._db = db
Expand All @@ -33,6 +35,7 @@ def __init__(
self._ledger_repo_class = ledger_repo_class
self._price_repo_class = price_repo_class
self._report_repo_class = report_repo_class
self._task_registry_class = task_registry_class

@property
def db(self) -> AsyncSession:
Expand Down Expand Up @@ -68,3 +71,8 @@ def price(self) -> PriceRepository:
def report(self) -> ReportRepository:
"""Return the report repository."""
return self._report_repo_class(self.db)

@cached_property
def task_registry(self) -> TaskRegistryRepository:
"""Return the task_registry repository."""
return self._task_registry_class(self.db)
76 changes: 76 additions & 0 deletions app/repository/task_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Task registry module."""

from datetime import datetime

import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as pg
from sqlalchemy.exc import DBAPIError

from app.db.model import TaskRegistry
from app.repository.base import BaseRepository


class TaskRegistryRepository(BaseRepository):
"""TaskRegistryRepository."""

async def populate_task(self, task_name: str) -> TaskRegistry | None:
"""Insert a task record if it doesn't exist already.
Args:
task_name: name of the task.
"""
insert_query = (
pg.insert(TaskRegistry)
.values(task_name=task_name)
.on_conflict_do_nothing()
.returning(TaskRegistry)
)
return (await self.db.execute(insert_query)).scalar_one_or_none()

async def get_locked_task(self, task_name: str) -> TaskRegistry | None:
"""Lock and return a record from the task registry, or None if already locked.
Args:
task_name: name of the task.
"""
select_query = (
sa.select(TaskRegistry)
.where(TaskRegistry.task_name == task_name)
.with_for_update(nowait=True)
)
try:
# ensure that the record exists and that it can be locked
return (await self.db.execute(select_query)).scalar_one()
except DBAPIError as ex:
if getattr(ex.orig, "pgcode", None) == "55P03":
# Lock Not Available: the record exists, but it cannot be locked
return None
raise

async def update_task(
self,
task_name: str,
*,
last_run: datetime,
last_duration: float,
last_error: str | None,
) -> TaskRegistry:
"""Update an existing task in the registry.
Args:
task_name: name of the task.
last_run: last start time.
last_duration: last duration in seconds.
last_error: traceback from the last task execution, or None.
"""
query = (
sa.update(TaskRegistry)
.values(
last_run=last_run,
last_duration=last_duration,
last_error=last_error,
)
.where(TaskRegistry.task_name == task_name)
.returning(TaskRegistry)
)
return (await self.db.execute(query)).scalar_one()
22 changes: 12 additions & 10 deletions app/schema/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,16 @@ class StartedJob(BaseJob):


@dataclass(kw_only=True)
class ChargeLongrunResult:
class TaskResult:
"""Result of a generic task."""

success: int = 0
failure: int = 0
state: dict | None = None


@dataclass(kw_only=True)
class ChargeLongrunResult(TaskResult):
"""Result of charge_longrun."""

unfinished_uncharged: int = 0
Expand All @@ -87,20 +96,13 @@ class ChargeLongrunResult:
finished_overcharged: int = 0
expired_uncharged: int = 0
expired_charged: int = 0
failure: int = 0


@dataclass(kw_only=True)
class ChargeOneshotResult:
class ChargeOneshotResult(TaskResult):
"""Result of charge_oneshot."""

success: int = 0
failure: int = 0


@dataclass(kw_only=True)
class ChargeStorageResult:
class ChargeStorageResult(TaskResult):
"""Result of charge_storage."""

success: int = 0
failure: int = 0
7 changes: 5 additions & 2 deletions app/service/charge_longrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ async def _charge_generic(
)


async def charge_longrun(
async def charge_longrun( # noqa: C901
repos: RepositoryGroup,
min_charging_interval: float = 0.0,
min_charging_amount: Decimal = D0,
Expand All @@ -155,11 +155,14 @@ def _on_error() -> None:
L.exception("Error processing longrun job {}", job.id)
result.failure += 1

def _on_success() -> None:
result.success += 1

now = transaction_datetime or utcnow()
result = ChargeLongrunResult()
jobs = await repos.job.get_longrun_to_be_charged()
for job in jobs:
async with try_nested(repos.db, on_error=_on_error):
async with try_nested(repos.db, on_error=_on_error, on_success=_on_success):
match job:
case StartedJob(
last_alive_at=datetime() as last_alive_at,
Expand Down
1 change: 0 additions & 1 deletion app/service/charge_oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ async def charge_oneshot(repos: RepositoryGroup) -> ChargeOneshotResult:
Args:
repos: repository group instance.
jobs: optional sequence of jobs.
"""

def _on_error() -> None:
Expand Down
Loading

0 comments on commit 66b01c3

Please sign in to comment.