Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

## [11.2.1]

- Fixes race issues with Refreshing OAuth token

## [11.2.0]

- Adds opentelemetry-javaagent to the core distribution
Expand Down
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ java {
}
}

version = "11.2.0"
version = "11.2.1"

repositories {
mavenCentral()
Expand Down
98 changes: 67 additions & 31 deletions src/main/java/io/supertokens/webserver/api/oauth/OAuthTokenAPI.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
import java.security.spec.InvalidKeySpecException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

public class OAuthTokenAPI extends WebserverAPI {

Expand Down Expand Up @@ -90,6 +92,32 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I

String authorizationHeader = InputParser.parseStringOrThrowError(input, "authorizationHeader", true);

if (grantType.equals("refresh_token")) {
String refreshTokenForLock = InputParser.parseStringOrThrowError(bodyFromSDK, "refresh_token", false);
NamedLockObject entry = lockMap.computeIfAbsent(refreshTokenForLock, k -> new NamedLockObject());
try {
entry.refCount.incrementAndGet();
synchronized (entry.obj) {
handle(req, resp, authorizationHeader, bodyFromSDK, grantType, iss, accessTokenUpdate,
idTokenUpdate,
useDynamicKey);
}
} finally {
entry.refCount.decrementAndGet();
if (entry.refCount.get() == 0) {
lockMap.remove(refreshTokenForLock, entry);
}
}

} else {
handle(req, resp, authorizationHeader, bodyFromSDK, grantType, iss, accessTokenUpdate, idTokenUpdate,
useDynamicKey);
}
}

private void handle(HttpServletRequest req, HttpServletResponse resp, String authorizationHeader,
JsonObject bodyFromSDK, String grantType, String iss, JsonObject accessTokenUpdate,
JsonObject idTokenUpdate, boolean useDynamicKey) throws ServletException, IOException {
Map<String, String> headers = new HashMap<>();
if (authorizationHeader != null) {
headers.put("Authorization", authorizationHeader);
Expand Down Expand Up @@ -127,34 +155,34 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I
formFieldsForTokenIntrospect.put("token", internalRefreshToken);

HttpRequestForOAuthProvider.Response response = OAuthProxyHelper.proxyFormPOST(
main, req, resp,
appIdentifier,
storage,
null, // clientIdToCheck
"/admin/oauth2/introspect", // pathProxy
true, // proxyToAdmin
false, // camelToSnakeCaseConversion
formFieldsForTokenIntrospect,
new HashMap<>() // headers
main, req, resp,
appIdentifier,
storage,
null, // clientIdToCheck
"/admin/oauth2/introspect", // pathProxy
true, // proxyToAdmin
false, // camelToSnakeCaseConversion
formFieldsForTokenIntrospect,
new HashMap<>() // headers
);

if (response == null) {
return; // proxy helper would have sent the error response
return;
}

JsonObject refreshTokenPayload = response.jsonResponse.getAsJsonObject();

try {
OAuth.verifyAndUpdateIntrospectRefreshTokenPayload(main, appIdentifier, storage, refreshTokenPayload, refreshToken, oauthClient.clientId);
} catch (StorageQueryException | TenantOrAppNotFoundException |
FeatureNotEnabledException | InvalidConfigException e) {
FeatureNotEnabledException | InvalidConfigException e) {
throw new ServletException(e);
}

if (!refreshTokenPayload.get("active").getAsBoolean()) {
// this is what ory would return for an invalid token
OAuthProxyHelper.handleOAuthAPIException(resp, new OAuthAPIException(
"token_inactive", "Token is inactive because it is malformed, expired or otherwise invalid. Token validation failed.", 401
"token_inactive", "Token is inactive because it is malformed, expired or otherwise invalid. Token validation failed.", 401
));
return;
}
Expand All @@ -163,20 +191,21 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I
}

HttpRequestForOAuthProvider.Response response = OAuthProxyHelper.proxyFormPOST(
main, req, resp,
getAppIdentifier(req),
enforcePublicTenantAndGetPublicTenantStorage(req),
clientId, // clientIdToCheck
"/oauth2/token", // proxyPath
false, // proxyToAdmin
false, // camelToSnakeCaseConversion
formFields,
headers // headers
main, req, resp,
getAppIdentifier(req),
enforcePublicTenantAndGetPublicTenantStorage(req),
clientId, // clientIdToCheck
"/oauth2/token", // proxyPath
false, // proxyToAdmin
false, // camelToSnakeCaseConversion
formFields,
headers // headers
);

if (response != null) {
try {
response.jsonResponse = OAuth.transformTokens(super.main, appIdentifier, storage, response.jsonResponse.getAsJsonObject(), iss, accessTokenUpdate, idTokenUpdate, useDynamicKey);
response.jsonResponse = OAuth.transformTokens(super.main, appIdentifier, storage, response.jsonResponse.getAsJsonObject(),
iss, accessTokenUpdate, idTokenUpdate, useDynamicKey);

if (grantType.equals("client_credentials")) {
try {
Expand Down Expand Up @@ -215,15 +244,15 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I
formFieldsForTokenIntrospect.put("token", newRefreshToken);

HttpRequestForOAuthProvider.Response introspectResponse = OAuthProxyHelper.proxyFormPOST(
main, req, resp,
getAppIdentifier(req),
enforcePublicTenantAndGetPublicTenantStorage(req),
null, // clientIdToCheck
"/admin/oauth2/introspect", // pathProxy
true, // proxyToAdmin
false, // camelToSnakeCaseConversion
formFieldsForTokenIntrospect,
new HashMap<>() // headers
main, req, resp,
getAppIdentifier(req),
enforcePublicTenantAndGetPublicTenantStorage(req),
null, // clientIdToCheck
"/admin/oauth2/introspect", // pathProxy
true, // proxyToAdmin
false, // camelToSnakeCaseConversion
formFieldsForTokenIntrospect,
new HashMap<>() // headers
);

if (introspectResponse != null) {
Expand Down Expand Up @@ -288,4 +317,11 @@ private void updateLastActive(AppIdentifier appIdentifier, String sessionHandle)
// ignore
}
}


private static class NamedLockObject {
final Object obj = new Object();
final AtomicInteger refCount = new AtomicInteger(0);
}
private static final ConcurrentHashMap<String, NamedLockObject> lockMap = new ConcurrentHashMap<>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
import java.net.URL;
import java.net.URLDecoder;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import static org.junit.Assert.*;

Expand Down Expand Up @@ -399,6 +403,75 @@ public void testRefreshTokenWithRotationIsDisabledAfter() throws Exception {
assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED));
}

@Test
public void testParallelRefreshTokenWithoutRotation() throws Exception {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails without the fix, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

String[] args = {"../"};

TestingProcessManager.TestingProcess process = TestingProcessManager.start(args);
assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED));

if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) {
return;
}

FeatureFlag.getInstance(process.getProcess())
.setLicenseKeyAndSyncFeatures(TotpLicenseTest.OPAQUE_KEY_WITH_MFA_FEATURE);
FeatureFlagTestContent.getInstance(process.getProcess())
.setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{EE_FEATURES.OAUTH});

