Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add union types validation with pydantic #193

Merged
merged 6 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions examples/tasks/task_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union

from fast_agave.tasks.sqs_tasks import task

Expand All @@ -10,17 +10,22 @@
QUEUE2_URL = 'http://127.0.0.1:4000/123456789012/validator.fifo'


class ValidatorModel(BaseModel):
class User(BaseModel):
name: str
age: int
nick_name: Optional[str]


class Company(BaseModel):
legal_name: str
rfc: str


@task(queue_url=QUEUE_URL, region_name='us-east-1')
async def dummy_task(message) -> None:
print(message)


@task(queue_url=QUEUE2_URL, region_name='us-east-1', validator=ValidatorModel)
async def task_validator(message: ValidatorModel) -> None:
@task(queue_url=QUEUE2_URL, region_name='us-east-1')
async def task_validator(message: Union[User, Company]) -> None:
print(message.dict())
15 changes: 7 additions & 8 deletions fast_agave/tasks/sqs_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from functools import wraps
from itertools import count
from json import JSONDecodeError
from typing import AsyncGenerator, Callable, Coroutine, Optional, Type
from typing import AsyncGenerator, Callable, Coroutine

from aiobotocore.httpsession import HTTPClientError
from aiobotocore.session import get_session
from pydantic import BaseModel
from pydantic import validate_arguments

from ..exc import RetryTask

Expand All @@ -25,12 +25,10 @@ async def run_task(
receipt_handle: str,
message_receive_count: int,
max_retries: int,
validator: Optional[Type[BaseModel]] = None,
) -> None:
delete_message = True
try:
data = validator(**body) if validator else body
await task_func(data)
await task_func(body)
except RetryTask as retry:
delete_message = message_receive_count >= max_retries + 1
if not delete_message and retry.countdown and retry.countdown > 0:
Expand Down Expand Up @@ -88,7 +86,6 @@ def task(
visibility_timeout: int = 3600,
max_retries: int = 1,
max_concurrent_tasks: int = 5,
validator: Optional[Type[BaseModel]] = None,
):
def task_builder(task_func: Callable):
@wraps(task_func)
Expand All @@ -108,6 +105,9 @@ async def concurrency_controller(coro: Coroutine) -> None:
can_read.set()

session = get_session()

task_with_validators = validate_arguments(task_func)

async with session.create_client('sqs', region_name) as sqs:
async for message in message_consumer(
queue_url,
Expand All @@ -127,14 +127,13 @@ async def concurrency_controller(coro: Coroutine) -> None:
bg_task = asyncio.create_task(
concurrency_controller(
run_task(
task_func,
task_with_validators,
body,
sqs,
queue_url,
message['ReceiptHandle'],
message_receive_count,
max_retries,
validator,
),
),
name='fast-agave-task',
Expand Down
2 changes: 1 addition & 1 deletion fast_agave/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.13.0'
__version__ = '0.14.0'
125 changes: 100 additions & 25 deletions tests/tasks/test_sqs_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import datetime as dt
import json
import uuid
from typing import Dict, Union
from unittest.mock import AsyncMock, call, patch

import aiobotocore.client
Expand Down Expand Up @@ -32,13 +33,17 @@ async def test_execute_tasks(sqs_client) -> None:
MessageGroupId='1234',
)

async_mock_function = AsyncMock(return_value=None)
async_mock_function = AsyncMock()

async def my_task(data: Dict) -> None:
await async_mock_function(data)

await task(
queue_url=sqs_client.queue_url,
region_name=CORE_QUEUE_REGION,
wait_time_seconds=1,
visibility_timeout=1,
)(async_mock_function)()
)(my_task)()
async_mock_function.assert_called_with(test_message)
assert async_mock_function.call_count == 1

Expand All @@ -48,26 +53,28 @@ async def test_execute_tasks(sqs_client) -> None:


@pytest.mark.asyncio
async def test_execute_tasks_validator(sqs_client) -> None:
async_mock_function = AsyncMock(return_value=None)

async def test_execute_tasks_with_validator(sqs_client) -> None:
class Validator(BaseModel):
id: str
name: str

async_mock_function = AsyncMock(return_value=None)

async def my_task(data: Validator) -> None:
await async_mock_function(data)

task_params = dict(
queue_url=sqs_client.queue_url,
region_name=CORE_QUEUE_REGION,
wait_time_seconds=1,
visibility_timeout=1,
validator=Validator,
)
# Invalid body, not execute function
await sqs_client.send_message(
MessageBody=json.dumps(dict(foo='bar')),
MessageGroupId='4321',
)
await task(**task_params)(async_mock_function)()
await task(**task_params)(my_task)()
assert async_mock_function.call_count == 0
resp = await sqs_client.receive_message()
assert 'Messages' not in resp
Expand All @@ -78,7 +85,59 @@ class Validator(BaseModel):
MessageBody=test_message.json(),
MessageGroupId='1234',
)
await task(**task_params)(async_mock_function)()
await task(**task_params)(my_task)()
async_mock_function.assert_called_with(test_message)
assert async_mock_function.call_count == 1

resp = await sqs_client.receive_message()
assert 'Messages' not in resp
assert len(BACKGROUND_TASKS) == 0


@pytest.mark.asyncio
async def test_execute_tasks_with_union_validator(sqs_client) -> None:
class User(BaseModel):
id: str
name: str

class Company(BaseModel):
id: str
legal_name: str
rfc: str

async_mock_function = AsyncMock(return_value=None)

async def my_task(data: Union[User, Company]) -> None:
await async_mock_function(data)

