From 415563156b810bc2e64701daa0d84b6f6d9adc73 Mon Sep 17 00:00:00 2001
From: Nemi Shah <nemishah1212@gmail.com>
Date: Fri, 29 Sep 2023 12:11:35 +0530
Subject: [PATCH] Refactor based on PR comments

---
 .../recipe/thirdparty/providers/custom.js     | 19 +++++++----------
 .../recipe/thirdparty/providers/google.js     |  7 -------
 lib/ts/recipe/thirdparty/providers/custom.ts  | 21 +++++++------------
 lib/ts/recipe/thirdparty/providers/google.ts  |  8 -------
 4 files changed, 15 insertions(+), 40 deletions(-)

diff --git a/lib/build/recipe/thirdparty/providers/custom.js b/lib/build/recipe/thirdparty/providers/custom.js
index 8d1ee53dc..3d83139fc 100644
--- a/lib/build/recipe/thirdparty/providers/custom.js
+++ b/lib/build/recipe/thirdparty/providers/custom.js
@@ -277,6 +277,13 @@ function NewProvider(input) {
                     });
                 }
             }
+            if (impl.config.validateAccessToken !== undefined && accessToken !== undefined) {
+                await impl.config.validateAccessToken({
+                    accessToken: accessToken,
+                    clientConfig: impl.config,
+                    userContext,
+                });
+            }
             if (accessToken && impl.config.userInfoEndpoint !== undefined) {
                 const headers = {
                     Authorization: "Bearer " + accessToken,
@@ -307,18 +314,6 @@ function NewProvider(input) {
                 );
                 rawUserInfoFromProvider.fromUserInfoAPI = userInfoFromAccessToken;
             }
-            /**
-             * This is intentionally not part of the above if block. This is because the user may want to validate the access
-             * token payload even if the user info API has not been provided by the provider. In this case they would get an
-             * empty object and they can fail if they always expect a non-empty object.
-             */
-            if (impl.config.validateAccessToken !== undefined) {
-                await impl.config.validateAccessToken({
-                    accessToken: accessToken,
-                    clientConfig: impl.config,
-                    userContext,
-                });
-            }
             const userInfoResult = getSupertokensUserInfoResultFromRawUserInfo(impl.config, rawUserInfoFromProvider);
             return {
                 thirdPartyUserId: userInfoResult.thirdPartyUserId,
diff --git a/lib/build/recipe/thirdparty/providers/google.js b/lib/build/recipe/thirdparty/providers/google.js
index 0b9ec3203..db1139473 100644
--- a/lib/build/recipe/thirdparty/providers/google.js
+++ b/lib/build/recipe/thirdparty/providers/google.js
@@ -17,13 +17,6 @@ function Google(input) {
         { included_grant_scopes: "true", access_type: "offline" },
         input.config.authorizationEndpointQueryParams
     );
-    // if (input.config.validateAccessToken === undefined) {
-    //     input.config.validateAccessToken = async ({ accessTokenPayload, clientConfig }) => {
-    //         if (accessTokenPayload.aud !== clientConfig.clientId) {
-    //             throw Error("accessTokenPayload.aud does not match clientId");
-    //         }
-    //     };
-    // }
     const oOverride = input.override;
     input.override = function (originalImplementation) {
         const oGetConfig = originalImplementation.getConfigForClientType;
diff --git a/lib/ts/recipe/thirdparty/providers/custom.ts b/lib/ts/recipe/thirdparty/providers/custom.ts
index b6e2ff849..7a9af3a65 100644
--- a/lib/ts/recipe/thirdparty/providers/custom.ts
+++ b/lib/ts/recipe/thirdparty/providers/custom.ts
@@ -305,6 +305,14 @@ export default function NewProvider(input: ProviderInput): TypeProvider {
                 }
             }
 
+            if (impl.config.validateAccessToken !== undefined && accessToken !== undefined) {
+                await impl.config.validateAccessToken({
+                    accessToken: accessToken,
+                    clientConfig: impl.config,
+                    userContext,
+                });
+            }
+
             if (accessToken && impl.config.userInfoEndpoint !== undefined) {
                 const headers: { [key: string]: string } = {
                     Authorization: "Bearer " + accessToken,
@@ -335,19 +343,6 @@ export default function NewProvider(input: ProviderInput): TypeProvider {
                 rawUserInfoFromProvider.fromUserInfoAPI = userInfoFromAccessToken;
             }
 
-            /**
-             * This is intentionally not part of the above if block. This is because the user may want to validate the access
-             * token payload even if the user info API has not been provided by the provider. In this case they would get an
-             * empty object and they can fail if they always expect a non-empty object.
-             */
-            if (impl.config.validateAccessToken !== undefined) {
-                await impl.config.validateAccessToken({
-                    accessToken: accessToken,
-                    clientConfig: impl.config,
-                    userContext,
-                });
-            }
-
             const userInfoResult = getSupertokensUserInfoResultFromRawUserInfo(impl.config, rawUserInfoFromProvider);
 
             return {
diff --git a/lib/ts/recipe/thirdparty/providers/google.ts b/lib/ts/recipe/thirdparty/providers/google.ts
index 0687f011b..378e8d067 100644
--- a/lib/ts/recipe/thirdparty/providers/google.ts
+++ b/lib/ts/recipe/thirdparty/providers/google.ts
@@ -30,14 +30,6 @@ export default function Google(input: ProviderInput): TypeProvider {
         ...input.config.authorizationEndpointQueryParams,
     };
 
-    // if (input.config.validateAccessToken === undefined) {
-    //     input.config.validateAccessToken = async ({ accessTokenPayload, clientConfig }) => {
-    //         if (accessTokenPayload.aud !== clientConfig.clientId) {
-    //             throw Error("accessTokenPayload.aud does not match clientId");
-    //         }
-    //     };
-    // }
-
     const oOverride = input.override;
 
     input.override = function (originalImplementation) {