Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,14 @@ func (cs *ClientSession) NotifyProgress(ctx context.Context, params *ProgressNot
return handleNotify(ctx, notificationProgress, newClientRequest(cs, orZero[Params](params)))
}

// SendNotification sends a custom notification from the client to the server
// associated with this session. This is typically used for protocol extensions
// that send arbitrary JSON-RPC notifications.
func (cs *ClientSession) SendNotification(ctx context.Context, method string, params any) error {
cp := &customNotificationParams{payload: params}
return handleNotify(ctx, "x-notifications/"+method, newClientRequest(cs, Params(cp)))
}

// Tools provides an iterator for all tools available on the server,
// automatically fetching pages and managing cursors.
// The params argument can set the initial cursor.
Expand Down
75 changes: 75 additions & 0 deletions mcp/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/jsonschema-go/jsonschema"
"encoding/json"
"time"
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
)

type Item struct {
Expand Down Expand Up @@ -543,3 +547,74 @@ func TestClientCapabilitiesOverWire(t *testing.T) {
})
}
}

func TestClientSendNotification(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

client := NewClient(&Implementation{Name: "testClient", Version: "1.0.0"}, nil)
cTrans, sTrans := NewInMemoryTransports()

// Create channels to capture notifications
notifCh := make(chan *jsonrpc.Request, 1) // Using jsonrpc.Request since decode returns that

// Intercept server transport to capture notifications
sConn, err := sTrans.Connect(ctx)
if err != nil {
t.Fatal(err)
}
go func() {
for {
msg, err := sConn.Read(ctx)
if err != nil {
return
}
if req, ok := msg.(*jsonrpc.Request); ok {
if req.Method == "initialize" {
resp, _ := jsonrpc2.NewResponse(req.ID, InitializeResult{
ProtocolVersion: "2024-11-05",
ServerInfo: &Implementation{Name: "testServer", Version: "1.0.0"},
}, nil)
sConn.Write(ctx, resp)
} else if req.Method != "notifications/initialized" && !req.IsCall() {
notifCh <- req
}
}
}
}()

cs, err := client.Connect(ctx, cTrans, nil)
if err != nil {
t.Fatal(err)
}

// Send a custom notification
type myParams struct {
Key string `json:"key"`
Val int `json:"val"`
}

err = cs.SendNotification(ctx, "custom/myNotif", myParams{Key: "hello", Val: 42})
if err != nil {
t.Fatalf("SendNotification failed: %v", err)
}

// Wait for the notification to be received
select {
case req := <-notifCh:
if req.Method != "custom/myNotif" {
t.Errorf("got method %q, want %q", req.Method, "custom/myNotif")
}

var gotParams myParams
if err := json.Unmarshal(req.Params, &gotParams); err != nil {
t.Fatalf("failed to unmarshal params: %v", err)
}

if gotParams.Key != "hello" || gotParams.Val != 42 {
t.Errorf("got params %+v, want {Key: hello, Val: 42}", gotParams)
}
case <-time.After(time.Second * 2):
t.Fatal("timeout waiting for notification")
}
}
8 changes: 8 additions & 0 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,14 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot
return handleNotify(ctx, notificationProgress, newServerRequest(ss, orZero[Params](params)))
}

// SendNotification sends a custom notification from the server to the client
// associated with this session. This is typically used for protocol extensions
// that send arbitrary JSON-RPC notifications.
func (ss *ServerSession) SendNotification(ctx context.Context, method string, params any) error {
cp := &customNotificationParams{payload: params}
return handleNotify(ctx, "x-notifications/"+method, newServerRequest(ss, Params(cp)))
}

func newServerRequest[P Params](ss *ServerSession, params P) *ServerRequest[P] {
return &ServerRequest[P]{Session: ss, Params: params}
}
Expand Down
65 changes: 65 additions & 0 deletions mcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/jsonschema-go/jsonschema"
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
)

type testItem struct {
Expand Down Expand Up @@ -1001,3 +1002,67 @@ func TestServerCapabilitiesOverWire(t *testing.T) {
})
}
}

func TestSendNotification(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

s := NewServer(testImpl, nil)
cTrans, sTrans := NewInMemoryTransports()

// Create channels to capture notifications
notifCh := make(chan *jsonrpc2.Request, 1)

// Intercept client transport to capture notifications
cConn, err := cTrans.Connect(ctx)
if err != nil {
t.Fatal(err)
}
go func() {
for {
msg, err := cConn.Read(ctx)
if err != nil {
return
}
if req, ok := msg.(*jsonrpc.Request); ok && !req.IsCall() {
// Capture notifications (requests without ID)
notifCh <- req
}
}
}()

ss, err := s.Connect(ctx, sTrans, nil)
if err != nil {
t.Fatal(err)
}

// Send a custom notification
type diffParams struct {
File string `json:"file"`
Hunk int `json:"hunk"`
}

err = ss.SendNotification(ctx, "ide/diffAccepted", diffParams{File: "main.go", Hunk: 2})
if err != nil {
t.Fatalf("SendNotification failed: %v", err)
}

// Wait for the notification to be received
select {
case req := <-notifCh:
if req.Method != "ide/diffAccepted" {
t.Errorf("got method %q, want %q", req.Method, "ide/diffAccepted")
}

var gotParams diffParams
if err := json.Unmarshal(req.Params, &gotParams); err != nil {
t.Fatalf("failed to unmarshal params: %v", err)
}

if gotParams.File != "main.go" || gotParams.Hunk != 2 {
t.Errorf("got params %+v, want {File: main.go, Hunk: 2}", gotParams)
}
case <-time.After(time.Second):
t.Fatal("timeout waiting for notification")
}
}
25 changes: 25 additions & 0 deletions mcp/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ func addMiddleware(handlerp *MethodHandler, middleware []Middleware) {
}

func defaultSendingMethodHandler(ctx context.Context, method string, req Request) (Result, error) {
// Custom notifications from SendNotification are prefixed with x-notifications/
if strings.HasPrefix(method, "x-notifications/") {
actualMethod := strings.TrimPrefix(method, "x-notifications/")
return nil, req.GetSession().getConn().Notify(ctx, actualMethod, req.GetParams())
}

info, ok := req.GetSession().sendingMethodInfos()[method]
if !ok {
// This can be called from user code, with an arbitrary value for method.
Expand Down Expand Up @@ -244,6 +250,25 @@ const (
missingParamsOK // params may be missing or null
)

// customNotificationParams wraps arbitrary payload parameters so they can pass
// through the SDK middleware as Params while satisfying the internal interface.
type customNotificationParams struct {
payload any
}

func (c *customNotificationParams) GetMeta() map[string]any { return nil }
func (c *customNotificationParams) SetMeta(map[string]any) {}
func (c *customNotificationParams) isParams() {}

// MarshalJSON delegates JSON marshaling to the wrapped payload payload.
// If payload is nil, it marshals to an empty object "{}".
func (c customNotificationParams) MarshalJSON() ([]byte, error) {
if c.payload == nil {
return []byte("{}"), nil
}
return json.Marshal(c.payload)
}

func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHandler[P, R], flags methodFlags) methodInfo {
mi := newMethodInfo[P, R](flags)
mi.newRequest = func(s Session, p Params, _ *RequestExtra) Request {
Expand Down