Skip to content

Commit

Permalink
Merge pull request #11 from ral-facilities/implement-refresh-endpoint…
Browse files Browse the repository at this point in the history
…-#10

Implement refresh endpoint
  • Loading branch information
VKTB authored Jan 30, 2024
2 parents 04910b8 + 2b96e53 commit affb151
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 48 deletions.
23 changes: 22 additions & 1 deletion ldap_jwt_auth/auth/jwt_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ldap_jwt_auth.core.config import config
from ldap_jwt_auth.core.constants import PRIVATE_KEY, PUBLIC_KEY
from ldap_jwt_auth.core.exceptions import InvalidJWTError
from ldap_jwt_auth.core.exceptions import InvalidJWTError, JWTRefreshError

logger = logging.getLogger()

Expand Down Expand Up @@ -44,6 +44,27 @@ def get_refresh_token(self) -> str:
}
return self._pack_jwt(payload)

def refresh_access_token(self, access_token: str, refresh_token: str):
"""
Refreshes the JWT access token by updating its expiry time, provided that the JWT refresh token is valid.
:param access_token: The JWT access token to refresh.
:param refresh_token: The JWT refresh token.
:raises JWTRefreshError: If the JWT access token cannot be refreshed.
:return: JWT access token with an updated expiry time.
"""
logger.info("Refreshing access token")
self.verify_token(refresh_token)
try:
payload = self._get_jwt_payload(access_token, {"verify_exp": False})
payload["exp"] = datetime.now(timezone.utc) + timedelta(
minutes=config.authentication.access_token_validity_minutes
)
return self._pack_jwt(payload)
except Exception as exc:
message = "Unable to refresh access token"
logger.exception(message)
raise JWTRefreshError(message) from exc

def verify_token(self, token: str) -> Dict[str, Any]:
"""
Verifies that the provided JWT token is valid. It does this by checking that it was signed by the corresponding
Expand Down
6 changes: 6 additions & 0 deletions ldap_jwt_auth/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ class InvalidJWTError(Exception):
"""


class JWTRefreshError(Exception):
"""
Exception raised when JWT access token cannot be refreshed.
"""


class LDAPServerError(Exception):
"""
Exception raised when there is problem with the LDAP server.
Expand Down
3 changes: 2 additions & 1 deletion ldap_jwt_auth/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from ldap_jwt_auth.core.config import config
from ldap_jwt_auth.core.logger_setup import setup_logger
from ldap_jwt_auth.routers import login, verify
from ldap_jwt_auth.routers import login, refresh, verify

app = FastAPI(title=config.api.title, description=config.api.description)

Expand Down Expand Up @@ -63,6 +63,7 @@ async def custom_validation_exception_handler(request: Request, exc: RequestVali
)

app.include_router(login.router)
app.include_router(refresh.router)
app.include_router(verify.router)


Expand Down
39 changes: 39 additions & 0 deletions ldap_jwt_auth/routers/refresh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
Module for providing an API router which defines a route for managing the refreshing/updating of a JWT access token
using a JWT refresh token.
"""
import logging
from typing import Annotated

from fastapi import APIRouter, Body, Cookie, Depends, HTTPException, status
from fastapi.responses import JSONResponse

from ldap_jwt_auth.auth.jwt_handler import JWTHandler
from ldap_jwt_auth.core.exceptions import JWTRefreshError, InvalidJWTError

logger = logging.getLogger()

router = APIRouter(prefix="/refresh", tags=["authentication"])


@router.post(
path="",
summary="Generate an updated JWT access token using the JWT refresh token",
response_description="A JWT access token",
)
def refresh_access_token(
jwt_handler: Annotated[JWTHandler, Depends(JWTHandler)],
token: Annotated[str, Body(description="The JWT access token to refresh", embed=True)],
refresh_token: Annotated[str | None, Cookie(description="The JWT refresh token from an HTTP-only cookie")] = None,
) -> JSONResponse:
# pylint: disable=missing-function-docstring
if refresh_token is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No JWT refresh token found")

