Skip to content

Commit

Permalink
Enable internal CA and add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jschlyter committed Dec 16, 2024
1 parent 493ecc6 commit 4376d8e
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 46 deletions.
57 changes: 39 additions & 18 deletions nodeman/internal_ca.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Self

from cryptography import x509
Expand Down Expand Up @@ -32,32 +33,47 @@ class InternalCertificateAuthority(CertificateAuthorityClient):

def __init__(
self,
ca_certificate: x509.Certificate,
ca_private_key: PrivateKey,
issuer_ca_certificate: x509.Certificate,
issuer_ca_private_key: PrivateKey,
root_ca_certificate: x509.Certificate | None = None,
validity: timedelta | None = None,
time_skew: timedelta | None = None,
):
self.ca_certificate = ca_certificate
self.ca_private_key = ca_private_key
self.issuer_ca_certificate = issuer_ca_certificate
self.issuer_ca_private_key = issuer_ca_private_key
self.root_ca_certificate = root_ca_certificate or issuer_ca_certificate
self.time_skew = time_skew or timedelta(minutes=10)
self.validity = validity or timedelta(minutes=10)
self.signature_hash_algorithm = get_hash_algorithm_from_key(self.ca_private_key)
self.signature_hash_algorithm = get_hash_algorithm_from_key(self.issuer_ca_private_key)

@classmethod
def load(
cls,
ca_certificate_file: str,
ca_private_key_file: str,
issuer_ca_certificate_file: Path,
issuer_ca_private_key_file: Path,
root_ca_certificate_file: Path | None = None,
validity: timedelta | None = None,
time_skew: timedelta | None = None,
) -> Self:
with open(ca_certificate_file, "rb") as fp:
ca_certificate = x509.load_pem_x509_certificate(fp.read())

with open(ca_private_key_file, "rb") as fp:
ca_private_key = load_pem_private_key(fp.read(), password=None)

return cls(ca_certificate=ca_certificate, ca_private_key=ca_private_key, validity=validity, time_skew=time_skew)
with open(issuer_ca_certificate_file, "rb") as fp:
issuer_ca_certificate = x509.load_pem_x509_certificate(fp.read())

with open(issuer_ca_private_key_file, "rb") as fp:
issuer_ca_private_key = load_pem_private_key(fp.read(), password=None)

if root_ca_certificate_file:
with open(root_ca_certificate_file, "rb") as fp:
root_ca_certificate = x509.load_pem_x509_certificate(fp.read())
else:
root_ca_certificate = None

return cls(
issuer_ca_certificate=issuer_ca_certificate,
issuer_ca_private_key=issuer_ca_private_key,
root_ca_certificate=root_ca_certificate,
validity=validity,
time_skew=time_skew,
)

def sign_csr(self, csr: x509.CertificateSigningRequest, name: str) -> CertificateInformation:
"""Sign CSR with CA private key"""
Expand All @@ -66,7 +82,7 @@ def sign_csr(self, csr: x509.CertificateSigningRequest, name: str) -> Certificat

builder = x509.CertificateBuilder()
builder = builder.subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, name)]))
builder = builder.issuer_name(self.ca_certificate.subject)
builder = builder.issuer_name(self.issuer_ca_certificate.subject)
builder = builder.not_valid_before(now - self.time_skew)
builder = builder.not_valid_after(now + self.validity)
builder = builder.serial_number(x509.random_serial_number())
Expand All @@ -77,14 +93,19 @@ def sign_csr(self, csr: x509.CertificateSigningRequest, name: str) -> Certificat

builder = builder.add_extension(x509.SubjectKeyIdentifier.from_public_key(csr.public_key()), critical=False)
builder = builder.add_extension(
x509.AuthorityKeyIdentifier.from_issuer_public_key(self.ca_certificate.public_key()), critical=False
x509.AuthorityKeyIdentifier.from_issuer_public_key(self.issuer_ca_certificate.public_key()), critical=False
)
builder = builder.add_extension(x509.SubjectAlternativeName([x509.DNSName(name)]), critical=False)
builder = builder.add_extension(
x509.BasicConstraints(ca=False, path_length=None),
critical=True,
)

certificate = builder.sign(private_key=self.ca_private_key, algorithm=self.signature_hash_algorithm)
certificate = builder.sign(private_key=self.issuer_ca_private_key, algorithm=self.signature_hash_algorithm)

if self.root_ca_certificate != self.issuer_ca_certificate:
cert_chain = [certificate, self.issuer_ca_certificate]
else:
cert_chain = [certificate]

return CertificateInformation(cert_chain=[certificate], ca_cert=self.ca_certificate)
return CertificateInformation(cert_chain=cert_chain, ca_cert=self.root_ca_certificate)
21 changes: 19 additions & 2 deletions nodeman/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import logging
from contextlib import asynccontextmanager
from datetime import timedelta

import mongoengine
import uvicorn
Expand All @@ -15,7 +16,8 @@
from dnstapir.opentelemetry import configure_opentelemetry

