diff --git a/src/agents/agent.py b/src/agents/agent.py index 5d700ebaa3..8cfb7c4952 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -150,6 +150,13 @@ class MCPConfig(TypedDict): default_tool_error_function. """ + include_server_in_tool_names: NotRequired[bool] + """If True, MCP tool names exposed to the model are prefixed with the server name + (e.g. ``my_server__my_tool``) so that tools from different servers with the same + name do not collide. The original MCP tool name is still used when invoking the + server. Defaults to False. + """ + @dataclass class AgentBase(Generic[TContext]): @@ -186,12 +193,14 @@ async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[ failure_error_function = self.mcp_config.get( "failure_error_function", default_tool_error_function ) + include_server_in_tool_names = self.mcp_config.get("include_server_in_tool_names", False) return await MCPUtil.get_all_function_tools( self.mcp_servers, convert_schemas_to_strict, run_context, self, failure_error_function=failure_error_function, + include_server_in_tool_names=include_server_in_tool_names, ) async def get_all_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]: diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 33bea065c5..6da0b28beb 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -5,6 +5,7 @@ import functools import inspect import json +import re from collections.abc import Awaitable from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Protocol, Union @@ -174,6 +175,19 @@ def create_static_tool_filter( return filter_dict +_SERVER_NAME_SANITIZE_RE = re.compile(r"[^a-zA-Z0-9_]") + + +def _sanitize_server_name(name: str) -> str: + """Sanitize an MCP server name so it is safe for use as a tool name prefix. + + Replaces any character that is not alphanumeric or underscore with an underscore. + Falls back to ``server`` if the result would be empty. + """ + sanitized = _SERVER_NAME_SANITIZE_RE.sub("_", name).strip("_") + return sanitized or "server" + + class MCPUtil: """Set of utilities for interop between MCP and Agents SDK tools.""" @@ -207,9 +221,10 @@ async def get_all_function_tools( run_context: RunContextWrapper[Any], agent: AgentBase, failure_error_function: ToolErrorFunction | None = default_tool_error_function, + include_server_in_tool_names: bool = False, ) -> list[Tool]: """Get all function tools from a list of MCP servers.""" - tools = [] + tools: list[Tool] = [] tool_names: set[str] = set() for server in servers: server_tools = await cls.get_function_tools( @@ -219,12 +234,18 @@ async def get_all_function_tools( agent, failure_error_function=failure_error_function, ) + + if include_server_in_tool_names: + prefix = _sanitize_server_name(server.name) + for tool in server_tools: + if isinstance(tool, FunctionTool): + tool.name = f"{prefix}__{tool.name}" + server_tool_names = {tool.name for tool in server_tools} - if len(server_tool_names & tool_names) > 0: - raise UserError( - f"Duplicate tool names found across MCP servers: " - f"{server_tool_names & tool_names}" - ) + duplicates = server_tool_names & tool_names + if duplicates: + raise UserError(f"Duplicate tool names found across MCP servers: {duplicates}") + tool_names.update(server_tool_names) tools.extend(server_tools) diff --git a/tests/mcp/test_mcp_util.py b/tests/mcp/test_mcp_util.py index c992e25e03..a35ae08ae8 100644 --- a/tests/mcp/test_mcp_util.py +++ b/tests/mcp/test_mcp_util.py @@ -1455,3 +1455,149 @@ def test_to_function_tool_description_falls_back_to_mcp_title(): assert function_tool.description == "Search Docs" assert function_tool._mcp_title == "Search Docs" + + +@pytest.mark.asyncio +async def test_duplicate_tool_names_raises_by_default(): + """Default behavior: duplicate tool names across servers raises UserError.""" + server_a = FakeMCPServer(server_name="server_a") + server_a.add_tool("run", {"type": "object", "properties": {}}) + + server_b = FakeMCPServer(server_name="server_b") + server_b.add_tool("run", {"type": "object", "properties": {}}) + + agent = Agent(name="test", instructions="test") + run_context = RunContextWrapper(context=None) + + with pytest.raises(AgentsException, match="Duplicate tool names"): + await MCPUtil.get_all_function_tools( + [server_a, server_b], + convert_schemas_to_strict=False, + run_context=run_context, + agent=agent, + ) + + +@pytest.mark.asyncio +async def test_include_server_in_tool_names_avoids_collision(): + """With include_server_in_tool_names=True, duplicate names are prefixed and no error.""" + server_a = FakeMCPServer(server_name="server_a") + server_a.add_tool("run", {"type": "object", "properties": {}}) + + server_b = FakeMCPServer(server_name="server_b") + server_b.add_tool("run", {"type": "object", "properties": {}}) + + agent = Agent(name="test", instructions="test") + run_context = RunContextWrapper(context=None) + + tools = await MCPUtil.get_all_function_tools( + [server_a, server_b], + convert_schemas_to_strict=False, + run_context=run_context, + agent=agent, + include_server_in_tool_names=True, + ) + + tool_names = [t.name for t in tools] + assert "server_a__run" in tool_names + assert "server_b__run" in tool_names + assert len(tool_names) == 2 + + +@pytest.mark.asyncio +async def test_include_server_in_tool_names_invokes_with_original_name(): + """Prefixed tools still invoke the MCP server using the original tool name.""" + server = FakeMCPServer(server_name="my_server") + server.add_tool("do_thing", {"type": "object", "properties": {}}) + + agent = Agent(name="test", instructions="test") + run_context = RunContextWrapper(context=None) + + tools = await MCPUtil.get_all_function_tools( + [server], + convert_schemas_to_strict=False, + run_context=run_context, + agent=agent, + include_server_in_tool_names=True, + ) + + assert len(tools) == 1 + func_tool = tools[0] + assert isinstance(func_tool, FunctionTool) + assert func_tool.name == "my_server__do_thing" + + # Invoke the tool and verify the server received the original name. + tool_context = ToolContext( + context=None, + tool_name="my_server__do_thing", + tool_call_id="test_call", + tool_arguments="{}", + ) + await func_tool.on_invoke_tool(tool_context, "{}") + assert server.tool_calls == ["do_thing"] + + +@pytest.mark.asyncio +async def test_include_server_in_tool_names_sanitizes_server_name(): + """Server names with special characters are sanitized for the prefix.""" + server = FakeMCPServer(server_name="my-cool.server/v2") + server.add_tool("action", {"type": "object", "properties": {}}) + + agent = Agent(name="test", instructions="test") + run_context = RunContextWrapper(context=None) + + tools = await MCPUtil.get_all_function_tools( + [server], + convert_schemas_to_strict=False, + run_context=run_context, + agent=agent, + include_server_in_tool_names=True, + ) + + func_tool = tools[0] + assert isinstance(func_tool, FunctionTool) + assert func_tool.name == "my_cool_server_v2__action" + + +@pytest.mark.asyncio +async def test_include_server_in_tool_names_empty_server_name_fallback(): + """Empty or all-special-character server names fall back to 'server'.""" + server = FakeMCPServer(server_name="---") + server.add_tool("action", {"type": "object", "properties": {}}) + + agent = Agent(name="test", instructions="test") + run_context = RunContextWrapper(context=None) + + tools = await MCPUtil.get_all_function_tools( + [server], + convert_schemas_to_strict=False, + run_context=run_context, + agent=agent, + include_server_in_tool_names=True, + ) + + func_tool = tools[0] + assert isinstance(func_tool, FunctionTool) + assert func_tool.name == "server__action" + + +@pytest.mark.asyncio +async def test_include_server_in_tool_names_detects_sanitized_collision(): + """Servers whose names sanitize to the same prefix still raise on collision.""" + server_a = FakeMCPServer(server_name="a-b") + server_a.add_tool("run", {"type": "object", "properties": {}}) + + server_b = FakeMCPServer(server_name="a_b") + server_b.add_tool("run", {"type": "object", "properties": {}}) + + agent = Agent(name="test", instructions="test") + run_context = RunContextWrapper(context=None) + + with pytest.raises(AgentsException, match="Duplicate tool names"): + await MCPUtil.get_all_function_tools( + [server_a, server_b], + convert_schemas_to_strict=False, + run_context=run_context, + agent=agent, + include_server_in_tool_names=True, + )