try:
access_token = jwt_handler.refresh_access_token(token, refresh_token)
return JSONResponse(content=access_token)
except (InvalidJWTError, JWTRefreshError) as exc:
message = "Unable to refresh access token"
logger.exception(message)
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=message) from exc
153 changes: 107 additions & 46 deletions test/unit/auth/test_jwt_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,49 @@
import pytest

from ldap_jwt_auth.auth.jwt_handler import JWTHandler
from ldap_jwt_auth.core.exceptions import InvalidJWTError
from ldap_jwt_auth.core.exceptions import InvalidJWTError, JWTRefreshError

VALID_ACCESS_TOKEN = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VybmFtZSI6InVzZXJuYW1lIiwiZXhwIjoyNTM0MDIzMDA3OTl9.bagU2Wix8wKzydVU_L3Z"
"ZuuMAxGxV4OTuZq_kS2Fuwm839_8UZOkICnPTkkpvsm1je0AWJaIXLGgwEa5zUjpG6lTrMMmzR9Zi63F0NXpJqQqoOZpTBMYBaggsXqFkdsv-yAKUZ"
"8MfjCEyk3UZ4PXZmEcUZcLhKcXZr4kYJPjio2e5WOGpdjK6q7s-iHGs9DQFT_IoCnw9CkyOKwYdgpB35hIGHkNjiwVSHpyKbFQvzJmIv5XCTSRYqq0"
"1fldh-QYuZqZeuaFidKbLRH610o2-1IfPMUr-yPtj5PZ-AaX-XTLkuMqdVMCk0_jeW9Os2BPtyUDkpcu1fvW3_S6_dK3nQ"
)

VALID_REFRESH_TOKEN = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjI1MzQwMjMwMDc5OX0.h4Hv_sq4-ika1rpuRx7k3pp0cF_BZ65WVSbIHS7oh9SjPpGHt"
"GhVHU1IJXzFtyA9TH-68JpAZ24Dm6bXbH6VJKoc7RCbmJXm44ufN32ga7jDqXH340oKvi_wdhEHaCf2HXjzsHHD7_D6XIcxU71v2W5_j8Vuwpr3SdX"
"6ea_yLIaCDWynN6FomPtUepQAOg3c7DdKohbJD8WhKIDV8UKuLtFdRBfN4HEK5nNs0JroROPhcYM9L_JIQZpdI0c83fDFuXQC-cAygzrSnGJ6O4DyS"
"cNL3VBNSmNTBtqYOs1szvkpvF9rICPgbEEJnbS6g5kmGld3eioeuDJIxeQglSbxog"
)

EXPIRED_ACCESS_TOKEN = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VybmFtZSI6InVzZXJuYW1lIiwiZXhwIjotNjIxMzU1OTY4MDB9.G_cfC8PNYE5yERyyQNRk"
"9mTmDusU_rEPgm7feo2lWQF6QMNnf8PUN-61FfMNRVE0QDSvAmIMMNEOa8ma0JHZARafgnYJfn1_FSJSoRxC740GpG8EFSWrpM-dQXnoD263V9FlK-"
"On6IbhF-4Rh9MdoxNyZk2Lj7NvCzJ7gbgbgYM5-sJXLxB-I5LfMfuYM3fx2cRixZFA153l46tFzcMVBrAiBxl_LdyxTIOPfHF0UGlaW2UtFi02gyBU"
"4E4wTOqPc4t_CSi1oBSbY7h9O63i8IU99YsOCdvZ7AD3ePxyM1xJR7CFHycg9Z_IDouYnJmXpTpbFMMl7SjME3cVMfMrAQ"
)

EXPIRED_REFRESH_TOKEN = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOi02MjEzNTU5NjgwMH0.Er0A8dvdZi7o1FK3b-Te2IkUjDJZjI0aANsP7bbAbeITPRnR0"
"YEhavmuLT1zaoALQjUzfSgtH0s3I-YbUr2ssqG1DnKh83uts3J2_EXIXQZBeuZisCW1nN1LC2nsR6o4HQEsbMsINjJviHeMWS8nRC06XXpN1WFPaGB"
"xXkLFeDWb3SXiirZ79m7lUBwQvVzpfeA337e_AejG45mtadgfW3xpDCw-6sVVIA-cuzruxnjRKAzJrw_goA9X4MukRXbnzou2mgkxFKs_-6hdTFDI-"
"B47wYqalP6KC5nqzjrCpvjmukgM-DN0uAhm2TUzUmE5EXtRLEYMRqsSmog4hYq1Nw"
)

