diff --git a/demo/__init__.py b/demo/__init__.py index 616139d2..b93b1828 100644 --- a/demo/__init__.py +++ b/demo/__init__.py @@ -6,12 +6,12 @@ from fastapi import FastAPI from fastapi.responses import HTMLResponse, PlainTextResponse from fastui import prebuilt_html +from fastui.auth import AuthError from fastui.dev import dev_fastapi_app from httpx import AsyncClient from .auth import router as auth_router from .components_list import router as components_router -from .db import create_db from .forms import router as forms_router from .main import router as main_router from .sse import router as sse_router @@ -20,7 +20,6 @@ @asynccontextmanager async def lifespan(app_: FastAPI): - await create_db() async with AsyncClient() as client: app_.state.httpx_client = client yield @@ -33,6 +32,7 @@ async def lifespan(app_: FastAPI): else: app = FastAPI(lifespan=lifespan) +app.exception_handler(AuthError)(AuthError.fastapi_handle) app.include_router(components_router, prefix='/api/components') app.include_router(sse_router, prefix='/api/components') app.include_router(table_router, prefix='/api/table') diff --git a/demo/auth.py b/demo/auth.py index b9ba8da5..cd41a656 100644 --- a/demo/auth.py +++ b/demo/auth.py @@ -1,49 +1,111 @@ from __future__ import annotations as _annotations -from typing import Annotated +import asyncio +import json +import os +from dataclasses import asdict +from typing import Annotated, Literal, TypeAlias -from fastapi import APIRouter, Depends, Header +from fastapi import APIRouter, Depends, Request from fastui import AnyComponent, FastUI from fastui import components as c +from fastui.auth import GitHubAuthProvider from fastui.events import AuthEvent, GoToEvent, PageEvent from fastui.forms import fastui_form +from httpx import AsyncClient from pydantic import BaseModel, EmailStr, Field, SecretStr -from . import db +from .auth_user import User from .shared import demo_page router = APIRouter() -async def get_user(authorization: Annotated[str, Header()] = '') -> db.User | None: - try: - token = authorization.split(' ', 1)[1] - except IndexError: - return None - else: - return await db.get_user(token) +# this will give an error when making requests to GitHub, but at least the app will run +GITHUB_CLIENT_SECRET = SecretStr(os.getenv('GITHUB_CLIENT_SECRET', 'dummy-secret')) + + +async def get_github_auth(request: Request) -> GitHubAuthProvider: + client: AsyncClient = request.app.state.httpx_client + return GitHubAuthProvider( + httpx_client=client, + github_client_id='9eddf87b27f71f52194a', + github_client_secret=GITHUB_CLIENT_SECRET, + scopes=['user:email'], + ) -@router.get('/login', response_model=FastUI, response_model_exclude_none=True) -def auth_login(user: Annotated[str | None, Depends(get_user)]) -> list[AnyComponent]: +LoginKind: TypeAlias = Literal['password', 'github'] + + +@router.get('/login/{kind}', response_model=FastUI, response_model_exclude_none=True) +async def auth_login( + kind: LoginKind, + user: Annotated[User | None, Depends(User.from_request)], + github_auth: Annotated[GitHubAuthProvider, Depends(get_github_auth)], +) -> list[AnyComponent]: if user is None: return demo_page( - c.Paragraph( - text=( - 'This is a very simple demo of authentication, ' - 'here you can "login" with any email address and password.' - ) + c.LinkList( + links=[ + c.Link( + components=[c.Text(text='Password Login')], + on_click=PageEvent(name='tab', push_path='/auth/login/password', context={'kind': 'password'}), + active='/auth/login/password', + ), + c.Link( + components=[c.Text(text='GitHub Login')], + on_click=PageEvent(name='tab', push_path='/auth/login/github', context={'kind': 'github'}), + active='/auth/login/github', + ), + ], + mode='tabs', + class_name='+ mb-4', + ), + c.ServerLoad( + path='/auth/login/content/{kind}', + load_trigger=PageEvent(name='tab'), + components=await auth_login_content(kind, github_auth), ), - c.Heading(text='Login'), - c.ModelForm(model=LoginForm, submit_url='/api/auth/login'), title='Authentication', ) else: return [c.FireEvent(event=GoToEvent(url='/auth/profile'))] +@router.get('/login/content/{kind}', response_model=FastUI, response_model_exclude_none=True) +async def auth_login_content( + kind: LoginKind, github_auth: Annotated[GitHubAuthProvider, Depends(get_github_auth)] +) -> list[AnyComponent]: + match kind: + case 'password': + return [ + c.Heading(text='Password Login', level=3), + c.Paragraph( + text=( + 'This is a very simple demo of password authentication, ' + 'here you can "login" with any email address and password.' + ) + ), + c.Paragraph(text='(Passwords are not saved and email stored in the browser via a JWT)'), + c.ModelForm(model=LoginForm, submit_url='/api/auth/login'), + ] + case 'github': + auth_url = await github_auth.authorization_url() + return [ + c.Heading(text='GitHub Login', level=3), + c.Paragraph(text='Demo of GitHub authentication.'), + c.Paragraph(text='(Credentials are stored in the browser via a JWT)'), + c.Button(text='Login with GitHub', on_click=GoToEvent(url=auth_url)), + ] + case _: + raise ValueError(f'Invalid kind {kind!r}') + + class LoginForm(BaseModel): - email: EmailStr = Field(title='Email Address', description='Enter whatever value you like') + email: EmailStr = Field( + title='Email Address', description='Enter whatever value you like', json_schema_extra={'autocomplete': 'email'} + ) password: SecretStr = Field( title='Password', description='Enter whatever value you like, password is not checked', @@ -53,19 +115,21 @@ class LoginForm(BaseModel): @router.post('/login', response_model=FastUI, response_model_exclude_none=True) async def login_form_post(form: Annotated[LoginForm, fastui_form(LoginForm)]) -> list[AnyComponent]: - token = await db.create_user(form.email) + user = User(email=form.email, extra={}) + token = user.encode_token() return [c.FireEvent(event=AuthEvent(token=token, url='/auth/profile'))] @router.get('/profile', response_model=FastUI, response_model_exclude_none=True) -async def profile(user: Annotated[db.User | None, Depends(get_user)]) -> list[AnyComponent]: +async def profile(user: Annotated[User | None, Depends(User.from_request)]) -> list[AnyComponent]: if user is None: return [c.FireEvent(event=GoToEvent(url='/auth/login'))] else: - active_count = await db.count_users() return demo_page( - c.Paragraph(text=f'You are logged in as "{user.email}", {active_count} active users right now.'), + c.Paragraph(text=f'You are logged in as "{user.email}".'), c.Button(text='Logout', on_click=PageEvent(name='submit-form')), + c.Heading(text='User Data:', level=3), + c.Code(language='json', text=json.dumps(asdict(user), indent=2)), c.Form( submit_url='/api/auth/logout', form_fields=[c.FormFieldInput(name='test', title='', initial='data', html_type='hidden')], @@ -77,7 +141,26 @@ async def profile(user: Annotated[db.User | None, Depends(get_user)]) -> list[An @router.post('/logout', response_model=FastUI, response_model_exclude_none=True) -async def logout_form_post(user: Annotated[db.User | None, Depends(get_user)]) -> list[AnyComponent]: - if user is not None: - await db.delete_user(user) - return [c.FireEvent(event=AuthEvent(token=False, url='/auth/login'))] +async def logout_form_post() -> list[AnyComponent]: + return [c.FireEvent(event=AuthEvent(token=False, url='/auth/login/password'))] + + +@router.get('/login/github/redirect', response_model=FastUI, response_model_exclude_none=True) +async def github_redirect( + code: str, + state: str | None, + github_auth: Annotated[GitHubAuthProvider, Depends(get_github_auth)], +) -> list[AnyComponent]: + exchange = await github_auth.exchange_code(code, state) + user_info, emails = await asyncio.gather( + github_auth.get_github_user(exchange), github_auth.get_github_user_emails(exchange) + ) + user = User( + email=next((e.email for e in emails if e.primary and e.verified), None), + extra={ + 'github_user_info': user_info.model_dump(), + 'github_emails': [e.model_dump() for e in emails], + }, + ) + token = user.encode_token() + return [c.FireEvent(event=AuthEvent(token=token, url='/auth/profile'))] diff --git a/demo/auth_user.py b/demo/auth_user.py new file mode 100644 index 00000000..e837d7f9 --- /dev/null +++ b/demo/auth_user.py @@ -0,0 +1,38 @@ +import json +from dataclasses import asdict, dataclass +from datetime import datetime +from typing import Annotated, Any, Self + +import jwt +from fastapi import Header, HTTPException + +JWT_SECRET = 'secret' + + +@dataclass +class User: + email: str | None + extra: dict[str, Any] + + def encode_token(self) -> str: + return jwt.encode(asdict(self), JWT_SECRET, algorithm='HS256', json_encoder=CustomJsonEncoder) + + @classmethod + async def from_request(cls, authorization: Annotated[str, Header()] = '') -> Self | None: + try: + token = authorization.split(' ', 1)[1] + except IndexError: + return None + + try: + return cls(**jwt.decode(token, JWT_SECRET, algorithms=['HS256'])) + except jwt.DecodeError: + raise HTTPException(status_code=401, detail='Invalid token') + + +class CustomJsonEncoder(json.JSONEncoder): + def default(self, obj: Any) -> Any: + if isinstance(obj, datetime): + return obj.isoformat() + else: + return super().default(obj) diff --git a/demo/db.py b/demo/db.py deleted file mode 100644 index c3932518..00000000 --- a/demo/db.py +++ /dev/null @@ -1,73 +0,0 @@ -import os -import secrets -from contextlib import asynccontextmanager -from dataclasses import dataclass -from datetime import datetime - -import libsql_client - - -@dataclass -class User: - token: str - email: str - last_active: datetime - - -async def get_user(token: str) -> User | None: - async with _connect() as conn: - rs = await conn.execute('select * from users where token = ?', (token,)) - if rs.rows: - await conn.execute('update users set last_active = current_timestamp where token = ?', (token,)) - return User(*rs.rows[0]) - - -async def create_user(email: str) -> str: - async with _connect() as conn: - await _delete_old_users(conn) - token = secrets.token_hex() - await conn.execute('insert into users (token, email) values (?, ?)', (token, email)) - return token - - -async def delete_user(user: User) -> None: - async with _connect() as conn: - await conn.execute('delete from users where token = ?', (user.token,)) - - -async def count_users() -> int: - async with _connect() as conn: - await _delete_old_users(conn) - rs = await conn.execute('select count(*) from users') - return rs.rows[0][0] - - -async def create_db() -> None: - async with _connect() as conn: - rs = await conn.execute("select 1 from sqlite_master where type='table' and name='users'") - if not rs.rows: - await conn.execute(SCHEMA) - - -SCHEMA = """ -create table if not exists users ( - token varchar(255) primary key, - email varchar(255) not null unique, - last_active timestamp not null default current_timestamp -); -""" - - -async def _delete_old_users(conn: libsql_client.Client) -> None: - await conn.execute('delete from users where last_active < datetime(current_timestamp, "-1 hour")') - - -@asynccontextmanager -async def _connect() -> libsql_client.Client: - auth_token = os.getenv('SQLITE_AUTH_TOKEN') - if auth_token: - url = 'libsql://fastui-samuelcolvin.turso.io' - else: - url = 'file:users.db' - async with libsql_client.create_client(url, auth_token=auth_token) as conn: - yield conn diff --git a/demo/main.py b/demo/main.py index 1671b10d..cdba7e22 100644 --- a/demo/main.py +++ b/demo/main.py @@ -37,6 +37,10 @@ def api_index() -> list[AnyComponent]: * `Table` — See [cities table](/table/cities) and [users table](/table/users) * `Pagination` — See the bottom of the [cities table](/table/cities) * `ModelForm` — See [forms](/forms/login) + +Authentication is supported via: +* token based authentication — see [here](/auth/login/password) for an example of password authentication +* GitHub OAuth — see [here](/auth/login/github) for an example of GitHub OAuth login """ return demo_page(c.Markdown(text=markdown)) diff --git a/demo/shared.py b/demo/shared.py index 7b7263c7..8f3f12f1 100644 --- a/demo/shared.py +++ b/demo/shared.py @@ -24,7 +24,7 @@ def demo_page(*components: AnyComponent, title: str | None = None) -> list[AnyCo ), c.Link( components=[c.Text(text='Auth')], - on_click=GoToEvent(url='/auth/login'), + on_click=GoToEvent(url='/auth/login/password'), active='startswith:/auth', ), c.Link( diff --git a/demo/tests.py b/demo/tests.py index a6c02161..0677a332 100644 --- a/demo/tests.py +++ b/demo/tests.py @@ -6,17 +6,21 @@ from . import app -client = TestClient(app) +@pytest.fixture +def client(): + with TestClient(app) as test_client: + yield test_client -def test_index(): + +def test_index(client: TestClient): r = client.get('/') assert r.status_code == 200, r.text assert r.text.startswith('\n') assert r.headers.get('content-type') == 'text/html; charset=utf-8' -def test_api_root(): +def test_api_root(client: TestClient): r = client.get('/api/') assert r.status_code == 200 data = r.json() @@ -52,16 +56,17 @@ def get_menu_links(): """ This is pretty cursory, we just go through the menu and load each page. """ - r = client.get('/api/') - assert r.status_code == 200 - data = r.json() - for link in data[1]['links']: - url = link['onClick']['url'] - yield pytest.param(f'/api{url}', id=url) + with TestClient(app) as client: + r = client.get('/api/') + assert r.status_code == 200 + data = r.json() + for link in data[1]['links']: + url = link['onClick']['url'] + yield pytest.param(f'/api{url}', id=url) @pytest.mark.parametrize('url', get_menu_links()) -def test_menu_links(url: str): +def test_menu_links(client: TestClient, url: str): r = client.get(url) assert r.status_code == 200 data = r.json() diff --git a/pyproject.toml b/pyproject.toml index cae01196..d4518502 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,3 +26,15 @@ omit = [ "src/python-fastui/fastui/__main__.py", "src/python-fastui/fastui/generate_typescript.py", ] + +[tool.coverage.report] +precision = 2 +exclude_lines = [ + 'pragma: no cover', + 'raise NotImplementedError', + 'if TYPE_CHECKING:', + 'if typing.TYPE_CHECKING:', + '@overload', + '@typing.overload', + '\(Protocol\):$', +] diff --git a/src/npm-fastui-bootstrap/package.json b/src/npm-fastui-bootstrap/package.json index f48496e0..817b42f5 100644 --- a/src/npm-fastui-bootstrap/package.json +++ b/src/npm-fastui-bootstrap/package.json @@ -1,6 +1,6 @@ { "name": "@pydantic/fastui-bootstrap", - "version": "0.0.15", + "version": "0.0.16", "description": "Boostrap renderer for FastUI", "main": "dist/index.js", "types": "dist/index.d.ts", @@ -29,6 +29,6 @@ "sass": "^1.69.5" }, "peerDependencies": { - "@pydantic/fastui": "0.0.15" + "@pydantic/fastui": "0.0.16" } } diff --git a/src/npm-fastui-bootstrap/src/index.tsx b/src/npm-fastui-bootstrap/src/index.tsx index b2fef51f..b394a3c5 100644 --- a/src/npm-fastui-bootstrap/src/index.tsx +++ b/src/npm-fastui-bootstrap/src/index.tsx @@ -148,7 +148,7 @@ export const classNameGenerator: ClassNameGenerator = ({ if (props.statusCode === 502) { return 'm-3 text-muted' } else { - return 'alert alert-danger m-3' + return 'error-alert alert alert-danger m-3' } } } diff --git a/src/npm-fastui-prebuilt/package.json b/src/npm-fastui-prebuilt/package.json index bd5342b6..00ef82f5 100644 --- a/src/npm-fastui-prebuilt/package.json +++ b/src/npm-fastui-prebuilt/package.json @@ -1,6 +1,6 @@ { "name": "@pydantic/fastui-prebuilt", - "version": "0.0.15", + "version": "0.0.16", "description": "Pre-built files for FastUI", "main": "dist/index.html", "type": "module", diff --git a/src/npm-fastui-prebuilt/src/main.scss b/src/npm-fastui-prebuilt/src/main.scss index 07cbace6..b099f9c8 100644 --- a/src/npm-fastui-prebuilt/src/main.scss +++ b/src/npm-fastui-prebuilt/src/main.scss @@ -116,3 +116,9 @@ h6 { box-shadow: 0 2.5em 0 0; } } + +// make sure alerts aren't hidden behind the navbar +.error-alert { + position: relative; + top: 60px; +} diff --git a/src/npm-fastui/package.json b/src/npm-fastui/package.json index bfafd41d..9979788c 100644 --- a/src/npm-fastui/package.json +++ b/src/npm-fastui/package.json @@ -1,6 +1,6 @@ { "name": "@pydantic/fastui", - "version": "0.0.15", + "version": "0.0.16", "description": "Build better UIs faster.", "main": "dist/index.js", "types": "dist/index.d.ts", diff --git a/src/npm-fastui/src/components/FireEvent.tsx b/src/npm-fastui/src/components/FireEvent.tsx index 5949c122..7b4aa23c 100644 --- a/src/npm-fastui/src/components/FireEvent.tsx +++ b/src/npm-fastui/src/components/FireEvent.tsx @@ -1,4 +1,4 @@ -import { FC, useEffect, useRef } from 'react' +import { FC, useEffect } from 'react' import type { FireEvent } from '../models' @@ -6,15 +6,12 @@ import { useFireEvent } from '../events' export const FireEventComp: FC = ({ event, message }) => { const { fireEvent } = useFireEvent() - const fireEventRef = useRef(fireEvent) useEffect(() => { - fireEventRef.current = fireEvent - }, [fireEvent]) - - useEffect(() => { - fireEventRef.current(event) - }, [event, fireEventRef]) + // debounce the event so changes to fireEvent (from location changes) don't trigger the event many times + const clear = setTimeout(() => fireEvent(event), 50) + return () => clearTimeout(clear) + }, [fireEvent, event]) return <>{message} } diff --git a/src/npm-fastui/src/components/FormField.tsx b/src/npm-fastui/src/components/FormField.tsx index 0e67959b..1676eb1d 100644 --- a/src/npm-fastui/src/components/FormField.tsx +++ b/src/npm-fastui/src/components/FormField.tsx @@ -24,7 +24,7 @@ interface FormFieldInputProps extends FormFieldInput { } export const FormFieldInputComp: FC = (props) => { - const { name, placeholder, required, htmlType, locked } = props + const { name, placeholder, required, htmlType, locked, autocomplete } = props return (
@@ -38,6 +38,7 @@ export const FormFieldInputComp: FC = (props) => { required={required} disabled={locked} placeholder={placeholder} + autoComplete={autocomplete} aria-describedby={descId(props)} /> diff --git a/src/npm-fastui/src/components/ServerLoad.tsx b/src/npm-fastui/src/components/ServerLoad.tsx index 062c821b..d8730502 100644 --- a/src/npm-fastui/src/components/ServerLoad.tsx +++ b/src/npm-fastui/src/components/ServerLoad.tsx @@ -52,8 +52,12 @@ export const ServerLoadFetch: FC<{ path: string; devReload?: number }> = ({ path useEffect(() => { setTransitioning(true) - const promise = request({ url, expectedStatus: [200, 404] }) - promise.then(([status, data]) => { + let componentUnloaded = false + request({ url, expectedStatus: [200, 404] }).then(([status, data]) => { + if (componentUnloaded) { + setTransitioning(false) + return + } if (status === 200) { setComponentProps(data as FastProps[]) // if there's a fragment, scroll to that ID once the page is loaded @@ -73,7 +77,7 @@ export const ServerLoadFetch: FC<{ path: string; devReload?: number }> = ({ path }) return () => { - promise.then(() => null) + componentUnloaded = true } }, [url, path, request, devReload]) diff --git a/src/npm-fastui/src/models.d.ts b/src/npm-fastui/src/models.d.ts index 461118fe..a7a97bdf 100644 --- a/src/npm-fastui/src/models.d.ts +++ b/src/npm-fastui/src/models.d.ts @@ -339,6 +339,7 @@ export interface FormFieldInput { htmlType?: 'text' | 'date' | 'datetime-local' | 'time' | 'email' | 'url' | 'number' | 'password' | 'hidden' initial?: string | number placeholder?: string + autocomplete?: string type: 'FormFieldInput' } export interface FormFieldTextarea { diff --git a/src/python-fastui/fastui/__init__.py b/src/python-fastui/fastui/__init__.py index 8ace8814..44069064 100644 --- a/src/python-fastui/fastui/__init__.py +++ b/src/python-fastui/fastui/__init__.py @@ -23,7 +23,7 @@ def coerce_to_list(cls, v): return [v] -_PREBUILT_VERSION = '0.0.15' +_PREBUILT_VERSION = '0.0.16' _PREBUILT_CDN_URL = f'https://cdn.jsdelivr.net/npm/@pydantic/fastui-prebuilt@{_PREBUILT_VERSION}/dist/assets' diff --git a/src/python-fastui/fastui/auth/__init__.py b/src/python-fastui/fastui/auth/__init__.py new file mode 100644 index 00000000..23f60222 --- /dev/null +++ b/src/python-fastui/fastui/auth/__init__.py @@ -0,0 +1,3 @@ +from .github import AuthError, GitHubAuthProvider + +__all__ = 'GitHubAuthProvider', 'AuthError' diff --git a/src/python-fastui/fastui/auth/github.py b/src/python-fastui/fastui/auth/github.py new file mode 100644 index 00000000..d53ba69f --- /dev/null +++ b/src/python-fastui/fastui/auth/github.py @@ -0,0 +1,284 @@ +from contextlib import asynccontextmanager +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, AsyncIterator, Dict, List, Tuple, Union, cast +from urllib.parse import urlencode + +from pydantic import BaseModel, SecretStr, TypeAdapter, field_validator + +if TYPE_CHECKING: + import httpx + from fastapi import Request + from fastapi.responses import JSONResponse + + +@dataclass +class GitHubExchangeError: + error: str + error_description: Union[str, None] = None + + +@dataclass +class GitHubExchange: + access_token: str + token_type: str + scope: List[str] + + @field_validator('scope', mode='before') + def check_scope(cls, v: str) -> List[str]: + return [s for s in v.split(',') if s] + + +github_exchange_type = TypeAdapter(Union[GitHubExchange, GitHubExchangeError]) + + +class GithubUser(BaseModel): + login: str + name: Union[str, None] + email: Union[str, None] + avatar_url: str + created_at: datetime + updated_at: datetime + public_repos: int + public_gists: int + followers: int + following: int + company: Union[str, None] + blog: Union[str, None] + location: Union[str, None] + hireable: Union[bool, None] + bio: Union[str, None] + twitter_username: Union[str, None] = None + + +class GitHubEmail(BaseModel): + email: str + primary: bool + verified: bool + visibility: Union[str, None] + + +github_emails_ta = TypeAdapter(List[GitHubEmail]) + + +class GitHubAuthProvider: + """ + For details see https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/authorizing-oauth-apps. + """ + + def __init__( + self, + httpx_client: 'httpx.AsyncClient', + github_client_id: str, + github_client_secret: SecretStr, + *, + redirect_uri: Union[str, None] = None, + scopes: Union[List[str], None] = None, + state_provider: Union['StateProvider', bool] = True, + exchange_cache_age: Union[timedelta, None] = timedelta(seconds=30), + ): + """ + Arguments: + httpx_client: An instance of `httpx.AsyncClient` to use for making requests to GitHub. + github_client_id: The client ID of the GitHub OAuth app. + github_client_secret: The client secret of the GitHub OAuth app. + redirect_uri: The URL in your app where users will be sent after authorization, if custom + scopes: See https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/scopes-for-oauth-apps#available-scopes + state_provider: If `True`, use a `StateProvider` to generate and validate state parameters for the OAuth + flow, you can also provide an instance directly. + exchange_cache_age: If not `None`, + responses from the access token exchange are cached for the given duration. + """ + self._httpx_client = httpx_client + self._github_client_id = github_client_id + self._github_client_secret = github_client_secret + self._redirect_uri = redirect_uri + self._scopes = scopes + if state_provider is True: + self._state_provider = StateProvider(github_client_secret) + elif state_provider is False: + self._state_provider = None + else: + self._state_provider = state_provider + # cache exchange responses, see `exchange_code` for details + self._exchange_cache_age = exchange_cache_age + + @classmethod + @asynccontextmanager + async def create( + cls, + client_id: str, + client_secret: SecretStr, + *, + redirect_uri: Union[str, None] = None, + state_provider: Union['StateProvider', bool] = True, + exchange_cache_age: Union[timedelta, None] = timedelta(seconds=10), + ) -> AsyncIterator['GitHubAuthProvider']: + """ + Async context manager to create a GitHubAuth instance with a new `httpx.AsyncClient`. + """ + import httpx + + async with httpx.AsyncClient() as client: + yield cls( + client, + client_id, + client_secret, + redirect_uri=redirect_uri, + state_provider=state_provider, + exchange_cache_age=exchange_cache_age, + ) + + async def authorization_url(self) -> str: + """ + See https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/authorizing-oauth-apps#1-request-a-users-github-identity + """ + params = {'client_id': self._github_client_id} + if self._redirect_uri: + params['redirect_uri'] = self._redirect_uri + if self._scopes: + params['scope'] = ' '.join(self._scopes) + if self._state_provider: + params['state'] = await self._state_provider.new_state() + return f'https://github.com/login/oauth/authorize?{urlencode(params)}' + + async def exchange_code(self, code: str, state: Union[str, None] = None) -> GitHubExchange: + """ + Exchange a code for an access token. + + If `self._exchange_cache_age` is not `None` (the default), responses are cached for the given duration to + work around issues with React often sending the same request multiple times in development mode. + """ + if self._exchange_cache_age: + cache_key = f'{code}:{state}' + if exchange := EXCHANGE_CACHE.get(cache_key, self._exchange_cache_age): + return exchange + else: + exchange = await self._exchange_code(code, state) + EXCHANGE_CACHE.set(cache_key, exchange) + return exchange + else: + return await self._exchange_code(code, state) + + async def _exchange_code(self, code: str, state: Union[str, None] = None) -> GitHubExchange: + if self._state_provider: + if state is None: + raise AuthError('Missing GitHub auth state', code='missing_state') + elif not await self._state_provider.check_state(state): + raise AuthError('Invalid GitHub auth state', code='invalid_state') + + params = { + 'client_id': self._github_client_id, + 'client_secret': self._github_client_secret.get_secret_value(), + 'code': code, + } + if self._redirect_uri: + params['redirect_uri'] = self._redirect_uri + + r = await self._httpx_client.post( + 'https://github.com/login/oauth/access_token', + params=params, + headers={'Accept': 'application/json'}, + ) + r.raise_for_status() + exchange_response = github_exchange_type.validate_json(r.content) + if isinstance(exchange_response, GitHubExchangeError): + if exchange_response.error == 'bad_verification_code': + raise AuthError('Invalid GitHub verification code', code=exchange_response.error) + else: + raise RuntimeError(f'Unexpected response from GitHub access token exchange: {r.text}') + else: + return cast(GitHubExchange, exchange_response) + + async def get_github_user(self, exchange: GitHubExchange) -> GithubUser: + """ + See https://docs.github.com/en/rest/users/users#get-the-authenticated-user + """ + headers = self._auth_headers(exchange) + user_response = await self._httpx_client.get('https://api.github.com/user', headers=headers) + user_response.raise_for_status() + return GithubUser.model_validate_json(user_response.content) + + async def get_github_user_emails(self, exchange: GitHubExchange) -> List[GitHubEmail]: + """ + See https://docs.github.com/en/rest/users/emails + """ + headers = self._auth_headers(exchange) + emails_response = await self._httpx_client.get('https://api.github.com/user/emails', headers=headers) + emails_response.raise_for_status() + return github_emails_ta.validate_json(emails_response.content) + + @staticmethod + def _auth_headers(exchange: GitHubExchange) -> Dict[str, str]: + return { + 'Authorization': f'Bearer {exchange.access_token}', + 'Accept': 'application/vnd.github+json', + } + + +class ExchangeCache: + def __init__(self): + self._cache: Dict[str, Tuple[datetime, GitHubExchange]] = {} + + def get(self, key: str, max_age: timedelta) -> Union[GitHubExchange, None]: + self._purge(max_age) + if v := self._cache.get(key): + return v[1] + + def set(self, key: str, value: GitHubExchange) -> None: + self._cache[key] = (datetime.now(), value) + + def _purge(self, max_age: timedelta) -> None: + """ + Remove old items from the exchange cache + """ + min_timestamp = datetime.now() - max_age + to_remove = [k for k, (ts, _) in self._cache.items() if ts < min_timestamp] + for k in to_remove: + del self._cache[k] + + +# exchange cache is a singleton so instantiating a new GitHubAuthProvider reuse the same cache +EXCHANGE_CACHE = ExchangeCache() + + +class AuthError(RuntimeError): + # TODO if we add other providers, this should be shared + + def __init__(self, message: str, *, code: str): + super().__init__(message) + self.code = code + + @staticmethod + def fastapi_handle(_request: 'Request', e: 'AuthError') -> 'JSONResponse': + from fastapi.responses import JSONResponse + + return JSONResponse({'detail': str(e)}, status_code=400) + + +class StateProvider: + """ + This is a simple state provider for the GitHub OAuth flow which uses a JWT to create an unguessable state. + """ + + # TODO if we add other providers, this could be shared + + def __init__(self, secret: SecretStr, max_age: timedelta = timedelta(minutes=5)): + self._secret = secret + self._max_age = max_age + + async def new_state(self) -> str: + import jwt + + data = {'created_at': datetime.now().isoformat()} + return jwt.encode(data, self._secret.get_secret_value(), algorithm='HS256') + + async def check_state(self, state: str) -> bool: + import jwt + + try: + d = jwt.decode(state, self._secret.get_secret_value(), algorithms=['HS256']) + except jwt.DecodeError: + return False + else: + return datetime.fromisoformat(d['created_at']) > datetime.now() - self._max_age diff --git a/src/python-fastui/fastui/components/forms.py b/src/python-fastui/fastui/components/forms.py index 539fb6b5..042c27d9 100644 --- a/src/python-fastui/fastui/components/forms.py +++ b/src/python-fastui/fastui/components/forms.py @@ -31,6 +31,7 @@ class FormFieldInput(BaseFormField): html_type: InputHtmlType = pydantic.Field(default='text', serialization_alias='htmlType') initial: _t.Union[str, float, None] = None placeholder: _t.Union[str, None] = None + autocomplete: _t.Union[str, None] = None type: _t.Literal['FormFieldInput'] = 'FormFieldInput' diff --git a/src/python-fastui/fastui/json_schema.py b/src/python-fastui/fastui/json_schema.py index eb209e6f..e0ee88da 100644 --- a/src/python-fastui/fastui/json_schema.py +++ b/src/python-fastui/fastui/json_schema.py @@ -195,6 +195,7 @@ def json_schema_field_to_field( html_type=input_html_type(schema), required=required, initial=schema.get('default'), + autocomplete=schema.get('autocomplete'), description=schema.get('description'), ) diff --git a/src/python-fastui/requirements/render.txt b/src/python-fastui/requirements/render.txt index 683fa919..3d1fa143 100644 --- a/src/python-fastui/requirements/render.txt +++ b/src/python-fastui/requirements/render.txt @@ -3,4 +3,4 @@ src/python-fastui uvicorn[standard] httpx -libsql-client +PyJWT diff --git a/src/python-fastui/requirements/test.in b/src/python-fastui/requirements/test.in index ba6ee03c..5928ebed 100644 --- a/src/python-fastui/requirements/test.in +++ b/src/python-fastui/requirements/test.in @@ -4,5 +4,4 @@ pytest-pretty dirty-equals pytest-asyncio httpx -# libsql-client is used by demo -libsql-client +PyJWT diff --git a/src/python-fastui/requirements/test.txt b/src/python-fastui/requirements/test.txt index 12d55713..2d8e26df 100644 --- a/src/python-fastui/requirements/test.txt +++ b/src/python-fastui/requirements/test.txt @@ -4,16 +4,10 @@ # # pip-compile --constraint=src/python-fastui/requirements/lint.txt --constraint=src/python-fastui/requirements/pyproject.txt --output-file=src/python-fastui/requirements/test.txt --strip-extras src/python-fastui/requirements/test.in # -aiohttp==3.9.3 - # via libsql-client -aiosignal==1.3.1 - # via aiohttp anyio==4.2.0 # via # -c src/python-fastui/requirements/pyproject.txt # httpx -attrs==23.2.0 - # via aiohttp certifi==2024.2.2 # via # httpcore @@ -22,10 +16,6 @@ coverage==7.4.1 # via -r src/python-fastui/requirements/test.in dirty-equals==0.7.1.post0 # via -r src/python-fastui/requirements/test.in -frozenlist==1.4.1 - # via - # aiohttp - # aiosignal h11==0.14.0 # via httpcore httpcore==1.0.2 @@ -37,25 +27,20 @@ idna==3.6 # -c src/python-fastui/requirements/pyproject.txt # anyio # httpx - # yarl iniconfig==2.0.0 # via pytest -libsql-client==0.3.0 - # via -r src/python-fastui/requirements/test.in markdown-it-py==3.0.0 # via rich mdurl==0.1.2 # via markdown-it-py -multidict==6.0.5 - # via - # aiohttp - # yarl packaging==23.2 # via pytest pluggy==1.4.0 # via pytest pygments==2.17.2 # via rich +pyjwt==2.8.0 + # via -r src/python-fastui/requirements/test.in pytest==7.4.4 # via # -r src/python-fastui/requirements/test.in @@ -74,9 +59,3 @@ sniffio==1.3.0 # -c src/python-fastui/requirements/pyproject.txt # anyio # httpx -typing-extensions==4.9.0 - # via - # -c src/python-fastui/requirements/pyproject.txt - # libsql-client -yarl==1.9.4 - # via aiohttp diff --git a/src/python-fastui/tests/test_auth_github.py b/src/python-fastui/tests/test_auth_github.py new file mode 100644 index 00000000..1723eed0 --- /dev/null +++ b/src/python-fastui/tests/test_auth_github.py @@ -0,0 +1,235 @@ +from typing import List, Optional + +import httpx +import pytest +from fastapi import FastAPI +from fastui.auth import AuthError, GitHubAuthProvider +from fastui.auth.github import GitHubEmail +from pydantic import SecretStr + + +@pytest.fixture +def github_requests() -> List[str]: + return [] + + +@pytest.fixture +def fake_github_app(github_requests: List[str]) -> FastAPI: + app = FastAPI() + + @app.post('/login/oauth/access_token') + async def access_token(code: str, client_id: str, client_secret: str, redirect_uri: Optional[str] = None): + r = f'/login/oauth/access_token code={code}' + if redirect_uri: + r += f' redirect_uri={redirect_uri}' + github_requests.append(r) + assert client_id == '1234' + assert client_secret == 'secret' + if code == 'good_user': + return {'access_token': 'good_token_user', 'token_type': 'bearer', 'scope': 'user'} + elif code == 'good': + return {'access_token': 'good_token', 'token_type': 'bearer', 'scope': ''} + elif code == 'bad_expected': + return {'error': 'bad_verification_code'} + else: + return {'error': 'bad_code'} + + @app.get('/user') + async def user(): + github_requests.append('/user') + return { + 'login': 'test_user', + 'name': 'Test User', + 'email': 'test@example.com', + 'avatar_url': 'https://example.com/avatar.png', + 'created_at': '2022-01-01T00:00:00Z', + 'updated_at': '2022-01-01T00:00:00Z', + 'public_repos': 0, + 'public_gists': 0, + 'followers': 0, + 'following': 0, + 'company': None, + 'blog': None, + 'location': None, + 'hireable': None, + 'bio': None, + } + + @app.get('/user/emails') + async def user_emails(): + github_requests.append('/user/emails') + return [ + {'email': 'foo@example.com', 'primary': False, 'verified': True, 'visibility': None}, + {'email': 'bar@example.com', 'primary': True, 'verified': True, 'visibility': 'public'}, + ] + + return app + + +@pytest.fixture +async def httpx_client(fake_github_app: FastAPI): + async with httpx.AsyncClient(app=fake_github_app) as client: + yield client + + +@pytest.fixture +async def github_auth_provider(fake_github_app: FastAPI, httpx_client: httpx.AsyncClient): + return GitHubAuthProvider( + httpx_client=httpx_client, + github_client_id='1234', + github_client_secret=SecretStr('secret'), + state_provider=False, + exchange_cache_age=None, + ) + + +async def test_get_auth_url(github_auth_provider: GitHubAuthProvider): + url = await github_auth_provider.authorization_url() + # no state here + assert url == 'https://github.com/login/oauth/authorize?client_id=1234' + + +async def test_exchange_ok(github_auth_provider: GitHubAuthProvider, github_requests: List[str]): + assert github_requests == [] + exchange = await github_auth_provider.exchange_code('good') + assert exchange.access_token == 'good_token' + assert exchange.token_type == 'bearer' + assert exchange.scope == [] + assert github_requests == ['/login/oauth/access_token code=good'] + + +async def test_exchange_ok_user(github_auth_provider: GitHubAuthProvider): + exchange = await github_auth_provider.exchange_code('good_user') + assert exchange.access_token == 'good_token_user' + assert exchange.token_type == 'bearer' + assert exchange.scope == ['user'] + + +async def test_exchange_bad_expected(github_auth_provider: GitHubAuthProvider): + with pytest.raises(AuthError, match='^Invalid GitHub verification code') as exc_info: + await github_auth_provider.exchange_code('bad_expected') + + # request argument is ignored + r = AuthError.fastapi_handle(object(), exc_info.value) + assert r.status_code == 400 + + +async def test_exchange_bad_unexpected(github_auth_provider: GitHubAuthProvider): + with pytest.raises(RuntimeError, match='^Unexpected response from GitHub access token exchange'): + await github_auth_provider.exchange_code('unknown') + + +@pytest.fixture +async def github_auth_provider_state(fake_github_app: FastAPI, httpx_client: httpx.AsyncClient): + return GitHubAuthProvider( + httpx_client=httpx_client, + github_client_id='1234', + github_client_secret=SecretStr('secret'), + state_provider=True, + ) + + +async def test_exchange_no_state(github_auth_provider_state: GitHubAuthProvider): + with pytest.raises(AuthError, match='^Missing GitHub auth state'): + await github_auth_provider_state.exchange_code('good') + + +async def test_exchange_bad_state(github_auth_provider_state: GitHubAuthProvider): + with pytest.raises(AuthError, match='^Invalid GitHub auth state'): + await github_auth_provider_state.exchange_code('good', 'bad_state') + + +async def test_exchange_good_state(github_auth_provider_state: GitHubAuthProvider): + url = await github_auth_provider_state.authorization_url() + assert url.startswith('https://github.com/login/oauth/authorize?client_id=1234&state=') + state = url.rsplit('=', 1)[-1] + + exchange = await github_auth_provider_state.exchange_code('good', state) + assert exchange.access_token == 'good_token' + + +async def test_exchange_bad_state_file_exists(github_auth_provider_state: GitHubAuthProvider): + url = await github_auth_provider_state.authorization_url() + assert url.startswith('https://github.com/login/oauth/authorize?client_id=1234&state=') + + with pytest.raises(AuthError, match='^Invalid GitHub auth state'): + await github_auth_provider_state.exchange_code('good', 'bad_state') + + +async def test_exchange_ok_repeat(github_auth_provider: GitHubAuthProvider, github_requests: List[str]): + assert github_requests == [] + exchange = await github_auth_provider.exchange_code('good') + assert exchange.access_token == 'good_token' + assert exchange.token_type == 'bearer' + assert exchange.scope == [] + assert github_requests == ['/login/oauth/access_token code=good'] + + exchange = await github_auth_provider.exchange_code('good') + assert exchange.access_token == 'good_token' + + assert github_requests == ['/login/oauth/access_token code=good', '/login/oauth/access_token code=good'] + + +async def test_exchange_ok_repeat_cached( + fake_github_app: FastAPI, httpx_client: httpx.AsyncClient, github_requests: List[str] +): + github_auth_provider = GitHubAuthProvider( + httpx_client=httpx_client, + github_client_id='1234', + github_client_secret=SecretStr('secret'), + state_provider=False, + ) + assert github_requests == [] + await github_auth_provider.exchange_code('good') + assert github_requests == ['/login/oauth/access_token code=good'] + await github_auth_provider.exchange_code('good') + assert github_requests == ['/login/oauth/access_token code=good'] # no repeat request to github + await github_auth_provider.exchange_code('good_user') + assert github_requests == ['/login/oauth/access_token code=good', '/login/oauth/access_token code=good_user'] + + +async def test_exchange_redirect_url( + fake_github_app: FastAPI, httpx_client: httpx.AsyncClient, github_requests: List[str] +): + github_auth_provider = GitHubAuthProvider( + httpx_client=httpx_client, + github_client_id='1234', + github_client_secret=SecretStr('secret'), + redirect_uri='/callback', + state_provider=False, + exchange_cache_age=None, + ) + url = await github_auth_provider.authorization_url() + assert url == 'https://github.com/login/oauth/authorize?client_id=1234&redirect_uri=%2Fcallback' + exchange = await github_auth_provider.exchange_code('good') + assert exchange.access_token == 'good_token' + assert github_requests == ['/login/oauth/access_token code=good redirect_uri=/callback'] + + +async def test_get_github_user(github_auth_provider: GitHubAuthProvider, github_requests: List[str]): + assert github_requests == [] + exchange = await github_auth_provider.exchange_code('good') + assert github_requests == ['/login/oauth/access_token code=good'] + user = await github_auth_provider.get_github_user(exchange) + assert user.login == 'test_user' + assert user.name == 'Test User' + assert user.email == 'test@example.com' + + assert github_requests == ['/login/oauth/access_token code=good', '/user'] + + +async def test_get_github_user_emails(github_auth_provider: GitHubAuthProvider, github_requests: List[str]): + assert github_requests == [] + exchange = await github_auth_provider.exchange_code('good') + assert github_requests == ['/login/oauth/access_token code=good'] + emails = await github_auth_provider.get_github_user_emails(exchange) + assert emails == [ + GitHubEmail(email='foo@example.com', primary=False, verified=True, visibility=None), + GitHubEmail(email='bar@example.com', primary=True, verified=True, visibility='public'), + ] + assert github_requests == ['/login/oauth/access_token code=good', '/user/emails'] + + +async def test_create(): + async with GitHubAuthProvider.create('foo', SecretStr('bar')) as provider: + assert isinstance(provider._httpx_client, httpx.AsyncClient)