From 669d6c3313e8f3fe266b529bf1a73bbc24bd75bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jaros=C5=82aw=20Jedynak?= Date: Thu, 2 Mar 2023 17:18:24 +0000 Subject: [PATCH] Fix a deadlock in the plugin manager (#355) * Fix a deadlock in the plugin manager I'm not even sure why the deadlock is there, but it happens on the server. Removing the global pluginmanager is a good idea anyway, at a small cost of performance we remove ugly manual locking. --- src/app.py | 51 +++++++++++++++++++++++++-------------------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/src/app.py b/src/app.py index 4ab27586..1bdb3d4c 100644 --- a/src/app.py +++ b/src/app.py @@ -1,6 +1,5 @@ from lib.ursadb import UrsaDb import os -from threading import Lock import uvicorn # type: ignore from config import app_config @@ -11,7 +10,6 @@ HTTPException, Depends, Header, - BackgroundTasks, ) # type: ignore from starlette.requests import Request # type: ignore from starlette.responses import Response, FileResponse, StreamingResponse # type: ignore @@ -50,28 +48,16 @@ db = Database(app_config.redis.host, app_config.redis.port) app = FastAPI() -plugins = PluginManager(app_config.mquery.plugins, db) -plugin_lock = Lock() -def use_plugins(background: BackgroundTasks) -> None: - """Acquires a plugin_lock, and releases it after cleanup and returning a response. - This function should be called by every API endpoint that uses plugins. +def with_plugins() -> Iterable[PluginManager]: + """Cleans up plugins after processing.""" - This lock is necessary, because nothing in the plugins API makes it obvious that - they should be thread-safe - so we assume that they're not. - """ - - def release_and_cleanup(): - try: - # Hopefully this won't crash... - plugins.cleanup() - finally: - # ...but just in case it does, we absolutely have to release the lock. - plugin_lock.release() - - plugin_lock.acquire() - background.add_task(release_and_cleanup) + plugins = PluginManager(app_config.mquery.plugins, db) + try: + yield plugins + finally: + plugins.cleanup() class User: @@ -326,9 +312,14 @@ def backend_status_datasets() -> BackendStatusDatasetsSchema: @app.get( "/api/download", tags=["stable"], - dependencies=[Depends(can_download_files), Depends(use_plugins)], + dependencies=[Depends(can_download_files)], ) -def download(job_id: str, ordinal: int, file_path: str) -> Response: +def download( + job_id: str, + ordinal: int, + file_path: str, + plugins: PluginManager = Depends(with_plugins), +) -> Response: """ Sends a file from given `file_path`. This path should come from results of one of the previous searches. @@ -366,7 +357,13 @@ def download_hashes(job_id: str) -> Response: return Response(hashes + "\n") -def zip_files(matches: List[Dict[Any, Any]]) -> Iterable[bytes]: +def zip_files( + plugins: PluginManager, matches: List[Dict[str, Any]] +) -> Iterable[bytes]: + """Adds all the samples to a zip archive (replacing original filename + with sha256) and returns it as a stream of bytes.""" + plugins = PluginManager(app_config.mquery.plugins, db) + with tempfile.NamedTemporaryFile() as writer: with open(writer.name, "rb") as reader: with zipfile.ZipFile(writer, mode="w") as zipwriter: @@ -385,9 +382,11 @@ def zip_files(matches: List[Dict[Any, Any]]) -> Iterable[bytes]: "/api/download/files/{job_id}", dependencies=[Depends(is_user), Depends(can_download_files)], ) -async def download_files(job_id: str) -> StreamingResponse: +async def download_files( + job_id: str, plugins: PluginManager = Depends(with_plugins) +) -> StreamingResponse: matches = db.get_job_matches(job_id).matches - return StreamingResponse(zip_files(matches)) + return StreamingResponse(zip_files(plugins, matches)) @app.post(