EXPECTED_ACCESS_TOKEN = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VybmFtZSI6InVzZXJuYW1lIiwiZXhwIjoxNzA1NDg1OTAwfQ.aWJ8T8RGHF93YhRSP9nOAD"
"EKY9nFjVIDu7RQhPGiMpvhgdpPBP17VQPbJ6Smt8mG1TjLXjquJZaDQRF7syrJd8ESDo-lh3ef-cMWg2hWZpbtpQaPaNHLAAMrjZo97qLxrBjeOKjY"
"ggqwKMr-7g_LlB--z9GiQrLJVhpGxAXjnTy9VSrioZIU7OE9L9tUyOI7LGjY0X2znWQ3Loy5sMwCP_SeFHBPolKXiErKeLItriaxYNEc5l5VXD2wsK"
"G9L8dDZZwe4BSU2eyT_2hhPTrVNfI8-J1KtwpLywC0NfS0Vaksy4HG2IbH8hpl6gaLZhtr2C5_0H_IpkTsvm_Zsnzhbg"
)

EXPECTED_REFRESH_TOKEN = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3MDYwOTA0MDB9.IHua0NcHiLOz7vamvcR4lxt-t51_UgzIQzho5vYK2UdHjG-bA5Sk"
"9YhHQy480UK4FiIKohpb8G70OwmsSCjzxvbo41MZKdz3z0z_4-L0_LSGLGGmxbvPaHy6_SI8qI1f7KOAD6T3OU1zIFTcyoREEN2uNRyjMnGcQzh72d"
"NkRAFEF3um4S2WVL0mwQ6ZltAjCiA2R8o5Eu3Aq67lkbq00ml69rfecT1JXiAfjrnW0J64COJDbQ9kVCNM1YrpqLBmROHMOOw9o7Qz1h78LbtKarVk"
"VGaPIxhdZsWKjZwDD-6h15NZuKTAmcPUaucx6Dd4uCjJHld1BNsfKfX_81G03g"
)


def mock_datetime_now() -> datetime:
Expand All @@ -23,52 +65,81 @@ def test_get_access_token(datetime_mock):
"""
Test getting an access token.
"""
expected_access_token = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VybmFtZSI6InVzZXJuYW1lIiwiZXhwIjoxNzA1NDg1OTAwfQ.aWJ8T8RGHF93YhRSP9"
"nOADEKY9nFjVIDu7RQhPGiMpvhgdpPBP17VQPbJ6Smt8mG1TjLXjquJZaDQRF7syrJd8ESDo-lh3ef-cMWg2hWZpbtpQaPaNHLAAMrjZo97qLx"
"rBjeOKjYggqwKMr-7g_LlB--z9GiQrLJVhpGxAXjnTy9VSrioZIU7OE9L9tUyOI7LGjY0X2znWQ3Loy5sMwCP_SeFHBPolKXiErKeLItriaxYN"
"Ec5l5VXD2wsKG9L8dDZZwe4BSU2eyT_2hhPTrVNfI8-J1KtwpLywC0NfS0Vaksy4HG2IbH8hpl6gaLZhtr2C5_0H_IpkTsvm_Zsnzhbg"
)
datetime_mock.now.return_value = mock_datetime_now()

jwt_handler = JWTHandler()
access_token = jwt_handler.get_access_token("username")

assert access_token == expected_access_token
assert access_token == EXPECTED_ACCESS_TOKEN


