From 58ef2d2538e3a332defefa15d19b1384db9027b1 Mon Sep 17 00:00:00 2001 From: rhatgadkar-goog Date: Wed, 4 Dec 2024 09:46:03 -0800 Subject: [PATCH] feat: improve aiohttp client error messages (#400) This change is similar to https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/pull/1201, except for AlloyDB. --- google/cloud/alloydb/connector/client.py | 34 +++++- requirements-test.txt | 1 + tests/unit/test_client.py | 137 +++++++++++++++++++++++ 3 files changed, 166 insertions(+), 6 deletions(-) diff --git a/google/cloud/alloydb/connector/client.py b/google/cloud/alloydb/connector/client.py index 6edb9ee4..59e923a4 100644 --- a/google/cloud/alloydb/connector/client.py +++ b/google/cloud/alloydb/connector/client.py @@ -124,8 +124,20 @@ async def _get_metadata( url = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}/instances/{name}/connectionInfo" - resp = await self._client.get(url, headers=headers, raise_for_status=True) - resp_dict = await resp.json() + resp = await self._client.get(url, headers=headers) + # try to get response json for better error message + try: + resp_dict = await resp.json() + if resp.status >= 400: + # if detailed error message is in json response, use as error message + message = resp_dict.get("error", {}).get("message") + if message: + resp.reason = message + # skip, raise_for_status will catch all errors in finally block + except Exception: + pass + finally: + resp.raise_for_status() # Remove trailing period from PSC DNS name. psc_dns = resp_dict.get("pscDnsName") @@ -175,10 +187,20 @@ async def _get_client_certificate( "useMetadataExchange": self._use_metadata, } - resp = await self._client.post( - url, headers=headers, json=data, raise_for_status=True - ) - resp_dict = await resp.json() + resp = await self._client.post(url, headers=headers, json=data) + # try to get response json for better error message + try: + resp_dict = await resp.json() + if resp.status >= 400: + # if detailed error message is in json response, use as error message + message = resp_dict.get("error", {}).get("message") + if message: + resp.reason = message + # skip, raise_for_status will catch all errors in finally block + except Exception: + pass + finally: + resp.raise_for_status() return (resp_dict["caCert"], resp_dict["pemCertificateChain"]) diff --git a/requirements-test.txt b/requirements-test.txt index 11bf97eb..86ac93ac 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -6,3 +6,4 @@ pytest-asyncio==0.24.0 pytest-cov==6.0.0 pytest-aiohttp==1.0.5 SQLAlchemy[asyncio]==2.0.36 +aioresponses==0.7.7 diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 3da68079..e4b2fdbb 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -15,7 +15,9 @@ import json from typing import Any, Optional +from aiohttp import ClientResponseError from aiohttp import web +from aioresponses import aioresponses from mocks import FakeCredentials import pytest @@ -138,6 +140,75 @@ async def test__get_metadata_with_psc( } +async def test__get_metadata_error( + credentials: FakeCredentials, +) -> None: + """ + Test that AlloyDB API error messages are raised for _get_metadata. + """ + # mock AlloyDB API calls with exceptions + client = AlloyDBClient( + alloydb_api_endpoint="https://alloydb.googleapis.com", + quota_project=None, + credentials=credentials, + ) + get_url = "https://alloydb.googleapis.com/v1beta/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance/connectionInfo" + resp_body = { + "error": { + "code": 403, + "message": "AlloyDB API has not been used in project 123456789 before or it is disabled", + } + } + with aioresponses() as mocked: + mocked.get( + get_url, + status=403, + payload=resp_body, + repeat=True, + ) + with pytest.raises(ClientResponseError) as exc_info: + await client._get_metadata( + "my-project", "my-region", "my-cluster", "my-instance" + ) + assert exc_info.value.status == 403 + assert ( + exc_info.value.message + == "AlloyDB API has not been used in project 123456789 before or it is disabled" + ) + await client.close() + + +async def test__get_metadata_error_parsing_json( + credentials: FakeCredentials, +) -> None: + """ + Test that aiohttp default error messages are raised when _get_metadata gets + a bad JSON response. + """ + # mock AlloyDB API calls with exceptions + client = AlloyDBClient( + alloydb_api_endpoint="https://alloydb.googleapis.com", + quota_project=None, + credentials=credentials, + ) + get_url = "https://alloydb.googleapis.com/v1beta/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance/connectionInfo" + resp_body = ["error"] # invalid json + with aioresponses() as mocked: + mocked.get( + get_url, + status=403, + payload=resp_body, + repeat=True, + ) + with pytest.raises(ClientResponseError) as exc_info: + await client._get_metadata( + "my-project", "my-region", "my-cluster", "my-instance" + ) + assert exc_info.value.status == 403 + assert exc_info.value.message == "Forbidden" + await client.close() + + @pytest.mark.asyncio async def test__get_client_certificate( client: Any, credentials: FakeCredentials @@ -157,6 +228,72 @@ async def test__get_client_certificate( assert cert_chain[2] == "This is the root cert" +async def test__get_client_certificate_error( + credentials: FakeCredentials, +) -> None: + """ + Test that AlloyDB API error messages are raised for _get_client_certificate. + """ + # mock AlloyDB API calls with exceptions + client = AlloyDBClient( + alloydb_api_endpoint="https://alloydb.googleapis.com", + quota_project=None, + credentials=credentials, + ) + post_url = "https://alloydb.googleapis.com/v1beta/projects/my-project/locations/my-region/clusters/my-cluster:generateClientCertificate" + resp_body = { + "error": { + "code": 404, + "message": "The AlloyDB instance does not exist.", + } + } + with aioresponses() as mocked: + mocked.post( + post_url, + status=404, + payload=resp_body, + repeat=True, + ) + with pytest.raises(ClientResponseError) as exc_info: + await client._get_client_certificate( + "my-project", "my-region", "my-cluster", "" + ) + assert exc_info.value.status == 404 + assert exc_info.value.message == "The AlloyDB instance does not exist." + await client.close() + + +async def test__get_client_certificate_error_parsing_json( + credentials: FakeCredentials, +) -> None: + """ + Test that aiohttp default error messages are raised when + _get_client_certificate gets a bad JSON response. + """ + # mock AlloyDB API calls with exceptions + client = AlloyDBClient( + alloydb_api_endpoint="https://alloydb.googleapis.com", + quota_project=None, + credentials=credentials, + ) + post_url = "https://alloydb.googleapis.com/v1beta/projects/my-project/locations/my-region/clusters/my-cluster:generateClientCertificate" + resp_body = ["error"] # invalid json + with aioresponses() as mocked: + mocked.post( + post_url, + status=404, + payload=resp_body, + repeat=True, + ) + with pytest.raises(ClientResponseError) as exc_info: + await client._get_client_certificate( + "my-project", "my-region", "my-cluster", "" + ) + assert exc_info.value.status == 404 + assert exc_info.value.message == "Not Found" + await client.close() + + @pytest.mark.asyncio async def test_AlloyDBClient_init_(credentials: FakeCredentials) -> None: """