From 77869154e4674199bbb176b00c8e7a6bead7b613 Mon Sep 17 00:00:00 2001 From: Casper van der Wel Date: Wed, 20 Mar 2024 09:24:54 +0100 Subject: [PATCH 1/8] Add support for access tokens in WSO2 client --- nens_auth_client/tests/conftest.py | 1 + nens_auth_client/tests/test_wso2.py | 35 ++++++++++++++++++---- nens_auth_client/wso2.py | 45 +++++++++++++++++++++++++---- setup.cfg | 2 +- 4 files changed, 71 insertions(+), 12 deletions(-) diff --git a/nens_auth_client/tests/conftest.py b/nens_auth_client/tests/conftest.py index 82455fe..0d4208a 100644 --- a/nens_auth_client/tests/conftest.py +++ b/nens_auth_client/tests/conftest.py @@ -145,6 +145,7 @@ def access_token_generator(token_generator, access_token_template): def func(**extra_claims): claims = {**access_token_template, **extra_claims} + claims = {k: v for (k, v) in claims.items() if v is not None} return token_generator(**claims) return func diff --git a/nens_auth_client/tests/test_wso2.py b/nens_auth_client/tests/test_wso2.py index 9f95333..c91a0c6 100644 --- a/nens_auth_client/tests/test_wso2.py +++ b/nens_auth_client/tests/test_wso2.py @@ -1,3 +1,6 @@ +from authlib.jose.errors import JoseError +from authlib.oidc.discovery import get_well_known_url +from django.conf import settings from nens_auth_client.wso2 import WSO2AuthClient import pytest @@ -19,9 +22,31 @@ def test_extract_username(claims, expected): assert WSO2AuthClient.extract_username(claims) == expected -def test_parse_access_token_includes_claims(access_token_generator): - with pytest.raises(NotImplementedError) as e: - WSO2AuthClient.parse_access_token(None, access_token_generator()) +@pytest.fixture +def wso2_client(): + return WSO2AuthClient( + "foo", + server_metadata_url=get_well_known_url( + settings.NENS_AUTH_ISSUER, external=True + ), + ) - # error is raised with claims as arg - assert e.value.args[0]["client_id"] == "1234" + +def test_parse_access_token_wso2(access_token_generator, jwks_request, wso2_client): + # disable 'token_use' (not included in WSO2 access token) + claims = wso2_client.parse_access_token( + access_token_generator(email="test@wso2", token_use=None) + ) + + assert claims["email"] == "test@wso2" + + +@pytest.mark.parametrize( + "claims_mod", [{"aud": "abc123"}, {"sub": None}, {"iss": "abc123"}, {"exp": 0}] +) +def test_parse_access_token_wso2_invalid_claims( + claims_mod, access_token_generator, jwks_request, wso2_client +): + token = access_token_generator(**claims_mod) + with pytest.raises(JoseError): + wso2_client.parse_access_token(token) diff --git a/nens_auth_client/wso2.py b/nens_auth_client/wso2.py index 48ffe72..597b4b7 100644 --- a/nens_auth_client/wso2.py +++ b/nens_auth_client/wso2.py @@ -1,4 +1,7 @@ from authlib.integrations.django_client import DjangoOAuth2App +from authlib.jose import JsonWebKey +from authlib.jose import JsonWebToken +from django.conf import settings from django.http.response import HttpResponseRedirect from urllib.parse import urlencode from urllib.parse import urlparse @@ -46,17 +49,47 @@ def logout_redirect(self, request, redirect_uri=None, login_after=False): def parse_access_token(self, token, claims_options=None, leeway=120): """Decode and validate a WSO2 access token and return its payload. - Note: this function just errors with the token claims in the error - message (so that we can figure out how we can actually validate the - token) - Args: token (str): access token (base64 encoded JWT) + Returns: + claims (dict): the token payload + Raises: - NotImplementedError: always + authlib.jose.errors.JoseError: if token is invalid + ValueError: if the key id is not present in the jwks.json """ - raise NotImplementedError(decode_jwt(token)) + + # this is a copy from the _parse_id_token equivalent function + def load_key(header, payload): + jwk_set = self.fetch_jwk_set() + kid = header.get("kid") + try: + return JsonWebKey.import_key_set(jwk_set).find_by_kid(kid) + except ValueError: + # re-try with new jwk set + jwk_set = self.fetch_jwk_set(force=True) + return JsonWebKey.import_key_set(jwk_set).find_by_kid(kid) + + metadata = self.load_server_metadata() + claims_options = { + "aud": {"essential": True, "value": settings.NENS_AUTH_RESOURCE_SERVER_ID}, + "iss": {"essential": True, "value": metadata["issuer"]}, + "sub": {"essential": True}, + "scope": {"essential": True}, + **(claims_options or {}), + } + + alg_values = metadata.get("id_token_signing_alg_values_supported") + if not alg_values: + alg_values = ["RS256"] + + claims = JsonWebToken(alg_values).decode( + token, key=load_key, claims_options=claims_options + ) + + claims.validate(leeway=leeway) + return claims @staticmethod def extract_provider_name(claims): diff --git a/setup.cfg b/setup.cfg index 9830613..efa77a3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,6 +13,6 @@ force_single_line = true [tool:pytest] DJANGO_SETTINGS_MODULE = nens_auth_client.testsettings -addopts = --cov --cache-clear --cov-report=term-missing nens_auth_client +#addopts = --cov --cache-clear --cov-report=term-missing nens_auth_client python_files = test_*.py junit_family = xunit1 From c1c465b84a14cf9a14344ef4268f0b2652acf263 Mon Sep 17 00:00:00 2001 From: Casper van der Wel Date: Wed, 20 Mar 2024 09:26:14 +0100 Subject: [PATCH 2/8] Revert change in setup.cfg --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index efa77a3..9830613 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,6 +13,6 @@ force_single_line = true [tool:pytest] DJANGO_SETTINGS_MODULE = nens_auth_client.testsettings -#addopts = --cov --cache-clear --cov-report=term-missing nens_auth_client +addopts = --cov --cache-clear --cov-report=term-missing nens_auth_client python_files = test_*.py junit_family = xunit1 From 7609c69739d1b9819a126c22cfa83b53d01a9d97 Mon Sep 17 00:00:00 2001 From: Casper van der Wel Date: Wed, 20 Mar 2024 13:33:18 +0100 Subject: [PATCH 3/8] PR change --- nens_auth_client/cognito.py | 13 ++++++------- nens_auth_client/wso2.py | 13 ++++++------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/nens_auth_client/cognito.py b/nens_auth_client/cognito.py index f5863c2..e830092 100644 --- a/nens_auth_client/cognito.py +++ b/nens_auth_client/cognito.py @@ -115,16 +115,15 @@ def parse_access_token(self, token, claims_options=None, leeway=120): ValueError: if the key id is not present in the jwks.json """ - # this is a copy from the _parse_id_token equivalent function - def load_key(header, payload): - jwk_set = self.fetch_jwk_set() - kid = header.get("kid") + # this is a copy from authlib.integrations.base_client.sync_openid.parse_id_token equivalent function + def load_key(header, _): + jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set()) try: - return JsonWebKey.import_key_set(jwk_set).find_by_kid(kid) + return jwk_set.find_by_kid(header.get("kid")) except ValueError: # re-try with new jwk set - jwk_set = self.fetch_jwk_set(force=True) - return JsonWebKey.import_key_set(jwk_set).find_by_kid(kid) + jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set(force=True)) + return jwk_set.find_by_kid(header.get("kid")) metadata = self.load_server_metadata() claims_options = { diff --git a/nens_auth_client/wso2.py b/nens_auth_client/wso2.py index 597b4b7..efb1b78 100644 --- a/nens_auth_client/wso2.py +++ b/nens_auth_client/wso2.py @@ -60,16 +60,15 @@ def parse_access_token(self, token, claims_options=None, leeway=120): ValueError: if the key id is not present in the jwks.json """ - # this is a copy from the _parse_id_token equivalent function - def load_key(header, payload): - jwk_set = self.fetch_jwk_set() - kid = header.get("kid") + # this is a copy from authlib.integrations.base_client.sync_openid.parse_id_token equivalent function + def load_key(header, _): + jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set()) try: - return JsonWebKey.import_key_set(jwk_set).find_by_kid(kid) + return jwk_set.find_by_kid(header.get("kid")) except ValueError: # re-try with new jwk set - jwk_set = self.fetch_jwk_set(force=True) - return JsonWebKey.import_key_set(jwk_set).find_by_kid(kid) + jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set(force=True)) + return jwk_set.find_by_kid(header.get("kid")) metadata = self.load_server_metadata() claims_options = { From ccf36d963ae763d2e89366d1149731c90a18b34e Mon Sep 17 00:00:00 2001 From: Casper van der Wel Date: Wed, 20 Mar 2024 14:04:25 +0100 Subject: [PATCH 4/8] Remove code duplication by separating out OAuth2Base class --- nens_auth_client/cognito.py | 138 +++++++--------------- nens_auth_client/oauth_base.py | 91 ++++++++++++++ nens_auth_client/tests/test_cognito.py | 3 +- nens_auth_client/tests/test_middleware.py | 93 ++++++++++++--- nens_auth_client/tests/test_oauth_base.py | 64 ++++++++++ nens_auth_client/tests/test_wso2.py | 21 ---- nens_auth_client/wso2.py | 63 +--------- setup.cfg | 2 +- 8 files changed, 280 insertions(+), 195 deletions(-) create mode 100644 nens_auth_client/oauth_base.py create mode 100644 nens_auth_client/tests/test_oauth_base.py diff --git a/nens_auth_client/cognito.py b/nens_auth_client/cognito.py index e830092..535564e 100644 --- a/nens_auth_client/cognito.py +++ b/nens_auth_client/cognito.py @@ -1,6 +1,4 @@ -from authlib.integrations.django_client import DjangoOAuth2App -from authlib.jose import JsonWebKey -from authlib.jose import JsonWebToken +from .oauth_base import BaseOAuthClient from django.conf import settings from django.http.response import HttpResponseRedirect from urllib.parse import urlencode @@ -8,59 +6,7 @@ from urllib.parse import urlunparse -def preprocess_access_token(claims): - """Convert AWS Cognito Access token claims to standard form, inplace. - - AWS Cognito Access tokens are missing the "aud" (audience) claim and - instead put the audience into each scope. - - This function filters the scopes on those that start with the - NENS_AUTH_RESOURCE_SERVER_ID setting. If there is any matching scope, the - "aud" claim will be set. - - The resulting "scope" has no audience(s) in it anymore. - - Args: - claims (dict): payload of the Access Token - - Example: - >>> audience = "https://some/api/" - >>> claims = { - "scope": "https://some/api/users.readwrite https://something/else" - } - >>> preprocess_access_token(claims) - >>> claims - { - "aud": "https://some/api/", - "scopes": "users.readwrite", - ... - } - """ - # Do nothing if there is an already an "aud" claim - if "aud" in claims: - return - - # Get the expected "aud" claim - audience = settings.NENS_AUTH_RESOURCE_SERVER_ID - - # List scopes and chop off the audience from the scope - new_scopes = [] - for scope_item in claims.get("scope", "").split(" "): - if scope_item.startswith(audience): - scope_without_audience = scope_item[len(audience) :] - new_scopes.append(scope_without_audience) - - # Don't set the audience if there are no scopes as Access Token is - # apparently not meant for this server. - if not new_scopes: - return - - # Update the claims inplace - claims["aud"] = audience - claims["scope"] = " ".join(new_scopes) - - -class CognitoOAuthClient(DjangoOAuth2App): +class CognitoOAuthClient(BaseOAuthClient): def logout_redirect(self, request, redirect_uri=None, login_after=False): """Create a redirect to the remote server's logout endpoint @@ -97,56 +43,56 @@ def logout_redirect(self, request, redirect_uri=None, login_after=False): return HttpResponseRedirect(logout_url) - def parse_access_token(self, token, claims_options=None, leeway=120): - """Decode and validate a Cognito access token and return its payload. + def preprocess_access_token(self, claims): + """Convert AWS Cognito Access token claims to standard form, inplace. - Note: this function is based on authlib.DjangoRemoteApp._parse_id_token - to make use of the same server settings and key cache. The token claims - are AWS Cognito specific. + AWS Cognito Access tokens are missing the "aud" (audience) claim and + instead put the audience into each scope. - Args: - token (str): access token (base64 encoded JWT) + This function filters the scopes on those that start with the + NENS_AUTH_RESOURCE_SERVER_ID setting. If there is any matching scope, the + "aud" claim will be set. - Returns: - claims (dict): the token payload + The resulting "scope" has no audience(s) in it anymore. - Raises: - authlib.jose.errors.JoseError: if token is invalid - ValueError: if the key id is not present in the jwks.json - """ + Args: + claims (dict): payload of the Access Token - # this is a copy from authlib.integrations.base_client.sync_openid.parse_id_token equivalent function - def load_key(header, _): - jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set()) - try: - return jwk_set.find_by_kid(header.get("kid")) - except ValueError: - # re-try with new jwk set - jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set(force=True)) - return jwk_set.find_by_kid(header.get("kid")) - - metadata = self.load_server_metadata() - claims_options = { - "aud": {"essential": True, "value": settings.NENS_AUTH_RESOURCE_SERVER_ID}, - "iss": {"essential": True, "value": metadata["issuer"]}, - "sub": {"essential": True}, - "scope": {"essential": True}, - **(claims_options or {}), + Example: + >>> audience = "https://some/api/" + >>> claims = { + "scope": "https://some/api/users.readwrite https://something/else" } + >>> preprocess_access_token(claims) + >>> claims + { + "aud": "https://some/api/", + "scopes": "users.readwrite", + ... + } + """ + # Do nothing if there is an already an "aud" claim + if "aud" in claims: + return - alg_values = metadata.get("id_token_signing_alg_values_supported") - if not alg_values: - alg_values = ["RS256"] + # Get the expected "aud" claim + audience = settings.NENS_AUTH_RESOURCE_SERVER_ID - claims = JsonWebToken(alg_values).decode( - token, key=load_key, claims_options=claims_options - ) + # List scopes and chop off the audience from the scope + new_scopes = [] + for scope_item in claims.get("scope", "").split(" "): + if scope_item.startswith(audience): + scope_without_audience = scope_item[len(audience) :] + new_scopes.append(scope_without_audience) - # Preprocess the token (to add the "aud" claim) - preprocess_access_token(claims) + # Don't set the audience if there are no scopes as Access Token is + # apparently not meant for this server. + if not new_scopes: + return - claims.validate(leeway=leeway) - return claims + # Update the claims inplace + claims["aud"] = audience + claims["scope"] = " ".join(new_scopes) @staticmethod def extract_provider_name(claims): diff --git a/nens_auth_client/oauth_base.py b/nens_auth_client/oauth_base.py new file mode 100644 index 0000000..3e8164d --- /dev/null +++ b/nens_auth_client/oauth_base.py @@ -0,0 +1,91 @@ +from authlib.integrations.django_client import DjangoOAuth2App +from authlib.jose import JsonWebKey +from authlib.jose import JsonWebToken +from django.conf import settings + + +class BaseOAuthClient(DjangoOAuth2App): + def logout_redirect(self, request, redirect_uri=None, login_after=False): + """Create a redirect to the remote server's logout endpoint + + Note that unlike with login, there is no standardization for logout. + This function should be written for a specific authorization server. + + Args: + request: The current request + redirect_uri: The absolute url to the logout success view of this app + login_after: whether to show the login screen after logout + + Returns: + HttpResponseRedirect authorization server logout endpoint + """ + raise NotImplementedError() + + def load_key(self, header, payload): + """Load a JSONWebKey from the authorization server given JWT header and payload. + + Source: + authlib.integrations.base_client.sync_openid.parse_id_token + """ + jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set()) + try: + return jwk_set.find_by_kid(header.get("kid")) + except ValueError: + # re-try with new jwk set + jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set(force=True)) + return jwk_set.find_by_kid(header.get("kid")) + + def preprocess_access_token(self, claims): + """Convert access token claims to standard form, inplace. + + Args: + claims (dict): payload of the Access Token + """ + + def parse_access_token(self, token, claims_options=None, leeway=120): + """Decode and validate an access token and return its payload. + + Args: + token (str): access token (base64 encoded JWT) + + Returns: + claims (dict): the token payload + + Raises: + authlib.jose.errors.JoseError: if token is invalid + ValueError: if the key id is not present in the jwks.json + """ + metadata = self.load_server_metadata() + claims_options = { + "aud": {"essential": True, "value": settings.NENS_AUTH_RESOURCE_SERVER_ID}, + "iss": {"essential": True, "value": metadata["issuer"]}, + "sub": {"essential": True}, + "scope": {"essential": True}, + **(claims_options or {}), + } + + alg_values = metadata.get("id_token_signing_alg_values_supported") + if not alg_values: + alg_values = ["RS256"] + + claims = JsonWebToken(alg_values).decode( + token, key=self.load_key, claims_options=claims_options + ) + + # Preprocess the token (to add the "aud" claim) + self.preprocess_access_token(claims) + + claims.validate(leeway=leeway) + return claims + + @staticmethod + def extract_provider_name(claims): + """Return provider name from claim and `None` if not found""" + # Also used by backends.py + raise NotImplementedError() + + @staticmethod + def extract_username(claims) -> str: + """Return username from claims""" + # Also used by backends.py + raise NotImplementedError() diff --git a/nens_auth_client/tests/test_cognito.py b/nens_auth_client/tests/test_cognito.py index eb51a1c..a616da1 100644 --- a/nens_auth_client/tests/test_cognito.py +++ b/nens_auth_client/tests/test_cognito.py @@ -1,5 +1,4 @@ from nens_auth_client.cognito import CognitoOAuthClient -from nens_auth_client.cognito import preprocess_access_token import pytest @@ -19,7 +18,7 @@ ) def test_preprocess_access_token(claims, expected, settings): settings.NENS_AUTH_RESOURCE_SERVER_ID = "api/" - preprocess_access_token(claims) + CognitoOAuthClient.preprocess_access_token(None, claims) assert claims == expected diff --git a/nens_auth_client/tests/test_middleware.py b/nens_auth_client/tests/test_middleware.py index 50adf11..f1de47f 100644 --- a/nens_auth_client/tests/test_middleware.py +++ b/nens_auth_client/tests/test_middleware.py @@ -1,3 +1,5 @@ +from authlib.jose.errors import JoseError +from django.conf import settings from django.contrib.auth import get_user_model from django.contrib.auth.models import AnonymousUser from nens_auth_client.middleware import AccessTokenMiddleware @@ -8,16 +10,22 @@ @pytest.fixture -def mocked_middleware(rf, mocker, rq_mocker, jwks_request, settings): - """Mock necessary functions to test AccessTokenMiddleware request""" - # Mock the user association call +def mocked_oauth_client(mocker): + get_oauth_client = mocker.patch("nens_auth_client.middleware.get_oauth_client") + get_oauth_client.return_value.parse_access_token.return_value = {"scope": "foo"} + return get_oauth_client.return_value + + +@pytest.fixture +def mocked_authenticate(mocker): authenticate = mocker.patch("django.contrib.auth.authenticate") authenticate.return_value = UserModel(username="testuser") - # Disable the custom AWS Cognito Access Token mapping - mocker.patch("nens_auth_client.cognito.preprocess_access_token") - # Make a middleware that returns the request as a response - middleware = AccessTokenMiddleware(get_response=lambda x: x) - return middleware + return authenticate + + +@pytest.fixture +def middleware(): + return AccessTokenMiddleware(get_response=lambda x: x) @pytest.fixture @@ -28,16 +36,73 @@ def r(rf): return request -def test_middleware(r, access_token_generator, mocked_middleware): - r.META["HTTP_AUTHORIZATION"] = "Bearer " + access_token_generator() - processed_request = mocked_middleware(r) +def test_middleware( + r, access_token_generator, mocked_oauth_client, mocked_authenticate, middleware +): + token = access_token_generator() + r.META["HTTP_AUTHORIZATION"] = "Bearer " + token + + processed_request = middleware(r) assert processed_request.user.username == "testuser" - assert processed_request.user.oauth2_scope == "readwrite" + assert processed_request.user.oauth2_scope == "foo" + + mocked_oauth_client.parse_access_token.assert_called_once_with( + token, leeway=settings.NENS_AUTH_LEEWAY + ) + mocked_authenticate.assert_called_once_with( + r, claims=mocked_oauth_client.parse_access_token.return_value + ) -def test_middleware_logged_in_user(r, access_token_generator, mocked_middleware): +def test_middleware_logged_in_user( + r, access_token_generator, middleware, mocked_oauth_client +): # An already logged in user (e.g. session cookie) is unchanged r.user = UserModel(username="otheruser") r.META["HTTP_AUTHORIZATION"] = "Bearer " + access_token_generator() - processed_request = mocked_middleware(r) + processed_request = middleware(r) assert processed_request.user.username == "otheruser" + + assert not mocked_oauth_client.parse_access_token.called + + +def test_middleware_no_token(r, middleware, mocked_oauth_client): + middleware(r) + + assert not mocked_oauth_client.parse_access_token.called + + +def test_middleware_invalid_token( + r, access_token_generator, mocked_oauth_client, mocked_authenticate, middleware +): + token = access_token_generator() + r.META["HTTP_AUTHORIZATION"] = "Bearer " + token + + mocked_oauth_client.parse_access_token.side_effect = JoseError() + + processed_request = middleware(r) + assert not processed_request.user.is_authenticated + + mocked_oauth_client.parse_access_token.assert_called_once_with( + token, leeway=settings.NENS_AUTH_LEEWAY + ) + assert not mocked_authenticate.called + + +def test_middleware_no_authentication( + r, access_token_generator, mocked_oauth_client, mocked_authenticate, middleware +): + mocked_authenticate.return_value = None + + token = access_token_generator() + r.META["HTTP_AUTHORIZATION"] = "Bearer " + token + + processed_request = middleware(r) + assert not processed_request.user.is_authenticated + + mocked_oauth_client.parse_access_token.assert_called_once_with( + token, leeway=settings.NENS_AUTH_LEEWAY + ) + mocked_authenticate.assert_called_once_with( + r, claims=mocked_oauth_client.parse_access_token.return_value + ) diff --git a/nens_auth_client/tests/test_oauth_base.py b/nens_auth_client/tests/test_oauth_base.py new file mode 100644 index 0000000..fde5779 --- /dev/null +++ b/nens_auth_client/tests/test_oauth_base.py @@ -0,0 +1,64 @@ +from authlib.jose.errors import JoseError +from authlib.oidc.discovery import get_well_known_url +from django.conf import settings +from nens_auth_client.oauth_base import BaseOAuthClient +from unittest import mock + +import pytest +import time + + +@pytest.fixture +def oauth_client(): + return BaseOAuthClient( + "foo", + server_metadata_url=get_well_known_url( + settings.NENS_AUTH_ISSUER, external=True + ), + ) + + +def test_parse_access_token(access_token_generator, jwks_request, oauth_client): + claims = oauth_client.parse_access_token(access_token_generator(email="test@wso2")) + + assert claims["email"] == "test@wso2" + + +@pytest.mark.parametrize( + "claims_mod", [{"aud": "abc123"}, {"sub": None}, {"iss": "abc123"}, {"exp": 0}] +) +def test_parse_access_token_invalid_claims( + claims_mod, access_token_generator, jwks_request, oauth_client +): + token = access_token_generator(**claims_mod) + with pytest.raises(JoseError): + oauth_client.parse_access_token(token) + + +def test_parse_access_token_preprocess( + access_token_generator, jwks_request, oauth_client +): + # In this example, the preprocess function makes an otherwise invalid token valid + def preprocess(claims): + claims["exp"] = time.time() + 1 + + with mock.patch.object( + oauth_client, "preprocess_access_token", side_effect=preprocess + ): + claims = oauth_client.parse_access_token(access_token_generator(exp=0)) + + assert claims["exp"] > 0 + + +def test_parse_access_token_preprocess_err( + access_token_generator, jwks_request, oauth_client +): + # In this example, the preprocess function makes an otherwise valid token invalid + def preprocess(claims): + claims["exp"] = 0 + + with mock.patch.object( + oauth_client, "preprocess_access_token", side_effect=preprocess + ): + with pytest.raises(JoseError): + oauth_client.parse_access_token(access_token_generator()) diff --git a/nens_auth_client/tests/test_wso2.py b/nens_auth_client/tests/test_wso2.py index c91a0c6..292f55c 100644 --- a/nens_auth_client/tests/test_wso2.py +++ b/nens_auth_client/tests/test_wso2.py @@ -1,4 +1,3 @@ -from authlib.jose.errors import JoseError from authlib.oidc.discovery import get_well_known_url from django.conf import settings from nens_auth_client.wso2 import WSO2AuthClient @@ -30,23 +29,3 @@ def wso2_client(): settings.NENS_AUTH_ISSUER, external=True ), ) - - -def test_parse_access_token_wso2(access_token_generator, jwks_request, wso2_client): - # disable 'token_use' (not included in WSO2 access token) - claims = wso2_client.parse_access_token( - access_token_generator(email="test@wso2", token_use=None) - ) - - assert claims["email"] == "test@wso2" - - -@pytest.mark.parametrize( - "claims_mod", [{"aud": "abc123"}, {"sub": None}, {"iss": "abc123"}, {"exp": 0}] -) -def test_parse_access_token_wso2_invalid_claims( - claims_mod, access_token_generator, jwks_request, wso2_client -): - token = access_token_generator(**claims_mod) - with pytest.raises(JoseError): - wso2_client.parse_access_token(token) diff --git a/nens_auth_client/wso2.py b/nens_auth_client/wso2.py index efb1b78..f385c3d 100644 --- a/nens_auth_client/wso2.py +++ b/nens_auth_client/wso2.py @@ -1,26 +1,11 @@ -from authlib.integrations.django_client import DjangoOAuth2App -from authlib.jose import JsonWebKey -from authlib.jose import JsonWebToken -from django.conf import settings +from .oauth_base import BaseOAuthClient from django.http.response import HttpResponseRedirect from urllib.parse import urlencode from urllib.parse import urlparse from urllib.parse import urlunparse -import base64 -import json - -def decode_jwt(token): - """Decode a JWT without checking its signature""" - # JWT consists of {header}.{payload}.{signature} - _, payload, _ = token.split(".") - # JWT should be padded with = (base64.b64decode expects this) - payload += "=" * (-len(payload) % 4) - return json.loads(base64.b64decode(payload)) - - -class WSO2AuthClient(DjangoOAuth2App): +class WSO2AuthClient(BaseOAuthClient): def logout_redirect(self, request, redirect_uri=None, login_after=False): """Create a redirect to the remote server's logout endpoint @@ -46,50 +31,6 @@ def logout_redirect(self, request, redirect_uri=None, login_after=False): logout_url = urlunparse(auth_url) return HttpResponseRedirect(logout_url) - def parse_access_token(self, token, claims_options=None, leeway=120): - """Decode and validate a WSO2 access token and return its payload. - - Args: - token (str): access token (base64 encoded JWT) - - Returns: - claims (dict): the token payload - - Raises: - authlib.jose.errors.JoseError: if token is invalid - ValueError: if the key id is not present in the jwks.json - """ - - # this is a copy from authlib.integrations.base_client.sync_openid.parse_id_token equivalent function - def load_key(header, _): - jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set()) - try: - return jwk_set.find_by_kid(header.get("kid")) - except ValueError: - # re-try with new jwk set - jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set(force=True)) - return jwk_set.find_by_kid(header.get("kid")) - - metadata = self.load_server_metadata() - claims_options = { - "aud": {"essential": True, "value": settings.NENS_AUTH_RESOURCE_SERVER_ID}, - "iss": {"essential": True, "value": metadata["issuer"]}, - "sub": {"essential": True}, - "scope": {"essential": True}, - **(claims_options or {}), - } - - alg_values = metadata.get("id_token_signing_alg_values_supported") - if not alg_values: - alg_values = ["RS256"] - - claims = JsonWebToken(alg_values).decode( - token, key=load_key, claims_options=claims_options - ) - - claims.validate(leeway=leeway) - return claims - @staticmethod def extract_provider_name(claims): """Return provider name from claim and `None` if not found""" diff --git a/setup.cfg b/setup.cfg index 9830613..cf79883 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,6 +13,6 @@ force_single_line = true [tool:pytest] DJANGO_SETTINGS_MODULE = nens_auth_client.testsettings -addopts = --cov --cache-clear --cov-report=term-missing nens_auth_client +# addopts = --cov --cache-clear --cov-report=term-missing nens_auth_client python_files = test_*.py junit_family = xunit1 From fdd3905b3ad36ac0fdf1d5f460625b31184c044d Mon Sep 17 00:00:00 2001 From: Casper van der Wel Date: Wed, 20 Mar 2024 14:06:57 +0100 Subject: [PATCH 5/8] Fix rest framework tests --- nens_auth_client/tests/test_rest_framework.py | 39 ++++++++++++------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/nens_auth_client/tests/test_rest_framework.py b/nens_auth_client/tests/test_rest_framework.py index 2028f23..05bf5f9 100644 --- a/nens_auth_client/tests/test_rest_framework.py +++ b/nens_auth_client/tests/test_rest_framework.py @@ -13,30 +13,39 @@ def r(): @pytest.fixture -def mocked_authenticator(rf, mocker, rq_mocker, jwks_request, settings): - """Mock necessary functions to test AccessTokenMiddleware request""" - # Mock the user association call +def mocked_oauth_client(mocker): + get_oauth_client = mocker.patch( + "nens_auth_client.rest_framework.authentication.get_oauth_client" + ) + get_oauth_client.return_value.parse_access_token.return_value = {"scope": "foo"} + return get_oauth_client.return_value + + +@pytest.fixture +def mocked_authenticate(mocker): authenticate = mocker.patch("django.contrib.auth.authenticate") authenticate.return_value = UserModel(username="testuser") - # Disable the custom AWS Cognito Access Token mapping - mocker.patch("nens_auth_client.cognito.preprocess_access_token") - # Make a middleware that returns the request as a response + return authenticate + + +@pytest.fixture +def authenticator(): return OAuth2TokenAuthentication() -def test_authentication_class(r, mocked_authenticator, access_token_generator): +def test_authentication_class( + r, authenticator, access_token_generator, mocked_oauth_client, mocked_authenticate +): r.META["HTTP_AUTHORIZATION"] = "Bearer " + access_token_generator() - user, auth = mocked_authenticator.authenticate(r) + user, auth = authenticator.authenticate(r) assert user.username == "testuser" - assert auth.scope == "readwrite" + assert auth.scope == "foo" -def test_authentication_class_no_header(r, mocked_authenticator): - assert mocked_authenticator.authenticate(r) is None +def test_authentication_class_no_header(r, authenticator): + assert authenticator.authenticate(r) is None -def test_authentication_class_no_bearer( - r, mocked_authenticator, access_token_generator -): +def test_authentication_class_no_bearer(r, authenticator, access_token_generator): r.META["HTTP_AUTHORIZATION"] = "Token xxx" - assert mocked_authenticator.authenticate(r) is None + assert authenticator.authenticate(r) is None From d95d326a272a96a1be58407183f6f63a1b83801b Mon Sep 17 00:00:00 2001 From: Casper van der Wel Date: Wed, 20 Mar 2024 14:14:51 +0100 Subject: [PATCH 6/8] Revert setup.cfg --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index cf79883..9830613 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,6 +13,6 @@ force_single_line = true [tool:pytest] DJANGO_SETTINGS_MODULE = nens_auth_client.testsettings -# addopts = --cov --cache-clear --cov-report=term-missing nens_auth_client +addopts = --cov --cache-clear --cov-report=term-missing nens_auth_client python_files = test_*.py junit_family = xunit1 From baa6bf42cb6a9c8e4a4023b131cb40275e3668a6 Mon Sep 17 00:00:00 2001 From: Casper van der Wel Date: Wed, 20 Mar 2024 14:30:00 +0100 Subject: [PATCH 7/8] Fix check on RESOURCE_SERVER_ID --- nens_auth_client/checks.py | 8 ++++++-- nens_auth_client/tests/test_wso2.py | 12 ------------ 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/nens_auth_client/checks.py b/nens_auth_client/checks.py index 602cc70..0d1aac2 100644 --- a/nens_auth_client/checks.py +++ b/nens_auth_client/checks.py @@ -19,11 +19,15 @@ def check_resource_server_id(app_configs=None, **kwargs): "AccessTokenMiddleware is used." ) ] - if not url.endswith("/"): + if ( + settings.NENS_AUTH_OAUTH_BACKEND + == "nens_auth_client.cognito.CognitoOAuthClient" + and not url.endswith("/") + ): return [ Error( "The NENS_AUTH_RESOURCE_SERVER_ID setting needs to end with a " - "slash (because AWS Cognito will automatically add one)." + "slash when using the CognitoOAuthClient." ) ] return [] diff --git a/nens_auth_client/tests/test_wso2.py b/nens_auth_client/tests/test_wso2.py index 292f55c..6897b99 100644 --- a/nens_auth_client/tests/test_wso2.py +++ b/nens_auth_client/tests/test_wso2.py @@ -1,5 +1,3 @@ -from authlib.oidc.discovery import get_well_known_url -from django.conf import settings from nens_auth_client.wso2 import WSO2AuthClient import pytest @@ -19,13 +17,3 @@ def test_extract_provider_name(): ) def test_extract_username(claims, expected): assert WSO2AuthClient.extract_username(claims) == expected - - -@pytest.fixture -def wso2_client(): - return WSO2AuthClient( - "foo", - server_metadata_url=get_well_known_url( - settings.NENS_AUTH_ISSUER, external=True - ), - ) From 406245fac5cdd5d88b5e41596183873780169594 Mon Sep 17 00:00:00 2001 From: Casper van der Wel Date: Wed, 20 Mar 2024 14:32:15 +0100 Subject: [PATCH 8/8] Changes [ci skip] --- CHANGES.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGES.rst b/CHANGES.rst index b9d816f..e9bd777 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -5,8 +5,9 @@ Changelog of nens-auth-client 1.5.1 (unreleased) ------------------ -- Nothing changed yet. +- Added Bearer token parsing to WSO2 client. +- Refactored the Cognito and WSO2 clients so they use the same baseclass. 1.5.0 (2024-02-19) ------------------