SIGN IN SIGN UP

The official Python SDK for Model Context Protocol servers and clients

0 0 0 Python
"""Tests for the unified Client class."""
from __future__ import annotations
import contextvars
from collections.abc import Iterator
from contextlib import contextmanager
from unittest.mock import patch
import anyio
import pytest
from inline_snapshot import snapshot
from mcp import MCPError, types
from mcp.client._memory import InMemoryTransport
from mcp.client.client import Client
from mcp.server import Server, ServerRequestContext
from mcp.server.mcpserver import MCPServer
from mcp.types import (
CallToolResult,
EmptyResult,
GetPromptResult,
ListPromptsResult,
ListResourcesResult,
ListResourceTemplatesResult,
ListToolsResult,
Prompt,
PromptArgument,
PromptMessage,
PromptsCapability,
ReadResourceResult,
Resource,
ResourcesCapability,
ServerCapabilities,
TextContent,
TextResourceContents,
Tool,
ToolsCapability,
)
pytestmark = pytest.mark.anyio
@pytest.fixture
def simple_server() -> Server:
"""Create a simple MCP server for testing."""
async def handle_list_resources(
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
) -> ListResourcesResult:
return ListResourcesResult(
resources=[Resource(uri="memory://test", name="Test Resource", description="A test resource")]
)
async def handle_subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult:
return EmptyResult()
async def handle_unsubscribe_resource(
ctx: ServerRequestContext, params: types.UnsubscribeRequestParams
) -> EmptyResult:
return EmptyResult()
async def handle_set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult:
return EmptyResult()
async def handle_completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult:
return types.CompleteResult(completion=types.Completion(values=[]))
return Server(
name="test_server",
on_list_resources=handle_list_resources,
on_subscribe_resource=handle_subscribe_resource,
on_unsubscribe_resource=handle_unsubscribe_resource,
on_set_logging_level=handle_set_logging_level,
on_completion=handle_completion,
)
@pytest.fixture
def app() -> MCPServer:
"""Create an MCPServer server for testing."""
server = MCPServer("test")
@server.tool()
def greet(name: str) -> str:
"""Greet someone by name."""
return f"Hello, {name}!"
@server.resource("test://resource")
def test_resource() -> str:
"""A test resource."""
return "Test content"
@server.prompt()
def greeting_prompt(name: str) -> str:
"""A greeting prompt."""
return f"Please greet {name} warmly."
return server
async def test_client_is_initialized(app: MCPServer):
"""Test that the client is initialized after entering context."""
async with Client(app) as client:
assert client.initialize_result.capabilities == snapshot(
ServerCapabilities(
experimental={},
prompts=PromptsCapability(list_changed=False),
resources=ResourcesCapability(subscribe=False, list_changed=False),
tools=ToolsCapability(list_changed=False),
)
)
assert client.initialize_result.server_info.name == "test"
async def test_client_with_simple_server(simple_server: Server):
"""Test that from_server works with a basic Server instance."""
async with Client(simple_server) as client:
resources = await client.list_resources()
assert resources == snapshot(
ListResourcesResult(
resources=[Resource(name="Test Resource", uri="memory://test", description="A test resource")]
)
)
async def test_client_send_ping(app: MCPServer):
async with Client(app) as client:
result = await client.send_ping()
assert result == snapshot(EmptyResult())
async def test_client_list_tools(app: MCPServer):
async with Client(app) as client:
result = await client.list_tools()
assert result == snapshot(
ListToolsResult(
tools=[
Tool(
name="greet",
description="Greet someone by name.",
input_schema={
"properties": {"name": {"title": "Name", "type": "string"}},
"required": ["name"],
"title": "greetArguments",
"type": "object",
},
output_schema={
"properties": {"result": {"title": "Result", "type": "string"}},
"required": ["result"],
"title": "greetOutput",
"type": "object",
},
)
]
)
)
async def test_client_call_tool(app: MCPServer):
async with Client(app) as client:
result = await client.call_tool("greet", {"name": "World"})
assert result == snapshot(
CallToolResult(
content=[TextContent(text="Hello, World!")],
structured_content={"result": "Hello, World!"},
)
)
async def test_read_resource(app: MCPServer):
"""Test reading a resource."""
async with Client(app) as client:
result = await client.read_resource("test://resource")
assert result == snapshot(
ReadResourceResult(
contents=[TextResourceContents(uri="test://resource", mime_type="text/plain", text="Test content")]
)
)
async def test_read_resource_error_propagates():
"""MCPError raised by a server handler propagates to the client with its code intact."""
async def handle_read_resource(
ctx: ServerRequestContext, params: types.ReadResourceRequestParams
) -> ReadResourceResult:
raise MCPError(code=404, message="no resource with that URI was found")
server = Server("test", on_read_resource=handle_read_resource)
async with Client(server) as client:
with pytest.raises(MCPError) as exc_info:
await client.read_resource("unknown://example")
assert exc_info.value.error.code == 404
async def test_get_prompt(app: MCPServer):
"""Test getting a prompt."""
async with Client(app) as client:
result = await client.get_prompt("greeting_prompt", {"name": "Alice"})
assert result == snapshot(
GetPromptResult(
description="A greeting prompt.",
messages=[PromptMessage(role="user", content=TextContent(text="Please greet Alice warmly."))],
)
)
def test_client_session_property_before_enter(app: MCPServer):
"""Test that accessing session before context manager raises RuntimeError."""
client = Client(app)
with pytest.raises(RuntimeError, match="Client must be used within an async context manager"):
client.session
async def test_client_reentry_raises_runtime_error(app: MCPServer):
"""Test that reentering a client raises RuntimeError."""
async with Client(app) as client:
with pytest.raises(RuntimeError, match="Client is already entered"):
await client.__aenter__()
async def test_client_send_progress_notification():
"""Test sending progress notification."""
received_from_client = None
event = anyio.Event()
async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotificationParams) -> None:
nonlocal received_from_client
received_from_client = {"progress_token": params.progress_token, "progress": params.progress}
event.set()
server = Server(name="test_server", on_progress=handle_progress)
async with Client(server) as client:
await client.send_progress_notification(progress_token="token123", progress=50.0)
await event.wait()
assert received_from_client == snapshot({"progress_token": "token123", "progress": 50.0})
async def test_client_subscribe_resource(simple_server: Server):
async with Client(simple_server) as client:
result = await client.subscribe_resource("memory://test")
assert result == snapshot(EmptyResult())
async def test_client_unsubscribe_resource(simple_server: Server):
async with Client(simple_server) as client:
result = await client.unsubscribe_resource("memory://test")
assert result == snapshot(EmptyResult())
async def test_client_set_logging_level(simple_server: Server):
"""Test setting logging level."""
async with Client(simple_server) as client:
result = await client.set_logging_level("debug")
assert result == snapshot(EmptyResult())
async def test_client_list_resources_with_params(app: MCPServer):
"""Test listing resources with params parameter."""
async with Client(app) as client:
result = await client.list_resources()
assert result == snapshot(
ListResourcesResult(
resources=[
Resource(
name="test_resource",
uri="test://resource",
description="A test resource.",
mime_type="text/plain",
)
]
)
)
async def test_client_list_resource_templates(app: MCPServer):
"""Test listing resource templates with params parameter."""
async with Client(app) as client:
result = await client.list_resource_templates()
assert result == snapshot(ListResourceTemplatesResult(resource_templates=[]))
async def test_list_prompts(app: MCPServer):
"""Test listing prompts with params parameter."""
async with Client(app) as client:
result = await client.list_prompts()
assert result == snapshot(
ListPromptsResult(
prompts=[
Prompt(
name="greeting_prompt",
description="A greeting prompt.",
arguments=[PromptArgument(name="name", required=True)],
)
]
)
)
async def test_complete_with_prompt_reference(simple_server: Server):
"""Test getting completions for a prompt argument."""
async with Client(simple_server) as client:
ref = types.PromptReference(type="ref/prompt", name="test_prompt")
result = await client.complete(ref=ref, argument={"name": "arg", "value": "test"})
assert result == snapshot(types.CompleteResult(completion=types.Completion(values=[])))
def test_client_with_url_initializes_streamable_http_transport():
with patch("mcp.client.client.streamable_http_client") as mock:
_ = Client("http://localhost:8000/mcp")
mock.assert_called_once_with("http://localhost:8000/mcp")
async def test_client_uses_transport_directly(app: MCPServer):
transport = InMemoryTransport(app)
async with Client(transport) as client:
result = await client.call_tool("greet", {"name": "Transport"})
assert result == snapshot(
CallToolResult(
content=[TextContent(text="Hello, Transport!")],
structured_content={"result": "Hello, Transport!"},
)
)
_TEST_CONTEXTVAR = contextvars.ContextVar("test_var", default="initial")
@contextmanager
def _set_test_contextvar(value: str) -> Iterator[None]:
token = _TEST_CONTEXTVAR.set(value)
try:
yield
finally:
_TEST_CONTEXTVAR.reset(token)
async def test_context_propagation():
"""Sender's contextvars.Context is propagated to the server handler."""
server = MCPServer("test")
@server.tool()
async def check_context() -> str:
"""Return the contextvar value visible to the handler."""
return _TEST_CONTEXTVAR.get()
async with Client(server) as client:
with _set_test_contextvar("client_value"):
result = await client.call_tool("check_context", {})
assert result.content[0].text == "client_value", ( # type: ignore[union-attr]
"Server handler did not see the sender's contextvars.Context"
)