diff --git a/google/cloud/alloydb/connector/instance.py b/google/cloud/alloydb/connector/instance.py index 06f3c17..40b2d77 100644 --- a/google/cloud/alloydb/connector/instance.py +++ b/google/cloud/alloydb/connector/instance.py @@ -39,7 +39,7 @@ ) -def parse_instance_uri(instance_uri: str) -> Tuple[str, str, str, str]: +def _parse_instance_uri(instance_uri: str) -> Tuple[str, str, str, str]: # should take form "projects//locations//clusters//instances/" if INSTANCE_URI_REGEX.fullmatch(instance_uri) is None: raise ValueError( @@ -47,9 +47,7 @@ def parse_instance_uri(instance_uri: str) -> Tuple[str, str, str, str]: "format: projects//locations//clusters//instances/, projects/:/locations//clusters//instances/" f"got {instance_uri}." ) - instance_uri_split = INSTANCE_URI_REGEX.split(instance_uri) - return ( instance_uri_split[1], instance_uri_split[3], @@ -79,17 +77,15 @@ def __init__( client: AlloyDBClient, keys: asyncio.Future[Tuple[rsa.RSAPrivateKey, str]], ) -> None: - - self._instance_uri = instance_uri - # validate and parse instance_uri ( self._project, self._region, self._cluster, self._name, - ) = parse_instance_uri(instance_uri) + ) = _parse_instance_uri(instance_uri) + self._instance_uri = instance_uri self._client = client self._keys = keys self._refresh_rate_limiter = AsyncRateLimiter( diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 7970d9c..2a76169 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -14,36 +14,50 @@ import asyncio from datetime import datetime, timedelta +from typing import Tuple import aiohttp from mocks import FakeAlloyDBClient import pytest from google.cloud.alloydb.connector.exceptions import RefreshError -from google.cloud.alloydb.connector.instance import Instance, parse_instance_uri +from google.cloud.alloydb.connector.instance import _parse_instance_uri, Instance from google.cloud.alloydb.connector.refresh import _is_valid, RefreshResult from google.cloud.alloydb.connector.utils import generate_keys -@pytest.mark.asycio -async def test_parser_instance_uri() -> None: +@pytest.mark.parametrize( + "instance_uri, expected", + [ + ( + "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance", + ("test-project", "test-region", "test-cluster", "test-instance"), + ), + ( + "projects/test-domain:test-project/locations/test-region/clusters/test-cluster/instances/test-instance", + ( + "test-domain:test-project", + "test-region", + "test-cluster", + "test-instance", + ), + ), + ], +) +def test_parse_instance_uri(instance_uri: str, expected: Tuple[str, str, str]) -> None: """ - Test to check whether the __init__ method of Instance - can tell if the instance URI that's passed in is formatted correctly. + Test that _parse_instance_uri correctly on + normal instance uri and domain-scoped projects. """ - instance_uri = [ - "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance", - "projects/google.com:test-project/locations/test-region/clusters/test-cluster/instances/test-instance", - ] + assert expected == _parse_instance_uri(instance_uri) - for uri in instance_uri: - project, region, cluster, name = parse_instance_uri(uri) - assert ( - project in ["test-project", "google.com:test-project"] - and region == "test-region" - and cluster == "test-cluster" - and name == "test-instance" - ) + +def test_parse_bad_instance_uri() -> None: + """ + Tests that ValueError is thrown for bad instance uri. + """ + with pytest.raises(ValueError): + _parse_instance_uri("test-project:test-instance") @pytest.mark.asyncio