diff --git a/mcp/client.go b/mcp/client.go index 74900b1c..04c06cfd 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -1106,6 +1106,14 @@ func (cs *ClientSession) NotifyProgress(ctx context.Context, params *ProgressNot return handleNotify(ctx, notificationProgress, newClientRequest(cs, orZero[Params](params))) } +// SendNotification sends a custom notification from the client to the server +// associated with this session. This is typically used for protocol extensions +// that send arbitrary JSON-RPC notifications. +func (cs *ClientSession) SendNotification(ctx context.Context, method string, params any) error { + cp := &customNotificationParams{payload: params} + return handleNotify(ctx, "x-notifications/"+method, newClientRequest(cs, Params(cp))) +} + // Tools provides an iterator for all tools available on the server, // automatically fetching pages and managing cursors. // The params argument can set the initial cursor. diff --git a/mcp/client_test.go b/mcp/client_test.go index fc37c3eb..52c0cbcc 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -13,6 +13,10 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/jsonschema-go/jsonschema" + "encoding/json" + "time" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) type Item struct { @@ -543,3 +547,74 @@ func TestClientCapabilitiesOverWire(t *testing.T) { }) } } + +func TestClientSendNotification(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := NewClient(&Implementation{Name: "testClient", Version: "1.0.0"}, nil) + cTrans, sTrans := NewInMemoryTransports() + + // Create channels to capture notifications + notifCh := make(chan *jsonrpc.Request, 1) // Using jsonrpc.Request since decode returns that + + // Intercept server transport to capture notifications + sConn, err := sTrans.Connect(ctx) + if err != nil { + t.Fatal(err) + } + go func() { + for { + msg, err := sConn.Read(ctx) + if err != nil { + return + } + if req, ok := msg.(*jsonrpc.Request); ok { + if req.Method == "initialize" { + resp, _ := jsonrpc2.NewResponse(req.ID, InitializeResult{ + ProtocolVersion: "2024-11-05", + ServerInfo: &Implementation{Name: "testServer", Version: "1.0.0"}, + }, nil) + sConn.Write(ctx, resp) + } else if req.Method != "notifications/initialized" && !req.IsCall() { + notifCh <- req + } + } + } + }() + + cs, err := client.Connect(ctx, cTrans, nil) + if err != nil { + t.Fatal(err) + } + + // Send a custom notification + type myParams struct { + Key string `json:"key"` + Val int `json:"val"` + } + + err = cs.SendNotification(ctx, "custom/myNotif", myParams{Key: "hello", Val: 42}) + if err != nil { + t.Fatalf("SendNotification failed: %v", err) + } + + // Wait for the notification to be received + select { + case req := <-notifCh: + if req.Method != "custom/myNotif" { + t.Errorf("got method %q, want %q", req.Method, "custom/myNotif") + } + + var gotParams myParams + if err := json.Unmarshal(req.Params, &gotParams); err != nil { + t.Fatalf("failed to unmarshal params: %v", err) + } + + if gotParams.Key != "hello" || gotParams.Val != 42 { + t.Errorf("got params %+v, want {Key: hello, Val: 42}", gotParams) + } + case <-time.After(time.Second * 2): + t.Fatal("timeout waiting for notification") + } +} diff --git a/mcp/server.go b/mcp/server.go index e3c03e27..a44d5867 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -1079,6 +1079,14 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot return handleNotify(ctx, notificationProgress, newServerRequest(ss, orZero[Params](params))) } +// SendNotification sends a custom notification from the server to the client +// associated with this session. This is typically used for protocol extensions +// that send arbitrary JSON-RPC notifications. +func (ss *ServerSession) SendNotification(ctx context.Context, method string, params any) error { + cp := &customNotificationParams{payload: params} + return handleNotify(ctx, "x-notifications/"+method, newServerRequest(ss, Params(cp))) +} + func newServerRequest[P Params](ss *ServerSession, params P) *ServerRequest[P] { return &ServerRequest[P]{Session: ss, Params: params} } diff --git a/mcp/server_test.go b/mcp/server_test.go index e57af1e2..3219d71c 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -18,6 +18,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) type testItem struct { @@ -1001,3 +1002,67 @@ func TestServerCapabilitiesOverWire(t *testing.T) { }) } } + +func TestSendNotification(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := NewServer(testImpl, nil) + cTrans, sTrans := NewInMemoryTransports() + + // Create channels to capture notifications + notifCh := make(chan *jsonrpc2.Request, 1) + + // Intercept client transport to capture notifications + cConn, err := cTrans.Connect(ctx) + if err != nil { + t.Fatal(err) + } + go func() { + for { + msg, err := cConn.Read(ctx) + if err != nil { + return + } + if req, ok := msg.(*jsonrpc.Request); ok && !req.IsCall() { + // Capture notifications (requests without ID) + notifCh <- req + } + } + }() + + ss, err := s.Connect(ctx, sTrans, nil) + if err != nil { + t.Fatal(err) + } + + // Send a custom notification + type diffParams struct { + File string `json:"file"` + Hunk int `json:"hunk"` + } + + err = ss.SendNotification(ctx, "ide/diffAccepted", diffParams{File: "main.go", Hunk: 2}) + if err != nil { + t.Fatalf("SendNotification failed: %v", err) + } + + // Wait for the notification to be received + select { + case req := <-notifCh: + if req.Method != "ide/diffAccepted" { + t.Errorf("got method %q, want %q", req.Method, "ide/diffAccepted") + } + + var gotParams diffParams + if err := json.Unmarshal(req.Params, &gotParams); err != nil { + t.Fatalf("failed to unmarshal params: %v", err) + } + + if gotParams.File != "main.go" || gotParams.Hunk != 2 { + t.Errorf("got params %+v, want {File: main.go, Hunk: 2}", gotParams) + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for notification") + } +} diff --git a/mcp/shared.go b/mcp/shared.go index f8c92502..667d6d74 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -90,6 +90,12 @@ func addMiddleware(handlerp *MethodHandler, middleware []Middleware) { } func defaultSendingMethodHandler(ctx context.Context, method string, req Request) (Result, error) { + // Custom notifications from SendNotification are prefixed with x-notifications/ + if strings.HasPrefix(method, "x-notifications/") { + actualMethod := strings.TrimPrefix(method, "x-notifications/") + return nil, req.GetSession().getConn().Notify(ctx, actualMethod, req.GetParams()) + } + info, ok := req.GetSession().sendingMethodInfos()[method] if !ok { // This can be called from user code, with an arbitrary value for method. @@ -244,6 +250,25 @@ const ( missingParamsOK // params may be missing or null ) +// customNotificationParams wraps arbitrary payload parameters so they can pass +// through the SDK middleware as Params while satisfying the internal interface. +type customNotificationParams struct { + payload any +} + +func (c *customNotificationParams) GetMeta() map[string]any { return nil } +func (c *customNotificationParams) SetMeta(map[string]any) {} +func (c *customNotificationParams) isParams() {} + +// MarshalJSON delegates JSON marshaling to the wrapped payload payload. +// If payload is nil, it marshals to an empty object "{}". +func (c customNotificationParams) MarshalJSON() ([]byte, error) { + if c.payload == nil { + return []byte("{}"), nil + } + return json.Marshal(c.payload) +} + func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHandler[P, R], flags methodFlags) methodInfo { mi := newMethodInfo[P, R](flags) mi.newRequest = func(s Session, p Params, _ *RequestExtra) Request {