from . import OPENAPI_METADATA, __verbose_version__
from .settings import Settings, StepSettings
from .internal_ca import InternalCertificateAuthority
from .settings import InternalCaSettings, Settings, StepSettings
from .step import StepClient
from .x509 import CertificateAuthorityClient

Expand Down Expand Up @@ -62,7 +64,22 @@ def __init__(self, settings: Settings):
self.logger.warning("Starting without users")

self.ca_client: CertificateAuthorityClient | None
self.ca_client = self.get_step_client(self.settings.step) if self.settings.step else None

if self.settings.internal_ca:
self.ca_client = self.get_internal_ca_client(self.settings.internal_ca)
elif self.settings.step:
self.ca_client = self.get_step_client(self.settings.step)
else:
self.ca_client = None

@staticmethod
def get_internal_ca_client(settings: InternalCaSettings) -> InternalCertificateAuthority:
return InternalCertificateAuthority.load(
issuer_ca_certificate_file=settings.issuer_ca_certificate,
issuer_ca_private_key_file=settings.issuer_ca_private_key,
root_ca_certificate_file=settings.root_ca_certificate,
validity=timedelta(seconds=settings.validity),
)

@staticmethod
def get_step_client(settings: StepSettings) -> StepClient:
Expand Down
11 changes: 10 additions & 1 deletion nodeman/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ class StepSettings(BaseModel):
provisioner_private_key: FilePath


class InternalCaSettings(BaseModel):
issuer_ca_certificate: FilePath
issuer_ca_private_key: FilePath
root_ca_certificate: FilePath | None = None
validity: int = Field(default=24 * 60 * 60 * 90)


class NodesSettings(BaseModel):
domain: str = Field(default="example.com")
trusted_keys: FilePath | None = Field(default=None)
Expand All @@ -57,10 +64,12 @@ def create(cls, username: str, password: str) -> Self:

class Settings(BaseSettings):
mongodb: MongoDB = Field(default=MongoDB())
step: StepSettings | None = None
otlp: OtlpSettings | None = None
users: list[User] = Field(default=[])

step: StepSettings | None = None
internal_ca: InternalCaSettings | None = None

nodes: NodesSettings = Field(default=NodesSettings())

model_config = SettingsConfigDict(toml_file="nodeman.toml")
Expand Down
12 changes: 9 additions & 3 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PrivateKey
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
from cryptography.hazmat.primitives.serialization import load_pem_public_key
from cryptography.x509.oid import NameOID
from fastapi import status
from fastapi.testclient import TestClient
from jwcrypto.jwk import JWK
Expand All @@ -34,16 +35,21 @@
settings = Settings()


def get_ca_client(ca_name: str) -> CertificateAuthorityClient:
def get_ca_client() -> CertificateAuthorityClient:
ca_name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Internal Test CA")])
ca_private_key = ec.generate_private_key(ec.SECP256R1())
ca_certificate = generate_ca_certificate(ca_name, ca_private_key)
validity = timedelta(minutes=10)
return InternalCertificateAuthority(ca_certificate=ca_certificate, ca_private_key=ca_private_key, validity=validity)
return InternalCertificateAuthority(
issuer_ca_certificate=ca_certificate,
issuer_ca_private_key=ca_private_key,
validity=validity,
)


def get_test_client() -> TestClient:
app = NodemanServer(settings)
app.ca_client = get_ca_client("ca.example.com")
app.ca_client = get_ca_client()
app.connect_mongodb()
return TestClient(app)

Expand Down
65 changes: 51 additions & 14 deletions tests/test_internal_ca.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from datetime import timedelta
from pathlib import Path
from tempfile import NamedTemporaryFile

from cryptography import x509
Expand All @@ -10,10 +11,21 @@
from cryptography.x509.oid import NameOID

from nodeman.internal_ca import InternalCertificateAuthority
from nodeman.x509 import PrivateKey, generate_x509_csr, verify_x509_csr
from nodeman.x509 import CertificateInformation, PrivateKey, generate_x509_csr, verify_x509_csr
from tests.utils import generate_ca_certificate


def _verify_certification_information(res: CertificateInformation) -> None:
store = x509.verification.Store([res.ca_cert])
builder = x509.verification.PolicyBuilder()
builder = builder.store(store)
verifier = builder.build_client_verifier()
peer_certificate = res.cert_chain[0]
untrusted_intermediates = res.cert_chain[1:]
verified_client = verifier.verify(peer_certificate, untrusted_intermediates)
assert verified_client.subjects is not None


def _test_internal_ca(ca_private_key: PrivateKey, verify: bool = True) -> None:
"""Test Internal CA"""

Expand All @@ -22,7 +34,7 @@ def _test_internal_ca(ca_private_key: PrivateKey, verify: bool = True) -> None:

validity = timedelta(minutes=10)
ca_client = InternalCertificateAuthority(
ca_certificate=ca_certificate, ca_private_key=ca_private_key, validity=validity
issuer_ca_certificate=ca_certificate, issuer_ca_private_key=ca_private_key, validity=validity
)

