From d14617bf01383cd61846ecb39cf3be44fd20c89a Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Thu, 11 Jan 2024 12:17:24 -0700 Subject: [PATCH] feat: add support for asyncpg (#199) Co-authored-by: Jack Wotherspoon --- README.md | 130 +++++++++++- google/cloud/alloydb/connector/__init__.py | 3 +- .../alloydb/connector/async_connector.py | 153 ++++++++++++++ google/cloud/alloydb/connector/asyncpg.py | 63 ++++++ mypy.ini | 3 + tests/system/test_asyncpg_connection.py | 47 +++++ tests/unit/mocks.py | 23 ++- tests/unit/test_async_connector.py | 193 ++++++++++++++++++ 8 files changed, 610 insertions(+), 5 deletions(-) create mode 100644 google/cloud/alloydb/connector/async_connector.py create mode 100644 google/cloud/alloydb/connector/asyncpg.py create mode 100644 tests/system/test_asyncpg_connection.py create mode 100644 tests/unit/test_async_connector.py diff --git a/README.md b/README.md index a5e848f..5e9a22f 100644 --- a/README.md +++ b/README.md @@ -34,10 +34,22 @@ Currently supported drivers are: You can install this library with `pip install`: +### pg8000 + ```sh pip install "google-cloud-alloydb-connector[pg8000]" ``` +See [Synchronous Driver Usage](#synchronous-driver-usage) for details. + +### asyncpg + +```sh +pip install "google-cloud-alloydb-connector[asyncpg]" +``` + +See [Async Driver Usage](#async-driver-usage) for details. + ### APIs and Services This package requires the following to connect successfully: @@ -70,7 +82,7 @@ This package provides several functions for authorizing and encrypting connections. These functions are used with your database driver to connect to your AlloyDB instance. -AlloyDB supports network connectivity through private, internal IP addresses only. +AlloyDB supports network connectivity through private, internal IP addresses only. This package must be run in an environment that is connected to the [VPC Network][vpc] that hosts your AlloyDB private IP address. @@ -79,7 +91,7 @@ Please see [Configuring AlloyDB Connectivity][alloydb-connectivity] for more det [vpc]: https://cloud.google.com/vpc/docs/vpc [alloydb-connectivity]: https://cloud.google.com/alloydb/docs/configure-connectivity -### How to use this Connector +### Synchronous Driver Usage To connect to AlloyDB using the connector, inititalize a `Connector` object and call it's `connect` method with the proper input parameters. @@ -151,7 +163,7 @@ To close the `Connector` object's background resources, call it's `close()` meth connector.close() ``` -### Using Connector as a Context Manager +### Synchronous Context Manager The `Connector` object can also be used as a context manager in order to automatically close and cleanup resources, removing the need for explicit @@ -202,6 +214,118 @@ with pool.connect() as db_conn: print(row) ``` +### Async Driver Usage + +The AlloyDB Connector is compatible with [asyncio][] to improve the speed and +efficiency of database connections through concurrency. The `AsyncConnector` +currently supports the following asyncio database drivers: + +- [asyncpg](https://magicstack.github.io/asyncpg) + +[asyncio]: https://docs.python.org/3/library/asyncio.html + +```python +import asyncpg + +import sqlalchemy +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine + +from google.cloud.alloydb.connector import AsyncConnector + +async def init_connection_pool(connector: AsyncConnector) -> AsyncEngine: + # initialize Connector object for connections to AlloyDB + async def getconn() -> asyncpg.Connection: + conn: asyncpg.Connection = await connector.connect( + "projects//locations//clusters//instances/", + "asyncpg", + user="my-user", + password="my-password", + db="my-db-name" + # ... additional database driver args + ) + return conn + + # The AlloyDB Python Connector can be used along with SQLAlchemy using the + # 'async_creator' argument to 'create_async_engine' + pool = create_async_engine( + "postgresql+asyncpg://", + async_creator=getconn, + ) + return pool + +async def main(): + connector = AsyncConnector() + + # initialize connection pool + pool = await init_connection_pool(connector) + + # example query + async with pool.connect() as conn: + await conn.execute(sqlalchemy.text("SELECT NOW()")) + + # dispose of connection pool + await pool.dispose() + + # close Connector + await connector.close() + +``` + +For more details on additional arguments with an `asyncpg.Connection`, please +visit the [official documentation][asyncpg-docs]. + + +[asyncpg-docs]: https://magicstack.github.io/asyncpg/current/api/index.html + +### Async Context Manager + +The `AsyncConnector` also may be used as an async context manager, removing the +need for explicit calls to `connector.close()` to cleanup resources. + +```python +import asyncio +import asyncpg + +import sqlalchemy +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine + +from google.cloud.alloydb.connector import AsyncConnector + +async def init_connection_pool(connector: AsyncConnector) -> AsyncEngine: + # initialize Connector object for connections to AlloyDB + async def getconn() -> asyncpg.Connection: + conn: asyncpg.Connection = await connector.connect( + "projects//locations//clusters//instances/", + "asyncpg", + user="my-user", + password="my-password", + db="my-db-name" + # ... additional database driver args + ) + return conn + + # The AlloyDB Python Connector can be used along with SQLAlchemy using the + # 'async_creator' argument to 'create_async_engine' + pool = create_async_engine( + "postgresql+asyncpg://", + async_creator=getconn, + ) + return pool + +async def main(): + # initialize Connector object for connections to AlloyDB + async with AsyncConnector() as connector: + # initialize connection pool + pool = await init_connection_pool(connector) + + # example query + async with pool.connect() as conn: + await conn.execute(sqlalchemy.text("SELECT NOW()")) + + # dispose of connection pool + await pool.dispose() +``` + ## Support policy ### Major version lifecycle diff --git a/google/cloud/alloydb/connector/__init__.py b/google/cloud/alloydb/connector/__init__.py index c4e0891..7ee1498 100644 --- a/google/cloud/alloydb/connector/__init__.py +++ b/google/cloud/alloydb/connector/__init__.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from google.cloud.alloydb.connector.async_connector import AsyncConnector from google.cloud.alloydb.connector.connector import Connector from google.cloud.alloydb.connector.version import __version__ -__all__ = ["__version__", "Connector"] +__all__ = ["__version__", "Connector", "AsyncConnector"] diff --git a/google/cloud/alloydb/connector/async_connector.py b/google/cloud/alloydb/connector/async_connector.py new file mode 100644 index 0000000..1c36801 --- /dev/null +++ b/google/cloud/alloydb/connector/async_connector.py @@ -0,0 +1,153 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from types import TracebackType +from typing import Any, Dict, Optional, Type, TYPE_CHECKING + +from google.auth import default +from google.auth.credentials import with_scopes_if_required + +import google.cloud.alloydb.connector.asyncpg as asyncpg +from google.cloud.alloydb.connector.client import AlloyDBClient +from google.cloud.alloydb.connector.instance import Instance +from google.cloud.alloydb.connector.utils import generate_keys + +if TYPE_CHECKING: + from google.auth.credentials import Credentials + + +class AsyncConnector: + """A class to configure and create connections to Cloud SQL instances + asynchronously. + + Args: + credentials (google.auth.credentials.Credentials): + A credentials object created from the google-auth Python library. + If not specified, Application Default Credentials are used. + quota_project (str): The Project ID for an existing Google Cloud + project. The project specified is used for quota and + billing purposes. + Defaults to None, picking up project from environment. + alloydb_api_endpoint (str): Base URL to use when calling + the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com". + """ + + def __init__( + self, + credentials: Optional[Credentials] = None, + quota_project: Optional[str] = None, + alloydb_api_endpoint: str = "https://alloydb.googleapis.com", + ) -> None: + self._instances: Dict[str, Instance] = {} + # initialize default params + self._quota_project = quota_project + self._alloydb_api_endpoint = alloydb_api_endpoint + # initialize credentials + scopes = ["https://www.googleapis.com/auth/cloud-platform"] + if credentials: + self._credentials = with_scopes_if_required(credentials, scopes=scopes) + # otherwise use application default credentials + else: + self._credentials, _ = default(scopes=scopes) + self._keys = asyncio.create_task(generate_keys()) + self._client: Optional[AlloyDBClient] = None + + async def connect( + self, + instance_uri: str, + driver: str, + **kwargs: Any, + ) -> Any: + """ + Asynchronously prepares and returns a database connection object. + + Starts tasks to refresh the certificates and get + AlloyDB instance IP address. Creates a secure TLS connection + to establish connection to AlloyDB instance. + + Args: + instance_uri (str): The instance URI of the AlloyDB instance. + ex. projects//locations//clusters//instances/ + driver (str): A string representing the database driver to connect + with. Supported drivers are asyncpg. + **kwargs: Pass in any database driver-specific arguments needed + to fine tune connection. + + Returns: + connection: A DBAPI connection to the specified AlloyDB instance. + """ + if self._client is None: + # lazy init client as it has to be initialized in async context + self._client = AlloyDBClient( + self._alloydb_api_endpoint, + self._quota_project, + self._credentials, + ) + + # use existing connection info if possible + if instance_uri in self._instances: + instance = self._instances[instance_uri] + else: + instance = Instance(instance_uri, self._client, self._keys) + self._instances[instance_uri] = instance + + connect_func = { + "asyncpg": asyncpg.connect, + } + # only accept supported database drivers + try: + connector = connect_func[driver] + except KeyError: + raise ValueError(f"Driver '{driver}' is not a supported database driver.") + + # Host and ssl options come from the certificates and instance IP + # address so we don't want the user to specify them. + kwargs.pop("host", None) + kwargs.pop("ssl", None) + kwargs.pop("port", None) + + # get connection info for AlloyDB instance + ip_address, context = await instance.connection_info() + + try: + return await connector(ip_address, context, **kwargs) + except Exception: + # we attempt a force refresh, then throw the error + await instance.force_refresh() + raise + + async def __aenter__(self) -> Any: + """Enter async context manager by returning Connector object""" + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + """Exit async context manager by closing Connector""" + await self.close() + + async def close(self) -> None: + """Helper function to cancel Instances' tasks + and close client.""" + await asyncio.gather( + *[instance.close() for instance in self._instances.values()] + ) + if self._client: + await self._client.close() diff --git a/google/cloud/alloydb/connector/asyncpg.py b/google/cloud/alloydb/connector/asyncpg.py new file mode 100644 index 0000000..2652ee4 --- /dev/null +++ b/google/cloud/alloydb/connector/asyncpg.py @@ -0,0 +1,63 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ssl +from typing import Any, TYPE_CHECKING + +SERVER_PROXY_PORT = 5433 + +if TYPE_CHECKING: + import asyncpg + + +async def connect( + ip_address: str, ctx: ssl.SSLContext, **kwargs: Any +) -> "asyncpg.Connection": + """Helper function to create an asyncpg DB-API connection object. + + :type ip_address: str + :param ip_address: A string containing an IP address for the AlloyDB + instance. + + :type ctx: ssl.SSLContext + :param ctx: An SSLContext object created from the AlloyDB server CA + cert and ephemeral cert. + + :type kwargs: Any + :param kwargs: Keyword arguments for establishing asyncpg connection + object to AlloyDB instance. + + :rtype: asyncpg.Connection + :returns: An asyncpg.Connection object to an AlloyDB instance. + """ + try: + import asyncpg + except ImportError: + raise ImportError( + 'Unable to import module "asyncpg." Please install and try again.' + ) + user = kwargs.pop("user") + db = kwargs.pop("db") + passwd = kwargs.pop("password") + + return await asyncpg.connect( + user=user, + database=db, + password=passwd, + host=ip_address, + port=SERVER_PROXY_PORT, + ssl=ctx, + direct_tls=True, + **kwargs, + ) diff --git a/mypy.ini b/mypy.ini index c37fb80..4095da9 100644 --- a/mypy.ini +++ b/mypy.ini @@ -8,3 +8,6 @@ ignore_missing_imports = True [mypy-pg8000] ignore_missing_imports = True + +[mypy-asyncpg] +ignore_missing_imports = True diff --git a/tests/system/test_asyncpg_connection.py b/tests/system/test_asyncpg_connection.py new file mode 100644 index 0000000..22eda9a --- /dev/null +++ b/tests/system/test_asyncpg_connection.py @@ -0,0 +1,47 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import asyncpg +import pytest +import sqlalchemy +from sqlalchemy.ext.asyncio import create_async_engine + +from google.cloud.alloydb.connector import AsyncConnector + + +@pytest.mark.asyncio +async def test_connection_with_asyncpg() -> None: + async with AsyncConnector() as connector: + + async def getconn() -> asyncpg.Connection: + conn: asyncpg.Connection = await connector.connect( + os.environ["ALLOYDB_INSTANCE_URI"], + "asyncpg", + user=os.environ["ALLOYDB_USER"], + password=os.environ["ALLOYDB_PASS"], + db=os.environ["ALLOYDB_DB"], + ) + return conn + + # create SQLAlchemy connection pool + pool = create_async_engine( + "postgresql+asyncpg://", + async_creator=getconn, + execution_options={"isolation_level": "AUTOCOMMIT"}, + ) + async with pool.connect() as conn: + res = (await conn.execute(sqlalchemy.text("SELECT 1"))).fetchone() + assert res[0] == 1 diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index f55814e..b7ff722 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from datetime import datetime from datetime import timedelta from datetime import timezone @@ -156,6 +157,7 @@ class FakeAlloyDBClient: def __init__(self) -> None: self.instance = FakeInstance() + self.closed = False async def _get_metadata(*args: Any, **kwargs: Any) -> str: return "127.0.0.1" @@ -191,4 +193,23 @@ async def _get_client_certificate( return (ca_cert, [client_cert, intermediate_cert, root_cert]) async def close(self) -> None: - pass + self.closed = True + + +class FakeConnectionInfo: + """Fake connection info class that doesn't perform a refresh""" + + def __init__(self) -> None: + self._close_called = False + self._force_refresh_called = False + + def connection_info(self) -> Tuple[str, Any]: + f = asyncio.Future() + f.set_result(("10.0.0.1", None)) + return f + + async def force_refresh(self) -> None: + self._force_refresh_called = True + + async def close(self) -> None: + self._close_called = True diff --git a/tests/unit/test_async_connector.py b/tests/unit/test_async_connector.py new file mode 100644 index 0000000..37685d9 --- /dev/null +++ b/tests/unit/test_async_connector.py @@ -0,0 +1,193 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from mock import patch +from mocks import FakeAlloyDBClient +from mocks import FakeConnectionInfo +from mocks import FakeCredentials +import pytest + +from google.cloud.alloydb.connector import AsyncConnector + +ALLOYDB_API_ENDPOINT = "https://alloydb.googleapis.com" + + +@pytest.mark.asyncio +async def test_AsyncConnector_init(credentials: FakeCredentials) -> None: + """ + Test to check whether the __init__ method of AsyncConnector + properly sets default attributes. + """ + connector = AsyncConnector(credentials) + assert connector._quota_project is None + assert connector._alloydb_api_endpoint == ALLOYDB_API_ENDPOINT + assert connector._client is None + assert connector._credentials == credentials + await connector.close() + + +@pytest.mark.asyncio +async def test_AsyncConnector_context_manager( + credentials: FakeCredentials, +) -> None: + """ + Test to check whether the __init__ method of AsyncConnector + properly sets defaults as context manager. + """ + async with AsyncConnector(credentials) as connector: + assert connector._quota_project is None + assert connector._alloydb_api_endpoint == ALLOYDB_API_ENDPOINT + assert connector._client is None + assert connector._credentials == credentials + + +TEST_INSTANCE_NAME = "/".join( + [ + "projects", + "PROJECT", + "locations", + "REGION", + "clusters", + "CLUSTER_NAME", + "instances", + "INSTANCE_NAME", + ], +) + + +@pytest.mark.asyncio +async def test_connect_and_close(credentials: FakeCredentials) -> None: + """ + Test that connector.connect calls asyncpg.connect and cleans up + """ + with patch("google.cloud.alloydb.connector.asyncpg.connect") as connect: + # patch db connection creation and return plain future + future = asyncio.Future() + future.set_result(True) + connect.return_value = future + + connector = AsyncConnector(credentials) + connector._client = FakeAlloyDBClient() + connection = await connector.connect( + TEST_INSTANCE_NAME, + "asyncpg", + user="test-user", + password="test-password", + db="test-db", + ) + await connector.close() + + # check connection is returned + assert connection.result() is True + # outside of context manager check close cleaned up + assert connector._client.closed is True + + +@pytest.mark.asyncio +async def test_force_refresh(credentials: FakeCredentials) -> None: + """ + Test that any failed connection results in a force refresh. + """ + with patch( + "google.cloud.alloydb.connector.asyncpg.connect", + side_effect=Exception("connection failed"), + ): + connector = AsyncConnector(credentials) + connector._client = FakeAlloyDBClient() + + # Prepare cached connection info to avoid the need for two calls + fake = FakeConnectionInfo() + connector._instances[TEST_INSTANCE_NAME] = fake + + with pytest.raises(Exception): + await connector.connect( + TEST_INSTANCE_NAME, + "asyncpg", + user="test-user", + password="test-password", + db="test-db", + ) + + assert fake._force_refresh_called is True + + +@pytest.mark.asyncio +async def test_close_stops_instance(credentials: FakeCredentials) -> None: + """ + Test that any connected instances are closed when the connector is + closed. + """ + connector = AsyncConnector(credentials) + connector._client = FakeAlloyDBClient() + # Simulate connection + fake = FakeConnectionInfo() + connector._instances[TEST_INSTANCE_NAME] = fake + + await connector.close() + + assert fake._close_called is True + + +@pytest.mark.asyncio +async def test_context_manager_connect_and_close( + credentials: FakeCredentials, +) -> None: + """ + Test that connector.connect calls asyncpg.connect and cleans up using the + async context manager + """ + with patch("google.cloud.alloydb.connector.asyncpg.connect") as connect: + fake_client = FakeAlloyDBClient() + async with AsyncConnector(credentials) as connector: + connector._client = fake_client + + # patch db connection creation + future = asyncio.Future() + future.set_result(True) + connect.return_value = future + + connection = await connector.connect( + TEST_INSTANCE_NAME, + "asyncpg", + user="test-user", + password="test-password", + db="test-db", + ) + + # check connection is returned + assert connection.result() is True + # outside of context manager check close cleaned up + assert fake_client.closed is True + + +@pytest.mark.asyncio +async def test_connect_unsupported_driver( + credentials: FakeCredentials, +) -> None: + """ + Test that connector.connect errors with unsupported database driver. + """ + client = FakeAlloyDBClient() + async with AsyncConnector(credentials) as connector: + connector._client = client + # try to connect using unsupported driver, should raise ValueError + with pytest.raises(ValueError) as exc_info: + await connector.connect(TEST_INSTANCE_NAME, "bad_driver") + # assert custom error message for unsupported driver is present + assert ( + exc_info.value.args[0] + == "Driver 'bad_driver' is not a supported database driver." + )