Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,8 @@ async def post_writer(
"""Handle writing requests to the server."""
try:
async with write_stream_reader:
async for session_message in write_stream_reader:

async def handle_message(session_message: SessionMessage) -> None:
message = session_message.message
metadata = (
session_message.metadata
Expand Down Expand Up @@ -467,6 +468,10 @@ async def handle_request_async():
else:
await handle_request_async()

async for session_message in write_stream_reader:
async with anyio.create_task_group() as tg_local:
session_message.context.run(tg_local.start_soon, handle_message, session_message)

except Exception:
logger.exception("Error in post_writer") # pragma: no cover
finally:
Expand Down
9 changes: 8 additions & 1 deletion src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,14 @@ async def run(
async for message in session.incoming_messages:
logger.debug("Received message: %s", message)

tg.start_soon(
if isinstance(message, RequestResponder) and message.context is not None:
logger.debug("Got a context to propagate, %s", message.context)
context = message.context
else:
context = contextvars.copy_context()

context.run(
tg.start_soon,
self._handle_message,
message,
session,
Expand Down
4 changes: 3 additions & 1 deletion src/mcp/shared/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
to support transport-specific features like resumability.
"""

import contextvars
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from dataclasses import dataclass, field

from mcp.types import JSONRPCMessage, RequestId

Expand Down Expand Up @@ -46,4 +47,5 @@ class SessionMessage:
"""A message with specific metadata for transport-specific features."""

message: JSONRPCMessage
context: contextvars.Context = field(default_factory=contextvars.copy_context)
metadata: MessageMetadata = None
18 changes: 14 additions & 4 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextvars
import logging
from collections.abc import Callable
from contextlib import AsyncExitStack
Expand Down Expand Up @@ -77,11 +78,13 @@ def __init__(
session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
on_complete: Callable[[RequestResponder[ReceiveRequestT, SendResultT]], Any],
message_metadata: MessageMetadata = None,
context: contextvars.Context | None = None,
) -> None:
self.request_id = request_id
self.request_meta = request_meta
self.request = request
self.message_metadata = message_metadata
self.context = context
self._session = session
self._completed = False
self._cancel_scope = anyio.CancelScope()
Expand Down Expand Up @@ -330,10 +333,9 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]:
async def _receive_loop(self) -> None:
async with self._read_stream, self._write_stream:
try:
async for message in self._read_stream:
if isinstance(message, Exception): # pragma: no cover
await self._handle_incoming(message)
elif isinstance(message.message, JSONRPCRequest):

async def handle_message(message: SessionMessage) -> None:
if isinstance(message.message, JSONRPCRequest):
try:
validated_request = self._receive_request_adapter.validate_python(
message.message.model_dump(by_alias=True, mode="json", exclude_none=True),
Expand All @@ -346,6 +348,7 @@ async def _receive_loop(self) -> None:
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
message_metadata=message.metadata,
context=message.context,
)
self._in_flight[responder.request_id] = responder
await self._received_request(responder)
Expand Down Expand Up @@ -403,6 +406,13 @@ async def _receive_loop(self) -> None:
else: # Response or error
await self._handle_response(message)

async for message in self._read_stream:
if isinstance(message, Exception): # pragma: no cover
await self._handle_incoming(message)
else:
async with anyio.create_task_group() as tg:
message.context.run(tg.start_soon, handle_message, message)

except anyio.ClosedResourceError:
# This is expected when the client disconnects abruptly.
# Without this handler, the exception would propagate up and
Expand Down
107 changes: 107 additions & 0 deletions tests/test_context_propagation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import contextvars
from collections.abc import Iterator
from contextlib import contextmanager

import httpx
import pytest
from inline_snapshot import snapshot
from starlette.types import Receive, Scope, Send

import mcp.types as types
from mcp import Client
from mcp.client.streamable_http import streamable_http_client
from mcp.server import MCPServer

# TODO: remove once https://github.com/modelcontextprotocol/python-sdk/pull/1991 is merged
pytestmark = pytest.mark.filterwarnings("ignore::ResourceWarning")


# TODO: remove once https://github.com/modelcontextprotocol/python-sdk/pull/1991 is merged
@pytest.fixture(autouse=True)
def force_gc_after_test_resource_leak():
yield
import gc

gc.collect()


TEST_CONTEXTVAR = contextvars.ContextVar("test_var", default="initial")
HOST = "testserver"


@contextmanager
def set_test_contextvar(value: str) -> Iterator[None]:
token = TEST_CONTEXTVAR.set(value)
try:
yield
finally:
TEST_CONTEXTVAR.reset(token)


@pytest.fixture
def server() -> MCPServer:
mcp = MCPServer("test_server")

# tool that returns the value of TEST_CONTEXT_VAR.
@mcp.tool()
async def my_tool() -> str:
return TEST_CONTEXTVAR.get()

return mcp


@pytest.mark.anyio
async def test_memory_transport_client_to_server(server: MCPServer):
async with Client(server) as client:
with set_test_contextvar("client_value"):
result = await client.call_tool(name="my_tool")

assert isinstance(result, types.CallToolResult)
assert result.content == snapshot([types.TextContent(text="client_value")])


@pytest.mark.anyio
async def test_streamable_http_asgi_to_mcpserver(server: MCPServer):
mcp_app = server.streamable_http_app(host=HOST)

# Wrap it in a middleware that sets the contextvar
async def middleware_app(scope: Scope, receive: Receive, send: Send):
with set_test_contextvar("from_middleware"):
await mcp_app(scope, receive, send)

async with (
mcp_app.router.lifespan_context(middleware_app),
httpx.ASGITransport(app=middleware_app) as transport,
httpx.AsyncClient(transport=transport) as http_client,
Client(streamable_http_client(f"http://{HOST}/mcp", http_client=http_client)) as client,
):
result = await client.call_tool("my_tool")
assert result.content == snapshot([types.TextContent(text="from_middleware")])


@pytest.mark.anyio
async def test_streamable_http_mcpclient_to_httpx(server: MCPServer):
mcp_app = server.streamable_http_app(host=HOST)

captured_context_var = None

# Intercepts the httpx call and capture the contextvar's value
class ContextCapturingASGITransport(httpx.ASGITransport):
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
nonlocal captured_context_var
captured_context_var = TEST_CONTEXTVAR.get()
return await super().handle_async_request(request)

async with (
mcp_app.router.lifespan_context(mcp_app),
ContextCapturingASGITransport(app=mcp_app) as transport,
httpx.AsyncClient(transport=transport) as http_client,
Client(streamable_http_client(f"http://{HOST}/mcp", http_client=http_client)) as client,
):
with set_test_contextvar("client_value_list"):
await client.list_tools()
assert captured_context_var == snapshot("client_value_list")

with set_test_contextvar("client_value_call_tool"):
await client.call_tool("my_tool")
assert captured_context_var == snapshot("client_value_call_tool")
Loading
Loading