From 1d2825c4f99d66dcb5c7ce0f220a995ac49c5ed5 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Tue, 17 Feb 2026 13:11:09 +0000 Subject: [PATCH] mcp: simplify and unify unit tests introduced for sampling with tools. --- mcp/protocol_test.go | 581 ++++++++++++++++++++++++------------------- mcp/sampling_test.go | 216 +++++----------- 2 files changed, 394 insertions(+), 403 deletions(-) diff --git a/mcp/protocol_test.go b/mcp/protocol_test.go index b33e5f50..20f2bb87 100644 --- a/mcp/protocol_test.go +++ b/mcp/protocol_test.go @@ -7,7 +7,6 @@ package mcp import ( "encoding/json" "maps" - "reflect" "testing" "github.com/google/go-cmp/cmp" @@ -489,11 +488,12 @@ func TestCompleteResult(t *testing.T) { } } +// TODO: merge the following 4 tests into content_test.go. func TestToolUseContent_MarshalJSON(t *testing.T) { tests := []struct { name string content *ToolUseContent - want *ToolUseContent + want string }{ { name: "basic tool use", @@ -506,15 +506,7 @@ func TestToolUseContent_MarshalJSON(t *testing.T) { "y": 2.0, }, }, - want: &ToolUseContent{ - ID: "tool_123", - Name: "calculator", - Input: map[string]any{ - "operation": "add", - "x": 1.0, - "y": 2.0, - }, - }, + want: `{"type":"tool_use","id":"tool_123","name":"calculator","input":{"operation":"add","x":1,"y":2}}`, }, { name: "nil input marshals as empty object", @@ -523,29 +515,17 @@ func TestToolUseContent_MarshalJSON(t *testing.T) { Name: "no_args_tool", Input: nil, }, - want: &ToolUseContent{ - ID: "tool_456", - Name: "no_args_tool", - Input: map[string]any{}, - }, + want: `{"type":"tool_use","id":"tool_456","name":"no_args_tool","input":{}}`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - data, err := tt.content.MarshalJSON() + got, err := tt.content.MarshalJSON() if err != nil { t.Fatalf("MarshalJSON() error = %v", err) } - wire := &wireContent{} - if err := json.Unmarshal(data, wire); err != nil { - t.Fatalf("Unmarshal wire error = %v", err) - } - got, err := contentFromWire(wire, map[string]bool{"tool_use": true}) - if err != nil { - t.Fatalf("contentFromWire() error = %v", err) - } - if diff := cmp.Diff(tt.want, got); diff != "" { + if diff := cmp.Diff(tt.want, string(got)); diff != "" { t.Errorf("mismatch (-want +got):\n%s", diff) } }) @@ -635,11 +615,11 @@ func TestToolResultContent_MarshalJSON(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - data, err := tt.content.MarshalJSON() + got, err := tt.content.MarshalJSON() if err != nil { t.Fatalf("MarshalJSON() error = %v", err) } - if diff := cmp.Diff(tt.want, string(data)); diff != "" { + if diff := cmp.Diff(tt.want, string(got)); diff != "" { t.Errorf("mismatch (-want +got):\n%s", diff) } }) @@ -648,44 +628,25 @@ func TestToolResultContent_MarshalJSON(t *testing.T) { func TestToolResultContent_UnmarshalJSON(t *testing.T) { tests := []struct { - name string - json string - wantID string - wantErr bool - checkFn func(t *testing.T, got *ToolResultContent) + name string + json string + want *ToolResultContent }{ { - name: "basic tool result", - json: `{"type":"tool_result","toolUseId":"tool_123","content":[{"type":"text","text":"42"}],"isError":false}`, - wantID: "tool_123", - checkFn: func(t *testing.T, got *ToolResultContent) { - if len(got.Content) != 1 { - t.Fatalf("len(Content) = %d, want 1", len(got.Content)) - } - tc, ok := got.Content[0].(*TextContent) - if !ok { - t.Fatalf("Content[0] type = %T, want *TextContent", got.Content[0]) - } - if tc.Text != "42" { - t.Errorf("Text = %v, want 42", tc.Text) - } + name: "basic tool result", + json: `{"type":"tool_result","toolUseId":"tool_123","content":[{"type":"text","text":"42"}],"isError":false}`, + want: &ToolResultContent{ + ToolUseID: "tool_123", + Content: []Content{&TextContent{Text: "42"}}, + IsError: false, }, }, { - name: "image nested content", - json: `{"type":"tool_result","toolUseId":"t1","content":[{"type":"image","mimeType":"image/png","data":"YWJj"}]}`, - wantID: "t1", - checkFn: func(t *testing.T, got *ToolResultContent) { - if len(got.Content) != 1 { - t.Fatalf("len(Content) = %d, want 1", len(got.Content)) - } - img, ok := got.Content[0].(*ImageContent) - if !ok { - t.Fatalf("Content[0] type = %T, want *ImageContent", got.Content[0]) - } - if img.MIMEType != "image/png" { - t.Errorf("MIMEType = %v, want image/png", img.MIMEType) - } + name: "image nested content", + json: `{"type":"tool_result","toolUseId":"t1","content":[{"type":"image","mimeType":"image/png","data":"YWJj"}]}`, + want: &ToolResultContent{ + ToolUseID: "t1", + Content: []Content{&ImageContent{MIMEType: "image/png", Data: []byte("abc")}}, }, }, } @@ -697,21 +658,11 @@ func TestToolResultContent_UnmarshalJSON(t *testing.T) { t.Fatalf("Unmarshal() error = %v", err) } got, err := contentFromWire(wire, map[string]bool{"tool_result": true}) - if (err != nil) != tt.wantErr { - t.Fatalf("contentFromWire() error = %v, wantErr %v", err, tt.wantErr) - } - if tt.wantErr { - return - } - result, ok := got.(*ToolResultContent) - if !ok { - t.Fatalf("type = %T, want *ToolResultContent", got) - } - if result.ToolUseID != tt.wantID { - t.Errorf("ToolUseID = %v, want %v", result.ToolUseID, tt.wantID) + if err != nil { + t.Fatalf("contentFromWire() error = %v", err) } - if tt.checkFn != nil { - tt.checkFn(t, result) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("mismatch (-want +got):\n%s", diff) } }) } @@ -719,79 +670,86 @@ func TestToolResultContent_UnmarshalJSON(t *testing.T) { func TestSamplingMessage_UnmarshalJSON(t *testing.T) { tests := []struct { - name string - json string - wantRole Role - wantType string // expected Content type name + name string + json string + want *SamplingMessage }{ { - name: "tool_use content", - json: `{"content":{"type":"tool_use","id":"tool_1","name":"calc","input":{}},"role":"assistant"}`, - wantRole: "assistant", - wantType: "*mcp.ToolUseContent", + name: "tool_use content", + json: `{"content":{"type":"tool_use","id":"tool_1","name":"calc","input":{}},"role":"assistant"}`, + want: &SamplingMessage{ + Role: "assistant", + Content: &ToolUseContent{ + ID: "tool_1", + Name: "calc", + Input: map[string]any{}, + }, + }, }, { - name: "tool_result content", - json: `{"content":{"type":"tool_result","toolUseId":"tool_1","content":[{"type":"text","text":"42"}]},"role":"user"}`, - wantRole: "user", - wantType: "*mcp.ToolResultContent", + name: "tool_result content", + json: `{"content":{"type":"tool_result","toolUseId":"tool_1","content":[{"type":"text","text":"42"}]},"role":"user"}`, + want: &SamplingMessage{ + Role: "user", + Content: &ToolResultContent{ + ToolUseID: "tool_1", + Content: []Content{&TextContent{Text: "42"}}, + }, + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var msg SamplingMessage - if err := json.Unmarshal([]byte(tt.json), &msg); err != nil { + var got SamplingMessage + if err := json.Unmarshal([]byte(tt.json), &got); err != nil { t.Fatalf("Unmarshal() error = %v", err) } - if msg.Role != tt.wantRole { - t.Errorf("Role = %v, want %v", msg.Role, tt.wantRole) - } - gotType := reflect.TypeOf(msg.Content).String() - if gotType != tt.wantType { - t.Errorf("Content type = %v, want %v", gotType, tt.wantType) + if diff := cmp.Diff(tt.want, &got); diff != "" { + t.Errorf("mismatch (-want +got):\n%s", diff) } }) } } -func TestSamplingCapabilities_MarshalJSON(t *testing.T) { - caps := &SamplingCapabilities{ - Tools: &SamplingToolsCapabilities{}, - Context: &SamplingContextCapabilities{}, - } - data, err := json.Marshal(caps) - if err != nil { - t.Fatalf("Marshal() error = %v", err) - } - var got SamplingCapabilities - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("Unmarshal() error = %v", err) - } - if got.Tools == nil { - t.Error("Tools capability should not be nil") - } - if got.Context == nil { - t.Error("Context capability should not be nil") +func TestSamplingCapabilities_MarshalUnmarshalJSON(t *testing.T) { + tests := []struct { + name string + caps *SamplingCapabilities + json string + }{ + { + name: "WithCapabilities", + caps: &SamplingCapabilities{ + Tools: &SamplingToolsCapabilities{}, + Context: &SamplingContextCapabilities{}, + }, + json: `{"context":{},"tools":{}}`, + }, + { + name: "EmptyCapabilities", + caps: &SamplingCapabilities{}, + json: `{}`, + }, } -} -func TestSamplingCapabilities_UnmarshalJSON(t *testing.T) { - // Empty struct should marshal/unmarshal correctly (backward compatibility). - caps := &SamplingCapabilities{} - data, err := json.Marshal(caps) - if err != nil { - t.Fatalf("Marshal() error = %v", err) - } - var got SamplingCapabilities - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("Unmarshal() error = %v", err) - } - if got.Tools != nil { - t.Error("Tools capability should be nil for empty capabilities") - } - if got.Context != nil { - t.Error("Context capability should be nil for empty capabilities") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotJson, err := json.Marshal(tt.caps) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + if diff := cmp.Diff(tt.json, string(gotJson)); diff != "" { + t.Errorf("mismatch (-want +got):\n%s", diff) + } + var gotCaps SamplingCapabilities + if err := json.Unmarshal([]byte(tt.json), &gotCaps); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + if diff := cmp.Diff(tt.caps, &gotCaps); diff != "" { + t.Errorf("mismatch (-want +got):\n%s", diff) + } + }) } } @@ -830,19 +788,13 @@ func TestCreateMessageWithToolsParams_MarshalUnmarshalJSON(t *testing.T) { t.Fatalf("Unmarshal() error = %v", err) } - if len(got.Tools) != 1 { - t.Fatalf("len(Tools) = %v, want 1", len(got.Tools)) - } - if got.Tools[0].Name != "calculator" { - t.Errorf("Tools[0].Name = %v, want calculator", got.Tools[0].Name) - } - if got.ToolChoice == nil || got.ToolChoice.Mode != "auto" { - t.Errorf("ToolChoice = %v, want {Mode: auto}", got.ToolChoice) + if diff := cmp.Diff(params, &got); diff != "" { + t.Errorf("mismatch (-want +got):\n%s", diff) } } func TestToolChoice_MarshalUnmarshalJSON(t *testing.T) { - choices := []*ToolChoice{ + choices := []ToolChoice{ {Mode: "auto"}, {Mode: "required"}, {Mode: "none"}, @@ -857,39 +809,60 @@ func TestToolChoice_MarshalUnmarshalJSON(t *testing.T) { if err := json.Unmarshal(data, &got); err != nil { t.Fatalf("Unmarshal() error = %v", err) } - if diff := cmp.Diff(*tc, got); diff != "" { + if diff := cmp.Diff(tc, got); diff != "" { t.Errorf("mismatch (-want +got):\n%s", diff) } } } func TestCreateMessageWithToolsResult_MarshalJSON(t *testing.T) { - // Single-element Content marshals as object (not array) for backward compat. - result := &CreateMessageWithToolsResult{ - Model: "test", - Role: "assistant", - Content: []Content{&TextContent{Text: "hello"}}, - } - - data, err := json.Marshal(result) - if err != nil { - t.Fatalf("Marshal() error = %v", err) - } - - var raw map[string]json.RawMessage - if err := json.Unmarshal(data, &raw); err != nil { - t.Fatalf("Unmarshal raw error = %v", err) + tests := []struct { + name string + result *CreateMessageWithToolsResult + want string + }{ + { + name: "single element content", + result: &CreateMessageWithToolsResult{ + Model: "test", + Role: "assistant", + Content: []Content{&TextContent{Text: "hello"}}, + }, + // Single-element Content marshals as object (not array) for backward compat. + want: `{"content":{"type":"text","text":"hello"},"model":"test","role":"assistant"}`, + }, + { + name: "multiple elements content", + result: &CreateMessageWithToolsResult{ + Model: "test", + Role: "assistant", + Content: []Content{ + &TextContent{Text: "thinking..."}, + &ToolUseContent{ + ID: "call_1", + Name: "calculator", + Input: map[string]any{ + "a": 1.0, + "b": 2.0, + }, + }, + }, + }, + // Multiple elements marshal as array. + want: `{"content":[{"type":"text","text":"thinking..."},{"type":"tool_use","id":"call_1","name":"calculator","input":{"a":1,"b":2}}],"model":"test","role":"assistant"}`, + }, } - content := raw["content"] - for _, b := range content { - if b == ' ' || b == '\t' || b == '\n' || b == '\r' { - continue - } - if b == '[' { - t.Errorf("single-element Content marshaled as array, want object") - } - break + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(tt.result) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + if diff := cmp.Diff(tt.want, string(got)); diff != "" { + t.Errorf("mismatch (-want +got):\n%s", diff) + } + }) } } @@ -897,23 +870,56 @@ func TestCreateMessageWithToolsResult_UnmarshalJSON(t *testing.T) { tests := []struct { name string json string - wantLen int + want *CreateMessageWithToolsResult wantErr bool }{ { - name: "single tool_use content", - json: `{"content":{"type":"tool_use","id":"tool_1","name":"calculator","input":{"x":1}},"model":"test-model","role":"assistant","stopReason":"toolUse"}`, - wantLen: 1, + name: "single tool_use content", + json: `{"content":{"type":"tool_use","id":"tool_1","name":"calculator","input":{"x":1}},"model":"test-model","role":"assistant","stopReason":"toolUse"}`, + want: &CreateMessageWithToolsResult{ + Model: "test-model", + Role: "assistant", + StopReason: "toolUse", + Content: []Content{ + &ToolUseContent{ + ID: "tool_1", + Name: "calculator", + Input: map[string]any{ + "x": 1.0, + }, + }, + }, + }, }, { - name: "array of tool_use content", - json: `{"content":[{"type":"tool_use","id":"t1","name":"calc","input":{"x":1}},{"type":"tool_use","id":"t2","name":"search","input":{"q":"hi"}}],"model":"test","role":"assistant","stopReason":"toolUse"}`, - wantLen: 2, + name: "array of tool_use content", + json: `{"content":[{"type":"tool_use","id":"t1","name":"calc","input":{"x":1}},{"type":"tool_use","id":"t2","name":"search","input":{"q":"hi"}}],"model":"test","role":"assistant","stopReason":"toolUse"}`, + want: &CreateMessageWithToolsResult{ + Model: "test", + Role: "assistant", + StopReason: "toolUse", + Content: []Content{ + &ToolUseContent{ + ID: "t1", + Name: "calc", + Input: map[string]any{"x": 1.0}, + }, + &ToolUseContent{ + ID: "t2", + Name: "search", + Input: map[string]any{"q": "hi"}, + }, + }, + }, }, { - name: "empty array", - json: `{"content":[],"model":"m","role":"assistant"}`, - wantLen: 0, + name: "empty array", + json: `{"content":[],"model":"m","role":"assistant"}`, + want: &CreateMessageWithToolsResult{ + Model: "m", + Role: "assistant", + Content: []Content{}, + }, }, { name: "null content", @@ -937,100 +943,165 @@ func TestCreateMessageWithToolsResult_UnmarshalJSON(t *testing.T) { if tt.wantErr { return } - if len(got.Content) != tt.wantLen { - t.Errorf("len(Content) = %d, want %d", len(got.Content), tt.wantLen) + if diff := cmp.Diff(tt.want, &got); diff != "" { + t.Errorf("mismatch (-want +got):\n%s", diff) } }) } } func TestSamplingMessageV2_MarshalJSON(t *testing.T) { - msg := &SamplingMessageV2{ - Role: "user", - Content: []Content{}, - } - data, err := json.Marshal(msg) - if err != nil { - t.Fatalf("Marshal() error = %v", err) - } - var got SamplingMessageV2 - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("Unmarshal() error = %v", err) + tests := []struct { + name string + msg *SamplingMessageV2 + want string + }{ + { + name: "empty content", + msg: &SamplingMessageV2{ + Role: "user", + Content: []Content{}, + }, + want: `{"content":[],"role":"user"}`, + }, + { + name: "single content", + msg: &SamplingMessageV2{ + Role: "assistant", + Content: []Content{&TextContent{Text: "hello"}}, + }, + want: `{"content":{"type":"text","text":"hello"},"role":"assistant"}`, + }, + { + name: "multiple content", + msg: &SamplingMessageV2{ + Role: "assistant", + // Text + tool_use in the same message (valid per spec for assistant). + Content: []Content{ + &TextContent{Text: "checking weather"}, + &ToolUseContent{ID: "c1", Name: "weather", Input: map[string]any{"city": "SF"}}, + }, + }, + want: `{"content":[{"type":"text","text":"checking weather"},{"type":"tool_use","id":"c1","name":"weather","input":{"city":"SF"}}],"role":"assistant"}`, + }, } - if len(got.Content) != 0 { - t.Errorf("len(Content) = %d, want 0", len(got.Content)) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(tt.msg) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + if diff := cmp.Diff(tt.want, string(got)); diff != "" { + t.Errorf("mismatch (-want +got):\n%s", diff) + } + }) } } func TestSamplingMessageV2_UnmarshalJSON(t *testing.T) { - // Text + tool_use in the same message (valid per spec for assistant). - msg := &SamplingMessageV2{ - Role: "assistant", - Content: []Content{ - &TextContent{Text: "Let me check the weather."}, - &ToolUseContent{ID: "c1", Name: "weather", Input: map[string]any{"city": "SF"}}, + tests := []struct { + name string + json string + want *SamplingMessageV2 + }{ + { + name: "single content object", + json: `{"role":"user","content":{"type":"text","text":"hello"}}`, + want: &SamplingMessageV2{ + Role: "user", + Content: []Content{ + &TextContent{Text: "hello"}, + }, + }, + }, + { + name: "multiple content", + json: `{"role":"assistant","content":[{"type":"text","text":"Let me check the weather."},{"type":"tool_use","id":"c1","name":"weather","input":{"city":"SF"}}]}`, + want: &SamplingMessageV2{ + Role: "assistant", + // Text + tool_use in the same message (valid per spec for assistant). + Content: []Content{ + &TextContent{Text: "Let me check the weather."}, + &ToolUseContent{ID: "c1", Name: "weather", Input: map[string]any{"city": "SF"}}, + }, + }, }, } - data, err := json.Marshal(msg) - if err != nil { - t.Fatalf("Marshal() error = %v", err) - } - var got SamplingMessageV2 - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("Unmarshal() error = %v", err) - } - if len(got.Content) != 2 { - t.Fatalf("len(Content) = %d, want 2", len(got.Content)) - } - if _, ok := got.Content[0].(*TextContent); !ok { - t.Errorf("Content[0] type = %T, want *TextContent", got.Content[0]) - } - if _, ok := got.Content[1].(*ToolUseContent); !ok { - t.Errorf("Content[1] type = %T, want *ToolUseContent", got.Content[1]) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got SamplingMessageV2 + if err := json.Unmarshal([]byte(tt.json), &got); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + if diff := cmp.Diff(tt.want, &got); diff != "" { + t.Errorf("mismatch (-want +got):\n%s", diff) + } + }) } } func TestToBase_Conversion(t *testing.T) { - // Single-content messages convert successfully. - params := &CreateMessageWithToolsParams{ - MaxTokens: 1000, - Messages: []*SamplingMessageV2{ - {Role: "user", Content: []Content{&TextContent{Text: "hello"}}}, + tests := []struct { + name string + params *CreateMessageWithToolsParams + want *CreateMessageParams + wantErr bool + }{ + { + name: "Single content", + params: &CreateMessageWithToolsParams{ + MaxTokens: 1000, + Messages: []*SamplingMessageV2{ + {Role: "user", Content: []Content{&TextContent{Text: "hello"}}}, + }, + Tools: []*Tool{{Name: "calc"}}, + ToolChoice: &ToolChoice{Mode: "auto"}, + }, + want: &CreateMessageParams{ + MaxTokens: 1000, + Messages: []*SamplingMessage{ + {Role: "user", Content: &TextContent{Text: "hello"}}, + }, + }, }, - Tools: []*Tool{{Name: "calc"}}, - ToolChoice: &ToolChoice{Mode: "auto"}, - } - base, err := params.toBase() - if err != nil { - t.Fatalf("toBase() error = %v", err) - } - if base.MaxTokens != 1000 { - t.Errorf("MaxTokens = %d, want 1000", base.MaxTokens) - } - if tc, ok := base.Messages[0].Content.(*TextContent); !ok || tc.Text != "hello" { - t.Errorf("Messages[0].Content = %v, want TextContent{hello}", base.Messages[0].Content) - } - - // Multi-content messages return an error. - params2 := &CreateMessageWithToolsParams{ - MaxTokens: 1000, - Messages: []*SamplingMessageV2{ - {Role: "assistant", Content: []Content{ - &ToolUseContent{ID: "c1", Name: "calc", Input: map[string]any{}}, - &ToolUseContent{ID: "c2", Name: "search", Input: map[string]any{}}, - }}, + { + name: "Multi content", + params: &CreateMessageWithToolsParams{ + MaxTokens: 1000, + Messages: []*SamplingMessageV2{ + {Role: "assistant", Content: []Content{ + &ToolUseContent{ID: "c1", Name: "calc", Input: map[string]any{}}, + &ToolUseContent{ID: "c2", Name: "search", Input: map[string]any{}}, + }}, + }, + }, + wantErr: true, }, } - if _, err := params2.toBase(); err == nil { - t.Error("toBase() should return error for multi-content message") + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.params.toBase() + if (err != nil) != tt.wantErr { + t.Fatalf("toBase() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr { + return + } + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("mismatch (-want +got):\n%s", diff) + } + }) } } func TestToWithTools_Conversion(t *testing.T) { tests := []struct { - name string - result *CreateMessageResult - wantLen int + name string + result *CreateMessageResult + want *CreateMessageWithToolsResult }{ { name: "with content", @@ -1040,7 +1111,12 @@ func TestToWithTools_Conversion(t *testing.T) { Content: &TextContent{Text: "hello"}, StopReason: "endTurn", }, - wantLen: 1, + want: &CreateMessageWithToolsResult{ + Model: "test", + Role: "assistant", + Content: []Content{&TextContent{Text: "hello"}}, + StopReason: "endTurn", + }, }, { name: "nil content", @@ -1048,23 +1124,18 @@ func TestToWithTools_Conversion(t *testing.T) { Model: "test", Role: "assistant", }, - wantLen: 0, + want: &CreateMessageWithToolsResult{ + Model: "test", + Role: "assistant", + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { wt := tt.result.toWithTools() - if wt.Model != tt.result.Model { - t.Errorf("Model = %v, want %v", wt.Model, tt.result.Model) - } - if len(wt.Content) != tt.wantLen { - t.Fatalf("len(Content) = %d, want %d", len(wt.Content), tt.wantLen) - } - if tt.wantLen > 0 { - if tc, ok := wt.Content[0].(*TextContent); !ok || tc.Text != "hello" { - t.Errorf("Content[0] = %v, want TextContent{hello}", wt.Content[0]) - } + if diff := cmp.Diff(tt.want, wt); diff != "" { + t.Errorf("mismatch (-want +got):\n%s", diff) } }) } diff --git a/mcp/sampling_test.go b/mcp/sampling_test.go index 61afce4a..bdffd452 100644 --- a/mcp/sampling_test.go +++ b/mcp/sampling_test.go @@ -10,30 +10,34 @@ import ( "context" "strings" "testing" + + "github.com/google/go-cmp/cmp" ) -func TestSamplingWithTools_Integration(t *testing.T) { +func TestSamplingWithTools_ToolUse(t *testing.T) { ctx := context.Background() ct, st := NewInMemoryTransports() // Track what the client received - var receivedParams *CreateMessageWithToolsParams + var gotParams *CreateMessageWithToolsParams + result := &CreateMessageWithToolsResult{ + Model: "test-model", + Role: "assistant", + Content: []Content{ + &ToolUseContent{ + ID: "tool_call_1", + Name: "calculator", + Input: map[string]any{"x": 1.0, "y": 2.0}, + }, + }, + StopReason: "toolUse", + } // Client with tools capability, using CreateMessageWithToolsHandler client := NewClient(testImpl, &ClientOptions{ CreateMessageWithToolsHandler: func(_ context.Context, req *CreateMessageWithToolsRequest) (*CreateMessageWithToolsResult, error) { - receivedParams = req.Params - // Return a tool use response - return &CreateMessageWithToolsResult{ - Model: "test-model", - Role: "assistant", - Content: []Content{&ToolUseContent{ - ID: "tool_call_1", - Name: "calculator", - Input: map[string]any{"x": 1.0, "y": 2.0}, - }}, - StopReason: "toolUse", - }, nil + gotParams = req.Params + return result, nil }, Capabilities: &ClientCapabilities{ Sampling: &SamplingCapabilities{Tools: &SamplingToolsCapabilities{}}, @@ -54,7 +58,7 @@ func TestSamplingWithTools_Integration(t *testing.T) { defer cs.Close() // Server sends CreateMessageWithTools - result, err := ss.CreateMessageWithTools(ctx, &CreateMessageWithToolsParams{ + params := &CreateMessageWithToolsParams{ MaxTokens: 1000, Messages: []*SamplingMessageV2{ {Role: "user", Content: []Content{&TextContent{Text: "Calculate 1+2"}}}, @@ -73,54 +77,33 @@ func TestSamplingWithTools_Integration(t *testing.T) { }, }, ToolChoice: &ToolChoice{Mode: "auto"}, - }) + } + gotResult, err := ss.CreateMessageWithTools(ctx, params) if err != nil { t.Fatalf("CreateMessageWithTools() error = %v", err) } - // Verify client received the tools - if receivedParams == nil { - t.Fatal("client did not receive params") - } - if len(receivedParams.Tools) != 1 { - t.Errorf("client received %d tools, want 1", len(receivedParams.Tools)) - } - if receivedParams.Tools[0].Name != "calculator" { - t.Errorf("tool name = %v, want calculator", receivedParams.Tools[0].Name) - } - if receivedParams.ToolChoice == nil || receivedParams.ToolChoice.Mode != "auto" { - t.Errorf("tool choice mode = %v, want auto", receivedParams.ToolChoice) + // Verify client received the params + if diff := cmp.Diff(params, gotParams); diff != "" { + t.Errorf("CreateMessageWithToolsParams mismatch (-want +got):\n%s", diff) } // Verify server received the tool use response - if result.StopReason != "toolUse" { - t.Errorf("StopReason = %v, want toolUse", result.StopReason) - } - if len(result.Content) != 1 { - t.Fatalf("len(Content) = %d, want 1", len(result.Content)) - } - toolUse, ok := result.Content[0].(*ToolUseContent) - if !ok { - t.Fatalf("Content[0] type = %T, want *ToolUseContent", result.Content[0]) - } - if toolUse.ID != "tool_call_1" { - t.Errorf("ToolUse.ID = %v, want tool_call_1", toolUse.ID) - } - if toolUse.Name != "calculator" { - t.Errorf("ToolUse.Name = %v, want calculator", toolUse.Name) + if diff := cmp.Diff(result, gotResult); diff != "" { + t.Errorf("CreateMessageWithToolsResult mismatch (-want +got):\n%s", diff) } } -func TestSamplingWithToolResult_Integration(t *testing.T) { +func TestSamplingWithTools_ToolResult(t *testing.T) { ctx := context.Background() ct, st := NewInMemoryTransports() // Track messages received by client - var receivedMessages []*SamplingMessage + var gotParams *CreateMessageParams client := NewClient(testImpl, &ClientOptions{ CreateMessageHandler: func(_ context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { - receivedMessages = req.Params.Messages + gotParams = req.Params return &CreateMessageResult{ Model: "test-model", Role: "assistant", @@ -145,8 +128,7 @@ func TestSamplingWithToolResult_Integration(t *testing.T) { } defer cs.Close() - // Server sends CreateMessage with tool result in messages - _, err = ss.CreateMessage(ctx, &CreateMessageParams{ + params := &CreateMessageParams{ MaxTokens: 1000, Messages: []*SamplingMessage{ {Role: "user", Content: &TextContent{Text: "Calculate 1+2"}}, @@ -160,43 +142,14 @@ func TestSamplingWithToolResult_Integration(t *testing.T) { Content: []Content{&TextContent{Text: "3"}}, }}, }, - }) + } + _, err = ss.CreateMessage(ctx, params) if err != nil { t.Fatalf("CreateMessage() error = %v", err) } - // Verify client received all messages including tool content - if len(receivedMessages) != 3 { - t.Fatalf("received %d messages, want 3", len(receivedMessages)) - } - - // Check first message is text - if _, ok := receivedMessages[0].Content.(*TextContent); !ok { - t.Errorf("message[0] content type = %T, want *TextContent", receivedMessages[0].Content) - } - - // Check second message is tool use - toolUse, ok := receivedMessages[1].Content.(*ToolUseContent) - if !ok { - t.Fatalf("message[1] content type = %T, want *ToolUseContent", receivedMessages[1].Content) - } - if toolUse.ID != "tool_1" { - t.Errorf("toolUse.ID = %v, want tool_1", toolUse.ID) - } - - // Check third message is tool result - toolResult, ok := receivedMessages[2].Content.(*ToolResultContent) - if !ok { - t.Fatalf("message[2] content type = %T, want *ToolResultContent", receivedMessages[2].Content) - } - if toolResult.ToolUseID != "tool_1" { - t.Errorf("toolResult.ToolUseID = %v, want tool_1", toolResult.ToolUseID) - } - if len(toolResult.Content) != 1 { - t.Fatalf("toolResult.Content len = %d, want 1", len(toolResult.Content)) - } - if tc, ok := toolResult.Content[0].(*TextContent); !ok || tc.Text != "3" { - t.Errorf("toolResult.Content[0] = %v, want TextContent with '3'", toolResult.Content[0]) + if diff := cmp.Diff(params, gotParams); diff != "" { + t.Errorf("CreateMessageParams mismatch (-want +got):\n%s", diff) } } @@ -233,14 +186,12 @@ func TestSamplingToolsCapabilities(t *testing.T) { // Check server sees client capabilities caps := ss.InitializeParams().Capabilities - if caps.Sampling == nil { - t.Fatal("client should advertise sampling capability") - } - if caps.Sampling.Tools == nil { - t.Error("client should advertise sampling.tools capability") + want := &SamplingCapabilities{ + Tools: &SamplingToolsCapabilities{}, + Context: &SamplingContextCapabilities{}, } - if caps.Sampling.Context == nil { - t.Error("client should advertise sampling.context capability") + if diff := cmp.Diff(want, caps.Sampling); diff != "" { + t.Errorf("SamplingCapabilities mismatch (-want +got):\n%s", diff) } }) @@ -269,14 +220,9 @@ func TestSamplingToolsCapabilities(t *testing.T) { // Check server sees client capabilities caps := ss.InitializeParams().Capabilities - if caps.Sampling == nil { - t.Fatal("client should advertise sampling capability") - } - if caps.Sampling.Tools != nil { - t.Error("client should NOT advertise sampling.tools capability") - } - if caps.Sampling.Context != nil { - t.Error("client should NOT advertise sampling.context capability") + want := &SamplingCapabilities{} + if diff := cmp.Diff(want, caps.Sampling); diff != "" { + t.Errorf("SamplingCapabilities mismatch (-want +got):\n%s", diff) } }) @@ -304,24 +250,24 @@ func TestSamplingToolsCapabilities(t *testing.T) { defer cs.Close() caps := ss.InitializeParams().Capabilities - if caps.Sampling == nil { - t.Fatal("client should advertise sampling capability") + want := &SamplingCapabilities{ + Tools: &SamplingToolsCapabilities{}, } - if caps.Sampling.Tools == nil { - t.Error("client should infer sampling.tools capability from CreateMessageWithToolsHandler") + if diff := cmp.Diff(want, caps.Sampling); diff != "" { + t.Errorf("SamplingCapabilities mismatch (-want +got):\n%s", diff) } }) } -func TestSamplingToolResultWithError_Integration(t *testing.T) { +func TestSamplingWithTools_ToolResultWithError(t *testing.T) { ctx := context.Background() ct, st := NewInMemoryTransports() - var receivedMessages []*SamplingMessage + var gotParams *CreateMessageParams client := NewClient(testImpl, &ClientOptions{ CreateMessageHandler: func(_ context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { - receivedMessages = req.Params.Messages + gotParams = req.Params return &CreateMessageResult{ Model: "test-model", Role: "assistant", @@ -346,9 +292,7 @@ func TestSamplingToolResultWithError_Integration(t *testing.T) { } defer cs.Close() - // Server sends CreateMessage with error tool result, preceded by - // the original request and tool use for a more realistic scenario. - _, err = ss.CreateMessage(ctx, &CreateMessageParams{ + params := &CreateMessageParams{ MaxTokens: 1000, Messages: []*SamplingMessage{ {Role: "user", Content: &TextContent{Text: "Divide 1 by 0"}}, @@ -363,43 +307,34 @@ func TestSamplingToolResultWithError_Integration(t *testing.T) { IsError: true, }}, }, - }) + } + _, err = ss.CreateMessage(ctx, params) if err != nil { t.Fatalf("CreateMessage() error = %v", err) } - if len(receivedMessages) != 3 { - t.Fatalf("received %d messages, want 3", len(receivedMessages)) - } - - toolResult, ok := receivedMessages[2].Content.(*ToolResultContent) - if !ok { - t.Fatalf("content type = %T, want *ToolResultContent", receivedMessages[2].Content) - } - if !toolResult.IsError { - t.Error("IsError should be true") - } - if toolResult.ToolUseID != "tool_1" { - t.Errorf("ToolUseID = %v, want tool_1", toolResult.ToolUseID) + if diff := cmp.Diff(params, gotParams); diff != "" { + t.Errorf("CreateMessageParams mismatch (-want +got):\n%s", diff) } } -func TestParallelToolCalls_Integration(t *testing.T) { +func TestSamplingWithTools_ParallelToolCalls(t *testing.T) { ctx := context.Background() ct, st := NewInMemoryTransports() + result := &CreateMessageWithToolsResult{ + Model: "test-model", + Role: "assistant", + Content: []Content{ + &ToolUseContent{ID: "call_1", Name: "weather", Input: map[string]any{"city": "SF"}}, + &ToolUseContent{ID: "call_2", Name: "weather", Input: map[string]any{"city": "NY"}}, + }, + StopReason: "toolUse", + } // Client returns parallel tool use results client := NewClient(testImpl, &ClientOptions{ CreateMessageWithToolsHandler: func(_ context.Context, req *CreateMessageWithToolsRequest) (*CreateMessageWithToolsResult, error) { - return &CreateMessageWithToolsResult{ - Model: "test-model", - Role: "assistant", - Content: []Content{ - &ToolUseContent{ID: "call_1", Name: "weather", Input: map[string]any{"city": "SF"}}, - &ToolUseContent{ID: "call_2", Name: "weather", Input: map[string]any{"city": "NY"}}, - }, - StopReason: "toolUse", - }, nil + return result, nil }, Capabilities: &ClientCapabilities{ Sampling: &SamplingCapabilities{Tools: &SamplingToolsCapabilities{}}, @@ -419,7 +354,7 @@ func TestParallelToolCalls_Integration(t *testing.T) { } defer cs.Close() - result, err := ss.CreateMessageWithTools(ctx, &CreateMessageWithToolsParams{ + gotResult, err := ss.CreateMessageWithTools(ctx, &CreateMessageWithToolsParams{ MaxTokens: 1000, Messages: []*SamplingMessageV2{ {Role: "user", Content: []Content{&TextContent{Text: "Weather in SF and NY"}}}, @@ -432,23 +367,8 @@ func TestParallelToolCalls_Integration(t *testing.T) { t.Fatalf("CreateMessageWithTools() error = %v", err) } - if len(result.Content) != 2 { - t.Fatalf("len(Content) = %d, want 2", len(result.Content)) - } - for i, c := range result.Content { - tu, ok := c.(*ToolUseContent) - if !ok { - t.Fatalf("Content[%d] type = %T, want *ToolUseContent", i, c) - } - if tu.Name != "weather" { - t.Errorf("Content[%d].Name = %v, want weather", i, tu.Name) - } - } - if result.Content[0].(*ToolUseContent).ID != "call_1" { - t.Errorf("Content[0].ID = %v, want call_1", result.Content[0].(*ToolUseContent).ID) - } - if result.Content[1].(*ToolUseContent).ID != "call_2" { - t.Errorf("Content[1].ID = %v, want call_2", result.Content[1].(*ToolUseContent).ID) + if diff := cmp.Diff(result, gotResult); diff != "" { + t.Errorf("CreateMessageWithToolsResult mismatch (-want +got):\n%s", diff) } }