diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 67453624c..728fb088d 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -83,6 +83,8 @@ async def main(): from typing_extensions import TypeVar import mcp.types as types +from mcp.server.auth.middleware.auth_context import auth_context_var +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser from mcp.server.experimental.request_context import Experimental from mcp.server.lowlevel.experimental import ExperimentalHandlers from mcp.server.lowlevel.func_inspection import create_call_wrapper @@ -723,6 +725,7 @@ async def _handle_request( logger.debug("Dispatching request of type %s", type(req).__name__) token = None + auth_token = None try: # Extract request context and close_sse_stream from message metadata request_data = None @@ -743,6 +746,14 @@ async def _handle_request( task_metadata = None if hasattr(req, "params") and req.params is not None: task_metadata = getattr(req.params, "task", None) + if request_data is not None: + scope = getattr(request_data, "scope", None) + if isinstance(scope, dict): + scope_dict = cast(dict[str, Any], scope) + user = scope_dict.get("user") + if isinstance(user, AuthenticatedUser): + auth_token = auth_context_var.set(user) + token = request_ctx.set( RequestContext( message.request_id, @@ -775,6 +786,8 @@ async def _handle_request( response = types.ErrorData(code=0, message=str(err), data=None) finally: # Reset the global state after we are done + if auth_token is not None: + auth_context_var.reset(auth_token) if token is not None: # pragma: no branch request_ctx.reset(token) diff --git a/tests/server/lowlevel/test_auth_context_from_request.py b/tests/server/lowlevel/test_auth_context_from_request.py new file mode 100644 index 000000000..970f4ebe1 --- /dev/null +++ b/tests/server/lowlevel/test_auth_context_from_request.py @@ -0,0 +1,69 @@ +from unittest.mock import AsyncMock, Mock + +import pytest +from starlette.requests import Request +from starlette.types import Scope + +import mcp.types as types +from mcp.server.auth.middleware.auth_context import auth_context_var, get_access_token +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken +from mcp.server.lowlevel.server import Server +from mcp.server.session import ServerSession +from mcp.shared.message import ServerMessageMetadata +from mcp.shared.session import RequestResponder + + +@pytest.mark.anyio +async def test_handle_request_sets_auth_context_from_request() -> None: + server = Server("test-server") + + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="echo_access_token", + description="Return access token", + inputSchema={"type": "object", "properties": {}}, + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, object]) -> list[types.TextContent]: + assert name == "echo_access_token" + access_token = get_access_token() + token = access_token.token if access_token else "" + return [types.TextContent(type="text", text=token)] + + access_token = AccessToken(token="test-token", client_id="client", scopes=["test"]) + user = AuthenticatedUser(access_token) + headers: list[tuple[bytes, bytes]] = [] + scope: Scope = { + "type": "http", + "method": "POST", + "path": "/mcp", + "headers": headers, + "user": user, + } + request = Request(scope) + + message = Mock(spec=RequestResponder) + message.request_id = "req-1" + message.request_meta = None + message.message_metadata = ServerMessageMetadata(request_context=request) + message.respond = AsyncMock() + + session = Mock(spec=ServerSession) + session.client_params = None + + call_request = types.CallToolRequest(params=types.CallToolRequestParams(name="echo_access_token", arguments={})) + + await server._handle_request(message, call_request, session, {}, raise_exceptions=False) + + assert auth_context_var.get() is None + assert message.respond.called + response = message.respond.call_args.args[0] + assert isinstance(response.root, types.CallToolResult) + content = response.root.content[0] + assert isinstance(content, types.TextContent) + assert content.text == "test-token" diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 731dd20dd..02a626ece 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -21,6 +21,8 @@ from httpx_sse import ServerSentEvent from pydantic import AnyUrl from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.authentication import AuthenticationMiddleware from starlette.requests import Request from starlette.routing import Mount @@ -32,6 +34,9 @@ streamablehttp_client, # pyright: ignore[reportDeprecated] ) from mcp.server import Server +from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, get_access_token +from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend +from mcp.server.auth.provider import AccessToken, TokenVerifier from mcp.server.streamable_http import ( MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, @@ -1520,6 +1525,71 @@ def run_context_aware_server(port: int): # pragma: no cover server_instance.run() +class AuthTokenServerTest(Server): # pragma: no cover + def __init__(self): + super().__init__("AuthTokenServer") + + @self.list_tools() + async def handle_list_tools() -> list[Tool]: + return [ + Tool( + name="echo_access_token", + description="Return the current access token", + inputSchema={"type": "object", "properties": {}}, + ) + ] + + @self.call_tool() + async def handle_call_tool(name: str, _args: dict[str, Any]) -> list[TextContent]: + assert name == "echo_access_token" + access_token = get_access_token() + assert access_token is not None + return [TextContent(type="text", text=access_token.token)] + + +def run_auth_token_server(port: int) -> None: # pragma: no cover + """Run the auth token test server.""" + server = AuthTokenServerTest() + + class AcceptAllTokenVerifier(TokenVerifier): + async def verify_token(self, token: str) -> AccessToken | None: + return AccessToken( + token=token, + client_id="test-client", + scopes=["test"], + ) + + token_verifier = AcceptAllTokenVerifier() + + session_manager = StreamableHTTPSessionManager( + app=server, + event_store=None, + json_response=False, + ) + + middleware = [ + Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(token_verifier)), + Middleware(AuthContextMiddleware), + ] + + app = Starlette( + debug=True, + routes=[Mount("/mcp", app=session_manager.handle_request)], + middleware=middleware, + lifespan=lambda app: session_manager.run(), + ) + + server_instance = uvicorn.Server( + config=uvicorn.Config( + app=app, + host="127.0.0.1", + port=port, + log_level="error", + ) + ) + server_instance.run() + + @pytest.fixture def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: """Start the context-aware server in a separate process.""" @@ -1537,6 +1607,22 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: print("Context-aware server process failed to terminate") +@pytest.fixture +def auth_token_server(basic_server_port: int) -> Generator[None, None, None]: + """Start the auth token server in a separate process.""" + proc = multiprocessing.Process(target=run_auth_token_server, args=(basic_server_port,), daemon=True) + proc.start() + + wait_for_server(basic_server_port) + + yield + + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): # pragma: no cover + print("Auth token server process failed to terminate") + + @pytest.mark.anyio async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None: """Test that request context is properly propagated through StreamableHTTP.""" @@ -1571,6 +1657,34 @@ async def test_streamablehttp_request_context_propagation(context_aware_server: assert headers_data.get("x-trace-id") == "trace-123" +@pytest.mark.anyio +async def test_streamablehttp_refreshes_access_token(auth_token_server: None, basic_server_url: str) -> None: + """Ensure refreshed bearer tokens are used for subsequent requests.""" + token_a = "token-a" + token_b = "token-b" + + async with create_mcp_http_client(headers={"Authorization": f"Bearer {token_a}"}) as httpx_client: + async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + result = await session.initialize() + assert isinstance(result, InitializeResult) + + tool_result = await session.call_tool("echo_access_token", {}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == token_a + + httpx_client.headers["Authorization"] = f"Bearer {token_b}" + tool_result = await session.call_tool("echo_access_token", {}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == token_b + + @pytest.mark.anyio async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None: """Test that request contexts are isolated between StreamableHTTP clients."""