diff --git a/google/cloud/alloydb/connector/async_connector.py b/google/cloud/alloydb/connector/async_connector.py index d1961a6..0086bb0 100644 --- a/google/cloud/alloydb/connector/async_connector.py +++ b/google/cloud/alloydb/connector/async_connector.py @@ -177,8 +177,14 @@ async def connect( # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): ip_type = IPTypes(ip_type.upper()) - conn_info = await cache.connect_info() - ip_address = conn_info.get_preferred_ip(ip_type) + try: + conn_info = await cache.connect_info() + ip_address = conn_info.get_preferred_ip(ip_type) + except Exception: + # with an error from AlloyDB Admin API call or IP type, invalidate + # the cache and re-raise the error + await self._remove_cached(instance_uri) + raise logger.debug(f"['{instance_uri}']: Connecting to {ip_address}:5433") # callable to be used for auto IAM authn @@ -202,6 +208,17 @@ def get_authentication_token() -> str: await cache.force_refresh() raise + async def _remove_cached(self, instance_uri: str) -> None: + """Stops all background refreshes and deletes the connection + info cache from the map of caches. + """ + logger.debug( + f"['{instance_uri}']: Removing connection info from cache" + ) + # remove cache from stored caches and close it + cache = self._cache.pop(instance_uri) + await cache.close() + async def __aenter__(self) -> Any: """Enter async context manager by returning Connector object""" return self diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index 2d1bfee..25047ba 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -206,8 +206,14 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): ip_type = IPTypes(ip_type.upper()) - conn_info = await cache.connect_info() - ip_address = conn_info.get_preferred_ip(ip_type) + try: + conn_info = await cache.connect_info() + ip_address = conn_info.get_preferred_ip(ip_type) + except Exception: + # with an error from AlloyDB Admin API call or IP type, invalidate + # the cache and re-raise the error + await self._remove_cached(instance_uri) + raise logger.debug(f"['{instance_uri}']: Connecting to {ip_address}:5433") # synchronous drivers are blocking and run using executor @@ -334,6 +340,17 @@ def metadata_exchange( return sock + async def _remove_cached(self, instance_uri: str) -> None: + """Stops all background refreshes and deletes the connection + info cache from the map of caches. + """ + logger.debug( + f"['{instance_uri}']: Removing connection info from cache" + ) + # remove cache from stored caches and close it + cache = self._cache.pop(instance_uri) + await cache.close() + def __enter__(self) -> "Connector": """Enter context manager by returning Connector object""" return self diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index ae60035..d6ca4b6 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +from aiohttp import ClientResponseError from datetime import datetime from datetime import timedelta from datetime import timezone @@ -206,7 +207,14 @@ def __init__( self._user_agent = f"test-user-agent+{driver}" self._credentials = FakeCredentials() + i = FakeInstance() + # The instances that currently exist and the client can send API requests to. + self.existing_instances = [f"projects/{i.project}/locations/{i.region}/clusters/{i.cluster}/instances/{i.name}"] + async def _get_metadata(self, *args: Any, **kwargs: Any) -> str: + instance_uri = f"projects/{self.instance.project}/locations/{self.instance.region}/clusters/{self.instance.cluster}/instances/{self.instance.name}" + if instance_uri not in self.existing_instances: + raise ClientResponseError(None, 404) return self.instance.ip_addrs async def _get_client_certificate( @@ -216,6 +224,9 @@ async def _get_client_certificate( cluster: str, pub_key: str, ) -> Tuple[str, List[str]]: + instance_uri = f"projects/{self.instance.project}/locations/{self.instance.region}/clusters/{self.instance.cluster}/instances/{self.instance.name}" + if instance_uri not in self.existing_instances: + raise ClientResponseError(None, 404) root_cert, intermediate_cert, server_cert = self.instance.get_pem_certs() # encode public key to bytes pub_key_bytes: rsa.RSAPublicKey = serialization.load_pem_public_key( diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 95b6223..4625053 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -19,10 +19,13 @@ from mock import patch from mocks import FakeAlloyDBClient from mocks import FakeCredentials +from mocks import FakeInstance +from aiohttp import ClientResponseError import pytest from google.cloud.alloydb.connector import Connector from google.cloud.alloydb.connector import IPTypes +from google.cloud.alloydb.connector.instance import RefreshAheadCache def test_Connector_init(credentials: FakeCredentials) -> None: @@ -203,3 +206,54 @@ def test_Connector_close_called_multiple_times(credentials: FakeCredentials) -> assert connector._thread.is_alive() is False # call connector.close a second time connector.close() + + +@pytest.mark.asyncio +async def test_Connector_remove_cached_bad_instance( + credentials: FakeCredentials, fake_client: FakeAlloyDBClient +) -> None: + """When a Connector attempts to retrieve connection info for a + non-existent instance, it should delete the instance from + the cache and ensure no background refresh happens (which would be + wasted cycles). + """ + instance_uri = "projects/test-project/locations/test-region/clusters/test-cluster/instances/bad-test-instance" + with Connector(credentials) as connector: + connector._client = FakeAlloyDBClient(instance = FakeInstance(name = "bad-test-instance")) + # patch db connection creation + with patch("google.cloud.alloydb.connector.pg8000.connect") as mock_connect: + mock_connect.return_value = True + cache = RefreshAheadCache(instance_uri, fake_client, connector._keys) + connector._cache[instance_uri] = cache + with pytest.raises(ClientResponseError): + await connector.connect_async(instance_uri, "pg8000") + assert instance_uri not in connector._cache + + +# def test_Connector_remove_cached_no_ip_type( +# fake_credentials: FakeCredentials, fake_client: FakeAlloyDBClient +# ) -> None: +# """When a Connector attempts to connect and preferred IP type is not present, +# it should delete the instance from the cache and ensure no background refresh +# happens (which would be wasted cycles). +# """ +# # set instance to only have public IP +# fake_client.instance.ip_addrs = {"PRIMARY": "127.0.0.1"} +# async with Connector( +# credentials=fake_credentials, loop=asyncio.get_running_loop() +# ) as connector: +# conn_name = "test-project:test-region:test-instance" +# # populate cache +# cache = RefreshAheadCache(conn_name, fake_client_sync, connector._keys) +# connector._cache[conn_name] = cache +# # test instance does not have Private IP, thus should invalidate cache +# with pytest.raises(CloudSQLIPTypeError): +# await connector.connect_async( +# conn_name, +# "pg8000", +# user="my-user", +# password="my-pass", +# ip_type="private", +# ) +# # check that cache has been removed from dict +# assert conn_name not in connector._cache \ No newline at end of file