diff --git a/examples/tasks/task_example.py b/examples/tasks/task_example.py index bb3e63f..1a84c95 100644 --- a/examples/tasks/task_example.py +++ b/examples/tasks/task_example.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union from fast_agave.tasks.sqs_tasks import task @@ -10,17 +10,22 @@ QUEUE2_URL = 'http://127.0.0.1:4000/123456789012/validator.fifo' -class ValidatorModel(BaseModel): +class User(BaseModel): name: str age: int nick_name: Optional[str] +class Company(BaseModel): + legal_name: str + rfc: str + + @task(queue_url=QUEUE_URL, region_name='us-east-1') async def dummy_task(message) -> None: print(message) -@task(queue_url=QUEUE2_URL, region_name='us-east-1', validator=ValidatorModel) -async def task_validator(message: ValidatorModel) -> None: +@task(queue_url=QUEUE2_URL, region_name='us-east-1') +async def task_validator(message: Union[User, Company]) -> None: print(message.dict()) diff --git a/fast_agave/tasks/sqs_tasks.py b/fast_agave/tasks/sqs_tasks.py index 4ea5102..912bcfa 100644 --- a/fast_agave/tasks/sqs_tasks.py +++ b/fast_agave/tasks/sqs_tasks.py @@ -4,11 +4,11 @@ from functools import wraps from itertools import count from json import JSONDecodeError -from typing import AsyncGenerator, Callable, Coroutine, Optional, Type +from typing import AsyncGenerator, Callable, Coroutine from aiobotocore.httpsession import HTTPClientError from aiobotocore.session import get_session -from pydantic import BaseModel +from pydantic import validate_arguments from ..exc import RetryTask @@ -25,12 +25,10 @@ async def run_task( receipt_handle: str, message_receive_count: int, max_retries: int, - validator: Optional[Type[BaseModel]] = None, ) -> None: delete_message = True try: - data = validator(**body) if validator else body - await task_func(data) + await task_func(body) except RetryTask as retry: delete_message = message_receive_count >= max_retries + 1 if not delete_message and retry.countdown and retry.countdown > 0: @@ -88,7 +86,6 @@ def task( visibility_timeout: int = 3600, max_retries: int = 1, max_concurrent_tasks: int = 5, - validator: Optional[Type[BaseModel]] = None, ): def task_builder(task_func: Callable): @wraps(task_func) @@ -108,6 +105,9 @@ async def concurrency_controller(coro: Coroutine) -> None: can_read.set() session = get_session() + + task_with_validators = validate_arguments(task_func) + async with session.create_client('sqs', region_name) as sqs: async for message in message_consumer( queue_url, @@ -127,14 +127,13 @@ async def concurrency_controller(coro: Coroutine) -> None: bg_task = asyncio.create_task( concurrency_controller( run_task( - task_func, + task_with_validators, body, sqs, queue_url, message['ReceiptHandle'], message_receive_count, max_retries, - validator, ), ), name='fast-agave-task', diff --git a/fast_agave/version.py b/fast_agave/version.py index 2d7893e..ef91994 100644 --- a/fast_agave/version.py +++ b/fast_agave/version.py @@ -1 +1 @@ -__version__ = '0.13.0' +__version__ = '0.14.0' diff --git a/tests/tasks/test_sqs_tasks.py b/tests/tasks/test_sqs_tasks.py index 559ace0..a6a780c 100644 --- a/tests/tasks/test_sqs_tasks.py +++ b/tests/tasks/test_sqs_tasks.py @@ -2,6 +2,7 @@ import datetime as dt import json import uuid +from typing import Dict, Union from unittest.mock import AsyncMock, call, patch import aiobotocore.client @@ -32,13 +33,17 @@ async def test_execute_tasks(sqs_client) -> None: MessageGroupId='1234', ) - async_mock_function = AsyncMock(return_value=None) + async_mock_function = AsyncMock() + + async def my_task(data: Dict) -> None: + await async_mock_function(data) + await task( queue_url=sqs_client.queue_url, region_name=CORE_QUEUE_REGION, wait_time_seconds=1, visibility_timeout=1, - )(async_mock_function)() + )(my_task)() async_mock_function.assert_called_with(test_message) assert async_mock_function.call_count == 1 @@ -48,26 +53,28 @@ async def test_execute_tasks(sqs_client) -> None: @pytest.mark.asyncio -async def test_execute_tasks_validator(sqs_client) -> None: - async_mock_function = AsyncMock(return_value=None) - +async def test_execute_tasks_with_validator(sqs_client) -> None: class Validator(BaseModel): id: str name: str + async_mock_function = AsyncMock(return_value=None) + + async def my_task(data: Validator) -> None: + await async_mock_function(data) + task_params = dict( queue_url=sqs_client.queue_url, region_name=CORE_QUEUE_REGION, wait_time_seconds=1, visibility_timeout=1, - validator=Validator, ) # Invalid body, not execute function await sqs_client.send_message( MessageBody=json.dumps(dict(foo='bar')), MessageGroupId='4321', ) - await task(**task_params)(async_mock_function)() + await task(**task_params)(my_task)() assert async_mock_function.call_count == 0 resp = await sqs_client.receive_message() assert 'Messages' not in resp @@ -78,7 +85,59 @@ class Validator(BaseModel): MessageBody=test_message.json(), MessageGroupId='1234', ) - await task(**task_params)(async_mock_function)() + await task(**task_params)(my_task)() + async_mock_function.assert_called_with(test_message) + assert async_mock_function.call_count == 1 + + resp = await sqs_client.receive_message() + assert 'Messages' not in resp + assert len(BACKGROUND_TASKS) == 0 + + +@pytest.mark.asyncio +async def test_execute_tasks_with_union_validator(sqs_client) -> None: + class User(BaseModel): + id: str + name: str + + class Company(BaseModel): + id: str + legal_name: str + rfc: str + + async_mock_function = AsyncMock(return_value=None) + + async def my_task(data: Union[User, Company]) -> None: + await async_mock_function(data) + + task_params = dict( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + ) + # Invalid body, not execute function + test_message = dict(id='ID123', name='Sor Juana Inés de la Cruz') + await sqs_client.send_message( + MessageBody=json.dumps(test_message), + MessageGroupId='4321', + ) + await task(**task_params)(my_task)() + async_mock_function.assert_called_with(test_message) + assert async_mock_function.call_count == 1 + + resp = await sqs_client.receive_message() + assert 'Messages' not in resp + assert len(BACKGROUND_TASKS) == 0 + + async_mock_function.reset_mock() + test_message = dict(id='ID123', legal_name='FastAgave', rfc='FA') + + await sqs_client.send_message( + MessageBody=json.dumps(test_message), + MessageGroupId='54321', + ) + await task(**task_params)(my_task)() async_mock_function.assert_called_with(test_message) assert async_mock_function.call_count == 1 @@ -92,14 +151,18 @@ async def test_not_execute_tasks(sqs_client) -> None: """ Este caso es cuando el queue está vacío. No hay nada que ejecutar """ - async_mock_function = AsyncMock(return_value=None) + async_mock_function = AsyncMock() + + async def my_task(data: Dict) -> None: + await async_mock_function(data) + # No escribimos un mensaje en el queue await task( queue_url=sqs_client.queue_url, region_name=CORE_QUEUE_REGION, wait_time_seconds=1, visibility_timeout=1, - )(async_mock_function)() + )(my_task)() async_mock_function.assert_not_called() resp = await sqs_client.receive_message() assert 'Messages' not in resp @@ -141,6 +204,10 @@ async def mock_create_client(*args, **kwargs): return client async_mock_function = AsyncMock(return_value=None) + + async def my_task(data: Dict) -> None: + await async_mock_function(data) + with patch( 'aiobotocore.client.AioClientCreator.create_client', mock_create_client ): @@ -150,7 +217,7 @@ async def mock_create_client(*args, **kwargs): wait_time_seconds=1, visibility_timeout=3, max_retries=1, - )(async_mock_function)() + )(my_task)() async_mock_function.assert_called_once() @@ -175,12 +242,15 @@ async def test_retry_tasks_default_max_retries(sqs_client) -> None: async_mock_function = AsyncMock(side_effect=RetryTask) + async def my_task(data: Dict) -> None: + await async_mock_function(data) + await task( queue_url=sqs_client.queue_url, region_name=CORE_QUEUE_REGION, wait_time_seconds=1, visibility_timeout=1, - )(async_mock_function)() + )(my_task)() expected_calls = [call(test_message)] * 2 async_mock_function.assert_has_calls(expected_calls) @@ -207,13 +277,16 @@ async def test_retry_tasks_custom_max_retries(sqs_client) -> None: async_mock_function = AsyncMock(side_effect=RetryTask) + async def my_task(data: Dict) -> None: + await async_mock_function(data) + await task( queue_url=sqs_client.queue_url, region_name=CORE_QUEUE_REGION, wait_time_seconds=1, visibility_timeout=1, max_retries=3, - )(async_mock_function)() + )(my_task)() expected_calls = [call(test_message)] * 4 async_mock_function.assert_has_calls(expected_calls) @@ -243,13 +316,16 @@ async def test_does_not_retry_on_unhandled_exceptions(sqs_client) -> None: side_effect=Exception('something went wrong :(') ) + async def my_task(data: Dict) -> None: + await async_mock_function(data) + await task( queue_url=sqs_client.queue_url, region_name=CORE_QUEUE_REGION, wait_time_seconds=1, visibility_timeout=1, max_retries=3, - )(async_mock_function)() + )(my_task)() async_mock_function.assert_called_with(test_message) assert async_mock_function.call_count == 1 @@ -281,11 +357,10 @@ async def test_retry_tasks_with_countdown(sqs_client) -> None: MessageGroupId='1234', ) - call_times = [] + async_mock_function = AsyncMock(side_effect=RetryTask(countdown=2)) - async def countdown_tester(_): - call_times.append(dt.datetime.now()) - raise RetryTask(countdown=2) + async def countdown_tester(data: Dict): + await async_mock_function(data, dt.datetime.now()) await task( queue_url=sqs_client.queue_url, @@ -294,7 +369,8 @@ async def countdown_tester(_): visibility_timeout=1, )(countdown_tester)() - assert len(call_times) == 2 + call_times = [arg[1] for arg, _ in async_mock_function.call_args_list] + assert async_mock_function.call_count == 2 assert call_times[1] - call_times[0] >= dt.timedelta(seconds=2) resp = await sqs_client.receive_message() assert 'Messages' not in resp @@ -312,14 +388,12 @@ async def test_concurrency_controller( MessageGroupId=message_id, ) - max_running_tasks = 0 + async_mock_function = AsyncMock() - async def task_counter(_) -> None: - nonlocal max_running_tasks + async def task_counter(data: Dict) -> None: await asyncio.sleep(1) running_tasks = len(await get_running_fast_agave_tasks()) - if running_tasks > max_running_tasks: - max_running_tasks = running_tasks + await async_mock_function(running_tasks) await task( queue_url=sqs_client.queue_url, @@ -330,4 +404,5 @@ async def task_counter(_) -> None: max_concurrent_tasks=2, )(task_counter)() - assert max_running_tasks == 2 + running_tasks = [call[0] for call, _ in async_mock_function.call_args_list] + assert max(running_tasks) == 2