From 8c15a6a6545e0f9fff3fd7f146219658743eb912 Mon Sep 17 00:00:00 2001 From: Sebastian Briesemeister <11663508+SebastianBr@users.noreply.github.com> Date: Fri, 30 Aug 2024 10:02:30 +0200 Subject: [PATCH] fix: cors middleware mirrors origin in case no initial cookie is present --- starlette/middleware/cors.py | 4 ++-- tests/middleware/test_cors.py | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 4b8e97bc9..8f1dea685 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -158,9 +158,9 @@ async def send( headers = MutableHeaders(scope=message) headers.update(self.simple_headers) origin = request_headers["Origin"] - has_cookie = "cookie" in request_headers + has_cookie = "cookie" in request_headers or "set-cookie" in headers - # If request includes any cookie headers, then we must respond + # If request or response includes any cookie headers, then we must respond # with the specific origin instead of '*'. if self.allow_all_origins and has_cookie: self.allow_explicit_origin(headers, origin) diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index 630361243..0d2bc23c0 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -432,6 +432,29 @@ def homepage(request: Request) -> PlainTextResponse: assert "access-control-allow-credentials" not in response.headers +def test_cors_credentialed_requests_return_specific_origin_without_initial_cookie( + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> PlainTextResponse: + response = PlainTextResponse("Homepage", status_code=200) + response.set_cookie("mycookie", "myvalue", path=None) + return response + + app = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CORSMiddleware, allow_origins=["*"])], + ) + client = test_client_factory(app) + + # Test credentialed request + headers = {"Origin": "https://example.org"} + response = client.get("/", headers=headers) + assert response.status_code == 200 + assert response.text == "Homepage" + assert response.headers["access-control-allow-origin"] == "https://example.org" + assert "access-control-allow-credentials" not in response.headers + + def test_cors_vary_header_defaults_to_origin( test_client_factory: TestClientFactory, ) -> None: