Skip to content

Commit f4f1d6a

Browse files
committed
mcp: prototype implementation for SEP 1442
This PR is the result of spending a few hours going through and attempting to implement SEP 1442 (modelcontextprotocol/modelcontextprotocol#1442). Implemented: - skipping initialization - new protocol version and session ID treatment - server/discover - unsupported version errors - new _meta fields - client capability embedding Still TODO: - client-initiated streams - ergonomic and documentation for handling capabilities from application code - many, many more tests (and corresponding bug fixes) This is very much quick and dirty, though I learned a lot about the SEP in the process. Tests pass, albeit largely because I configure the client to keep initialization, or downgrade the protocol version. Notably, server->client requests do work in the context of an uninitialized session. Additional observations will be noted in PR comments.
1 parent d297272 commit f4f1d6a

File tree

14 files changed

+386
-105
lines changed

14 files changed

+386
-105
lines changed

internal/jsonrpc2/wire.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ var (
3232
// server being temporarily unable to accept any new messages.
3333
ErrServerOverloaded = NewError(-32000, "overloaded")
3434
// ErrUnknown should be used for all non coded errors.
35-
ErrUnknown = NewError(-32001, "unknown error")
35+
ErrUnknown = NewError(-32099, "unknown error")
3636
// ErrServerClosing is returned for calls that arrive while the server is closing.
3737
ErrServerClosing = NewError(-32004, "server is closing")
3838
// ErrClientClosing is a dummy error returned for calls initiated while the client is closing.

mcp/client.go

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,14 @@ type ClientOptions struct {
7777
// If the peer fails to respond to pings originating from the keepalive check,
7878
// the session is automatically closed.
7979
KeepAlive time.Duration
80+
// ProtocolVersion is the version of the protocol to use.
81+
// If empty, it defaults to the latest version.
82+
ProtocolVersion string
83+
// GetSessionID is the session ID to use for this client.
84+
//
85+
// If unset, no session ID will be used.
86+
// Incompatible with protocol versions before 2025-11-30.
87+
GetSessionID func() string
8088
}
8189

8290
// bind implements the binder[*ClientSession] interface, so that Clients can
@@ -113,7 +121,11 @@ func (e unsupportedProtocolVersionError) Error() string {
113121
}
114122

115123
// ClientSessionOptions is reserved for future use.
116-
type ClientSessionOptions struct{}
124+
type ClientSessionOptions struct {
125+
// If Initialize is set, do initialization even when on protocol version
126+
// 2025-11-30 or later.
127+
Initialize bool
128+
}
117129

118130
func (c *Client) capabilities() *ClientCapabilities {
119131
caps := &ClientCapabilities{}
@@ -134,14 +146,34 @@ func (c *Client) capabilities() *ClientCapabilities {
134146
// when it is no longer needed. However, if the connection is closed by the
135147
// server, calls or notifications will return an error wrapping
136148
// [ErrConnectionClosed].
137-
func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptions) (cs *ClientSession, err error) {
149+
func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOptions) (cs *ClientSession, err error) {
138150
cs, err = connect(ctx, t, c, (*clientSessionState)(nil), nil)
139151
if err != nil {
140152
return nil, err
141153
}
142154

155+
protocolVersion := c.opts.ProtocolVersion
156+
if protocolVersion == "" {
157+
protocolVersion = latestProtocolVersion
158+
}
159+
160+
if compareProtocolVersions(protocolVersion, protocolVersion20251130) >= 0 && (opts == nil || !opts.Initialize) {
161+
// For protocol versions >= 2025-11-30, skip the initialize handshake.
162+
cs.state.ProtocolVersion = protocolVersion
163+
if c.opts.GetSessionID != nil {
164+
cs.state.SessionID = c.opts.GetSessionID()
165+
}
166+
if hc, ok := cs.mcpConn.(clientConnection); ok {
167+
hc.sessionUpdated(cs.state)
168+
}
169+
if c.opts.KeepAlive > 0 {
170+
cs.startKeepalive(c.opts.KeepAlive)
171+
}
172+
return cs, nil
173+
}
174+
143175
params := &InitializeParams{
144-
ProtocolVersion: latestProtocolVersion,
176+
ProtocolVersion: protocolVersion,
145177
ClientInfo: c.impl,
146178
Capabilities: c.capabilities(),
147179
}
@@ -155,6 +187,8 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio
155187
return nil, unsupportedProtocolVersionError{res.ProtocolVersion}
156188
}
157189
cs.state.InitializeResult = res
190+
cs.state.ProtocolVersion = res.ProtocolVersion
191+
cs.state.SessionID = res.SessionID
158192
if hc, ok := cs.mcpConn.(clientConnection); ok {
159193
hc.sessionUpdated(cs.state)
160194
}
@@ -196,17 +230,26 @@ type ClientSession struct {
196230

197231
type clientSessionState struct {
198232
InitializeResult *InitializeResult
233+
ProtocolVersion string
234+
SessionID string
199235
}
200236

201237
func (cs *ClientSession) InitializeResult() *InitializeResult { return cs.state.InitializeResult }
202238

203239
func (cs *ClientSession) ID() string {
240+
if cs.state.SessionID != "" {
241+
return cs.state.SessionID
242+
}
204243
if c, ok := cs.mcpConn.(hasSessionID); ok {
205244
return c.SessionID()
206245
}
207246
return ""
208247
}
209248

249+
func (cs *ClientSession) ProtocolVersion() string { return cs.state.ProtocolVersion }
250+
251+
func (cs *ClientSession) setProtocolVersion(v string) { cs.state.ProtocolVersion = v }
252+
210253
// Close performs a graceful close of the connection, preventing new requests
211254
// from being handled, and waiting for ongoing requests to return. Close then
212255
// terminates the connection.
@@ -686,6 +729,11 @@ func (cs *ClientSession) NotifyProgress(ctx context.Context, params *ProgressNot
686729
return handleNotify(ctx, notificationProgress, newClientRequest(cs, orZero[Params](params)))
687730
}
688731

732+
// Discover sends a "server/discover" request to the server and returns the result.
733+
func (cs *ClientSession) Discover(ctx context.Context, params *DiscoverParams) (*DiscoverResult, error) {
734+
return handleSend[*DiscoverResult](ctx, methodServerDiscover, newClientRequest(cs, orZero[Params](params)))
735+
}
736+
689737
// Tools provides an iterator for all tools available on the server,
690738
// automatically fetching pages and managing cursors.
691739
// The params argument can set the initial cursor.

mcp/mcp_example_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func Example_lifecycle() {
3737
if err != nil {
3838
log.Fatal(err)
3939
}
40-
clientSession, err := client.Connect(ctx, t2, nil)
40+
clientSession, err := client.Connect(ctx, t2, &mcp.ClientSessionOptions{Initialize: true})
4141
if err != nil {
4242
log.Fatal(err)
4343
}

mcp/mcp_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"time"
2424

2525
"github.com/google/go-cmp/cmp"
26+
"github.com/google/go-cmp/cmp/cmpopts"
2627
"github.com/google/jsonschema-go/jsonschema"
2728
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
2829
)
@@ -159,7 +160,7 @@ func TestEndToEnd(t *testing.T) {
159160
c.AddRoots(&Root{URI: "file://" + rootAbs})
160161

161162
// Connect the client.
162-
cs, err := c.Connect(ctx, ct, nil)
163+
cs, err := c.Connect(ctx, ct, &ClientSessionOptions{Initialize: true})
163164
if err != nil {
164165
t.Fatal(err)
165166
}
@@ -405,7 +406,7 @@ func TestEndToEnd(t *testing.T) {
405406
t.Fatal("timed out waiting for log messages")
406407
}
407408
}
408-
if diff := cmp.Diff(want, got); diff != "" {
409+
if diff := cmp.Diff(want, got, cmpopts.IgnoreFields(LoggingMessageParams{}, "Meta")); diff != "" {
409410
t.Errorf("mismatch (-want, +got):\n%s", diff)
410411
}
411412
}
@@ -760,7 +761,7 @@ func TestMiddleware(t *testing.T) {
760761
c.AddSendingMiddleware(traceCalls[*ClientSession](&cbuf, "S1"), traceCalls[*ClientSession](&cbuf, "S2"))
761762
c.AddReceivingMiddleware(traceCalls[*ClientSession](&cbuf, "R1"), traceCalls[*ClientSession](&cbuf, "R2"))
762763

763-
cs, err := c.Connect(ctx, ct, nil)
764+
cs, err := c.Connect(ctx, ct, &ClientSessionOptions{Initialize: true})
764765
if err != nil {
765766
t.Fatal(err)
766767
}

mcp/protocol.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,29 @@ func (r *CreateMessageResult) UnmarshalJSON(data []byte) error {
347347
return nil
348348
}
349349

350+
// DiscoverParams is sent from the client to the server to request information
351+
// about the server's capabilities and other metadata.
352+
type DiscoverParams struct {
353+
// This property is reserved by the protocol to allow clients and servers to
354+
// attach additional metadata to their responses.
355+
Meta `json:"_meta,omitempty"`
356+
}
357+
358+
func (*DiscoverParams) isParams() {}
359+
360+
// DiscoverResult is the server's response to a server/discover request.
361+
type DiscoverResult struct {
362+
// This property is reserved by the protocol to allow clients and servers to
363+
// attach additional metadata to their responses.
364+
Meta `json:"_meta,omitempty"`
365+
ProtocolVersion string `json:"protocolVersion"`
366+
ServerInfo *Implementation `json:"serverInfo"`
367+
Capabilities *ServerCapabilities `json:"capabilities"`
368+
Instructions string `json:"instructions,omitempty"`
369+
}
370+
371+
func (*DiscoverResult) isResult() {}
372+
350373
type GetPromptParams struct {
351374
// This property is reserved by the protocol to allow clients and servers to
352375
// attach additional metadata to their responses.
@@ -406,6 +429,7 @@ type InitializeResult struct {
406429
// support this version, it must disconnect.
407430
ProtocolVersion string `json:"protocolVersion"`
408431
ServerInfo *Implementation `json:"serverInfo"`
432+
SessionID string `json:"sessionId,omitempty"`
409433
}
410434

411435
func (*InitializeResult) isResult() {}
@@ -1162,4 +1186,5 @@ const (
11621186
methodSubscribe = "resources/subscribe"
11631187
notificationToolListChanged = "notifications/tools/list_changed"
11641188
methodUnsubscribe = "resources/unsubscribe"
1189+
methodServerDiscover = "server/discover"
11651190
)

mcp/server.go

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ type ServerOptions struct {
8989
// even if no tools have been registered.
9090
HasTools bool
9191

92+
// ProtocolVersion is the version of the protocol to use.
93+
// If empty, it defaults to the latest version.
94+
ProtocolVersion string
95+
9296
// GetSessionID provides the next session ID to use for an incoming request.
9397
// If nil, a default randomly generated ID will be used.
9498
//
@@ -980,6 +984,18 @@ func (ss *ServerSession) ID() string {
980984
return ""
981985
}
982986

987+
func (ss *ServerSession) ProtocolVersion() string {
988+
protocolVersion := ss.server.opts.ProtocolVersion
989+
if protocolVersion == "" {
990+
return latestProtocolVersion
991+
}
992+
return protocolVersion
993+
}
994+
995+
func (ss *ServerSession) setProtocolVersion(v string) {
996+
ss.server.opts.ProtocolVersion = v
997+
}
998+
983999
// Ping pings the client.
9841000
func (ss *ServerSession) Ping(ctx context.Context, params *PingParams) error {
9851001
_, err := handleSend[*emptyResult](ctx, methodPing, newServerRequest(ss, orZero[Params](params)))
@@ -1086,6 +1102,7 @@ var serverMethodInfos = map[string]methodInfo{
10861102
methodSetLevel: newServerMethodInfo(serverSessionMethod((*ServerSession).setLevel), 0),
10871103
methodSubscribe: newServerMethodInfo(serverMethod((*Server).subscribe), 0),
10881104
methodUnsubscribe: newServerMethodInfo(serverMethod((*Server).unsubscribe), 0),
1105+
methodServerDiscover: newServerMethodInfo(serverSessionMethod((*ServerSession).discover), missingParamsOK),
10891106
notificationCancelled: newServerMethodInfo(serverSessionMethod((*ServerSession).cancel), notification|missingParamsOK),
10901107
notificationInitialized: newServerMethodInfo(serverSessionMethod((*ServerSession).initialized), notification|missingParamsOK),
10911108
notificationRootsListChanged: newServerMethodInfo(serverMethod((*Server).callRootsListChangedHandler), notification|missingParamsOK),
@@ -1117,17 +1134,23 @@ func (ss *ServerSession) getConn() *jsonrpc2.Connection { return ss.conn }
11171134
func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) {
11181135
ss.mu.Lock()
11191136
initialized := ss.state.InitializeParams != nil
1137+
protocolVersion := ss.server.opts.ProtocolVersion
1138+
if protocolVersion == "" {
1139+
protocolVersion = latestProtocolVersion
1140+
}
11201141
ss.mu.Unlock()
11211142

11221143
// From the spec:
11231144
// "The client SHOULD NOT send requests other than pings before the server
11241145
// has responded to the initialize request."
1125-
switch req.Method {
1126-
case methodInitialize, methodPing, notificationInitialized:
1127-
default:
1128-
if !initialized {
1129-
ss.server.opts.Logger.Error("method invalid during initialization", "method", req.Method)
1130-
return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method)
1146+
if compareProtocolVersions(protocolVersion, protocolVersion20251130) < 0 {
1147+
switch req.Method {
1148+
case methodInitialize, methodPing, notificationInitialized:
1149+
default:
1150+
if !initialized {
1151+
ss.server.opts.Logger.Error("method invalid during initialization", "method", req.Method)
1152+
return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method)
1153+
}
11311154
}
11321155
}
11331156

@@ -1154,21 +1177,46 @@ func (ss *ServerSession) InitializeParams() *InitializeParams {
11541177
}
11551178

11561179
func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParams) (*InitializeResult, error) {
1157-
if params == nil {
1158-
return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams)
1180+
protocolVersion := ss.server.opts.ProtocolVersion
1181+
if protocolVersion == "" {
1182+
protocolVersion = latestProtocolVersion
1183+
}
1184+
1185+
// For older protocol versions, the initialize handshake is required.
1186+
if compareProtocolVersions(protocolVersion, protocolVersion20251130) < 0 {
1187+
if params == nil {
1188+
return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams)
1189+
}
1190+
ss.updateState(func(state *ServerSessionState) {
1191+
state.InitializeParams = params
1192+
})
1193+
} else {
1194+
// For protocol versions >= 2025-11-30, the initialize handshake is optional.
1195+
// If params are provided, we process them.
1196+
if params != nil {
1197+
ss.updateState(func(state *ServerSessionState) {
1198+
state.InitializeParams = params
1199+
})
1200+
}
11591201
}
1160-
ss.updateState(func(state *ServerSessionState) {
1161-
state.InitializeParams = params
1162-
})
11631202

11641203
s := ss.server
11651204
return &InitializeResult{
1166-
// TODO(rfindley): alter behavior when falling back to an older version:
1167-
// reject unsupported features.
11681205
ProtocolVersion: negotiatedVersion(params.ProtocolVersion),
11691206
Capabilities: s.capabilities(),
11701207
Instructions: s.opts.Instructions,
11711208
ServerInfo: s.impl,
1209+
SessionID: ss.ID(),
1210+
}, nil
1211+
}
1212+
1213+
func (ss *ServerSession) discover(ctx context.Context, req *DiscoverParams) (*DiscoverResult, error) {
1214+
s := ss.server
1215+
return &DiscoverResult{
1216+
ProtocolVersion: ss.ProtocolVersion(),
1217+
ServerInfo: s.impl,
1218+
Capabilities: s.capabilities(),
1219+
Instructions: s.opts.Instructions,
11721220
}, nil
11731221
}
11741222

0 commit comments

Comments
 (0)