Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Litestar framework #483

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,8 @@ with-django2x:
with-flask:
pip3 install -e .[flask]

with-litestar:
pip3 install -e .[litestar]

build-docs:
rm -rf html && pdoc --html supertokens_python --template-dir docs-templates
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,4 @@ uvicorn==0.18.2
Werkzeug==2.0.3
wrapt==1.13.3
zipp==3.7.0
litestar==2.8.1
3 changes: 2 additions & 1 deletion supertokens_python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing_extensions import Literal

from supertokens_python.framework.request import BaseRequest
from supertokens_python.types import SupportedFrameworks

from . import supertokens
from .recipe_module import RecipeModule
Expand All @@ -29,7 +30,7 @@

def init(
app_info: InputAppInfo,
framework: Literal["fastapi", "flask", "django"],
framework: SupportedFrameworks,
supertokens_config: SupertokensConfig,
recipe_list: List[Callable[[supertokens.AppInfo], RecipeModule]],
mode: Optional[Literal["asgi", "wsgi"]] = None,
Expand Down
46 changes: 46 additions & 0 deletions supertokens_python/framework/litestar/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations
from typing import Any, Callable, Coroutine, TYPE_CHECKING

from litestar import Request

from supertokens_python.framework.litestar.litestar_request import LitestarRequest
from .litestar_middleware import get_middleware

if TYPE_CHECKING:
from ...recipe.session import SessionRecipe, SessionContainer
from ...recipe.session.interfaces import SessionClaimValidator
from ...types import MaybeAwaitable

__all__ = ['get_middleware']


def verify_session(
anti_csrf_check: bool | None = None,
session_required: bool = True,
override_global_claim_validators: Callable[
[list[SessionClaimValidator], SessionContainer, dict[str, Any]],
MaybeAwaitable[list[SessionClaimValidator]],
]
| None = None,
user_context: None | dict[str, Any] = None,
) -> Callable[..., Coroutine[Any, Any, SessionContainer | None]]:
async def func(request: Request[Any, Any, Any]) -> SessionContainer | None:
custom_request = LitestarRequest(request)
recipe = SessionRecipe.get_instance()
session = await recipe.verify_session(
custom_request,
anti_csrf_check,
session_required,
user_context=user_context or {}
)

if session:
custom_request.set_session(session)
elif session_required:
raise RuntimeError("Should never come here")
else:
custom_request.set_session_as_none()

return custom_request.get_session()

return func
17 changes: 17 additions & 0 deletions supertokens_python/framework/litestar/framework.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from supertokens_python.framework.types import Framework

if TYPE_CHECKING:
from litestar import Request


class LitestarFramework(Framework):
def wrap_request(self, unwrapped: Request[Any, Any, Any]):
from supertokens_python.framework.litestar.litestar_request import (
LitestarRequest,
)

return LitestarRequest(unwrapped)
57 changes: 57 additions & 0 deletions supertokens_python/framework/litestar/litestar_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

from functools import lru_cache
from typing import TYPE_CHECKING, Any

from litestar.response.base import ASGIResponse

if TYPE_CHECKING:
from litestar.middleware.base import AbstractMiddleware


@lru_cache
def get_middleware() -> type[AbstractMiddleware]:
from supertokens_python import Supertokens
from supertokens_python.exceptions import SuperTokensError
from supertokens_python.framework.litestar.litestar_request import LitestarRequest
from supertokens_python.framework.litestar.litestar_response import LitestarResponse
from supertokens_python.recipe.session import SessionContainer
from supertokens_python.supertokens import manage_session_post_response
from supertokens_python.utils import default_user_context

from litestar import Response, Request
from litestar.middleware.base import AbstractMiddleware
from litestar.types import Scope, Receive, Send

class SupertokensMiddleware(AbstractMiddleware):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
st = Supertokens.get_instance()
request = Request[Any, Any, Any](scope, receive, send)
user_context = default_user_context(request)

try:
result = await st.middleware(
LitestarRequest(request),
LitestarResponse(Response[Any](content=None)),
user_context
)
except SuperTokensError as e:
result = await st.handle_supertokens_error(
LitestarRequest(request),
e,
LitestarResponse(Response[Any](content=None)),
user_context
)

if isinstance(result, LitestarResponse):
if (
session_container := request.state.get("supertokens")
) and isinstance(session_container, SessionContainer):
manage_session_post_response(session_container, result, user_context)

await result.response.to_asgi_response(app=None, request=request)(scope, receive, send)
return

await self.app(scope, receive, send)

return SupertokensMiddleware
59 changes: 59 additions & 0 deletions supertokens_python/framework/litestar/litestar_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from supertokens_python.framework.request import BaseRequest

if TYPE_CHECKING:
from litestar import Request
from supertokens_python.recipe.session.interfaces import SessionContainer

try:
from litestar.exceptions import SerializationException
except ImportError:
SerializationException = Exception # type: ignore


class LitestarRequest(BaseRequest):
def __init__(self, request: Request[Any, Any, Any]):
super().__init__()
self.request = request

def get_original_url(self) -> str:
return self.request.url

def get_query_param(self, key: str, default: str | None = None) -> Any:
return self.request.query_params.get(key, default) # pyright: ignore

def get_query_params(self) -> dict[str, list[Any]]:
return self.request.query_params.dict() # pyright: ignore

async def json(self) -> Any:
try:
return await self.request.json()
except SerializationException:
return {}

def method(self) -> str:
return self.request.method

def get_cookie(self, key: str) -> str | None:
return self.request.cookies.get(key)

def get_header(self, key: str) -> str | None:
return self.request.headers.get(key, None)

def get_session(self) -> SessionContainer | None:
return self.request.state.supertokens

def set_session(self, session: SessionContainer):
self.request.state.supertokens = session

def set_session_as_none(self):
self.request.state.supertokens = None

def get_path(self) -> str:
return self.request.url.path

async def form_data(self) -> dict[str, list[Any]]:
return (await self.request.form()).dict()
72 changes: 72 additions & 0 deletions supertokens_python/framework/litestar/litestar_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations
from typing import Any, TYPE_CHECKING, cast
from typing_extensions import Literal
from supertokens_python.framework.response import BaseResponse

if TYPE_CHECKING:
from litestar import Response


class LitestarResponse(BaseResponse):
def __init__(self, response: Response[Any]):
super().__init__({})
self.response = response
self.original = response
self.parser_checked = False
self.response_sent = False
self.status_set = False

def set_html_content(self, content: str):
if not self.response_sent:
body = bytes(content, "utf-8")
self.set_header("Content-Length", str(len(body)))
self.set_header("Content-Type", "text/html")
self.response.content = body
self.response_sent = True

def set_cookie(
self,
key: str,
value: str,
expires: int,
path: str = "/",
domain: str | None = None,
secure: bool = False,
httponly: bool = False,
samesite: str = "lax",
):
self.response.set_cookie(
key=key,
value=value,
expires=expires,
path=path,
domain=domain,
secure=secure,
httponly=httponly,
samesite=cast(Literal["lax", "strict", "none"], samesite),
)

def set_header(self, key: str, value: str):
self.response.set_header(key, value)

def get_header(self, key: str) -> str | None:
return self.response.headers.get(key, None)

def remove_header(self, key: str):
del self.response.headers[key]

def set_status_code(self, status_code: int):
if not self.status_set:
self.response.status_code = status_code
self.status_code = status_code
self.status_set = True

def set_json_content(self, content: dict[str, Any]):
if not self.response_sent:
from litestar.serialization import encode_json

body = encode_json(content)
self.set_header("Content-Type", "application/json; charset=utf-8")
self.set_header("Content-Length", str(len(body)))
self.response.content = body
self.response_sent = True
1 change: 0 additions & 1 deletion supertokens_python/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def enable_debug_logging():
if debug_env == "1":
enable_debug_logging()


def _get_log_timestamp() -> str:
return datetime.utcnow().isoformat()[:-3] + "Z"

Expand Down
4 changes: 3 additions & 1 deletion supertokens_python/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# License for the specific language governing permissions and limitations
# under the License.
from abc import ABC, abstractmethod
from typing import Any, Awaitable, Dict, List, TypeVar, Union
from typing import Any, Awaitable, Dict, List, TypeVar, Union, Literal

_T = TypeVar("_T")

SupportedFrameworks = Literal["fastapi", "flask", "django", "litestar"]


class ThirdPartyInfo:
def __init__(self, third_party_user_id: str, third_party_id: str):
Expand Down
9 changes: 6 additions & 3 deletions supertokens_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
List,
TypeVar,
Union,
Optional,
Optional
)
from urllib.parse import urlparse

Expand All @@ -40,24 +40,27 @@
from supertokens_python.framework.django.framework import DjangoFramework
from supertokens_python.framework.fastapi.framework import FastapiFramework
from supertokens_python.framework.flask.framework import FlaskFramework
from supertokens_python.framework.litestar.framework import LitestarFramework
from supertokens_python.framework.request import BaseRequest
from supertokens_python.framework.response import BaseResponse
from supertokens_python.logger import log_debug_message

from .constants import ERROR_MESSAGE_KEY, RID_KEY_HEADER
from .exceptions import raise_general_exception
from .types import MaybeAwaitable
from .framework.types import Framework
from .types import MaybeAwaitable, SupportedFrameworks

_T = TypeVar("_T")

if TYPE_CHECKING:
pass


FRAMEWORKS = {
FRAMEWORKS: dict[SupportedFrameworks, Framework] = {
"fastapi": FastapiFramework(),
"flask": FlaskFramework(),
"django": DjangoFramework(),
"litestar": LitestarFramework()
}


Expand Down
3 changes: 3 additions & 0 deletions tests/litestar/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import nest_asyncio # type: ignore

nest_asyncio.apply() # type: ignore
Loading