diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index d2e1d61032..75f298c2ba 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -484,18 +484,13 @@ async def _run_on_tool_error_callbacks( tool = _get_tool(function_call, tools_dict) except ValueError as tool_error: tool = BaseTool(name=function_call.name, description='Tool not found') - error_response = await _run_on_tool_error_callbacks( - tool=tool, - tool_args=function_args, - tool_context=tool_context, - error=tool_error, - ) - if error_response is not None: - return __build_response_event( - tool, error_response, tool_context, invocation_context - ) - else: - raise tool_error + # Fall through to _run_with_trace so that before_tool_callback and the + # OTel span are created *before* on_tool_error_callback fires. This + # keeps the callback lifecycle balanced (push/pop) and prevents plugins + # like BigQueryAgentAnalyticsPlugin from corrupting their span stacks. + _tool_lookup_error: Exception = tool_error + else: + _tool_lookup_error = None async def _run_with_trace(): nonlocal function_args @@ -520,6 +515,22 @@ async def _run_with_trace(): if function_response: break + # Step 2.5: If the tool was not found (hallucinated), surface the error + # *after* before_tool_callback so the lifecycle stays balanced. + if _tool_lookup_error is not None: + error_response = await _run_on_tool_error_callbacks( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + error=_tool_lookup_error, + ) + if error_response is not None: + return __build_response_event( + tool, error_response, tool_context, invocation_context + ) + else: + raise _tool_lookup_error + # Step 3: Otherwise, proceed calling the tool normally. if function_response is None: try: @@ -715,17 +726,9 @@ async def _run_on_tool_error_callbacks( tool = _get_tool(function_call, tools_dict) except ValueError as tool_error: tool = BaseTool(name=function_call.name, description='Tool not found') - error_response = await _run_on_tool_error_callbacks( - tool=tool, - tool_args=function_args, - tool_context=tool_context, - error=tool_error, - ) - if error_response is not None: - return __build_response_event( - tool, error_response, tool_context, invocation_context - ) - raise tool_error + _tool_lookup_error: Exception = tool_error + else: + _tool_lookup_error = None async def _run_with_trace(): nonlocal function_args @@ -755,6 +758,21 @@ async def _run_with_trace(): if function_response: break + # Step 2.5: If the tool was not found (hallucinated), surface the error + # *after* before_tool_callback so the lifecycle stays balanced. + if _tool_lookup_error is not None: + error_response = await _run_on_tool_error_callbacks( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + error=_tool_lookup_error, + ) + if error_response is not None: + return __build_response_event( + tool, error_response, tool_context, invocation_context + ) + raise _tool_lookup_error + # Step 3: Otherwise, proceed calling the tool normally. if function_response is None: try: diff --git a/tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py b/tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py index 3c39e2844b..807885ebff 100644 --- a/tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py +++ b/tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py @@ -340,5 +340,113 @@ def agent_after_cb(tool, args, tool_context, tool_response): assert part.function_response.response == mock_plugin.after_tool_response +@pytest.mark.asyncio +async def test_hallucinated_tool_fires_before_and_error_callbacks( + mock_tool, mock_plugin +): + """Regression test for https://github.com/google/adk-python/issues/4775. + + When the LLM hallucinates a tool name, on_tool_error_callback used to fire + *before* before_tool_callback, corrupting plugin span stacks (e.g. + BigQueryAgentAnalyticsPlugin's TraceManager). After the fix, both + callbacks should fire in order: before_tool → on_tool_error. + """ + mock_plugin.enable_before_tool_callback = True + mock_plugin.enable_on_tool_error_callback = True + + # Track callback invocation order + call_order = [] + original_before = mock_plugin.before_tool_callback + original_error = mock_plugin.on_tool_error_callback + + async def tracking_before(**kwargs): + call_order.append("before_tool") + return await original_before(**kwargs) + + async def tracking_error(**kwargs): + call_order.append("on_tool_error") + return await original_error(**kwargs) + + mock_plugin.before_tool_callback = tracking_before + mock_plugin.on_tool_error_callback = tracking_error + + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name="agent", + model=model, + tools=[mock_tool], + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content="", plugins=[mock_plugin] + ) + + # Build function call for a non-existent tool (hallucinated name) + function_call = types.FunctionCall( + name="hallucinated_tool_xyz", args={"query": "test"} + ) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {mock_tool.name: mock_tool} + + result_event = await handle_function_calls_async( + invocation_context, + event, + tools_dict, + ) + + # on_tool_error_callback returned a response, so we should get an event + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == mock_plugin.on_tool_error_response + + # Verify that before_tool fired BEFORE on_tool_error + assert "before_tool" in call_order + assert "on_tool_error" in call_order + assert call_order.index("before_tool") < call_order.index("on_tool_error") + + +@pytest.mark.asyncio +async def test_hallucinated_tool_raises_when_no_error_callback( + mock_tool, mock_plugin +): + """When a tool is hallucinated and no error callback handles it, ValueError + should propagate — but only after before_tool_callback has had a chance to + run (so plugin stacks remain balanced).""" + mock_plugin.enable_before_tool_callback = False + mock_plugin.enable_on_tool_error_callback = False + + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name="agent", + model=model, + tools=[mock_tool], + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content="", plugins=[mock_plugin] + ) + + function_call = types.FunctionCall( + name="nonexistent_tool", args={} + ) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {mock_tool.name: mock_tool} + + with pytest.raises(ValueError, match="nonexistent_tool"): + await handle_function_calls_async( + invocation_context, + event, + tools_dict, + ) + + if __name__ == "__main__": pytest.main([__file__])