Skip to content

Commit

Permalink
migrate prefect_aws.client_waiter
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz committed Nov 21, 2024
1 parent 567138f commit 0a976be
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 37 deletions.
78 changes: 70 additions & 8 deletions src/integrations/prefect-aws/prefect_aws/client_waiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,79 @@
from botocore.waiter import WaiterModel, create_waiter_with_client

from prefect import task
from prefect._internal.compatibility.async_dispatch import async_dispatch
from prefect.logging import get_run_logger
from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible
from prefect.utilities.asyncutils import run_sync_in_worker_thread
from prefect_aws.credentials import AwsCredentials


@task
@sync_compatible
async def client_waiter(
async def aclient_waiter(
client: str,
waiter_name: str,
aws_credentials: AwsCredentials,
waiter_definition: Optional[Dict[str, Any]] = None,
**waiter_kwargs: Optional[Dict[str, Any]],
):
"""
Asynchronously uses the underlying boto3 waiter functionality.
Args:
client: The AWS client on which to wait (e.g., 'client_wait', 'ec2', etc).
waiter_name: The name of the waiter to instantiate.
You may also use a custom waiter name, if you supply
an accompanying waiter definition dict.
aws_credentials: Credentials to use for authentication with AWS.
waiter_definition: A valid custom waiter model, as a dict. Note that if
you supply a custom definition, it is assumed that the provided
'waiter_name' is contained within the waiter definition dict.
**waiter_kwargs: Arguments to pass to the `waiter.wait(...)` method. Will
depend upon the specific waiter being called.
Example:
Run an ec2 waiter until instance_exists.
```python
from prefect import flow
from prefect_aws import AwsCredentials
from prefect_aws.client_wait import aclient_waiter
@flow
async def example_client_wait_flow():
aws_credentials = AwsCredentials(
aws_access_key_id="acccess_key_id",
aws_secret_access_key="secret_access_key"
)
await aclient_waiter(
"ec2",
"instance_exists",
aws_credentials
)
```
"""
logger = get_run_logger()
logger.info("Waiting on %s job", client)

boto_client = aws_credentials.get_boto3_session().client(client)

if waiter_definition is not None:
# Use user-provided waiter definition
waiter_model = WaiterModel(waiter_definition)
waiter = create_waiter_with_client(waiter_name, waiter_model, boto_client)
elif waiter_name in boto_client.waiter_names:
waiter = boto_client.get_waiter(waiter_name)
else:
raise ValueError(
f"The waiter name, {waiter_name}, is not a valid boto waiter; "
"if using a custom waiter, you must provide a waiter definition"
)

await run_sync_in_worker_thread(waiter.wait, **waiter_kwargs)


@task
@async_dispatch(aclient_waiter)
def client_waiter(
client: str,
waiter_name: str,
aws_credentials: AwsCredentials,
Expand Down Expand Up @@ -48,14 +113,11 @@ def example_client_wait_flow():
aws_secret_access_key="secret_access_key"
)
waiter = client_waiter(
client_waiter(
"ec2",
"instance_exists",
aws_credentials
)
return waiter
example_client_wait_flow()
```
"""
logger = get_run_logger()
Expand All @@ -75,4 +137,4 @@ def example_client_wait_flow():
"if using a custom waiter, you must provide a waiter definition"
)

await run_sync_in_worker_thread(waiter.wait, **waiter_kwargs)
waiter.wait(**waiter_kwargs)
111 changes: 82 additions & 29 deletions src/integrations/prefect-aws/tests/test_client_waiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
from moto import mock_ec2
from prefect_aws.client_waiter import client_waiter
from prefect_aws.client_waiter import aclient_waiter, client_waiter

from prefect import flow

Expand Down Expand Up @@ -31,39 +31,92 @@ def mock_client(monkeypatch, mock_waiter):
return client_creator_mock


@mock_ec2
def test_client_waiter_custom(mock_waiter, aws_credentials):
@flow
def test_flow():
waiter = client_waiter(
"batch",
"JobExists",
aws_credentials,
waiter_definition={"waiters": {"JobExists": ["definition"]}, "version": 2},
)
return waiter
class TestClientWaiter:
@mock_ec2
def test_client_waiter_custom(self, mock_waiter, aws_credentials):
@flow
def test_flow():
return client_waiter(
"batch",
"JobExists",
aws_credentials,
waiter_definition={
"waiters": {"JobExists": ["definition"]},
"version": 2,
},
)

test_flow()
mock_waiter().wait.assert_called_once_with()

test_flow()
mock_waiter().wait.assert_called_once_with()
@mock_ec2
def test_client_waiter_custom_no_definition(self, mock_waiter, aws_credentials):
@flow
def test_flow():
return client_waiter("batch", "JobExists", aws_credentials)

with pytest.raises(ValueError, match="The waiter name, JobExists"):
test_flow()

@mock_ec2
def test_client_waiter_custom_no_definition(mock_waiter, aws_credentials):
@flow
def test_flow():
waiter = client_waiter("batch", "JobExists", aws_credentials)
return waiter
@mock_ec2
def test_client_waiter_boto(self, mock_waiter, mock_client, aws_credentials):
@flow
def test_flow():
return client_waiter("ec2", "instance_exists", aws_credentials)

with pytest.raises(ValueError, match="The waiter name, JobExists"):
test_flow()
mock_waiter.wait.assert_called_once_with()

async def test_client_waiter_async_dispatch(
self, mock_waiter, mock_client, aws_credentials
):
@flow
async def test_flow():
return await client_waiter("ec2", "instance_exists", aws_credentials)

await test_flow()
mock_waiter.wait.assert_called_once_with()

async def test_client_waiter_force_sync_from_async(
self, mock_waiter, mock_client, aws_credentials
):
client_waiter("ec2", "instance_exists", aws_credentials, _sync=True)
mock_waiter.wait.assert_called_once_with()


class TestClientWaiterAsync:
async def test_client_waiter_explicit_async(
self, mock_waiter, mock_client, aws_credentials
):
@flow
async def test_flow():
return await aclient_waiter("ec2", "instance_exists", aws_credentials)

await test_flow()
mock_waiter.wait.assert_called_once_with()

async def test_aclient_waiter_custom(self, mock_waiter, aws_credentials):
@flow
async def test_flow():
return await aclient_waiter(
"batch",
"JobExists",
aws_credentials,
waiter_definition={
"waiters": {"JobExists": ["definition"]},
"version": 2,
},
)

await test_flow()
mock_waiter().wait.assert_called_once_with()

@mock_ec2
def test_client_waiter_boto(mock_waiter, mock_client, aws_credentials):
@flow
def test_flow():
waiter = client_waiter("ec2", "instance_exists", aws_credentials)
return waiter
async def test_aclient_waiter_custom_no_definition(
self, mock_waiter, aws_credentials
):
@flow
async def test_flow():
return await aclient_waiter("batch", "JobExists", aws_credentials)

test_flow()
mock_waiter.wait.assert_called_once_with()
with pytest.raises(ValueError, match="The waiter name, JobExists"):
await test_flow()

0 comments on commit 0a976be

Please sign in to comment.