From 3d7784e75fee23bf9d49b418f428e7b6dba88e1e Mon Sep 17 00:00:00 2001 From: kvothe Date: Mon, 18 Nov 2024 13:55:43 +0100 Subject: [PATCH] Refactor user model and config handling (#433) * UserModelCOnfig with __getattr__ * UserModelConfig with @property decorator * Config prepared --- src/app.py | 52 +++++++++++++--------------------------------- src/db.py | 60 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 73 insertions(+), 39 deletions(-) diff --git a/src/app.py b/src/app.py index d0113259..c07e5f8d 100644 --- a/src/app.py +++ b/src/app.py @@ -1,7 +1,6 @@ from contextlib import asynccontextmanager import os -from enum import Enum, auto import uvicorn # type: ignore from pathlib import Path from fastapi import ( @@ -26,7 +25,7 @@ from .config import app_config from .util import mquery_version -from .db import Database +from .db import Database, UserRole from .lib.yaraparse import parse_yara from .plugins import PluginManager from .lib.ursadb import UrsaDb @@ -71,24 +70,6 @@ def with_plugins() -> Iterable[PluginManager]: plugins.cleanup() -# See docs/users.md for documentation on the permission model. -# Enum values are meaningless and may change. Make sure to not store them -# anywhere (for storing/transfer use role names instead). -class UserRole(Enum): - # "role groups", used to grant a collection of "action roles" - nobody = auto() # no permissions granted - user = auto() # can run yara queries and read the state - admin = auto() # can manage the system (and do everything else) - - # "action roles", used to give permission to a specific thing - can_manage_all_queries = auto() - can_manage_queries = auto() - can_list_all_queries = auto() - can_list_queries = auto() - can_view_queries = auto() - can_download_files = auto() - - class User: def __init__(self, token: Optional[Dict]) -> None: self.__token = token @@ -114,8 +95,8 @@ def roles(self, client_id: Optional[str]) -> List[UserRole]: async def current_user(authorization: Optional[str] = Header(None)) -> User: - auth_enabled = db.get_mquery_config_key("auth_enabled") - if not auth_enabled or auth_enabled == "false": + auth_enabled = db.config.auth_enabled + if not auth_enabled: return User(None) if not authorization: @@ -134,7 +115,7 @@ async def current_user(authorization: Optional[str] = Header(None)) -> User: _bearer, token = token_parts - secret = db.get_mquery_config_key("openid_secret") + secret = db.config.openid_secret if secret is None: raise RuntimeError("Invalid configuration - missing_openid_secret.") @@ -169,9 +150,9 @@ def __init__(self, need_permissions: List[UserRole]) -> None: self.need_permissions = need_permissions def __call__(self, user: User = Depends(current_user)): - auth_enabled = db.get_mquery_config_key("auth_enabled") - if not auth_enabled or auth_enabled == "false": - return + auth_enabled = db.config.auth_enabled + if not auth_enabled: + return User(None) all_roles = get_user_roles(user) if not any(role in self.need_permissions for role in all_roles): @@ -198,15 +179,10 @@ def __call__(self, user: User = Depends(current_user)): def get_user_roles(user: User) -> List[UserRole]: """Get all roles assigned to user, taking into account the system configuration (like default configured roles)""" - client_id = db.get_mquery_config_key("openid_client_id") + client_id = db.config.openid_client_id user_roles = user.roles(client_id) - auth_default_roles = db.get_mquery_config_key("auth_default_roles") - if not auth_default_roles: - auth_default_roles = "admin" - default_roles = [ - UserRole[role.strip()] for role in auth_default_roles.split(",") - ] - all_roles = set(user_roles + default_roles) + auth_default_roles = db.config.auth_default_roles + all_roles = set(user_roles + auth_default_roles) return sum((expand_role(role) for role in all_roles), []) @@ -455,7 +431,7 @@ def query( ] degenerate_rules = [r.name for r in rules if r.parse().is_degenerate] - allow_slow = db.get_mquery_config_key("query_allow_slow") == "true" + allow_slow = db.config.query_allow_slow if degenerate_rules and not (allow_slow and data.force_slow_queries): if allow_slow: # Warning: "You can force a slow query" literal is used to @@ -601,9 +577,9 @@ def query_remove( def server() -> ServerSchema: return ServerSchema( version=mquery_version(), - auth_enabled=db.get_mquery_config_key("auth_enabled"), - openid_url=db.get_mquery_config_key("openid_url"), - openid_client_id=db.get_mquery_config_key("openid_client_id"), + auth_enabled=str(db.config.auth_enabled).lower(), + openid_url=db.config.openid_url, + openid_client_id=db.config.openid_client_id, about=app_config.mquery.about, ) diff --git a/src/db.py b/src/db.py index 4c85d1be..4f1eb775 100644 --- a/src/db.py +++ b/src/db.py @@ -8,7 +8,7 @@ import random import string from redis import StrictRedis -from enum import Enum +from enum import Enum, auto from rq import Queue # type: ignore from sqlmodel import ( Session, @@ -40,10 +40,64 @@ class TaskType(Enum): COMMAND = "command" +# See docs/users.md for documentation on the permission model. +# Enum values are meaningless and may change. Make sure to not store them +# anywhere (for storing/transfer use role names instead). +class UserRole(Enum): + # "role groups", used to grant a collection of "action roles" + nobody = auto() # no permissions granted + user = auto() # can run yara queries and read the state + admin = auto() # can manage the system (and do everything else) + + # "action roles", used to give permission to a specific thing + can_manage_all_queries = auto() + can_manage_queries = auto() + can_list_all_queries = auto() + can_list_queries = auto() + can_view_queries = auto() + can_download_files = auto() + + # Type alias for Job ids JobId = str +class UserModelConfig: + def __init__(self, db_instance): + self.db = db_instance + + @property + def auth_default_roles(self) -> List[UserRole]: + auth_default_roles = self.db.get_mquery_config_key( + "auth_default_roles" + ) + if auth_default_roles is None: + auth_default_roles = "admin" + return [ + UserRole[role.strip()] for role in auth_default_roles.split(",") + ] + + @property + def openid_client_id(self) -> str | None: + return self.db.get_mquery_config_key("openid_client_id") + + @property + def query_allow_slow(self) -> bool: + return self.db.get_mquery_config_key("query_allow_slow") == "true" + + @property + def auth_enabled(self) -> bool: + return self.db.get_mquery_config_key("auth_enabled") == "true" + + @property + def openid_url(self) -> str | None: + return self.db.get_mquery_config_key("openid_url") + + @property + def openid_secret(self) -> str | None: + return self.db.get_mquery_config_key("openid_secret") + + class Database: def __init__(self, redis_host: str, redis_port: int) -> None: self.redis: Any = StrictRedis( @@ -57,6 +111,10 @@ def __schedule(self, agent: str, task: Any, *args: Any) -> None: task, *args, job_timeout=app_config.rq.job_timeout ) + @property + def config(self): + return UserModelConfig(self) + @contextmanager def session(self): with Session(self.engine) as session: