diff --git a/Makefile b/Makefile index fc271666b..53415ae0a 100644 --- a/Makefile +++ b/Makefile @@ -37,5 +37,8 @@ with-django2x: with-flask: pip3 install -e .[flask] +with-litestar: + pip3 install -e .[litestar] + build-docs: rm -rf html && pdoc --html supertokens_python --template-dir docs-templates \ No newline at end of file diff --git a/dev-requirements.txt b/dev-requirements.txt index 4291ad283..f7f88b85c 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -85,3 +85,4 @@ uvicorn==0.18.2 Werkzeug==2.0.3 wrapt==1.13.3 zipp==3.7.0 +litestar==2.8.1 \ No newline at end of file diff --git a/supertokens_python/__init__.py b/supertokens_python/__init__.py index ea7ba506a..0656347dd 100644 --- a/supertokens_python/__init__.py +++ b/supertokens_python/__init__.py @@ -17,6 +17,7 @@ from typing_extensions import Literal from supertokens_python.framework.request import BaseRequest +from supertokens_python.types import SupportedFrameworks from . import supertokens from .recipe_module import RecipeModule @@ -29,7 +30,7 @@ def init( app_info: InputAppInfo, - framework: Literal["fastapi", "flask", "django"], + framework: SupportedFrameworks, supertokens_config: SupertokensConfig, recipe_list: List[Callable[[supertokens.AppInfo], RecipeModule]], mode: Optional[Literal["asgi", "wsgi"]] = None, diff --git a/supertokens_python/framework/litestar/__init__.py b/supertokens_python/framework/litestar/__init__.py new file mode 100644 index 000000000..aad679a54 --- /dev/null +++ b/supertokens_python/framework/litestar/__init__.py @@ -0,0 +1,46 @@ +from __future__ import annotations +from typing import Any, Callable, Coroutine, TYPE_CHECKING + +from litestar import Request + +from supertokens_python.framework.litestar.litestar_request import LitestarRequest +from .litestar_middleware import get_middleware + +if TYPE_CHECKING: + from ...recipe.session import SessionRecipe, SessionContainer + from ...recipe.session.interfaces import SessionClaimValidator + from ...types import MaybeAwaitable + +__all__ = ['get_middleware'] + + +def verify_session( + anti_csrf_check: bool | None = None, + session_required: bool = True, + override_global_claim_validators: Callable[ + [list[SessionClaimValidator], SessionContainer, dict[str, Any]], + MaybeAwaitable[list[SessionClaimValidator]], + ] + | None = None, + user_context: None | dict[str, Any] = None, +) -> Callable[..., Coroutine[Any, Any, SessionContainer | None]]: + async def func(request: Request[Any, Any, Any]) -> SessionContainer | None: + custom_request = LitestarRequest(request) + recipe = SessionRecipe.get_instance() + session = await recipe.verify_session( + custom_request, + anti_csrf_check, + session_required, + user_context=user_context or {} + ) + + if session: + custom_request.set_session(session) + elif session_required: + raise RuntimeError("Should never come here") + else: + custom_request.set_session_as_none() + + return custom_request.get_session() + + return func \ No newline at end of file diff --git a/supertokens_python/framework/litestar/framework.py b/supertokens_python/framework/litestar/framework.py new file mode 100644 index 000000000..4ac0b7a3f --- /dev/null +++ b/supertokens_python/framework/litestar/framework.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from supertokens_python.framework.types import Framework + +if TYPE_CHECKING: + from litestar import Request + + +class LitestarFramework(Framework): + def wrap_request(self, unwrapped: Request[Any, Any, Any]): + from supertokens_python.framework.litestar.litestar_request import ( + LitestarRequest, + ) + + return LitestarRequest(unwrapped) \ No newline at end of file diff --git a/supertokens_python/framework/litestar/litestar_middleware.py b/supertokens_python/framework/litestar/litestar_middleware.py new file mode 100644 index 000000000..f5eb9eae8 --- /dev/null +++ b/supertokens_python/framework/litestar/litestar_middleware.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from functools import lru_cache +from typing import TYPE_CHECKING, Any + +from litestar.response.base import ASGIResponse + +if TYPE_CHECKING: + from litestar.middleware.base import AbstractMiddleware + + +@lru_cache +def get_middleware() -> type[AbstractMiddleware]: + from supertokens_python import Supertokens + from supertokens_python.exceptions import SuperTokensError + from supertokens_python.framework.litestar.litestar_request import LitestarRequest + from supertokens_python.framework.litestar.litestar_response import LitestarResponse + from supertokens_python.recipe.session import SessionContainer + from supertokens_python.supertokens import manage_session_post_response + from supertokens_python.utils import default_user_context + + from litestar import Response, Request + from litestar.middleware.base import AbstractMiddleware + from litestar.types import Scope, Receive, Send + + class SupertokensMiddleware(AbstractMiddleware): + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + st = Supertokens.get_instance() + request = Request[Any, Any, Any](scope, receive, send) + user_context = default_user_context(request) + + try: + result = await st.middleware( + LitestarRequest(request), + LitestarResponse(Response[Any](content=None)), + user_context + ) + except SuperTokensError as e: + result = await st.handle_supertokens_error( + LitestarRequest(request), + e, + LitestarResponse(Response[Any](content=None)), + user_context + ) + + if isinstance(result, LitestarResponse): + if ( + session_container := request.state.get("supertokens") + ) and isinstance(session_container, SessionContainer): + manage_session_post_response(session_container, result, user_context) + + await result.response.to_asgi_response(app=None, request=request)(scope, receive, send) + return + + await self.app(scope, receive, send) + + return SupertokensMiddleware diff --git a/supertokens_python/framework/litestar/litestar_request.py b/supertokens_python/framework/litestar/litestar_request.py new file mode 100644 index 000000000..6b94881c9 --- /dev/null +++ b/supertokens_python/framework/litestar/litestar_request.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from supertokens_python.framework.request import BaseRequest + +if TYPE_CHECKING: + from litestar import Request + from supertokens_python.recipe.session.interfaces import SessionContainer + +try: + from litestar.exceptions import SerializationException +except ImportError: + SerializationException = Exception # type: ignore + + +class LitestarRequest(BaseRequest): + def __init__(self, request: Request[Any, Any, Any]): + super().__init__() + self.request = request + + def get_original_url(self) -> str: + return self.request.url + + def get_query_param(self, key: str, default: str | None = None) -> Any: + return self.request.query_params.get(key, default) # pyright: ignore + + def get_query_params(self) -> dict[str, list[Any]]: + return self.request.query_params.dict() # pyright: ignore + + async def json(self) -> Any: + try: + return await self.request.json() + except SerializationException: + return {} + + def method(self) -> str: + return self.request.method + + def get_cookie(self, key: str) -> str | None: + return self.request.cookies.get(key) + + def get_header(self, key: str) -> str | None: + return self.request.headers.get(key, None) + + def get_session(self) -> SessionContainer | None: + return self.request.state.supertokens + + def set_session(self, session: SessionContainer): + self.request.state.supertokens = session + + def set_session_as_none(self): + self.request.state.supertokens = None + + def get_path(self) -> str: + return self.request.url.path + + async def form_data(self) -> dict[str, list[Any]]: + return (await self.request.form()).dict() \ No newline at end of file diff --git a/supertokens_python/framework/litestar/litestar_response.py b/supertokens_python/framework/litestar/litestar_response.py new file mode 100644 index 000000000..def235824 --- /dev/null +++ b/supertokens_python/framework/litestar/litestar_response.py @@ -0,0 +1,72 @@ +from __future__ import annotations +from typing import Any, TYPE_CHECKING, cast +from typing_extensions import Literal +from supertokens_python.framework.response import BaseResponse + +if TYPE_CHECKING: + from litestar import Response + + +class LitestarResponse(BaseResponse): + def __init__(self, response: Response[Any]): + super().__init__({}) + self.response = response + self.original = response + self.parser_checked = False + self.response_sent = False + self.status_set = False + + def set_html_content(self, content: str): + if not self.response_sent: + body = bytes(content, "utf-8") + self.set_header("Content-Length", str(len(body))) + self.set_header("Content-Type", "text/html") + self.response.content = body + self.response_sent = True + + def set_cookie( + self, + key: str, + value: str, + expires: int, + path: str = "/", + domain: str | None = None, + secure: bool = False, + httponly: bool = False, + samesite: str = "lax", + ): + self.response.set_cookie( + key=key, + value=value, + expires=expires, + path=path, + domain=domain, + secure=secure, + httponly=httponly, + samesite=cast(Literal["lax", "strict", "none"], samesite), + ) + + def set_header(self, key: str, value: str): + self.response.set_header(key, value) + + def get_header(self, key: str) -> str | None: + return self.response.headers.get(key, None) + + def remove_header(self, key: str): + del self.response.headers[key] + + def set_status_code(self, status_code: int): + if not self.status_set: + self.response.status_code = status_code + self.status_code = status_code + self.status_set = True + + def set_json_content(self, content: dict[str, Any]): + if not self.response_sent: + from litestar.serialization import encode_json + + body = encode_json(content) + self.set_header("Content-Type", "application/json; charset=utf-8") + self.set_header("Content-Length", str(len(body))) + self.response.content = body + self.response_sent = True \ No newline at end of file diff --git a/supertokens_python/logger.py b/supertokens_python/logger.py index 365ff60ea..2a5c4c3e6 100644 --- a/supertokens_python/logger.py +++ b/supertokens_python/logger.py @@ -35,7 +35,6 @@ def enable_debug_logging(): if debug_env == "1": enable_debug_logging() - def _get_log_timestamp() -> str: return datetime.utcnow().isoformat()[:-3] + "Z" diff --git a/supertokens_python/types.py b/supertokens_python/types.py index 1c7b2f799..654477f90 100644 --- a/supertokens_python/types.py +++ b/supertokens_python/types.py @@ -12,10 +12,12 @@ # License for the specific language governing permissions and limitations # under the License. from abc import ABC, abstractmethod -from typing import Any, Awaitable, Dict, List, TypeVar, Union +from typing import Any, Awaitable, Dict, List, TypeVar, Union, Literal _T = TypeVar("_T") +SupportedFrameworks = Literal["fastapi", "flask", "django", "litestar"] + class ThirdPartyInfo: def __init__(self, third_party_user_id: str, third_party_id: str): diff --git a/supertokens_python/utils.py b/supertokens_python/utils.py index 3a14a5f7d..2c8389744 100644 --- a/supertokens_python/utils.py +++ b/supertokens_python/utils.py @@ -30,7 +30,7 @@ List, TypeVar, Union, - Optional, + Optional ) from urllib.parse import urlparse @@ -40,13 +40,15 @@ from supertokens_python.framework.django.framework import DjangoFramework from supertokens_python.framework.fastapi.framework import FastapiFramework from supertokens_python.framework.flask.framework import FlaskFramework +from supertokens_python.framework.litestar.framework import LitestarFramework from supertokens_python.framework.request import BaseRequest from supertokens_python.framework.response import BaseResponse from supertokens_python.logger import log_debug_message from .constants import ERROR_MESSAGE_KEY, RID_KEY_HEADER from .exceptions import raise_general_exception -from .types import MaybeAwaitable +from .framework.types import Framework +from .types import MaybeAwaitable, SupportedFrameworks _T = TypeVar("_T") @@ -54,10 +56,11 @@ pass -FRAMEWORKS = { +FRAMEWORKS: dict[SupportedFrameworks, Framework] = { "fastapi": FastapiFramework(), "flask": FlaskFramework(), "django": DjangoFramework(), + "litestar": LitestarFramework() } diff --git a/tests/litestar/__init__.py b/tests/litestar/__init__.py new file mode 100644 index 000000000..1c8b13ade --- /dev/null +++ b/tests/litestar/__init__.py @@ -0,0 +1,3 @@ +import nest_asyncio # type: ignore + +nest_asyncio.apply() # type: ignore \ No newline at end of file diff --git a/tests/litestar/test_litestar.py b/tests/litestar/test_litestar.py new file mode 100644 index 000000000..a144e2a5e --- /dev/null +++ b/tests/litestar/test_litestar.py @@ -0,0 +1,1004 @@ +# pyright: reportUnknownMemberType=false, reportGeneralTypeIssues=false +from __future__ import annotations +import json +from typing import Any, Dict, Union + +from litestar import get, post, Litestar, Request, MediaType +from litestar.di import Provide +from litestar.testing import TestClient +from pytest import fixture, mark, skip + +from supertokens_python import InputAppInfo, SupertokensConfig, init +from supertokens_python.framework import BaseRequest +from supertokens_python.framework.litestar import get_middleware +from supertokens_python.querier import Querier +from supertokens_python.recipe import emailpassword, session +from supertokens_python.recipe import thirdparty +from supertokens_python.recipe.dashboard import DashboardRecipe, InputOverrideConfig +from supertokens_python.recipe.dashboard.interfaces import RecipeInterface +from supertokens_python.recipe.dashboard.utils import DashboardConfig +from supertokens_python.recipe.emailpassword.interfaces import ( + APIInterface as EPAPIInterface, +) +from supertokens_python.recipe.emailpassword.interfaces import APIOptions +from supertokens_python.recipe.passwordless import PasswordlessRecipe, ContactConfig +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.session.asyncio import ( + create_new_session, + get_session, + refresh_session, +) +from supertokens_python.recipe.session.exceptions import UnauthorisedError +from supertokens_python.recipe.session.framework.litestar import verify_session +from supertokens_python.recipe.session.interfaces import APIInterface +from supertokens_python.recipe.session.interfaces import APIOptions as SessionAPIOptions +from supertokens_python.utils import is_version_gte +from tests.utils import ( + TEST_DRIVER_CONFIG_ACCESS_TOKEN_PATH, + TEST_DRIVER_CONFIG_COOKIE_DOMAIN, + TEST_DRIVER_CONFIG_COOKIE_SAME_SITE, + TEST_DRIVER_CONFIG_REFRESH_TOKEN_PATH, + assert_info_clears_tokens, + clean_st, + extract_all_cookies, + extract_info, + get_st_init_args, + reset, + setup_st, + start_st, + create_users, +) + + +def get_token_transfer_method(*args: Any) -> Any: + return "cookie" + + +def override_dashboard_functions(original_implementation: RecipeInterface): + def should_allow_access( + request: BaseRequest, __: DashboardConfig, ___: Dict[str, Any] + ): + auth_header = request.get_header("authorization") + return auth_header == "Bearer testapikey" + + original_implementation.should_allow_access = should_allow_access + return original_implementation + + +def setup_function(_): + reset() + clean_st() + setup_st() + + +def teardown_function(_): + reset() + clean_st() + + +@fixture(scope="function") +def litestar_test_client() -> TestClient[Litestar]: + @get("/login") + async def login(request: Request[Any, Any, Any]) -> dict[str, Any]: + user_id = "userId" + await create_new_session(request, user_id, {}, {}) + return {"userId": user_id} + + @post("/refresh") + async def custom_refresh(request: Request[Any, Any, Any]) -> dict[str, Any]: + await refresh_session(request) + return {} + + @get("/info") + async def info_get(request: Request[Any, Any, Any]) -> dict[str, Any]: + await get_session(request, True) + return {} + + @get("/custom/info") + def custom_info() -> dict[str, Any]: + return {} + + @get("/handle") + async def handle_get(request: Request[Any, Any, Any]) -> dict[str, Any]: + session: Union[None, SessionContainer] = await get_session(request, True) + if session is None: + raise RuntimeError("Should never come here") + return {"s": session.get_handle()} + + @get( + "/handle-session-optional", + dependencies={"session": Provide(verify_session(session_required=False))}, + ) + def handle_get_optional(session: SessionContainer) -> dict[str, Any]: + + if session is None: + return {"s": "empty session"} + + return {"s": session.get_handle()} + + @post("/logout") + async def custom_logout(request: Request[Any, Any, Any]) -> dict[str, Any]: + session: Union[None, SessionContainer] = await get_session(request, True) + if session is None: + raise RuntimeError("Should never come here") + await session.revoke_session() + return {} + + @post("/create", media_type=MediaType.TEXT) + async def _create(request: Request[Any, Any, Any]) -> str: + await create_new_session(request, "userId", {}, {}) + return "" + + @post("/create-throw") + async def _create_throw(request: Request[Any, Any, Any]) -> None: + await create_new_session(request, "userId", {}, {}) + raise UnauthorisedError("unauthorised") + + app = Litestar( + route_handlers=[ + login, + custom_logout, + custom_refresh, + custom_info, + info_get, + handle_get, + handle_get_optional, + _create, + _create_throw, + ], + middleware=[get_middleware()], + ) + + return TestClient(app) + + +def apis_override_session(param: APIInterface): + param.disable_refresh_post = True + return param + + +def test_login_refresh(litestar_test_client: TestClient[Litestar]): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + override=session.InputOverrideConfig(apis=apis_override_session), + ) + ], + mode="asgi", + ) + start_st() + + with litestar_test_client as client: + response_1 = client.get("/login") + cookies_1 = extract_all_cookies(response_1) + + assert response_1.headers.get("anti-csrf") is not None + assert cookies_1["sAccessToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_1["sRefreshToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_1["sAccessToken"]["path"] == TEST_DRIVER_CONFIG_ACCESS_TOKEN_PATH + assert cookies_1["sRefreshToken"]["path"] == TEST_DRIVER_CONFIG_REFRESH_TOKEN_PATH + assert cookies_1["sAccessToken"]["httponly"] + assert cookies_1["sRefreshToken"]["httponly"] + assert ( + cookies_1["sAccessToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert ( + cookies_1["sRefreshToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + + with litestar_test_client as client: + response_3 = client.post( + url="/refresh", + headers={"anti-csrf": response_1.headers.get("anti-csrf")}, + cookies={ + "sRefreshToken": cookies_1["sRefreshToken"]["value"], + }, + ) + cookies_3 = extract_all_cookies(response_3) + + assert cookies_3["sAccessToken"]["value"] != cookies_1["sAccessToken"]["value"] + assert cookies_3["sRefreshToken"]["value"] != cookies_1["sRefreshToken"]["value"] + assert response_3.headers.get("anti-csrf") is not None + assert cookies_3["sAccessToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_3["sRefreshToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_3["sRefreshToken"]["path"] == TEST_DRIVER_CONFIG_REFRESH_TOKEN_PATH + assert cookies_3["sAccessToken"]["httponly"] + assert cookies_3["sRefreshToken"]["httponly"] + assert ( + cookies_3["sAccessToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert ( + cookies_3["sRefreshToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + + +def test_login_logout(litestar_test_client: TestClient[Litestar]): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ) + ], + mode="asgi", + ) + start_st() + + with litestar_test_client as client: + response_1 = client.get("/login") + cookies_1 = extract_all_cookies(response_1) + + assert response_1.headers.get("anti-csrf") is not None + assert cookies_1["sAccessToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_1["sRefreshToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_1["sAccessToken"]["path"] == TEST_DRIVER_CONFIG_ACCESS_TOKEN_PATH + assert cookies_1["sRefreshToken"]["path"] == TEST_DRIVER_CONFIG_REFRESH_TOKEN_PATH + assert cookies_1["sAccessToken"]["httponly"] + assert cookies_1["sRefreshToken"]["httponly"] + assert ( + cookies_1["sAccessToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert ( + cookies_1["sRefreshToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert cookies_1["sAccessToken"]["secure"] is None + assert cookies_1["sRefreshToken"]["secure"] is None + + with litestar_test_client as client: + response_2 = client.post( + url="/logout", + headers={"anti-csrf": response_1.headers.get("anti-csrf")}, + cookies={ + "sAccessToken": cookies_1["sAccessToken"]["value"], + }, + ) + cookies_2 = extract_all_cookies(response_2) + assert response_2.headers.get("anti-csrf") is None + assert cookies_2["sAccessToken"]["value"] == "" + assert cookies_2["sRefreshToken"]["value"] == "" + assert cookies_2["sAccessToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_2["sRefreshToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_2["sAccessToken"]["path"] == TEST_DRIVER_CONFIG_ACCESS_TOKEN_PATH + assert cookies_2["sAccessToken"]["httponly"] + assert cookies_2["sRefreshToken"]["httponly"] + assert ( + cookies_2["sAccessToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert ( + cookies_2["sRefreshToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert cookies_2["sAccessToken"]["secure"] is None + assert cookies_2["sRefreshToken"]["secure"] is None + + +def test_login_info(litestar_test_client: TestClient[Litestar]): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ) + ], + mode="asgi", + ) + start_st() + + with litestar_test_client as client: + response_1 = client.get("/login") + cookies_1 = extract_all_cookies(response_1) + + assert response_1.headers.get("anti-csrf") is not None + assert cookies_1["sAccessToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_1["sRefreshToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_1["sAccessToken"]["path"] == TEST_DRIVER_CONFIG_ACCESS_TOKEN_PATH + assert cookies_1["sRefreshToken"]["path"] == TEST_DRIVER_CONFIG_REFRESH_TOKEN_PATH + assert cookies_1["sAccessToken"]["httponly"] + assert cookies_1["sRefreshToken"]["httponly"] + assert ( + cookies_1["sAccessToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert ( + cookies_1["sRefreshToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert cookies_1["sAccessToken"]["secure"] is None + assert cookies_1["sRefreshToken"]["secure"] is None + + with litestar_test_client as client: + response_2 = client.get( + url="/info", + headers={"anti-csrf": response_1.headers.get("anti-csrf")}, + cookies={ + "sAccessToken": cookies_1["sAccessToken"]["value"], + }, + ) + cookies_2 = extract_all_cookies(response_2) + assert not cookies_2 + + +def test_login_handle(litestar_test_client: TestClient[Litestar]): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ) + ], + mode="asgi", + ) + start_st() + + with litestar_test_client as client: + response_1 = client.get("/login") + cookies_1 = extract_all_cookies(response_1) + + assert response_1.headers.get("anti-csrf") is not None + assert cookies_1["sAccessToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_1["sRefreshToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_1["sAccessToken"]["path"] == TEST_DRIVER_CONFIG_ACCESS_TOKEN_PATH + assert cookies_1["sRefreshToken"]["path"] == TEST_DRIVER_CONFIG_REFRESH_TOKEN_PATH + assert cookies_1["sAccessToken"]["httponly"] + assert cookies_1["sRefreshToken"]["httponly"] + assert ( + cookies_1["sAccessToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert ( + cookies_1["sRefreshToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert cookies_1["sAccessToken"]["secure"] is None + assert cookies_1["sRefreshToken"]["secure"] is None + + with litestar_test_client as client: + response_2 = client.get( + url="/handle", + headers={"anti-csrf": response_1.headers.get("anti-csrf")}, + cookies={ + "sAccessToken": cookies_1["sAccessToken"]["value"], + }, + ) + result_dict = json.loads(response_2.content) + assert "s" in result_dict + + +def test_login_refresh_error_handler(litestar_test_client: TestClient[Litestar]): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ) + ], + mode="asgi", + ) + start_st() + + with litestar_test_client as client: + response_1 = client.get("/login") + cookies_1 = extract_all_cookies(response_1) + + assert response_1.headers.get("anti-csrf") is not None + assert cookies_1["sAccessToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_1["sRefreshToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_1["sAccessToken"]["path"] == TEST_DRIVER_CONFIG_ACCESS_TOKEN_PATH + assert cookies_1["sRefreshToken"]["path"] == TEST_DRIVER_CONFIG_REFRESH_TOKEN_PATH + assert cookies_1["sAccessToken"]["httponly"] + assert cookies_1["sRefreshToken"]["httponly"] + assert ( + cookies_1["sAccessToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert ( + cookies_1["sRefreshToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert cookies_1["sAccessToken"]["secure"] is None + assert cookies_1["sRefreshToken"]["secure"] is None + + with litestar_test_client as client: + response_3 = client.post( + url="/refresh", + headers={"anti-csrf": response_1.headers.get("anti-csrf")}, + cookies={ + # no cookies + }, + ) + assert response_3.status_code == 401 # not authorized because no refresh tokens + + +def test_custom_response(litestar_test_client: TestClient[Litestar]): + def override_email_password_apis(original_implementation: EPAPIInterface): + original_func = original_implementation.email_exists_get + + async def email_exists_get( + email: str, api_options: APIOptions, user_context: Dict[str, Any] + ): + response_dict = {"custom": True} + api_options.response.set_status_code(203) + api_options.response.set_json_content(response_dict) + return await original_func(email, api_options, user_context) + + original_implementation.email_exists_get = email_exists_get + return original_implementation + + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + emailpassword.init( + override=emailpassword.InputOverrideConfig( + apis=override_email_password_apis + ) + ) + ], + mode="asgi", + ) + start_st() + + with litestar_test_client as client: + response = client.get( + url="/auth/signup/email/exists?email=test@example.com", + ) + + dict_response = json.loads(response.text) + assert response.status_code == 203 + assert dict_response["custom"] + + +def test_optional_session(litestar_test_client: TestClient[Litestar]): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[session.init(get_token_transfer_method=get_token_transfer_method)], + mode="asgi", + ) + start_st() + + with litestar_test_client as client: + response = client.get( + url="handle-session-optional", + ) + + dict_response = json.loads(response.text) + assert response.status_code == 200 + assert dict_response["s"] == "empty session" + + +@mark.parametrize("token_transfer_method", ["cookie", "header"]) +def test_should_clear_all_response_during_refresh_if_unauthorized( + litestar_test_client: TestClient[Litestar], token_transfer_method: str +): + def override_session_apis(oi: APIInterface): + oi_refresh_post = oi.refresh_post + + async def refresh_post( + api_options: SessionAPIOptions, user_context: Dict[str, Any] + ): + await oi_refresh_post(api_options, user_context) + raise UnauthorisedError("unauthorized", clear_tokens=True) + + oi.refresh_post = refresh_post + return oi + + init( + **get_st_init_args( + [ + session.init( + anti_csrf="VIA_TOKEN", + override=session.InputOverrideConfig(apis=override_session_apis), + ) + ] + ) + ) + start_st() + + with litestar_test_client as client: + res = client.post("/create", headers={"st-auth-mode": token_transfer_method}) + info = extract_info(res) # pyright: ignore + + assert info["accessTokenFromAny"] is not None + assert info["refreshTokenFromAny"] is not None + + headers: Dict[str, Any] = {} + cookies: Dict[str, Any] = {} + + if token_transfer_method == "header": + headers.update({"authorization": f"Bearer {info['refreshTokenFromAny']}"}) + else: + cookies.update( + {"sRefreshToken": info["refreshTokenFromAny"], "sIdRefreshToken": "asdf"} + ) + + if info["antiCsrf"] is not None: + headers.update({"anti-csrf": info["antiCsrf"]}) + + with litestar_test_client as client: + res = client.post("/auth/session/refresh", headers=headers, cookies=cookies) + info = extract_info(res) # pyright: ignore + + assert res.status_code == 401 + assert_info_clears_tokens(info, token_transfer_method) + + +@mark.parametrize("token_transfer_method", ["cookie", "header"]) +def test_revoking_session_after_create_new_session_with_throwing_unauthorized_error( + litestar_test_client: TestClient[Litestar], token_transfer_method: str +): + init( + **get_st_init_args( + [ + session.init( + anti_csrf="VIA_TOKEN", + ) + ] + ) + ) + start_st() + + with litestar_test_client as client: + res = client.post( + "/create-throw", headers={"st-auth-mode": token_transfer_method} + ) + info = extract_info(res) # pyright: ignore + + assert res.status_code == 401 + assert_info_clears_tokens(info, token_transfer_method) + + +@mark.asyncio +async def test_search_with_email_t(litestar_test_client: TestClient[Litestar]): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ), + DashboardRecipe.init( + api_key="testapikey", + override=InputOverrideConfig(functions=override_dashboard_functions), + ), + emailpassword.init(), + ], + mode="asgi", + ) + start_st() + querier = Querier.get_instance(DashboardRecipe.recipe_id) + cdi_version = await querier.get_api_version() + if not cdi_version: + skip() + if not is_version_gte(cdi_version, "2.20"): + skip() + await create_users(emailpassword=True) + query = {"limit": "10", "email": "t"} + with litestar_test_client as client: + res = client.get( + "/auth/dashboard/api/users", + headers={ + "Authorization": "Bearer testapikey", + "Content-Type": "application/json", + }, + params=query, + ) + info = extract_info(res) # pyright: ignore + assert res.status_code == 200 + assert len(info["body"]["users"]) == 5 + + +@mark.asyncio +async def test_search_with_email_multiple_email_entry( + litestar_test_client: TestClient[Litestar], +): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ), + DashboardRecipe.init( + api_key="testapikey", + override=InputOverrideConfig(functions=override_dashboard_functions), + ), + emailpassword.init(), + ], + mode="asgi", + ) + start_st() + querier = Querier.get_instance(DashboardRecipe.recipe_id) + cdi_version = await querier.get_api_version() + if not cdi_version: + skip() + if not is_version_gte(cdi_version, "2.20"): + skip() + await create_users(emailpassword=True) + query = {"limit": "10", "email": "iresh;john"} + with litestar_test_client as client: + res = client.get( + "/auth/dashboard/api/users", + headers={ + "Authorization": "Bearer testapikey", + "Content-Type": "application/json", + }, + params=query, + ) + info = extract_info(res) # pyright: ignore + assert res.status_code == 200 + assert len(info["body"]["users"]) == 1 + + +@mark.asyncio +async def test_search_with_email_iresh(litestar_test_client: TestClient[Litestar]): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ), + DashboardRecipe.init( + api_key="testapikey", + override=InputOverrideConfig(functions=override_dashboard_functions), + ), + emailpassword.init(), + ], + mode="asgi", + ) + start_st() + querier = Querier.get_instance(DashboardRecipe.recipe_id) + cdi_version = await querier.get_api_version() + if not cdi_version: + skip() + if not is_version_gte(cdi_version, "2.20"): + skip() + await create_users(emailpassword=True) + query = {"limit": "10", "email": "iresh"} + with litestar_test_client as client: + res = client.get( + "/auth/dashboard/api/users", + headers={ + "Authorization": "Bearer testapikey", + "Content-Type": "application/json", + }, + params=query, + ) + info = extract_info(res) # pyright: ignore + assert res.status_code == 200 + assert len(info["body"]["users"]) == 0 + + +@mark.asyncio +async def test_search_with_phone_plus_one(litestar_test_client: TestClient[Litestar]): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ), + DashboardRecipe.init( + api_key="testapikey", + override=InputOverrideConfig(functions=override_dashboard_functions), + ), + PasswordlessRecipe.init( + contact_config=ContactConfig(contact_method="EMAIL"), + flow_type="USER_INPUT_CODE", + ), + ], + mode="asgi", + ) + start_st() + querier = Querier.get_instance(DashboardRecipe.recipe_id) + cdi_version = await querier.get_api_version() + if not cdi_version: + skip() + if not is_version_gte(cdi_version, "2.20"): + skip() + await create_users(passwordless=True) + query = {"limit": "10", "phone": "+1"} + with litestar_test_client as client: + res = client.get( + "/auth/dashboard/api/users", + headers={ + "Authorization": "Bearer testapikey", + "Content-Type": "application/json", + }, + params=query, + ) + info = extract_info(res) # pyright: ignore + assert res.status_code == 200 + assert len(info["body"]["users"]) == 3 + + +@mark.asyncio +async def test_search_with_phone_one_bracket( + litestar_test_client: TestClient[Litestar], +): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ), + DashboardRecipe.init( + api_key="testapikey", + override=InputOverrideConfig(functions=override_dashboard_functions), + ), + PasswordlessRecipe.init( + contact_config=ContactConfig(contact_method="EMAIL"), + flow_type="USER_INPUT_CODE", + ), + ], + mode="asgi", + ) + start_st() + querier = Querier.get_instance(DashboardRecipe.recipe_id) + cdi_version = await querier.get_api_version() + if not cdi_version: + skip() + if not is_version_gte(cdi_version, "2.20"): + skip() + await create_users(passwordless=True) + query = {"limit": "10", "phone": "1("} + with litestar_test_client as client: + res = client.get( + "/auth/dashboard/api/users", + headers={ + "Authorization": "Bearer testapikey", + "Content-Type": "application/json", + }, + params=query, + ) + info = extract_info(res) # pyright: ignore + assert res.status_code == 200 + assert len(info["body"]["users"]) == 0 + + +@mark.asyncio +async def test_search_with_provider_google(litestar_test_client: TestClient[Litestar]): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ), + DashboardRecipe.init( + api_key="testapikey", + override=InputOverrideConfig(functions=override_dashboard_functions), + ), + thirdparty.init( + sign_in_and_up_feature=thirdparty.SignInAndUpFeature( + providers=[ + thirdparty.Apple( + client_id="4398792-io.supertokens.example.service", + client_key_id="7M48Y4RYDL", + client_team_id="YWQCXGJRJL", + client_private_key="-----BEGIN PRIVATE KEY-----\nMIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgu8gXs+XYkqXD6Ala9Sf/iJXzhbwcoG5dMh1OonpdJUmgCgYIKoZIzj0DAQehRANCAASfrvlFbFCYqn3I2zeknYXLwtH30JuOKestDbSfZYxZNMqhF/OzdZFTV0zc5u5s3eN+oCWbnvl0hM+9IW0UlkdA\n-----END PRIVATE KEY-----", + ), + thirdparty.Google( + client_id="467101b197249757c71f", + client_secret="e97051221f4b6426e8fe8d51486396703012f5bd", + ), + thirdparty.Github( + client_id="1060725074195-kmeum4crr01uirfl2op9kd5acmi9jutn.apps.googleusercontent.com", + client_secret="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", + ), + ] + ) + ), + ], + mode="asgi", + ) + start_st() + querier = Querier.get_instance(DashboardRecipe.recipe_id) + cdi_version = await querier.get_api_version() + if not cdi_version: + skip() + if not is_version_gte(cdi_version, "2.20"): + skip() + await create_users(thirdparty=True) + query = {"limit": "10", "provider": "google"} + with litestar_test_client as client: + res = client.get( + "/auth/dashboard/api/users", + headers={ + "Authorization": "Bearer testapikey", + "Content-Type": "application/json", + }, + params=query, + ) + info = extract_info(res) # pyright: ignore + assert res.status_code == 200 + assert len(info["body"]["users"]) == 3 + + +@mark.asyncio +async def test_search_with_provider_google_and_phone_1( + litestar_test_client: TestClient[Litestar], +): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ), + DashboardRecipe.init( + api_key="testapikey", + override=InputOverrideConfig(functions=override_dashboard_functions), + ), + PasswordlessRecipe.init( + contact_config=ContactConfig(contact_method="EMAIL"), + flow_type="USER_INPUT_CODE", + ), + thirdparty.init( + sign_in_and_up_feature=thirdparty.SignInAndUpFeature( + providers=[ + thirdparty.Apple( + client_id="4398792-io.supertokens.example.service", + client_key_id="7M48Y4RYDL", + client_team_id="YWQCXGJRJL", + client_private_key="-----BEGIN PRIVATE KEY-----\nMIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgu8gXs+XYkqXD6Ala9Sf/iJXzhbwcoG5dMh1OonpdJUmgCgYIKoZIzj0DAQehRANCAASfrvlFbFCYqn3I2zeknYXLwtH30JuOKestDbSfZYxZNMqhF/OzdZFTV0zc5u5s3eN+oCWbnvl0hM+9IW0UlkdA\n-----END PRIVATE KEY-----", + ), + thirdparty.Google( + client_id="467101b197249757c71f", + client_secret="e97051221f4b6426e8fe8d51486396703012f5bd", + ), + thirdparty.Github( + client_id="1060725074195-kmeum4crr01uirfl2op9kd5acmi9jutn.apps.googleusercontent.com", + client_secret="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", + ), + ] + ) + ), + ], + mode="asgi", + ) + start_st() + querier = Querier.get_instance(DashboardRecipe.recipe_id) + cdi_version = await querier.get_api_version() + if not cdi_version: + skip() + if not is_version_gte(cdi_version, "2.20"): + skip() + await create_users(thirdparty=True, passwordless=True) + query = {"limit": "10", "provider": "google", "phone": "1"} + with litestar_test_client as client: + res = client.get( + "/auth/dashboard/api/users", + headers={ + "Authorization": "Bearer testapikey", + "Content-Type": "application/json", + }, + params=query, + ) + info = extract_info(res) # pyright: ignore + assert res.status_code == 200 + assert len(info["body"]["users"]) == 0