Skip to content

Commit 2a2c4d8

Browse files
authored
Wait for pending initializing sessions to avoid session not found (#21)
#3
1 parent 79318b6 commit 2a2c4d8

File tree

2 files changed

+49
-19
lines changed

2 files changed

+49
-19
lines changed

acp/src/commonMain/kotlin/com/agentclientprotocol/client/Client.kt

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ import kotlinx.atomicfu.update
1414
import kotlinx.collections.immutable.persistentMapOf
1515
import kotlinx.coroutines.CompletableDeferred
1616
import kotlinx.coroutines.ExperimentalCoroutinesApi
17+
import kotlinx.coroutines.flow.MutableStateFlow
18+
import kotlinx.coroutines.flow.first
19+
import kotlinx.coroutines.flow.update
1720
import kotlinx.serialization.json.JsonElement
1821

1922
private val logger = KotlinLogging.logger {}
@@ -37,6 +40,7 @@ public class Client(
3740
private val _sessions = atomic(persistentMapOf<SessionId, CompletableDeferred<ClientSessionImpl>>())
3841
private val _clientInfo = CompletableDeferred<ClientInfo>()
3942
private val _agentInfo = CompletableDeferred<AgentInfo>()
43+
private val _currentlyInitializingSessionsCount = MutableStateFlow(0)
4044

4145
init {
4246
// Set up request handlers for incoming agent requests
@@ -158,15 +162,18 @@ public class Client(
158162
* @return a [ClientSession] instance for the new session
159163
*/
160164
public suspend fun newSession(sessionParameters: SessionCreationParameters, operationsFactory: ClientOperationsFactory): ClientSession {
161-
val newSessionResponse = AcpMethod.AgentMethods.SessionNew(protocol,
162-
NewSessionRequest(
163-
sessionParameters.cwd,
164-
sessionParameters.mcpServers,
165-
sessionParameters._meta
165+
return withInitializingSession {
166+
val newSessionResponse = AcpMethod.AgentMethods.SessionNew(
167+
protocol,
168+
NewSessionRequest(
169+
sessionParameters.cwd,
170+
sessionParameters.mcpServers,
171+
sessionParameters._meta
172+
)
166173
)
167-
)
168-
val sessionId = newSessionResponse.sessionId
169-
return createSession(sessionId, sessionParameters, newSessionResponse, operationsFactory)
174+
val sessionId = newSessionResponse.sessionId
175+
return@withInitializingSession createSession(sessionId, sessionParameters, newSessionResponse, operationsFactory)
176+
}
170177
}
171178

172179
/**
@@ -180,15 +187,18 @@ public class Client(
180187
* @return a [ClientSession] instance for the new session
181188
*/
182189
public suspend fun loadSession(sessionId: SessionId, sessionParameters: SessionCreationParameters, operationsFactory: ClientOperationsFactory): ClientSession {
183-
val loadSessionResponse = AcpMethod.AgentMethods.SessionLoad(protocol,
184-
LoadSessionRequest(
185-
sessionId,
186-
sessionParameters.cwd,
187-
sessionParameters.mcpServers,
188-
sessionParameters._meta
189-
))
190-
191-
return createSession(sessionId, sessionParameters, loadSessionResponse, operationsFactory)
190+
return withInitializingSession {
191+
val loadSessionResponse = AcpMethod.AgentMethods.SessionLoad(
192+
protocol,
193+
LoadSessionRequest(
194+
sessionId,
195+
sessionParameters.cwd,
196+
sessionParameters.mcpServers,
197+
sessionParameters._meta
198+
)
199+
)
200+
return@withInitializingSession createSession(sessionId, sessionParameters, loadSessionResponse, operationsFactory)
201+
}
192202
}
193203

194204
private suspend fun createSession(sessionId: SessionId, sessionParameters: SessionCreationParameters, sessionResponse: AcpCreatedSessionResponse, factory: ClientOperationsFactory): ClientSession {
@@ -215,7 +225,27 @@ public class Client(
215225
return completableDeferred.getCompleted()
216226
}
217227

218-
private suspend fun getSessionOrThrow(sessionId: SessionId): ClientSessionImpl = (_sessions.value[sessionId] ?: acpFail("Session $sessionId not found")).await()
228+
private suspend fun getSessionOrThrow(sessionId: SessionId): ClientSessionImpl {
229+
_sessions.value[sessionId]?.let {
230+
return it.await()
231+
}
232+
// try to wait for all pending sessions to initialize
233+
_currentlyInitializingSessionsCount.first { it == 0 }
234+
// try to get the session again
235+
_sessions.value[sessionId]?.let {
236+
return it.await()
237+
}
238+
acpFail("Session $sessionId not found")
239+
}
240+
241+
private suspend fun<T> withInitializingSession(block: suspend () -> T): T {
242+
_currentlyInitializingSessionsCount.update { it + 1 }
243+
try {
244+
return block()
245+
} finally {
246+
_currentlyInitializingSessionsCount.update { it - 1 }
247+
}
248+
}
219249
}
220250

221251
private inline fun <reified TInterface> sessionMethodNotFound(method: AcpMethod): Nothing {

build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ plugins {
77
private val buildNumber: String? = System.getenv("GITHUB_RUN_NUMBER")
88
private val isReleasePublication = System.getenv("RELEASE_PUBLICATION")?.toBoolean() ?: false
99

10-
private val baseVersion = "0.7.1"
10+
private val baseVersion = "0.7.2"
1111

1212
allprojects {
1313
group = "com.agentclientprotocol"

0 commit comments

Comments
 (0)