Skip to content

Commit

Permalink
Refactor user model and config handling (#433)
Browse files Browse the repository at this point in the history
* UserModelCOnfig with __getattr__
* UserModelConfig with @Property decorator
* Config prepared
  • Loading branch information
michalkrzem authored Nov 18, 2024
1 parent acad24d commit 3d7784e
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 39 deletions.
52 changes: 14 additions & 38 deletions src/app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from contextlib import asynccontextmanager
import os

from enum import Enum, auto
import uvicorn # type: ignore
from pathlib import Path
from fastapi import (
Expand All @@ -26,7 +25,7 @@

from .config import app_config
from .util import mquery_version
from .db import Database
from .db import Database, UserRole
from .lib.yaraparse import parse_yara
from .plugins import PluginManager
from .lib.ursadb import UrsaDb
Expand Down Expand Up @@ -71,24 +70,6 @@ def with_plugins() -> Iterable[PluginManager]:
plugins.cleanup()


# See docs/users.md for documentation on the permission model.
# Enum values are meaningless and may change. Make sure to not store them
# anywhere (for storing/transfer use role names instead).
class UserRole(Enum):
# "role groups", used to grant a collection of "action roles"
nobody = auto() # no permissions granted
user = auto() # can run yara queries and read the state
admin = auto() # can manage the system (and do everything else)

# "action roles", used to give permission to a specific thing
can_manage_all_queries = auto()
can_manage_queries = auto()
can_list_all_queries = auto()
can_list_queries = auto()
can_view_queries = auto()
can_download_files = auto()


class User:
def __init__(self, token: Optional[Dict]) -> None:
self.__token = token
Expand All @@ -114,8 +95,8 @@ def roles(self, client_id: Optional[str]) -> List[UserRole]:


async def current_user(authorization: Optional[str] = Header(None)) -> User:
auth_enabled = db.get_mquery_config_key("auth_enabled")
if not auth_enabled or auth_enabled == "false":
auth_enabled = db.config.auth_enabled
if not auth_enabled:
return User(None)

if not authorization:
Expand All @@ -134,7 +115,7 @@ async def current_user(authorization: Optional[str] = Header(None)) -> User:

_bearer, token = token_parts

secret = db.get_mquery_config_key("openid_secret")
secret = db.config.openid_secret
if secret is None:
raise RuntimeError("Invalid configuration - missing_openid_secret.")

Expand Down Expand Up @@ -169,9 +150,9 @@ def __init__(self, need_permissions: List[UserRole]) -> None:
self.need_permissions = need_permissions

def __call__(self, user: User = Depends(current_user)):
auth_enabled = db.get_mquery_config_key("auth_enabled")
if not auth_enabled or auth_enabled == "false":
return
auth_enabled = db.config.auth_enabled
if not auth_enabled:
return User(None)

all_roles = get_user_roles(user)
if not any(role in self.need_permissions for role in all_roles):
Expand All @@ -198,15 +179,10 @@ def __call__(self, user: User = Depends(current_user)):
def get_user_roles(user: User) -> List[UserRole]:
"""Get all roles assigned to user, taking into account the
system configuration (like default configured roles)"""
client_id = db.get_mquery_config_key("openid_client_id")
client_id = db.config.openid_client_id
user_roles = user.roles(client_id)
auth_default_roles = db.get_mquery_config_key("auth_default_roles")
if not auth_default_roles:
auth_default_roles = "admin"
default_roles = [
UserRole[role.strip()] for role in auth_default_roles.split(",")
]
all_roles = set(user_roles + default_roles)
auth_default_roles = db.config.auth_default_roles
all_roles = set(user_roles + auth_default_roles)
return sum((expand_role(role) for role in all_roles), [])


Expand Down Expand Up @@ -455,7 +431,7 @@ def query(
]

degenerate_rules = [r.name for r in rules if r.parse().is_degenerate]
allow_slow = db.get_mquery_config_key("query_allow_slow") == "true"
allow_slow = db.config.query_allow_slow
if degenerate_rules and not (allow_slow and data.force_slow_queries):
if allow_slow:
# Warning: "You can force a slow query" literal is used to
Expand Down Expand Up @@ -601,9 +577,9 @@ def query_remove(
def server() -> ServerSchema:
return ServerSchema(
version=mquery_version(),
auth_enabled=db.get_mquery_config_key("auth_enabled"),
openid_url=db.get_mquery_config_key("openid_url"),
openid_client_id=db.get_mquery_config_key("openid_client_id"),
auth_enabled=str(db.config.auth_enabled).lower(),
openid_url=db.config.openid_url,
openid_client_id=db.config.openid_client_id,
about=app_config.mquery.about,
)

Expand Down
60 changes: 59 additions & 1 deletion src/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import random
import string
from redis import StrictRedis
from enum import Enum
from enum import Enum, auto
from rq import Queue # type: ignore
from sqlmodel import (
Session,
Expand Down Expand Up @@ -40,10 +40,64 @@ class TaskType(Enum):
COMMAND = "command"


# See docs/users.md for documentation on the permission model.
# Enum values are meaningless and may change. Make sure to not store them
# anywhere (for storing/transfer use role names instead).
class UserRole(Enum):
# "role groups", used to grant a collection of "action roles"
nobody = auto() # no permissions granted
user = auto() # can run yara queries and read the state
admin = auto() # can manage the system (and do everything else)

# "action roles", used to give permission to a specific thing
can_manage_all_queries = auto()
can_manage_queries = auto()
can_list_all_queries = auto()
can_list_queries = auto()
can_view_queries = auto()
can_download_files = auto()


# Type alias for Job ids
JobId = str


class UserModelConfig:
def __init__(self, db_instance):
self.db = db_instance

@property
def auth_default_roles(self) -> List[UserRole]:
auth_default_roles = self.db.get_mquery_config_key(
"auth_default_roles"
)
if auth_default_roles is None:
auth_default_roles = "admin"
return [
UserRole[role.strip()] for role in auth_default_roles.split(",")
]

@property
def openid_client_id(self) -> str | None:
return self.db.get_mquery_config_key("openid_client_id")

@property
def query_allow_slow(self) -> bool:
return self.db.get_mquery_config_key("query_allow_slow") == "true"

@property
def auth_enabled(self) -> bool:
return self.db.get_mquery_config_key("auth_enabled") == "true"

@property
def openid_url(self) -> str | None:
return self.db.get_mquery_config_key("openid_url")

@property
def openid_secret(self) -> str | None:
return self.db.get_mquery_config_key("openid_secret")


class Database:
def __init__(self, redis_host: str, redis_port: int) -> None:
self.redis: Any = StrictRedis(
Expand All @@ -57,6 +111,10 @@ def __schedule(self, agent: str, task: Any, *args: Any) -> None:
task, *args, job_timeout=app_config.rq.job_timeout
)

@property
def config(self):
return UserModelConfig(self)

@contextmanager
def session(self):
with Session(self.engine) as session:
Expand Down

0 comments on commit 3d7784e

Please sign in to comment.