Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent charger tasks from running concurrently #49

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ format: # Run formatters
lint: ## Run linters
uv run -m ruff format --check
uv run -m ruff check
uv run -m mypy app
uv run -m mypy app tests

build: ## Build the Docker image
docker compose build app
Expand Down
50 changes: 50 additions & 0 deletions alembic/versions/20240909_154014_1c15c81ce8a8_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""empty message

Revision ID: 1c15c81ce8a8
Revises: b56e853bb310
Create Date: 2024-09-09 15:40:14.071295

"""

from collections.abc import Sequence

import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "1c15c81ce8a8"
down_revision: str | None = "b56e853bb310"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"task_registry",
sa.Column("task_name", sa.String(), nullable=False),
sa.Column("last_run", sa.DateTime(timezone=True), nullable=True),
sa.Column("last_duration", sa.Float(), server_default=sa.text("0"), nullable=False),
sa.Column("last_error", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("task_name", name=op.f("pk_task_registry")),
)
op.drop_column("job", "updated_at")
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"job",
sa.Column(
"updated_at",
postgresql.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
autoincrement=False,
nullable=False,
),
)
op.drop_table("task_registry")
# ### end Alembic commands ###
12 changes: 11 additions & 1 deletion app/db/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ class Job(Base):
service_type: Mapped[ServiceType]
service_subtype: Mapped[ServiceSubtype]
created_at: Mapped[CREATED_AT]
updated_at: Mapped[UPDATED_AT]
reserved_at: Mapped[datetime | None]
started_at: Mapped[datetime | None]
last_alive_at: Mapped[datetime | None]
Expand Down Expand Up @@ -130,6 +129,17 @@ class Price(Base):
updated_at: Mapped[UPDATED_AT]


class TaskRegistry(Base):
"""TaskRegistry table."""

__tablename__ = "task_registry"

task_name: Mapped[str] = mapped_column(primary_key=True)
last_run: Mapped[datetime | None]
last_duration: Mapped[float] = mapped_column(server_default=text("0"))
last_error: Mapped[str | None]


Index(
"only_one_system_account",
Account.account_type,
Expand Down
11 changes: 8 additions & 3 deletions app/db/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@ async def close(self) -> None:
L.info("DB engine has been closed")

@asynccontextmanager
async def session(self) -> AsyncIterator[AsyncSession]:
"""Yield a new database session."""
async def session(self, *, commit: bool = True) -> AsyncIterator[AsyncSession]:
"""Yield a new database session.

Args:
commit: if True and no errors occurred, commit before closing the session.
"""
if not self._engine:
err = "DB engine not initialized"
raise RuntimeError(err)
Expand All @@ -50,7 +54,8 @@ async def session(self) -> AsyncIterator[AsyncSession]:
await session.rollback()
raise
else:
await session.commit()
if commit:
await session.commit()


database_session_manager = DatabaseSessionManager()
21 changes: 21 additions & 0 deletions app/db/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

from collections.abc import AsyncIterator, Callable
from contextlib import asynccontextmanager
from datetime import datetime

import sqlalchemy as sa
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession


Expand All @@ -22,3 +25,21 @@ async def try_nested(
else:
if on_success:
on_success()


async def current_timestamp(db: AsyncSession) -> datetime:
"""Return the start datetime of the current transaction.

The returned value does not change during the transaction.
"""
query = sa.select(func.current_timestamp())
return (await db.execute(query)).scalar_one()


async def clock_timestamp(db: AsyncSession) -> datetime:
"""Return the actual current time from the database.

The returned value changes every time the function is called.
"""
query = sa.select(func.clock_timestamp())
return (await db.execute(query)).scalar_one()
8 changes: 8 additions & 0 deletions app/repository/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from app.repository.ledger import LedgerRepository
from app.repository.price import PriceRepository
from app.repository.report import ReportRepository
from app.repository.task_registry import TaskRegistryRepository


class RepositoryGroup:
Expand All @@ -24,6 +25,7 @@ def __init__(
ledger_repo_class: type[LedgerRepository] = LedgerRepository,
price_repo_class: type[PriceRepository] = PriceRepository,
report_repo_class: type[ReportRepository] = ReportRepository,
task_registry_class: type[TaskRegistryRepository] = TaskRegistryRepository,
) -> None:
"""Init the repository group."""
self._db = db
Expand All @@ -33,6 +35,7 @@ def __init__(
self._ledger_repo_class = ledger_repo_class
self._price_repo_class = price_repo_class
self._report_repo_class = report_repo_class
self._task_registry_class = task_registry_class

@property
def db(self) -> AsyncSession:
Expand Down Expand Up @@ -68,3 +71,8 @@ def price(self) -> PriceRepository:
def report(self) -> ReportRepository:
"""Return the report repository."""
return self._report_repo_class(self.db)

@cached_property
def task_registry(self) -> TaskRegistryRepository:
"""Return the task_registry repository."""
return self._task_registry_class(self.db)
76 changes: 76 additions & 0 deletions app/repository/task_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Task registry module."""