task_params = dict(
queue_url=sqs_client.queue_url,
region_name=CORE_QUEUE_REGION,
wait_time_seconds=1,
visibility_timeout=1,
)
# Invalid body, not execute function
test_message = dict(id='ID123', name='Sor Juana Inés de la Cruz')
await sqs_client.send_message(
MessageBody=json.dumps(test_message),
MessageGroupId='4321',
)
await task(**task_params)(my_task)()
async_mock_function.assert_called_with(test_message)
assert async_mock_function.call_count == 1

resp = await sqs_client.receive_message()
assert 'Messages' not in resp
assert len(BACKGROUND_TASKS) == 0

async_mock_function.reset_mock()
test_message = dict(id='ID123', legal_name='FastAgave', rfc='FA')

await sqs_client.send_message(
MessageBody=json.dumps(test_message),
MessageGroupId='54321',
)
await task(**task_params)(my_task)()
async_mock_function.assert_called_with(test_message)
assert async_mock_function.call_count == 1

Expand All @@ -92,14 +151,18 @@ async def test_not_execute_tasks(sqs_client) -> None:
"""
Este caso es cuando el queue está vacío. No hay nada que ejecutar
"""
async_mock_function = AsyncMock(return_value=None)
async_mock_function = AsyncMock()

async def my_task(data: Dict) -> None:
await async_mock_function(data)

# No escribimos un mensaje en el queue
await task(
queue_url=sqs_client.queue_url,
region_name=CORE_QUEUE_REGION,
wait_time_seconds=1,
visibility_timeout=1,
)(async_mock_function)()
)(my_task)()
async_mock_function.assert_not_called()
resp = await sqs_client.receive_message()
assert 'Messages' not in resp
Expand Down Expand Up @@ -141,6 +204,10 @@ async def mock_create_client(*args, **kwargs):
return client

async_mock_function = AsyncMock(return_value=None)

async def my_task(data: Dict) -> None:
await async_mock_function(data)

with patch(
'aiobotocore.client.AioClientCreator.create_client', mock_create_client
):
Expand All @@ -150,7 +217,7 @@ async def mock_create_client(*args, **kwargs):
wait_time_seconds=1,
visibility_timeout=3,
max_retries=1,
)(async_mock_function)()
)(my_task)()
async_mock_function.assert_called_once()


Expand All @@ -175,12 +242,15 @@ async def test_retry_tasks_default_max_retries(sqs_client) -> None:

async_mock_function = AsyncMock(side_effect=RetryTask)

async def my_task(data: Dict) -> None:
await async_mock_function(data)

await task(
queue_url=sqs_client.queue_url,
region_name=CORE_QUEUE_REGION,
wait_time_seconds=1,
visibility_timeout=1,
)(async_mock_function)()
)(my_task)()

expected_calls = [call(test_message)] * 2
async_mock_function.assert_has_calls(expected_calls)
Expand All @@ -207,13 +277,16 @@ async def test_retry_tasks_custom_max_retries(sqs_client) -> None:

async_mock_function = AsyncMock(side_effect=RetryTask)

async def my_task(data: Dict) -> None:
await async_mock_function(data)

await task(
queue_url=sqs_client.queue_url,
region_name=CORE_QUEUE_REGION,
wait_time_seconds=1,
visibility_timeout=1,
max_retries=3,
)(async_mock_function)()
)(my_task)()

expected_calls = [call(test_message)] * 4
async_mock_function.assert_has_calls(expected_calls)
Expand Down Expand Up @@ -243,13 +316,16 @@ async def test_does_not_retry_on_unhandled_exceptions(sqs_client) -> None:
side_effect=Exception('something went wrong :(')
)

async def my_task(data: Dict) -> None:
await async_mock_function(data)

await task(
queue_url=sqs_client.queue_url,
region_name=CORE_QUEUE_REGION,
wait_time_seconds=1,
visibility_timeout=1,
max_retries=3,
)(async_mock_function)()
)(my_task)()

async_mock_function.assert_called_with(test_message)
assert async_mock_function.call_count == 1
Expand Down Expand Up @@ -281,11 +357,10 @@ async def test_retry_tasks_with_countdown(sqs_client) -> None:
MessageGroupId='1234',
)

call_times = []
async_mock_function = AsyncMock(side_effect=RetryTask(countdown=2))

async def countdown_tester(_):
call_times.append(dt.datetime.now())
raise RetryTask(countdown=2)
async def countdown_tester(data: Dict):
await async_mock_function(data, dt.datetime.now())

await task(
queue_url=sqs_client.queue_url,
Expand All @@ -294,7 +369,8 @@ async def countdown_tester(_):
visibility_timeout=1,
)(countdown_tester)()

assert len(call_times) == 2
call_times = [arg[1] for arg, _ in async_mock_function.call_args_list]
assert async_mock_function.call_count == 2
assert call_times[1] - call_times[0] >= dt.timedelta(seconds=2)
resp = await sqs_client.receive_message()
assert 'Messages' not in resp
Expand All @@ -312,14 +388,12 @@ async def test_concurrency_controller(
MessageGroupId=message_id,
)

max_running_tasks = 0
async_mock_function = AsyncMock()

async def task_counter(_) -> None:
nonlocal max_running_tasks
async def task_counter(data: Dict) -> None:
await asyncio.sleep(1)
running_tasks = len(await get_running_fast_agave_tasks())
if running_tasks > max_running_tasks:
max_running_tasks = running_tasks
await async_mock_function(running_tasks)

await task(
queue_url=sqs_client.queue_url,
Expand All @@ -330,4 +404,5 @@ async def task_counter(_) -> None:
max_concurrent_tasks=2,
)(task_counter)()

assert max_running_tasks == 2
running_tasks = [call[0] for call, _ in async_mock_function.call_args_list]
assert max(running_tasks) == 2
Loading