Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/jsonrpc2/wire.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ var (
// server being temporarily unable to accept any new messages.
ErrServerOverloaded = NewError(-32000, "overloaded")
// ErrUnknown should be used for all non coded errors.
ErrUnknown = NewError(-32001, "unknown error")
ErrUnknown = NewError(-32099, "unknown error")
// ErrServerClosing is returned for calls that arrive while the server is closing.
ErrServerClosing = NewError(-32004, "server is closing")
// ErrClientClosing is a dummy error returned for calls initiated while the client is closing.
Expand Down
54 changes: 51 additions & 3 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ type ClientOptions struct {
// If the peer fails to respond to pings originating from the keepalive check,
// the session is automatically closed.
KeepAlive time.Duration
// ProtocolVersion is the version of the protocol to use.
// If empty, it defaults to the latest version.
ProtocolVersion string
// GetSessionID is the session ID to use for this client.
//
// If unset, no session ID will be used.
// Incompatible with protocol versions before 2025-11-30.
GetSessionID func() string
}

// bind implements the binder[*ClientSession] interface, so that Clients can
Expand Down Expand Up @@ -113,7 +121,11 @@ func (e unsupportedProtocolVersionError) Error() string {
}

// ClientSessionOptions is reserved for future use.
type ClientSessionOptions struct{}
type ClientSessionOptions struct {
// If Initialize is set, do initialization even when on protocol version
// 2025-11-30 or later.
Initialize bool
}

func (c *Client) capabilities() *ClientCapabilities {
caps := &ClientCapabilities{}
Expand All @@ -134,14 +146,34 @@ func (c *Client) capabilities() *ClientCapabilities {
// when it is no longer needed. However, if the connection is closed by the
// server, calls or notifications will return an error wrapping
// [ErrConnectionClosed].
func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptions) (cs *ClientSession, err error) {
func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOptions) (cs *ClientSession, err error) {
cs, err = connect(ctx, t, c, (*clientSessionState)(nil), nil)
if err != nil {
return nil, err
}

protocolVersion := c.opts.ProtocolVersion
if protocolVersion == "" {
protocolVersion = latestProtocolVersion
}

if compareProtocolVersions(protocolVersion, protocolVersion20251130) >= 0 && (opts == nil || !opts.Initialize) {
// For protocol versions >= 2025-11-30, skip the initialize handshake.
cs.state.ProtocolVersion = protocolVersion
if c.opts.GetSessionID != nil {
cs.state.SessionID = c.opts.GetSessionID()
}
if hc, ok := cs.mcpConn.(clientConnection); ok {
hc.sessionUpdated(cs.state)
}
if c.opts.KeepAlive > 0 {
cs.startKeepalive(c.opts.KeepAlive)
}
return cs, nil
}

params := &InitializeParams{
ProtocolVersion: latestProtocolVersion,
ProtocolVersion: protocolVersion,
ClientInfo: c.impl,
Capabilities: c.capabilities(),
}
Expand All @@ -155,6 +187,8 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio
return nil, unsupportedProtocolVersionError{res.ProtocolVersion}
}
cs.state.InitializeResult = res
cs.state.ProtocolVersion = res.ProtocolVersion
cs.state.SessionID = res.SessionID
if hc, ok := cs.mcpConn.(clientConnection); ok {
hc.sessionUpdated(cs.state)
}
Expand Down Expand Up @@ -196,17 +230,26 @@ type ClientSession struct {

type clientSessionState struct {
InitializeResult *InitializeResult
ProtocolVersion string
SessionID string
}

func (cs *ClientSession) InitializeResult() *InitializeResult { return cs.state.InitializeResult }

func (cs *ClientSession) ID() string {
if cs.state.SessionID != "" {
return cs.state.SessionID
}
if c, ok := cs.mcpConn.(hasSessionID); ok {
return c.SessionID()
}
return ""
}

func (cs *ClientSession) ProtocolVersion() string { return cs.state.ProtocolVersion }

func (cs *ClientSession) setProtocolVersion(v string) { cs.state.ProtocolVersion = v }

// Close performs a graceful close of the connection, preventing new requests
// from being handled, and waiting for ongoing requests to return. Close then
// terminates the connection.
Expand Down Expand Up @@ -686,6 +729,11 @@ func (cs *ClientSession) NotifyProgress(ctx context.Context, params *ProgressNot
return handleNotify(ctx, notificationProgress, newClientRequest(cs, orZero[Params](params)))
}

// Discover sends a "server/discover" request to the server and returns the result.
func (cs *ClientSession) Discover(ctx context.Context, params *DiscoverParams) (*DiscoverResult, error) {
return handleSend[*DiscoverResult](ctx, methodServerDiscover, newClientRequest(cs, orZero[Params](params)))
}

// Tools provides an iterator for all tools available on the server,
// automatically fetching pages and managing cursors.
// The params argument can set the initial cursor.
Expand Down
2 changes: 1 addition & 1 deletion mcp/mcp_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func Example_lifecycle() {
if err != nil {
log.Fatal(err)
}
clientSession, err := client.Connect(ctx, t2, nil)
clientSession, err := client.Connect(ctx, t2, &mcp.ClientSessionOptions{Initialize: true})
if err != nil {
log.Fatal(err)
}
Expand Down
7 changes: 4 additions & 3 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/jsonschema-go/jsonschema"
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
)
Expand Down Expand Up @@ -159,7 +160,7 @@ func TestEndToEnd(t *testing.T) {
c.AddRoots(&Root{URI: "file://" + rootAbs})

// Connect the client.
cs, err := c.Connect(ctx, ct, nil)
cs, err := c.Connect(ctx, ct, &ClientSessionOptions{Initialize: true})
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -405,7 +406,7 @@ func TestEndToEnd(t *testing.T) {
t.Fatal("timed out waiting for log messages")
}
}
if diff := cmp.Diff(want, got); diff != "" {
if diff := cmp.Diff(want, got, cmpopts.IgnoreFields(LoggingMessageParams{}, "Meta")); diff != "" {
t.Errorf("mismatch (-want, +got):\n%s", diff)
}
}
Expand Down Expand Up @@ -760,7 +761,7 @@ func TestMiddleware(t *testing.T) {
c.AddSendingMiddleware(traceCalls[*ClientSession](&cbuf, "S1"), traceCalls[*ClientSession](&cbuf, "S2"))
c.AddReceivingMiddleware(traceCalls[*ClientSession](&cbuf, "R1"), traceCalls[*ClientSession](&cbuf, "R2"))

cs, err := c.Connect(ctx, ct, nil)
cs, err := c.Connect(ctx, ct, &ClientSessionOptions{Initialize: true})
if err != nil {
t.Fatal(err)
}
Expand Down
25 changes: 25 additions & 0 deletions mcp/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,29 @@ func (r *CreateMessageResult) UnmarshalJSON(data []byte) error {
return nil
}

// DiscoverParams is sent from the client to the server to request information
// about the server's capabilities and other metadata.
type DiscoverParams struct {
// This property is reserved by the protocol to allow clients and servers to
// attach additional metadata to their responses.
Meta `json:"_meta,omitempty"`
}

func (*DiscoverParams) isParams() {}

// DiscoverResult is the server's response to a server/discover request.
type DiscoverResult struct {
// This property is reserved by the protocol to allow clients and servers to
// attach additional metadata to their responses.
Meta `json:"_meta,omitempty"`
ProtocolVersion string `json:"protocolVersion"`
ServerInfo *Implementation `json:"serverInfo"`
Capabilities *ServerCapabilities `json:"capabilities"`
Instructions string `json:"instructions,omitempty"`
}

func (*DiscoverResult) isResult() {}

type GetPromptParams struct {
// This property is reserved by the protocol to allow clients and servers to
// attach additional metadata to their responses.
Expand Down Expand Up @@ -406,6 +429,7 @@ type InitializeResult struct {
// support this version, it must disconnect.
ProtocolVersion string `json:"protocolVersion"`
ServerInfo *Implementation `json:"serverInfo"`
SessionID string `json:"sessionId,omitempty"`
}

func (*InitializeResult) isResult() {}
Expand Down Expand Up @@ -1162,4 +1186,5 @@ const (
methodSubscribe = "resources/subscribe"
notificationToolListChanged = "notifications/tools/list_changed"
methodUnsubscribe = "resources/unsubscribe"
methodServerDiscover = "server/discover"
)
74 changes: 61 additions & 13 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ type ServerOptions struct {
// even if no tools have been registered.
HasTools bool

// ProtocolVersion is the version of the protocol to use.
// If empty, it defaults to the latest version.
ProtocolVersion string

// GetSessionID provides the next session ID to use for an incoming request.
// If nil, a default randomly generated ID will be used.
//
Expand Down Expand Up @@ -980,6 +984,18 @@ func (ss *ServerSession) ID() string {
return ""
}

func (ss *ServerSession) ProtocolVersion() string {
protocolVersion := ss.server.opts.ProtocolVersion
if protocolVersion == "" {
return latestProtocolVersion
}
return protocolVersion
}

func (ss *ServerSession) setProtocolVersion(v string) {
ss.server.opts.ProtocolVersion = v
}

// Ping pings the client.
func (ss *ServerSession) Ping(ctx context.Context, params *PingParams) error {
_, err := handleSend[*emptyResult](ctx, methodPing, newServerRequest(ss, orZero[Params](params)))
Expand Down Expand Up @@ -1086,6 +1102,7 @@ var serverMethodInfos = map[string]methodInfo{
methodSetLevel: newServerMethodInfo(serverSessionMethod((*ServerSession).setLevel), 0),
methodSubscribe: newServerMethodInfo(serverMethod((*Server).subscribe), 0),
methodUnsubscribe: newServerMethodInfo(serverMethod((*Server).unsubscribe), 0),
methodServerDiscover: newServerMethodInfo(serverSessionMethod((*ServerSession).discover), missingParamsOK),
notificationCancelled: newServerMethodInfo(serverSessionMethod((*ServerSession).cancel), notification|missingParamsOK),
notificationInitialized: newServerMethodInfo(serverSessionMethod((*ServerSession).initialized), notification|missingParamsOK),
notificationRootsListChanged: newServerMethodInfo(serverMethod((*Server).callRootsListChangedHandler), notification|missingParamsOK),
Expand Down Expand Up @@ -1117,17 +1134,23 @@ func (ss *ServerSession) getConn() *jsonrpc2.Connection { return ss.conn }
func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) {
ss.mu.Lock()
initialized := ss.state.InitializeParams != nil
protocolVersion := ss.server.opts.ProtocolVersion
if protocolVersion == "" {
protocolVersion = latestProtocolVersion
}
ss.mu.Unlock()

// From the spec:
// "The client SHOULD NOT send requests other than pings before the server
// has responded to the initialize request."
switch req.Method {
case methodInitialize, methodPing, notificationInitialized:
default:
if !initialized {
ss.server.opts.Logger.Error("method invalid during initialization", "method", req.Method)
return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method)
if compareProtocolVersions(protocolVersion, protocolVersion20251130) < 0 {
switch req.Method {
case methodInitialize, methodPing, notificationInitialized:
default:
if !initialized {
ss.server.opts.Logger.Error("method invalid during initialization", "method", req.Method)
return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method)
}
}
}

Expand All @@ -1154,21 +1177,46 @@ func (ss *ServerSession) InitializeParams() *InitializeParams {
}

func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParams) (*InitializeResult, error) {
if params == nil {
return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams)
protocolVersion := ss.server.opts.ProtocolVersion
if protocolVersion == "" {
protocolVersion = latestProtocolVersion
}

// For older protocol versions, the initialize handshake is required.
if compareProtocolVersions(protocolVersion, protocolVersion20251130) < 0 {
if params == nil {
return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams)
}
ss.updateState(func(state *ServerSessionState) {
state.InitializeParams = params
})
} else {
// For protocol versions >= 2025-11-30, the initialize handshake is optional.
// If params are provided, we process them.
if params != nil {
ss.updateState(func(state *ServerSessionState) {
state.InitializeParams = params
})
}
}
ss.updateState(func(state *ServerSessionState) {
state.InitializeParams = params
})

s := ss.server
return &InitializeResult{
// TODO(rfindley): alter behavior when falling back to an older version:
// reject unsupported features.
ProtocolVersion: negotiatedVersion(params.ProtocolVersion),
Capabilities: s.capabilities(),
Instructions: s.opts.Instructions,
ServerInfo: s.impl,
SessionID: ss.ID(),
}, nil
}

func (ss *ServerSession) discover(ctx context.Context, req *DiscoverParams) (*DiscoverResult, error) {
s := ss.server
return &DiscoverResult{
ProtocolVersion: ss.ProtocolVersion(),
ServerInfo: s.impl,
Capabilities: s.capabilities(),
Instructions: s.opts.Instructions,
}, nil
}

Expand Down
Loading