Skip to content

Commit

Permalink
add union types validation with pydantic logic
Browse files Browse the repository at this point in the history
  • Loading branch information
felipao-mx committed Oct 13, 2023
1 parent 570d811 commit a0d3831
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 33 deletions.
13 changes: 9 additions & 4 deletions examples/tasks/task_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union

from fast_agave.tasks.sqs_tasks import task

Expand All @@ -10,17 +10,22 @@
QUEUE2_URL = 'http://127.0.0.1:4000/123456789012/validator.fifo'


class ValidatorModel(BaseModel):
class User(BaseModel):
name: str
age: int
nick_name: Optional[str]


class Company(BaseModel):
legal_name: str
rfc: str


@task(queue_url=QUEUE_URL, region_name='us-east-1')
async def dummy_task(message) -> None:
print(message)


@task(queue_url=QUEUE2_URL, region_name='us-east-1', validator=ValidatorModel)
async def task_validator(message: ValidatorModel) -> None:
@task(queue_url=QUEUE2_URL, region_name='us-east-1')
async def task_validator(message: Union[User, Company]) -> None:
print(message.dict())
5 changes: 5 additions & 0 deletions fast_agave/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,8 @@ class FastAgaveViewError(FastAgaveError):
@dataclass
class RetryTask(Exception):
countdown: Optional[int] = None


@dataclass
class TaskDefinitionError(Exception):
error: str
15 changes: 7 additions & 8 deletions fast_agave/tasks/sqs_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from functools import wraps
from itertools import count
from json import JSONDecodeError
from typing import AsyncGenerator, Callable, Coroutine, Optional, Type
from typing import AsyncGenerator, Callable, Coroutine

from aiobotocore.httpsession import HTTPClientError
from aiobotocore.session import get_session
from pydantic import BaseModel
from pydantic import BaseModel, validate_arguments

from ..exc import RetryTask

Expand All @@ -25,12 +25,10 @@ async def run_task(
receipt_handle: str,
message_receive_count: int,
max_retries: int,
validator: Optional[Type[BaseModel]] = None,
) -> None:
delete_message = True
try:
data = validator(**body) if validator else body
await task_func(data)
await task_func(body)
except RetryTask as retry:
delete_message = message_receive_count >= max_retries + 1
if not delete_message and retry.countdown and retry.countdown > 0:
Expand Down Expand Up @@ -88,7 +86,6 @@ def task(
visibility_timeout: int = 3600,
max_retries: int = 1,
max_concurrent_tasks: int = 5,
validator: Optional[Type[BaseModel]] = None,
):
def task_builder(task_func: Callable):
@wraps(task_func)
Expand All @@ -108,6 +105,9 @@ async def concurrency_controller(coro: Coroutine) -> None:
can_read.set()

session = get_session()

task_with_validators = validate_arguments(task_func)

async with session.create_client('sqs', region_name) as sqs:
async for message in message_consumer(
queue_url,
Expand All @@ -127,14 +127,13 @@ async def concurrency_controller(coro: Coroutine) -> None:
bg_task = asyncio.create_task(
concurrency_controller(
run_task(
task_func,
task_with_validators,
body,
sqs,
queue_url,
message['ReceiptHandle'],
message_receive_count,
max_retries,
validator,
),
),
name='fast-agave-task',
Expand Down
2 changes: 1 addition & 1 deletion fast_agave/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.13.0'
__version__ = '0.14.0.dev0'
116 changes: 96 additions & 20 deletions tests/tasks/test_sqs_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import datetime as dt
import json
import uuid
from typing import Dict, Union
from unittest.mock import AsyncMock, call, patch

import aiobotocore.client
Expand Down Expand Up @@ -32,13 +33,17 @@ async def test_execute_tasks(sqs_client) -> None:
MessageGroupId='1234',
)

async_mock_function = AsyncMock(return_value=None)
async_mock_function = AsyncMock()

async def my_task(data: Dict) -> None:
await async_mock_function(data)

await task(
queue_url=sqs_client.queue_url,
region_name=CORE_QUEUE_REGION,
wait_time_seconds=1,
visibility_timeout=1,
)(async_mock_function)()
)(my_task)()
async_mock_function.assert_called_with(test_message)
assert async_mock_function.call_count == 1

Expand All @@ -48,26 +53,28 @@ async def test_execute_tasks(sqs_client) -> None:


@pytest.mark.asyncio
async def test_execute_tasks_validator(sqs_client) -> None:
async_mock_function = AsyncMock(return_value=None)

async def test_execute_tasks_with_validator(sqs_client) -> None:
class Validator(BaseModel):
id: str
name: str

async_mock_function = AsyncMock(return_value=None)

async def my_task(data: Validator) -> None:
await async_mock_function(data)

task_params = dict(
queue_url=sqs_client.queue_url,
region_name=CORE_QUEUE_REGION,
wait_time_seconds=1,
visibility_timeout=1,
validator=Validator,
)
# Invalid body, not execute function
await sqs_client.send_message(
MessageBody=json.dumps(dict(foo='bar')),
MessageGroupId='4321',
)
await task(**task_params)(async_mock_function)()
await task(**task_params)(my_task)()
assert async_mock_function.call_count == 0
resp = await sqs_client.receive_message()
assert 'Messages' not in resp
Expand All @@ -78,7 +85,59 @@ class Validator(BaseModel):
MessageBody=test_message.json(),
MessageGroupId='1234',
)
await task(**task_params)(async_mock_function)()
await task(**task_params)(my_task)()
async_mock_function.assert_called_with(test_message)
assert async_mock_function.call_count == 1

resp = await sqs_client.receive_message()
assert 'Messages' not in resp
assert len(BACKGROUND_TASKS) == 0