if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) {
return;
}

// Create client with token rotation disabled
JsonObject client = createClient(process.getProcess(), false);
JsonObject tokens = completeFlowAndGetTokens(process.getProcess(), client);

String refreshToken = tokens.get("refresh_token").getAsString();

// Setup parallel execution: 16 threads, each making 1000 refresh calls
int numberOfThreads = 16;
int refreshCallsPerThread = 25;
ExecutorService executor = Executors.newFixedThreadPool(numberOfThreads);
AtomicInteger successCount = new AtomicInteger(0);
AtomicInteger failureCount = new AtomicInteger(0);
List<Exception> exceptions = Collections.synchronizedList(new ArrayList<>());

// Execute refresh token calls in parallel
for (int i = 0; i < numberOfThreads; i++) {
executor.execute(() -> {
for (int j = 0; j < refreshCallsPerThread; j++) {
try {
JsonObject refreshResponse = refreshToken(process.getProcess(), client, refreshToken);
if ("OK".equals(refreshResponse.get("status").getAsString())) {
successCount.incrementAndGet();
} else {
failureCount.incrementAndGet();
exceptions.add(new RuntimeException("Refresh failed: " + refreshResponse.toString()));
}
} catch (Exception e) {
System.out.println(e.getMessage());
failureCount.incrementAndGet();
exceptions.add(e);
}
}
});
}

executor.shutdown();
boolean terminated = executor.awaitTermination(5, TimeUnit.MINUTES);
assertTrue("Executor did not terminate within timeout", terminated);

// Verify all refresh calls succeeded
int totalExpectedCalls = numberOfThreads * refreshCallsPerThread;
assertEquals("All refresh token calls should succeed", totalExpectedCalls, successCount.get());
assertEquals("No refresh token calls should fail", 0, failureCount.get());
assertTrue("No exceptions should occur", exceptions.isEmpty());

process.kill();
assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED));
}

private static Map<String, String> splitQuery(URL url) throws UnsupportedEncodingException {
Map<String, String> queryPairs = new LinkedHashMap<>();
String query = url.getQuery();
Expand Down
Loading