Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/packages/core/agent_framework/_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,11 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str:
for attempt in range(2):
try:
result = await self.session.call_tool(tool_name, arguments=filtered_kwargs, meta=otel_meta) # type: ignore
if result.isError:
raise ToolExecutionException(parser(result))
return parser(result)
except ToolExecutionException:
raise
except ClosedResourceError as cl_ex:
if attempt == 0:
# First attempt failed, try reconnecting
Expand Down
4 changes: 2 additions & 2 deletions python/packages/core/agent_framework/_skills.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,10 @@ async def _read_skill_resource(self, skill_name: str, resource_name: str, **kwar
try:
if inspect.iscoroutinefunction(resource.function):
result = (
await resource.function(**kwargs) if resource._accepts_kwargs else await resource.function()
await resource.function(**kwargs) if resource._accepts_kwargs else await resource.function() # pyright: ignore[reportPrivateUsage]
)
else:
result = resource.function(**kwargs) if resource._accepts_kwargs else resource.function()
result = resource.function(**kwargs) if resource._accepts_kwargs else resource.function() # pyright: ignore[reportPrivateUsage]
return str(result)
except Exception as exc:
logger.exception("Failed to read resource '%s' from skill '%s'", resource_name, skill_name)
Expand Down
148 changes: 146 additions & 2 deletions python/packages/core/tests/core/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from agent_framework import (
Content,
FunctionInvocationContext,
FunctionMiddleware,
MCPStdioTool,
MCPStreamableHTTPTool,
MCPWebsocketTool,
Expand All @@ -30,6 +32,7 @@
_prepare_message_for_mcp,
logger,
)
from agent_framework._middleware import FunctionMiddlewarePipeline
from agent_framework.exceptions import ToolException, ToolExecutionException

# Integration test skip condition
Expand Down Expand Up @@ -898,6 +901,147 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
await func.invoke(param="test_value")


async def test_mcp_tool_call_tool_raises_on_is_error():
"""Test that call_tool raises ToolExecutionException when MCP returns isError=True."""

class TestServer(MCPTool):
async def connect(self):
self.session = Mock(spec=ClientSession)
self.session.list_tools = AsyncMock(
return_value=types.ListToolsResult(
tools=[
types.Tool(
name="test_tool",
description="Test tool",
inputSchema={
"type": "object",
"properties": {"param": {"type": "string"}},
"required": ["param"],
},
)
]
)
)
self.session.call_tool = AsyncMock(
return_value=types.CallToolResult(
content=[types.TextContent(type="text", text="Something went wrong")],
isError=True,
)
)

def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
return None

server = TestServer(name="test_server")
async with server:
await server.load_tools()
func = server.functions[0]

with pytest.raises(ToolExecutionException, match="Something went wrong"):
await func.invoke(param="test_value")


async def test_mcp_tool_call_tool_succeeds_when_is_error_false():
"""Test that call_tool returns normally when MCP returns isError=False."""

class TestServer(MCPTool):
async def connect(self):
self.session = Mock(spec=ClientSession)
self.session.list_tools = AsyncMock(
return_value=types.ListToolsResult(
tools=[
types.Tool(
name="test_tool",
description="Test tool",
inputSchema={
"type": "object",
"properties": {"param": {"type": "string"}},
"required": ["param"],
},
)
]
)
)
self.session.call_tool = AsyncMock(
return_value=types.CallToolResult(
content=[types.TextContent(type="text", text="Success")],
isError=False,
)
)

def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
return None

server = TestServer(name="test_server")
async with server:
await server.load_tools()
func = server.functions[0]
result = await func.invoke(param="test_value")
assert result == "Success"


async def test_mcp_tool_is_error_propagates_through_function_middleware():
"""Test that MCP isError=True propagates as ToolExecutionException through function middleware."""
error_seen_in_middleware = False

class ErrorCheckMiddleware(FunctionMiddleware):
async def process(self, context: FunctionInvocationContext, call_next):
nonlocal error_seen_in_middleware
try:
await call_next()
except ToolExecutionException:
error_seen_in_middleware = True
raise

class TestServer(MCPTool):
async def connect(self):
self.session = Mock(spec=ClientSession)
self.session.list_tools = AsyncMock(
return_value=types.ListToolsResult(
tools=[
types.Tool(
name="test_tool",
description="Test tool",
inputSchema={
"type": "object",
"properties": {"param": {"type": "string"}},
"required": ["param"],
},
)
]
)
)
self.session.call_tool = AsyncMock(
return_value=types.CallToolResult(
content=[types.TextContent(type="text", text="MCP error occurred")],
isError=True,
)
)

def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
return None

server = TestServer(name="test_server")
async with server:
await server.load_tools()
func = server.functions[0]

middleware_pipeline = FunctionMiddlewarePipeline(ErrorCheckMiddleware())

middleware_context = FunctionInvocationContext(
function=func,
arguments={"param": "test_value"},
)

with pytest.raises(ToolExecutionException, match="MCP error occurred"):
await middleware_pipeline.execute(
middleware_context,
lambda ctx: func.invoke(arguments=ctx.arguments),
)

assert error_seen_in_middleware, "Middleware should have seen the ToolExecutionException"


async def test_local_mcp_server_prompt_execution():
"""Test prompt execution through MCP server."""

Expand Down Expand Up @@ -2098,7 +2242,7 @@ async def restore_session(*, reset=False):
tool._tools_loaded = True

# First call should work - connection is valid
mock_session.call_tool.return_value = MagicMock(content=[])
mock_session.call_tool.return_value = types.CallToolResult(content=[])
result = await tool.call_tool("test_tool", arg1="value1")
assert result is not None

Expand All @@ -2111,7 +2255,7 @@ async def call_tool_with_error(*args, **kwargs):
call_count += 1
if call_count == 1:
raise ClosedResourceError
return MagicMock(content=[])
return types.CallToolResult(content=[])

mock_session.call_tool = call_tool_with_error

Expand Down