Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: session refresh loop if expired token is passed in headers #73

Merged
merged 2 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.5.1] - 2024-06-13

### Changes

- Fixed session refresh loop caused by passing an expired access token in the Authorization header.

## [0.5.0] - 2024-06-06

### Changes
Expand Down
2 changes: 1 addition & 1 deletion app/build.gradle
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
apply plugin: 'com.android.library'
apply plugin: 'maven-publish'
def publishVersionID = "0.5.0"
def publishVersionID = "0.5.1"

android {
compileSdkVersion 32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,25 @@
public class SuperTokensCustomHttpURLConnection extends HttpURLConnection {
HttpURLConnection original;
Context applicationContext;
private boolean wasAuthHeaderRemovedInitially = false;

public SuperTokensCustomHttpURLConnection(HttpURLConnection original, Context applicationContext) {
super(original.getURL());
this.original = original;
this.applicationContext = applicationContext;
}

public SuperTokensCustomHttpURLConnection(HttpURLConnection original, Context applicationContext, boolean wasAuthHeaderRemovedInitially) {
super(original.getURL());
this.original = original;
this.applicationContext = applicationContext;
this.wasAuthHeaderRemovedInitially = wasAuthHeaderRemovedInitially;
}

public boolean getWasAuthHeaderRemovedInitially() {
return wasAuthHeaderRemovedInitially;
}

@Override
public void disconnect() {
original.disconnect();
Expand Down Expand Up @@ -249,9 +261,18 @@ public void setDefaultUseCaches(boolean defaultusecaches) {
}

private boolean shouldAllowSettingAuthHeader(String value) {
// This check ensures that if the authorization header was removed initially (because it matched the local access token),
// it remains removed in subsequent retries even after the session is refreshed.
// This prevents the use of an expired access token which would no longer match the updated local access token.
if (wasAuthHeaderRemovedInitially) {
return false;
}


String accessToken = Utils.getTokenForHeaderAuth(Utils.TokenType.ACCESS, applicationContext);
String refreshToken = Utils.getTokenForHeaderAuth(Utils.TokenType.REFRESH, applicationContext);
if (accessToken != null && refreshToken != null && value.equals("Bearer " + accessToken)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should check with ignore case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

wasAuthHeaderRemovedInitially = true;
// We ignore the attempt to set the header because it matches the existing access token
// which will get added by the SDK
return false;
Expand All @@ -275,6 +296,14 @@ public void setRequestProperty(String key, String value, boolean force) {
original.setRequestProperty(key, value);
}

// Sets the authorization header without performing the "shouldAllowSettingAuthHeader" check.
// This bypass is necessary because the "shouldAllowSettingAuthHeader" function tracks whether
// setting the auth header was disallowed, which is only intended for custom headers set by the user,
// not for headers set by our library code.
public void setRequestPropertyIgnoringOverride(String key, String value) {
original.setRequestProperty(key, value);
}

public void setRequestProperty(String key, String value) {
setRequestProperty(key, value, false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import java.net.HttpURLConnection;
import java.net.URISyntaxException;
import java.net.URL;
import java.net.URLConnection;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand All @@ -37,16 +36,21 @@ public class SuperTokensHttpURLConnection {
private static final ReentrantReadWriteLock refreshAPILock = new ReentrantReadWriteLock();

private static void setAuthorizationHeaderIfRequired(SuperTokensCustomHttpURLConnection connection, Context context) {
Map<String, String> headersToSet = Utils.getAuthorizationHeaderIfRequired(context);
for (Map.Entry<String, String> entry: headersToSet.entrySet()) {
connection.setRequestProperty(entry.getKey(), entry.getValue(), true);
String authHeader = Utils.getAuthorizationHeaderIfExists(false, context);

// NOTE: We do not check for existing Auth headers here because they are added after this function runs.
// The `setRequestProperty` method in SuperTokensCustomHttpURLConnection is overridden to prevent users from adding
// an auth header that matches the locally stored access token.
if (authHeader != null) {
connection.setRequestPropertyIgnoringOverride("Authorization", authHeader);
}
}

private static void setAuthorizationHeaderIfRequiredForRefresh(HttpURLConnection connection, Context context) {
Map<String, String> headersToSet = Utils.getAuthorizationHeaderIfRequired(true, context);
for (Map.Entry<String, String> entry: headersToSet.entrySet()) {
connection.setRequestProperty(entry.getKey(), entry.getValue());
String authHeader = Utils.getAuthorizationHeaderIfExists(true, context);
// NOTE: Checking for an existing auth header is not necessary for a refresh API call.
if (authHeader != null) {
connection.setRequestProperty("Authorization", authHeader);
}
}

Expand Down Expand Up @@ -130,16 +134,17 @@ public static HttpURLConnection newRequest(URL url, PreConnectCallback preConnec

try {
int sessionRefreshAttempts = 0;
HttpURLConnection connection;
SuperTokensCustomHttpURLConnection customConnection = null;
while (true) {
HttpURLConnection connection;
SuperTokensCustomHttpURLConnection customConnection;
Utils.LocalSessionState preRequestLocalSessionState;
int responseCode;
// TODO: write comment as to why we have this lock here. Do we also have this lock for iOS and website package?
refreshAPILock.readLock().lock();
try {
boolean wasAuthHeaderRemovedInitially = customConnection != null && customConnection.getWasAuthHeaderRemovedInitially();
connection = (HttpURLConnection) url.openConnection();
customConnection = new SuperTokensCustomHttpURLConnection(connection, applicationContext);
customConnection = new SuperTokensCustomHttpURLConnection(connection, applicationContext, wasAuthHeaderRemovedInitially);

// Add antiCSRF token, if present in storage, to the request headers
preRequestLocalSessionState = Utils.getLocalSessionState(applicationContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ public class SuperTokensInterceptor implements Interceptor {
private static final Object refreshTokenLock = new Object();
private static final ReentrantReadWriteLock refreshAPILock = new ReentrantReadWriteLock();

private Request removeAuthHeaderIfMatchesLocalToken(Request request, Request.Builder builder, Context context) {

// Returns true authorization header in the provided request matches the current local access token.
// This is used to decide whether the authorization header should be removed before making the request.
private boolean shouldRemoveAuthHeader(Request request, Context context) {
String originalHeader = request.header("Authorization");

if (originalHeader == null) {
Expand All @@ -46,18 +49,20 @@ private Request removeAuthHeaderIfMatchesLocalToken(Request request, Request.Bui
String refreshToken = Utils.getTokenForHeaderAuth(Utils.TokenType.REFRESH, context);

if (accessToken != null && refreshToken != null && originalHeader.equals("Bearer " + accessToken)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shuold use ignore case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

builder.removeHeader("Authorization");
builder.removeHeader("authorization");
return true;
}
}

return builder.build();
return false;
}

private static Request setAuthorizationHeaderIfRequired(Request.Builder builder, Context context, boolean addRefreshToken) {
Map<String, String> headersToSet = Utils.getAuthorizationHeaderIfRequired(addRefreshToken, context);
for (Map.Entry<String, String> entry: headersToSet.entrySet()) {
builder.header(entry.getKey(), entry.getValue());

private static Request setAuthorizationHeaderIfRequired(Request request, Request.Builder builder, Context context, boolean addRefreshToken) {
String authHeader = Utils.getAuthorizationHeaderIfExists(addRefreshToken, context);
boolean hasExistingAuthHeader = request.header("Authorization") != null || request.header("authorization") != null;

if (authHeader != null && !hasExistingAuthHeader) {
builder.header("Authorization", authHeader);
}

return builder.build();
Expand Down Expand Up @@ -98,6 +103,7 @@ public Response intercept(@NotNull Chain chain) throws IOException {
}

try {
boolean wasAuthHeaderRemovedInitially = false;
int sessionRefreshAttempts = 0;
while (true) {
Request.Builder requestBuilder = chain.request().newBuilder();
Expand All @@ -120,8 +126,19 @@ public Response intercept(@NotNull Chain chain) throws IOException {
request = request.newBuilder().header("rid", "anti-csrf").build();
}

request = removeAuthHeaderIfMatchesLocalToken(request, request.newBuilder(), applicationContext);
request = setAuthorizationHeaderIfRequired(request.newBuilder(), applicationContext, false);
// Check if the Authorization header should be removed
// This is necessary to ensure that if the auth header was removed initially,
// it remains removed in subsequent retries even if the token has changed.
if (wasAuthHeaderRemovedInitially || shouldRemoveAuthHeader(request, applicationContext)) {
Request.Builder builder = request.newBuilder();
builder.removeHeader("Authorization");
builder.removeHeader("authorization");
request = builder.build();

wasAuthHeaderRemovedInitially = true;
}

request = setAuthorizationHeaderIfRequired(request, request.newBuilder(), applicationContext, false);

response = makeRequest(chain, request);
Utils.saveTokenFromHeaders(response, applicationContext);
Expand Down Expand Up @@ -207,6 +224,8 @@ private static Utils.Unauthorised onUnauthorisedResponse(Utils.LocalSessionState
Request.Builder refreshRequestBuilder = new Request.Builder();
refreshRequestBuilder.url(SuperTokens.refreshTokenUrl);
refreshRequestBuilder.method("POST", new FormBody.Builder().build());

Request refreshRequest = refreshRequestBuilder.build();

if (preRequestLocalSessionState.status == Utils.LocalSessionStateStatus.EXISTS) {
String antiCSRFToken = AntiCSRF.getToken(applicationContext, preRequestLocalSessionState.lastAccessTokenUpdate);
Expand All @@ -220,7 +239,7 @@ private static Utils.Unauthorised onUnauthorisedResponse(Utils.LocalSessionState
refreshRequestBuilder.header("fdi-version", Utils.join(Version.supported_fdi, ","));
refreshRequestBuilder.header("st-auth-mode", SuperTokens.config.tokenTransferMethod);

refreshRequestBuilder = setAuthorizationHeaderIfRequired(refreshRequestBuilder, applicationContext, true).newBuilder();
refreshRequestBuilder = setAuthorizationHeaderIfRequired(refreshRequest, refreshRequestBuilder, applicationContext, true).newBuilder();

Map<String, String> customRefreshHeaders = SuperTokens.config.customHeaderMapper.getRequestHeaders(CustomHeaderProvider.RequestType.REFRESH);
if (customRefreshHeaders != null) {
Expand All @@ -229,7 +248,7 @@ private static Utils.Unauthorised onUnauthorisedResponse(Utils.LocalSessionState
}
}

Request refreshRequest = refreshRequestBuilder.build();
refreshRequest = refreshRequestBuilder.build();
refreshResponse = makeRequest(chain, refreshRequest);

Utils.saveTokenFromHeaders(refreshResponse, applicationContext);
Expand Down
21 changes: 6 additions & 15 deletions app/src/main/java/com/supertokens/session/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -319,10 +319,6 @@ public static String getTokenForHeaderAuth(TokenType tokenType, Context context)
return getFromStorage(name, context);
}

public static Map<String, String> getAuthorizationHeaderIfRequired(Context context) {
return getAuthorizationHeaderIfRequired(false, context);
}

// Checks if a key exists in a map regardless of case
public static <T> T getIgnoreCase(Map<String, T> map, String key) {
for (Map.Entry<String, T> entry : map.entrySet()) {
Expand All @@ -332,33 +328,28 @@ public static <T> T getIgnoreCase(Map<String, T> map, String key) {
return null;
}

public static Map<String, String> getAuthorizationHeaderIfRequired(boolean addRefreshToken, Context context) {
// We set the Authorization header even if the tokenTransferMethod preference
public static String getAuthorizationHeaderIfExists(boolean addRefreshToken, Context context) {
// We return the Authorization header even if the tokenTransferMethod preference
// set in the config is cookies
// since the active session may be using cookies. By default, we want to allow
// users to continue these sessions.
// The new session preference should be applied at the start of the next
// session, if the backend allows it.
Map<String, String> headers = new HashMap<>();
String accessToken = getTokenForHeaderAuth(TokenType.ACCESS, context);
String refreshToken = getTokenForHeaderAuth(TokenType.REFRESH, context);

// We don't always need the refresh token because that's only required by the
// refresh call
// Still, we only add the Authorization header if both are present, because we
// Still, we only return the Authorization header if both are present, because we
// are planning to add an option to expose the
// access token to the frontend while using cookie based auth - so that users
// can get the access token to use
if (accessToken != null && refreshToken != null) {
if (getIgnoreCase(headers, "Authorization") != null) {
// no-op
} else {
String tokenToAdd = addRefreshToken ? refreshToken : accessToken;
headers.put("Authorization", "Bearer " + tokenToAdd);
}
String tokenToAdd = addRefreshToken ? refreshToken : accessToken;
return "Bearer " + tokenToAdd;
}

return headers;
return null;
}

public static void fireSessionUpdateEventsIfNecessary(
Expand Down
2 changes: 1 addition & 1 deletion examples/with-thirdparty/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencyResolutionManagement {
Add the following to your app level `build.gradle`

```gradle
implementation("com.github.supertokens:supertokens-android:0.5.0")
implementation("com.github.supertokens:supertokens-android:0.5.1")
implementation ("com.google.android.gms:play-services-auth:20.7.0")
implementation("com.squareup.retrofit2:retrofit:2.9.0")
implementation("net.openid:appauth:0.11.1")
Expand Down
2 changes: 1 addition & 1 deletion examples/with-thirdparty/app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ dependencies {
implementation("androidx.appcompat:appcompat:1.6.1")
implementation("com.google.android.material:material:1.8.0")
implementation("androidx.constraintlayout:constraintlayout:2.1.4")
implementation("com.github.supertokens:supertokens-android:0.5.0")
implementation("com.github.supertokens:supertokens-android:0.5.1")
implementation ("com.google.android.gms:play-services-auth:20.7.0")
implementation("com.squareup.retrofit2:retrofit:2.9.0")
implementation("net.openid:appauth:0.11.1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ public void doAction(HttpURLConnection con) throws IOException {
}

@Test
public void httpUrlConnection_testThatAuthHeaderIsNotIgnoredEvenIfItMatchesTheStoredAccessToken() throws Exception {
public void httpUrlConnection_testThatAuthHeaderIsNotIgnoredIfItDoesntMatchTheStoredAccessToken() throws Exception {
com.example.TestUtils.startST();
new SuperTokens.Builder(context, Constants.apiDomain).build();

Expand All @@ -931,9 +931,6 @@ public void doAction(HttpURLConnection con) throws IOException {

loginRequestConnection.disconnect();

Thread.sleep(5000);
Utils.setToken(Utils.TokenType.ACCESS, "myOwnHeHe", context);

HttpURLConnection connection = SuperTokensHttpURLConnection.newRequest(new URL(baseCustomAuthUrl), new SuperTokensHttpURLConnection.PreConnectCallback() {
@Override
public void doAction(HttpURLConnection con) throws IOException {
Expand Down Expand Up @@ -1104,4 +1101,63 @@ public void doAction(HttpURLConnection con) throws IOException {
throw new Exception("Expected session refresh endpoint to be called 0 times but it was called " + sessionRefreshCalledCount + " times");
}
}

@Test
public void httpUrlConnection_shouldNotEndUpInRefreshLoopIfExpiredAccessTokenWasPassedInHeaders() throws Exception{
com.example.TestUtils.startST(1, true, 144000);
new SuperTokens.Builder(context, Constants.apiDomain).build();

//login request
HttpURLConnection loginRequestConnection = SuperTokensHttpURLConnection.newRequest(new URL(loginAPIURL), new SuperTokensHttpURLConnection.PreConnectCallback() {
@Override
public void doAction(HttpURLConnection con) throws IOException {
con.setDoOutput(true);
con.setRequestMethod("POST");
con.setRequestProperty("Accept", "application/json");
con.setRequestProperty("Content-Type", "application/json");

JsonObject bodyJson = new JsonObject();
bodyJson.addProperty("userId", Constants.userId);

OutputStream outputStream = con.getOutputStream();
outputStream.write(bodyJson.toString().getBytes(StandardCharsets.UTF_8));
outputStream.close();
}
});

if (loginRequestConnection.getResponseCode() != 200) {
throw new Exception("Login request failed");
}

loginRequestConnection.disconnect();

String expiredAccessToken = Utils.getTokenForHeaderAuth(Utils.TokenType.ACCESS, context);

// wait for access token expiry
Thread.sleep(2000);

int sessionRefreshCalledCount = com.example.TestUtils.getRefreshTokenCounter();
if (sessionRefreshCalledCount != 0) {
throw new Exception("Expected session refresh endpoint to be called 0 times but it was called " + sessionRefreshCalledCount + " times");
}

HttpURLConnection userInfoRequestConnection = SuperTokensHttpURLConnection.newRequest(new URL(userInfoAPIURL), new SuperTokensHttpURLConnection.PreConnectCallback() {
@Override
public void doAction(HttpURLConnection con) throws IOException {
con.setRequestMethod("GET");
con.setRequestProperty("Authorization", "Bearer " + expiredAccessToken);
rishabhpoddar marked this conversation as resolved.
Show resolved Hide resolved
}
});

if (userInfoRequestConnection.getResponseCode() != 200) {
throw new Exception("userInfo api failed");
}

sessionRefreshCalledCount = com.example.TestUtils.getRefreshTokenCounter();
if (sessionRefreshCalledCount != 1) {
throw new Exception("Expected session refresh endpoint to be called 1 time but it was called " + sessionRefreshCalledCount + " times");
}

userInfoRequestConnection.disconnect();
}
}
Loading
Loading