Skip to content

Commit 14e50b2

Browse files
authored
Fix prompt race when final response was returned before all updates are emitted (#18)
1 parent 431bd9e commit 14e50b2

File tree

2 files changed

+89
-3
lines changed

2 files changed

+89
-3
lines changed

acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/SimpleAgentTest.kt

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import com.agentclientprotocol.model.SessionUpdate
2626
import com.agentclientprotocol.model.StopReason
2727
import com.agentclientprotocol.model.ToolCallId
2828
import com.agentclientprotocol.protocol.invoke
29+
import io.github.oshai.kotlinlogging.KotlinLogging.logger
2930
import kotlinx.coroutines.CancellationException
3031
import kotlinx.coroutines.CompletableDeferred
3132
import kotlinx.coroutines.awaitCancellation
@@ -140,6 +141,80 @@ abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver
140141
assertEquals(result!!.stopReason, StopReason.END_TURN)
141142
}
142143

144+
@Test
145+
fun `prompt response and update have proper order`() = testWithProtocols { clientProtocol, agentProtocol ->
146+
val client = Client(protocol = clientProtocol)
147+
val agent = Agent(protocol = agentProtocol, agentSupport = object : AgentSupport {
148+
override suspend fun initialize(clientInfo: ClientInfo): AgentInfo {
149+
return AgentInfo(clientInfo.protocolVersion)
150+
}
151+
152+
override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession {
153+
return object : AgentSession {
154+
override val sessionId: SessionId = SessionId("test-session-id")
155+
156+
override suspend fun prompt(
157+
content: List<ContentBlock>,
158+
_meta: JsonElement?,
159+
): Flow<Event> = flow {
160+
emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(ContentBlock.Text(sessionParameters.cwd))))
161+
emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(ContentBlock.Text("text 1"))))
162+
emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(ContentBlock.Text("text 2"))))
163+
emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(ContentBlock.Text("text 3"))))
164+
}
165+
}
166+
}
167+
168+
override suspend fun loadSession(
169+
sessionId: SessionId,
170+
sessionParameters: SessionCreationParameters,
171+
): AgentSession {
172+
TODO("Not yet implemented")
173+
}
174+
})
175+
val testVersion = 10
176+
val clientInfo = ClientInfo(protocolVersion = testVersion)
177+
val agentInfo = client.initialize(clientInfo)
178+
val cwd = "/test/path"
179+
val newSession = client.newSession(SessionCreationParameters(cwd, emptyList())) { _, _ ->
180+
object : ClientSessionOperations {
181+
override suspend fun requestPermissions(
182+
toolCall: SessionUpdate.ToolCallUpdate,
183+
permissions: List<PermissionOption>,
184+
_meta: JsonElement?,
185+
): RequestPermissionResponse {
186+
TODO("Not yet implemented")
187+
}
188+
189+
override suspend fun notify(
190+
notification: SessionUpdate,
191+
_meta: JsonElement?,
192+
) {
193+
TODO("Not yet implemented")
194+
}
195+
}
196+
}
197+
val responses = mutableListOf<String>()
198+
var result: PromptResponse? = null
199+
withTimeout(1000) {
200+
newSession.prompt(listOf()).collect { event ->
201+
when (event) {
202+
is Event.PromptResponseEvent -> {
203+
println( "Received prompt response: ${event.response}" )
204+
result = event.response
205+
responses.add(event.response.stopReason.toString())
206+
}
207+
is Event.SessionUpdateEvent -> {
208+
println( "Received session update: ${(event.update as SessionUpdate.AgentMessageChunk).content}" )
209+
responses.add(((event.update as SessionUpdate.AgentMessageChunk).content as ContentBlock.Text).text)
210+
}
211+
}
212+
}
213+
}
214+
assertContentEquals(listOf("/test/path", "text 1", "text 2", "text 3", "END_TURN"), responses)
215+
assertEquals(result!!.stopReason, StopReason.END_TURN)
216+
}
217+
143218
@Test
144219
fun `cancel simple prompt from client`() = testWithProtocols { clientProtocol, agentProtocol ->
145220
val client = Client(protocol = clientProtocol)

acp/src/commonMain/kotlin/com/agentclientprotocol/client/ClientSessionImpl.kt

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import com.agentclientprotocol.protocol.Protocol
88
import com.agentclientprotocol.protocol.invoke
99
import io.github.oshai.kotlinlogging.KotlinLogging
1010
import kotlinx.atomicfu.atomic
11+
import kotlinx.coroutines.DelicateCoroutinesApi
1112
import kotlinx.coroutines.channels.Channel
1213
import kotlinx.coroutines.flow.Flow
1314
import kotlinx.coroutines.flow.MutableStateFlow
@@ -62,12 +63,20 @@ internal class ClientSessionImpl(
6263
logger.trace { "Sending prompt request: $content" }
6364
val promptResponse = AcpMethod.AgentMethods.SessionPrompt(protocol, PromptRequest(sessionId, content, _meta))
6465
logger.trace { "Received prompt response: $promptResponse" }
65-
send(Event.PromptResponseEvent(promptResponse))
66-
} finally {
66+
67+
// after receiving prompt response we immediately close the current prompt channel
68+
// and then waiting for draining all the updates that were sent during prompt execution
69+
// only after that we emit the PromptResponseEvent to the outbound flow
6770
logger.trace { "Closing prompt channel" }
6871
activePrompt.getAndSet(null)?.updateChannel?.close()
6972
logger.trace { "Waiting for prompt channel to close" }
7073
channelJob.join()
74+
75+
send(Event.PromptResponseEvent(promptResponse))
76+
close()
77+
} finally {
78+
activePrompt.getAndSet(null)?.updateChannel?.close()
79+
channelJob.cancel()
7180
}
7281
}
7382

@@ -123,7 +132,9 @@ internal class ClientSessionImpl(
123132
// }
124133

125134
val promptSession = activePrompt.value
126-
if (promptSession != null) {
135+
@OptIn(DelicateCoroutinesApi::class)
136+
// check for isClosedForSend because the prompt may exist, but the code is waiting for the updates drain
137+
if (promptSession != null && !promptSession.updateChannel.isClosedForSend) {
127138
logger.trace { "Sending update to active prompt: $notification" }
128139
promptSession.updateChannel.send(notification)
129140
}

0 commit comments

Comments
 (0)