From bc9e85252aa864db957f4421b6f133d65f4d90f6 Mon Sep 17 00:00:00 2001 From: Stefaan Lippens Date: Tue, 21 Nov 2023 17:30:27 +0100 Subject: [PATCH] Issue #508 also refresh token on `401 TokenInvalid` --- CHANGELOG.md | 2 + openeo/rest/connection.py | 2 +- tests/rest/test_connection.py | 103 +++++++++++++++++++++++----------- 3 files changed, 72 insertions(+), 35 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9bb1494fb..c9d276189 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Follow the official band mapping from Awesome Spectral Indices better. Allow manually specifying the desired band mapping. ([#485](https://github.com/Open-EO/openeo-python-client/issues/485), [#501](https://github.com/Open-EO/openeo-python-client/issues/501)) +- Also attempt to automatically refresh OIDC access token on a `401 TokenInvalid` response (in addition to `403 TokenInvalid`) ([#508](https://github.com/Open-EO/openeo-python-client/issues/508)) + ### Changed diff --git a/openeo/rest/connection.py b/openeo/rest/connection.py index 2da1bd84b..a4bef517b 100644 --- a/openeo/rest/connection.py +++ b/openeo/rest/connection.py @@ -756,7 +756,7 @@ def _request(): # Initial request attempt return _request() except OpenEoApiError as api_exc: - if api_exc.http_status_code == 403 and api_exc.code == "TokenInvalid": + if api_exc.http_status_code in {401, 403} and api_exc.code == "TokenInvalid": # Auth token expired: can we refresh? if isinstance(self.auth, OidcBearerAuth) and self._oidc_auth_renewer: msg = f"OIDC access token expired ({api_exc.http_status_code} {api_exc.code})." diff --git a/tests/rest/test_connection.py b/tests/rest/test_connection.py index 922a2e6db..fc0c4bdcb 100644 --- a/tests/rest/test_connection.py +++ b/tests/rest/test_connection.py @@ -1836,7 +1836,7 @@ def test_authenticate_oidc_method_client_credentials_from_env( assert conn.auth.bearer == f"oidc/{expected_provider_id}/" + oidc_mock.state["access_token"] -def _setup_get_me_handler(requests_mock, oidc_mock: OidcMock): +def _setup_get_me_handler(requests_mock, oidc_mock: OidcMock, token_invalid_status_code: int = 403): def get_me(request: requests.Request, context): """handler for `GET /me` (with access_token checking)""" auth_header = request.headers["Authorization"] @@ -1844,7 +1844,7 @@ def get_me(request: requests.Request, context): try: user_id = oidc_mock.validate_access_token(access_token)["user_id"] except LookupError: - context.status_code = 403 + context.status_code = token_invalid_status_code return {"code": "TokenInvalid", "message": "Authorization token has expired or is invalid."} return { @@ -1857,12 +1857,16 @@ def get_me(request: requests.Request, context): requests_mock.get(API_URL + "me", json=get_me) -@pytest.mark.parametrize(["invalidate"], [ - (False,), - (True,), -]) +@pytest.mark.parametrize( + ["invalidate", "token_invalid_status_code"], + [ + (False, 403), + (True, 403), + (True, 401), + ], +) def test_authenticate_oidc_auto_renew_expired_access_token_initial_refresh_token( - requests_mock, refresh_token_store, invalidate, caplog + requests_mock, refresh_token_store, invalidate, token_invalid_status_code, caplog ): requests_mock.get(API_URL, json={"api_version": "1.0.0"}) client_id = "myclient" @@ -1888,7 +1892,9 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_refresh_token oidc_issuer=oidc_issuer, expected_fields={"refresh_token": initial_refresh_token} ) - _setup_get_me_handler(requests_mock=requests_mock, oidc_mock=oidc_mock) + _setup_get_me_handler( + requests_mock=requests_mock, oidc_mock=oidc_mock, token_invalid_status_code=token_invalid_status_code + ) caplog.set_level(logging.INFO) # Explicit authentication with `authenticate_oidc_refresh_token` @@ -1922,12 +1928,12 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_refresh_token # Two "refresh_token" auth requests should have happened now assert [h["grant_type"] for h in oidc_mock.grant_request_history] == ["refresh_token", "refresh_token"] assert access_token2 != access_token1 - assert "OIDC access token expired (403 TokenInvalid)" in caplog.text + assert f"OIDC access token expired ({token_invalid_status_code} TokenInvalid)" in caplog.text assert "Obtained new access token (grant 'refresh_token')" in caplog.text else: assert [h["grant_type"] for h in oidc_mock.grant_request_history] == ["refresh_token"] assert access_token2 == access_token1 - assert "403 TokenInvalid" not in caplog.text + assert "TokenInvalid" not in caplog.text assert get_me_response == { "user_id": "john", @@ -1936,12 +1942,16 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_refresh_token } -@pytest.mark.parametrize(["invalidate"], [ - (False,), - (True,), -]) +@pytest.mark.parametrize( + ["invalidate", "token_invalid_status_code"], + [ + (False, 403), + (True, 403), + (True, 401), + ], +) def test_authenticate_oidc_auto_renew_expired_access_token_initial_device_code( - requests_mock, refresh_token_store, invalidate, caplog, oidc_device_code_flow_checker + requests_mock, refresh_token_store, invalidate, token_invalid_status_code, caplog, oidc_device_code_flow_checker ): requests_mock.get(API_URL, json={"api_version": "1.0.0"}) client_id = "myclient" @@ -1971,7 +1981,9 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_device_code( }, scopes_supported=["openid"], ) - _setup_get_me_handler(requests_mock=requests_mock, oidc_mock=oidc_mock) + _setup_get_me_handler( + requests_mock=requests_mock, oidc_mock=oidc_mock, token_invalid_status_code=token_invalid_status_code + ) caplog.set_level(logging.INFO) # Explicit authentication with `authenticate_oidc_device` @@ -2013,7 +2025,7 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_device_code( "refresh_token", ] assert (access_token2, refresh_token2) != (access_token1, refresh_token1) - assert "OIDC access token expired (403 TokenInvalid)" in caplog.text + assert f"OIDC access token expired ({token_invalid_status_code} TokenInvalid)" in caplog.text assert "Obtained new access token (grant 'refresh_token')" in caplog.text else: assert [h["grant_type"] for h in oidc_mock.grant_request_history] == [ @@ -2029,8 +2041,15 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_device_code( } +@pytest.mark.parametrize( + ["token_invalid_status_code"], + [ + (403,), + (401,), + ], +) def test_authenticate_oidc_auto_renew_expired_access_token_invalid_refresh_token( - requests_mock, refresh_token_store, caplog, oidc_device_code_flow_checker + requests_mock, refresh_token_store, caplog, oidc_device_code_flow_checker, token_invalid_status_code ): requests_mock.get(API_URL, json={"api_version": "1.0.0"}) client_id = "myclient" @@ -2060,7 +2079,9 @@ def test_authenticate_oidc_auto_renew_expired_access_token_invalid_refresh_token }, scopes_supported=["openid"], ) - _setup_get_me_handler(requests_mock=requests_mock, oidc_mock=oidc_mock) + _setup_get_me_handler( + requests_mock=requests_mock, oidc_mock=oidc_mock, token_invalid_status_code=token_invalid_status_code + ) caplog.set_level(logging.INFO) # Explicit authentication with `authenticate_oidc_device` @@ -2090,13 +2111,14 @@ def test_authenticate_oidc_auto_renew_expired_access_token_invalid_refresh_token oidc_mock.expected_fields["refresh_token"] = "sorry-not-accepting-refresh-tokens-now" oidc_mock.expected_grant_type = "refresh_token" # Do request that requires auth headers triggers attempt to re-authenticate (which will fail) - assert "403 TokenInvalid" not in caplog.text + assert "TokenInvalid" not in caplog.text with pytest.raises( - OpenEoApiError, match=re.escape("[403] TokenInvalid: Authorization token has expired or is invalid.") + OpenEoApiError, + match=re.escape(f"[{token_invalid_status_code}] TokenInvalid: Authorization token has expired or is invalid."), ): conn.describe_account() - assert "OIDC access token expired (403 TokenInvalid)" in caplog.text + assert f"OIDC access token expired ({token_invalid_status_code} TokenInvalid)" in caplog.text assert "Failed to obtain new access token (grant 'refresh_token')" in caplog.text @@ -2151,14 +2173,15 @@ def get_me(request: requests.Request, context): @pytest.mark.parametrize( - ["invalidate"], + ["invalidate", "token_invalid_status_code"], [ - (False,), - (True,), + (False, 403), + (True, 403), + (True, 401), ], ) def test_authenticate_oidc_auto_renew_expired_access_token_initial_client_credentials( - requests_mock, refresh_token_store, invalidate, caplog + requests_mock, refresh_token_store, invalidate, token_invalid_status_code, caplog ): requests_mock.get(API_URL, json={"api_version": "1.0.0"}) client_id = "myclient" @@ -2176,7 +2199,9 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_client_creden oidc_issuer=issuer, ) - _setup_get_me_handler(requests_mock=requests_mock, oidc_mock=oidc_mock) + _setup_get_me_handler( + requests_mock=requests_mock, oidc_mock=oidc_mock, token_invalid_status_code=token_invalid_status_code + ) caplog.set_level(logging.INFO) # Explicit authentication with `authenticate_oidc_refresh_token` @@ -2211,12 +2236,12 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_client_creden "client_credentials", ] assert access_token2 != access_token1 - assert "OIDC access token expired (403 TokenInvalid)" in caplog.text + assert f"OIDC access token expired ({token_invalid_status_code} TokenInvalid)" in caplog.text assert "Obtained new access token (grant 'client_credentials')" in caplog.text else: assert [h["grant_type"] for h in oidc_mock.grant_request_history] == ["client_credentials"] assert access_token2 == access_token1 - assert "403 TokenInvalid" not in caplog.text + assert "TokenInvalid" not in caplog.text assert get_me_response == { "user_id": "john", @@ -2225,8 +2250,15 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_client_creden } +@pytest.mark.parametrize( + ["token_invalid_status_code"], + [ + (403,), + (401,), + ], +) def test_authenticate_oidc_auto_renew_expired_access_token_initial_client_credentials_blocked( - requests_mock, refresh_token_store, caplog + requests_mock, refresh_token_store, caplog, token_invalid_status_code ): requests_mock.get(API_URL, json={"api_version": "1.0.0"}) client_id = "myclient" @@ -2244,7 +2276,9 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_client_creden oidc_issuer=issuer, ) - _setup_get_me_handler(requests_mock=requests_mock, oidc_mock=oidc_mock) + _setup_get_me_handler( + requests_mock=requests_mock, oidc_mock=oidc_mock, token_invalid_status_code=token_invalid_status_code + ) caplog.set_level(logging.INFO) # Explicit authentication with `authenticate_oidc_refresh_token` @@ -2268,13 +2302,14 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_client_creden oidc_mock.invalidate_access_token() requests_mock.post(oidc_mock.token_endpoint, status_code=401, text="nope") # Do request that requires auth headers and might trigger re-authentication - assert "403 TokenInvalid" not in caplog.text + assert f"{token_invalid_status_code} TokenInvalid" not in caplog.text with pytest.raises( - OpenEoApiError, match=re.escape("[403] TokenInvalid: Authorization token has expired or is invalid.") + OpenEoApiError, + match=re.escape(f"[{token_invalid_status_code}] TokenInvalid: Authorization token has expired or is invalid."), ): conn.describe_account() - assert "OIDC access token expired (403 TokenInvalid)" in caplog.text + assert f"OIDC access token expired ({token_invalid_status_code} TokenInvalid)" in caplog.text assert "Failed to obtain new access token (grant 'client_credentials')" in caplog.text