diff --git a/mcp/server.go b/mcp/server.go index 0c1acc24..cec8644f 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -212,7 +212,7 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan elemZero any // only non-nil if Out is a pointer type outputResolved *jsonschema.Resolved ) - if reflect.TypeFor[Out]() != reflect.TypeFor[any]() { + if t.OutputSchema != nil || reflect.TypeFor[Out]() != reflect.TypeFor[any]() { var err error elemZero, err = setSchema[Out](&tt.OutputSchema, &outputResolved) if err != nil { @@ -302,8 +302,8 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan // TODO(rfindley): we really shouldn't ever return 'null' results. Maybe we // should have a jsonschema.Zero(schema) helper? func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) (zero any, err error) { - rt := reflect.TypeFor[T]() if *sfield == nil { + rt := reflect.TypeFor[T]() if rt.Kind() == reflect.Pointer { rt = rt.Elem() zero = reflect.Zero(rt).Interface() diff --git a/mcp/server_test.go b/mcp/server_test.go index 7db40738..e46be379 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -6,6 +6,7 @@ package mcp import ( "context" + "encoding/json" "log" "slices" "testing" @@ -487,3 +488,59 @@ func TestAddTool(t *testing.T) { t.Error("bad Out: expected panic") } } + +type schema = jsonschema.Schema + +func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out, wantIn, wantOut *schema, wantErr bool) { + t.Helper() + th := func(context.Context, *CallToolRequest, In) (*CallToolResult, Out, error) { + return nil, out, nil + } + gott, goth, err := toolForErr(tool, th) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(wantIn, gott.InputSchema); diff != "" { + t.Errorf("input: mismatch (-want, +got):\n%s", diff) + } + if diff := cmp.Diff(wantOut, gott.OutputSchema); diff != "" { + t.Errorf("output: mismatch (-want, +got):\n%s", diff) + } + ctr := &CallToolRequest{ + Params: &CallToolParamsRaw{ + Arguments: json.RawMessage(in), + }, + } + _, err = goth(context.Background(), ctr) + + if gotErr := err != nil; gotErr != wantErr { + t.Errorf("got error: %t, want error: %t", gotErr, wantErr) + } +} + +func TestToolForSchemas(t *testing.T) { + // Validate that ToolFor handles schemas properly. + + // Infer both schemas. + testToolForSchema[int](t, &Tool{}, "3", true, + &schema{Type: "integer"}, &schema{Type: "boolean"}, false) + // Validate the input schema: expect an error if it's wrong. + // We can't test that the output schema is validated, because it's typed. + testToolForSchema[int](t, &Tool{}, `"x"`, true, + &schema{Type: "integer"}, &schema{Type: "boolean"}, true) + + // Ignore type any for output. + testToolForSchema[int, any](t, &Tool{}, "3", 0, + &schema{Type: "integer"}, nil, false) + // Input is still validated. + testToolForSchema[int, any](t, &Tool{}, `"x"`, 0, + &schema{Type: "integer"}, nil, true) + + // Tool sets input schema: that is what's used. + testToolForSchema[int, any](t, &Tool{InputSchema: &schema{Type: "string"}}, "3", 0, + &schema{Type: "string"}, nil, true) // error: 3 is not a string + + // Tool sets output schema: that is what's used, and validation happens. + testToolForSchema[string, any](t, &Tool{OutputSchema: &schema{Type: "integer"}}, "3", "x", + &schema{Type: "string"}, &schema{Type: "integer"}, true) // error: "x" is not an integer +} diff --git a/mcp/tool.go b/mcp/tool.go index 53a3c7aa..9e757f46 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -9,7 +9,6 @@ import ( "context" "encoding/json" "fmt" - // "log" "github.com/google/jsonschema-go/jsonschema" )