diff --git a/src/integrations/prefect-aws/prefect_aws/client_waiter.py b/src/integrations/prefect-aws/prefect_aws/client_waiter.py index db213b3eac99..4717326c9f29 100644 --- a/src/integrations/prefect-aws/prefect_aws/client_waiter.py +++ b/src/integrations/prefect-aws/prefect_aws/client_waiter.py @@ -5,14 +5,79 @@ from botocore.waiter import WaiterModel, create_waiter_with_client from prefect import task +from prefect._internal.compatibility.async_dispatch import async_dispatch from prefect.logging import get_run_logger -from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible +from prefect.utilities.asyncutils import run_sync_in_worker_thread from prefect_aws.credentials import AwsCredentials @task -@sync_compatible -async def client_waiter( +async def aclient_waiter( + client: str, + waiter_name: str, + aws_credentials: AwsCredentials, + waiter_definition: Optional[Dict[str, Any]] = None, + **waiter_kwargs: Optional[Dict[str, Any]], +): + """ + Asynchronously uses the underlying boto3 waiter functionality. + + Args: + client: The AWS client on which to wait (e.g., 'client_wait', 'ec2', etc). + waiter_name: The name of the waiter to instantiate. + You may also use a custom waiter name, if you supply + an accompanying waiter definition dict. + aws_credentials: Credentials to use for authentication with AWS. + waiter_definition: A valid custom waiter model, as a dict. Note that if + you supply a custom definition, it is assumed that the provided + 'waiter_name' is contained within the waiter definition dict. + **waiter_kwargs: Arguments to pass to the `waiter.wait(...)` method. Will + depend upon the specific waiter being called. + + Example: + Run an ec2 waiter until instance_exists. + ```python + from prefect import flow + from prefect_aws import AwsCredentials + from prefect_aws.client_wait import aclient_waiter + + @flow + async def example_client_wait_flow(): + aws_credentials = AwsCredentials( + aws_access_key_id="acccess_key_id", + aws_secret_access_key="secret_access_key" + ) + + await aclient_waiter( + "ec2", + "instance_exists", + aws_credentials + ) + ``` + """ + logger = get_run_logger() + logger.info("Waiting on %s job", client) + + boto_client = aws_credentials.get_boto3_session().client(client) + + if waiter_definition is not None: + # Use user-provided waiter definition + waiter_model = WaiterModel(waiter_definition) + waiter = create_waiter_with_client(waiter_name, waiter_model, boto_client) + elif waiter_name in boto_client.waiter_names: + waiter = boto_client.get_waiter(waiter_name) + else: + raise ValueError( + f"The waiter name, {waiter_name}, is not a valid boto waiter; " + "if using a custom waiter, you must provide a waiter definition" + ) + + await run_sync_in_worker_thread(waiter.wait, **waiter_kwargs) + + +@task +@async_dispatch(aclient_waiter) +def client_waiter( client: str, waiter_name: str, aws_credentials: AwsCredentials, @@ -48,14 +113,11 @@ def example_client_wait_flow(): aws_secret_access_key="secret_access_key" ) - waiter = client_waiter( + client_waiter( "ec2", "instance_exists", aws_credentials ) - - return waiter - example_client_wait_flow() ``` """ logger = get_run_logger() @@ -75,4 +137,4 @@ def example_client_wait_flow(): "if using a custom waiter, you must provide a waiter definition" ) - await run_sync_in_worker_thread(waiter.wait, **waiter_kwargs) + waiter.wait(**waiter_kwargs) diff --git a/src/integrations/prefect-aws/tests/test_client_waiter.py b/src/integrations/prefect-aws/tests/test_client_waiter.py index 6981f0182ba1..c545ee65cb34 100644 --- a/src/integrations/prefect-aws/tests/test_client_waiter.py +++ b/src/integrations/prefect-aws/tests/test_client_waiter.py @@ -2,7 +2,7 @@ import pytest from moto import mock_ec2 -from prefect_aws.client_waiter import client_waiter +from prefect_aws.client_waiter import aclient_waiter, client_waiter from prefect import flow @@ -31,39 +31,92 @@ def mock_client(monkeypatch, mock_waiter): return client_creator_mock -@mock_ec2 -def test_client_waiter_custom(mock_waiter, aws_credentials): - @flow - def test_flow(): - waiter = client_waiter( - "batch", - "JobExists", - aws_credentials, - waiter_definition={"waiters": {"JobExists": ["definition"]}, "version": 2}, - ) - return waiter +class TestClientWaiter: + @mock_ec2 + def test_client_waiter_custom(self, mock_waiter, aws_credentials): + @flow + def test_flow(): + return client_waiter( + "batch", + "JobExists", + aws_credentials, + waiter_definition={ + "waiters": {"JobExists": ["definition"]}, + "version": 2, + }, + ) + + test_flow() + mock_waiter().wait.assert_called_once_with() - test_flow() - mock_waiter().wait.assert_called_once_with() + @mock_ec2 + def test_client_waiter_custom_no_definition(self, mock_waiter, aws_credentials): + @flow + def test_flow(): + return client_waiter("batch", "JobExists", aws_credentials) + with pytest.raises(ValueError, match="The waiter name, JobExists"): + test_flow() -@mock_ec2 -def test_client_waiter_custom_no_definition(mock_waiter, aws_credentials): - @flow - def test_flow(): - waiter = client_waiter("batch", "JobExists", aws_credentials) - return waiter + @mock_ec2 + def test_client_waiter_boto(self, mock_waiter, mock_client, aws_credentials): + @flow + def test_flow(): + return client_waiter("ec2", "instance_exists", aws_credentials) - with pytest.raises(ValueError, match="The waiter name, JobExists"): test_flow() + mock_waiter.wait.assert_called_once_with() + + async def test_client_waiter_async_dispatch( + self, mock_waiter, mock_client, aws_credentials + ): + @flow + async def test_flow(): + return await client_waiter("ec2", "instance_exists", aws_credentials) + + await test_flow() + mock_waiter.wait.assert_called_once_with() + + async def test_client_waiter_force_sync_from_async( + self, mock_waiter, mock_client, aws_credentials + ): + client_waiter("ec2", "instance_exists", aws_credentials, _sync=True) + mock_waiter.wait.assert_called_once_with() + + +class TestClientWaiterAsync: + async def test_client_waiter_explicit_async( + self, mock_waiter, mock_client, aws_credentials + ): + @flow + async def test_flow(): + return await aclient_waiter("ec2", "instance_exists", aws_credentials) + + await test_flow() + mock_waiter.wait.assert_called_once_with() + + async def test_aclient_waiter_custom(self, mock_waiter, aws_credentials): + @flow + async def test_flow(): + return await aclient_waiter( + "batch", + "JobExists", + aws_credentials, + waiter_definition={ + "waiters": {"JobExists": ["definition"]}, + "version": 2, + }, + ) + await test_flow() + mock_waiter().wait.assert_called_once_with() -@mock_ec2 -def test_client_waiter_boto(mock_waiter, mock_client, aws_credentials): - @flow - def test_flow(): - waiter = client_waiter("ec2", "instance_exists", aws_credentials) - return waiter + async def test_aclient_waiter_custom_no_definition( + self, mock_waiter, aws_credentials + ): + @flow + async def test_flow(): + return await aclient_waiter("batch", "JobExists", aws_credentials) - test_flow() - mock_waiter.wait.assert_called_once_with() + with pytest.raises(ValueError, match="The waiter name, JobExists"): + await test_flow()