@pytest.mark.asyncio
async def test_execute_tasks_with_union_validator(sqs_client) -> None:
class User(BaseModel):
id: str
name: str

class Company(BaseModel):
id: str
legal_name: str
rfc: str

async_mock_function = AsyncMock(return_value=None)

async def my_task(data: Union[User, Company]) -> None:
await async_mock_function(data)

task_params = dict(
queue_url=sqs_client.queue_url,
region_name=CORE_QUEUE_REGION,
wait_time_seconds=1,
visibility_timeout=1,
)
# Invalid body, not execute function
test_message = dict(id='ID123', name='Sor Juana Inés de la Cruz')
await sqs_client.send_message(
MessageBody=json.dumps(test_message),
MessageGroupId='4321',
)
await task(**task_params)(my_task)()
async_mock_function.assert_called_with(test_message)
assert async_mock_function.call_count == 1

resp = await sqs_client.receive_message()
assert 'Messages' not in resp
assert len(BACKGROUND_TASKS) == 0

async_mock_function.reset_mock()
test_message = dict(id='ID123', legal_name='FastAgave', rfc='FA')

await sqs_client.send_message(
MessageBody=json.dumps(test_message),
MessageGroupId='54321',
)
await task(**task_params)(my_task)()
async_mock_function.assert_called_with(test_message)
assert async_mock_function.call_count == 1

Expand All @@ -92,14 +151,18 @@ async def test_not_execute_tasks(sqs_client) -> None:
"""
Este caso es cuando el queue está vacío. No hay nada que ejecutar
"""
async_mock_function = AsyncMock(return_value=None)
async_mock_function = AsyncMock()

async def my_task(data: Dict) -> None:
await async_mock_function(data)

# No escribimos un mensaje en el queue
await task(
queue_url=sqs_client.queue_url,
region_name=CORE_QUEUE_REGION,
wait_time_seconds=1,
visibility_timeout=1,
)(async_mock_function)()
)(my_task)()
async_mock_function.assert_not_called()
resp = await sqs_client.receive_message()
assert 'Messages' not in resp
Expand Down Expand Up @@ -141,6 +204,10 @@ async def mock_create_client(*args, **kwargs):
return client

async_mock_function = AsyncMock(return_value=None)

async def my_task(data: Dict) -> None:
await async_mock_function(data)

with patch(
'aiobotocore.client.AioClientCreator.create_client', mock_create_client
):
Expand All @@ -150,7 +217,7 @@ async def mock_create_client(*args, **kwargs):
wait_time_seconds=1,
visibility_timeout=3,
max_retries=1,
)(async_mock_function)()
)(my_task)()
async_mock_function.assert_called_once()


Expand All @@ -175,12 +242,15 @@ async def test_retry_tasks_default_max_retries(sqs_client) -> None:

async_mock_function = AsyncMock(side_effect=RetryTask)

async def my_task(data: Dict) -> None:
await async_mock_function(data)

await task(
queue_url=sqs_client.queue_url,
region_name=CORE_QUEUE_REGION,
wait_time_seconds=1,
visibility_timeout=1,
)(async_mock_function)()
)(my_task)()

expected_calls = [call(test_message)] * 2
async_mock_function.assert_has_calls(expected_calls)
Expand All @@ -207,13 +277,16 @@ async def test_retry_tasks_custom_max_retries(sqs_client) -> None:

async_mock_function = AsyncMock(side_effect=RetryTask)

async def my_task(data: Dict) -> None:
await async_mock_function(data)

await task(
queue_url=sqs_client.queue_url,
region_name=CORE_QUEUE_REGION,
wait_time_seconds=1,
visibility_timeout=1,
max_retries=3,
)(async_mock_function)()
)(my_task)()

expected_calls = [call(test_message)] * 4
async_mock_function.assert_has_calls(expected_calls)
Expand Down Expand Up @@ -243,13 +316,16 @@ async def test_does_not_retry_on_unhandled_exceptions(sqs_client) -> None:
side_effect=Exception('something went wrong :(')
)

async def my_task(data: Dict) -> None:
await async_mock_function(data)

await task(
queue_url=sqs_client.queue_url,
region_name=CORE_QUEUE_REGION,
wait_time_seconds=1,
visibility_timeout=1,
max_retries=3,
)(async_mock_function)()
)(my_task)()

async_mock_function.assert_called_with(test_message)
assert async_mock_function.call_count == 1
Expand Down Expand Up @@ -281,11 +357,10 @@ async def test_retry_tasks_with_countdown(sqs_client) -> None:
MessageGroupId='1234',
)

call_times = []
async_mock_function = AsyncMock(side_effect=RetryTask(countdown=2))

async def countdown_tester(_):
call_times.append(dt.datetime.now())
raise RetryTask(countdown=2)
async def countdown_tester(data: Dict):
await async_mock_function(data, dt.datetime.now())

await task(
queue_url=sqs_client.queue_url,
Expand All @@ -294,7 +369,8 @@ async def countdown_tester(_):
visibility_timeout=1,
)(countdown_tester)()

assert len(call_times) == 2
call_times = [arg[1] for arg, _ in async_mock_function.call_args_list]
assert async_mock_function.call_count == 2
assert call_times[1] - call_times[0] >= dt.timedelta(seconds=2)
resp = await sqs_client.receive_message()
assert 'Messages' not in resp
Expand All @@ -314,7 +390,7 @@ async def test_concurrency_controller(

max_running_tasks = 0

async def task_counter(_) -> None:
async def task_counter(_: Dict) -> None:
nonlocal max_running_tasks
await asyncio.sleep(1)
running_tasks = len(await get_running_fast_agave_tasks())
Expand Down

0 comments on commit a0d3831

Please sign in to comment.