"""Tests for the unified Client class.""" from __future__ import annotations from unittest.mock import patch import anyio import pytest from inline_snapshot import snapshot from pydantic import FileUrl from mcp import MCPError, types from mcp.client._memory import InMemoryTransport from mcp.client.client import Client from mcp.client.context import ClientRequestContext from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer from mcp.types import ( CallToolResult, EmptyResult, GetPromptResult, ListPromptsResult, ListResourcesResult, ListResourceTemplatesResult, ListToolsResult, Prompt, PromptArgument, PromptMessage, PromptsCapability, ReadResourceResult, Resource, ResourcesCapability, ServerCapabilities, TextContent, TextResourceContents, Tool, ToolsCapability, ) pytestmark = pytest.mark.anyio @pytest.fixture def simple_server() -> Server: """Create a simple MCP server for testing.""" async def handle_list_resources( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> ListResourcesResult: return ListResourcesResult( resources=[Resource(uri="memory://test", name="Test Resource", description="A test resource")] ) async def handle_subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult: return EmptyResult() async def handle_unsubscribe_resource( ctx: ServerRequestContext, params: types.UnsubscribeRequestParams ) -> EmptyResult: return EmptyResult() async def handle_set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: return EmptyResult() async def handle_completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult: return types.CompleteResult(completion=types.Completion(values=[])) return Server( name="test_server", on_list_resources=handle_list_resources, on_subscribe_resource=handle_subscribe_resource, on_unsubscribe_resource=handle_unsubscribe_resource, on_set_logging_level=handle_set_logging_level, on_completion=handle_completion, ) @pytest.fixture def app() -> MCPServer: """Create an MCPServer server for testing.""" server = MCPServer("test") @server.tool() def greet(name: str) -> str: """Greet someone by name.""" return f"Hello, {name}!" @server.resource("test://resource") def test_resource() -> str: """A test resource.""" return "Test content" @server.prompt() def greeting_prompt(name: str) -> str: """A greeting prompt.""" return f"Please greet {name} warmly." return server async def test_client_is_initialized(app: MCPServer): """Test that the client is initialized after entering context.""" async with Client(app) as client: assert client.initialize_result.capabilities == snapshot( ServerCapabilities( experimental={}, prompts=PromptsCapability(list_changed=False), resources=ResourcesCapability(subscribe=False, list_changed=False), tools=ToolsCapability(list_changed=False), ) ) assert client.initialize_result.server_info.name == "test" async def test_client_with_simple_server(simple_server: Server): """Test that from_server works with a basic Server instance.""" async with Client(simple_server) as client: resources = await client.list_resources() assert resources == snapshot( ListResourcesResult( resources=[Resource(name="Test Resource", uri="memory://test", description="A test resource")] ) ) async def test_client_send_ping(app: MCPServer): async with Client(app) as client: result = await client.send_ping() assert result == snapshot(EmptyResult()) async def test_client_list_tools(app: MCPServer): async with Client(app) as client: result = await client.list_tools() assert result == snapshot( ListToolsResult( tools=[ Tool( name="greet", description="Greet someone by name.", input_schema={ "properties": {"name": {"title": "Name", "type": "string"}}, "required": ["name"], "title": "greetArguments", "type": "object", }, output_schema={ "properties": {"result": {"title": "Result", "type": "string"}}, "required": ["result"], "title": "greetOutput", "type": "object", }, ) ] ) ) async def test_client_call_tool(app: MCPServer): async with Client(app) as client: result = await client.call_tool("greet", {"name": "World"}) assert result == snapshot( CallToolResult( content=[TextContent(text="Hello, World!")], structured_content={"result": "Hello, World!"}, ) ) async def test_read_resource(app: MCPServer): """Test reading a resource.""" async with Client(app) as client: result = await client.read_resource("test://resource") assert result == snapshot( ReadResourceResult( contents=[TextResourceContents(uri="test://resource", mime_type="text/plain", text="Test content")] ) ) async def test_read_resource_error_propagates(): """MCPError raised by a server handler propagates to the client with its code intact.""" async def handle_read_resource( ctx: ServerRequestContext, params: types.ReadResourceRequestParams ) -> ReadResourceResult: raise MCPError(code=404, message="no resource with that URI was found") server = Server("test", on_read_resource=handle_read_resource) async with Client(server) as client: with pytest.raises(MCPError) as exc_info: await client.read_resource("unknown://example") assert exc_info.value.error.code == 404 async def test_get_prompt(app: MCPServer): """Test getting a prompt.""" async with Client(app) as client: result = await client.get_prompt("greeting_prompt", {"name": "Alice"}) assert result == snapshot( GetPromptResult( description="A greeting prompt.", messages=[PromptMessage(role="user", content=TextContent(text="Please greet Alice warmly."))], ) ) def test_client_session_property_before_enter(app: MCPServer): """Test that accessing session before context manager raises RuntimeError.""" client = Client(app) with pytest.raises(RuntimeError, match="Client must be used within an async context manager"): client.session async def test_client_reentry_raises_runtime_error(app: MCPServer): """Test that reentering a client raises RuntimeError.""" async with Client(app) as client: with pytest.raises(RuntimeError, match="Client is already entered"): await client.__aenter__() async def test_client_send_progress_notification(): """Test sending progress notification.""" received_from_client = None event = anyio.Event() async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotificationParams) -> None: nonlocal received_from_client received_from_client = {"progress_token": params.progress_token, "progress": params.progress} event.set() server = Server(name="test_server", on_progress=handle_progress) async with Client(server) as client: await client.send_progress_notification(progress_token="token123", progress=50.0) await event.wait() assert received_from_client == snapshot({"progress_token": "token123", "progress": 50.0}) async def test_client_subscribe_resource(simple_server: Server): async with Client(simple_server) as client: result = await client.subscribe_resource("memory://test") assert result == snapshot(EmptyResult()) async def test_client_unsubscribe_resource(simple_server: Server): async with Client(simple_server) as client: result = await client.unsubscribe_resource("memory://test") assert result == snapshot(EmptyResult()) async def test_client_set_logging_level(simple_server: Server): """Test setting logging level.""" async with Client(simple_server) as client: result = await client.set_logging_level("debug") assert result == snapshot(EmptyResult()) async def test_client_list_resources_with_params(app: MCPServer): """Test listing resources with params parameter.""" async with Client(app) as client: result = await client.list_resources() assert result == snapshot( ListResourcesResult( resources=[ Resource( name="test_resource", uri="test://resource", description="A test resource.", mime_type="text/plain", ) ] ) ) async def test_client_list_resource_templates(app: MCPServer): """Test listing resource templates with params parameter.""" async with Client(app) as client: result = await client.list_resource_templates() assert result == snapshot(ListResourceTemplatesResult(resource_templates=[])) async def test_list_prompts(app: MCPServer): """Test listing prompts with params parameter.""" async with Client(app) as client: result = await client.list_prompts() assert result == snapshot( ListPromptsResult( prompts=[ Prompt( name="greeting_prompt", description="A greeting prompt.", arguments=[PromptArgument(name="name", required=True)], ) ] ) ) async def test_complete_with_prompt_reference(simple_server: Server): """Test getting completions for a prompt argument.""" async with Client(simple_server) as client: ref = types.PromptReference(type="ref/prompt", name="test_prompt") result = await client.complete(ref=ref, argument={"name": "arg", "value": "test"}) assert result == snapshot(types.CompleteResult(completion=types.Completion(values=[]))) def test_client_with_url_initializes_streamable_http_transport(): with patch("mcp.client.client.streamable_http_client") as mock: _ = Client("http://localhost:8000/mcp") mock.assert_called_once_with("http://localhost:8000/mcp") async def test_client_uses_transport_directly(app: MCPServer): transport = InMemoryTransport(app) async with Client(transport) as client: result = await client.call_tool("greet", {"name": "Transport"}) assert result == snapshot( CallToolResult( content=[TextContent(text="Hello, Transport!")], structured_content={"result": "Hello, Transport!"}, ) ) # ─── MRTR (SEP-2322) ──────────────────────────────────────────────────────── async def _mrtr_list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: return types.ListToolsResult(tools=[]) async def test_mrtr_single_round_elicitation(): """Server returns IncompleteResult with one elicitation; Client drives retry transparently.""" async def on_call_tool( ctx: ServerRequestContext, params: types.CallToolRequestParams ) -> types.CallToolResult | types.IncompleteResult: units = params.input_responses.get("units") if params.input_responses else None if units is None: return types.IncompleteResult( input_requests={ "units": types.ElicitRequest( params=types.ElicitRequestFormParams( message="Which units?", requested_schema={"type": "object", "properties": {"u": {"type": "string"}}}, ) ) }, ) u = units["content"]["u"] location = params.arguments["location"] if params.arguments else "?" return types.CallToolResult(content=[TextContent(text=f"Weather in {location}: 22°{u}")]) server = Server("mrtr-test", on_call_tool=on_call_tool, on_list_tools=_mrtr_list_tools) async def elicitation_cb(context: ClientRequestContext, params: types.ElicitRequestParams) -> types.ElicitResult: return types.ElicitResult(action="accept", content={"u": "C"}) async with Client(server, elicitation_callback=elicitation_cb) as client: result = await client.call_tool("weather", {"location": "Tokyo"}) assert result == snapshot(CallToolResult(content=[TextContent(text="Weather in Tokyo: 22°C")])) async def test_mrtr_multi_round_with_request_state(): """Two-round elicitation accumulating state in request_state (the ADO-rules SEP example).""" async def on_call_tool( ctx: ServerRequestContext, params: types.CallToolRequestParams ) -> types.CallToolResult | types.IncompleteResult: responses = params.input_responses or {} state = params.request_state if "resolution" not in responses and state is None: return types.IncompleteResult( input_requests={ "resolution": types.ElicitRequest( params=types.ElicitRequestFormParams(message="Resolution?", requested_schema={}) ) }, ) if state is None: resolution = responses["resolution"]["content"]["r"] return types.IncompleteResult( input_requests={ "dup": types.ElicitRequest( params=types.ElicitRequestFormParams(message="Duplicate of?", requested_schema={}) ) }, request_state=f"resolution={resolution}", ) dup = responses["dup"]["content"]["id"] return types.CallToolResult(content=[TextContent(text=f"{state} dup={dup}")]) server = Server("mrtr-test", on_call_tool=on_call_tool, on_list_tools=_mrtr_list_tools) answers = {"Resolution?": {"r": "Duplicate"}, "Duplicate of?": {"id": "4301"}} async def elicitation_cb(context: ClientRequestContext, params: types.ElicitRequestParams) -> types.ElicitResult: assert isinstance(params, types.ElicitRequestFormParams) return types.ElicitResult(action="accept", content=dict(answers[params.message])) async with Client(server, elicitation_callback=elicitation_cb) as client: result = await client.call_tool("update_item", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="resolution=Duplicate dup=4301")])) async def test_mrtr_round_limit_exceeded(): """Server never converges → Client raises after max_mrtr_rounds.""" async def on_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.IncompleteResult: return types.IncompleteResult(request_state="spin") server = Server("mrtr-test", on_call_tool=on_call_tool, on_list_tools=_mrtr_list_tools) async with Client(server, max_mrtr_rounds=3) as client: with pytest.raises(RuntimeError, match="exceeded 3 rounds"): await client.call_tool("stuck", {}) async def test_mrtr_elicitation_without_callback_raises(): """IncompleteResult with elicitation but no callback → clear error.""" async def on_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.IncompleteResult: return types.IncompleteResult( input_requests={ "ask": types.ElicitRequest(params=types.ElicitRequestFormParams(message="?", requested_schema={})) }, ) server = Server("mrtr-test", on_call_tool=on_call_tool, on_list_tools=_mrtr_list_tools) async with Client(server) as client: with pytest.raises(RuntimeError, match="no elicitation_callback"): await client.call_tool("ask", {}) async def test_mrtr_sampling_input_request(): """IncompleteResult with a sampling input request is dispatched to sampling_callback.""" async def on_call_tool( ctx: ServerRequestContext, params: types.CallToolRequestParams ) -> types.CallToolResult | types.IncompleteResult: if params.input_responses and "q" in params.input_responses: answer = params.input_responses["q"]["content"]["text"] return types.CallToolResult(content=[TextContent(text=answer)]) return types.IncompleteResult( input_requests={ "q": types.CreateMessageRequest( params=types.CreateMessageRequestParams( messages=[ types.SamplingMessage(role="user", content=types.TextContent(text="Capital of France?")) ], max_tokens=50, ) ) }, ) server = Server("mrtr-test", on_call_tool=on_call_tool, on_list_tools=_mrtr_list_tools) async def sampling_cb( context: ClientRequestContext, params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: return types.CreateMessageResult( role="assistant", content=types.TextContent(text="Paris"), model="test", stop_reason="endTurn" ) async with Client(server, sampling_callback=sampling_cb) as client: result = await client.call_tool("ask", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="Paris")])) async def test_mrtr_list_roots_input_request(): """IncompleteResult with a roots/list input request is dispatched to list_roots_callback.""" async def on_call_tool( ctx: ServerRequestContext, params: types.CallToolRequestParams ) -> types.CallToolResult | types.IncompleteResult: if params.input_responses and "roots" in params.input_responses: n = len(params.input_responses["roots"]["roots"]) return types.CallToolResult(content=[TextContent(text=f"saw {n} roots")]) return types.IncompleteResult(input_requests={"roots": types.ListRootsRequest()}) server = Server("mrtr-test", on_call_tool=on_call_tool, on_list_tools=_mrtr_list_tools) async def list_roots_cb(context: ClientRequestContext) -> types.ListRootsResult: return types.ListRootsResult(roots=[types.Root(uri=FileUrl("file:///a")), types.Root(uri=FileUrl("file:///b"))]) async with Client(server, list_roots_callback=list_roots_cb) as client: result = await client.call_tool("scan", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="saw 2 roots")])) async def test_mrtr_callback_returns_error_data(): """Callback returning ErrorData surfaces as RuntimeError.""" async def on_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.IncompleteResult: return types.IncompleteResult( input_requests={ "ask": types.ElicitRequest(params=types.ElicitRequestFormParams(message="?", requested_schema={})) }, ) server = Server("mrtr-test", on_call_tool=on_call_tool, on_list_tools=_mrtr_list_tools) async def elicitation_cb(context: ClientRequestContext, params: types.ElicitRequestParams) -> types.ErrorData: return types.ErrorData(code=-1, message="user closed dialog") async with Client(server, elicitation_callback=elicitation_cb) as client: with pytest.raises(RuntimeError, match="user closed dialog"): await client.call_tool("ask", {}) async def test_session_call_tool_raises_on_incomplete(): """ClientSession.call_tool (non-MRTR) raises if server returns IncompleteResult.""" async def on_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.IncompleteResult: return types.IncompleteResult(request_state="x") server = Server("mrtr-test", on_call_tool=on_call_tool, on_list_tools=_mrtr_list_tools) async with Client(server) as client: with pytest.raises(RuntimeError, match="Use Client.call_tool"): await client.session.call_tool("stuck", {})