Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 37 additions & 48 deletions tests/client/test_list_methods_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import pytest

import mcp.types as types
from mcp.client._memory import InMemoryTransport
from mcp.client.session import ClientSession
from mcp import Client
from mcp.server import Server
from mcp.server.fastmcp import FastMCP
from mcp.types import ListToolsRequest, ListToolsResult
Expand Down Expand Up @@ -66,49 +65,43 @@ async def test_list_methods_params_parameter(

See: https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/pagination#request-format
"""
transport = InMemoryTransport(full_featured_server)
async with transport.connect() as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
spies = stream_spy()

# Test without params (omitted)
method = getattr(session, method_name)
_ = await method()
requests = spies.get_client_requests(method=request_method)
assert len(requests) == 1
assert requests[0].params is None

spies.clear()

# Test with params containing cursor
_ = await method(params=types.PaginatedRequestParams(cursor="from_params"))
requests = spies.get_client_requests(method=request_method)
assert len(requests) == 1
assert requests[0].params is not None
assert requests[0].params["cursor"] == "from_params"

spies.clear()

# Test with empty params
_ = await method(params=types.PaginatedRequestParams())
requests = spies.get_client_requests(method=request_method)
assert len(requests) == 1
# Empty params means no cursor
assert requests[0].params is None or "cursor" not in requests[0].params
async with Client(full_featured_server) as client:
spies = stream_spy()

# Test without params (omitted)
method = getattr(client, method_name)
_ = await method()
requests = spies.get_client_requests(method=request_method)
assert len(requests) == 1
assert requests[0].params is None

spies.clear()

# Test with params containing cursor
_ = await method(params=types.PaginatedRequestParams(cursor="from_params"))
requests = spies.get_client_requests(method=request_method)
assert len(requests) == 1
assert requests[0].params is not None
assert requests[0].params["cursor"] == "from_params"

spies.clear()

# Test with empty params
_ = await method(params=types.PaginatedRequestParams())
requests = spies.get_client_requests(method=request_method)
assert len(requests) == 1
# Empty params means no cursor
assert requests[0].params is None or "cursor" not in requests[0].params


async def test_list_tools_with_strict_server_validation(
full_featured_server: FastMCP,
):
"""Test pagination with a server that validates request format strictly."""
transport = InMemoryTransport(full_featured_server)
async with transport.connect() as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
result = await session.list_tools(params=types.PaginatedRequestParams())
assert isinstance(result, ListToolsResult)
assert len(result.tools) > 0
async with Client(full_featured_server) as client:
result = await client.list_tools(params=types.PaginatedRequestParams())
assert isinstance(result, ListToolsResult)
assert len(result.tools) > 0


async def test_list_tools_with_lowlevel_server():
Expand All @@ -129,13 +122,9 @@ async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult:
]
)

transport = InMemoryTransport(server)
async with transport.connect() as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()

result = await session.list_tools(params=types.PaginatedRequestParams())
assert result.tools[0].description == "cursor=None"
async with Client(server) as client:
result = await client.list_tools(params=types.PaginatedRequestParams())
assert result.tools[0].description == "cursor=None"

result = await session.list_tools(params=types.PaginatedRequestParams(cursor="page2"))
assert result.tools[0].description == "cursor=page2"
result = await client.list_tools(params=types.PaginatedRequestParams(cursor="page2"))
assert result.tools[0].description == "cursor=page2"
41 changes: 13 additions & 28 deletions tests/client/transports/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest

from mcp import Client
from mcp.client._memory import InMemoryTransport
from mcp.server import Server
from mcp.server.fastmcp import FastMCP
Expand Down Expand Up @@ -69,42 +70,26 @@ async def test_with_fastmcp(fastmcp_server: FastMCP):

async def test_server_is_running(fastmcp_server: FastMCP):
"""Test that the server is running and responding to requests."""
from mcp.client.session import ClientSession

transport = InMemoryTransport(fastmcp_server)
async with transport.connect() as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
result = await session.initialize()
assert result is not None
assert result.server_info.name == "test"
async with Client(fastmcp_server) as client:
assert client.server_capabilities is not None


async def test_list_tools(fastmcp_server: FastMCP):
"""Test listing tools through the transport."""
from mcp.client.session import ClientSession

transport = InMemoryTransport(fastmcp_server)
async with transport.connect() as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
tools_result = await session.list_tools()
assert len(tools_result.tools) > 0
tool_names = [t.name for t in tools_result.tools]
assert "greet" in tool_names
async with Client(fastmcp_server) as client:
tools_result = await client.list_tools()
assert len(tools_result.tools) > 0
tool_names = [t.name for t in tools_result.tools]
assert "greet" in tool_names


async def test_call_tool(fastmcp_server: FastMCP):
"""Test calling a tool through the transport."""
from mcp.client.session import ClientSession

transport = InMemoryTransport(fastmcp_server)
async with transport.connect() as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
result = await session.call_tool("greet", {"name": "World"})
assert result is not None
assert len(result.content) > 0
assert "Hello, World!" in str(result.content[0])
async with Client(fastmcp_server) as client:
result = await client.call_tool("greet", {"name": "World"})
assert result is not None
assert len(result.content) > 0
assert "Hello, World!" in str(result.content[0])


async def test_raise_exceptions(fastmcp_server: FastMCP):
Expand Down
98 changes: 43 additions & 55 deletions tests/server/test_cancel_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import pytest

import mcp.types as types
from mcp.client._memory import InMemoryTransport
from mcp.client.session import ClientSession
from mcp import Client
from mcp.server.lowlevel.server import Server
from mcp.shared.exceptions import McpError
from mcp.types import (
Expand Down Expand Up @@ -55,61 +54,50 @@ async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[
return [types.TextContent(type="text", text=f"Call number: {call_count}")]
raise ValueError(f"Unknown tool: {name}") # pragma: no cover

transport = InMemoryTransport(server)
async with transport.connect() as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as client:
await client.initialize()

# First request (will be cancelled)
async def first_request():
try:
await client.send_request(
ClientRequest(
CallToolRequest(
params=CallToolRequestParams(name="test_tool", arguments={}),
)
),
CallToolResult,
)
pytest.fail("First request should have been cancelled") # pragma: no cover
except McpError:
pass # Expected

# Start first request
async with anyio.create_task_group() as tg:
tg.start_soon(first_request)

# Wait for it to start
await ev_first_call.wait()

# Cancel it
assert first_request_id is not None
await client.send_notification(
ClientNotification(
CancelledNotification(
params=CancelledNotificationParams(
request_id=first_request_id,
reason="Testing server recovery",
),
async with Client(server) as client:
# First request (will be cancelled)
async def first_request():
try:
await client.session.send_request(
ClientRequest(
CallToolRequest(
params=CallToolRequestParams(name="test_tool", arguments={}),
)
)
),
CallToolResult,
)

# Second request (should work normally)
result = await client.send_request(
ClientRequest(
CallToolRequest(
params=CallToolRequestParams(name="test_tool", arguments={}),
pytest.fail("First request should have been cancelled") # pragma: no cover
except McpError:
pass # Expected

# Start first request
async with anyio.create_task_group() as tg:
tg.start_soon(first_request)

# Wait for it to start
await ev_first_call.wait()

# Cancel it
assert first_request_id is not None
await client.session.send_notification(
ClientNotification(
CancelledNotification(
params=CancelledNotificationParams(
request_id=first_request_id,
reason="Testing server recovery",
),
)
),
CallToolResult,
)
)

# Verify second request completed successfully
assert len(result.content) == 1
# Type narrowing for pyright
content = result.content[0]
assert content.type == "text"
assert isinstance(content, types.TextContent)
assert content.text == "Call number: 2"
assert call_count == 2
# Second request (should work normally)
result = await client.call_tool("test_tool", {})

# Verify second request completed successfully
assert len(result.content) == 1
# Type narrowing for pyright
content = result.content[0]
assert content.type == "text"
assert isinstance(content, types.TextContent)
assert content.text == "Call number: 2"
assert call_count == 2
42 changes: 16 additions & 26 deletions tests/shared/test_progress_notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

import mcp.types as types
from mcp.client._memory import InMemoryTransport
from mcp import Client
from mcp.client.session import ClientSession
from mcp.server import Server
from mcp.server.lowlevel import NotificationOptions
Expand Down Expand Up @@ -369,30 +369,20 @@ async def handle_list_tools() -> list[types.Tool]:

# Test with mocked logging
with patch("mcp.shared.session.logging.error", side_effect=mock_log_error):
transport = InMemoryTransport(server)
async with transport.connect() as (read_stream, write_stream):
async with ClientSession( # pragma: no branch
read_stream=read_stream, write_stream=write_stream
) as session:
await session.initialize()
# Send a request with a failing progress callback
result = await session.send_request(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(name="progress_tool", arguments={}),
)
),
types.CallToolResult,
progress_callback=failing_progress_callback,
)
async with Client(server) as client:
# Call tool with a failing progress callback
result = await client.call_tool(
"progress_tool",
arguments={},
progress_callback=failing_progress_callback,
)

# Verify the request completed successfully despite the callback failure
assert len(result.content) == 1
content = result.content[0]
assert isinstance(content, types.TextContent)
assert content.text == "progress_result"
# Verify the request completed successfully despite the callback failure
assert len(result.content) == 1
content = result.content[0]
assert isinstance(content, types.TextContent)
assert content.text == "progress_result"

# Check that a warning was logged for the progress callback exception
assert len(logged_errors) > 0
assert any("Progress callback raised an exception" in warning for warning in logged_errors)
# Check that a warning was logged for the progress callback exception
assert len(logged_errors) > 0
assert any("Progress callback raised an exception" in warning for warning in logged_errors)
Loading