Skip to content

Commit

Permalink
Issue #508 also refresh token on 401 TokenInvalid
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Nov 21, 2023
1 parent 8bf9114 commit bc9e852
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 35 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
Follow the official band mapping from Awesome Spectral Indices better.
Allow manually specifying the desired band mapping.
([#485](https://github.com/Open-EO/openeo-python-client/issues/485), [#501](https://github.com/Open-EO/openeo-python-client/issues/501))
- Also attempt to automatically refresh OIDC access token on a `401 TokenInvalid` response (in addition to `403 TokenInvalid`) ([#508](https://github.com/Open-EO/openeo-python-client/issues/508))


### Changed

Expand Down
2 changes: 1 addition & 1 deletion openeo/rest/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def _request():
# Initial request attempt
return _request()
except OpenEoApiError as api_exc:
if api_exc.http_status_code == 403 and api_exc.code == "TokenInvalid":
if api_exc.http_status_code in {401, 403} and api_exc.code == "TokenInvalid":
# Auth token expired: can we refresh?
if isinstance(self.auth, OidcBearerAuth) and self._oidc_auth_renewer:
msg = f"OIDC access token expired ({api_exc.http_status_code} {api_exc.code})."
Expand Down
103 changes: 69 additions & 34 deletions tests/rest/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1836,15 +1836,15 @@ def test_authenticate_oidc_method_client_credentials_from_env(
assert conn.auth.bearer == f"oidc/{expected_provider_id}/" + oidc_mock.state["access_token"]


def _setup_get_me_handler(requests_mock, oidc_mock: OidcMock):
def _setup_get_me_handler(requests_mock, oidc_mock: OidcMock, token_invalid_status_code: int = 403):
def get_me(request: requests.Request, context):
"""handler for `GET /me` (with access_token checking)"""
auth_header = request.headers["Authorization"]
oidc_provider, access_token = re.match(r"Bearer oidc/(?P<p>\w+)/(?P<a>.*)", auth_header).group("p", "a")
try:
user_id = oidc_mock.validate_access_token(access_token)["user_id"]
except LookupError:
context.status_code = 403
context.status_code = token_invalid_status_code
return {"code": "TokenInvalid", "message": "Authorization token has expired or is invalid."}

return {
Expand All @@ -1857,12 +1857,16 @@ def get_me(request: requests.Request, context):
requests_mock.get(API_URL + "me", json=get_me)


@pytest.mark.parametrize(["invalidate"], [
(False,),
(True,),
])
@pytest.mark.parametrize(
["invalidate", "token_invalid_status_code"],
[
(False, 403),
(True, 403),
(True, 401),
],
)
def test_authenticate_oidc_auto_renew_expired_access_token_initial_refresh_token(
requests_mock, refresh_token_store, invalidate, caplog
requests_mock, refresh_token_store, invalidate, token_invalid_status_code, caplog
):
requests_mock.get(API_URL, json={"api_version": "1.0.0"})
client_id = "myclient"
Expand All @@ -1888,7 +1892,9 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_refresh_token
oidc_issuer=oidc_issuer,
expected_fields={"refresh_token": initial_refresh_token}
)
_setup_get_me_handler(requests_mock=requests_mock, oidc_mock=oidc_mock)
_setup_get_me_handler(
requests_mock=requests_mock, oidc_mock=oidc_mock, token_invalid_status_code=token_invalid_status_code
)
caplog.set_level(logging.INFO)

# Explicit authentication with `authenticate_oidc_refresh_token`
Expand Down Expand Up @@ -1922,12 +1928,12 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_refresh_token
# Two "refresh_token" auth requests should have happened now
assert [h["grant_type"] for h in oidc_mock.grant_request_history] == ["refresh_token", "refresh_token"]
assert access_token2 != access_token1
assert "OIDC access token expired (403 TokenInvalid)" in caplog.text
assert f"OIDC access token expired ({token_invalid_status_code} TokenInvalid)" in caplog.text
assert "Obtained new access token (grant 'refresh_token')" in caplog.text
else:
assert [h["grant_type"] for h in oidc_mock.grant_request_history] == ["refresh_token"]
assert access_token2 == access_token1
assert "403 TokenInvalid" not in caplog.text
assert "TokenInvalid" not in caplog.text

assert get_me_response == {
"user_id": "john",
Expand All @@ -1936,12 +1942,16 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_refresh_token
}


@pytest.mark.parametrize(["invalidate"], [
(False,),
(True,),
])
@pytest.mark.parametrize(
["invalidate", "token_invalid_status_code"],
[
(False, 403),
(True, 403),
(True, 401),
],
)
def test_authenticate_oidc_auto_renew_expired_access_token_initial_device_code(
requests_mock, refresh_token_store, invalidate, caplog, oidc_device_code_flow_checker
requests_mock, refresh_token_store, invalidate, token_invalid_status_code, caplog, oidc_device_code_flow_checker
):
requests_mock.get(API_URL, json={"api_version": "1.0.0"})
client_id = "myclient"
Expand Down Expand Up @@ -1971,7 +1981,9 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_device_code(
},
scopes_supported=["openid"],
)
_setup_get_me_handler(requests_mock=requests_mock, oidc_mock=oidc_mock)
_setup_get_me_handler(
requests_mock=requests_mock, oidc_mock=oidc_mock, token_invalid_status_code=token_invalid_status_code
)
caplog.set_level(logging.INFO)

# Explicit authentication with `authenticate_oidc_device`
Expand Down Expand Up @@ -2013,7 +2025,7 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_device_code(
"refresh_token",
]
assert (access_token2, refresh_token2) != (access_token1, refresh_token1)
assert "OIDC access token expired (403 TokenInvalid)" in caplog.text
assert f"OIDC access token expired ({token_invalid_status_code} TokenInvalid)" in caplog.text
assert "Obtained new access token (grant 'refresh_token')" in caplog.text
else:
assert [h["grant_type"] for h in oidc_mock.grant_request_history] == [
Expand All @@ -2029,8 +2041,15 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_device_code(
}


@pytest.mark.parametrize(
["token_invalid_status_code"],
[
(403,),
(401,),
],
)
def test_authenticate_oidc_auto_renew_expired_access_token_invalid_refresh_token(
requests_mock, refresh_token_store, caplog, oidc_device_code_flow_checker
requests_mock, refresh_token_store, caplog, oidc_device_code_flow_checker, token_invalid_status_code
):
requests_mock.get(API_URL, json={"api_version": "1.0.0"})
client_id = "myclient"
Expand Down Expand Up @@ -2060,7 +2079,9 @@ def test_authenticate_oidc_auto_renew_expired_access_token_invalid_refresh_token
},
scopes_supported=["openid"],
)
_setup_get_me_handler(requests_mock=requests_mock, oidc_mock=oidc_mock)
_setup_get_me_handler(
requests_mock=requests_mock, oidc_mock=oidc_mock, token_invalid_status_code=token_invalid_status_code
)
caplog.set_level(logging.INFO)

# Explicit authentication with `authenticate_oidc_device`
Expand Down Expand Up @@ -2090,13 +2111,14 @@ def test_authenticate_oidc_auto_renew_expired_access_token_invalid_refresh_token
oidc_mock.expected_fields["refresh_token"] = "sorry-not-accepting-refresh-tokens-now"
oidc_mock.expected_grant_type = "refresh_token"
# Do request that requires auth headers triggers attempt to re-authenticate (which will fail)
assert "403 TokenInvalid" not in caplog.text
assert "TokenInvalid" not in caplog.text
with pytest.raises(
OpenEoApiError, match=re.escape("[403] TokenInvalid: Authorization token has expired or is invalid.")
OpenEoApiError,
match=re.escape(f"[{token_invalid_status_code}] TokenInvalid: Authorization token has expired or is invalid."),
):
conn.describe_account()

assert "OIDC access token expired (403 TokenInvalid)" in caplog.text
assert f"OIDC access token expired ({token_invalid_status_code} TokenInvalid)" in caplog.text
assert "Failed to obtain new access token (grant 'refresh_token')" in caplog.text


Expand Down Expand Up @@ -2151,14 +2173,15 @@ def get_me(request: requests.Request, context):


@pytest.mark.parametrize(
["invalidate"],
["invalidate", "token_invalid_status_code"],
[
(False,),
(True,),
(False, 403),
(True, 403),
(True, 401),
],
)
def test_authenticate_oidc_auto_renew_expired_access_token_initial_client_credentials(
requests_mock, refresh_token_store, invalidate, caplog
requests_mock, refresh_token_store, invalidate, token_invalid_status_code, caplog
):
requests_mock.get(API_URL, json={"api_version": "1.0.0"})
client_id = "myclient"
Expand All @@ -2176,7 +2199,9 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_client_creden
oidc_issuer=issuer,
)

_setup_get_me_handler(requests_mock=requests_mock, oidc_mock=oidc_mock)
_setup_get_me_handler(
requests_mock=requests_mock, oidc_mock=oidc_mock, token_invalid_status_code=token_invalid_status_code
)
caplog.set_level(logging.INFO)

# Explicit authentication with `authenticate_oidc_refresh_token`
Expand Down Expand Up @@ -2211,12 +2236,12 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_client_creden
"client_credentials",
]
assert access_token2 != access_token1
assert "OIDC access token expired (403 TokenInvalid)" in caplog.text
assert f"OIDC access token expired ({token_invalid_status_code} TokenInvalid)" in caplog.text
assert "Obtained new access token (grant 'client_credentials')" in caplog.text
else:
assert [h["grant_type"] for h in oidc_mock.grant_request_history] == ["client_credentials"]
assert access_token2 == access_token1
assert "403 TokenInvalid" not in caplog.text
assert "TokenInvalid" not in caplog.text

assert get_me_response == {
"user_id": "john",
Expand All @@ -2225,8 +2250,15 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_client_creden
}


@pytest.mark.parametrize(
["token_invalid_status_code"],
[
(403,),
(401,),
],
)
def test_authenticate_oidc_auto_renew_expired_access_token_initial_client_credentials_blocked(
requests_mock, refresh_token_store, caplog
requests_mock, refresh_token_store, caplog, token_invalid_status_code
):
requests_mock.get(API_URL, json={"api_version": "1.0.0"})
client_id = "myclient"
Expand All @@ -2244,7 +2276,9 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_client_creden
oidc_issuer=issuer,
)

