Skip to content

Commit a82637f

Browse files
committed
Evict streamable HTTP sessions after failed keep-alive pings
1 parent d1ef187 commit a82637f

4 files changed

Lines changed: 381 additions & 4 deletions

File tree

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import io.modelcontextprotocol.spec.HttpHeaders;
2525
import io.modelcontextprotocol.spec.McpError;
2626
import io.modelcontextprotocol.spec.McpSchema;
27+
import io.modelcontextprotocol.spec.McpSession;
2728
import io.modelcontextprotocol.spec.McpStreamableServerSession;
2829
import io.modelcontextprotocol.spec.McpStreamableServerTransport;
2930
import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
@@ -87,6 +88,8 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet
8788

8889
public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}";
8990

91+
private static final int KEEP_ALIVE_FAILURE_THRESHOLD = 3;
92+
9093
/**
9194
* The endpoint URI where clients should send their JSON-RPC messages. Defaults to
9295
* "/mcp".
@@ -107,6 +110,8 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet
107110
*/
108111
private final ConcurrentHashMap<String, McpStreamableServerSession> sessions = new ConcurrentHashMap<>();
109112

113+
private final ConcurrentHashMap<String, Integer> keepAliveFailureCounts = new ConcurrentHashMap<>();
114+
110115
private McpTransportContextExtractor<HttpServletRequest> contextExtractor;
111116

112117
/**
@@ -158,6 +163,8 @@ private HttpServletStreamableServerTransportProvider(McpJsonMapper jsonMapper, S
158163
.builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessions.values()))
159164
.initialDelay(keepAliveInterval)
160165
.interval(keepAliveInterval)
166+
.onSuccess(this::resetKeepAliveFailures)
167+
.onFailure(this::handleKeepAliveFailure)
161168
.build();
162169

163170
this.keepAliveScheduler.start();
@@ -231,8 +238,10 @@ public Mono<Void> closeGracefully() {
231238
});
232239

233240
this.sessions.clear();
241+
this.keepAliveFailureCounts.clear();
234242
}).then().doOnSuccess(v -> {
235243
sessions.clear();
244+
keepAliveFailureCounts.clear();
236245
logger.debug("Graceful shutdown completed");
237246
if (this.keepAliveScheduler != null) {
238247
this.keepAliveScheduler.shutdown();
@@ -445,6 +454,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
445454
McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory
446455
.startSession(initializeRequest);
447456
this.sessions.put(init.session().getId(), init.session());
457+
this.keepAliveFailureCounts.remove(init.session().getId());
448458

449459
try {
450460
McpSchema.InitializeResult initResult = init.initResult().block();
@@ -614,6 +624,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response
614624
try {
615625
session.delete().contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block();
616626
this.sessions.remove(sessionId);
627+
this.keepAliveFailureCounts.remove(sessionId);
617628
response.setStatus(HttpServletResponse.SC_OK);
618629
}
619630
catch (Exception e) {
@@ -640,6 +651,42 @@ public void responseError(HttpServletResponse response, int httpCode, McpError m
640651
return;
641652
}
642653

654+
void resetKeepAliveFailures(McpSession session) {
655+
if (session instanceof McpStreamableServerSession streamableSession) {
656+
String sessionId = streamableSession.getId();
657+
if (this.sessions.get(sessionId) == streamableSession) {
658+
this.keepAliveFailureCounts.remove(sessionId);
659+
}
660+
}
661+
}
662+
663+
void handleKeepAliveFailure(McpSession session, Throwable error) {
664+
if (!(session instanceof McpStreamableServerSession streamableSession)) {
665+
return;
666+
}
667+
668+
String sessionId = streamableSession.getId();
669+
if (this.sessions.get(sessionId) != streamableSession) {
670+
return;
671+
}
672+
673+
int failures = this.keepAliveFailureCounts.merge(sessionId, 1, Integer::sum);
674+
if (failures < KEEP_ALIVE_FAILURE_THRESHOLD) {
675+
logger.debug("Keep-alive ping failed for session {} ({}/{} consecutive failures): {}", sessionId, failures,
676+
KEEP_ALIVE_FAILURE_THRESHOLD, error.getMessage());
677+
return;
678+
}
679+
680+
if (this.sessions.remove(sessionId, streamableSession)) {
681+
this.keepAliveFailureCounts.remove(sessionId);
682+
streamableSession.close();
683+
logger.info("Evicted session {} after {} failed keep-alive attempts", sessionId, failures);
684+
}
685+
else {
686+
this.keepAliveFailureCounts.remove(sessionId);
687+
}
688+
}
689+
643690
/**
644691
* Sends an SSE event to a client with a specific ID.
645692
* @param writer The writer to send the event through
@@ -748,6 +795,7 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message, String messageId
748795
catch (Exception e) {
749796
logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage());
750797
HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId);
798+
HttpServletStreamableServerTransportProvider.this.keepAliveFailureCounts.remove(this.sessionId);
751799
this.asyncContext.complete();
752800
}
753801
finally {

mcp-core/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import java.time.Duration;
88
import java.util.concurrent.atomic.AtomicBoolean;
9+
import java.util.function.BiConsumer;
10+
import java.util.function.Consumer;
911
import java.util.function.Supplier;
1012

1113
import org.slf4j.Logger;
@@ -57,6 +59,10 @@ public class KeepAliveScheduler {
5759
/** Supplier for reactive McpSession instances */
5860
private final Supplier<Flux<McpSession>> mcpSessions;
5961

