Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions acp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io"
"slices"
"sync"
"sync/atomic"
"testing"
"time"
)
Expand Down Expand Up @@ -633,3 +634,115 @@ func TestPromptCancellationSendsCancelAndAllowsNewSession(t *testing.T) {
t.Fatalf("newSession after cancel: %v", err)
}
}

// TestPromptWaitsForSessionUpdatesComplete verifies that Prompt() waits for all SessionUpdate
// notification handlers to complete before returning. This ensures that when a server sends
// SessionUpdate notifications followed by a PromptResponse, the client-side Prompt() call will not
// return until all notification handlers have finished processing. This is the expected semantic
// contract: the prompt operation includes all its updates.
func TestPromptWaitsForSessionUpdatesComplete(t *testing.T) {
const numUpdates = 10
const handlerDelay = 50 * time.Millisecond

var (
updateStarted atomic.Int64
updateCompleted atomic.Int64
)

c2aR, c2aW := io.Pipe()
a2cR, a2cW := io.Pipe()

// Client side with SessionUpdate handler that tracks execution
c := NewClientSideConnection(&clientFuncs{
WriteTextFileFunc: func(context.Context, WriteTextFileRequest) (WriteTextFileResponse, error) {
return WriteTextFileResponse{}, nil
},
ReadTextFileFunc: func(context.Context, ReadTextFileRequest) (ReadTextFileResponse, error) {
return ReadTextFileResponse{Content: "test"}, nil
},
RequestPermissionFunc: func(context.Context, RequestPermissionRequest) (RequestPermissionResponse, error) {
return RequestPermissionResponse{Outcome: RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{OptionId: "allow"}}}, nil
},
SessionUpdateFunc: func(_ context.Context, n SessionNotification) error {
updateStarted.Add(1)
// Simulate processing time
time.Sleep(handlerDelay)
updateCompleted.Add(1)
return nil
},
}, c2aW, a2cR)

// Agent side that sends multiple SessionUpdate notifications before responding
var wg sync.WaitGroup
wg.Add(1)

var ag *AgentSideConnection
ag = NewAgentSideConnection(agentFuncs{
InitializeFunc: func(context.Context, InitializeRequest) (InitializeResponse, error) {
return InitializeResponse{ProtocolVersion: ProtocolVersionNumber, AgentCapabilities: AgentCapabilities{LoadSession: false}, AuthMethods: []AuthMethod{}}, nil
},
NewSessionFunc: func(context.Context, NewSessionRequest) (NewSessionResponse, error) {
return NewSessionResponse{SessionId: "test-session"}, nil
},
LoadSessionFunc: func(context.Context, LoadSessionRequest) (LoadSessionResponse, error) {
return LoadSessionResponse{}, nil
},
AuthenticateFunc: func(context.Context, AuthenticateRequest) (AuthenticateResponse, error) {
return AuthenticateResponse{}, nil
},
PromptFunc: func(ctx context.Context, p PromptRequest) (PromptResponse, error) {
defer wg.Done()

// Send multiple SessionUpdate notifications
for i := 0; i < numUpdates; i++ {
_ = ag.SessionUpdate(ctx, SessionNotification{
SessionId: p.SessionId,
Update: SessionUpdate{
AgentMessageChunk: &SessionUpdateAgentMessageChunk{
Content: TextBlock("chunk"),
},
},
})
}

// Small delay to ensure notifications are queued
time.Sleep(10 * time.Millisecond)

// Return response (this will unblock client's Prompt() call)
return PromptResponse{StopReason: "end_turn"}, nil
},
CancelFunc: func(context.Context, CancelNotification) error { return nil },
}, a2cW, c2aR)

if _, err := c.Initialize(context.Background(), InitializeRequest{ProtocolVersion: ProtocolVersionNumber}); err != nil {
t.Fatalf("initialize: %v", err)
}
sess, err := c.NewSession(context.Background(), NewSessionRequest{Cwd: "/", McpServers: []McpServer{}})
if err != nil {
t.Fatalf("newSession: %v", err)
}

_, err = c.Prompt(context.Background(), PromptRequest{
SessionId: sess.SessionId,
Prompt: []ContentBlock{TextBlock("test")},
})
if err != nil {
t.Fatalf("prompt: %v", err)
}

wg.Wait()

// Verify the expected behavior: at this point, Prompt() has returned, and all SessionUpdate
// handlers should have completed their processing.
// started := updateStarted.Load() ; Currently unsused but useful for debugging
completed := updateCompleted.Load()

// ASSERT: when Prompt() returns, all SessionUpdate notifications that were sent
// before the PromptResponse must have been fully processed. This is the semantic
// contract: the prompt operation includes all its updates.
if completed != numUpdates {
t.Fatalf("Prompt() returned with only %d/%d SessionUpdate "+
"handlers completed. Expected all handlers to complete before Prompt() "+
"returns.", completed, numUpdates)
}
}
20 changes: 19 additions & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ type Connection struct {
cancel context.CancelCauseFunc

logger *slog.Logger

// notificationWg tracks in-flight notification handlers. This ensures SendRequest waits
// for all notifications received before the response to complete processing.
notificationWg sync.WaitGroup
}

func NewConnection(handler MethodHandler, peerInput io.Writer, peerOutput io.Reader) *Connection {
Expand Down Expand Up @@ -94,7 +98,11 @@ func (c *Connection) receive() {
case msg.ID != nil && msg.Method == "":
c.handleResponse(&msg)
case msg.Method != "":
go c.handleInbound(&msg)
c.notificationWg.Add(1)
go func(m *anyMessage) {
defer c.notificationWg.Done()
c.handleInbound(m)
}(&msg)
default:
c.loggerOrDefault().Error("received message with neither id nor method", "raw", string(line))
}
Expand Down Expand Up @@ -193,6 +201,11 @@ func SendRequest[T any](c *Connection, ctx context.Context, method string, param
return result, err
}

// Wait for all notification handlers that were spawned before this response to complete
// processing. This ensures that when a request returns, all notifications sent by the
// server before the response have been fully processed.
c.notificationWg.Wait()

if resp.Error != nil {
return result, resp.Error
}
Expand Down Expand Up @@ -266,6 +279,11 @@ func (c *Connection) SendRequestNoResult(ctx context.Context, method string, par
return err
}

// Wait for all notification handlers that were spawned before this response to complete
// processing. This ensures that when a request returns, all notifications sent by the
// server before the response have been fully processed.
c.notificationWg.Wait()

if resp.Error != nil {
return resp.Error
}
Expand Down