SIGN IN SIGN UP

The official Python SDK for Model Context Protocol servers and clients

0 0 0 Python
from __future__ import annotations
2024-09-24 22:04:19 +01:00
import anyio
import pytest
from mcp import types
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
from mcp.shared._context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
2024-11-11 12:31:36 +00:00
from mcp.types import (
2024-10-21 14:50:44 +01:00
LATEST_PROTOCOL_VERSION,
CallToolResult,
2024-09-24 22:04:19 +01:00
Implementation,
InitializedNotification,
InitializeRequest,
InitializeResult,
JSONRPCNotification,
2024-09-24 22:04:19 +01:00
JSONRPCRequest,
JSONRPCResponse,
2026-01-22 14:50:39 +01:00
RequestParamsMeta,
2024-09-24 22:04:19 +01:00
ServerCapabilities,
TextContent,
client_notification_adapter,
client_request_adapter,
2024-09-24 22:04:19 +01:00
)
@pytest.mark.anyio
async def test_client_session_initialize():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
initialized_notification = None
result = None
2024-09-24 22:04:19 +01:00
async def mock_server():
nonlocal initialized_notification
session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request, JSONRPCRequest)
request = client_request_adapter.validate_python(
2024-10-02 21:45:37 +01:00
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
2024-09-24 22:04:19 +01:00
)
assert isinstance(request, InitializeRequest)
result = InitializeResult(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(
logging=None,
resources=None,
tools=None,
experimental=None,
prompts=None,
),
server_info=Implementation(name="mock-server", version="0.1.0"),
instructions="The server instructions.",
2024-09-24 22:04:19 +01:00
)
async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
2024-09-24 22:04:19 +01:00
)
)
session_notification = await client_to_server_receive.receive()
jsonrpc_notification = session_notification.message
assert isinstance(jsonrpc_notification, JSONRPCNotification)
initialized_notification = client_notification_adapter.validate_python(
2024-10-11 11:54:16 +01:00
jsonrpc_notification.model_dump(by_alias=True, mode="json", exclude_none=True)
2024-09-24 22:04:19 +01:00
)
# Create a message handler to catch exceptions
async def message_handler( # pragma: no cover
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
raise message
2024-09-24 22:04:19 +01:00
async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
message_handler=message_handler,
) as session,
2024-09-24 22:04:19 +01:00
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
2024-09-24 22:04:19 +01:00
):
tg.start_soon(mock_server)
result = await session.initialize()
# Assert the result
assert isinstance(result, InitializeResult)
assert result.protocol_version == LATEST_PROTOCOL_VERSION
2024-09-24 22:04:19 +01:00
assert isinstance(result.capabilities, ServerCapabilities)
assert result.server_info == Implementation(name="mock-server", version="0.1.0")
assert result.instructions == "The server instructions."
2024-09-24 22:04:19 +01:00
# Check that the client sent the initialized notification
assert initialized_notification
assert isinstance(initialized_notification, InitializedNotification)
@pytest.mark.anyio
async def test_client_session_custom_client_info():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
custom_client_info = Implementation(name="test-client", version="1.2.3")
received_client_info = None
async def mock_server():
nonlocal received_client_info
session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request, JSONRPCRequest)
request = client_request_adapter.validate_python(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request, InitializeRequest)
received_client_info = request.params.client_info
result = InitializeResult(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
server_info=Implementation(name="mock-server", version="0.1.0"),
)
async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
# Receive initialized notification
await client_to_server_receive.receive()
async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
client_info=custom_client_info,
) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
await session.initialize()
# Assert that the custom client info was sent
assert received_client_info == custom_client_info
@pytest.mark.anyio
async def test_client_session_default_client_info():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
received_client_info = None
async def mock_server():
nonlocal received_client_info
session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request, JSONRPCRequest)
request = client_request_adapter.validate_python(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request, InitializeRequest)
received_client_info = request.params.client_info
result = InitializeResult(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
server_info=Implementation(name="mock-server", version="0.1.0"),
)
async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
# Receive initialized notification
await client_to_server_receive.receive()
async with (
ClientSession(server_to_client_receive, client_to_server_send) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
await session.initialize()
# Assert that the default client info was sent
assert received_client_info == DEFAULT_CLIENT_INFO
@pytest.mark.anyio
async def test_client_session_version_negotiation_success():
"""Test successful version negotiation with supported version"""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
result = None
async def mock_server():
session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request, JSONRPCRequest)
request = client_request_adapter.validate_python(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request, InitializeRequest)
# Verify client sent the latest protocol version
assert request.params.protocol_version == LATEST_PROTOCOL_VERSION
# Server responds with a supported older version
result = InitializeResult(
protocol_version="2024-11-05",
capabilities=ServerCapabilities(),
server_info=Implementation(name="mock-server", version="0.1.0"),
)
async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
# Receive initialized notification
await client_to_server_receive.receive()
async with (
ClientSession(server_to_client_receive, client_to_server_send) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
result = await session.initialize()
# Assert the result with negotiated version
assert isinstance(result, InitializeResult)
assert result.protocol_version == "2024-11-05"
assert result.protocol_version in SUPPORTED_PROTOCOL_VERSIONS
@pytest.mark.anyio
async def test_client_session_version_negotiation_failure():
"""Test version negotiation failure with unsupported version"""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
async def mock_server():
session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request, JSONRPCRequest)
request = client_request_adapter.validate_python(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request, InitializeRequest)
# Server responds with an unsupported version
result = InitializeResult(
protocol_version="2020-01-01", # Unsupported old version
capabilities=ServerCapabilities(),
server_info=Implementation(name="mock-server", version="0.1.0"),
)
async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
async with (
ClientSession(server_to_client_receive, client_to_server_send) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
# Should raise RuntimeError for unsupported version
with pytest.raises(RuntimeError, match="Unsupported protocol version"):
await session.initialize()
@pytest.mark.anyio
async def test_client_capabilities_default():
"""Test that client capabilities are properly set with default callbacks"""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
received_capabilities = None
async def mock_server():
nonlocal received_capabilities
session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request, JSONRPCRequest)
request = client_request_adapter.validate_python(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request, InitializeRequest)
received_capabilities = request.params.capabilities
result = InitializeResult(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
server_info=Implementation(name="mock-server", version="0.1.0"),
)
async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
# Receive initialized notification
await client_to_server_receive.receive()
async with (
ClientSession(server_to_client_receive, client_to_server_send) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
await session.initialize()
# Assert that capabilities are properly set with defaults
assert received_capabilities is not None
assert received_capabilities.sampling is None # No custom sampling callback
assert received_capabilities.roots is None # No custom list_roots callback
@pytest.mark.anyio
async def test_client_capabilities_with_custom_callbacks():
"""Test that client capabilities are properly set with custom callbacks"""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
received_capabilities = None
async def custom_sampling_callback( # pragma: no cover
context: RequestContext[ClientSession],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData:
return types.CreateMessageResult(
role="assistant",
content=types.TextContent(type="text", text="test"),
model="test-model",
)
async def custom_list_roots_callback( # pragma: no cover
context: RequestContext[ClientSession],
) -> types.ListRootsResult | types.ErrorData:
return types.ListRootsResult(roots=[])
async def mock_server():
nonlocal received_capabilities
session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request, JSONRPCRequest)
request = client_request_adapter.validate_python(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request, InitializeRequest)
received_capabilities = request.params.capabilities
result = InitializeResult(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
server_info=Implementation(name="mock-server", version="0.1.0"),
)
async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
# Receive initialized notification
await client_to_server_receive.receive()
async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
sampling_callback=custom_sampling_callback,
list_roots_callback=custom_list_roots_callback,
) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
await session.initialize()
# Assert that capabilities are properly set with custom callbacks
assert received_capabilities is not None
# Custom sampling callback provided
assert received_capabilities.sampling is not None
assert isinstance(received_capabilities.sampling, types.SamplingCapability)
# Default sampling capabilities (no tools)
assert received_capabilities.sampling.tools is None
# Custom list_roots callback provided
assert received_capabilities.roots is not None
assert isinstance(received_capabilities.roots, types.RootsCapability)
# Should be True for custom callback
assert received_capabilities.roots.list_changed is True
@pytest.mark.anyio
async def test_client_capabilities_with_sampling_tools():
"""Test that sampling capabilities with tools are properly advertised"""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
received_capabilities = None
async def custom_sampling_callback( # pragma: no cover
context: RequestContext[ClientSession],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData:
return types.CreateMessageResult(
role="assistant",
content=types.TextContent(type="text", text="test"),
model="test-model",
)
async def mock_server():
nonlocal received_capabilities
session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request, JSONRPCRequest)
request = client_request_adapter.validate_python(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request, InitializeRequest)
received_capabilities = request.params.capabilities
result = InitializeResult(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
server_info=Implementation(name="mock-server", version="0.1.0"),
)
async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
# Receive initialized notification
await client_to_server_receive.receive()
async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
sampling_callback=custom_sampling_callback,
sampling_capabilities=types.SamplingCapability(tools=types.SamplingToolsCapability()),
) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
await session.initialize()
# Assert that sampling capabilities with tools are properly advertised
assert received_capabilities is not None
assert received_capabilities.sampling is not None
assert isinstance(received_capabilities.sampling, types.SamplingCapability)
# Tools capability should be present
assert received_capabilities.sampling.tools is not None
assert isinstance(received_capabilities.sampling.tools, types.SamplingToolsCapability)
@pytest.mark.anyio
async def test_initialize_result():
"""Test that initialize_result is None before init and contains the full result after."""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
expected_capabilities = ServerCapabilities(
logging=types.LoggingCapability(),
prompts=types.PromptsCapability(list_changed=True),
resources=types.ResourcesCapability(subscribe=True, list_changed=True),
tools=types.ToolsCapability(list_changed=False),
)
expected_server_info = Implementation(name="mock-server", version="0.1.0")
expected_instructions = "Use the tools wisely."
async def mock_server():
session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request, JSONRPCRequest)
request = client_request_adapter.validate_python(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request, InitializeRequest)
result = InitializeResult(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=expected_capabilities,
server_info=expected_server_info,
instructions=expected_instructions,
)
async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
await client_to_server_receive.receive()
async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
assert session.initialize_result is None
tg.start_soon(mock_server)
await session.initialize()
result = session.initialize_result
assert result is not None
assert result.server_info == expected_server_info
assert result.capabilities == expected_capabilities
assert result.instructions == expected_instructions
assert result.protocol_version == LATEST_PROTOCOL_VERSION
@pytest.mark.anyio
@pytest.mark.parametrize(argnames="meta", argvalues=[None, {"toolMeta": "value"}])
2026-01-22 14:50:39 +01:00
async def test_client_tool_call_with_meta(meta: RequestParamsMeta | None):
"""Test that client tool call requests can include metadata"""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
mocked_tool = types.Tool(name="sample_tool", input_schema={})
async def mock_server():
# Receive initialization request from client
session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request, JSONRPCRequest)
request = client_request_adapter.validate_python(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request, InitializeRequest)
result = InitializeResult(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
server_info=Implementation(name="mock-server", version="0.1.0"),
)
# Answer initialization request
await server_to_client_send.send(
SessionMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
# Receive initialized notification
await client_to_server_receive.receive()
# Wait for the client to send a 'tools/call' request
session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request, JSONRPCRequest)
assert jsonrpc_request.method == "tools/call"
if meta is not None:
assert jsonrpc_request.params
assert "_meta" in jsonrpc_request.params
assert jsonrpc_request.params["_meta"] == meta
result = CallToolResult(content=[TextContent(type="text", text="Called successfully")], is_error=False)
# Send the tools/call result
await server_to_client_send.send(
SessionMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
# Wait for the tools/list request from the client
# The client requires this step to validate the tool output schema
session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request, JSONRPCRequest)
assert jsonrpc_request.method == "tools/list"
result = types.ListToolsResult(tools=[mocked_tool])
await server_to_client_send.send(
SessionMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
server_to_client_send.close()
async with (
ClientSession(server_to_client_receive, client_to_server_send) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
await session.initialize()
await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta)