2025-08-11 19:56:37 +02:00
|
|
|
from collections.abc import Callable, Generator
|
2025-05-21 14:27:06 -07:00
|
|
|
from contextlib import asynccontextmanager
|
2025-08-11 19:56:37 +02:00
|
|
|
from typing import Any
|
2025-05-21 14:27:06 -07:00
|
|
|
from unittest.mock import patch
|
|
|
|
|
|
|
|
|
|
import pytest
|
2025-08-11 19:56:37 +02:00
|
|
|
from anyio.streams.memory import MemoryObjectSendStream
|
2025-05-21 14:27:06 -07:00
|
|
|
|
|
|
|
|
import mcp.shared.memory
|
|
|
|
|
from mcp.shared.message import SessionMessage
|
2025-08-11 19:56:37 +02:00
|
|
|
from mcp.types import JSONRPCNotification, JSONRPCRequest
|
2025-05-21 14:27:06 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpyMemoryObjectSendStream:
|
2025-08-11 19:56:37 +02:00
|
|
|
def __init__(self, original_stream: MemoryObjectSendStream[SessionMessage]):
|
2025-05-21 14:27:06 -07:00
|
|
|
self.original_stream = original_stream
|
|
|
|
|
self.sent_messages: list[SessionMessage] = []
|
|
|
|
|
|
2025-08-11 19:56:37 +02:00
|
|
|
async def send(self, message: SessionMessage):
|
2025-05-21 14:27:06 -07:00
|
|
|
self.sent_messages.append(message)
|
|
|
|
|
await self.original_stream.send(message)
|
|
|
|
|
|
|
|
|
|
async def aclose(self):
|
|
|
|
|
await self.original_stream.aclose()
|
|
|
|
|
|
|
|
|
|
async def __aenter__(self):
|
|
|
|
|
return self
|
|
|
|
|
|
2025-08-11 19:56:37 +02:00
|
|
|
async def __aexit__(self, *args: Any):
|
2025-05-21 14:27:06 -07:00
|
|
|
await self.aclose()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StreamSpyCollection:
|
2025-08-11 19:56:37 +02:00
|
|
|
def __init__(self, client_spy: SpyMemoryObjectSendStream, server_spy: SpyMemoryObjectSendStream):
|
2025-05-21 14:27:06 -07:00
|
|
|
self.client = client_spy
|
|
|
|
|
self.server = server_spy
|
|
|
|
|
|
|
|
|
|
def clear(self) -> None:
|
|
|
|
|
"""Clear all captured messages."""
|
|
|
|
|
self.client.sent_messages.clear()
|
|
|
|
|
self.server.sent_messages.clear()
|
|
|
|
|
|
|
|
|
|
def get_client_requests(self, method: str | None = None) -> list[JSONRPCRequest]:
|
|
|
|
|
"""Get client-sent requests, optionally filtered by method."""
|
|
|
|
|
return [
|
|
|
|
|
req.message.root
|
|
|
|
|
for req in self.client.sent_messages
|
|
|
|
|
if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]:
|
|
|
|
|
"""Get server-sent requests, optionally filtered by method."""
|
|
|
|
|
return [
|
|
|
|
|
req.message.root
|
|
|
|
|
for req in self.server.sent_messages
|
|
|
|
|
if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
def get_client_notifications(self, method: str | None = None) -> list[JSONRPCNotification]:
|
|
|
|
|
"""Get client-sent notifications, optionally filtered by method."""
|
|
|
|
|
return [
|
|
|
|
|
notif.message.root
|
|
|
|
|
for notif in self.client.sent_messages
|
|
|
|
|
if isinstance(notif.message.root, JSONRPCNotification)
|
|
|
|
|
and (method is None or notif.message.root.method == method)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
def get_server_notifications(self, method: str | None = None) -> list[JSONRPCNotification]:
|
|
|
|
|
"""Get server-sent notifications, optionally filtered by method."""
|
|
|
|
|
return [
|
|
|
|
|
notif.message.root
|
|
|
|
|
for notif in self.server.sent_messages
|
|
|
|
|
if isinstance(notif.message.root, JSONRPCNotification)
|
|
|
|
|
and (method is None or notif.message.root.method == method)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
2025-08-11 19:56:37 +02:00
|
|
|
def stream_spy() -> Generator[Callable[[], StreamSpyCollection], None, None]:
|
2025-05-21 14:27:06 -07:00
|
|
|
"""Fixture that provides spies for both client and server write streams.
|
|
|
|
|
|
|
|
|
|
Example usage:
|
|
|
|
|
async def test_something(stream_spy):
|
|
|
|
|
# ... set up server and client ...
|
|
|
|
|
|
|
|
|
|
spies = stream_spy()
|
|
|
|
|
|
|
|
|
|
# Run some operation that sends messages
|
|
|
|
|
await client.some_operation()
|
|
|
|
|
|
|
|
|
|
# Check the messages
|
|
|
|
|
requests = spies.get_client_requests(method="some/method")
|
|
|
|
|
assert len(requests) == 1
|
|
|
|
|
|
|
|
|
|
# Clear for the next operation
|
|
|
|
|
spies.clear()
|
|
|
|
|
"""
|
|
|
|
|
client_spy = None
|
|
|
|
|
server_spy = None
|
|
|
|
|
|
|
|
|
|
# Store references to our spy objects
|
2025-08-11 19:56:37 +02:00
|
|
|
def capture_spies(c_spy: SpyMemoryObjectSendStream, s_spy: SpyMemoryObjectSendStream):
|
2025-05-21 14:27:06 -07:00
|
|
|
nonlocal client_spy, server_spy
|
|
|
|
|
client_spy = c_spy
|
|
|
|
|
server_spy = s_spy
|
|
|
|
|
|
|
|
|
|
# Create patched version of stream creation
|
|
|
|
|
original_create_streams = mcp.shared.memory.create_client_server_memory_streams
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager
|
|
|
|
|
async def patched_create_streams():
|
|
|
|
|
async with original_create_streams() as (client_streams, server_streams):
|
|
|
|
|
client_read, client_write = client_streams
|
|
|
|
|
server_read, server_write = server_streams
|
|
|
|
|
|
|
|
|
|
# Create spy wrappers
|
|
|
|
|
spy_client_write = SpyMemoryObjectSendStream(client_write)
|
|
|
|
|
spy_server_write = SpyMemoryObjectSendStream(server_write)
|
|
|
|
|
|
|
|
|
|
# Capture references for the test to use
|
|
|
|
|
capture_spies(spy_client_write, spy_server_write)
|
|
|
|
|
|
|
|
|
|
yield (client_read, spy_client_write), (server_read, spy_server_write)
|
|
|
|
|
|
|
|
|
|
# Apply the patch for the duration of the test
|
|
|
|
|
with patch("mcp.shared.memory.create_client_server_memory_streams", patched_create_streams):
|
|
|
|
|
# Return a collection with helper methods
|
|
|
|
|
def get_spy_collection() -> StreamSpyCollection:
|
|
|
|
|
assert client_spy is not None, "client_spy was not initialized"
|
|
|
|
|
assert server_spy is not None, "server_spy was not initialized"
|
|
|
|
|
return StreamSpyCollection(client_spy, server_spy)
|
|
|
|
|
|
|
|
|
|
yield get_spy_collection
|