From 427a8dcf357597df27b2509b1ac436caf7708300 Mon Sep 17 00:00:00 2001 From: Yurii Karabas <1998uriyyo@gmail.com> Date: Fri, 15 Nov 2024 00:34:22 +0100 Subject: [PATCH] Fix issue with middleware args passing (#2752) --- starlette/applications.py | 6 ++--- starlette/middleware/__init__.py | 13 +++++----- starlette/routing.py | 6 ++--- tests/middleware/test_base.py | 4 +-- tests/test_applications.py | 44 ++++++++++++++++++++++++++++++++ 5 files changed, 58 insertions(+), 15 deletions(-) diff --git a/starlette/applications.py b/starlette/applications.py index 0feae72e4..aae38f588 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -10,7 +10,7 @@ from typing_extensions import ParamSpec from starlette.datastructures import State, URLPath -from starlette.middleware import Middleware, _MiddlewareClass +from starlette.middleware import Middleware, _MiddlewareFactory from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.errors import ServerErrorMiddleware from starlette.middleware.exceptions import ExceptionMiddleware @@ -96,7 +96,7 @@ def build_middleware_stack(self) -> ASGIApp: app = self.router for cls, args, kwargs in reversed(middleware): - app = cls(app=app, *args, **kwargs) + app = cls(app, *args, **kwargs) return app @property @@ -123,7 +123,7 @@ def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: def add_middleware( self, - middleware_class: type[_MiddlewareClass[P]], + middleware_class: _MiddlewareFactory[P], *args: P.args, **kwargs: P.kwargs, ) -> None: diff --git a/starlette/middleware/__init__.py b/starlette/middleware/__init__.py index 8566aac08..8e0a54edb 100644 --- a/starlette/middleware/__init__.py +++ b/starlette/middleware/__init__.py @@ -8,21 +8,19 @@ else: # pragma: no cover from typing_extensions import ParamSpec -from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.types import ASGIApp P = ParamSpec("P") -class _MiddlewareClass(Protocol[P]): - def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None: ... # pragma: no cover - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ... # pragma: no cover +class _MiddlewareFactory(Protocol[P]): + def __call__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> ASGIApp: ... # pragma: no cover class Middleware: def __init__( self, - cls: type[_MiddlewareClass[P]], + cls: _MiddlewareFactory[P], *args: P.args, **kwargs: P.kwargs, ) -> None: @@ -38,5 +36,6 @@ def __repr__(self) -> str: class_name = self.__class__.__name__ args_strings = [f"{value!r}" for value in self.args] option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()] - args_repr = ", ".join([self.cls.__name__] + args_strings + option_strings) + name = getattr(self.cls, "__name__", "") + args_repr = ", ".join([name] + args_strings + option_strings) return f"{class_name}({args_repr})" diff --git a/starlette/routing.py b/starlette/routing.py index 1504ef50a..3b3c52968 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -236,7 +236,7 @@ def __init__( if middleware is not None: for cls, args, kwargs in reversed(middleware): - self.app = cls(app=self.app, *args, **kwargs) + self.app = cls(self.app, *args, **kwargs) if methods is None: self.methods = None @@ -328,7 +328,7 @@ def __init__( if middleware is not None: for cls, args, kwargs in reversed(middleware): - self.app = cls(app=self.app, *args, **kwargs) + self.app = cls(self.app, *args, **kwargs) self.path_regex, self.path_format, self.param_convertors = compile_path(path) @@ -388,7 +388,7 @@ def __init__( self.app = self._base_app if middleware is not None: for cls, args, kwargs in reversed(middleware): - self.app = cls(app=self.app, *args, **kwargs) + self.app = cls(self.app, *args, **kwargs) self.name = name self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}") diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 041cc7ce2..fa0cba479 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -10,7 +10,7 @@ from starlette.applications import Starlette from starlette.background import BackgroundTask -from starlette.middleware import Middleware, _MiddlewareClass +from starlette.middleware import Middleware, _MiddlewareFactory from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import ClientDisconnect, Request from starlette.responses import PlainTextResponse, Response, StreamingResponse @@ -232,7 +232,7 @@ async def dispatch( ) def test_contextvars( test_client_factory: TestClientFactory, - middleware_cls: type[_MiddlewareClass[Any]], + middleware_cls: _MiddlewareFactory[Any], ) -> None: # this has to be an async endpoint because Starlette calls run_in_threadpool # on sync endpoints which has it's own set of peculiarities w.r.t propagating diff --git a/tests/test_applications.py b/tests/test_applications.py index 056044438..29c011a29 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os from contextlib import asynccontextmanager from pathlib import Path @@ -533,6 +535,48 @@ def get_app() -> ASGIApp: assert SimpleInitializableMiddleware.counter == 2 +def test_middleware_args(test_client_factory: TestClientFactory) -> None: + calls: list[str] = [] + + class MiddlewareWithArgs: + def __init__(self, app: ASGIApp, arg: str) -> None: + self.app = app + self.arg = arg + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + calls.append(self.arg) + await self.app(scope, receive, send) + + app = Starlette() + app.add_middleware(MiddlewareWithArgs, "foo") + app.add_middleware(MiddlewareWithArgs, "bar") + + with test_client_factory(app): + pass + + assert calls == ["bar", "foo"] + + +def test_middleware_factory(test_client_factory: TestClientFactory) -> None: + calls: list[str] = [] + + def _middleware_factory(app: ASGIApp, arg: str) -> ASGIApp: + async def _app(scope: Scope, receive: Receive, send: Send) -> None: + calls.append(arg) + await app(scope, receive, send) + + return _app + + app = Starlette() + app.add_middleware(_middleware_factory, arg="foo") + app.add_middleware(_middleware_factory, arg="bar") + + with test_client_factory(app): + pass + + assert calls == ["bar", "foo"] + + def test_lifespan_app_subclass() -> None: # This test exists to make sure that subclasses of Starlette # (like FastAPI) are compatible with the types hints for Lifespan