Skip to content

Commit

Permalink
Merge pull request #9 from dnstapir/schema_validation
Browse files Browse the repository at this point in the history
Schema validation
  • Loading branch information
jschlyter authored Jun 18, 2024
2 parents 198b3d6 + cd6b2af commit bf680a4
Show file tree
Hide file tree
Showing 8 changed files with 461 additions and 13 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ This repository contains the DNS TAPIR Event Receiver, a server component use fo
MQTT_BROKER = "localhost"
MQTT_TOPIC_READ = "events/up/#"
MQTT_TOPIC_WRITE = "verified"
SCHEMA_VALIDATION = true
58 changes: 58 additions & 0 deletions evrec/schema/new_qname.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
{
"$schema": "http://json-schema.org/schema#",
"type": "object",
"additionalProperties": true,
"required": [
"type",
"version",
"timestamp",
"qname"
],
"properties": {
"version": {
"type": "integer",
"minimum": 0
},
"timestamp": {
"type": "string",
"format": "date-time"
},
"type": {
"const": "new_qname"
},
"initiator": {
"type": "string",
"enum": [
"client",
"resolver"
]
},
"qname": {
"description": "Query Name",
"$ref": "#/$defs/domain_name"
},
"qtype": {
"description": "Query Type",
"type": "integer",
"minimum": 0
},
"qclass": {
"description": "Query Class",
"type": "integer",
"minimum": 0
},
"flags": {
"description": "Flag Field (QR/Opcode/AA/TC/RD/TA/Z/RCODE)",
"type": "integer"
},
"rdlength": {
"type": "integer",
"minimum": 0
}
},
"$defs": {
"domain_name": {
"type": "string"
}
}
}
21 changes: 14 additions & 7 deletions evrec/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from . import __verbose_version__
from .settings import Settings
from .validator import MessageValidator

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,10 +59,12 @@

class EvrecServer:
def __init__(self, settings: Settings):
self.logger = logging.getLogger(__name__).getChild(self.__class__.__name__)
self.settings = settings
if self.settings.mqtt_topic_write is None:
logger.warning("Not publishing verified messages")
self.logger.warning("Not publishing verified messages")
self.clients_keys = self.get_clients_keys()
self.message_validator = MessageValidator()

@staticmethod
def create_settings(config_filename: Optional[str]):
Expand All @@ -83,7 +86,7 @@ def get_clients_keys(self) -> JWKSet:
with open(filename, "rb") as fp:
key = JWK.from_pem(fp.read())
key.kid = filename.name.removesuffix(".pem")
logger.debug("Adding key kid=%s (%s)", key.kid, key.thumbprint())
self.logger.debug("Adding key kid=%s (%s)", key.kid, key.thumbprint())
res.add(key)
return res

Expand All @@ -95,21 +98,25 @@ async def run(self):
await client.subscribe(self.settings.mqtt_topic_read)

async for message in client.messages:
logger.debug("Received message on %s", message.topic)
self.logger.debug("Received message on %s", message.topic)
try:
jws = JWS()
jws.deserialize(message.payload)
key = verify_jws_with_keys(jws, self.clients_keys)
if self.settings.schema_validation:
self.message_validator.validate_message(
str(message.topic), jws.objects["payload"]
)
if self.settings.mqtt_topic_write:
await self.handle_payload(client, message, jws, key)
else:
logger.debug("Not publishing verified message")
self.logger.debug("Not publishing verified message")
except JWKeyNotFound:
logger.warning(
self.logger.warning(
"Dropping unverified message on %s", message.topic
)
except Exception as exc:
logger.error(
self.logger.error(
"Error parsing message on %s",
message.topic,
exc_info=exc,
Expand Down Expand Up @@ -138,7 +145,7 @@ async def handle_payload(
retain=message.retain,
properties=properties,
)
logger.info("Published verified message from %s on %s", key.kid, new_topic)
self.logger.info("Published verified message from %s on %s", key.kid, new_topic)


def verify_jws_with_keys(jws: JWS, keys: JWKSet) -> JWK:
Expand Down
2 changes: 2 additions & 0 deletions evrec/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class Settings(BaseSettings):
mqtt_topic_read: str
mqtt_topic_write: str | None
mqtt_reconnect_interval: int = DEFAULT_MQTT_RECONNECT_INTERVAL
schema_validation: bool = False

@classmethod
def from_file(cls, filename: str):
Expand All @@ -25,4 +26,5 @@ def from_file(cls, filename: str):
mqtt_reconnect_interval=data.get(
"MQTT_RECONNECT_INTERVAL", DEFAULT_MQTT_RECONNECT_INTERVAL
),
schema_validation=data.get("SCHEMA_VALIDATION", False),
)
36 changes: 36 additions & 0 deletions evrec/validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import json
import logging
import os
from pathlib import Path

import jsonschema

SCHEMA_DIR = os.path.dirname(__file__) + "/schema"


class MessageValidator:
"""MQTT Message Validator"""

VALIDATOR = jsonschema.Draft202012Validator

def __init__(self) -> None:
self.logger = logging.getLogger(__name__).getChild(self.__class__.__name__)
self.schemas: dict[str, jsonschema.Validator] = {}
for filename in Path(SCHEMA_DIR).glob("*.json"):
with open(filename) as fp:
schema = json.load(fp)
name = filename.name.removesuffix(".json")
self.schemas[name] = self.VALIDATOR(
schema, format_checker=self.VALIDATOR.FORMAT_CHECKER
)
self.logger.debug("Loaded schema %s from %s", name, filename)

def validate_message(self, topic: str, payload: bytes) -> None:
"""Validate message against schema based on content type"""
content = json.loads(payload)
content_type = content["type"]
if schema := self.schemas.get(content_type):
schema.validate(content)
self.logger.debug(
"Message on %s validated against schema %s", topic, content_type
)
2 changes: 2 additions & 0 deletions example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ CLIENTS_DATABASE = "clients"
MQTT_BROKER = "localhost"
MQTT_TOPIC_READ = "events/up/#"
#MQTT_TOPIC_WRITE = "verified"

SCHEMA_VALIDATION = true
Loading

0 comments on commit bf680a4

Please sign in to comment.