diff --git a/google/cloud/alloydb/connector/async_connector.py b/google/cloud/alloydb/connector/async_connector.py index d1961a6..fba7488 100644 --- a/google/cloud/alloydb/connector/async_connector.py +++ b/google/cloud/alloydb/connector/async_connector.py @@ -177,8 +177,14 @@ async def connect( # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): ip_type = IPTypes(ip_type.upper()) - conn_info = await cache.connect_info() - ip_address = conn_info.get_preferred_ip(ip_type) + try: + conn_info = await cache.connect_info() + ip_address = conn_info.get_preferred_ip(ip_type) + except Exception: + # with an error from AlloyDB API call or IP type, invalidate the + # cache and re-raise the error + await self._remove_cached(instance_uri) + raise logger.debug(f"['{instance_uri}']: Connecting to {ip_address}:5433") # callable to be used for auto IAM authn @@ -202,6 +208,15 @@ def get_authentication_token() -> str: await cache.force_refresh() raise + async def _remove_cached(self, instance_uri: str) -> None: + """Stops all background refreshes and deletes the connection + info cache from the map of caches. + """ + logger.debug(f"['{instance_uri}']: Removing connection info from cache") + # remove cache from stored caches and close it + cache = self._cache.pop(instance_uri) + await cache.close() + async def __aenter__(self) -> Any: """Enter async context manager by returning Connector object""" return self diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index 2d1bfee..c4ad299 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -206,8 +206,14 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): ip_type = IPTypes(ip_type.upper()) - conn_info = await cache.connect_info() - ip_address = conn_info.get_preferred_ip(ip_type) + try: + conn_info = await cache.connect_info() + ip_address = conn_info.get_preferred_ip(ip_type) + except Exception: + # with an error from AlloyDB API call or IP type, invalidate the + # cache and re-raise the error + await self._remove_cached(instance_uri) + raise logger.debug(f"['{instance_uri}']: Connecting to {ip_address}:5433") # synchronous drivers are blocking and run using executor @@ -334,6 +340,15 @@ def metadata_exchange( return sock + async def _remove_cached(self, instance_uri: str) -> None: + """Stops all background refreshes and deletes the connection + info cache from the map of caches. + """ + logger.debug(f"['{instance_uri}']: Removing connection info from cache") + # remove cache from stored caches and close it + cache = self._cache.pop(instance_uri) + await cache.close() + def __enter__(self) -> "Connector": """Enter context manager by returning Connector object""" return self diff --git a/tests/unit/test_async_connector.py b/tests/unit/test_async_connector.py index e2b22b1..0f15087 100644 --- a/tests/unit/test_async_connector.py +++ b/tests/unit/test_async_connector.py @@ -15,6 +15,7 @@ import asyncio from typing import Union +from aiohttp import ClientResponseError from mock import patch from mocks import FakeAlloyDBClient from mocks import FakeConnectionInfo @@ -23,6 +24,8 @@ from google.cloud.alloydb.connector import AsyncConnector from google.cloud.alloydb.connector import IPTypes +from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError +from google.cloud.alloydb.connector.instance import RefreshAheadCache ALLOYDB_API_ENDPOINT = "https://alloydb.googleapis.com" @@ -294,3 +297,39 @@ async def test_async_connect_bad_ip_type( exc_info.value.args[0] == f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'PUBLIC', 'PRIVATE', 'PSC'." ) + + +async def test_Connector_remove_cached_bad_instance( + credentials: FakeCredentials, +) -> None: + """When a Connector attempts to retrieve connection info for a + non-existent instance, it should delete the instance from + the cache and ensure no background refresh happens (which would be + wasted cycles). + """ + instance_uri = "projects/test-project/locations/test-region/clusters/test-cluster/instances/bad-test-instance" + async with AsyncConnector(credentials=credentials) as connector: + with pytest.raises(ClientResponseError): + await connector.connect(instance_uri, "asyncpg") + assert instance_uri not in connector._cache + + +async def test_Connector_remove_cached_no_ip_type(credentials: FakeCredentials) -> None: + """When a Connector attempts to connect and preferred IP type is not present, + it should delete the instance from the cache and ensure no background refresh + happens (which would be wasted cycles). + """ + instance_uri = "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance" + # set instance to only have Public IP + fake_client = FakeAlloyDBClient() + fake_client.instance.ip_addrs = {"PUBLIC": "127.0.0.1"} + async with AsyncConnector(credentials=credentials) as connector: + connector._client = fake_client + # populate cache + cache = RefreshAheadCache(instance_uri, fake_client, connector._keys) + connector._cache[instance_uri] = cache + # test instance does not have Private IP, thus should invalidate cache + with pytest.raises(IPTypeNotFoundError): + await connector.connect(instance_uri, "asyncpg", ip_type="private") + # check that cache has been removed from dict + assert instance_uri not in connector._cache diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 95b6223..6afc366 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -16,6 +16,7 @@ from threading import Thread from typing import Union +from aiohttp import ClientResponseError from mock import patch from mocks import FakeAlloyDBClient from mocks import FakeCredentials @@ -23,6 +24,9 @@ from google.cloud.alloydb.connector import Connector from google.cloud.alloydb.connector import IPTypes +from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError +from google.cloud.alloydb.connector.instance import RefreshAheadCache +from google.cloud.alloydb.connector.utils import generate_keys def test_Connector_init(credentials: FakeCredentials) -> None: @@ -203,3 +207,50 @@ def test_Connector_close_called_multiple_times(credentials: FakeCredentials) -> assert connector._thread.is_alive() is False # call connector.close a second time connector.close() + + +async def test_Connector_remove_cached_bad_instance( + credentials: FakeCredentials, +) -> None: + """When a Connector attempts to retrieve connection info for a + non-existent instance, it should delete the instance from + the cache and ensure no background refresh happens (which would be + wasted cycles). + """ + instance_uri = "projects/test-project/locations/test-region/clusters/test-cluster/instances/bad-test-instance" + with Connector(credentials) as connector: + connector._keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe( + generate_keys(), asyncio.get_running_loop() + ), + loop=asyncio.get_running_loop(), + ) + with pytest.raises(ClientResponseError): + await connector.connect_async(instance_uri, "pg8000") + assert instance_uri not in connector._cache + + +async def test_Connector_remove_cached_no_ip_type(credentials: FakeCredentials) -> None: + """When a Connector attempts to connect and preferred IP type is not present, + it should delete the instance from the cache and ensure no background refresh + happens (which would be wasted cycles). + """ + instance_uri = "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance" + # set instance to only have Public IP + fake_client = FakeAlloyDBClient() + fake_client.instance.ip_addrs = {"PUBLIC": "127.0.0.1"} + with Connector(credentials=credentials) as connector: + connector._client = fake_client + connector._keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe( + generate_keys(), asyncio.get_running_loop() + ), + loop=asyncio.get_running_loop(), + ) + cache = RefreshAheadCache(instance_uri, fake_client, connector._keys) + connector._cache[instance_uri] = cache + # test instance does not have Private IP, thus should invalidate cache + with pytest.raises(IPTypeNotFoundError): + await connector.connect_async(instance_uri, "pg8000", ip_type="private") + # check that cache has been removed from dict + assert instance_uri not in connector._cache