diff --git a/.changeset/fine-kiwis-rush.md b/.changeset/fine-kiwis-rush.md new file mode 100644 index 000000000..b2460fabe --- /dev/null +++ b/.changeset/fine-kiwis-rush.md @@ -0,0 +1,5 @@ +--- +'@modelcontextprotocol/client': patch +--- + +fix: validate sampling/createMessage with correct schema diff --git a/packages/client/src/client/client.ts b/packages/client/src/client/client.ts index 8d96ba0bc..82363cef2 100644 --- a/packages/client/src/client/client.ts +++ b/packages/client/src/client/client.ts @@ -43,6 +43,7 @@ import { CompleteResultSchema, CreateMessageRequestSchema, CreateMessageResultSchema, + CreateMessageResultWithToolsSchema, CreateTaskResultSchema, ElicitRequestSchema, ElicitResultSchema, @@ -458,8 +459,9 @@ export class Client< return taskValidationResult.data; } - // For non-task requests, validate against CreateMessageResultSchema - const validationResult = safeParse(CreateMessageResultSchema, result); + // For non-task requests, validate against appropriate schema + const schema = params.tools ? CreateMessageResultWithToolsSchema : CreateMessageResultSchema; + const validationResult = safeParse(schema, result); if (!validationResult.success) { const errorMessage = validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); diff --git a/test/integration/test/sampling_tools.test.ts b/test/integration/test/sampling_tools.test.ts new file mode 100644 index 000000000..a98e47196 --- /dev/null +++ b/test/integration/test/sampling_tools.test.ts @@ -0,0 +1,104 @@ +import { Client } from '@modelcontextprotocol/client'; +import { CreateMessageRequestSchema, InMemoryTransport } from '@modelcontextprotocol/core'; +import { Server } from '@modelcontextprotocol/server'; +import { describe, expect, test } from 'vitest'; + +describe('sampling/createMessage with tools', () => { + test('should support returning tool calls when tools are provided', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + + const client = new Client( + { name: 'test client', version: '1.0' }, + { + capabilities: { + sampling: { + tools: {} + } + } + } + ); + + // Implement request handler for sampling/createMessage that returns a tool call + client.setRequestHandler(CreateMessageRequestSchema, async _request => { + return { + model: 'test-model', + role: 'assistant', + stopReason: 'toolUse', + content: [ + { + type: 'tool_use', + id: 'call_1', + name: 'test_tool', + input: { arg: 'value' } + } + ] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const result = await server.createMessage({ + messages: [{ role: 'user', content: { type: 'text', text: 'use the tool' } }], + maxTokens: 100, + tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }] + }); + + expect(result).toEqual({ + model: 'test-model', + role: 'assistant', + stopReason: 'toolUse', + content: [ + { + type: 'tool_use', + id: 'call_1', + name: 'test_tool', + input: { arg: 'value' } + } + ] + }); + }); + + test('should fail if returning tool calls when tools are NOT provided', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + + const client = new Client( + { name: 'test client', version: '1.0' }, + { + capabilities: { + sampling: {} + } + } + ); + + // Implement request handler for sampling/createMessage that returns a tool call + client.setRequestHandler(CreateMessageRequestSchema, async _request => { + return { + model: 'test-model', + role: 'assistant', + stopReason: 'toolUse', + content: [ + { + type: 'tool_use', + id: 'call_1', + name: 'test_tool', + input: { arg: 'value' } + } + ] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // This should fail because the client validation will reject the tool call result + // when validating against CreateMessageResultSchema (since tools were not requested) + await expect( + server.createMessage({ + messages: [{ role: 'user', content: { type: 'text', text: 'use the tool' } }], + maxTokens: 100 + // No tools provided + }) + ).rejects.toThrow(); + }); +});