Skip to content

Commit

Permalink
move and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
jschlyter committed Nov 28, 2024
1 parent 5f60434 commit 24339d5
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 54 deletions.
2 changes: 1 addition & 1 deletion nodeman/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from jwcrypto.jwk import JWK
from jwcrypto.jws import JWS

from nodeman.utils import generate_x509_csr, jwk_to_alg
from nodeman.jose import generate_x509_csr, jwk_to_alg


def main() -> None:
Expand Down
16 changes: 16 additions & 0 deletions nodeman/jose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from jwcrypto.jwk import JWK


def jwk_to_alg(key: JWK) -> str:
match (key.kty, key.get("crv")):
case ("RSA", None):
return "RS256"
case ("EC", "P-256"):
return "ES256"
case ("EC", "P-384"):
return "ES384"
case ("OKP", "Ed25519"):
return "EdDSA"
case ("OKP", "Ed448"):
return "EdDSA"
raise ValueError
10 changes: 5 additions & 5 deletions nodeman/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .const import MIME_TYPE_JWK, MIME_TYPE_PEM
from .db_models import TapirNode, TapirNodeSecret
from .models import NodeBootstrapInformation, NodeCollection, NodeConfiguration, NodeInformation, PublicJwk
from .utils import verify_x509_csr
from .x509 import verify_x509_csr

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -194,16 +194,16 @@ async def enroll_node(
verify_x509_csr(name=name, csr=x509_csr)

try:
step_ca_response = request.app.step_client.sign_csr(x509_csr, name)
ca_response = request.app.ca_client.sign_csr(x509_csr, name)
except Exception as exc:
logger.error("Failed to processes CSR for %s", name)
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error issuing certificate") from exc

x509_certificate = "".join(
[certificate.public_bytes(serialization.Encoding.PEM).decode() for certificate in step_ca_response.cert_chain]
[certificate.public_bytes(serialization.Encoding.PEM).decode() for certificate in ca_response.cert_chain]
)
x509_ca_certificate = step_ca_response.ca_cert.public_bytes(serialization.Encoding.PEM).decode()
x509_ca_url = request.app.step_client.ca_url
x509_ca_certificate = ca_response.ca_cert.public_bytes(serialization.Encoding.PEM).decode()
x509_ca_url = request.app.ca_client.ca_url

node.activated = datetime.now(tz=timezone.utc)
node.save()
Expand Down
5 changes: 3 additions & 2 deletions nodeman/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from . import OPENAPI_METADATA, __verbose_version__
from .settings import Settings, StepSettings
from .step import StepClient
from .x509 import CertificateAuthorityClient

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -53,8 +54,8 @@ def __init__(self, settings: Settings):
else:
self.logger.warning("Starting without trusted keys")

self.step_client: StepClient | None
self.step_client = self.get_step_client(self.settings.step) if self.settings.step else None
self.ca_client: CertificateAuthorityClient | None
self.ca_client = self.get_step_client(self.settings.step) if self.settings.step else None

@staticmethod
def get_step_client(settings: StepSettings) -> StepClient:
Expand Down
16 changes: 5 additions & 11 deletions nodeman/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import tempfile
import time
import uuid
from dataclasses import dataclass
from urllib.parse import urljoin

import httpx
Expand All @@ -12,16 +11,11 @@
from jwcrypto.jwk import JWK
from jwcrypto.jwt import JWT

from .utils import jwk_to_alg
from .jose import jwk_to_alg
from .x509 import CertificateAuthorityClient, CertificateInformation


@dataclass(frozen=True)
class StepSignResponse:
cert_chain: list[x509.Certificate]
ca_cert: x509.Certificate


class StepClient:
class StepClient(CertificateAuthorityClient):
def __init__(self, ca_url: str, ca_fingerprint: str, provisioner_name: str, provisioner_jwk: JWK):
self.ca_url = ca_url
self.ca_fingerprint = ca_fingerprint
Expand All @@ -30,7 +24,7 @@ def __init__(self, ca_url: str, ca_fingerprint: str, provisioner_name: str, prov
self.ca_bundle_filename = self._get_root_ca_cert()
self.token_ttl = 300

def sign_csr(self, csr: x509.CertificateSigningRequest, name: str) -> StepSignResponse:
def sign_csr(self, csr: x509.CertificateSigningRequest, name: str) -> CertificateInformation:
csr_pem = csr.public_bytes(encoding=serialization.Encoding.PEM).decode()
token = self._get_token(name)
response = httpx.post(
Expand All @@ -40,7 +34,7 @@ def sign_csr(self, csr: x509.CertificateSigningRequest, name: str) -> StepSignRe
)
response.raise_for_status()
payload = response.json()
return StepSignResponse(
return CertificateInformation(
cert_chain=[x509.load_pem_x509_certificate(cert.encode()) for cert in payload["certChain"]],
ca_cert=x509.load_pem_x509_certificate(payload["ca"].encode()),
)
Expand Down
31 changes: 15 additions & 16 deletions nodeman/utils.py → nodeman/x509.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass

from cryptography import x509
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.x509.oid import ExtensionOID, NameOID
from jwcrypto.jwk import JWK

type PrivateKey = ec.EllipticCurvePrivateKey


@dataclass(frozen=True)
class CertificateInformation:
cert_chain: list[x509.Certificate]
ca_cert: x509.Certificate


class CertificateAuthorityClient(ABC):
@abstractmethod
def sign_csr(self, csr: x509.CertificateSigningRequest, name: str) -> CertificateInformation:
pass


def generate_x509_csr(name: str, key: PrivateKey) -> x509.CertificateSigningRequest:
"""Generate X.509 CSR with name and key"""
return (
Expand Down Expand Up @@ -42,18 +56,3 @@ def verify_x509_csr(name: str, csr: x509.CertificateSigningRequest) -> None:
ext = csr.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME)
if ext.value.get_values_for_type(x509.DNSName) != [name]:
raise ValueError("Invalid SubjectAlternativeName")


def jwk_to_alg(key: JWK) -> str:
match (key.kty, key.get("crv")):
case ("RSA", None):
return "RS256"
case ("EC", "P-256"):
return "ES256"
case ("EC", "P-384"):
return "ES384"
case ("OKP", "Ed25519"):
return "EdDSA"
case ("OKP", "Ed448"):
return "EdDSA"
raise ValueError
7 changes: 4 additions & 3 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
from pydantic_settings import SettingsConfigDict

from nodeman.const import MIME_TYPE_JWK, MIME_TYPE_PEM
from nodeman.jose import jwk_to_alg
from nodeman.server import NodemanServer
from nodeman.settings import Settings
from nodeman.utils import generate_x509_csr, jwk_to_alg
from tests.utils import TestStepClient
from nodeman.x509 import generate_x509_csr
from tests.utils import CaTestClient

ADMIN_TEST_NODE_COUNT = 100

Expand All @@ -25,7 +26,7 @@
def get_test_client() -> TestClient:
settings = Settings()
app = NodemanServer(settings)
app.step_client = TestStepClient()
app.ca_client = CaTestClient()
app.connect_mongodb()
return TestClient(app)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from nodeman.settings import StepSettings
from nodeman.step import StepClient
from nodeman.utils import generate_x509_csr, verify_x509_csr
from nodeman.x509 import generate_x509_csr, verify_x509_csr


def test_step_ca() -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_x509.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.x509.oid import NameOID

from nodeman.utils import verify_x509_csr
from nodeman.x509 import verify_x509_csr

type PrivateKey = ec.EllipticCurvePrivateKey

Expand Down
44 changes: 30 additions & 14 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,59 @@
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.x509.oid import NameOID

from nodeman.step import StepSignResponse
from nodeman.x509 import CertificateAuthorityClient, CertificateInformation


class TestStepClient:
class CaTestClient(CertificateAuthorityClient):
def __init__(self):
self.ca_name = "ca.example.com"
self.ca_url = "https://ca.example.com"
self.ca_private_key = ec.generate_private_key(ec.SECP256R1())

def sign_csr(self, csr: x509.CertificateSigningRequest, name: str) -> StepSignResponse:
now = datetime.now(tz=timezone.utc)
one_day = timedelta(days=1)
ca_private_key = ec.generate_private_key(ec.SECP256R1())
validity = timedelta(days=1)

# build CA certificate
builder = x509.CertificateBuilder()
builder = builder.subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, self.ca_name)]))
builder = builder.issuer_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, self.ca_name)]))
builder = builder.not_valid_before(now)
builder = builder.not_valid_after(now + one_day)
builder = builder.not_valid_after(now + validity)
builder = builder.serial_number(x509.random_serial_number())
builder = builder.public_key(ca_private_key.public_key())
builder = builder.public_key(self.ca_private_key.public_key())
builder = builder.add_extension(x509.IssuerAlternativeName([x509.DNSName(self.ca_name)]), critical=False)
builder = builder.add_extension(
x509.BasicConstraints(ca=True, path_length=None),
critical=True,
)
ca_certificate = builder.sign(
private_key=ca_private_key,
builder = builder.add_extension(
x509.KeyUsage(
digital_signature=True,
content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
key_cert_sign=True,
crl_sign=True,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
self.ca_certificate = builder.sign(
private_key=self.ca_private_key,
algorithm=hashes.SHA256(),
)

# build client certificate
def sign_csr(self, csr: x509.CertificateSigningRequest, name: str) -> CertificateInformation:
"""Sign CSR with CA private key"""
now = datetime.now(tz=timezone.utc)
validity = timedelta(minutes=10)

builder = x509.CertificateBuilder()
builder = builder.subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, name)]))
builder = builder.issuer_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, self.ca_name)]))
builder = builder.not_valid_before(now)
builder = builder.not_valid_after(now + one_day)
builder = builder.not_valid_after(now + validity)
builder = builder.serial_number(x509.random_serial_number())
builder = builder.public_key(csr.public_key())
builder = builder.add_extension(x509.SubjectAlternativeName([x509.DNSName(name)]), critical=False)
Expand All @@ -50,8 +66,8 @@ def sign_csr(self, csr: x509.CertificateSigningRequest, name: str) -> StepSignRe
critical=True,
)
certificate = builder.sign(
private_key=ca_private_key,
private_key=self.ca_private_key,
algorithm=hashes.SHA256(),
)

return StepSignResponse(cert_chain=[certificate], ca_cert=ca_certificate)
return CertificateInformation(cert_chain=[certificate], ca_cert=self.ca_certificate)

0 comments on commit 24339d5

Please sign in to comment.