Skip to content

Commit

Permalink
HTTPBearer token is set, Auth button not shown on /api/docs (#67)
Browse files Browse the repository at this point in the history
Refactor Oauth2-lib to allow token pass to method, token can be extracted from the request via the default HttpBearerExtractor() or a custom implementation.

* WIP: HTTPBearer token is set

* Fix: Authenticate refactor for GraphQL and other

* Fix: raise 403 on empty token

* Fix: Add test with token extractor

* Fix: changed bool to False auto_error

* Fix artefact package

* Linting

* Fix: Linting issues

* Fix: space removal black check .

* Refactor: make it more abstract

* Fix: linting

* Test: fix tests

* Fix: linting imports

* Fix: linting return type

* Add init to HTTPBearerExtractor

* Fix: Black --check

* Fix typing decorator and removed unused code

* fix: black

* Update oauth2_lib/fastapi.py

Co-authored-by: Mark Moes <[email protected]>

* Update oauth2_lib/fastapi.py

Co-authored-by: Mark Moes <[email protected]>

* Fix: ruff

* Rollback Typing checks

* Add: typing ignore on decorator tests

* bump version

* bump version

---------

Co-authored-by: Peter Boers <[email protected]>
Co-authored-by: Mark Moes <[email protected]>
  • Loading branch information
3 people authored Oct 24, 2024
1 parent a55dfd7 commit 0e3e5f4
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 56 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 2.2.0
current_version = 2.3.0
commit = False
tag = False
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
Expand Down
2 changes: 1 addition & 1 deletion oauth2_lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@

"""This is the SURF Oauth2 module that interfaces with the oauth2 setup."""

__version__ = "2.2.0"
__version__ = "2.3.0"
56 changes: 21 additions & 35 deletions oauth2_lib/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from fastapi import HTTPException
from fastapi.requests import Request
from fastapi.security.http import HTTPBearer
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from httpx import AsyncClient, NetworkError
from pydantic import BaseModel
from starlette.requests import ClientDisconnect, HTTPConnection
Expand Down Expand Up @@ -126,7 +126,7 @@ class Authentication(ABC):
"""

@abstractmethod
async def authenticate(self, request: HTTPConnection, token: str | None = None) -> dict | None:
async def authenticate(self, request: Request, token: str | None = None) -> dict | None:
"""Authenticate the user."""
pass

Expand All @@ -142,17 +142,24 @@ async def extract(self, request: Request) -> str | None:
pass


class HttpBearerExtractor(IdTokenExtractor):
class HttpBearerExtractor(HTTPBearer, IdTokenExtractor):
"""Extracts bearer tokens using FastAPI's HTTPBearer.
Specifically designed for HTTP Authorization header token extraction.
"""

def __init__(self, auto_error: bool = False):
super().__init__(auto_error=auto_error)

async def __call__(self, request: Request) -> Optional[HTTPAuthorizationCredentials]:
"""Extract the Authorization header from the request."""
return await super().__call__(request)

async def extract(self, request: Request) -> str | None:
http_bearer = HTTPBearer(auto_error=False)
credential = await http_bearer(request)
"""Extract the token from the Authorization header in the request."""
http_auth_credentials = await super().__call__(request)

return credential.credentials if credential else None
return http_auth_credentials.credentials if http_auth_credentials else None


class OIDCAuth(Authentication):
Expand All @@ -168,11 +175,7 @@ def __init__(
resource_server_id: str,
resource_server_secret: str,
oidc_user_model_cls: type[OIDCUserModel],
id_token_extractor: IdTokenExtractor | None = None,
):
if not id_token_extractor:
self.id_token_extractor = HttpBearerExtractor()

self.openid_url = openid_url
self.openid_config_url = openid_config_url
self.resource_server_id = resource_server_id
Expand All @@ -181,7 +184,7 @@ def __init__(

self.openid_config: OIDCConfig | None = None

async def authenticate(self, request: HTTPConnection, token: str | None = None) -> OIDCUserModel | None:
async def authenticate(self, request: Request, token: str | None = None) -> OIDCUserModel | None:
"""Return the OIDC user from OIDC introspect endpoint.
This is used as a security module in Fastapi projects
Expand All @@ -197,33 +200,16 @@ async def authenticate(self, request: HTTPConnection, token: str | None = None)
if not oauth2lib_settings.OAUTH2_ACTIVE:
return None

async with AsyncClient(http1=True, verify=HTTPX_SSL_CONTEXT) as async_client:
await self.check_openid_config(async_client)

# Handle WebSocket requests separately only to check for token presence.
if isinstance(request, WebSocket):
if token is None:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Not authenticated",
)
token_or_extracted_id_token = token
else:
request = cast(Request, request)

if await self.is_bypassable_request(request):
return None
if await self.is_bypassable_request(request):
return None

if token is None:
extracted_id_token = await self.id_token_extractor.extract(request)
if not extracted_id_token:
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Not authenticated")
if not token:
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Not authenticated")

token_or_extracted_id_token = extracted_id_token
else:
token_or_extracted_id_token = token
async with AsyncClient(http1=True, verify=HTTPX_SSL_CONTEXT) as async_client:
await self.check_openid_config(async_client)

user_info: OIDCUserModel = await self.userinfo(async_client, token_or_extracted_id_token)
user_info: OIDCUserModel = await self.userinfo(async_client, token)
logger.debug("OIDCUserModel object.", user_info=user_info)
return user_info

