diff --git a/Makefile b/Makefile index b5d6e89..a1125d1 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ push-container: docker push $(CONTAINER) server: $(DEPENDS) clients clients/test.pem - poetry run evrec_server --config example.toml --debug + poetry run evrec_server --debug test-private.pem: openssl ecparam -genkey -name prime256v1 -noout -out $@ diff --git a/README.md b/README.md index ae283bb..cbc1e32 100644 --- a/README.md +++ b/README.md @@ -5,8 +5,13 @@ This repository contains the DNS TAPIR Event Receiver, a server component use fo ## Configuration - CLIENTS_DATABASE = "clients" - MQTT_BROKER = "localhost" - MQTT_TOPIC_READ = "events/up/#" - MQTT_TOPIC_WRITE = "verified" - SCHEMA_VALIDATION = true +The default configuration file is `evrec.toml`. Example configuration below: + + clients_database = "clients" + schema_validation = true + + [mqtt] + broker = "mqtt://localhost" + topic_read = "events/up/#" + topic_write = "verified" + reconnect_interval = 5 diff --git a/clients/test.pem b/clients/test.pem new file mode 100644 index 0000000..d534243 --- /dev/null +++ b/clients/test.pem @@ -0,0 +1,4 @@ +-----BEGIN PUBLIC KEY----- +MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEEMXdg43RVuFdkTxkfB2szubckpnv +GrAAl7d3pN1y9n38EMZoTjSj2/7pLZYwmJak0du5q+DscXtK3sbPkdD6fQ== +-----END PUBLIC KEY----- diff --git a/evrec/server.py b/evrec/server.py index 6913bf5..350ce97 100644 --- a/evrec/server.py +++ b/evrec/server.py @@ -3,9 +3,7 @@ import json import logging import logging.config -import os from pathlib import Path -from typing import Optional import aiomqtt from jwcrypto.common import JWKeyNotFound @@ -61,24 +59,15 @@ class EvrecServer: def __init__(self, settings: Settings): self.logger = logging.getLogger(__name__).getChild(self.__class__.__name__) self.settings = settings - if self.settings.mqtt_topic_write is None: + if self.settings.mqtt.topic_write is None: self.logger.warning("Not publishing verified messages") self.clients_keys = self.get_clients_keys() self.message_validator = MessageValidator() - @staticmethod - def create_settings(config_filename: Optional[str]): - config_filename = config_filename or os.environ.get("EVREC_CONFIG") - if config_filename: - logger.info("Reading configuration from %s", config_filename) - return Settings.from_file(config_filename) - else: - return Settings() - @classmethod - def factory(cls, config_filename: Optional[str]): + def factory(cls): logger.info("Starting Event Receiver version %s", __verbose_version__) - return cls(settings=cls.create_settings(config_filename)) + return cls(settings=Settings()) def get_clients_keys(self) -> JWKSet: res = JWKSet() @@ -93,9 +82,14 @@ def get_clients_keys(self) -> JWKSet: async def run(self): while True: try: - async with aiomqtt.Client(self.settings.mqtt_broker) as client: - logging.info("MQTT connected to %s", self.settings.mqtt_broker) - await client.subscribe(self.settings.mqtt_topic_read) + async with aiomqtt.Client( + hostname=self.settings.mqtt.broker.host, + port=self.settings.mqtt.broker.port, + username=self.settings.mqtt.broker.username, + password=self.settings.mqtt.broker.password, + ) as client: + logging.info("MQTT connected to %s", self.settings.mqtt.broker) + await client.subscribe(self.settings.mqtt.topic_read) async for message in client.messages: self.logger.debug("Received message on %s", message.topic) @@ -107,7 +101,7 @@ async def run(self): self.message_validator.validate_message( str(message.topic), jws.objects["payload"] ) - if self.settings.mqtt_topic_write: + if self.settings.mqtt.topic_write: await self.handle_payload(client, message, jws, key) else: self.logger.debug("Not publishing verified message") @@ -124,9 +118,9 @@ async def run(self): except aiomqtt.MqttError: logging.error( "MQTT connection lost; Reconnecting in %d seconds...", - self.settings.mqtt_reconnect_interval, + self.settings.mqtt.reconnect_interval, ) - await asyncio.sleep(self.settings.mqtt_reconnect_interval) + await asyncio.sleep(self.settings.mqtt.reconnect_interval) async def handle_payload( self, @@ -135,7 +129,7 @@ async def handle_payload( jws: JWS, key: JWK, ) -> None: - new_topic = f"{self.settings.mqtt_topic_write}/{message.topic}" + new_topic = f"{self.settings.mqtt.topic_write}/{message.topic}" properties = Properties(PacketTypes.PUBLISH) properties.UserProperty = [("kid", key.kid), ("thumbprint", key.thumbprint())] await client.publish( @@ -169,7 +163,6 @@ def main() -> None: parser = argparse.ArgumentParser(description="Event Receiver") - parser.add_argument("--config", metavar="filename", help="Configuration file") parser.add_argument("--debug", action="store_true", help="Enable debugging") parser.add_argument("--version", action="store_true", help="Show version") @@ -186,7 +179,7 @@ def main() -> None: else: logging.basicConfig(level=logging.INFO) - app = EvrecServer.factory(args.config) + app = EvrecServer.factory() asyncio.run(app.run()) diff --git a/evrec/settings.py b/evrec/settings.py index 6077719..1849cc0 100644 --- a/evrec/settings.py +++ b/evrec/settings.py @@ -1,30 +1,43 @@ -import tomllib +from typing import Annotated, Tuple, Type -from pydantic_settings import BaseSettings +from pydantic import BaseModel, DirectoryPath, Field, UrlConstraints +from pydantic_core import Url +from pydantic_settings import ( + BaseSettings, + PydanticBaseSettingsSource, + SettingsConfigDict, + TomlConfigSettingsSource, +) -DEFAULT_MQTT_RECONNECT_INTERVAL = 5 +MqttUrl = Annotated[ + Url, + UrlConstraints( + allowed_schemes=["mqtt", "mqtts"], default_port=1883, host_required=True + ), +] + + +class MqttSettings(BaseModel): + broker: MqttUrl = Field(default="mqtt://localhost") + topic_read: str = Field(default="events/up/#") + topic_write: str | None = None + reconnect_interval: int = Field(default=5) class Settings(BaseSettings): - clients_database: str - mqtt_broker: str | None - mqtt_topic_read: str - mqtt_topic_write: str | None - mqtt_reconnect_interval: int = DEFAULT_MQTT_RECONNECT_INTERVAL + mqtt: MqttSettings = Field(default=MqttSettings()) + clients_database: DirectoryPath = Field(default="clients") schema_validation: bool = False + model_config = SettingsConfigDict(toml_file="evrec.toml") + @classmethod - def from_file(cls, filename: str): - with open(filename, "rb") as fp: - data = tomllib.load(fp) - - return cls( - clients_database=data.get("CLIENTS_DATABASE", "clients"), - mqtt_broker=data.get("MQTT_BROKER"), - mqtt_topic_read=data.get("MQTT_TOPIC_READ", "events/up/#"), - mqtt_topic_write=data.get("MQTT_TOPIC_WRITE"), - mqtt_reconnect_interval=data.get( - "MQTT_RECONNECT_INTERVAL", DEFAULT_MQTT_RECONNECT_INTERVAL - ), - schema_validation=data.get("SCHEMA_VALIDATION", False), - ) + def settings_customise_sources( + cls, + settings_cls: Type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> Tuple[PydanticBaseSettingsSource, ...]: + return (TomlConfigSettingsSource(settings_cls),) diff --git a/example.toml b/example.toml deleted file mode 100644 index 6dcb56c..0000000 --- a/example.toml +++ /dev/null @@ -1,7 +0,0 @@ -CLIENTS_DATABASE = "clients" - -MQTT_BROKER = "localhost" -MQTT_TOPIC_READ = "events/up/#" -#MQTT_TOPIC_WRITE = "verified" - -SCHEMA_VALIDATION = true diff --git a/tests/test_mqtt.py b/tests/test_mqtt.py index abcc19e..e074fc6 100644 --- a/tests/test_mqtt.py +++ b/tests/test_mqtt.py @@ -3,10 +3,5 @@ def test_server(): - settings = Settings( - clients_database="clients", - mqtt_broker=None, - mqtt_topic_read="read", - mqtt_topic_write="write", - ) + settings = Settings() _ = EvrecServer(settings)