diff --git a/dnstapir/key_resolver.py b/dnstapir/key_resolver.py index 378665b..455d0d6 100644 --- a/dnstapir/key_resolver.py +++ b/dnstapir/key_resolver.py @@ -1,6 +1,8 @@ +import logging +import re from abc import abstractmethod from pathlib import Path -from urllib.parse import urljoin +from urllib.parse import urljoin, urlparse import httpx from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey @@ -31,13 +33,22 @@ def key_resolver_from_client_database(client_database: str, key_cache: KeyCache class KeyResolver: + def __init__(self): + self.logger = logging.getLogger(__name__).getChild(self.__class__.__name__) + self.key_id_validator = re.compile(r"^[a-zA-Z0-9_-]+$") + @abstractmethod def resolve_public_key(self, key_id: str) -> PublicKey: pass + def validate_key_id(self, key_id: str) -> None: + if not self.key_id_validator.match(key_id): + raise ValueError(f"Invalid key_id format: {key_id}") + class CacheKeyResolver(KeyResolver): def __init__(self, key_cache: KeyCache | None): + super().__init__() self.key_cache = key_cache @abstractmethod @@ -64,7 +75,9 @@ def __init__(self, client_database_directory: str, key_cache: KeyCache | None = def get_public_key_pem(self, key_id: str) -> bytes: with tracer.start_as_current_span("get_public_key_pem_from_file"): + self.validate_key_id(key_id) filename = Path(self.client_database_directory) / f"{key_id}.pem" + self.logger.debug("Fetching public key for %s from %s", key_id, filename) try: with open(filename, "rb") as fp: return fp.read() @@ -77,10 +90,21 @@ def __init__(self, client_database_base_url: str, key_cache: KeyCache | None = N super().__init__(key_cache=key_cache) self.client_database_base_url = client_database_base_url self._httpx_client: httpx.Client | None = None + self.key_id_pattern = "%s" def get_public_key_pem(self, key_id: str) -> bytes: with tracer.start_as_current_span("get_public_key_pem_from_url"): - public_key_url = urljoin(self.client_database_base_url, f"{key_id}.pem") + self.validate_key_id(key_id) + + if self.key_id_pattern in self.client_database_base_url: + public_key_url = self.client_database_base_url.replace(self.key_id_pattern, key_id) + else: + public_key_url = urljoin(self.client_database_base_url, f"{key_id}.pem") + + if urlparse(public_key_url).scheme not in ("http", "https"): + raise ValueError(f"Invalid URL constructed: {public_key_url}") + + self.logger.debug("Fetching public key for %s from %s", key_id, public_key_url) try: response = self.httpx_client.get(public_key_url) response.raise_for_status() @@ -91,7 +115,7 @@ def get_public_key_pem(self, key_id: str) -> bytes: @property def httpx_client(self) -> httpx.Client: if self._httpx_client is None: - self._httpx_client = httpx.Client() + self._httpx_client = httpx.Client(headers={"Accept": "application/x-pem-file"}) return self._httpx_client def __enter__(self): diff --git a/pyproject.toml b/pyproject.toml index 8c9d294..c3ceaa0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "dnstapir" -version = "1.1.0" +version = "1.2.0" description = "DNS TAPIR Python Library" authors = ["Jakob Schlyter "] readme = "README.md" diff --git a/tests/test_key_resolver.py b/tests/test_key_resolver.py index 48528d9..5ce701a 100644 --- a/tests/test_key_resolver.py +++ b/tests/test_key_resolver.py @@ -1,8 +1,35 @@ +from pathlib import Path +from tempfile import TemporaryDirectory + +import pytest from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import ed25519 from pytest_httpx import HTTPXMock -from dnstapir.key_resolver import UrlKeyResolver +from dnstapir.key_resolver import FileKeyResolver, UrlKeyResolver + + +def test_file_key_resolver(httpx_mock: HTTPXMock): + key_id = "xyzzy" + public_key = ed25519.Ed25519PrivateKey.generate().public_key() + public_key_pem = public_key.public_bytes( + encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + + with TemporaryDirectory(prefix="dnstapir") as directory: + pem_filename = Path(directory) / f"{key_id}.pem" + with open(pem_filename, "wb") as fp: + fp.write(public_key_pem) + + resolver = FileKeyResolver(client_database_directory=directory) + res = resolver.resolve_public_key(key_id) + assert res == public_key + + with pytest.raises(ValueError): + _ = resolver.resolve_public_key("🔐") + + with pytest.raises(KeyError): + _ = resolver.resolve_public_key("unknown") def test_url_key_resolver(httpx_mock: HTTPXMock): @@ -13,11 +40,42 @@ def test_url_key_resolver(httpx_mock: HTTPXMock): ) httpx_mock.add_response(url=f"https://keys/{key_id}.pem", content=public_key_pem) + httpx_mock.add_response(url="https://keys/unknown.pem", status_code=404) resolver = UrlKeyResolver(client_database_base_url="https://keys") res = resolver.resolve_public_key(key_id) assert res == public_key + request = httpx_mock.get_request() + assert request.headers["Accept"] == "application/x-pem-file" + + with pytest.raises(ValueError): + _ = resolver.resolve_public_key("🔐") + + with pytest.raises(KeyError): + _ = resolver.resolve_public_key("unknown") + + +def test_url_key_resolver_pattern(httpx_mock: HTTPXMock): + key_id = "xyzzy" + public_key = ed25519.Ed25519PrivateKey.generate().public_key() + public_key_pem = public_key.public_bytes( + encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + + httpx_mock.add_response(url=f"https://nodeman/api/v1/node/{key_id}/public_key", content=public_key_pem) + httpx_mock.add_response(url="https://nodeman/api/v1/node/unknown/public_key", status_code=404) + + resolver = UrlKeyResolver(client_database_base_url="https://nodeman/api/v1/node/%s/public_key") + res = resolver.resolve_public_key(key_id) + assert res == public_key + + request = httpx_mock.get_request() + assert request.headers["Accept"] == "application/x-pem-file" + + with pytest.raises(KeyError): + _ = resolver.resolve_public_key("unknown") + def test_url_key_resolver_contextlib(httpx_mock: HTTPXMock): key_id = "xyzzy" @@ -27,7 +85,11 @@ def test_url_key_resolver_contextlib(httpx_mock: HTTPXMock): ) httpx_mock.add_response(url=f"https://keys/{key_id}.pem", content=public_key_pem) + httpx_mock.add_response(url="https://keys/unknown.pem", status_code=404) with UrlKeyResolver(client_database_base_url="https://keys") as resolver: res = resolver.resolve_public_key(key_id) assert res == public_key + + with pytest.raises(KeyError): + _ = resolver.resolve_public_key("unknown")