diff --git a/nodeman/authn.py b/nodeman/authn.py new file mode 100644 index 0000000..f902076 --- /dev/null +++ b/nodeman/authn.py @@ -0,0 +1,28 @@ +import logging +from typing import Annotated + +from fastapi import Depends, HTTPException, Request, status +from fastapi.security import HTTPBasic, HTTPBasicCredentials + +logger = logging.getLogger(__name__) + +security = HTTPBasic() + + +def get_current_username( + request: Request, + credentials: Annotated[HTTPBasicCredentials, Depends(security)], +): + if user := request.app.users.get(credentials.username): + if user.verify_password(credentials.password): + return credentials.username + else: + logger.warning("Invalid password for user %s", credentials.username) + else: + logger.warning("Unknown user %s", credentials.username) + + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Basic"}, + ) diff --git a/nodeman/nodes.py b/nodeman/nodes.py index 6db2d65..faa72ec 100644 --- a/nodeman/nodes.py +++ b/nodeman/nodes.py @@ -4,13 +4,12 @@ from typing import Annotated from cryptography import x509 -from cryptography.hazmat.primitives import serialization from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status -from fastapi.security import HTTPBasic, HTTPBasicCredentials from jwcrypto.jwk import JWK from jwcrypto.jws import JWS, InvalidJWSSignature from opentelemetry import metrics, trace +from .authn import get_current_username from .const import MIME_TYPE_JWK, MIME_TYPE_PEM from .db_models import TapirNode, TapirNodeSecret from .models import ( @@ -21,7 +20,7 @@ NodeInformation, PublicJwk, ) -from .x509 import verify_x509_csr +from .x509 import process_csr_request logger = logging.getLogger(__name__) @@ -30,64 +29,12 @@ router = APIRouter() -security = HTTPBasic() - - -def get_current_username( - request: Request, - credentials: Annotated[HTTPBasicCredentials, Depends(security)], -): - if user := request.app.users.get(credentials.username): - if user.verify_password(credentials.password): - return credentials.username - else: - logger.warning("Invalid password for user %s", credentials.username) - else: - logger.warning("Unknown user %s", credentials.username) - - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect username or password", - headers={"WWW-Authenticate": "Basic"}, - ) - - -def process_csr(csr: x509.CertificateSigningRequest, name: str, request: Request) -> NodeCertificate: - """Verify CSR and issuer certificate""" - - verify_x509_csr(name=name, csr=csr) - - try: - ca_response = request.app.ca_client.sign_csr(csr, name) - except Exception as exc: - logger.error("Failed to process CSR for %s", name) - raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error issuing certificate") from exc - - x509_certificate = "".join( - [certificate.public_bytes(serialization.Encoding.PEM).decode() for certificate in ca_response.cert_chain] - ) - x509_ca_certificate = ca_response.ca_cert.public_bytes(serialization.Encoding.PEM).decode() - x509_certificate_serial_number = ca_response.cert_chain[0].serial_number - - logger.info( - "Issued certificate for name=%s serial=%d", - name, - x509_certificate_serial_number, - extra={"nodename": name, "x509_certificate_serial_number": x509_certificate_serial_number}, - ) - - return NodeCertificate( - x509_certificate=x509_certificate, - x509_ca_certificate=x509_ca_certificate, - x509_certificate_serial_number=x509_certificate_serial_number, - ) - @router.post( "/api/v1/node", - status_code=201, + status_code=status.HTTP_201_CREATED, responses={ - 201: {"model": NodeBootstrapInformation}, + status.HTTP_201_CREATED: {"model": NodeBootstrapInformation}, }, tags=["backend"], ) @@ -115,8 +62,8 @@ async def create_node( @router.get( "/api/v1/node/{name}", responses={ - 200: {"model": NodeInformation}, - 404: {}, + status.HTTP_200_OK: {"model": NodeInformation}, + status.HTTP_404_NOT_FOUND: {}, }, tags=["backend"], ) @@ -135,8 +82,7 @@ def get_node_information(name: str, username: Annotated[str, Depends(get_current @router.get( "/api/v1/nodes", responses={ - 200: {"model": NodeCollection}, - 404: {}, + status.HTTP_200_OK: {"model": NodeCollection}, }, tags=["backend"], ) @@ -149,7 +95,7 @@ def get_all_nodes(username: Annotated[str, Depends(get_current_username)]) -> No @router.get( "/api/v1/node/{name}/public_key", responses={ - 200: { + status.HTTP_200_OK: { "content": { MIME_TYPE_JWK: { "title": "JWK", @@ -158,7 +104,7 @@ def get_all_nodes(username: Annotated[str, Depends(get_current_username)]) -> No MIME_TYPE_PEM: {"title": "PEM", "schema": {"type": "string"}}, }, }, - 404: {}, + status.HTTP_404_NOT_FOUND: {}, }, tags=["client"], ) @@ -186,9 +132,10 @@ async def get_node_public_key( @router.delete( "/api/v1/node/{name}", + status_code=status.HTTP_204_NO_CONTENT, responses={ - 204: {"description": "Node deleted", "content": None}, - 404: {}, + status.HTTP_204_NO_CONTENT: {"description": "Node deleted", "content": None}, + status.HTTP_404_NOT_FOUND: {}, }, tags=["backend"], ) @@ -213,7 +160,8 @@ def delete_node(name: str, username: Annotated[str, Depends(get_current_username @router.post( "/api/v1/node/{name}/enroll", responses={ - 200: {"model": NodeConfiguration}, + status.HTTP_200_OK: {"model": NodeConfiguration}, + status.HTTP_404_NOT_FOUND: {}, }, tags=["client"], ) @@ -264,7 +212,7 @@ async def enroll_node( # Verify X.509 CSR and issue certificate x509_csr = x509.load_pem_x509_csr(message["x509_csr"].encode()) - node_certificate = process_csr(csr=x509_csr, name=name, request=request) + node_certificate = process_csr_request(csr=x509_csr, name=name, request=request) node.activated = datetime.now(tz=timezone.utc) node.save() @@ -284,7 +232,8 @@ async def enroll_node( @router.post( "/api/v1/node/{name}/renew", responses={ - 200: {"model": NodeCertificate}, + status.HTTP_200_OK: {"model": NodeCertificate}, + status.HTTP_404_NOT_FOUND: {}, }, tags=["client"], ) @@ -319,4 +268,4 @@ async def renew_node( # Verify X.509 CSR and issue certificate x509_csr = x509.load_pem_x509_csr(message["x509_csr"].encode()) - return process_csr(csr=x509_csr, name=name, request=request) + return process_csr_request(csr=x509_csr, name=name, request=request) diff --git a/nodeman/x509.py b/nodeman/x509.py index 8c5c3a5..9135a9a 100644 --- a/nodeman/x509.py +++ b/nodeman/x509.py @@ -1,17 +1,23 @@ +import logging from abc import ABC, abstractmethod from dataclasses import dataclass from cryptography import x509 -from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PrivateKey from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey from cryptography.x509.oid import ExtensionOID, NameOID +from fastapi import HTTPException, Request, status + +from .models import NodeCertificate type PrivateKey = RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey +logger = logging.getLogger(__name__) + @dataclass(frozen=True) class CertificateInformation: @@ -87,3 +93,34 @@ def verify_x509_csr(name: str, csr: x509.CertificateSigningRequest) -> None: san_value = san_ext.value.get_values_for_type(x509.DNSName) if san_value != [name]: raise SubjectAlternativeNameMismatchError(f"Invalid SubjectAlternativeName, got {san_value} expected {name}") + + +def process_csr_request(request: Request, csr: x509.CertificateSigningRequest, name: str) -> NodeCertificate: + """Verify CSR and issue certificate""" + + verify_x509_csr(name=name, csr=csr) + + try: + ca_response = request.app.ca_client.sign_csr(csr, name) + except Exception as exc: + logger.error("Failed to process CSR for %s", name) + raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error issuing certificate") from exc + + x509_certificate = "".join( + [certificate.public_bytes(serialization.Encoding.PEM).decode() for certificate in ca_response.cert_chain] + ) + x509_ca_certificate = ca_response.ca_cert.public_bytes(serialization.Encoding.PEM).decode() + x509_certificate_serial_number = ca_response.cert_chain[0].serial_number + + logger.info( + "Issued certificate for name=%s serial=%d", + name, + x509_certificate_serial_number, + extra={"nodename": name, "x509_certificate_serial_number": x509_certificate_serial_number}, + ) + + return NodeCertificate( + x509_certificate=x509_certificate, + x509_ca_certificate=x509_ca_certificate, + x509_certificate_serial_number=x509_certificate_serial_number, + )