Skip to content

Commit

Permalink
Fix a deadlock in the plugin manager (#355)
Browse files Browse the repository at this point in the history
* Fix a deadlock in the plugin manager

I'm not even sure why the deadlock is there, but it happens on the
server. Removing the global pluginmanager is a good idea anyway,
at a small cost of performance we remove ugly manual locking.
  • Loading branch information
msm-code authored Mar 2, 2023
1 parent f17a874 commit 669d6c3
Showing 1 changed file with 25 additions and 26 deletions.
51 changes: 25 additions & 26 deletions src/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from lib.ursadb import UrsaDb
import os
from threading import Lock

import uvicorn # type: ignore
from config import app_config
Expand All @@ -11,7 +10,6 @@
HTTPException,
Depends,
Header,
BackgroundTasks,
) # type: ignore
from starlette.requests import Request # type: ignore
from starlette.responses import Response, FileResponse, StreamingResponse # type: ignore
Expand Down Expand Up @@ -50,28 +48,16 @@

db = Database(app_config.redis.host, app_config.redis.port)
app = FastAPI()
plugins = PluginManager(app_config.mquery.plugins, db)
plugin_lock = Lock()


def use_plugins(background: BackgroundTasks) -> None:
"""Acquires a plugin_lock, and releases it after cleanup and returning a response.
This function should be called by every API endpoint that uses plugins.
def with_plugins() -> Iterable[PluginManager]:
"""Cleans up plugins after processing."""

This lock is necessary, because nothing in the plugins API makes it obvious that
they should be thread-safe - so we assume that they're not.
"""

def release_and_cleanup():
try:
# Hopefully this won't crash...
plugins.cleanup()
finally:
# ...but just in case it does, we absolutely have to release the lock.
plugin_lock.release()

plugin_lock.acquire()
background.add_task(release_and_cleanup)
plugins = PluginManager(app_config.mquery.plugins, db)
try:
yield plugins
finally:
plugins.cleanup()


class User:
Expand Down Expand Up @@ -326,9 +312,14 @@ def backend_status_datasets() -> BackendStatusDatasetsSchema:
@app.get(
"/api/download",
tags=["stable"],
dependencies=[Depends(can_download_files), Depends(use_plugins)],
dependencies=[Depends(can_download_files)],
)
def download(job_id: str, ordinal: int, file_path: str) -> Response:
def download(
job_id: str,
ordinal: int,
file_path: str,
plugins: PluginManager = Depends(with_plugins),
) -> Response:
"""
Sends a file from given `file_path`. This path should come from
results of one of the previous searches.
Expand Down Expand Up @@ -366,7 +357,13 @@ def download_hashes(job_id: str) -> Response:
return Response(hashes + "\n")


def zip_files(matches: List[Dict[Any, Any]]) -> Iterable[bytes]:
def zip_files(
plugins: PluginManager, matches: List[Dict[str, Any]]
) -> Iterable[bytes]:
"""Adds all the samples to a zip archive (replacing original filename
with sha256) and returns it as a stream of bytes."""
plugins = PluginManager(app_config.mquery.plugins, db)

with tempfile.NamedTemporaryFile() as writer:
with open(writer.name, "rb") as reader:
with zipfile.ZipFile(writer, mode="w") as zipwriter:
Expand All @@ -385,9 +382,11 @@ def zip_files(matches: List[Dict[Any, Any]]) -> Iterable[bytes]:
"/api/download/files/{job_id}",
dependencies=[Depends(is_user), Depends(can_download_files)],
)
async def download_files(job_id: str) -> StreamingResponse:
async def download_files(
job_id: str, plugins: PluginManager = Depends(with_plugins)
) -> StreamingResponse:
matches = db.get_job_matches(job_id).matches
return StreamingResponse(zip_files(matches))
return StreamingResponse(zip_files(plugins, matches))


@app.post(
Expand Down

0 comments on commit 669d6c3

Please sign in to comment.