Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions astrbot/core/agent/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
from contextlib import AsyncExitStack
from datetime import timedelta
from typing import Generic
from typing import Any, Generic

from tenacity import (
before_sleep_log,
Expand Down Expand Up @@ -125,6 +125,59 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
return False, f"{e!s}"


_EMPTY_MCP_ARGUMENT = object()


def _sanitize_mcp_arguments(
value: Any,
schema: dict[str, Any] | None = None,
*,
required: bool = False,
) -> Any:
"""Remove empty optional payload values before sending to MCP tools."""
if value is None:
return value if required else _EMPTY_MCP_ARGUMENT

if isinstance(value, str):
return value if value != "" or required else _EMPTY_MCP_ARGUMENT

if isinstance(value, list):
if not value:
return value if required else _EMPTY_MCP_ARGUMENT
cleaned_items = []
item_schema = schema.get("items") if isinstance(schema, dict) else None
for item in value:
cleaned_item = _sanitize_mcp_arguments(item, item_schema)
# Preserve list positions. If sanitizing an item would remove it,
# keep the original item instead of reindexing the payload.
if cleaned_item is _EMPTY_MCP_ARGUMENT:
cleaned_items.append(item)
else:
cleaned_items.append(cleaned_item)
return cleaned_items

if isinstance(value, dict):
if not value:
return value if required else _EMPTY_MCP_ARGUMENT

cleaned_dict = {}
properties = schema.get("properties", {}) if isinstance(schema, dict) else {}
required_keys = set(schema.get("required", [])) if isinstance(schema, dict) else set()
for key, item in value.items():
child_schema = properties.get(key) if isinstance(properties, dict) else None
cleaned_item = _sanitize_mcp_arguments(
item,
child_schema,
required=key in required_keys,
)
if cleaned_item is _EMPTY_MCP_ARGUMENT:
continue
cleaned_dict[key] = cleaned_item
return cleaned_dict if cleaned_dict or required else _EMPTY_MCP_ARGUMENT

return value


class MCPClient:
def __init__(self) -> None:
# Initialize session and client objects
Expand Down Expand Up @@ -347,6 +400,21 @@ async def call_tool_with_reconnect(
anyio.ClosedResourceError: raised after reconnection failure
"""

tool_schema = next(
(tool.inputSchema for tool in self.tools if tool.name == tool_name),
None,
)
sanitized_arguments = _sanitize_mcp_arguments(arguments, tool_schema)
if sanitized_arguments is _EMPTY_MCP_ARGUMENT:
sanitized_arguments = {}
if sanitized_arguments != arguments:
logger.debug(
"Sanitized MCP tool %s arguments from %s to %s",
tool_name,
arguments,
sanitized_arguments,
Comment on lines +410 to +415
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚨 suggestion (security): Debug logging of full arguments may expose sensitive data and be expensive

This logs both the raw and sanitized arguments, which can leak credentials/PII into logs and adds overhead for large payloads. Consider redacting known sensitive fields, truncating large values, or guarding this behind a more verbose/debug-only flag so it’s not enabled in typical deployments.

)

@retry(
retry=retry_if_exception_type(anyio.ClosedResourceError),
stop=stop_after_attempt(2),
Expand All @@ -361,7 +429,7 @@ async def _call_with_retry():
try:
return await self.session.call_tool(
name=tool_name,
arguments=arguments,
arguments=sanitized_arguments,
read_timeout_seconds=read_timeout_seconds,
)
except anyio.ClosedResourceError:
Expand Down
127 changes: 127 additions & 0 deletions tests/unit/test_mcp_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from __future__ import annotations

import importlib.util
import logging
import sys
import types
from pathlib import Path
from typing import Generic, TypeVar
from unittest.mock import AsyncMock

import pytest

REPO_ROOT = Path(__file__).resolve().parents[2]
MCP_CLIENT_MODULE_PATH = REPO_ROOT / "astrbot/core/agent/mcp_client.py"


def load_mcp_client_module():
package_names = [
"astrbot",
"astrbot.core",
"astrbot.core.agent",
"astrbot.core.utils",
]
for name in package_names:
if name not in sys.modules:
module = types.ModuleType(name)
module.__path__ = []
sys.modules[name] = module

astrbot_module = sys.modules["astrbot"]
astrbot_module.logger = logging.getLogger("astrbot-test")

log_pipe_module = types.ModuleType("astrbot.core.utils.log_pipe")
log_pipe_module.LogPipe = type("LogPipe", (), {})
sys.modules[log_pipe_module.__name__] = log_pipe_module

run_context_module = types.ModuleType("astrbot.core.agent.run_context")
run_context_module.TContext = TypeVar("TContext")

class ContextWrapper(Generic[run_context_module.TContext]):
pass

run_context_module.ContextWrapper = ContextWrapper
sys.modules[run_context_module.__name__] = run_context_module

tool_module = types.ModuleType("astrbot.core.agent.tool")
tool_module.FunctionTool = type("FunctionTool", (), {})
sys.modules[tool_module.__name__] = tool_module

anyio_module = types.ModuleType("anyio")
anyio_module.ClosedResourceError = type("ClosedResourceError", (Exception,), {})
sys.modules["anyio"] = anyio_module

Comment on lines +50 to +53
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Isolate sys.modules stubs to each test

load_mcp_client_module() writes fake modules into sys.modules and never restores them, so once this helper runs, later tests in the same pytest process can import these incomplete stubs (anyio, mcp, and related astrbot.* entries) instead of real modules, causing order-dependent failures or hiding real regressions. Please patch these entries via monkeypatch (or equivalent teardown) so global import state is restored after each test.

Useful? React with 👍 / 👎.

mcp_module = types.ModuleType("mcp")
mcp_module.Tool = type("Tool", (), {})
mcp_module.ClientSession = type("ClientSession", (), {})
mcp_module.ListToolsResult = type("ListToolsResult", (), {})
mcp_module.StdioServerParameters = type("StdioServerParameters", (), {})
mcp_module.stdio_client = lambda *args, **kwargs: None
mcp_module.types = types.SimpleNamespace(
LoggingMessageNotificationParams=type(
"LoggingMessageNotificationParams", (), {}
),
CallToolResult=type("CallToolResult", (), {}),
)
sys.modules["mcp"] = mcp_module

mcp_client_module = types.ModuleType("mcp.client")
sys.modules[mcp_client_module.__name__] = mcp_client_module

mcp_client_sse_module = types.ModuleType("mcp.client.sse")
mcp_client_sse_module.sse_client = lambda *args, **kwargs: None
sys.modules[mcp_client_sse_module.__name__] = mcp_client_sse_module

mcp_client_streamable_http_module = types.ModuleType(
"mcp.client.streamable_http"
)
mcp_client_streamable_http_module.streamablehttp_client = (
lambda *args, **kwargs: None
)
sys.modules[mcp_client_streamable_http_module.__name__] = (
mcp_client_streamable_http_module
)

spec = importlib.util.spec_from_file_location(
"astrbot.core.agent.mcp_client", MCP_CLIENT_MODULE_PATH
)
assert spec and spec.loader
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
return module


def test_sanitize_mcp_arguments_removes_nested_empty_collections():
mcp_client_module = load_mcp_client_module()

sanitized = mcp_client_module._sanitize_mcp_arguments(
{
"query": "hello",
"filters": {"tags": [], "scope": {}},
"metadata": {"owner": "", "visibility": None},
}
)

assert sanitized == {"query": "hello"}


@pytest.mark.asyncio
async def test_call_tool_with_reconnect_falls_back_to_empty_top_level_arguments():
mcp_client_module = load_mcp_client_module()

client = mcp_client_module.MCPClient()
client.session = types.SimpleNamespace(call_tool=AsyncMock(return_value="ok"))

result = await client.call_tool_with_reconnect(
tool_name="search",
arguments={"filters": {}, "query": ""},
read_timeout_seconds=mcp_client_module.timedelta(seconds=1),
)

assert result == "ok"
client.session.call_tool.assert_awaited_once_with(
name="search",
arguments={},
read_timeout_seconds=mcp_client_module.timedelta(seconds=1),
)
61 changes: 61 additions & 0 deletions tests/unit/test_mcp_client_sanitization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from astrbot.core.agent.mcp_client import _sanitize_mcp_arguments


def test_sanitize_mcp_arguments_drops_empty_optional_object_fields():
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"area": {"type": "string"},
"floor": {"type": "string"},
"domain": {"type": "array", "items": {"type": "string"}},
"device_class": {"type": "array", "items": {"type": "string"}},
},
"required": ["name"],
}

value = {
"name": "demo",
"area": "",
"floor": "",
"domain": ["light"],
"device_class": [],
}

assert _sanitize_mcp_arguments(value, schema) == {
"name": "demo",
"domain": ["light"],
}


def test_sanitize_mcp_arguments_preserves_required_empty_values():
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"tags": {"type": "array", "items": {"type": "string"}},
"metadata": {
"type": "object",
"properties": {
"note": {"type": "string"},
},
},
},
"required": ["name", "tags", "metadata"],
}

value = {
"name": "",
"tags": [],
"metadata": {},
}

assert _sanitize_mcp_arguments(value, schema) == value


def test_sanitize_mcp_arguments_preserves_list_positions():
schema = {"type": "array", "items": {"type": "string"}}

value = ["alpha", "", "omega"]

assert _sanitize_mcp_arguments(value, schema) == ["alpha", "", "omega"]
22 changes: 15 additions & 7 deletions tests/unit/test_tool_google_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from pathlib import Path
from typing import Generic, TypeVar

import pytest

REPO_ROOT = Path(__file__).resolve().parents[2]
TOOL_MODULE_PATH = REPO_ROOT / "astrbot/core/agent/tool.py"


def load_tool_module():
def load_tool_module(monkeypatch: pytest.MonkeyPatch):
package_names = [
"astrbot",
"astrbot.core",
Expand All @@ -21,13 +23,17 @@ def load_tool_module():
if name not in sys.modules:
module = types.ModuleType(name)
module.__path__ = []
sys.modules[name] = module
monkeypatch.setitem(sys.modules, name, module)

message_result_module = types.ModuleType(
"astrbot.core.message.message_event_result"
)
message_result_module.MessageEventResult = type("MessageEventResult", (), {})
sys.modules[message_result_module.__name__] = message_result_module
monkeypatch.setitem(
sys.modules,
message_result_module.__name__,
message_result_module,
)

run_context_module = types.ModuleType("astrbot.core.agent.run_context")
run_context_module.TContext = TypeVar("TContext")
Expand All @@ -36,20 +42,22 @@ class ContextWrapper(Generic[run_context_module.TContext]):
pass

run_context_module.ContextWrapper = ContextWrapper
sys.modules[run_context_module.__name__] = run_context_module
monkeypatch.setitem(sys.modules, run_context_module.__name__, run_context_module)

spec = importlib.util.spec_from_file_location(
"astrbot.core.agent.tool", TOOL_MODULE_PATH
)
assert spec and spec.loader
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
monkeypatch.setitem(sys.modules, spec.name, module)
spec.loader.exec_module(module)
return module


def test_google_schema_fills_missing_array_items_with_string_schema():
tool_module = load_tool_module()
def test_google_schema_fills_missing_array_items_with_string_schema(
monkeypatch: pytest.MonkeyPatch,
):
tool_module = load_tool_module(monkeypatch)
FunctionTool = tool_module.FunctionTool
ToolSet = tool_module.ToolSet

Expand Down