diff --git a/google/cloud/alloydb/connector/async_connector.py b/google/cloud/alloydb/connector/async_connector.py index 5a2b9f5..d1961a6 100644 --- a/google/cloud/alloydb/connector/async_connector.py +++ b/google/cloud/alloydb/connector/async_connector.py @@ -194,7 +194,9 @@ def get_authentication_token() -> str: if enable_iam_auth: kwargs["password"] = get_authentication_token try: - return await connector(ip_address, conn_info.create_ssl_context(), **kwargs) + return await connector( + ip_address, await conn_info.create_ssl_context(), **kwargs + ) except Exception: # we attempt a force refresh, then throw the error await cache.force_refresh() diff --git a/google/cloud/alloydb/connector/connection_info.py b/google/cloud/alloydb/connector/connection_info.py index 93c514b..9cd4876 100644 --- a/google/cloud/alloydb/connector/connection_info.py +++ b/google/cloud/alloydb/connector/connection_info.py @@ -17,9 +17,10 @@ from dataclasses import dataclass import logging import ssl -from tempfile import TemporaryDirectory from typing import Dict, List, Optional, TYPE_CHECKING +from aiofiles.tempfile import TemporaryDirectory + from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError from google.cloud.alloydb.connector.utils import _write_to_file @@ -45,7 +46,7 @@ class ConnectionInfo: expiration: datetime.datetime context: Optional[ssl.SSLContext] = None - def create_ssl_context(self) -> ssl.SSLContext: + async def create_ssl_context(self) -> ssl.SSLContext: """Constructs a SSL/TLS context for the given connection info. Cache the SSL context to ensure we don't read from disk repeatedly when @@ -66,8 +67,8 @@ def create_ssl_context(self) -> ssl.SSLContext: # tmpdir and its contents are automatically deleted after the CA cert # and cert chain are loaded into the SSLcontext. The values # need to be written to files in order to be loaded by the SSLContext - with TemporaryDirectory() as tmpdir: - ca_filename, cert_chain_filename, key_filename = _write_to_file( + async with TemporaryDirectory() as tmpdir: + ca_filename, cert_chain_filename, key_filename = await _write_to_file( tmpdir, self.ca_cert, self.cert_chain, self.key ) context.load_cert_chain(cert_chain_filename, keyfile=key_filename) diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index c4e8bae..2d1bfee 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -215,7 +215,7 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> metadata_partial = partial( self.metadata_exchange, ip_address, - conn_info.create_ssl_context(), + await conn_info.create_ssl_context(), enable_iam_auth, driver, ) diff --git a/google/cloud/alloydb/connector/utils.py b/google/cloud/alloydb/connector/utils.py index a549f70..4e558d3 100644 --- a/google/cloud/alloydb/connector/utils.py +++ b/google/cloud/alloydb/connector/utils.py @@ -16,11 +16,12 @@ from typing import List, Tuple +import aiofiles from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa -def _write_to_file( +async def _write_to_file( dir_path: str, ca_cert: str, cert_chain: List[str], key: rsa.RSAPrivateKey ) -> Tuple[str, str, str]: """ @@ -37,12 +38,12 @@ def _write_to_file( encryption_algorithm=serialization.NoEncryption(), ) - with open(ca_filename, "w+") as ca_out: - ca_out.write(ca_cert) - with open(cert_chain_filename, "w+") as chain_out: - chain_out.write("".join(cert_chain)) - with open(key_filename, "wb") as priv_out: - priv_out.write(key_bytes) + async with aiofiles.open(ca_filename, "w+") as ca_out: + await ca_out.write(ca_cert) + async with aiofiles.open(cert_chain_filename, "w+") as chain_out: + await chain_out.write("".join(cert_chain)) + async with aiofiles.open(key_filename, "wb") as priv_out: + await priv_out.write(key_bytes) return (ca_filename, cert_chain_filename, key_filename) diff --git a/requirements.txt b/requirements.txt index ca256ce..29c89ce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +aiofiles==24.1.0 aiohttp==3.9.5 cryptography==42.0.8 google-auth==2.32.0 diff --git a/setup.py b/setup.py index 90f407d..0033a3e 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,7 @@ release_status = "Development Status :: 5 - Production/Stable" dependencies = [ + "aiofiles", "aiohttp", "cryptography>=42.0.0", "requests", diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index d001a59..45648fa 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import socket import ssl -from tempfile import TemporaryDirectory from threading import Thread from typing import Generator +from aiofiles.tempfile import TemporaryDirectory from mocks import FakeAlloyDBClient from mocks import FakeCredentials from mocks import FakeInstance @@ -42,7 +43,7 @@ def fake_client(fake_instance: FakeInstance) -> FakeAlloyDBClient: return FakeAlloyDBClient(fake_instance) -def start_proxy_server(instance: FakeInstance) -> None: +async def start_proxy_server(instance: FakeInstance) -> None: """Run local proxy server capable of performing metadata exchange""" ip_address = "127.0.0.1" port = 5433 @@ -55,8 +56,8 @@ def start_proxy_server(instance: FakeInstance) -> None: # tmpdir and its contents are automatically deleted after the CA cert # and cert chain are loaded into the SSLcontext. The values # need to be written to files in order to be loaded by the SSLContext - with TemporaryDirectory() as tmpdir: - _, cert_chain_filename, key_filename = _write_to_file( + async with TemporaryDirectory() as tmpdir: + _, cert_chain_filename, key_filename = await _write_to_file( tmpdir, server, [server, root], instance.server_key ) context.load_cert_chain(cert_chain_filename, key_filename) @@ -76,7 +77,15 @@ def start_proxy_server(instance: FakeInstance) -> None: @pytest.fixture(scope="session") def proxy_server(fake_instance: FakeInstance) -> Generator: """Run local proxy server capable of performing metadata exchange""" - thread = Thread(target=start_proxy_server, args=(fake_instance,), daemon=True) + thread = Thread( + target=asyncio.run, + args=( + start_proxy_server( + fake_instance, + ), + ), + daemon=True, + ) thread.start() yield thread thread.join() diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 9ee22d6..ae60035 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -370,7 +370,7 @@ def get_preferred_ip(self, ip_type: Any) -> Tuple[str, Any]: f.set_result("10.0.0.1") return f - def create_ssl_context(self) -> None: + async def create_ssl_context(self) -> None: return None async def force_refresh(self) -> None: diff --git a/tests/unit/test_connection_info.py b/tests/unit/test_connection_info.py index d6867d9..4c4cf30 100644 --- a/tests/unit/test_connection_info.py +++ b/tests/unit/test_connection_info.py @@ -29,7 +29,7 @@ from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError -def test_ConnectionInfo_init_(fake_instance: FakeInstance) -> None: +async def test_ConnectionInfo_init_(fake_instance: FakeInstance) -> None: """ Test to check whether the __init__ method of ConnectionInfo can correctly initialize TLS context. @@ -58,19 +58,19 @@ def test_ConnectionInfo_init_(fake_instance: FakeInstance) -> None: fake_instance.ip_addrs, datetime.now(timezone.utc) + timedelta(minutes=10), ) - context = conn_info.create_ssl_context() + context = await conn_info.create_ssl_context() # verify TLS requirements assert context.minimum_version == ssl.TLSVersion.TLSv1_3 -def test_ConnectionInfo_caches_sslcontext() -> None: +async def test_ConnectionInfo_caches_sslcontext() -> None: info = ConnectionInfo(["cert"], "cert", "key".encode(), {}, datetime.now()) # context should default to None assert info.context is None # cache a 'context' info.context = "context" # calling create_ssl_context should no-op with an existing 'context' - info.create_ssl_context() + await info.create_ssl_context() assert info.context == "context"