@patch("ldap_jwt_auth.auth.jwt_handler.datetime")
def test_get_refresh_token(datetime_mock):
"""
Test getting a refresh token.
"""
expected_refresh_token = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3MDYwOTA0MDB9.IHua0NcHiLOz7vamvcR4lxt-t51_UgzIQzho5vYK2UdHjG-b"
"A5Sk9YhHQy480UK4FiIKohpb8G70OwmsSCjzxvbo41MZKdz3z0z_4-L0_LSGLGGmxbvPaHy6_SI8qI1f7KOAD6T3OU1zIFTcyoREEN2uNRyjMn"
"GcQzh72dNkRAFEF3um4S2WVL0mwQ6ZltAjCiA2R8o5Eu3Aq67lkbq00ml69rfecT1JXiAfjrnW0J64COJDbQ9kVCNM1YrpqLBmROHMOOw9o7Qz"
"1h78LbtKarVkVGaPIxhdZsWKjZwDD-6h15NZuKTAmcPUaucx6Dd4uCjJHld1BNsfKfX_81G03g"
)
datetime_mock.now.return_value = mock_datetime_now()

jwt_handler = JWTHandler()
refresh_token = jwt_handler.get_refresh_token()

assert refresh_token == expected_refresh_token
assert refresh_token == EXPECTED_REFRESH_TOKEN


@patch("ldap_jwt_auth.auth.jwt_handler.datetime")
def test_refresh_access_token(datetime_mock):
"""
Test refreshing an expired access token with a valid refresh token.
"""
datetime_mock.now.return_value = mock_datetime_now()

jwt_handler = JWTHandler()
access_token = jwt_handler.refresh_access_token(EXPIRED_ACCESS_TOKEN, VALID_REFRESH_TOKEN)

assert access_token == EXPECTED_ACCESS_TOKEN


@patch("ldap_jwt_auth.auth.jwt_handler.datetime")
def test_refresh_access_token_with_valid_access_token(datetime_mock):
"""
Test refreshing a valid access token with a valid refresh token.
"""
datetime_mock.now.return_value = mock_datetime_now()

jwt_handler = JWTHandler()
access_token = jwt_handler.refresh_access_token(VALID_ACCESS_TOKEN, VALID_REFRESH_TOKEN)

assert access_token == EXPECTED_ACCESS_TOKEN


def test_refresh_access_token_with_invalid_access_token():
"""
Test refreshing an invalid access token with a valid refresh token.
"""
jwt_handler = JWTHandler()

with pytest.raises(JWTRefreshError) as exc:
jwt_handler.refresh_access_token("invalid", VALID_REFRESH_TOKEN)
assert str(exc.value) == "Unable to refresh access token"


def test_refresh_access_token_with_expired_refresh_token():
"""
Test refreshing an expired access token with an expired refresh token.
"""
jwt_handler = JWTHandler()

with pytest.raises(InvalidJWTError) as exc:
jwt_handler.refresh_access_token(EXPIRED_ACCESS_TOKEN, EXPIRED_REFRESH_TOKEN)
assert str(exc.value) == "Invalid JWT token"


def test_verify_token_with_access_token():
"""
Test verifying a valid access token.
"""
access_token = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VybmFtZSI6InVzZXJuYW1lIiwiZXhwIjoyNTM0MDIzMDA3OTl9.bagU2Wix8wKzydVU"
"_L3ZZuuMAxGxV4OTuZq_kS2Fuwm839_8UZOkICnPTkkpvsm1je0AWJaIXLGgwEa5zUjpG6lTrMMmzR9Zi63F0NXpJqQqoOZpTBMYBaggsXqFkd"
"sv-yAKUZ8MfjCEyk3UZ4PXZmEcUZcLhKcXZr4kYJPjio2e5WOGpdjK6q7s-iHGs9DQFT_IoCnw9CkyOKwYdgpB35hIGHkNjiwVSHpyKbFQvzJm"
"Iv5XCTSRYqq01fldh-QYuZqZeuaFidKbLRH610o2-1IfPMUr-yPtj5PZ-AaX-XTLkuMqdVMCk0_jeW9Os2BPtyUDkpcu1fvW3_S6_dK3nQ"
)

jwt_handler = JWTHandler()
payload = jwt_handler.verify_token(access_token)
payload = jwt_handler.verify_token(VALID_ACCESS_TOKEN)

assert payload == {"username": "username", "exp": 253402300799}

