Skip to content

Commit

Permalink
Merge pull request #13 from dnstapir/toml
Browse files Browse the repository at this point in the history
use standard pydantic settings reader
  • Loading branch information
jschlyter authored Aug 12, 2024
2 parents c5cd5ef + 652ae81 commit 9d8e4e0
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 64 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ push-container:
docker push $(CONTAINER)

server: $(DEPENDS) clients clients/test.pem
poetry run evrec_server --config example.toml --debug
poetry run evrec_server --debug

test-private.pem:
openssl ecparam -genkey -name prime256v1 -noout -out $@
Expand Down
15 changes: 10 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@ This repository contains the DNS TAPIR Event Receiver, a server component use fo

## Configuration

CLIENTS_DATABASE = "clients"
MQTT_BROKER = "localhost"
MQTT_TOPIC_READ = "events/up/#"
MQTT_TOPIC_WRITE = "verified"
SCHEMA_VALIDATION = true
The default configuration file is `evrec.toml`. Example configuration below:

clients_database = "clients"
schema_validation = true

[mqtt]
broker = "mqtt://localhost"
topic_read = "events/up/#"
topic_write = "verified"
reconnect_interval = 5
4 changes: 4 additions & 0 deletions clients/test.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEEMXdg43RVuFdkTxkfB2szubckpnv
GrAAl7d3pN1y9n38EMZoTjSj2/7pLZYwmJak0du5q+DscXtK3sbPkdD6fQ==
-----END PUBLIC KEY-----
39 changes: 16 additions & 23 deletions evrec/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import json
import logging
import logging.config
import os
from pathlib import Path
from typing import Optional

import aiomqtt
from jwcrypto.common import JWKeyNotFound
Expand Down Expand Up @@ -61,24 +59,15 @@ class EvrecServer:
def __init__(self, settings: Settings):
self.logger = logging.getLogger(__name__).getChild(self.__class__.__name__)
self.settings = settings
if self.settings.mqtt_topic_write is None:
if self.settings.mqtt.topic_write is None:
self.logger.warning("Not publishing verified messages")
self.clients_keys = self.get_clients_keys()
self.message_validator = MessageValidator()

@staticmethod
def create_settings(config_filename: Optional[str]):
config_filename = config_filename or os.environ.get("EVREC_CONFIG")
if config_filename:
logger.info("Reading configuration from %s", config_filename)
return Settings.from_file(config_filename)
else:
return Settings()

@classmethod
def factory(cls, config_filename: Optional[str]):
def factory(cls):
logger.info("Starting Event Receiver version %s", __verbose_version__)
return cls(settings=cls.create_settings(config_filename))
return cls(settings=Settings())

def get_clients_keys(self) -> JWKSet:
res = JWKSet()
Expand All @@ -93,9 +82,14 @@ def get_clients_keys(self) -> JWKSet:
async def run(self):
while True:
try:
async with aiomqtt.Client(self.settings.mqtt_broker) as client:
logging.info("MQTT connected to %s", self.settings.mqtt_broker)
await client.subscribe(self.settings.mqtt_topic_read)
async with aiomqtt.Client(
hostname=self.settings.mqtt.broker.host,
port=self.settings.mqtt.broker.port,
username=self.settings.mqtt.broker.username,
password=self.settings.mqtt.broker.password,
) as client:
logging.info("MQTT connected to %s", self.settings.mqtt.broker)
await client.subscribe(self.settings.mqtt.topic_read)

async for message in client.messages:
self.logger.debug("Received message on %s", message.topic)
Expand All @@ -107,7 +101,7 @@ async def run(self):
self.message_validator.validate_message(
str(message.topic), jws.objects["payload"]
)
if self.settings.mqtt_topic_write:
if self.settings.mqtt.topic_write:
await self.handle_payload(client, message, jws, key)
else:
self.logger.debug("Not publishing verified message")
Expand All @@ -124,9 +118,9 @@ async def run(self):
except aiomqtt.MqttError:
logging.error(
"MQTT connection lost; Reconnecting in %d seconds...",
self.settings.mqtt_reconnect_interval,
self.settings.mqtt.reconnect_interval,
)
await asyncio.sleep(self.settings.mqtt_reconnect_interval)
await asyncio.sleep(self.settings.mqtt.reconnect_interval)

