Skip to content

Commit

Permalink
Ensure key identifiers are sane and verify URLs
Browse files Browse the repository at this point in the history
  • Loading branch information
jschlyter committed Dec 20, 2024
1 parent b68e81d commit 4bbce0f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
15 changes: 14 additions & 1 deletion dnstapir/key_resolver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
import re
from abc import abstractmethod
from pathlib import Path
from urllib.parse import urljoin
from urllib.parse import urljoin, urlparse

import httpx
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey
Expand Down Expand Up @@ -34,11 +35,16 @@ def key_resolver_from_client_database(client_database: str, key_cache: KeyCache
class KeyResolver:
def __init__(self):
self.logger = logging.getLogger(__name__).getChild(self.__class__.__name__)
self.key_id_validator = re.compile(r"^[a-zA-Z0-9_-]+$")

@abstractmethod
def resolve_public_key(self, key_id: str) -> PublicKey:
pass

def validate_key_id(self, key_id: str) -> None:
if not self.key_id_validator.match(key_id):
raise ValueError(f"Invalid key_id format: {key_id}")


class CacheKeyResolver(KeyResolver):
def __init__(self, key_cache: KeyCache | None):
Expand Down Expand Up @@ -69,6 +75,7 @@ def __init__(self, client_database_directory: str, key_cache: KeyCache | None =

def get_public_key_pem(self, key_id: str) -> bytes:
with tracer.start_as_current_span("get_public_key_pem_from_file"):
self.validate_key_id(key_id)
filename = Path(self.client_database_directory) / f"{key_id}.pem"
self.logger.debug("Fetching public key for %s from %s", key_id, filename)
try:
Expand All @@ -87,10 +94,16 @@ def __init__(self, client_database_base_url: str, key_cache: KeyCache | None = N

def get_public_key_pem(self, key_id: str) -> bytes:
with tracer.start_as_current_span("get_public_key_pem_from_url"):
self.validate_key_id(key_id)

if self.key_id_pattern in self.client_database_base_url:
public_key_url = self.client_database_base_url.replace(self.key_id_pattern, key_id)
else:
public_key_url = urljoin(self.client_database_base_url, f"{key_id}.pem")

if urlparse(public_key_url).scheme not in ("http", "https"):
raise ValueError(f"Invalid URL constructed: {public_key_url}")

self.logger.debug("Fetching public key for %s from %s", key_id, public_key_url)
try:
response = self.httpx_client.get(public_key_url)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_key_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def test_url_key_resolver(httpx_mock: HTTPXMock):
request = httpx_mock.get_request()
assert request.headers["Accept"] == "application/x-pem-file"

with pytest.raises(ValueError):
_ = resolver.resolve_public_key("🔐")

with pytest.raises(KeyError):
_ = resolver.resolve_public_key("unknown")

Expand Down

0 comments on commit 4bbce0f

Please sign in to comment.