Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for access tokens in WSO2 client #85

Merged
merged 8 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nens_auth_client/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def access_token_generator(token_generator, access_token_template):

def func(**extra_claims):
claims = {**access_token_template, **extra_claims}
claims = {k: v for (k, v) in claims.items() if v is not None}
return token_generator(**claims)

return func
Expand Down
35 changes: 30 additions & 5 deletions nens_auth_client/tests/test_wso2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from authlib.jose.errors import JoseError
from authlib.oidc.discovery import get_well_known_url
from django.conf import settings
from nens_auth_client.wso2 import WSO2AuthClient

import pytest
Expand All @@ -19,9 +22,31 @@ def test_extract_username(claims, expected):
assert WSO2AuthClient.extract_username(claims) == expected


def test_parse_access_token_includes_claims(access_token_generator):
with pytest.raises(NotImplementedError) as e:
WSO2AuthClient.parse_access_token(None, access_token_generator())
@pytest.fixture
def wso2_client():
return WSO2AuthClient(
"foo",
server_metadata_url=get_well_known_url(
settings.NENS_AUTH_ISSUER, external=True
),
)

# error is raised with claims as arg
assert e.value.args[0]["client_id"] == "1234"

def test_parse_access_token_wso2(access_token_generator, jwks_request, wso2_client):
# disable 'token_use' (not included in WSO2 access token)
claims = wso2_client.parse_access_token(
access_token_generator(email="test@wso2", token_use=None)
)

assert claims["email"] == "test@wso2"


@pytest.mark.parametrize(
"claims_mod", [{"aud": "abc123"}, {"sub": None}, {"iss": "abc123"}, {"exp": 0}]
)
caspervdw marked this conversation as resolved.
Show resolved Hide resolved
def test_parse_access_token_wso2_invalid_claims(
claims_mod, access_token_generator, jwks_request, wso2_client
):
token = access_token_generator(**claims_mod)
with pytest.raises(JoseError):
wso2_client.parse_access_token(token)
45 changes: 39 additions & 6 deletions nens_auth_client/wso2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from authlib.integrations.django_client import DjangoOAuth2App
from authlib.jose import JsonWebKey
from authlib.jose import JsonWebToken
from django.conf import settings
from django.http.response import HttpResponseRedirect
from urllib.parse import urlencode
from urllib.parse import urlparse
Expand Down Expand Up @@ -46,17 +49,47 @@ def logout_redirect(self, request, redirect_uri=None, login_after=False):
def parse_access_token(self, token, claims_options=None, leeway=120):
"""Decode and validate a WSO2 access token and return its payload.

Note: this function just errors with the token claims in the error
message (so that we can figure out how we can actually validate the
token)

Args:
token (str): access token (base64 encoded JWT)

Returns:
claims (dict): the token payload

Raises:
NotImplementedError: always
authlib.jose.errors.JoseError: if token is invalid
ValueError: if the key id is not present in the jwks.json
"""
raise NotImplementedError(decode_jwt(token))

# this is a copy from the _parse_id_token equivalent function
caspervdw marked this conversation as resolved.
Show resolved Hide resolved
def load_key(header, payload):
jwk_set = self.fetch_jwk_set()
kid = header.get("kid")
try:
return JsonWebKey.import_key_set(jwk_set).find_by_kid(kid)
except ValueError:
# re-try with new jwk set
jwk_set = self.fetch_jwk_set(force=True)
return JsonWebKey.import_key_set(jwk_set).find_by_kid(kid)

metadata = self.load_server_metadata()
claims_options = {
"aud": {"essential": True, "value": settings.NENS_AUTH_RESOURCE_SERVER_ID},
"iss": {"essential": True, "value": metadata["issuer"]},
"sub": {"essential": True},
"scope": {"essential": True},
**(claims_options or {}),
}

alg_values = metadata.get("id_token_signing_alg_values_supported")
if not alg_values:
alg_values = ["RS256"]

claims = JsonWebToken(alg_values).decode(
token, key=load_key, claims_options=claims_options
)

claims.validate(leeway=leeway)
return claims

@staticmethod
def extract_provider_name(claims):
Expand Down
Loading