from datetime import datetime

import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as pg
from sqlalchemy.exc import DBAPIError

from app.db.model import TaskRegistry
from app.repository.base import BaseRepository


class TaskRegistryRepository(BaseRepository):
"""TaskRegistryRepository."""

async def populate_task(self, task_name: str) -> TaskRegistry | None:
"""Insert a task record if it doesn't exist already.

Args:
task_name: name of the task.
"""
insert_query = (
pg.insert(TaskRegistry)
.values(task_name=task_name)
.on_conflict_do_nothing()
.returning(TaskRegistry)
)
return (await self.db.execute(insert_query)).scalar_one_or_none()

async def get_locked_task(self, task_name: str) -> TaskRegistry | None:
"""Lock and return a record from the task registry, or None if already locked.

Args:
task_name: name of the task.
"""
select_query = (
sa.select(TaskRegistry)
.where(TaskRegistry.task_name == task_name)
.with_for_update(nowait=True)
)
try:
# ensure that the record exists and that it can be locked
return (await self.db.execute(select_query)).scalar_one()
except DBAPIError as ex:
if getattr(ex.orig, "pgcode", None) == "55P03":
# Lock Not Available: the record exists, but it cannot be locked
return None
raise

async def update_task(
self,
task_name: str,
*,
last_run: datetime,
last_duration: float,
last_error: str | None,
) -> TaskRegistry:
"""Update an existing task in the registry.

Args:
task_name: name of the task.
last_run: last start time.
last_duration: last duration in seconds.
last_error: traceback from the last task execution, or None.
"""
query = (
sa.update(TaskRegistry)
.values(
last_run=last_run,
last_duration=last_duration,
last_error=last_error,
)
.where(TaskRegistry.task_name == task_name)
.returning(TaskRegistry)
)
return (await self.db.execute(query)).scalar_one()
22 changes: 12 additions & 10 deletions app/schema/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,16 @@ class StartedJob(BaseJob):


@dataclass(kw_only=True)
class ChargeLongrunResult:
class TaskResult:
"""Result of a generic task."""

success: int = 0
failure: int = 0
state: dict | None = None


@dataclass(kw_only=True)
class ChargeLongrunResult(TaskResult):
"""Result of charge_longrun."""

unfinished_uncharged: int = 0
Expand All @@ -87,20 +96,13 @@ class ChargeLongrunResult:
finished_overcharged: int = 0
expired_uncharged: int = 0
expired_charged: int = 0
failure: int = 0


@dataclass(kw_only=True)
class ChargeOneshotResult:
class ChargeOneshotResult(TaskResult):
"""Result of charge_oneshot."""

success: int = 0
failure: int = 0


@dataclass(kw_only=True)
class ChargeStorageResult:
class ChargeStorageResult(TaskResult):
"""Result of charge_storage."""

success: int = 0
failure: int = 0
7 changes: 5 additions & 2 deletions app/service/charge_longrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ async def _charge_generic(
)


async def charge_longrun(
async def charge_longrun( # noqa: C901
repos: RepositoryGroup,
min_charging_interval: float = 0.0,
min_charging_amount: Decimal = D0,
Expand All @@ -155,11 +155,14 @@ def _on_error() -> None:
L.exception("Error processing longrun job {}", job.id)
result.failure += 1

def _on_success() -> None:
result.success += 1

now = transaction_datetime or utcnow()
result = ChargeLongrunResult()
jobs = await repos.job.get_longrun_to_be_charged()
for job in jobs:
async with try_nested(repos.db, on_error=_on_error):
async with try_nested(repos.db, on_error=_on_error, on_success=_on_success):
match job:
case StartedJob(
last_alive_at=datetime() as last_alive_at,
Expand Down
1 change: 0 additions & 1 deletion app/service/charge_oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ async def charge_oneshot(repos: RepositoryGroup) -> ChargeOneshotResult:

Args:
repos: repository group instance.
jobs: optional sequence of jobs.
"""

def _on_error() -> None:
Expand Down
Loading