62+
private final Consumer<McpSession> onSuccess;
63+
64+
private final BiConsumer<McpSession, Throwable> onFailure;
65+
6066
/**
6167
* Creates a KeepAliveScheduler with a custom scheduler, initial delay, interval and a
6268
* supplier for McpSession instances.
@@ -66,11 +72,14 @@ public class KeepAliveScheduler {
6672
* @param mcpSessions Supplier for McpSession instances
6773
*/
6874
KeepAliveScheduler(Scheduler scheduler, Duration initialDelay, Duration interval,
69-
Supplier<Flux<McpSession>> mcpSessions) {
75+
Supplier<Flux<McpSession>> mcpSessions, Consumer<McpSession> onSuccess,
76+
BiConsumer<McpSession, Throwable> onFailure) {
7077
this.scheduler = scheduler;
7178
this.initialDelay = initialDelay;
7279
this.interval = interval;
7380
this.mcpSessions = mcpSessions;
81+
this.onSuccess = onSuccess;
82+
this.onFailure = onFailure;
7483
}
7584

7685
/**
@@ -92,8 +101,12 @@ public Disposable start() {
92101
.doOnNext(tick -> {
93102
this.mcpSessions.get()
94103
.flatMap(session -> session.sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF)
95-
.doOnError(e -> logger.warn("Failed to send keep-alive ping to session {}: {}", session,
96-
e.getMessage()))
104+
.doOnSuccess(result -> this.notifySuccess(session))
105+
.doOnError(e -> {
106+
logger.warn("Failed to send keep-alive ping to session {}: {}", session,
107+
e.getMessage());
108+
this.notifyFailure(session, e);
109+
})
97110
.onErrorComplete())
98111
.subscribe();
99112
})
@@ -131,6 +144,24 @@ public boolean isRunning() {
131144
return this.isRunning.get();
132145
}
133146

147+
private void notifySuccess(McpSession session) {
148+
try {
149+
this.onSuccess.accept(session);
150+
}
151+
catch (Exception e) {
152+
logger.warn("Keep-alive success callback failed for session {}: {}", session, e.getMessage());
153+
}
154+
}
155+
156+
private void notifyFailure(McpSession session, Throwable error) {
157+
try {
158+
this.onFailure.accept(session, error);
159+
}
160+
catch (Exception e) {
161+
logger.warn("Keep-alive failure callback failed for session {}: {}", session, e.getMessage());
162+
}
163+
}
164+
134165
/**
135166
* Shuts down the scheduler and releases resources.
136167
*/
@@ -154,6 +185,12 @@ public static class Builder {
154185

155186
private Supplier<Flux<McpSession>> mcpSessions;
156187

188+
private Consumer<McpSession> onSuccess = session -> {
189+
};
190+
191+
private BiConsumer<McpSession, Throwable> onFailure = (session, error) -> {
192+
};
193+
157194
/**
158195
* Creates a new Builder instance with a supplier for McpSession instances.
159196
* @param mcpSessions The supplier for McpSession instances
@@ -204,12 +241,34 @@ public Builder interval(Duration interval) {
204241
return this;
205242
}
206243

244+
/**
245+
* Sets the callback invoked after a keep-alive ping completes successfully.
246+
* @param onSuccess The success callback. Must not be null.
247+
* @return This builder instance for method chaining
248+
*/
249+
public Builder onSuccess(Consumer<McpSession> onSuccess) {
250+
Assert.notNull(onSuccess, "OnSuccess callback must not be null");
251+
this.onSuccess = onSuccess;
252+
return this;
253+
}
254+
255+
/**
256+
* Sets the callback invoked after a keep-alive ping fails.
257+
* @param onFailure The failure callback. Must not be null.
258+
* @return This builder instance for method chaining
259+
*/
260+
public Builder onFailure(BiConsumer<McpSession, Throwable> onFailure) {
261+
Assert.notNull(onFailure, "OnFailure callback must not be null");
262+
this.onFailure = onFailure;
263+
return this;
264+
}
265+
207266
/**
208267
* Builds and returns a new KeepAliveScheduler instance.
209268
* @return A new KeepAliveScheduler configured with the builder's settings
210269
*/
211270
public KeepAliveScheduler build() {
212-
return new KeepAliveScheduler(scheduler, initialDelay, interval, mcpSessions);
271+
return new KeepAliveScheduler(scheduler, initialDelay, interval, mcpSessions, onSuccess, onFailure);
213272
}
214273

215274
}
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
/*
2+
* Copyright 2026-2026 the original author or authors.
3+
*/
4+
5+
package io.modelcontextprotocol.server.transport;
6+
7+
import java.lang.reflect.Field;
8+
import java.time.Duration;
9+
import java.util.Map;
10+
import java.util.concurrent.atomic.AtomicInteger;
11+
12+
import io.modelcontextprotocol.spec.McpSchema;
13+
import io.modelcontextprotocol.spec.McpStreamableServerSession;
14+
import io.modelcontextprotocol.spec.json.gson.GsonMcpJsonMapper;
15+
import org.junit.jupiter.api.Test;
16+
import reactor.core.publisher.Mono;
17+
18+
import static org.assertj.core.api.Assertions.assertThat;
19+
20+
/**
21+
* Tests for keep-alive failure eviction in
22+
* {@link HttpServletStreamableServerTransportProvider}.
23+
*/
24+
class HttpServletStreamableServerTransportProviderTests {
25+
26+
@Test
27+
void firstKeepAliveFailureDoesNotEvictSession() throws Exception {
28+
HttpServletStreamableServerTransportProvider provider = createProvider();
29+
TrackingStreamableSession session = createSession("session-1");
30+
putSession(provider, session);
31+
32+
provider.handleKeepAliveFailure(session, new RuntimeException("ping failed"));
33+
34+
assertThat(sessions(provider)).containsEntry("session-1", session);
35+
assertThat(keepAliveFailureCounts(provider)).containsEntry("session-1", 1);
36+
assertThat(session.closeCount()).isZero();
37+
}
38+
39+
@Test
40+
void repeatedKeepAliveFailuresEvictSession() throws Exception {
41+
HttpServletStreamableServerTransportProvider provider = createProvider();
42+
TrackingStreamableSession session = createSession("session-1");
43+
putSession(provider, session);
44+
45+
provider.handleKeepAliveFailure(session, new RuntimeException("first failure"));
46+
provider.handleKeepAliveFailure(session, new RuntimeException("second failure"));
47+
provider.handleKeepAliveFailure(session, new RuntimeException("third failure"));
48+
49+
assertThat(sessions(provider)).doesNotContainKey("session-1");
50+
assertThat(keepAliveFailureCounts(provider)).doesNotContainKey("session-1");
51+
assertThat(session.closeCount()).isOne();
52+
}
53+
54+
@Test
55+
void successfulKeepAliveResetsFailureCount() throws Exception {
56+
HttpServletStreamableServerTransportProvider provider = createProvider();
57+
TrackingStreamableSession session = createSession("session-1");
58+
putSession(provider, session);
59+
60+
provider.handleKeepAliveFailure(session, new RuntimeException("first failure"));
61+
provider.handleKeepAliveFailure(session, new RuntimeException("second failure"));
62+
provider.resetKeepAliveFailures(session);
63+
provider.handleKeepAliveFailure(session, new RuntimeException("failure after success"));
64+
65+
assertThat(sessions(provider)).containsEntry("session-1", session);
66+
assertThat(keepAliveFailureCounts(provider)).containsEntry("session-1", 1);
67+
assertThat(session.closeCount()).isZero();
68+
}
69+
70+
@Test
71+
void successfulKeepAliveFromReplacedSessionDoesNotResetReplacementFailureCount() throws Exception {
72+
HttpServletStreamableServerTransportProvider provider = createProvider();
73+
TrackingStreamableSession oldSession = createSession("session-1");
74+
TrackingStreamableSession replacementSession = createSession("session-1");
75+
putSession(provider, replacementSession);
76+
77+
provider.handleKeepAliveFailure(replacementSession, new RuntimeException("replacement failure"));
78+
provider.resetKeepAliveFailures(oldSession);
79+
80+
assertThat(sessions(provider)).containsEntry("session-1", replacementSession);
81+
assertThat(keepAliveFailureCounts(provider)).containsEntry("session-1", 1);
82+
assertThat(oldSession.closeCount()).isZero();
83+
assertThat(replacementSession.closeCount()).isZero();
84+
}
85+
86+
@Test
87+
void keepAliveFailureDoesNotCloseReplacedSession() throws Exception {
88+
HttpServletStreamableServerTransportProvider provider = createProvider();
89+
TrackingStreamableSession oldSession = createSession("session-1");
90+
TrackingStreamableSession replacementSession = createSession("session-1");
91+
putSession(provider, replacementSession);
92+
93+
provider.handleKeepAliveFailure(oldSession, new RuntimeException("first failure"));
94+
provider.handleKeepAliveFailure(oldSession, new RuntimeException("second failure"));
95+
provider.handleKeepAliveFailure(oldSession, new RuntimeException("third failure"));
96+
97+
assertThat(sessions(provider)).containsEntry("session-1", replacementSession);
98+
assertThat(keepAliveFailureCounts(provider)).doesNotContainKey("session-1");
99+
assertThat(oldSession.closeCount()).isZero();
100+
assertThat(replacementSession.closeCount()).isZero();
101+
}
102+
103+
private HttpServletStreamableServerTransportProvider createProvider() {
104+
return HttpServletStreamableServerTransportProvider.builder().jsonMapper(new GsonMcpJsonMapper()).build();
105+
}
106+
107+
private TrackingStreamableSession createSession(String sessionId) {
108+
return new TrackingStreamableSession(sessionId);
109+
}
110+
111+
private void putSession(HttpServletStreamableServerTransportProvider provider, TrackingStreamableSession session)
112+
throws Exception {
113+
sessions(provider).put(session.getId(), session);
114+
}
115+
116+
@SuppressWarnings("unchecked")
117+
private Map<String, McpStreamableServerSession> sessions(HttpServletStreamableServerTransportProvider provider)
118+
throws Exception {
119+
Field field = HttpServletStreamableServerTransportProvider.class.getDeclaredField("sessions");
120+
field.setAccessible(true);
121+
return (Map<String, McpStreamableServerSession>) field.get(provider);
122+
}
123+
124+
@SuppressWarnings("unchecked")
125+
private Map<String, Integer> keepAliveFailureCounts(HttpServletStreamableServerTransportProvider provider)
126+
throws Exception {
127+
Field field = HttpServletStreamableServerTransportProvider.class.getDeclaredField("keepAliveFailureCounts");
128+
field.setAccessible(true);
129+
return (Map<String, Integer>) field.get(provider);
130+
}
131+
132+
private static class TrackingStreamableSession extends McpStreamableServerSession {
133+
134+
private final AtomicInteger closeCount = new AtomicInteger();
135+
136+
TrackingStreamableSession(String id) {
137+
super(id, McpSchema.ClientCapabilities.builder().build(),
138+
new McpSchema.Implementation("test-client", "1.0.0"), Duration.ofSeconds(5), Map.of(), Map.of(),
139+
Mono::empty);
140+
}
141+
142+
@Override
143+
public void close() {
144+
this.closeCount.incrementAndGet();
145+
}
146+
147+
int closeCount() {
148+
return this.closeCount.get();
149+
}
150+
151+
}
152+
153+
}

0 commit comments

Comments
 (0)