@@ -14,6 +14,9 @@ import kotlinx.atomicfu.update
1414import kotlinx.collections.immutable.persistentMapOf
1515import kotlinx.coroutines.CompletableDeferred
1616import kotlinx.coroutines.ExperimentalCoroutinesApi
17+ import kotlinx.coroutines.flow.MutableStateFlow
18+ import kotlinx.coroutines.flow.first
19+ import kotlinx.coroutines.flow.update
1720import kotlinx.serialization.json.JsonElement
1821
1922private val logger = KotlinLogging .logger {}
@@ -37,6 +40,7 @@ public class Client(
3740 private val _sessions = atomic(persistentMapOf<SessionId , CompletableDeferred <ClientSessionImpl >>())
3841 private val _clientInfo = CompletableDeferred <ClientInfo >()
3942 private val _agentInfo = CompletableDeferred <AgentInfo >()
43+ private val _currentlyInitializingSessionsCount = MutableStateFlow (0 )
4044
4145 init {
4246 // Set up request handlers for incoming agent requests
@@ -158,15 +162,18 @@ public class Client(
158162 * @return a [ClientSession] instance for the new session
159163 */
160164 public suspend fun newSession (sessionParameters : SessionCreationParameters , operationsFactory : ClientOperationsFactory ): ClientSession {
161- val newSessionResponse = AcpMethod .AgentMethods .SessionNew (protocol,
162- NewSessionRequest (
163- sessionParameters.cwd,
164- sessionParameters.mcpServers,
165- sessionParameters._meta
165+ return withInitializingSession {
166+ val newSessionResponse = AcpMethod .AgentMethods .SessionNew (
167+ protocol,
168+ NewSessionRequest (
169+ sessionParameters.cwd,
170+ sessionParameters.mcpServers,
171+ sessionParameters._meta
172+ )
166173 )
167- )
168- val sessionId = newSessionResponse.sessionId
169- return createSession(sessionId, sessionParameters, newSessionResponse, operationsFactory)
174+ val sessionId = newSessionResponse.sessionId
175+ return @withInitializingSession createSession( sessionId, sessionParameters, newSessionResponse, operationsFactory)
176+ }
170177 }
171178
172179 /* *
@@ -180,15 +187,18 @@ public class Client(
180187 * @return a [ClientSession] instance for the new session
181188 */
182189 public suspend fun loadSession (sessionId : SessionId , sessionParameters : SessionCreationParameters , operationsFactory : ClientOperationsFactory ): ClientSession {
183- val loadSessionResponse = AcpMethod .AgentMethods .SessionLoad (protocol,
184- LoadSessionRequest (
185- sessionId,
186- sessionParameters.cwd,
187- sessionParameters.mcpServers,
188- sessionParameters._meta
189- ))
190-
191- return createSession(sessionId, sessionParameters, loadSessionResponse, operationsFactory)
190+ return withInitializingSession {
191+ val loadSessionResponse = AcpMethod .AgentMethods .SessionLoad (
192+ protocol,
193+ LoadSessionRequest (
194+ sessionId,
195+ sessionParameters.cwd,
196+ sessionParameters.mcpServers,
197+ sessionParameters._meta
198+ )
199+ )
200+ return @withInitializingSession createSession(sessionId, sessionParameters, loadSessionResponse, operationsFactory)
201+ }
192202 }
193203
194204 private suspend fun createSession (sessionId : SessionId , sessionParameters : SessionCreationParameters , sessionResponse : AcpCreatedSessionResponse , factory : ClientOperationsFactory ): ClientSession {
@@ -215,7 +225,27 @@ public class Client(
215225 return completableDeferred.getCompleted()
216226 }
217227
218- private suspend fun getSessionOrThrow (sessionId : SessionId ): ClientSessionImpl = (_sessions .value[sessionId] ? : acpFail(" Session $sessionId not found" )).await()
228+ private suspend fun getSessionOrThrow (sessionId : SessionId ): ClientSessionImpl {
229+ _sessions .value[sessionId]?.let {
230+ return it.await()
231+ }
232+ // try to wait for all pending sessions to initialize
233+ _currentlyInitializingSessionsCount .first { it == 0 }
234+ // try to get the session again
235+ _sessions .value[sessionId]?.let {
236+ return it.await()
237+ }
238+ acpFail(" Session $sessionId not found" )
239+ }
240+
241+ private suspend fun <T > withInitializingSession (block : suspend () -> T ): T {
242+ _currentlyInitializingSessionsCount .update { it + 1 }
243+ try {
244+ return block()
245+ } finally {
246+ _currentlyInitializingSessionsCount .update { it - 1 }
247+ }
248+ }
219249}
220250
221251private inline fun <reified TInterface > sessionMethodNotFound (method : AcpMethod ): Nothing {
0 commit comments