Expand Down
8 changes: 6 additions & 2 deletions oauth2_lib/strawberry.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from strawberry.types.fields.resolver import StrawberryResolver
from strawberry.types.info import RootValueType

from oauth2_lib.fastapi import AuthManager, OIDCUserModel
from oauth2_lib.fastapi import AuthManager, HttpBearerExtractor, OIDCUserModel
from oauth2_lib.settings import oauth2lib_settings

logger = structlog.get_logger(__name__)
Expand Down Expand Up @@ -56,7 +56,11 @@ async def get_current_user(self) -> OIDCUserModel | None:
return None

try:
return await self.auth_manager.authentication.authenticate(self.request)
http_bearer_extractor = HttpBearerExtractor(auto_error=False)
http_authorization_credentials = await http_bearer_extractor(self.request)

token = http_authorization_credentials.credentials if http_authorization_credentials else None
return await self.auth_manager.authentication.authenticate(self.request, token)
except HTTPException as exc:
logger.debug("User is not authenticated", status_code=exc.status_code, detail=exc.detail)
return None
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def make_mock_async_client():
Pass a MockResponse for single or list for multiple sequential HTTP responses.
"""

def _make_mock_async_client(mock_response: MockResponse | list[MockResponse] | None = None):
def _make_mock_async_client(mock_response: MockResponse | list[MockResponse] | None = None) -> AsyncClientMock:
mock_async_client = AsyncMock(spec=AsyncClient)

mock_responses = ([mock_response] if isinstance(mock_response, MockResponse) else mock_response) or []
Expand Down
17 changes: 9 additions & 8 deletions tests/strawberry/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import strawberry
from fastapi import Depends, FastAPI
from starlette.requests import Request
from httpx._client import AsyncClient
from starlette.testclient import TestClient
from strawberry.fastapi import GraphQLRouter

Expand All @@ -17,7 +17,7 @@

async def get_oidc_authentication():
class OIDCAuthMock(OIDCAuth):
async def userinfo(self, request: Request, token: str | None = None) -> OIDCUserModel | None:
async def userinfo(self, async_request: AsyncClient, token: str) -> OIDCUserModel:
return user_info_matching

return OIDCAuthMock("openid_url", "openid_url/.well-known/openid-configuration", "id", "secret", OIDCUserModel)
Expand Down Expand Up @@ -47,27 +47,28 @@ class BookType:
class bookNestedAuthType:
title: str

@authenticated_field("test authentication for book author")
@authenticated_field("test authentication for book author") # type: ignore
def author(self) -> str:
return f"{self.title} author"

@strawberry.type
class Query:
@authenticated_field("query book test")
@authenticated_field("query book test") # type: ignore
def book(self) -> BookType:
return BookType(title="test title", author="test author")

@strawberry.field(description="query book nested auth test")
# Known issue with strawberry: https://github.com/strawberry-graphql/strawberry/issues/1929
@strawberry.field(description="query book nested auth test") # type: ignore
def book_nested_auth(self) -> bookNestedAuthType:
return bookNestedAuthType(title="test title")

@authenticated_federated_field("federated field book test")
@authenticated_federated_field("federated field book test") # type: ignore
def federated_book(self) -> BookType:
return BookType(title="test title federated field", author="test author federated field")

@strawberry.type
class Mutation:
@authenticated_mutation_field("mutation test")
@authenticated_mutation_field("mutation test") # type: ignore
def add_book(self, title: str, author: str) -> BookType:
return BookType(title=title, author=author)

Expand All @@ -76,7 +77,7 @@ async def get_context(auth_manager=Depends(get_auth_manger)) -> OauthContext: #

app = FastAPI()
schema = strawberry.Schema(query=Query, mutation=Mutation)
graphql_app = GraphQLRouter(schema, context_getter=get_context)
graphql_app: GraphQLRouter = GraphQLRouter(schema, context_getter=get_context)
app.include_router(graphql_app, prefix="/graphql")

return TestClient(app)
Expand Down
13 changes: 12 additions & 1 deletion tests/test_async_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,18 @@ class FakeApiClient:
def __init__(self, *args, **kwargs):
pass

def request(self, method, url, query_params, headers, *args):
def request(
self,
method,
url,
query_params=None,
headers=None,
post_params=None,
body=None,
_preload_content=True,
_request_timeout=None,
):
headers = {} if headers is None else headers
http = urllib3.PoolManager()
response = http.request(method, url, headers=headers)
if not 200 <= response.status <= 299:
Expand Down
10 changes: 3 additions & 7 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,6 @@ async def test_userinfo_success_with_mock(oidc_auth):
assert user["sub"] == "hoi", "User info not retrieved correctly"


def test_oidc_auth_initialization_default_extractor(oidc_auth):
assert isinstance(
oidc_auth.id_token_extractor, HttpBearerExtractor
), "Default ID token extractor should be HttpBearerExtractor"


@pytest.mark.asyncio
async def test_extract_token_success():
request = mock.MagicMock()
Expand All @@ -165,7 +159,9 @@ async def test_authenticate_success(make_mock_async_client, discovery, oidc_auth
request = mock.MagicMock(spec=Request)
request.headers = {"Authorization": "Bearer valid_token"}

user = await oidc_auth.authenticate(request)
http_bearer_extractor = HttpBearerExtractor(auto_error=False)
token = await http_bearer_extractor(request)
user = await oidc_auth.authenticate(request, token)
assert user == user_info_matching, "Authentication failed for a valid token"


Expand Down

0 comments on commit 0e3e5f4

Please sign in to comment.