SIGN IN SIGN UP

The official Python SDK for Model Context Protocol servers and clients

0 0 0 Python
from typing import Any
2024-09-24 22:04:19 +01:00
import anyio
import pytest
from mcp import types
2024-11-11 12:31:36 +00:00
from mcp.client.session import ClientSession
from mcp.server import Server, ServerRequestContext
from mcp.server.lowlevel import NotificationOptions
from mcp.server.models import InitializationOptions
2024-11-11 20:17:39 +00:00
from mcp.server.session import ServerSession
from mcp.shared.exceptions import MCPError
from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder
2024-11-11 12:31:36 +00:00
from mcp.types import (
2024-09-24 22:04:19 +01:00
ClientNotification,
CompletionsCapability,
2024-09-24 22:04:19 +01:00
InitializedNotification,
PromptsCapability,
ResourcesCapability,
ServerCapabilities,
2024-09-24 22:04:19 +01:00
)
@pytest.mark.anyio
async def test_server_session_initialize():
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
2024-09-24 22:04:19 +01:00
# Create a message handler to catch exceptions
async def message_handler( # pragma: no cover
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
raise message
2024-09-24 22:04:19 +01:00
received_initialized = False
async def run_server():
nonlocal received_initialized
async with ServerSession(
2024-10-11 11:54:16 +01:00
client_to_server_receive,
server_to_client_send,
InitializationOptions(
2024-11-11 12:31:36 +00:00
server_name="mcp",
2024-10-11 11:54:16 +01:00
server_version="0.1.0",
capabilities=ServerCapabilities(),
),
2024-09-24 22:04:19 +01:00
) as server_session:
async for message in server_session.incoming_messages: # pragma: no branch
if isinstance(message, Exception): # pragma: no cover
raise message
if isinstance(message, ClientNotification) and isinstance(
message, InitializedNotification
2024-09-24 22:04:19 +01:00
): # pragma: no branch
received_initialized = True
return
try:
async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
message_handler=message_handler,
2024-09-24 22:04:19 +01:00
) as client_session,
anyio.create_task_group() as tg,
):
tg.start_soon(run_server)
await client_session.initialize()
2024-10-02 21:55:53 +01:00
except anyio.ClosedResourceError: # pragma: no cover
2024-09-24 22:04:19 +01:00
pass
assert received_initialized
@pytest.mark.anyio
async def test_server_capabilities():
notification_options = NotificationOptions()
experimental_capabilities: dict[str, Any] = {}
async def noop_list_prompts(
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
) -> types.ListPromptsResult:
raise NotImplementedError
async def noop_list_resources(
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
) -> types.ListResourcesResult:
raise NotImplementedError
async def noop_completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult:
raise NotImplementedError
# No capabilities
server = Server("test")
caps = server.get_capabilities(notification_options, experimental_capabilities)
assert caps.prompts is None
assert caps.resources is None
assert caps.completions is None
# With prompts handler
server = Server("test", on_list_prompts=noop_list_prompts)
caps = server.get_capabilities(notification_options, experimental_capabilities)
assert caps.prompts == PromptsCapability(list_changed=False)
assert caps.resources is None
assert caps.completions is None
# With prompts + resources handlers
server = Server("test", on_list_prompts=noop_list_prompts, on_list_resources=noop_list_resources)
caps = server.get_capabilities(notification_options, experimental_capabilities)
assert caps.prompts == PromptsCapability(list_changed=False)
assert caps.resources == ResourcesCapability(subscribe=False, list_changed=False)
assert caps.completions is None
# With prompts + resources + completion handlers
server = Server(
"test",
on_list_prompts=noop_list_prompts,
on_list_resources=noop_list_resources,
on_completion=noop_completion,
)
caps = server.get_capabilities(notification_options, experimental_capabilities)
assert caps.prompts == PromptsCapability(list_changed=False)
assert caps.resources == ResourcesCapability(subscribe=False, list_changed=False)
assert caps.completions == CompletionsCapability()
@pytest.mark.anyio
async def test_server_session_initialize_with_older_protocol_version():
"""Test that server accepts and responds with older protocol (2024-11-05)."""
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
received_initialized = False
received_protocol_version = None
async def run_server():
nonlocal received_initialized
async with ServerSession(
client_to_server_receive,
server_to_client_send,
InitializationOptions(
server_name="mcp",
server_version="0.1.0",
capabilities=ServerCapabilities(),
),
) as server_session:
async for message in server_session.incoming_messages: # pragma: no branch
if isinstance(message, Exception): # pragma: no cover
raise message
if isinstance(message, types.ClientNotification) and isinstance(
message, InitializedNotification
): # pragma: no branch
received_initialized = True
return
async def mock_client():
nonlocal received_protocol_version
# Send initialization request with older protocol version (2024-11-05)
await client_to_server_send.send(
SessionMessage(
types.JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="initialize",
params=types.InitializeRequestParams(
protocol_version="2024-11-05",
capabilities=types.ClientCapabilities(),
client_info=types.Implementation(name="test-client", version="1.0.0"),
).model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
# Wait for the initialize response
init_response_message = await server_to_client_receive.receive()
assert isinstance(init_response_message.message, types.JSONRPCResponse)
result_data = init_response_message.message.result
init_result = types.InitializeResult.model_validate(result_data)
# Check that the server responded with the requested protocol version
received_protocol_version = init_result.protocol_version
assert received_protocol_version == "2024-11-05"
# Send initialized notification
await client_to_server_send.send(
SessionMessage(types.JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized"))
)
async with (
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
anyio.create_task_group() as tg,
):
tg.start_soon(run_server)
tg.start_soon(mock_client)
assert received_initialized
assert received_protocol_version == "2024-11-05"
@pytest.mark.anyio
async def test_ping_request_before_initialization():
"""Test that ping requests are allowed before initialization is complete."""
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
ping_response_received = False
ping_response_id = None
async def run_server():
async with ServerSession(
client_to_server_receive,
server_to_client_send,
InitializationOptions(
server_name="mcp",
server_version="0.1.0",
capabilities=ServerCapabilities(),
),
) as server_session:
async for message in server_session.incoming_messages: # pragma: no branch
if isinstance(message, Exception): # pragma: no cover
raise message
# We should receive a ping request before initialization
if isinstance(message, RequestResponder) and isinstance(
message.request, types.PingRequest
): # pragma: no branch
# Respond to the ping
with message:
await message.respond(types.EmptyResult())
return
async def mock_client():
nonlocal ping_response_received, ping_response_id
# Send ping request before any initialization
await client_to_server_send.send(SessionMessage(types.JSONRPCRequest(jsonrpc="2.0", id=42, method="ping")))
# Wait for the ping response
ping_response_message = await server_to_client_receive.receive()
assert isinstance(ping_response_message.message, types.JSONRPCResponse)
ping_response_received = True
ping_response_id = ping_response_message.message.id
async with (
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
anyio.create_task_group() as tg,
):
tg.start_soon(run_server)
tg.start_soon(mock_client)
assert ping_response_received
assert ping_response_id == 42
@pytest.mark.anyio
async def test_create_message_tool_result_validation():
"""Test tool_use/tool_result validation in create_message."""
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
async with (
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
async with ServerSession(
client_to_server_receive,
server_to_client_send,
InitializationOptions(
server_name="test",
server_version="0.1.0",
capabilities=ServerCapabilities(),
),
) as session:
# Set up client params with sampling.tools capability for the test
session._client_params = types.InitializeRequestParams(
protocol_version=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(
sampling=types.SamplingCapability(tools=types.SamplingToolsCapability())
),
client_info=types.Implementation(name="test", version="1.0"),
)
tool = types.Tool(name="test_tool", input_schema={"type": "object"})
text = types.TextContent(type="text", text="hello")
tool_use = types.ToolUseContent(type="tool_use", id="call_1", name="test_tool", input={})
tool_result = types.ToolResultContent(type="tool_result", tool_use_id="call_1", content=[])
# Case 1: tool_result mixed with other content
with pytest.raises(ValueError, match="only tool_result content"):
await session.create_message(
messages=[
types.SamplingMessage(role="user", content=text),
types.SamplingMessage(role="assistant", content=tool_use),
types.SamplingMessage(role="user", content=[tool_result, text]), # mixed!
],
max_tokens=100,
tools=[tool],
)
# Case 2: tool_result without previous message
with pytest.raises(ValueError, match="requires a previous message"):
await session.create_message(
messages=[types.SamplingMessage(role="user", content=tool_result)],
max_tokens=100,
tools=[tool],
)
# Case 3: tool_result without previous tool_use
with pytest.raises(ValueError, match="do not match any tool_use"):
await session.create_message(
messages=[
types.SamplingMessage(role="user", content=text),
types.SamplingMessage(role="user", content=tool_result),
],
max_tokens=100,
tools=[tool],
)
# Case 4: mismatched tool IDs
with pytest.raises(ValueError, match="ids of tool_result blocks and tool_use blocks"):
await session.create_message(
messages=[
types.SamplingMessage(role="user", content=text),
types.SamplingMessage(role="assistant", content=tool_use),
types.SamplingMessage(
role="user",
content=types.ToolResultContent(type="tool_result", tool_use_id="wrong_id", content=[]),
),
],
max_tokens=100,
tools=[tool],
)
# Case 5: text-only message with tools (no tool_results) - passes validation
# Covers has_tool_results=False branch.
# We use move_on_after because validation happens synchronously before
# send_request, which would block indefinitely waiting for a response.
# The timeout lets validation pass, then cancels the blocked send.
with anyio.move_on_after(0.01):
await session.create_message(
messages=[types.SamplingMessage(role="user", content=text)],
max_tokens=100,
tools=[tool],
)
# Case 6: valid matching tool_result/tool_use IDs - passes validation
# Covers tool_use_ids == tool_result_ids branch.
# (see Case 5 comment for move_on_after explanation)
with anyio.move_on_after(0.01):
await session.create_message(
messages=[
types.SamplingMessage(role="user", content=text),
types.SamplingMessage(role="assistant", content=tool_use),
types.SamplingMessage(role="user", content=tool_result),
],
max_tokens=100,
tools=[tool],
)
# Case 7: validation runs even without `tools` parameter
# (tool loop continuation may omit tools while containing tool_result)
with pytest.raises(ValueError, match="do not match any tool_use"):
await session.create_message(
messages=[
types.SamplingMessage(role="user", content=text),
types.SamplingMessage(role="user", content=tool_result),
],
max_tokens=100,
# Note: no tools parameter
)
# Case 8: empty messages list - skips validation entirely
# Covers the `if messages:` branch (line 280->302)
with anyio.move_on_after(0.01): # pragma: no branch
2026-01-07 17:28:23 +01:00
await session.create_message(messages=[], max_tokens=100)
@pytest.mark.anyio
async def test_create_message_without_tools_capability():
"""Test that create_message raises MCPError when tools are provided without capability."""
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
async with (
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
async with ServerSession(
client_to_server_receive,
server_to_client_send,
InitializationOptions(
server_name="test",
server_version="0.1.0",
capabilities=ServerCapabilities(),
),
) as session:
# Set up client params WITHOUT sampling.tools capability
session._client_params = types.InitializeRequestParams(
protocol_version=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(sampling=types.SamplingCapability()),
client_info=types.Implementation(name="test", version="1.0"),
)
tool = types.Tool(name="test_tool", input_schema={"type": "object"})
text = types.TextContent(type="text", text="hello")
# Should raise MCPError when tools are provided but client lacks capability
with pytest.raises(MCPError) as exc_info:
await session.create_message(
messages=[types.SamplingMessage(role="user", content=text)],
max_tokens=100,
tools=[tool],
)
assert "does not support sampling tools capability" in exc_info.value.error.message
# Should also raise MCPError when tool_choice is provided
with pytest.raises(MCPError) as exc_info:
await session.create_message(
messages=[types.SamplingMessage(role="user", content=text)],
max_tokens=100,
tool_choice=types.ToolChoice(mode="auto"),
)
assert "does not support sampling tools capability" in exc_info.value.error.message
@pytest.mark.anyio
async def test_other_requests_blocked_before_initialization():
"""Test that non-ping requests are still blocked before initialization."""
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
error_response_received = False
error_code = None
async def run_server():
async with ServerSession(
client_to_server_receive,
server_to_client_send,
InitializationOptions(
server_name="mcp",
server_version="0.1.0",
capabilities=ServerCapabilities(),
),
):
# Server should handle the request and send an error response
# No need to process incoming_messages since the error is handled automatically
await anyio.sleep(0.1) # Give time for the request to be processed
async def mock_client():
nonlocal error_response_received, error_code
# Try to send a non-ping request before initialization
await client_to_server_send.send(
SessionMessage(types.JSONRPCRequest(jsonrpc="2.0", id=1, method="prompts/list"))
)
# Wait for the error response
error_message = await server_to_client_receive.receive()
if isinstance(error_message.message, types.JSONRPCError): # pragma: no branch
error_response_received = True
error_code = error_message.message.error.code
async with (
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
anyio.create_task_group() as tg,
):
tg.start_soon(run_server)
tg.start_soon(mock_client)
assert error_response_received
assert error_code == types.INVALID_PARAMS