diff --git a/asu/fastapi/staticfiles.py b/asu/fastapi/staticfiles.py new file mode 100644 index 00000000..96cfa4a2 --- /dev/null +++ b/asu/fastapi/staticfiles.py @@ -0,0 +1,38 @@ +import os + +from fastapi.responses import FileResponse, Response +from fastapi.staticfiles import StaticFiles as FastApiStaticFiles + +from starlette.staticfiles import PathLike +from starlette.types import Scope + + +class StaticFiles(FastApiStaticFiles): + def __init__( + self, + *, + directory: PathLike | None = None, + packages: list[str | tuple[str, str]] | None = None, + html: bool = False, + check_dir: bool = True, + follow_symlink: bool = False, + ) -> None: + super().__init__( + directory=directory, + packages=packages, + html=html, + check_dir=check_dir, + follow_symlink=follow_symlink, + ) + + def file_response( + self, + full_path: PathLike, + stat_result: os.stat_result, + scope: Scope, + status_code: int = 200, + ) -> Response: + response = super().file_response(full_path, stat_result, scope, status_code) + if isinstance(response, FileResponse): + response.headers["Content-Type"] = "application/octet-stream" + return response diff --git a/asu/main.py b/asu/main.py index 9a45f929..4dcff09d 100644 --- a/asu/main.py +++ b/asu/main.py @@ -15,6 +15,7 @@ from asu.config import settings from asu.routers import api from asu.util import get_redis_client, parse_feeds_conf, parse_packages_file +from asu.fastapi.staticfiles import StaticFiles as AsuStaticFiles logging.basicConfig(encoding="utf-8", level=settings.log_level) @@ -36,7 +37,9 @@ async def lifespan(app: FastAPI): (settings.public_path / "json").mkdir(parents=True, exist_ok=True) (settings.public_path / "store").mkdir(parents=True, exist_ok=True) -app.mount("/store", StaticFiles(directory=settings.public_path / "store"), name="store") +app.mount( + "/store", AsuStaticFiles(directory=settings.public_path / "store"), name="store" +) app.mount("/static", StaticFiles(directory=base_path / "static"), name="static") templates = Jinja2Templates(directory=base_path / "templates") diff --git a/tests/test_store.py b/tests/test_store.py new file mode 100644 index 00000000..5cb0b586 --- /dev/null +++ b/tests/test_store.py @@ -0,0 +1,36 @@ +from asu.config import settings + +store_path = settings.public_path / "store" + + +def test_store_content_type_img(client): + store_path.mkdir(parents=True, exist_ok=True) + with open(store_path / "test_store_content_type.img", "w+b"): + pass + response = client.head("/store/test_store_content_type.img") + + assert response.status_code == 200 + + headers = response.headers + assert headers["Content-Type"] == "application/octet-stream" + + +def test_store_content_type_imggz(client): + store_path.mkdir(parents=True, exist_ok=True) + with open(store_path / "test_store_content_type.img.gz", "w+b"): + pass + response = client.head("/store/test_store_content_type.img.gz") + + assert response.status_code == 200 + + headers = response.headers + assert headers["Content-Type"] == "application/octet-stream" + + +def test_store_file_missing(client): + response = client.head("/store/test_store_file_missing.bin") + + assert response.status_code == 404 + + headers = response.headers + assert headers["Content-Type"] != "application/octet-stream"