Skip to content

Commit 1b1b2b2

Browse files
authored
fix: race issue with oauth refresh (#1199)
* fix: race issue with oauth refresh * fix: review comment * fix: remove print
1 parent e2ce582 commit 1b1b2b2

File tree

4 files changed

+145
-32
lines changed

4 files changed

+145
-32
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

88
## [Unreleased]
99

10+
## [11.2.1]
11+
12+
- Fixes race issues with Refreshing OAuth token
13+
1014
## [11.2.0]
1115

1216
- Adds opentelemetry-javaagent to the core distribution

build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ java {
2727
}
2828
}
2929

30-
version = "11.2.0"
30+
version = "11.2.1"
3131

3232
repositories {
3333
mavenCentral()

src/main/java/io/supertokens/webserver/api/oauth/OAuthTokenAPI.java

Lines changed: 67 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@
6262
import java.security.spec.InvalidKeySpecException;
6363
import java.util.HashMap;
6464
import java.util.Map;
65+
import java.util.concurrent.ConcurrentHashMap;
66+
import java.util.concurrent.atomic.AtomicInteger;
6567

6668
public class OAuthTokenAPI extends WebserverAPI {
6769

@@ -90,6 +92,32 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I
9092

9193
String authorizationHeader = InputParser.parseStringOrThrowError(input, "authorizationHeader", true);
9294

95+
if (grantType.equals("refresh_token")) {
96+
String refreshTokenForLock = InputParser.parseStringOrThrowError(bodyFromSDK, "refresh_token", false);
97+
NamedLockObject entry = lockMap.computeIfAbsent(refreshTokenForLock, k -> new NamedLockObject());
98+
try {
99+
entry.refCount.incrementAndGet();
100+
synchronized (entry.obj) {
101+
handle(req, resp, authorizationHeader, bodyFromSDK, grantType, iss, accessTokenUpdate,
102+
idTokenUpdate,
103+
useDynamicKey);
104+
}
105+
} finally {
106+
entry.refCount.decrementAndGet();
107+
if (entry.refCount.get() == 0) {
108+
lockMap.remove(refreshTokenForLock, entry);
109+
}
110+
}
111+
112+
} else {
113+
handle(req, resp, authorizationHeader, bodyFromSDK, grantType, iss, accessTokenUpdate, idTokenUpdate,
114+
useDynamicKey);
115+
}
116+
}
117+
118+
private void handle(HttpServletRequest req, HttpServletResponse resp, String authorizationHeader,
119+
JsonObject bodyFromSDK, String grantType, String iss, JsonObject accessTokenUpdate,
120+
JsonObject idTokenUpdate, boolean useDynamicKey) throws ServletException, IOException {
93121
Map<String, String> headers = new HashMap<>();
94122
if (authorizationHeader != null) {
95123
headers.put("Authorization", authorizationHeader);
@@ -127,34 +155,34 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I
127155
formFieldsForTokenIntrospect.put("token", internalRefreshToken);
128156

129157
HttpRequestForOAuthProvider.Response response = OAuthProxyHelper.proxyFormPOST(
130-
main, req, resp,
131-
appIdentifier,
132-
storage,
133-
null, // clientIdToCheck
134-
"/admin/oauth2/introspect", // pathProxy
135-
true, // proxyToAdmin
136-
false, // camelToSnakeCaseConversion
137-
formFieldsForTokenIntrospect,
138-
new HashMap<>() // headers
158+
main, req, resp,
159+
appIdentifier,
160+
storage,
161+
null, // clientIdToCheck
162+
"/admin/oauth2/introspect", // pathProxy
163+
true, // proxyToAdmin
164+
false, // camelToSnakeCaseConversion
165+
formFieldsForTokenIntrospect,
166+
new HashMap<>() // headers
139167
);
140168

141169
if (response == null) {
142-
return; // proxy helper would have sent the error response
170+
return;
143171
}
144172

145173
JsonObject refreshTokenPayload = response.jsonResponse.getAsJsonObject();
146174

147175
try {
148176
OAuth.verifyAndUpdateIntrospectRefreshTokenPayload(main, appIdentifier, storage, refreshTokenPayload, refreshToken, oauthClient.clientId);
149177
} catch (StorageQueryException | TenantOrAppNotFoundException |
150-
FeatureNotEnabledException | InvalidConfigException e) {
178+
FeatureNotEnabledException | InvalidConfigException e) {
151179
throw new ServletException(e);
152180
}
153181

154182
if (!refreshTokenPayload.get("active").getAsBoolean()) {
155183
// this is what ory would return for an invalid token
156184
OAuthProxyHelper.handleOAuthAPIException(resp, new OAuthAPIException(
157-
"token_inactive", "Token is inactive because it is malformed, expired or otherwise invalid. Token validation failed.", 401
185+
"token_inactive", "Token is inactive because it is malformed, expired or otherwise invalid. Token validation failed.", 401
158186
));
159187
return;
160188
}
@@ -163,20 +191,21 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I
163191
}
164192

165193
HttpRequestForOAuthProvider.Response response = OAuthProxyHelper.proxyFormPOST(
166-
main, req, resp,
167-
getAppIdentifier(req),
168-
enforcePublicTenantAndGetPublicTenantStorage(req),
169-
clientId, // clientIdToCheck
170-
"/oauth2/token", // proxyPath
171-
false, // proxyToAdmin
172-
false, // camelToSnakeCaseConversion
173-
formFields,
174-
headers // headers
194+
main, req, resp,
195+
getAppIdentifier(req),
196+
enforcePublicTenantAndGetPublicTenantStorage(req),
197+
clientId, // clientIdToCheck
198+
"/oauth2/token", // proxyPath
199+
false, // proxyToAdmin
200+
false, // camelToSnakeCaseConversion
201+
formFields,
202+
headers // headers
175203
);
176204

177205
if (response != null) {
178206
try {
179-
response.jsonResponse = OAuth.transformTokens(super.main, appIdentifier, storage, response.jsonResponse.getAsJsonObject(), iss, accessTokenUpdate, idTokenUpdate, useDynamicKey);
207+
response.jsonResponse = OAuth.transformTokens(super.main, appIdentifier, storage, response.jsonResponse.getAsJsonObject(),
208+
iss, accessTokenUpdate, idTokenUpdate, useDynamicKey);
180209

181210
if (grantType.equals("client_credentials")) {
182211
try {
@@ -215,15 +244,15 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I
215244
formFieldsForTokenIntrospect.put("token", newRefreshToken);
216245

217246
HttpRequestForOAuthProvider.Response introspectResponse = OAuthProxyHelper.proxyFormPOST(
218-
main, req, resp,
219-
getAppIdentifier(req),
220-
enforcePublicTenantAndGetPublicTenantStorage(req),
221-
null, // clientIdToCheck
222-
"/admin/oauth2/introspect", // pathProxy
223-
true, // proxyToAdmin
224-
false, // camelToSnakeCaseConversion
225-
formFieldsForTokenIntrospect,
226-
new HashMap<>() // headers
247+
main, req, resp,
248+
getAppIdentifier(req),
249+
enforcePublicTenantAndGetPublicTenantStorage(req),
250+
null, // clientIdToCheck
251+
"/admin/oauth2/introspect", // pathProxy
252+
true, // proxyToAdmin
253+
false, // camelToSnakeCaseConversion
254+
formFieldsForTokenIntrospect,
255+
new HashMap<>() // headers
227256
);
228257

229258
if (introspectResponse != null) {
@@ -288,4 +317,11 @@ private void updateLastActive(AppIdentifier appIdentifier, String sessionHandle)
288317
// ignore
289318
}
290319
}
320+
321+
322+
private static class NamedLockObject {
323+
final Object obj = new Object();
324+
final AtomicInteger refCount = new AtomicInteger(0);
325+
}
326+
private static final ConcurrentHashMap<String, NamedLockObject> lockMap = new ConcurrentHashMap<>();
291327
}

src/test/java/io/supertokens/test/oauth/api/TestRefreshTokenFlowWithTokenRotationOptions.java

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@
4040
import java.net.URL;
4141
import java.net.URLDecoder;
4242
import java.util.*;
43+
import java.util.concurrent.ExecutorService;
44+
import java.util.concurrent.Executors;
45+
import java.util.concurrent.TimeUnit;
46+
import java.util.concurrent.atomic.AtomicInteger;
4347

4448
import static org.junit.Assert.*;
4549

@@ -399,6 +403,75 @@ public void testRefreshTokenWithRotationIsDisabledAfter() throws Exception {
399403
assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED));
400404
}
401405

406+
@Test
407+
public void testParallelRefreshTokenWithoutRotation() throws Exception {
408+
String[] args = {"../"};
409+
410+
TestingProcessManager.TestingProcess process = TestingProcessManager.start(args);
411+
assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED));
412+
413+
if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) {
414+
return;
415+
}
416+
417+
FeatureFlag.getInstance(process.getProcess())
418+
.setLicenseKeyAndSyncFeatures(TotpLicenseTest.OPAQUE_KEY_WITH_MFA_FEATURE);
419+
FeatureFlagTestContent.getInstance(process.getProcess())
420+
.setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{EE_FEATURES.OAUTH});
421+
422+
if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) {
423+
return;
424+
}
425+
426+
// Create client with token rotation disabled
427+
JsonObject client = createClient(process.getProcess(), false);
428+
JsonObject tokens = completeFlowAndGetTokens(process.getProcess(), client);
429+
430+
String refreshToken = tokens.get("refresh_token").getAsString();
431+
432+
// Setup parallel execution: 16 threads, each making 1000 refresh calls
433+
int numberOfThreads = 16;
434+
int refreshCallsPerThread = 25;
435+
ExecutorService executor = Executors.newFixedThreadPool(numberOfThreads);
436+
AtomicInteger successCount = new AtomicInteger(0);
437+
AtomicInteger failureCount = new AtomicInteger(0);
438+
List<Exception> exceptions = Collections.synchronizedList(new ArrayList<>());
439+
440+
// Execute refresh token calls in parallel
441+
for (int i = 0; i < numberOfThreads; i++) {
442+
executor.execute(() -> {
443+
for (int j = 0; j < refreshCallsPerThread; j++) {
444+
try {
445+
JsonObject refreshResponse = refreshToken(process.getProcess(), client, refreshToken);
446+
if ("OK".equals(refreshResponse.get("status").getAsString())) {
447+
successCount.incrementAndGet();
448+
} else {
449+
failureCount.incrementAndGet();
450+
exceptions.add(new RuntimeException("Refresh failed: " + refreshResponse.toString()));
451+
}
452+
} catch (Exception e) {
453+
System.out.println(e.getMessage());
454+
failureCount.incrementAndGet();
455+
exceptions.add(e);
456+
}
457+
}
458+
});
459+
}
460+
461+
executor.shutdown();
462+
boolean terminated = executor.awaitTermination(5, TimeUnit.MINUTES);
463+
assertTrue("Executor did not terminate within timeout", terminated);
464+
465+
// Verify all refresh calls succeeded
466+
int totalExpectedCalls = numberOfThreads * refreshCallsPerThread;
467+
assertEquals("All refresh token calls should succeed", totalExpectedCalls, successCount.get());
468+
assertEquals("No refresh token calls should fail", 0, failureCount.get());
469+
assertTrue("No exceptions should occur", exceptions.isEmpty());
470+
471+
process.kill();
472+
assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED));
473+
}
474+
402475
private static Map<String, String> splitQuery(URL url) throws UnsupportedEncodingException {
403476
Map<String, String> queryPairs = new LinkedHashMap<>();
404477
String query = url.getQuery();

0 commit comments

Comments
 (0)