diff --git a/google/cloud/alloydb/connector/instance.py b/google/cloud/alloydb/connector/instance.py index 3e5545d..7d58b8e 100644 --- a/google/cloud/alloydb/connector/instance.py +++ b/google/cloud/alloydb/connector/instance.py @@ -16,6 +16,7 @@ import asyncio import logging +import re from typing import Tuple, TYPE_CHECKING from google.cloud.alloydb.connector.exceptions import RefreshError @@ -33,6 +34,27 @@ logger = logging.getLogger(name=__name__) +INSTANCE_URI_REGEX = re.compile( + "projects/([^:]+(:[^:]+)?)/locations/([^:]+)/clusters/([^:]+)/instances/([^:]+)" +) + + +def _parse_instance_uri(instance_uri: str) -> Tuple[str, str, str, str]: + # should take form "projects//locations//clusters//instances/" + if INSTANCE_URI_REGEX.fullmatch(instance_uri) is None: + raise ValueError( + "Arg `instance_uri` must have " + "format: projects//locations//clusters//instances/, projects/:/locations//clusters//instances/" + f"got {instance_uri}." + ) + instance_uri_split = INSTANCE_URI_REGEX.split(instance_uri) + return ( + instance_uri_split[1], + instance_uri_split[3], + instance_uri_split[4], + instance_uri_split[5], + ) + class Instance: """ @@ -56,21 +78,11 @@ def __init__( keys: asyncio.Future[Tuple[rsa.RSAPrivateKey, str]], ) -> None: # validate and parse instance_uri - instance_uri_split = instance_uri.split("/") - # should take form "projects//locations//clusters//instances/" - if len(instance_uri_split) == 8: - self._instance_uri = instance_uri - self._project = instance_uri_split[1] - self._region = instance_uri_split[3] - self._cluster = instance_uri_split[5] - self._name = instance_uri_split[7] - else: - raise ValueError( - "Arg `instance_uri` must have " - "format: projects//locations//clusters//instances/, " - f"got {instance_uri}." - ) + self._project, self._region, self._cluster, self._name = _parse_instance_uri( + instance_uri + ) + self._instance_uri = instance_uri self._client = client self._keys = keys self._refresh_rate_limiter = AsyncRateLimiter( diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index cf518e9..1a0d491 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -14,17 +14,54 @@ import asyncio from datetime import datetime, timedelta +from typing import Tuple import aiohttp from mocks import FakeAlloyDBClient import pytest from google.cloud.alloydb.connector.exceptions import RefreshError -from google.cloud.alloydb.connector.instance import Instance +from google.cloud.alloydb.connector.instance import _parse_instance_uri, Instance from google.cloud.alloydb.connector.refresh import _is_valid, RefreshResult from google.cloud.alloydb.connector.utils import generate_keys +@pytest.mark.parametrize( + "instance_uri, expected", + [ + ( + "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance", + ("test-project", "test-region", "test-cluster", "test-instance"), + ), + ( + "projects/test-domain:test-project/locations/test-region/clusters/test-cluster/instances/test-instance", + ( + "test-domain:test-project", + "test-region", + "test-cluster", + "test-instance", + ), + ), + ], +) +def test_parse_instance_uri( + instance_uri: str, expected: Tuple[str, str, str, str] +) -> None: + """ + Test that _parse_instance_uri works correctly on + normal instance uri and domain-scoped projects. + """ + assert expected == _parse_instance_uri(instance_uri) + + +def test_parse_bad_instance_uri() -> None: + """ + Tests that ValueError is thrown for bad instance uri. + """ + with pytest.raises(ValueError): + _parse_instance_uri("test-project:test-instance") + + @pytest.mark.asyncio async def test_Instance_init() -> None: """