diff --git a/dnstapir/key_resolver.py b/dnstapir/key_resolver.py index ff8ed7a..455d0d6 100644 --- a/dnstapir/key_resolver.py +++ b/dnstapir/key_resolver.py @@ -1,7 +1,8 @@ import logging +import re from abc import abstractmethod from pathlib import Path -from urllib.parse import urljoin +from urllib.parse import urljoin, urlparse import httpx from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey @@ -34,11 +35,16 @@ def key_resolver_from_client_database(client_database: str, key_cache: KeyCache class KeyResolver: def __init__(self): self.logger = logging.getLogger(__name__).getChild(self.__class__.__name__) + self.key_id_validator = re.compile(r"^[a-zA-Z0-9_-]+$") @abstractmethod def resolve_public_key(self, key_id: str) -> PublicKey: pass + def validate_key_id(self, key_id: str) -> None: + if not self.key_id_validator.match(key_id): + raise ValueError(f"Invalid key_id format: {key_id}") + class CacheKeyResolver(KeyResolver): def __init__(self, key_cache: KeyCache | None): @@ -69,6 +75,7 @@ def __init__(self, client_database_directory: str, key_cache: KeyCache | None = def get_public_key_pem(self, key_id: str) -> bytes: with tracer.start_as_current_span("get_public_key_pem_from_file"): + self.validate_key_id(key_id) filename = Path(self.client_database_directory) / f"{key_id}.pem" self.logger.debug("Fetching public key for %s from %s", key_id, filename) try: @@ -87,10 +94,16 @@ def __init__(self, client_database_base_url: str, key_cache: KeyCache | None = N def get_public_key_pem(self, key_id: str) -> bytes: with tracer.start_as_current_span("get_public_key_pem_from_url"): + self.validate_key_id(key_id) + if self.key_id_pattern in self.client_database_base_url: public_key_url = self.client_database_base_url.replace(self.key_id_pattern, key_id) else: public_key_url = urljoin(self.client_database_base_url, f"{key_id}.pem") + + if urlparse(public_key_url).scheme not in ("http", "https"): + raise ValueError(f"Invalid URL constructed: {public_key_url}") + self.logger.debug("Fetching public key for %s from %s", key_id, public_key_url) try: response = self.httpx_client.get(public_key_url) diff --git a/tests/test_key_resolver.py b/tests/test_key_resolver.py index 9958337..8c2a9fe 100644 --- a/tests/test_key_resolver.py +++ b/tests/test_key_resolver.py @@ -46,6 +46,9 @@ def test_url_key_resolver(httpx_mock: HTTPXMock): request = httpx_mock.get_request() assert request.headers["Accept"] == "application/x-pem-file" + with pytest.raises(ValueError): + _ = resolver.resolve_public_key("🔐") + with pytest.raises(KeyError): _ = resolver.resolve_public_key("unknown")