2025-02-20 10:49:43 +00:00
|
|
|
import pytest
|
|
|
|
|
|
2026-01-16 15:49:26 +00:00
|
|
|
from mcp import Client
|
2025-02-20 10:49:43 +00:00
|
|
|
from mcp.client.session import ClientSession
|
2026-03-04 13:23:02 +00:00
|
|
|
from mcp.server.mcpserver import Context, MCPServer
|
2026-02-03 17:37:38 +01:00
|
|
|
from mcp.shared._context import RequestContext
|
2025-02-20 10:49:43 +00:00
|
|
|
from mcp.types import (
|
|
|
|
|
CreateMessageRequestParams,
|
|
|
|
|
CreateMessageResult,
|
2025-12-02 13:17:45 +00:00
|
|
|
CreateMessageResultWithTools,
|
2025-02-20 10:49:43 +00:00
|
|
|
SamplingMessage,
|
|
|
|
|
TextContent,
|
2025-12-02 13:17:45 +00:00
|
|
|
ToolUseContent,
|
2025-02-20 10:49:43 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.anyio
|
|
|
|
|
async def test_sampling_callback():
|
2026-01-25 14:45:52 +01:00
|
|
|
server = MCPServer("test")
|
2025-02-20 10:49:43 +00:00
|
|
|
|
|
|
|
|
callback_return = CreateMessageResult(
|
|
|
|
|
role="assistant",
|
|
|
|
|
content=TextContent(type="text", text="This is a response from the sampling callback"),
|
|
|
|
|
model="test-model",
|
2026-01-16 15:51:27 +01:00
|
|
|
stop_reason="endTurn",
|
2025-02-20 10:49:43 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def sampling_callback(
|
2026-02-03 14:35:07 +01:00
|
|
|
context: RequestContext[ClientSession],
|
2025-02-20 10:49:43 +00:00
|
|
|
params: CreateMessageRequestParams,
|
|
|
|
|
) -> CreateMessageResult:
|
|
|
|
|
return callback_return
|
|
|
|
|
|
|
|
|
|
@server.tool("test_sampling")
|
2026-03-04 13:23:02 +00:00
|
|
|
async def test_sampling_tool(message: str, ctx: Context) -> bool:
|
|
|
|
|
value = await ctx.session.create_message(
|
2025-02-20 10:49:43 +00:00
|
|
|
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))],
|
|
|
|
|
max_tokens=100,
|
|
|
|
|
)
|
|
|
|
|
assert value == callback_return
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
# Test with sampling callback
|
2026-01-16 15:49:26 +00:00
|
|
|
async with Client(server, sampling_callback=sampling_callback) as client:
|
2025-02-20 10:49:43 +00:00
|
|
|
# Make a request to trigger sampling callback
|
2026-01-16 15:49:26 +00:00
|
|
|
result = await client.call_tool("test_sampling", {"message": "Test message for sampling"})
|
2026-01-16 15:51:27 +01:00
|
|
|
assert result.is_error is False
|
2025-02-20 10:49:43 +00:00
|
|
|
assert isinstance(result.content[0], TextContent)
|
|
|
|
|
assert result.content[0].text == "true"
|
|
|
|
|
|
|
|
|
|
# Test without sampling callback
|
2026-01-16 15:49:26 +00:00
|
|
|
async with Client(server) as client:
|
2025-02-20 10:49:43 +00:00
|
|
|
# Make a request to trigger sampling callback
|
2026-01-16 15:49:26 +00:00
|
|
|
result = await client.call_tool("test_sampling", {"message": "Test message for sampling"})
|
2026-01-16 15:51:27 +01:00
|
|
|
assert result.is_error is True
|
2025-02-20 10:49:43 +00:00
|
|
|
assert isinstance(result.content[0], TextContent)
|
|
|
|
|
assert result.content[0].text == "Error executing tool test_sampling: Sampling not supported"
|
2025-12-02 13:17:45 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.anyio
|
|
|
|
|
async def test_create_message_backwards_compat_single_content():
|
|
|
|
|
"""Test backwards compatibility: create_message without tools returns single content."""
|
2026-01-25 14:45:52 +01:00
|
|
|
server = MCPServer("test")
|
2025-12-02 13:17:45 +00:00
|
|
|
|
|
|
|
|
# Callback returns single content (text)
|
|
|
|
|
callback_return = CreateMessageResult(
|
|
|
|
|
role="assistant",
|
|
|
|
|
content=TextContent(type="text", text="Hello from LLM"),
|
|
|
|
|
model="test-model",
|
2026-01-16 15:51:27 +01:00
|
|
|
stop_reason="endTurn",
|
2025-12-02 13:17:45 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def sampling_callback(
|
2026-02-03 14:35:07 +01:00
|
|
|
context: RequestContext[ClientSession],
|
2025-12-02 13:17:45 +00:00
|
|
|
params: CreateMessageRequestParams,
|
|
|
|
|
) -> CreateMessageResult:
|
|
|
|
|
return callback_return
|
|
|
|
|
|
|
|
|
|
@server.tool("test_backwards_compat")
|
2026-03-04 13:23:02 +00:00
|
|
|
async def test_tool(message: str, ctx: Context) -> bool:
|
2025-12-02 13:17:45 +00:00
|
|
|
# Call create_message WITHOUT tools
|
2026-03-04 13:23:02 +00:00
|
|
|
result = await ctx.session.create_message(
|
2025-12-02 13:17:45 +00:00
|
|
|
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))],
|
|
|
|
|
max_tokens=100,
|
|
|
|
|
)
|
|
|
|
|
# Backwards compat: result should be CreateMessageResult
|
|
|
|
|
assert isinstance(result, CreateMessageResult)
|
|
|
|
|
# Content should be single (not a list) - this is the key backwards compat check
|
|
|
|
|
assert isinstance(result.content, TextContent)
|
|
|
|
|
assert result.content.text == "Hello from LLM"
|
|
|
|
|
# CreateMessageResult should NOT have content_as_list (that's on WithTools)
|
|
|
|
|
assert not hasattr(result, "content_as_list") or not callable(getattr(result, "content_as_list", None))
|
|
|
|
|
return True
|
|
|
|
|
|
2026-01-16 15:49:26 +00:00
|
|
|
async with Client(server, sampling_callback=sampling_callback) as client:
|
|
|
|
|
result = await client.call_tool("test_backwards_compat", {"message": "Test"})
|
2026-01-16 15:51:27 +01:00
|
|
|
assert result.is_error is False
|
2025-12-02 13:17:45 +00:00
|
|
|
assert isinstance(result.content[0], TextContent)
|
|
|
|
|
assert result.content[0].text == "true"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.anyio
|
|
|
|
|
async def test_create_message_result_with_tools_type():
|
|
|
|
|
"""Test that CreateMessageResultWithTools supports content_as_list."""
|
|
|
|
|
# Test the type itself, not the overload (overload requires client capability setup)
|
|
|
|
|
result = CreateMessageResultWithTools(
|
|
|
|
|
role="assistant",
|
|
|
|
|
content=ToolUseContent(type="tool_use", id="call_123", name="get_weather", input={"city": "SF"}),
|
|
|
|
|
model="test-model",
|
2026-01-16 15:51:27 +01:00
|
|
|
stop_reason="toolUse",
|
2025-12-02 13:17:45 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# CreateMessageResultWithTools should have content_as_list
|
|
|
|
|
content_list = result.content_as_list
|
|
|
|
|
assert len(content_list) == 1
|
|
|
|
|
assert content_list[0].type == "tool_use"
|
|
|
|
|
|
|
|
|
|
# It should also work with array content
|
|
|
|
|
result_array = CreateMessageResultWithTools(
|
|
|
|
|
role="assistant",
|
|
|
|
|
content=[
|
|
|
|
|
TextContent(type="text", text="Let me check the weather"),
|
|
|
|
|
ToolUseContent(type="tool_use", id="call_456", name="get_weather", input={"city": "NYC"}),
|
|
|
|
|
],
|
|
|
|
|
model="test-model",
|
2026-01-16 15:51:27 +01:00
|
|
|
stop_reason="toolUse",
|
2025-12-02 13:17:45 +00:00
|
|
|
)
|
|
|
|
|
content_list_array = result_array.content_as_list
|
|
|
|
|
assert len(content_list_array) == 2
|
|
|
|
|
assert content_list_array[0].type == "text"
|
|
|
|
|
assert content_list_array[1].type == "tool_use"
|