2025-08-11 19:56:37 +02:00
|
|
|
from typing import Any
|
|
|
|
|
|
2024-09-24 22:04:19 +01:00
|
|
|
import anyio
|
|
|
|
|
import pytest
|
|
|
|
|
|
2026-02-03 17:42:29 +01:00
|
|
|
from mcp import types
|
2024-11-11 12:31:36 +00:00
|
|
|
from mcp.client.session import ClientSession
|
2026-02-12 15:55:54 +00:00
|
|
|
from mcp.server import Server, ServerRequestContext
|
2024-12-09 16:16:47 +00:00
|
|
|
from mcp.server.lowlevel import NotificationOptions
|
2024-11-11 20:05:51 +00:00
|
|
|
from mcp.server.models import InitializationOptions
|
2024-11-11 20:17:39 +00:00
|
|
|
from mcp.server.session import ServerSession
|
2026-01-26 14:37:44 +01:00
|
|
|
from mcp.shared.exceptions import MCPError
|
2025-05-02 14:29:00 +01:00
|
|
|
from mcp.shared.message import SessionMessage
|
2025-03-24 14:14:14 +00:00
|
|
|
from mcp.shared.session import RequestResponder
|
2024-11-11 12:31:36 +00:00
|
|
|
from mcp.types import (
|
2024-09-24 22:04:19 +01:00
|
|
|
ClientNotification,
|
2025-07-07 10:26:22 +02:00
|
|
|
CompletionsCapability,
|
2024-09-24 22:04:19 +01:00
|
|
|
InitializedNotification,
|
2024-11-06 22:50:37 +00:00
|
|
|
PromptsCapability,
|
|
|
|
|
ResourcesCapability,
|
2024-10-11 11:06:02 +01:00
|
|
|
ServerCapabilities,
|
2024-09-24 22:04:19 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.anyio
|
|
|
|
|
async def test_server_session_initialize():
|
2025-01-03 15:23:58 +00:00
|
|
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
|
|
|
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
2024-09-24 22:04:19 +01:00
|
|
|
|
2025-03-24 14:14:14 +00:00
|
|
|
# Create a message handler to catch exceptions
|
|
|
|
|
async def message_handler( # pragma: no cover
|
|
|
|
|
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
|
|
|
|
) -> None:
|
|
|
|
|
if isinstance(message, Exception):
|
|
|
|
|
raise message
|
2024-09-24 22:04:19 +01:00
|
|
|
|
|
|
|
|
received_initialized = False
|
|
|
|
|
|
|
|
|
|
async def run_server():
|
|
|
|
|
nonlocal received_initialized
|
|
|
|
|
|
|
|
|
|
async with ServerSession(
|
2024-10-11 11:54:16 +01:00
|
|
|
client_to_server_receive,
|
|
|
|
|
server_to_client_send,
|
|
|
|
|
InitializationOptions(
|
2024-11-11 12:31:36 +00:00
|
|
|
server_name="mcp",
|
2024-10-11 11:54:16 +01:00
|
|
|
server_version="0.1.0",
|
|
|
|
|
capabilities=ServerCapabilities(),
|
|
|
|
|
),
|
2024-09-24 22:04:19 +01:00
|
|
|
) as server_session:
|
|
|
|
|
async for message in server_session.incoming_messages: # pragma: no branch
|
|
|
|
|
if isinstance(message, Exception): # pragma: no cover
|
|
|
|
|
raise message
|
|
|
|
|
|
|
|
|
|
if isinstance(message, ClientNotification) and isinstance(
|
2026-01-19 14:29:15 +01:00
|
|
|
message, InitializedNotification
|
2024-09-24 22:04:19 +01:00
|
|
|
): # pragma: no branch
|
|
|
|
|
received_initialized = True
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
async with (
|
|
|
|
|
ClientSession(
|
2025-03-24 14:14:14 +00:00
|
|
|
server_to_client_receive,
|
|
|
|
|
client_to_server_send,
|
|
|
|
|
message_handler=message_handler,
|
2024-09-24 22:04:19 +01:00
|
|
|
) as client_session,
|
|
|
|
|
anyio.create_task_group() as tg,
|
|
|
|
|
):
|
|
|
|
|
tg.start_soon(run_server)
|
|
|
|
|
|
|
|
|
|
await client_session.initialize()
|
2024-10-02 21:55:53 +01:00
|
|
|
except anyio.ClosedResourceError: # pragma: no cover
|
2024-09-24 22:04:19 +01:00
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
assert received_initialized
|
2024-10-11 11:06:02 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.anyio
|
|
|
|
|
async def test_server_capabilities():
|
2024-11-06 22:50:37 +00:00
|
|
|
notification_options = NotificationOptions()
|
2025-08-11 19:56:37 +02:00
|
|
|
experimental_capabilities: dict[str, Any] = {}
|
2024-10-11 11:06:02 +01:00
|
|
|
|
2026-02-12 15:55:54 +00:00
|
|
|
async def noop_list_prompts(
|
|
|
|
|
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
|
|
|
|
|
) -> types.ListPromptsResult:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
async def noop_list_resources(
|
|
|
|
|
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
|
|
|
|
|
) -> types.ListResourcesResult:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
async def noop_completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
# No capabilities
|
|
|
|
|
server = Server("test")
|
2024-11-06 22:50:37 +00:00
|
|
|
caps = server.get_capabilities(notification_options, experimental_capabilities)
|
2024-10-11 11:06:02 +01:00
|
|
|
assert caps.prompts is None
|
|
|
|
|
assert caps.resources is None
|
2025-07-07 10:26:22 +02:00
|
|
|
assert caps.completions is None
|
2024-10-11 11:06:02 +01:00
|
|
|
|
2026-02-12 15:55:54 +00:00
|
|
|
# With prompts handler
|
|
|
|
|
server = Server("test", on_list_prompts=noop_list_prompts)
|
2024-11-06 22:50:37 +00:00
|
|
|
caps = server.get_capabilities(notification_options, experimental_capabilities)
|
2026-01-16 15:51:27 +01:00
|
|
|
assert caps.prompts == PromptsCapability(list_changed=False)
|
2024-10-11 11:06:02 +01:00
|
|
|
assert caps.resources is None
|
2025-07-07 10:26:22 +02:00
|
|
|
assert caps.completions is None
|
2024-10-11 11:06:02 +01:00
|
|
|
|
2026-02-12 15:55:54 +00:00
|
|
|
# With prompts + resources handlers
|
|
|
|
|
server = Server("test", on_list_prompts=noop_list_prompts, on_list_resources=noop_list_resources)
|
2024-11-06 22:50:37 +00:00
|
|
|
caps = server.get_capabilities(notification_options, experimental_capabilities)
|
2026-01-16 15:51:27 +01:00
|
|
|
assert caps.prompts == PromptsCapability(list_changed=False)
|
|
|
|
|
assert caps.resources == ResourcesCapability(subscribe=False, list_changed=False)
|
2025-07-07 10:26:22 +02:00
|
|
|
assert caps.completions is None
|
|
|
|
|
|
2026-02-12 15:55:54 +00:00
|
|
|
# With prompts + resources + completion handlers
|
|
|
|
|
server = Server(
|
|
|
|
|
"test",
|
|
|
|
|
on_list_prompts=noop_list_prompts,
|
|
|
|
|
on_list_resources=noop_list_resources,
|
|
|
|
|
on_completion=noop_completion,
|
|
|
|
|
)
|
2025-07-07 10:26:22 +02:00
|
|
|
caps = server.get_capabilities(notification_options, experimental_capabilities)
|
2026-01-16 15:51:27 +01:00
|
|
|
assert caps.prompts == PromptsCapability(list_changed=False)
|
|
|
|
|
assert caps.resources == ResourcesCapability(subscribe=False, list_changed=False)
|
2025-07-07 10:26:22 +02:00
|
|
|
assert caps.completions == CompletionsCapability()
|
2025-05-15 18:33:31 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.anyio
|
|
|
|
|
async def test_server_session_initialize_with_older_protocol_version():
|
|
|
|
|
"""Test that server accepts and responds with older protocol (2024-11-05)."""
|
|
|
|
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
|
|
|
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
|
|
|
|
|
|
|
|
|
|
received_initialized = False
|
|
|
|
|
received_protocol_version = None
|
|
|
|
|
|
|
|
|
|
async def run_server():
|
|
|
|
|
nonlocal received_initialized
|
|
|
|
|
|
|
|
|
|
async with ServerSession(
|
|
|
|
|
client_to_server_receive,
|
|
|
|
|
server_to_client_send,
|
|
|
|
|
InitializationOptions(
|
|
|
|
|
server_name="mcp",
|
|
|
|
|
server_version="0.1.0",
|
|
|
|
|
capabilities=ServerCapabilities(),
|
|
|
|
|
),
|
|
|
|
|
) as server_session:
|
|
|
|
|
async for message in server_session.incoming_messages: # pragma: no branch
|
|
|
|
|
if isinstance(message, Exception): # pragma: no cover
|
|
|
|
|
raise message
|
|
|
|
|
|
|
|
|
|
if isinstance(message, types.ClientNotification) and isinstance(
|
2026-01-19 14:29:15 +01:00
|
|
|
message, InitializedNotification
|
2025-05-15 18:33:31 +01:00
|
|
|
): # pragma: no branch
|
|
|
|
|
received_initialized = True
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
async def mock_client():
|
|
|
|
|
nonlocal received_protocol_version
|
|
|
|
|
|
|
|
|
|
# Send initialization request with older protocol version (2024-11-05)
|
|
|
|
|
await client_to_server_send.send(
|
|
|
|
|
SessionMessage(
|
2026-01-19 14:04:15 +01:00
|
|
|
types.JSONRPCRequest(
|
|
|
|
|
jsonrpc="2.0",
|
|
|
|
|
id=1,
|
|
|
|
|
method="initialize",
|
|
|
|
|
params=types.InitializeRequestParams(
|
|
|
|
|
protocol_version="2024-11-05",
|
|
|
|
|
capabilities=types.ClientCapabilities(),
|
|
|
|
|
client_info=types.Implementation(name="test-client", version="1.0.0"),
|
|
|
|
|
).model_dump(by_alias=True, mode="json", exclude_none=True),
|
2025-05-15 18:33:31 +01:00
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Wait for the initialize response
|
|
|
|
|
init_response_message = await server_to_client_receive.receive()
|
2026-01-19 14:04:15 +01:00
|
|
|
assert isinstance(init_response_message.message, types.JSONRPCResponse)
|
|
|
|
|
result_data = init_response_message.message.result
|
2025-05-15 18:33:31 +01:00
|
|
|
init_result = types.InitializeResult.model_validate(result_data)
|
|
|
|
|
|
|
|
|
|
# Check that the server responded with the requested protocol version
|
2026-01-16 15:51:27 +01:00
|
|
|
received_protocol_version = init_result.protocol_version
|
2025-05-15 18:33:31 +01:00
|
|
|
assert received_protocol_version == "2024-11-05"
|
|
|
|
|
|
|
|
|
|
# Send initialized notification
|
|
|
|
|
await client_to_server_send.send(
|
2026-01-19 14:04:15 +01:00
|
|
|
SessionMessage(types.JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized"))
|
2025-05-15 18:33:31 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async with (
|
|
|
|
|
client_to_server_send,
|
|
|
|
|
client_to_server_receive,
|
|
|
|
|
server_to_client_send,
|
|
|
|
|
server_to_client_receive,
|
|
|
|
|
anyio.create_task_group() as tg,
|
|
|
|
|
):
|
|
|
|
|
tg.start_soon(run_server)
|
|
|
|
|
tg.start_soon(mock_client)
|
|
|
|
|
|
|
|
|
|
assert received_initialized
|
|
|
|
|
assert received_protocol_version == "2024-11-05"
|
2025-09-01 23:37:36 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.anyio
|
|
|
|
|
async def test_ping_request_before_initialization():
|
|
|
|
|
"""Test that ping requests are allowed before initialization is complete."""
|
|
|
|
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
|
|
|
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
|
|
|
|
|
|
|
|
|
|
ping_response_received = False
|
|
|
|
|
ping_response_id = None
|
|
|
|
|
|
|
|
|
|
async def run_server():
|
|
|
|
|
async with ServerSession(
|
|
|
|
|
client_to_server_receive,
|
|
|
|
|
server_to_client_send,
|
|
|
|
|
InitializationOptions(
|
|
|
|
|
server_name="mcp",
|
|
|
|
|
server_version="0.1.0",
|
|
|
|
|
capabilities=ServerCapabilities(),
|
|
|
|
|
),
|
|
|
|
|
) as server_session:
|
|
|
|
|
async for message in server_session.incoming_messages: # pragma: no branch
|
|
|
|
|
if isinstance(message, Exception): # pragma: no cover
|
|
|
|
|
raise message
|
|
|
|
|
|
|
|
|
|
# We should receive a ping request before initialization
|
|
|
|
|
if isinstance(message, RequestResponder) and isinstance(
|
2026-01-19 14:29:15 +01:00
|
|
|
message.request, types.PingRequest
|
2025-09-01 23:37:36 +02:00
|
|
|
): # pragma: no branch
|
|
|
|
|
# Respond to the ping
|
|
|
|
|
with message:
|
2026-01-19 14:29:15 +01:00
|
|
|
await message.respond(types.EmptyResult())
|
2025-09-01 23:37:36 +02:00
|
|
|
return
|
|
|
|
|
|
|
|
|
|
async def mock_client():
|
|
|
|
|
nonlocal ping_response_received, ping_response_id
|
|
|
|
|
|
|
|
|
|
# Send ping request before any initialization
|
2026-01-19 14:04:15 +01:00
|
|
|
await client_to_server_send.send(SessionMessage(types.JSONRPCRequest(jsonrpc="2.0", id=42, method="ping")))
|
2025-09-01 23:37:36 +02:00
|
|
|
|
|
|
|
|
# Wait for the ping response
|
|
|
|
|
ping_response_message = await server_to_client_receive.receive()
|
2026-01-19 14:04:15 +01:00
|
|
|
assert isinstance(ping_response_message.message, types.JSONRPCResponse)
|
2025-09-01 23:37:36 +02:00
|
|
|
|
|
|
|
|
ping_response_received = True
|
2026-01-19 14:04:15 +01:00
|
|
|
ping_response_id = ping_response_message.message.id
|
2025-09-01 23:37:36 +02:00
|
|
|
|
|
|
|
|
async with (
|
|
|
|
|
client_to_server_send,
|
|
|
|
|
client_to_server_receive,
|
|
|
|
|
server_to_client_send,
|
|
|
|
|
server_to_client_receive,
|
|
|
|
|
anyio.create_task_group() as tg,
|
|
|
|
|
):
|
|
|
|
|
tg.start_soon(run_server)
|
|
|
|
|
tg.start_soon(mock_client)
|
|
|
|
|
|
|
|
|
|
assert ping_response_received
|
|
|
|
|
assert ping_response_id == 42
|
|
|
|
|
|
|
|
|
|
|
2025-11-23 04:58:14 +00:00
|
|
|
@pytest.mark.anyio
|
|
|
|
|
async def test_create_message_tool_result_validation():
|
|
|
|
|
"""Test tool_use/tool_result validation in create_message."""
|
|
|
|
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
|
|
|
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
|
|
|
|
|
|
|
|
|
|
async with (
|
|
|
|
|
client_to_server_send,
|
|
|
|
|
client_to_server_receive,
|
|
|
|
|
server_to_client_send,
|
|
|
|
|
server_to_client_receive,
|
|
|
|
|
):
|
|
|
|
|
async with ServerSession(
|
|
|
|
|
client_to_server_receive,
|
|
|
|
|
server_to_client_send,
|
|
|
|
|
InitializationOptions(
|
|
|
|
|
server_name="test",
|
|
|
|
|
server_version="0.1.0",
|
|
|
|
|
capabilities=ServerCapabilities(),
|
|
|
|
|
),
|
|
|
|
|
) as session:
|
|
|
|
|
# Set up client params with sampling.tools capability for the test
|
|
|
|
|
session._client_params = types.InitializeRequestParams(
|
2026-01-16 15:51:27 +01:00
|
|
|
protocol_version=types.LATEST_PROTOCOL_VERSION,
|
2025-11-23 04:58:14 +00:00
|
|
|
capabilities=types.ClientCapabilities(
|
|
|
|
|
sampling=types.SamplingCapability(tools=types.SamplingToolsCapability())
|
|
|
|
|
),
|
2026-01-16 15:51:27 +01:00
|
|
|
client_info=types.Implementation(name="test", version="1.0"),
|
2025-11-23 04:58:14 +00:00
|
|
|
)
|
|
|
|
|
|
2026-01-16 15:51:27 +01:00
|
|
|
tool = types.Tool(name="test_tool", input_schema={"type": "object"})
|
2025-11-23 04:58:14 +00:00
|
|
|
text = types.TextContent(type="text", text="hello")
|
|
|
|
|
tool_use = types.ToolUseContent(type="tool_use", id="call_1", name="test_tool", input={})
|
2026-01-16 15:51:27 +01:00
|
|
|
tool_result = types.ToolResultContent(type="tool_result", tool_use_id="call_1", content=[])
|
2025-11-23 04:58:14 +00:00
|
|
|
|
|
|
|
|
# Case 1: tool_result mixed with other content
|
|
|
|
|
with pytest.raises(ValueError, match="only tool_result content"):
|
|
|
|
|
await session.create_message(
|
|
|
|
|
messages=[
|
|
|
|
|
types.SamplingMessage(role="user", content=text),
|
|
|
|
|
types.SamplingMessage(role="assistant", content=tool_use),
|
|
|
|
|
types.SamplingMessage(role="user", content=[tool_result, text]), # mixed!
|
|
|
|
|
],
|
|
|
|
|
max_tokens=100,
|
|
|
|
|
tools=[tool],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Case 2: tool_result without previous message
|
|
|
|
|
with pytest.raises(ValueError, match="requires a previous message"):
|
|
|
|
|
await session.create_message(
|
|
|
|
|
messages=[types.SamplingMessage(role="user", content=tool_result)],
|
|
|
|
|
max_tokens=100,
|
|
|
|
|
tools=[tool],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Case 3: tool_result without previous tool_use
|
|
|
|
|
with pytest.raises(ValueError, match="do not match any tool_use"):
|
|
|
|
|
await session.create_message(
|
|
|
|
|
messages=[
|
|
|
|
|
types.SamplingMessage(role="user", content=text),
|
|
|
|
|
types.SamplingMessage(role="user", content=tool_result),
|
|
|
|
|
],
|
|
|
|
|
max_tokens=100,
|
|
|
|
|
tools=[tool],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Case 4: mismatched tool IDs
|
|
|
|
|
with pytest.raises(ValueError, match="ids of tool_result blocks and tool_use blocks"):
|
|
|
|
|
await session.create_message(
|
|
|
|
|
messages=[
|
|
|
|
|
types.SamplingMessage(role="user", content=text),
|
|
|
|
|
types.SamplingMessage(role="assistant", content=tool_use),
|
|
|
|
|
types.SamplingMessage(
|
|
|
|
|
role="user",
|
2026-01-16 15:51:27 +01:00
|
|
|
content=types.ToolResultContent(type="tool_result", tool_use_id="wrong_id", content=[]),
|
2025-11-23 04:58:14 +00:00
|
|
|
),
|
|
|
|
|
],
|
|
|
|
|
max_tokens=100,
|
|
|
|
|
tools=[tool],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Case 5: text-only message with tools (no tool_results) - passes validation
|
|
|
|
|
# Covers has_tool_results=False branch.
|
|
|
|
|
# We use move_on_after because validation happens synchronously before
|
|
|
|
|
# send_request, which would block indefinitely waiting for a response.
|
|
|
|
|
# The timeout lets validation pass, then cancels the blocked send.
|
|
|
|
|
with anyio.move_on_after(0.01):
|
|
|
|
|
await session.create_message(
|
|
|
|
|
messages=[types.SamplingMessage(role="user", content=text)],
|
|
|
|
|
max_tokens=100,
|
|
|
|
|
tools=[tool],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Case 6: valid matching tool_result/tool_use IDs - passes validation
|
|
|
|
|
# Covers tool_use_ids == tool_result_ids branch.
|
|
|
|
|
# (see Case 5 comment for move_on_after explanation)
|
|
|
|
|
with anyio.move_on_after(0.01):
|
|
|
|
|
await session.create_message(
|
|
|
|
|
messages=[
|
|
|
|
|
types.SamplingMessage(role="user", content=text),
|
|
|
|
|
types.SamplingMessage(role="assistant", content=tool_use),
|
|
|
|
|
types.SamplingMessage(role="user", content=tool_result),
|
|
|
|
|
],
|
|
|
|
|
max_tokens=100,
|
|
|
|
|
tools=[tool],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Case 7: validation runs even without `tools` parameter
|
|
|
|
|
# (tool loop continuation may omit tools while containing tool_result)
|
|
|
|
|
with pytest.raises(ValueError, match="do not match any tool_use"):
|
|
|
|
|
await session.create_message(
|
|
|
|
|
messages=[
|
|
|
|
|
types.SamplingMessage(role="user", content=text),
|
|
|
|
|
types.SamplingMessage(role="user", content=tool_result),
|
|
|
|
|
],
|
|
|
|
|
max_tokens=100,
|
|
|
|
|
# Note: no tools parameter
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Case 8: empty messages list - skips validation entirely
|
|
|
|
|
# Covers the `if messages:` branch (line 280->302)
|
2026-01-23 20:00:20 +00:00
|
|
|
with anyio.move_on_after(0.01): # pragma: no branch
|
2026-01-07 17:28:23 +01:00
|
|
|
await session.create_message(messages=[], max_tokens=100)
|
2025-11-23 04:58:14 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.anyio
|
|
|
|
|
async def test_create_message_without_tools_capability():
|
2026-01-26 14:37:44 +01:00
|
|
|
"""Test that create_message raises MCPError when tools are provided without capability."""
|
2025-11-23 04:58:14 +00:00
|
|
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
|
|
|
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
|
|
|
|
|
|
|
|
|
|
async with (
|
|
|
|
|
client_to_server_send,
|
|
|
|
|
client_to_server_receive,
|
|
|
|
|
server_to_client_send,
|
|
|
|
|
server_to_client_receive,
|
|
|
|
|
):
|
|
|
|
|
async with ServerSession(
|
|
|
|
|
client_to_server_receive,
|
|
|
|
|
server_to_client_send,
|
|
|
|
|
InitializationOptions(
|
|
|
|
|
server_name="test",
|
|
|
|
|
server_version="0.1.0",
|
|
|
|
|
capabilities=ServerCapabilities(),
|
|
|
|
|
),
|
|
|
|
|
) as session:
|
|
|
|
|
# Set up client params WITHOUT sampling.tools capability
|
|
|
|
|
session._client_params = types.InitializeRequestParams(
|
2026-01-16 15:51:27 +01:00
|
|
|
protocol_version=types.LATEST_PROTOCOL_VERSION,
|
2025-11-23 04:58:14 +00:00
|
|
|
capabilities=types.ClientCapabilities(sampling=types.SamplingCapability()),
|
2026-01-16 15:51:27 +01:00
|
|
|
client_info=types.Implementation(name="test", version="1.0"),
|
2025-11-23 04:58:14 +00:00
|
|
|
)
|
|
|
|
|
|
2026-01-16 15:51:27 +01:00
|
|
|
tool = types.Tool(name="test_tool", input_schema={"type": "object"})
|
2025-11-23 04:58:14 +00:00
|
|
|
text = types.TextContent(type="text", text="hello")
|
|
|
|
|
|
2026-01-26 14:37:44 +01:00
|
|
|
# Should raise MCPError when tools are provided but client lacks capability
|
|
|
|
|
with pytest.raises(MCPError) as exc_info:
|
2025-11-23 04:58:14 +00:00
|
|
|
await session.create_message(
|
|
|
|
|
messages=[types.SamplingMessage(role="user", content=text)],
|
|
|
|
|
max_tokens=100,
|
|
|
|
|
tools=[tool],
|
|
|
|
|
)
|
|
|
|
|
assert "does not support sampling tools capability" in exc_info.value.error.message
|
|
|
|
|
|
2026-01-26 14:37:44 +01:00
|
|
|
# Should also raise MCPError when tool_choice is provided
|
|
|
|
|
with pytest.raises(MCPError) as exc_info:
|
2025-11-23 04:58:14 +00:00
|
|
|
await session.create_message(
|
|
|
|
|
messages=[types.SamplingMessage(role="user", content=text)],
|
|
|
|
|
max_tokens=100,
|
|
|
|
|
tool_choice=types.ToolChoice(mode="auto"),
|
|
|
|
|
)
|
|
|
|
|
assert "does not support sampling tools capability" in exc_info.value.error.message
|
|
|
|
|
|
|
|
|
|
|
2025-09-01 23:37:36 +02:00
|
|
|
@pytest.mark.anyio
|
|
|
|
|
async def test_other_requests_blocked_before_initialization():
|
|
|
|
|
"""Test that non-ping requests are still blocked before initialization."""
|
|
|
|
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
|
|
|
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
|
|
|
|
|
|
|
|
|
|
error_response_received = False
|
|
|
|
|
error_code = None
|
|
|
|
|
|
|
|
|
|
async def run_server():
|
|
|
|
|
async with ServerSession(
|
|
|
|
|
client_to_server_receive,
|
|
|
|
|
server_to_client_send,
|
|
|
|
|
InitializationOptions(
|
|
|
|
|
server_name="mcp",
|
|
|
|
|
server_version="0.1.0",
|
|
|
|
|
capabilities=ServerCapabilities(),
|
|
|
|
|
),
|
|
|
|
|
):
|
|
|
|
|
# Server should handle the request and send an error response
|
|
|
|
|
# No need to process incoming_messages since the error is handled automatically
|
|
|
|
|
await anyio.sleep(0.1) # Give time for the request to be processed
|
|
|
|
|
|
|
|
|
|
async def mock_client():
|
|
|
|
|
nonlocal error_response_received, error_code
|
|
|
|
|
|
|
|
|
|
# Try to send a non-ping request before initialization
|
|
|
|
|
await client_to_server_send.send(
|
2026-01-19 14:04:15 +01:00
|
|
|
SessionMessage(types.JSONRPCRequest(jsonrpc="2.0", id=1, method="prompts/list"))
|
2025-09-01 23:37:36 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Wait for the error response
|
|
|
|
|
error_message = await server_to_client_receive.receive()
|
2026-01-19 14:04:15 +01:00
|
|
|
if isinstance(error_message.message, types.JSONRPCError): # pragma: no branch
|
2025-09-01 23:37:36 +02:00
|
|
|
error_response_received = True
|
2026-01-19 14:04:15 +01:00
|
|
|
error_code = error_message.message.error.code
|
2025-09-01 23:37:36 +02:00
|
|
|
|
|
|
|
|
async with (
|
|
|
|
|
client_to_server_send,
|
|
|
|
|
client_to_server_receive,
|
|
|
|
|
server_to_client_send,
|
|
|
|
|
server_to_client_receive,
|
|
|
|
|
anyio.create_task_group() as tg,
|
|
|
|
|
):
|
|
|
|
|
tg.start_soon(run_server)
|
|
|
|
|
tg.start_soon(mock_client)
|
|
|
|
|
|
|
|
|
|
assert error_response_received
|
|
|
|
|
assert error_code == types.INVALID_PARAMS
|