_setup_get_me_handler(requests_mock=requests_mock, oidc_mock=oidc_mock)
_setup_get_me_handler(
requests_mock=requests_mock, oidc_mock=oidc_mock, token_invalid_status_code=token_invalid_status_code
)
caplog.set_level(logging.INFO)

# Explicit authentication with `authenticate_oidc_refresh_token`
Expand All @@ -2268,13 +2302,14 @@ def test_authenticate_oidc_auto_renew_expired_access_token_initial_client_creden
oidc_mock.invalidate_access_token()
requests_mock.post(oidc_mock.token_endpoint, status_code=401, text="nope")
# Do request that requires auth headers and might trigger re-authentication
assert "403 TokenInvalid" not in caplog.text
assert f"{token_invalid_status_code} TokenInvalid" not in caplog.text
with pytest.raises(
OpenEoApiError, match=re.escape("[403] TokenInvalid: Authorization token has expired or is invalid.")
OpenEoApiError,
match=re.escape(f"[{token_invalid_status_code}] TokenInvalid: Authorization token has expired or is invalid."),
):
conn.describe_account()

assert "OIDC access token expired (403 TokenInvalid)" in caplog.text
assert f"OIDC access token expired ({token_invalid_status_code} TokenInvalid)" in caplog.text
assert "Failed to obtain new access token (grant 'client_credentials')" in caplog.text


Expand Down

0 comments on commit bc9e852

Please sign in to comment.