async def handle_payload(
self,
Expand All @@ -135,7 +129,7 @@ async def handle_payload(
jws: JWS,
key: JWK,
) -> None:
new_topic = f"{self.settings.mqtt_topic_write}/{message.topic}"
new_topic = f"{self.settings.mqtt.topic_write}/{message.topic}"
properties = Properties(PacketTypes.PUBLISH)
properties.UserProperty = [("kid", key.kid), ("thumbprint", key.thumbprint())]
await client.publish(
Expand Down Expand Up @@ -169,7 +163,6 @@ def main() -> None:

parser = argparse.ArgumentParser(description="Event Receiver")

parser.add_argument("--config", metavar="filename", help="Configuration file")
parser.add_argument("--debug", action="store_true", help="Enable debugging")
parser.add_argument("--version", action="store_true", help="Show version")

Expand All @@ -186,7 +179,7 @@ def main() -> None:
else:
logging.basicConfig(level=logging.INFO)

app = EvrecServer.factory(args.config)
app = EvrecServer.factory()

asyncio.run(app.run())

Expand Down
57 changes: 35 additions & 22 deletions evrec/settings.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,43 @@
import tomllib
from typing import Annotated, Tuple, Type

from pydantic_settings import BaseSettings
from pydantic import BaseModel, DirectoryPath, Field, UrlConstraints
from pydantic_core import Url
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
TomlConfigSettingsSource,
)

DEFAULT_MQTT_RECONNECT_INTERVAL = 5
MqttUrl = Annotated[
Url,
UrlConstraints(
allowed_schemes=["mqtt", "mqtts"], default_port=1883, host_required=True
),
]


class MqttSettings(BaseModel):
broker: MqttUrl = Field(default="mqtt://localhost")
topic_read: str = Field(default="events/up/#")
topic_write: str | None = None
reconnect_interval: int = Field(default=5)


class Settings(BaseSettings):
clients_database: str
mqtt_broker: str | None
mqtt_topic_read: str
mqtt_topic_write: str | None
mqtt_reconnect_interval: int = DEFAULT_MQTT_RECONNECT_INTERVAL
mqtt: MqttSettings = Field(default=MqttSettings())
clients_database: DirectoryPath = Field(default="clients")
schema_validation: bool = False

model_config = SettingsConfigDict(toml_file="evrec.toml")

@classmethod
def from_file(cls, filename: str):
with open(filename, "rb") as fp:
data = tomllib.load(fp)

return cls(
clients_database=data.get("CLIENTS_DATABASE", "clients"),
mqtt_broker=data.get("MQTT_BROKER"),
mqtt_topic_read=data.get("MQTT_TOPIC_READ", "events/up/#"),
mqtt_topic_write=data.get("MQTT_TOPIC_WRITE"),
mqtt_reconnect_interval=data.get(
"MQTT_RECONNECT_INTERVAL", DEFAULT_MQTT_RECONNECT_INTERVAL
),
schema_validation=data.get("SCHEMA_VALIDATION", False),
)
def settings_customise_sources(
cls,
settings_cls: Type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> Tuple[PydanticBaseSettingsSource, ...]:
return (TomlConfigSettingsSource(settings_cls),)
7 changes: 0 additions & 7 deletions example.toml

This file was deleted.

7 changes: 1 addition & 6 deletions tests/test_mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,5 @@


def test_server():
settings = Settings(
clients_database="clients",
mqtt_broker=None,
mqtt_topic_read="read",
mqtt_topic_write="write",
)
settings = Settings()
_ = EvrecServer(settings)

0 comments on commit 9d8e4e0

Please sign in to comment.