name = "hostname.example.com"
Expand Down Expand Up @@ -50,19 +62,43 @@ def _test_internal_ca(ca_private_key: PrivateKey, verify: bool = True) -> None:
print(x509_ca_certificate_pem)

if verify:
store = x509.verification.Store([res.ca_cert])
builder = x509.verification.PolicyBuilder()
builder = builder.store(store)
verifier = builder.build_client_verifier()
peer_certificate = res.cert_chain[0]
untrusted_intermediates = res.cert_chain[1:]
verified_client = verifier.verify(peer_certificate, untrusted_intermediates)
assert verified_client.subjects is not None
_verify_certification_information(res)


def test_internal_sub_ca() -> None:
"""Test internal issuer CA with separate root CA"""

root_ca_name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Root Test CA")])
root_ca_private_key = ec.generate_private_key(ec.SECP256R1())
root_ca_certificate = generate_ca_certificate(root_ca_name, root_ca_private_key)

issuer_ca_name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Issuer Test CA")])
issuer_ca_private_key = ec.generate_private_key(ec.SECP256R1())
issuer_ca_certificate = generate_ca_certificate(
issuer_ca_name=issuer_ca_name,
issuer_ca_private_key=issuer_ca_private_key,
root_ca_name=root_ca_name,
root_ca_private_key=root_ca_private_key,
)

validity = timedelta(minutes=10)
ca_client = InternalCertificateAuthority(
issuer_ca_certificate=issuer_ca_certificate,
issuer_ca_private_key=issuer_ca_private_key,
root_ca_certificate=root_ca_certificate,
validity=validity,
)

name = "hostname.example.com"
key = ec.generate_private_key(ec.SECP256R1())
csr = generate_x509_csr(key=key, name=name)

res = ca_client.sign_csr(csr, name)
_verify_certification_information(res)


def test_internal_ca_file() -> None:
ca_name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Internal Test CA")])

ca_private_key = ec.generate_private_key(ec.SECP256R1())
ca_certificate = generate_ca_certificate(ca_name, ca_private_key)

Expand All @@ -74,14 +110,15 @@ def test_internal_ca_file() -> None:
encryption_algorithm=serialization.NoEncryption(),
)
)
ca_private_key_file = fp.name
ca_private_key_file = Path(fp.name)

with NamedTemporaryFile(mode="wb", delete=False, suffix=".pem") as fp:
fp.write(ca_certificate.public_bytes(encoding=serialization.Encoding.PEM))
ca_certificate_file = fp.name
ca_certificate_file = Path(fp.name)

_ = InternalCertificateAuthority.load(
ca_certificate_file=ca_certificate_file, ca_private_key_file=ca_private_key_file
issuer_ca_certificate_file=ca_certificate_file,
issuer_ca_private_key_file=ca_private_key_file,
)

os.unlink(ca_certificate_file)
Expand Down
21 changes: 13 additions & 8 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from datetime import datetime, timedelta, timezone

from cryptography import x509
from cryptography.x509.oid import NameOID
from jwcrypto.common import base64url_decode
from jwcrypto.jwk import JWK

Expand All @@ -21,21 +20,24 @@ def rekey(key: JWK) -> JWK:
return JWK.generate(**params)


def generate_ca_certificate(ca_name: x509.Name | str, ca_private_key: PrivateKey) -> x509.Certificate:
def generate_ca_certificate(
issuer_ca_name: x509.Name,
issuer_ca_private_key: PrivateKey,
root_ca_name: x509.Name | None = None,
root_ca_private_key: PrivateKey | None = None,
) -> x509.Certificate:
"""Generate CA Certificate"""

now = datetime.now(tz=timezone.utc)
validity = timedelta(days=1)

ca_name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, ca_name)]) if isinstance(ca_name, str) else ca_name

builder = x509.CertificateBuilder()
builder = builder.subject_name(ca_name)
builder = builder.issuer_name(ca_name)
builder = builder.subject_name(issuer_ca_name)
builder = builder.issuer_name(root_ca_name or issuer_ca_name)
builder = builder.not_valid_before(now)
builder = builder.not_valid_after(now + validity)
builder = builder.serial_number(x509.random_serial_number())
builder = builder.public_key(ca_private_key.public_key())
builder = builder.public_key(issuer_ca_private_key.public_key())
builder = builder.add_extension(
x509.BasicConstraints(ca=True, path_length=None),
critical=True,
Expand All @@ -55,4 +57,7 @@ def generate_ca_certificate(ca_name: x509.Name | str, ca_private_key: PrivateKey
critical=True,
)

return builder.sign(private_key=ca_private_key, algorithm=get_hash_algorithm_from_key(ca_private_key))
return builder.sign(
private_key=root_ca_private_key or issuer_ca_private_key,
algorithm=get_hash_algorithm_from_key(issuer_ca_private_key),
)

0 comments on commit 4376d8e

Please sign in to comment.