Expand All @@ -77,15 +148,8 @@ def test_verify_token_with_refresh_token():
"""
Test verifying a valid refresh token.
"""
refresh_token = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjI1MzQwMjMwMDc5OX0.h4Hv_sq4-ika1rpuRx7k3pp0cF_BZ65WVSbIHS7oh9SjP"
"pGHtGhVHU1IJXzFtyA9TH-68JpAZ24Dm6bXbH6VJKoc7RCbmJXm44ufN32ga7jDqXH340oKvi_wdhEHaCf2HXjzsHHD7_D6XIcxU71v2W5_j8V"
"uwpr3SdX6ea_yLIaCDWynN6FomPtUepQAOg3c7DdKohbJD8WhKIDV8UKuLtFdRBfN4HEK5nNs0JroROPhcYM9L_JIQZpdI0c83fDFuXQC-cAyg"
"zrSnGJ6O4DyScNL3VBNSmNTBtqYOs1szvkpvF9rICPgbEEJnbS6g5kmGld3eioeuDJIxeQglSbxog"
)

jwt_handler = JWTHandler()
payload = jwt_handler.verify_token(refresh_token)
payload = jwt_handler.verify_token(VALID_REFRESH_TOKEN)

assert payload == {"exp": 253402300799}

Expand All @@ -94,33 +158,30 @@ def test_verify_token_with_expired_access_token():
"""
Test verifying an expired access token.
"""
expired_access_token = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VybmFtZSI6InVzZXJuYW1lIiwiZXhwIjotNjIxMzU1OTY4MDB9.G_cfC8PNYE5yERyy"
"QNRk9mTmDusU_rEPgm7feo2lWQF6QMNnf8PUN-61FfMNRVE0QDSvAmIMMNEOa8ma0JHZARafgnYJfn1_FSJSoRxC740GpG8EFSWrpM-dQXnoD2"
"63V9FlK-On6IbhF-4Rh9MdoxNyZk2Lj7NvCzJ7gbgbgYM5-sJXLxB-I5LfMfuYM3fx2cRixZFA153l46tFzcMVBrAiBxl_LdyxTIOPfHF0UGla"
"W2UtFi02gyBU4E4wTOqPc4t_CSi1oBSbY7h9O63i8IU99YsOCdvZ7AD3ePxyM1xJR7CFHycg9Z_IDouYnJmXpTpbFMMl7SjME3cVMfMrAQ"
)

jwt_handler = JWTHandler()

with pytest.raises(InvalidJWTError) as exc:
jwt_handler.verify_token(expired_access_token)
jwt_handler.verify_token(EXPIRED_ACCESS_TOKEN)
assert str(exc.value) == "Invalid JWT token"


def test_verify_token_with_expired_refresh_token():
"""
Test verifying an expired refresh token.
"""
expired_refresh_tokenb = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOi02MjEzNTU5NjgwMH0.Er0A8dvdZi7o1FK3b-Te2IkUjDJZjI0aANsP7bbAbeITP"
"RnR0YEhavmuLT1zaoALQjUzfSgtH0s3I-YbUr2ssqG1DnKh83uts3J2_EXIXQZBeuZisCW1nN1LC2nsR6o4HQEsbMsINjJviHeMWS8nRC06XXp"
"N1WFPaGBxXkLFeDWb3SXiirZ79m7lUBwQvVzpfeA337e_AejG45mtadgfW3xpDCw-6sVVIA-cuzruxnjRKAzJrw_goA9X4MukRXbnzou2mgkxF"
"Ks_-6hdTFDI-B47wYqalP6KC5nqzjrCpvjmukgM-DN0uAhm2TUzUmE5EXtRLEYMRqsSmog4hYq1Nw"
)
jwt_handler = JWTHandler()

with pytest.raises(InvalidJWTError) as exc:
jwt_handler.verify_token(EXPIRED_REFRESH_TOKEN)
assert str(exc.value) == "Invalid JWT token"


def test_verify_token_with_invalid_token():
"""
Test verifying an invalid access token.
"""
jwt_handler = JWTHandler()

with pytest.raises(InvalidJWTError) as exc:
jwt_handler.verify_token(expired_refresh_tokenb)
jwt_handler.verify_token("invalid")
assert str(exc.value) == "Invalid JWT token"

0 comments on commit affb151

Please sign in to comment.