-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add initial support for async processes (#19)
- Loading branch information
Showing
9 changed files
with
595 additions
and
3 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Async Retsu package.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
"""Retsu tasks with celery.""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Any, Optional | ||
|
||
import celery | ||
|
||
from celery import chain, chord, group | ||
from public import public | ||
|
||
from retsu.asyncio.core import AsyncProcess | ||
|
||
|
||
@public | ||
class CeleryAsyncProcess(AsyncProcess): | ||
"""Async Celery Process class.""" | ||
|
||
async def process(self, *args, task_id: str, **kwargs) -> Any: | ||
"""Define the async process to be executed.""" | ||
chord_tasks, chord_callback = await self.get_chord_tasks( | ||
*args, | ||
task_id=task_id, | ||
**kwargs, | ||
) | ||
group_tasks = await self.get_group_tasks( | ||
*args, | ||
task_id=task_id, | ||
**kwargs, | ||
) | ||
chain_tasks = await self.get_chain_tasks( | ||
*args, | ||
task_id=task_id, | ||
**kwargs, | ||
) | ||
|
||
# Start the tasks asynchronously | ||
results = [] | ||
if chord_tasks: | ||
workflow_chord = chord(chord_tasks, chord_callback) | ||
promise_chord = workflow_chord.apply_async() | ||
results.extend(promise_chord.get()) | ||
|
||
if group_tasks: | ||
workflow_group = group(group_tasks) | ||
promise_group = workflow_group.apply_async() | ||
results.extend(promise_group.get()) | ||
|
||
if chain_tasks: | ||
workflow_chain = chain(chain_tasks) | ||
promise_chain = workflow_chain.apply_async() | ||
results.extend(promise_chain.get()) | ||
|
||
return results | ||
|
||
async def get_chord_tasks( # type: ignore | ||
self, *args, **kwargs | ||
) -> tuple[list[celery.Signature], Optional[celery.Signature]]: | ||
""" | ||
Run tasks with chord. | ||
Return | ||
------ | ||
tuple: | ||
list of tasks for the chord, and the task to be used as a callback | ||
""" | ||
chord_tasks: list[celery.Signature] = [] | ||
callback_task = None | ||
return (chord_tasks, callback_task) | ||
|
||
async def get_group_tasks( # type: ignore | ||
self, *args, **kwargs | ||
) -> list[celery.Signature]: | ||
""" | ||
Run tasks with group. | ||
Return | ||
------ | ||
tuple: | ||
list of tasks for the chord, and the task to be used as a callback | ||
""" | ||
group_tasks: list[celery.Signature] = [] | ||
return group_tasks | ||
|
||
async def get_chain_tasks( # type: ignore | ||
self, *args, **kwargs | ||
) -> list[celery.Signature]: | ||
"""Run tasks with chain.""" | ||
chain_tasks: list[celery.Signature] = [] | ||
return chain_tasks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
"""Async core module.""" | ||
|
||
import asyncio | ||
import logging | ||
|
||
from abc import abstractmethod | ||
from datetime import datetime | ||
from typing import Any | ||
from uuid import uuid4 | ||
|
||
from redis import asyncio as aioredis | ||
|
||
from retsu.asyncio.queues import RedisRetsuAsyncQueue | ||
from retsu.asyncio.results import ( | ||
ResultProcessManagerAsync, | ||
create_result_task_manager_async, | ||
) | ||
from retsu.core import Process | ||
from retsu.queues import get_redis_queue_config | ||
|
||
|
||
class AsyncProcess(Process): | ||
"""Main class for handling an async process.""" | ||
|
||
def __init__(self, workers: int = 1) -> None: | ||
"""Initialize an async process object.""" | ||
_klass = self.__class__ | ||
queue_in_name = f"{_klass.__module__}.{_klass.__qualname__}" | ||
|
||
self._client = aioredis.Redis(**get_redis_queue_config()) | ||
self.active = True | ||
self.workers = workers | ||
self.result: ResultProcessManagerAsync = ( | ||
create_result_task_manager_async() | ||
) | ||
self.queue_in = RedisRetsuAsyncQueue(queue_in_name) | ||
self.tasks = [] | ||
|
||
async def start(self) -> None: | ||
"""Start async tasks.""" | ||
logging.info(f"Starting async process {self.__class__.__name__}") | ||
for _ in range(self.workers): | ||
task = asyncio.create_task(self.run()) | ||
self.tasks.append(task) | ||
|
||
async def stop(self) -> None: | ||
"""Stop async tasks.""" | ||
logging.info(f"Stopping async process {self.__class__.__name__}") | ||
self.active = False | ||
for task in self.tasks: | ||
task.cancel() | ||
try: | ||
await ( | ||
task | ||
) # Ensure the task is properly awaited before moving on | ||
except asyncio.CancelledError: | ||
logging.info(f"Task {task.get_name()} has been cancelled.") | ||
|
||
async def request(self, *args, **kwargs) -> str: # type: ignore | ||
"""Feed the queue with data from the request for the process.""" | ||
task_id = uuid4().hex | ||
metadata = { | ||
"status": "starting", | ||
"created_at": datetime.now().isoformat(), | ||
"updated_at": datetime.now().isoformat(), | ||
} | ||
await self.result.create(task_id, metadata) # Ensure this is awaited | ||
await self.queue_in.put( | ||
{ | ||
"task_id": task_id, | ||
"args": args, | ||
"kwargs": kwargs, | ||
} | ||
) | ||
return task_id | ||
|
||
@abstractmethod | ||
async def process(self, *args, task_id: str, **kwargs) -> Any: # type: ignore | ||
"""Define the async process to be executed.""" | ||
raise Exception("`process` not implemented yet.") | ||
|
||
async def prepare_process(self, data: dict[str, Any]) -> None: | ||
"""Call the process with the necessary arguments.""" | ||
task_id = data.pop("task_id") | ||
await self.result.metadata.update(task_id, "status", "running") | ||
result = await self.process( | ||
*data["args"], | ||
task_id=task_id, | ||
**data["kwargs"], | ||
) | ||
await self.result.save(task_id, result) | ||
await self.result.metadata.update(task_id, "status", "completed") | ||
|
||
async def run(self) -> None: | ||
"""Run the async process with data from the queue.""" | ||
while self.active: | ||
try: | ||
data = await self.queue_in.get() | ||
await self.prepare_process(data) | ||
except asyncio.CancelledError: | ||
logging.info( | ||
f"Task {asyncio.current_task().get_name()} cancelled." | ||
) | ||
break # Break out of the loop if the task is canceled | ||
except Exception as e: | ||
logging.error(f"Error in process: {e}") | ||
break # Break out of the loop on any other exceptions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
"""Functions for handling queues and their configurations.""" | ||
|
||
from __future__ import annotations | ||
|
||
import asyncio | ||
import pickle | ||
|
||
from abc import abstractmethod | ||
from typing import Any | ||
|
||
from public import public | ||
from redis import asyncio as aioredis | ||
|
||
from retsu.queues import BaseRetsuQueue, get_redis_queue_config | ||
|
||
|
||
@public | ||
class BaseRetsuAsyncQueue(BaseRetsuQueue): | ||
"""Base Queue class.""" | ||
|
||
def __init__(self, name: str) -> None: | ||
"""Initialize BaseRetsuQueue.""" | ||
self.name = name | ||
|
||
@abstractmethod | ||
async def put(self, data: Any) -> None: | ||
"""Put data into the end of the queue.""" | ||
... | ||
|
||
@abstractmethod | ||
async def get(self) -> Any: | ||
"""Get the next data from the queue.""" | ||
... | ||
|
||
|
||
@public | ||
class RedisRetsuAsyncQueue(BaseRetsuQueue): | ||
"""Async RedisRetsuQueue class.""" | ||
|
||
def __init__(self, name: str) -> None: | ||
"""Initialize RedisRetsuQueue with async Redis client.""" | ||
super().__init__(name) | ||
self._client = aioredis.Redis( | ||
**get_redis_queue_config(), # Async Redis client configuration | ||
decode_responses=False, | ||
) | ||
|
||
async def put(self, data: Any) -> None: | ||
"""Put data into the end of the queue asynchronously.""" | ||
await self._client.rpush(self.name, pickle.dumps(data)) | ||
|
||
async def get(self) -> Any: | ||
"""Get the next data from the queue asynchronously.""" | ||
while True: | ||
data = await self._client.lpop(self.name) | ||
if data is None: | ||
await asyncio.sleep(0.1) # Non-blocking sleep for 100ms | ||
continue | ||
|
||
return pickle.loads(data) |
Oops, something went wrong.