Skip to content

Commit

Permalink
chore: Add sub-progress reporter
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Nov 22, 2024
1 parent 0ead71f commit cd90688
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 19 deletions.
98 changes: 80 additions & 18 deletions src/ai/backend/common/bgtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import uuid
import weakref
from collections import defaultdict
from collections.abc import Mapping
from datetime import datetime
from typing import (
Annotated,
Any,
AsyncIterator,
Awaitable,
Expand All @@ -20,10 +23,18 @@
Type,
TypeAlias,
Union,
cast,
)

from aiohttp import web
from aiohttp_sse import sse_response
from dateutil.tz import tzutc
from pydantic import (
BaseModel,
Field,
PlainSerializer,
field_serializer,
)
from redis.asyncio import Redis
from redis.asyncio.client import Pipeline

Expand All @@ -50,46 +61,72 @@
MAX_BGTASK_ARCHIVE_PERIOD: Final = 86400 # 24 hours


NumSerializedToStr = Annotated[
int | float, PlainSerializer(lambda x: str(x), return_type=str, when_used="json")
]


class ProgressModel(BaseModel):
current: NumSerializedToStr = Field()
total: NumSerializedToStr = Field()
msg: str = Field(default="")
last_update: NumSerializedToStr = Field()
last_update_datetime: datetime = Field()
subreporter_task_ids: list[uuid.UUID] = Field()

@field_serializer("subreporter_task_ids", when_used="json")
def stringify_task_ids(self, subreporter_task_ids: list[uuid.UUID], _info: Any) -> str:
return ",".join([str(val) for val in subreporter_task_ids])

@field_serializer("last_update_datetime", when_used="json")
def stringify_dt(self, last_update_datetime: datetime, _info: Any) -> str:
return last_update_datetime.isoformat()


class ProgressReporter:
event_producer: Final[EventProducer]
task_id: Final[uuid.UUID]
total_progress: Union[int, float]
current_progress: Union[int, float]
subreporters: dict[uuid.UUID, ProgressReporter]
cool_down_seconds: float | None

def __init__(
self,
event_dispatcher: EventProducer,
task_id: uuid.UUID,
current_progress: int = 0,
total_progress: int = 0,
subreporters: dict[uuid.UUID, ProgressReporter] | None = None,
cool_down_seconds: float | None = None,
) -> None:
self.event_producer = event_dispatcher
self.task_id = task_id
self.current_progress = current_progress
self.total_progress = total_progress
self.subreporters = subreporters if subreporters is not None else {}
self.cool_down_seconds = cool_down_seconds

async def update(
self,
increment: Union[int, float] = 0,
message: str | None = None,
) -> None:
self.current_progress += increment
# keep the state as local variables because they might be changed
# due to interleaving at await statements below.
current, total = self.current_progress, self.total_progress
self._report_time = time.time()

def register_subreporter(self, reporter: ProgressReporter) -> None:
if reporter.task_id not in self.subreporters:
self.subreporters[reporter.task_id] = reporter

async def _update(self, data: ProgressModel, force: bool = False) -> None:
now = time.time()
if not force and (
self.cool_down_seconds is not None and now - self._report_time < self.cool_down_seconds
):
return
redis_producer = self.event_producer.redis_client

async def _pipe_builder(r: Redis) -> Pipeline:
pipe = r.pipeline(transaction=False)
tracker_key = f"bgtask.{self.task_id}"
await pipe.hset(
tracker_key,
mapping={
"current": str(current),
"total": str(total),
"msg": message or "",
"last_update": str(time.time()),
},
mapping=cast(Mapping[str | bytes, str], data.model_dump(mode="json")),
)
await pipe.expire(tracker_key, MAX_BGTASK_ARCHIVE_PERIOD)
return pipe
Expand All @@ -98,11 +135,33 @@ async def _pipe_builder(r: Redis) -> Pipeline:
await self.event_producer.produce_event(
BgtaskUpdatedEvent(
self.task_id,
message=message,
current_progress=current,
total_progress=total,
message=data.msg,
current_progress=data.current,
total_progress=data.total,
subreporter_task_ids=data.subreporter_task_ids,
),
)
self._report_time = now

async def update(
self,
increment: Union[int, float] = 0,
message: str | None = None,
force: bool = False,
) -> None:
now = time.time()
current_dt = datetime.now(tzutc())
self.current_progress += increment

data = ProgressModel(
current=self.current_progress,
total=self.total_progress,
msg=message or "",
last_update=now,
last_update_datetime=current_dt,
subreporter_task_ids=list(self.subreporters.keys()),
)
await self._update(data, force=force)


BackgroundTask = Callable[Concatenate[ProgressReporter, ...], Awaitable[str | None]]
Expand Down Expand Up @@ -159,6 +218,9 @@ async def push_bgtask_events(
case BgtaskUpdatedEvent():
body["current_progress"] = event.current_progress
body["total_progress"] = event.total_progress
body["subreporter_task_id"] = [
str(id) for id in event.subreporter_task_ids
]
await resp.send(json.dumps(body), event=event.name, retry=5)
case BgtaskDoneEvent():
if extra_data:
Expand Down
5 changes: 4 additions & 1 deletion src/ai/backend/common/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,13 +610,15 @@ class BgtaskUpdatedEvent(AbstractEvent):
task_id: uuid.UUID = attrs.field()
current_progress: float = attrs.field()
total_progress: float = attrs.field()
subreporter_task_ids: list[uuid.UUID] = attrs.field()
message: Optional[str] = attrs.field(default=None)

def serialize(self) -> tuple:
return (
str(self.task_id),
self.current_progress,
self.total_progress,
tuple(str(v) for v in self.subreporter_task_ids),
self.message,
)

Expand All @@ -626,7 +628,8 @@ def deserialize(cls, value: tuple):
uuid.UUID(value[0]),
value[1],
value[2],
value[3],
list(uuid.UUID(v) for v in value[3]),
value[4],
)


Expand Down

0 comments on commit cd90688

Please sign in to comment.