diff --git a/acp_test.go b/acp_test.go index 47bc889..915bb87 100644 --- a/acp_test.go +++ b/acp_test.go @@ -5,6 +5,7 @@ import ( "io" "slices" "sync" + "sync/atomic" "testing" "time" ) @@ -633,3 +634,115 @@ func TestPromptCancellationSendsCancelAndAllowsNewSession(t *testing.T) { t.Fatalf("newSession after cancel: %v", err) } } + +// TestPromptWaitsForSessionUpdatesComplete verifies that Prompt() waits for all SessionUpdate +// notification handlers to complete before returning. This ensures that when a server sends +// SessionUpdate notifications followed by a PromptResponse, the client-side Prompt() call will not +// return until all notification handlers have finished processing. This is the expected semantic +// contract: the prompt operation includes all its updates. +func TestPromptWaitsForSessionUpdatesComplete(t *testing.T) { + const numUpdates = 10 + const handlerDelay = 50 * time.Millisecond + + var ( + updateStarted atomic.Int64 + updateCompleted atomic.Int64 + ) + + c2aR, c2aW := io.Pipe() + a2cR, a2cW := io.Pipe() + + // Client side with SessionUpdate handler that tracks execution + c := NewClientSideConnection(&clientFuncs{ + WriteTextFileFunc: func(context.Context, WriteTextFileRequest) (WriteTextFileResponse, error) { + return WriteTextFileResponse{}, nil + }, + ReadTextFileFunc: func(context.Context, ReadTextFileRequest) (ReadTextFileResponse, error) { + return ReadTextFileResponse{Content: "test"}, nil + }, + RequestPermissionFunc: func(context.Context, RequestPermissionRequest) (RequestPermissionResponse, error) { + return RequestPermissionResponse{Outcome: RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{OptionId: "allow"}}}, nil + }, + SessionUpdateFunc: func(_ context.Context, n SessionNotification) error { + updateStarted.Add(1) + // Simulate processing time + time.Sleep(handlerDelay) + updateCompleted.Add(1) + return nil + }, + }, c2aW, a2cR) + + // Agent side that sends multiple SessionUpdate notifications before responding + var wg sync.WaitGroup + wg.Add(1) + + var ag *AgentSideConnection + ag = NewAgentSideConnection(agentFuncs{ + InitializeFunc: func(context.Context, InitializeRequest) (InitializeResponse, error) { + return InitializeResponse{ProtocolVersion: ProtocolVersionNumber, AgentCapabilities: AgentCapabilities{LoadSession: false}, AuthMethods: []AuthMethod{}}, nil + }, + NewSessionFunc: func(context.Context, NewSessionRequest) (NewSessionResponse, error) { + return NewSessionResponse{SessionId: "test-session"}, nil + }, + LoadSessionFunc: func(context.Context, LoadSessionRequest) (LoadSessionResponse, error) { + return LoadSessionResponse{}, nil + }, + AuthenticateFunc: func(context.Context, AuthenticateRequest) (AuthenticateResponse, error) { + return AuthenticateResponse{}, nil + }, + PromptFunc: func(ctx context.Context, p PromptRequest) (PromptResponse, error) { + defer wg.Done() + + // Send multiple SessionUpdate notifications + for i := 0; i < numUpdates; i++ { + _ = ag.SessionUpdate(ctx, SessionNotification{ + SessionId: p.SessionId, + Update: SessionUpdate{ + AgentMessageChunk: &SessionUpdateAgentMessageChunk{ + Content: TextBlock("chunk"), + }, + }, + }) + } + + // Small delay to ensure notifications are queued + time.Sleep(10 * time.Millisecond) + + // Return response (this will unblock client's Prompt() call) + return PromptResponse{StopReason: "end_turn"}, nil + }, + CancelFunc: func(context.Context, CancelNotification) error { return nil }, + }, a2cW, c2aR) + + if _, err := c.Initialize(context.Background(), InitializeRequest{ProtocolVersion: ProtocolVersionNumber}); err != nil { + t.Fatalf("initialize: %v", err) + } + sess, err := c.NewSession(context.Background(), NewSessionRequest{Cwd: "/", McpServers: []McpServer{}}) + if err != nil { + t.Fatalf("newSession: %v", err) + } + + _, err = c.Prompt(context.Background(), PromptRequest{ + SessionId: sess.SessionId, + Prompt: []ContentBlock{TextBlock("test")}, + }) + if err != nil { + t.Fatalf("prompt: %v", err) + } + + wg.Wait() + + // Verify the expected behavior: at this point, Prompt() has returned, and all SessionUpdate + // handlers should have completed their processing. + // started := updateStarted.Load() ; Currently unsused but useful for debugging + completed := updateCompleted.Load() + + // ASSERT: when Prompt() returns, all SessionUpdate notifications that were sent + // before the PromptResponse must have been fully processed. This is the semantic + // contract: the prompt operation includes all its updates. + if completed != numUpdates { + t.Fatalf("Prompt() returned with only %d/%d SessionUpdate "+ + "handlers completed. Expected all handlers to complete before Prompt() "+ + "returns.", completed, numUpdates) + } +} diff --git a/connection.go b/connection.go index 2cebcb9..5e4865b 100644 --- a/connection.go +++ b/connection.go @@ -41,6 +41,10 @@ type Connection struct { cancel context.CancelCauseFunc logger *slog.Logger + + // notificationWg tracks in-flight notification handlers. This ensures SendRequest waits + // for all notifications received before the response to complete processing. + notificationWg sync.WaitGroup } func NewConnection(handler MethodHandler, peerInput io.Writer, peerOutput io.Reader) *Connection { @@ -94,7 +98,11 @@ func (c *Connection) receive() { case msg.ID != nil && msg.Method == "": c.handleResponse(&msg) case msg.Method != "": - go c.handleInbound(&msg) + c.notificationWg.Add(1) + go func(m *anyMessage) { + defer c.notificationWg.Done() + c.handleInbound(m) + }(&msg) default: c.loggerOrDefault().Error("received message with neither id nor method", "raw", string(line)) } @@ -193,6 +201,11 @@ func SendRequest[T any](c *Connection, ctx context.Context, method string, param return result, err } + // Wait for all notification handlers that were spawned before this response to complete + // processing. This ensures that when a request returns, all notifications sent by the + // server before the response have been fully processed. + c.notificationWg.Wait() + if resp.Error != nil { return result, resp.Error } @@ -266,6 +279,11 @@ func (c *Connection) SendRequestNoResult(ctx context.Context, method string, par return err } + // Wait for all notification handlers that were spawned before this response to complete + // processing. This ensures that when a request returns, all notifications sent by the + // server before the response have been fully processed. + c.notificationWg.Wait() + if resp.Error != nil { return resp.Error }