Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add validate_access_token function to providers #453

Merged
merged 10 commits into from
Oct 5, 2023
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased]

IamMayankThakur marked this conversation as resolved.
Show resolved Hide resolved
## [0.16.4] - 2023-10-05

- Add `validate_access_token` function to providers
- This can be used to verify the access token received from providers.
- Implemented `validate_access_token` for the Github provider.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def parse_tenant_config(tenant: Dict[str, Any]) -> TenantConfigResponse:
user_info_map=user_info_map,
require_email=p.get("requireEmail", True),
validate_id_token_payload=None,
validate_access_token=None,
generate_fake_email=None,
validate_access_token=None,
)
)

Expand Down
24 changes: 12 additions & 12 deletions supertokens_python/recipe/thirdparty/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,15 @@ def __init__(
Awaitable[None],
]
] = None,
generate_fake_email: Optional[
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
] = None,
validate_access_token: Optional[
rishabhpoddar marked this conversation as resolved.
Show resolved Hide resolved
Callable[
[str, ProviderConfigForClient, Dict[str, Any]],
Awaitable[None],
]
] = None,
generate_fake_email: Optional[
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
] = None,
):
self.third_party_id = third_party_id
self.name = name
Expand All @@ -197,8 +197,8 @@ def __init__(
self.user_info_map = user_info_map
self.require_email = require_email
self.validate_id_token_payload = validate_id_token_payload
self.validate_access_token = validate_access_token
self.generate_fake_email = generate_fake_email
self.validate_access_token = validate_access_token

def to_json(self) -> Dict[str, Any]:
res = {
Expand Down Expand Up @@ -254,15 +254,15 @@ def __init__(
Awaitable[None],
]
] = None,
generate_fake_email: Optional[
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
] = None,
validate_access_token: Optional[
Callable[
[str, ProviderConfigForClient, Dict[str, Any]],
Awaitable[None],
]
] = None,
generate_fake_email: Optional[
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
] = None,
):
ProviderClientConfig.__init__(
self,
Expand All @@ -289,8 +289,8 @@ def __init__(
user_info_map,
require_email,
validate_id_token_payload,
validate_access_token,
generate_fake_email,
validate_access_token,
)

def to_json(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -324,15 +324,15 @@ def __init__(
Awaitable[None],
]
] = None,
generate_fake_email: Optional[
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
] = None,
validate_access_token: Optional[
Callable[
[str, ProviderConfigForClient, Dict[str, Any]],
Awaitable[None],
]
] = None,
generate_fake_email: Optional[
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
] = None,
):
super().__init__(
third_party_id,
Expand All @@ -349,8 +349,8 @@ def __init__(
user_info_map,
require_email,
validate_id_token_payload,
validate_access_token,
generate_fake_email,
validate_access_token,
)
self.clients = clients

Expand Down
14 changes: 7 additions & 7 deletions supertokens_python/recipe/thirdparty/providers/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def get_provider_config_for_client(
user_info_map=config.user_info_map,
require_email=config.require_email,
validate_id_token_payload=config.validate_id_token_payload,
validate_access_token=config.validate_access_token,
generate_fake_email=config.generate_fake_email,
validate_access_token=config.validate_access_token,
)


Expand Down Expand Up @@ -403,7 +403,12 @@ async def get_user_info(
user_context,
)

if access_token is not None and self.config.token_endpoint is not None:
if self.config.validate_access_token is not None and access_token is not None:
await self.config.validate_access_token(
access_token, self.config, user_context
)

if access_token is not None and self.config.user_info_endpoint is not None:
rishabhpoddar marked this conversation as resolved.
Show resolved Hide resolved
headers: Dict[str, str] = {"Authorization": f"Bearer {access_token}"}
query_params: Dict[str, str] = {}

Expand All @@ -422,11 +427,6 @@ async def get_user_info(
self.config.user_info_endpoint, query_params, headers
)

if self.config.validate_access_token is not None and access_token is not None:
await self.config.validate_access_token(
access_token, self.config, user_context
)

user_info_result = get_supertokens_user_info_result_from_raw_user_info(
self.config, raw_user_info_from_provider
)
Expand Down
19 changes: 8 additions & 11 deletions supertokens_python/recipe/thirdparty/providers/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from __future__ import annotations

import base64
import json
from typing import Any, Dict, List, Optional

import requests

from supertokens_python.recipe.thirdparty.providers.utils import do_get_request
from supertokens_python.recipe.thirdparty.providers.utils import (
do_get_request,
do_post_request,
)
from supertokens_python.recipe.thirdparty.types import UserInfo, UserInfoEmail

from .custom import GenericProvider, NewProvider
Expand Down Expand Up @@ -95,14 +95,11 @@ async def validate_access_token(
"Authorization": f"Basic {basic_auth_token}",
"Content-Type": "application/json",
}
payload = json.dumps({"access_token": access_token})

resp = requests.post(url, headers=headers, data=payload)

if resp.status_code != 200:
try:
body = await do_post_request(url, {"access_token": access_token}, headers)
IamMayankThakur marked this conversation as resolved.
Show resolved Hide resolved
except Exception:
raise ValueError("Invalid access token")

body = resp.json()

if "app" not in body or body["app"]["client_id"] != config.client_id:
if "app" not in body or body["app"].get("client_id") != config.client_id:
raise ValueError("Access token does not belong to your application")
131 changes: 67 additions & 64 deletions tests/thirdparty/test_thirdparty.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import respx
from fastapi import FastAPI
from pytest import fixture, mark
from pytest_mock import MockerFixture
from starlette.testclient import TestClient

from supertokens_python import init
Expand Down Expand Up @@ -106,18 +107,6 @@ async def exchange_auth_code_for_valid_oauth_tokens( # pylint: disable=unused-a
}


async def get_user_info( # pylint: disable=unused-argument
oauth_tokens: Dict[str, Any],
user_context: Dict[str, Any],
) -> UserInfo:
time = str(datetime.datetime.now())
return UserInfo(
"" + time,
UserInfoEmail(f"johndoeprovidertest+{time}@supertokens.com", True),
RawUserInfoFromProvider({}, {}),
)


async def exchange_auth_code_for_invalid_oauth_tokens( # pylint: disable=unused-argument
redirect_uri_info: RedirectUriInfo,
user_context: Dict[str, Any],
Expand All @@ -139,7 +128,6 @@ def get_custom_valid_token_provider(provider: Provider) -> Provider:
provider.exchange_auth_code_for_oauth_tokens = (
exchange_auth_code_for_valid_oauth_tokens
)
provider.get_user_info = get_user_info
return provider


Expand All @@ -153,7 +141,9 @@ async def invalid_access_token( # pylint: disable=unused-argument


async def valid_access_token( # pylint: disable=unused-argument
access_token: str, config: ProviderConfig, user_context: Optional[Dict[str, Any]]
access_token: str,
config: ProviderConfigForClient,
user_context: Optional[Dict[str, Any]],
):
if access_token == "accesstoken":
IamMayankThakur marked this conversation as resolved.
Show resolved Hide resolved
return
Expand Down Expand Up @@ -210,53 +200,66 @@ async def test_signinup_when_validate_access_token_throws(fastapi_client: TestCl
assert res.status_code == 500


# async def test_signinup_works_when_validate_access_token_does_not_throw(fastapi_client: TestClient):
# st_init_args = {
# **st_init_common_args,
# "recipe_list": [
# session.init(),
# thirdpartyemailpassword.init(
# providers=[
# ProviderInput(
# config=ProviderConfig(
# third_party_id="custom",
# clients=[
# ProviderClientConfig(
# client_id="test",
# client_secret="test-secret",
# scope=["profile", "email"],
# ),
# ],
# authorization_endpoint="https://example.com/oauth/authorize",
# validate_access_token=valid_access_token,
# authorization_endpoint_query_params={
# "response_type": "token", # Changing an existing parameter
# "response_mode": "form", # Adding a new parameter
# "scope": None, # Removing a parameter
# },
# token_endpoint="https://example.com/oauth/token",
# ),
# override=get_custom_valid_token_provider
# )
# ]
# ),
# ],
# }
#
# init(**st_init_args) # type: ignore
# start_st()
#
# res = fastapi_client.post(
# "/auth/signinup",
# json={
# "thirdPartyId": "custom",
# "redirectURIInfo": {
# "redirectURIOnProviderDashboard": "http://127.0.0.1/callback",
# "redirectURIQueryParams": {
# "code": "abcdefghj",
# },
# },
# }
# )
# assert res.status_code == 200
# assert res.json()["status"] == "OK"
async def test_signinup_works_when_validate_access_token_does_not_throw(
fastapi_client: TestClient, mocker: MockerFixture
):
time = str(datetime.datetime.now())
mocker.patch(
"supertokens_python.recipe.thirdparty.providers.custom.get_supertokens_user_info_result_from_raw_user_info",
return_value=UserInfo(
"" + time,
UserInfoEmail(f"johndoeprovidertest+{time}@supertokens.com", True),
RawUserInfoFromProvider({}, {}),
),
)

st_init_args = {
**st_init_common_args,
"recipe_list": [
session.init(),
thirdpartyemailpassword.init(
providers=[
ProviderInput(
config=ProviderConfig(
third_party_id="custom",
clients=[
ProviderClientConfig(
client_id="test",
client_secret="test-secret",
scope=["profile", "email"],
),
],
authorization_endpoint="https://example.com/oauth/authorize",
validate_access_token=valid_access_token,
authorization_endpoint_query_params={
"response_type": "token", # Changing an existing parameter
"response_mode": "form", # Adding a new parameter
"scope": None, # Removing a parameter
},
token_endpoint="https://example.com/oauth/token",
),
override=get_custom_valid_token_provider,
)
]
),
],
}

init(**st_init_args) # type: ignore
start_st()

res = fastapi_client.post(
"/auth/signinup",
json={
"thirdPartyId": "custom",
"redirectURIInfo": {
"redirectURIOnProviderDashboard": "http://127.0.0.1/callback",
"redirectURIQueryParams": {
"code": "abcdefghj",
},
},
},
)

assert res.status_code == 200
assert res.json()["status"] == "OK"