diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index d4fe1b838..aa62b9f31 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -233,7 +233,7 @@ private Flowable callLlm( callLlmContext) .doOnSubscribe( s -> - Tracing.traceCallLlm( + traceCallLlm( span, context, eventForCallbackUsage.id(), @@ -520,6 +520,7 @@ public Flowable runLive(InvocationContext invocationContext) { .doOnComplete( () -> Tracing.traceSendData( + Span.current(), invocationContext, eventIdForSendData, llmRequestAfterPreprocess.contents())) @@ -529,6 +530,7 @@ public Flowable runLive(InvocationContext invocationContext) { span.setStatus(StatusCode.ERROR, error.getMessage()); span.recordException(error); Tracing.traceSendData( + Span.current(), invocationContext, eventIdForSendData, llmRequestAfterPreprocess.contents()); @@ -706,6 +708,19 @@ private Flowable buildPostprocessingEvents( return processorEvents.concatWith(Flowable.just(modelResponseEvent)).concatWith(functionEvents); } + /** + * Traces an LLM call without an associated exception. This is an overload for {@link + * Tracing#traceCallLlm} for successful calls. + */ + private void traceCallLlm( + Span span, + InvocationContext context, + String eventId, + LlmRequest llmRequest, + LlmResponse llmResponse) { + Tracing.traceCallLlm(span, context, eventId, llmRequest, llmResponse, null); + } + private Event buildModelResponseEvent( Event baseEventForLlmResponse, LlmRequest llmRequest, LlmResponse llmResponse) { Event.Builder eventBuilder = diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 84a8141ea..0b0e5b4d5 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -178,8 +178,12 @@ public static Maybe handleFunctionCalls( if (events.size() > 1) { return Maybe.just(mergedEvent) - .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)) - .compose(Tracing.trace("tool_response").setParent(parentContext)); + .compose( + Tracing.trace("execute_tool (merged)") + .setParent(parentContext) + .onSuccess( + (span, event) -> + Tracing.traceMergedToolCalls(span, event.id(), event))); } return Maybe.just(mergedEvent); }); @@ -269,10 +273,8 @@ private static Function> getFunctionCallMapper( tool, toolContext, functionCall, - functionArgs, - parentContext) - : callTool( - tool, functionArgs, toolContext, parentContext)) + functionArgs) + : callTool(tool, functionArgs, toolContext)) .compose(Tracing.withContext(parentContext))); return postProcessFunctionResult( @@ -296,8 +298,7 @@ private static Maybe> processFunctionLive( BaseTool tool, ToolContext toolContext, FunctionCall functionCall, - Map args, - Context parentContext) { + Map args) { // Case 1: Handle a call to stopStreaming if (functionCall.name().get().equals("stopStreaming") && args.containsKey("functionName")) { String functionNameToStop = (String) args.get("functionName"); @@ -365,7 +366,7 @@ private static Maybe> processFunctionLive( } // Case 3: Fallback for regular, non-streaming tools - return callTool(tool, args, toolContext, parentContext); + return callTool(tool, args, toolContext); } public static Set getLongRunningFunctionCalls( @@ -426,12 +427,22 @@ private static Maybe postProcessFunctionResult( Event event = buildResponseEvent( tool, finalFunctionResult, toolContext, invocationContext); - Tracing.traceToolResponse(event.id(), event); return Maybe.just(event); }); }) .compose( - Tracing.trace("tool_response [" + tool.name() + "]").setParent(parentContext)); + Tracing.trace("execute_tool [" + tool.name() + "]") + .setParent(parentContext) + .onSuccess( + (span, event) -> + Tracing.traceToolExecution( + span, + tool.name(), + tool.description(), + tool.getClass().getSimpleName(), + functionArgs, + event, + null))); } private static Optional mergeParallelFunctionResponseEvents( @@ -579,17 +590,10 @@ private static Maybe> maybeInvokeAfterToolCall( } private static Maybe> callTool( - BaseTool tool, Map args, ToolContext toolContext, Context parentContext) { + BaseTool tool, Map args, ToolContext toolContext) { return tool.runAsync(args, toolContext) .toMaybe() - .doOnSubscribe( - d -> - Tracing.traceToolCall( - tool.name(), tool.description(), tool.getClass().getSimpleName(), args)) .doOnError(t -> Span.current().recordException(t)) - .compose( - Tracing.>trace("tool_call [" + tool.name() + "]") - .setParent(parentContext)) .onErrorResumeNext( e -> Maybe.error( diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 215e317e1..589215073 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -33,6 +33,7 @@ import io.opentelemetry.api.GlobalOpenTelemetry; import io.opentelemetry.api.common.AttributeKey; import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.StatusCode; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; @@ -61,6 +62,7 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Supplier; +import org.jspecify.annotations.Nullable; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; @@ -77,6 +79,11 @@ public class Tracing { private static final Logger log = LoggerFactory.getLogger(Tracing.class); + private static final String INVOKE_AGENT_OPERATION = "invoke_agent"; + private static final String EXECUTE_TOOL_OPERATION = "execute_tool"; + private static final String SEND_DATA_OPERATION = "send_data"; + private static final String CALL_LLM_OPERATION = "call_llm"; + private static final AttributeKey> GEN_AI_RESPONSE_FINISH_REASONS = AttributeKey.stringArrayKey("gen_ai.response.finish_reasons"); @@ -134,15 +141,6 @@ public class Tracing { private Tracing() {} - private static void traceWithSpan(String methodName, Consumer traceAction) { - Span span = Span.current(); - if (!span.getSpanContext().isValid()) { - log.trace("{}: No valid span in current context.", methodName); - return; - } - traceAction.accept(span); - } - private static void setInvocationAttributes( Span span, InvocationContext invocationContext, String eventId) { span.setAttribute(ADK_INVOCATION_ID, invocationContext.invocationId()); @@ -159,12 +157,6 @@ private static void setInvocationAttributes( } } - private static void setToolExecutionAttributes(Span span) { - span.setAttribute(GEN_AI_OPERATION_NAME, "execute_tool"); - span.setAttribute(ADK_LLM_REQUEST, "{}"); - span.setAttribute(ADK_LLM_RESPONSE, "{}"); - } - private static void setJsonAttribute(Span span, AttributeKey key, Object value) { if (!CAPTURE_MESSAGE_CONTENT_IN_SPANS) { span.setAttribute(key, "{}"); @@ -198,7 +190,7 @@ public static void setTracerForTesting(Tracer tracer) { */ public static void traceAgentInvocation( Span span, String agentName, String agentDescription, InvocationContext invocationContext) { - span.setAttribute(GEN_AI_OPERATION_NAME, "invoke_agent"); + span.setAttribute(GEN_AI_OPERATION_NAME, INVOKE_AGENT_OPERATION); span.setAttribute(GEN_AI_AGENT_DESCRIPTION, agentDescription); span.setAttribute(GEN_AI_AGENT_NAME, agentName); if (invocationContext.session() != null && invocationContext.session().id() != null) { @@ -207,58 +199,62 @@ public static void traceAgentInvocation( } /** - * Traces tool call arguments. - * - * @param args The arguments to the tool call. - */ - public static void traceToolCall( - String toolName, String toolDescription, String toolType, Map args) { - traceWithSpan( - "traceToolCall", - span -> { - setToolExecutionAttributes(span); - span.setAttribute(GEN_AI_TOOL_NAME, toolName); - span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); - span.setAttribute(GEN_AI_TOOL_TYPE, toolType); - - setJsonAttribute(span, ADK_TOOL_CALL_ARGS, args); - }); - } - - /** - * Traces tool response event. + * Traces a tool execution, including its arguments, response, and any potential error. * - * @param eventId The ID of the event. - * @param functionResponseEvent The function response event. + * @param span The span representing the tool execution. + * @param toolName The name of the tool. + * @param toolDescription The tool's description. + * @param toolType The tool's type (e.g., "FunctionTool"). + * @param args The arguments passed to the tool. + * @param functionResponseEvent The event containing the tool's response, if successful. + * @param error The exception thrown during execution, if any. */ - public static void traceToolResponse(String eventId, Event functionResponseEvent) { - traceWithSpan( - "traceToolResponse", - span -> { - setToolExecutionAttributes(span); - span.setAttribute(ADK_EVENT_ID, eventId); - - FunctionResponse functionResponse = - functionResponseEvent.functionResponses().stream().findFirst().orElse(null); - - String toolCallId = ""; - Object toolResponse = ""; - if (functionResponse != null) { - toolCallId = functionResponse.id().orElse(toolCallId); - if (functionResponse.response().isPresent()) { - toolResponse = functionResponse.response().get(); - } - } - - span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); + public static void traceToolExecution( + Span span, + String toolName, + String toolDescription, + String toolType, + Map args, + @Nullable Event functionResponseEvent, + @Nullable Exception error) { + span.setAttribute(GEN_AI_OPERATION_NAME, EXECUTE_TOOL_OPERATION); + span.setAttribute(GEN_AI_TOOL_NAME, toolName); + span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); + span.setAttribute(GEN_AI_TOOL_TYPE, toolType); + + setJsonAttribute(span, ADK_TOOL_CALL_ARGS, args); + + if (functionResponseEvent != null) { + span.setAttribute(ADK_EVENT_ID, functionResponseEvent.id()); + FunctionResponse functionResponse = + functionResponseEvent.functionResponses().stream().findFirst().orElse(null); + + String toolCallId = ""; + Object toolResponse = ""; + if (functionResponse != null) { + toolCallId = functionResponse.id().orElse(toolCallId); + if (functionResponse.response().isPresent()) { + toolResponse = functionResponse.response().get(); + } + } + span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); + Object finalToolResponse = + (toolResponse instanceof Map) ? toolResponse : ImmutableMap.of("result", toolResponse); + setJsonAttribute(span, ADK_TOOL_RESPONSE, finalToolResponse); + } else { + // Set placeholder if no response event is available (e.g., due to an error) + span.setAttribute(GEN_AI_TOOL_CALL_ID, ""); + setJsonAttribute(span, ADK_TOOL_RESPONSE, "{}"); + } - Object finalToolResponse = - (toolResponse instanceof Map) - ? toolResponse - : ImmutableMap.of("result", toolResponse); + // Also set empty LLM attributes for UI compatibility, like in traceToolResponse + span.setAttribute(ADK_LLM_REQUEST, "{}"); + span.setAttribute(ADK_LLM_RESPONSE, "{}"); - setJsonAttribute(span, ADK_TOOL_RESPONSE, finalToolResponse); - }); + if (error != null) { + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + } } /** @@ -303,8 +299,10 @@ public static void traceCallLlm( InvocationContext invocationContext, String eventId, LlmRequest llmRequest, - LlmResponse llmResponse) { + LlmResponse llmResponse, + @Nullable Exception error) { span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); + span.setAttribute(GEN_AI_OPERATION_NAME, CALL_LLM_OPERATION); llmRequest.model().ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); setInvocationAttributes(span, invocationContext, eventId); @@ -312,6 +310,11 @@ public static void traceCallLlm( setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); + if (error != null) { + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + } + llmRequest .config() .ifPresent( @@ -352,18 +355,45 @@ public static void traceCallLlm( * @param data A list of content objects being sent. */ public static void traceSendData( - InvocationContext invocationContext, String eventId, List data) { - traceWithSpan( - "traceSendData", - span -> { - setInvocationAttributes(span, invocationContext, eventId); - - ImmutableList safeData = - Optional.ofNullable(data).orElse(ImmutableList.of()).stream() - .filter(Objects::nonNull) - .collect(toImmutableList()); - setJsonAttribute(span, ADK_DATA, safeData); - }); + Span span, InvocationContext invocationContext, String eventId, List data) { + if (!span.getSpanContext().isValid()) { + log.trace("traceSendData: No valid span in current context."); + return; + } + setInvocationAttributes(span, invocationContext, eventId); + span.setAttribute(GEN_AI_OPERATION_NAME, SEND_DATA_OPERATION); + + ImmutableList safeData = + Optional.ofNullable(data).orElse(ImmutableList.of()).stream() + .filter(Objects::nonNull) + .collect(toImmutableList()); + setJsonAttribute(span, ADK_DATA, safeData); + } + + /** + * Traces merged tool call events. + * + *

Calling this function is not needed for telemetry purposes. This is provided for preventing + * /debug/trace requests (typically sent by web UI). + * + * @param responseEventId The ID of the response event. + * @param functionResponseEvent The merged response event. + */ + public static void traceMergedToolCalls( + Span span, String responseEventId, Event functionResponseEvent) { + if (!span.getSpanContext().isValid()) { + log.trace("traceMergedToolCalls: No valid span in current context."); + return; + } + span.setAttribute(GEN_AI_OPERATION_NAME, EXECUTE_TOOL_OPERATION); + span.setAttribute(GEN_AI_TOOL_NAME, "(merged tools)"); + span.setAttribute(GEN_AI_TOOL_DESCRIPTION, "(merged tools)"); + span.setAttribute(GEN_AI_TOOL_CALL_ID, responseEventId); + span.setAttribute(ADK_TOOL_CALL_ARGS, "N/A"); + span.setAttribute(ADK_EVENT_ID, responseEventId); + setJsonAttribute(span, ADK_TOOL_RESPONSE, functionResponseEvent); + span.setAttribute(ADK_LLM_REQUEST, "{}"); + span.setAttribute(ADK_LLM_RESPONSE, "{}"); } /** diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index a9e7a6f8d..e40a83aa0 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -494,12 +494,10 @@ public void runAsync_withTools_createsToolSpans() throws InterruptedException { List spans = openTelemetryRule.getSpans(); SpanData agentSpan = findSpanByName(spans, "invoke_agent test agent"); List llmSpans = findSpansByName(spans, "call_llm"); - List toolCallSpans = findSpansByName(spans, "tool_call [echo_tool]"); - List toolResponseSpans = findSpansByName(spans, "tool_response [echo_tool]"); + List toolSpans = findSpansByName(spans, "execute_tool [echo_tool]"); assertThat(llmSpans).hasSize(2); - assertThat(toolCallSpans).hasSize(1); - assertThat(toolResponseSpans).hasSize(1); + assertThat(toolSpans).hasSize(1); String agentSpanId = agentSpan.getSpanContext().getSpanId(); llmSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); @@ -507,9 +505,7 @@ public void runAsync_withTools_createsToolSpans() throws InterruptedException { // The tool calls and responses are children of the first LLM call that produced the function // call. String firstLlmSpanId = llmSpans.get(0).getSpanContext().getSpanId(); - toolCallSpans.forEach(s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); - toolResponseSpans.forEach( - s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); + toolSpans.forEach(s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); } @Test diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index efd565c16..b68b6ff5f 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -1061,21 +1061,16 @@ public void runAsync_createsToolSpansWithCorrectParent() { List spans = openTelemetryRule.getSpans(); List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); - List toolCallSpans = - spans.stream().filter(s -> s.getName().equals("tool_call [echo_tool]")).toList(); - List toolResponseSpans = - spans.stream().filter(s -> s.getName().equals("tool_response [echo_tool]")).toList(); + List toolSpans = + spans.stream().filter(s -> s.getName().equals("execute_tool [echo_tool]")).toList(); assertThat(llmSpans).hasSize(2); - assertThat(toolCallSpans).hasSize(1); - assertThat(toolResponseSpans).hasSize(1); + assertThat(toolSpans).hasSize(1); List llmSpanIds = llmSpans.stream().map(s -> s.getSpanContext().getSpanId()).toList(); - String toolCallParentId = toolCallSpans.get(0).getParentSpanContext().getSpanId(); - String toolResponseParentId = toolResponseSpans.get(0).getParentSpanContext().getSpanId(); + String toolParentId = toolSpans.get(0).getParentSpanContext().getSpanId(); - assertThat(toolCallParentId).isEqualTo(toolResponseParentId); - assertThat(llmSpanIds).contains(toolCallParentId); + assertThat(llmSpanIds).contains(toolParentId); } @Test @@ -1101,22 +1096,17 @@ public void runLive_createsToolSpansWithCorrectParent() throws Exception { List spans = openTelemetryRule.getSpans(); List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); - List toolCallSpans = - spans.stream().filter(s -> s.getName().equals("tool_call [echo_tool]")).toList(); - List toolResponseSpans = - spans.stream().filter(s -> s.getName().equals("tool_response [echo_tool]")).toList(); + List toolSpans = + spans.stream().filter(s -> s.getName().equals("execute_tool [echo_tool]")).toList(); // In runLive, there is one call_llm span for the execution assertThat(llmSpans).hasSize(1); - assertThat(toolCallSpans).hasSize(1); - assertThat(toolResponseSpans).hasSize(1); + assertThat(toolSpans).hasSize(1); List llmSpanIds = llmSpans.stream().map(s -> s.getSpanContext().getSpanId()).toList(); - String toolCallParentId = toolCallSpans.get(0).getParentSpanContext().getSpanId(); - String toolResponseParentId = toolResponseSpans.get(0).getParentSpanContext().getSpanId(); + String toolParentId = toolSpans.get(0).getParentSpanContext().getSpanId(); - assertThat(toolCallParentId).isEqualTo(toolResponseParentId); - assertThat(llmSpanIds).contains(toolCallParentId); + assertThat(llmSpanIds).contains(toolParentId); } @Test diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index b13904934..44877e972 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -18,7 +18,6 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import com.google.adk.agents.BaseAgent; @@ -100,242 +99,6 @@ public void tearDown() { Tracing.setTracerForTesting(originalTracer); } - @Test - public void testToolCallSpanLinksToParent() { - // Given: Parent span is active - Span parentSpan = tracer.spanBuilder("parent").startSpan(); - - try (Scope scope = parentSpan.makeCurrent()) { - // When: ADK creates tool_call span with setParent(Context.current()) - Span toolCallSpan = - tracer.spanBuilder("tool_call [testTool]").setParent(Context.current()).startSpan(); - - try (Scope toolScope = toolCallSpan.makeCurrent()) { - // Simulate tool execution - } finally { - toolCallSpan.end(); - } - } finally { - parentSpan.end(); - } - - // Then: tool_call should be child of parent - SpanData parentSpanData = findSpanByName("parent"); - SpanData toolCallSpanData = findSpanByName("tool_call [testTool]"); - - // Verify parent-child relationship - assertEquals( - "Tool call should have same trace ID as parent", - parentSpanData.getSpanContext().getTraceId(), - toolCallSpanData.getSpanContext().getTraceId()); - - assertParent(parentSpanData, toolCallSpanData); - } - - @Test - public void testToolCallWithoutParentCreatesRootSpan() { - // Given: No parent span active - // When: ADK creates tool_call span with setParent(Context.current()) - try (Scope s = Context.root().makeCurrent()) { - Span toolCallSpan = - tracer.spanBuilder("tool_call [testTool]").setParent(Context.current()).startSpan(); - - try (Scope scope = toolCallSpan.makeCurrent()) { - // Work - } finally { - toolCallSpan.end(); - } - } - - // Then: Should create root span (backward compatible) - List spans = openTelemetryRule.getSpans(); - assertThat(spans).hasSize(1); - - SpanData toolCallSpanData = spans.get(0); - assertFalse( - "Tool call should be root span when no parent exists", - toolCallSpanData.getParentSpanContext().isValid()); - } - - @Test - public void testNestedSpanHierarchy() { - // Test: parent → invocation → tool_call → tool_response hierarchy - - Span parentSpan = tracer.spanBuilder("parent").startSpan(); - - try (Scope parentScope = parentSpan.makeCurrent()) { - - Span invocationSpan = - tracer.spanBuilder("invocation").setParent(Context.current()).startSpan(); - - try (Scope invocationScope = invocationSpan.makeCurrent()) { - - Span toolCallSpan = - tracer.spanBuilder("tool_call [testTool]").setParent(Context.current()).startSpan(); - - try (Scope toolScope = toolCallSpan.makeCurrent()) { - - Span toolResponseSpan = - tracer - .spanBuilder("tool_response [testTool]") - .setParent(Context.current()) - .startSpan(); - - toolResponseSpan.end(); - } finally { - toolCallSpan.end(); - } - } finally { - invocationSpan.end(); - } - } finally { - parentSpan.end(); - } - - // Verify complete hierarchy - List spans = openTelemetryRule.getSpans(); - // The 4 spans are: "parent", "invocation", "tool_call [testTool]", and "tool_response - // [testTool]". - assertThat(spans).hasSize(4); - - SpanData parentSpanData = findSpanByName("parent"); - String parentTraceId = parentSpanData.getSpanContext().getTraceId(); - - // All spans should have same trace ID - for (SpanData span : openTelemetryRule.getSpans()) { - assertEquals( - "All spans should be in same trace", parentTraceId, span.getSpanContext().getTraceId()); - } - - // Verify parent-child relationships - SpanData invocationSpanData = findSpanByName("invocation"); - SpanData toolCallSpanData = findSpanByName("tool_call [testTool]"); - SpanData toolResponseSpanData = findSpanByName("tool_response [testTool]"); - - // invocation should be child of parent - assertParent(parentSpanData, invocationSpanData); - - // tool_call should be child of invocation - assertParent(invocationSpanData, toolCallSpanData); - - // tool_response should be child of tool_call - assertParent(toolCallSpanData, toolResponseSpanData); - } - - @Test - public void testMultipleSpansInParallel() { - // Test: Multiple tool calls in parallel should all link to same parent - - Span parentSpan = tracer.spanBuilder("parent").startSpan(); - - try (Scope parentScope = parentSpan.makeCurrent()) { - // Simulate parallel tool calls - Span toolCall1 = - tracer.spanBuilder("tool_call [tool1]").setParent(Context.current()).startSpan(); - Span toolCall2 = - tracer.spanBuilder("tool_call [tool2]").setParent(Context.current()).startSpan(); - Span toolCall3 = - tracer.spanBuilder("tool_call [tool3]").setParent(Context.current()).startSpan(); - - toolCall1.end(); - toolCall2.end(); - toolCall3.end(); - } finally { - parentSpan.end(); - } - - // Verify all tool calls link to same parent - SpanData parentSpanData = findSpanByName("parent"); - String parentTraceId = parentSpanData.getSpanContext().getTraceId(); - - // All tool calls should have same trace ID and parent span ID - List toolCallSpans = - openTelemetryRule.getSpans().stream() - .filter(s -> s.getName().startsWith("tool_call")) - .toList(); - - assertThat(toolCallSpans).hasSize(3); - - toolCallSpans.forEach( - span -> { - assertEquals( - "Tool call should have same trace ID as parent", - parentTraceId, - span.getSpanContext().getTraceId()); - assertParent(parentSpanData, span); - }); - } - - @Test - public void testInvokeAgentSpanLinksToInvocation() { - // Test: invoke_agent span should link to invocation span - - Span invocationSpan = tracer.spanBuilder("invocation").startSpan(); - - try (Scope invocationScope = invocationSpan.makeCurrent()) { - Span invokeAgentSpan = - tracer.spanBuilder("invoke_agent test-agent").setParent(Context.current()).startSpan(); - - try (Scope agentScope = invokeAgentSpan.makeCurrent()) { - // Simulate agent work - } finally { - invokeAgentSpan.end(); - } - } finally { - invocationSpan.end(); - } - - SpanData invocationSpanData = findSpanByName("invocation"); - SpanData invokeAgentSpanData = findSpanByName("invoke_agent test-agent"); - - assertParent(invocationSpanData, invokeAgentSpanData); - } - - @Test - public void testCallLlmSpanLinksToAgentRun() { - // Test: call_llm span should link to agent_run span - - Span invokeAgentSpan = tracer.spanBuilder("invoke_agent test-agent").startSpan(); - - try (Scope agentScope = invokeAgentSpan.makeCurrent()) { - Span callLlmSpan = tracer.spanBuilder("call_llm").setParent(Context.current()).startSpan(); - - try (Scope llmScope = callLlmSpan.makeCurrent()) { - // Simulate LLM call - } finally { - callLlmSpan.end(); - } - } finally { - invokeAgentSpan.end(); - } - - List spans = openTelemetryRule.getSpans(); - assertThat(spans).hasSize(2); - - SpanData invokeAgentSpanData = findSpanByName("invoke_agent test-agent"); - SpanData callLlmSpanData = findSpanByName("call_llm"); - - assertParent(invokeAgentSpanData, callLlmSpanData); - } - - @Test - public void testSpanCreatedWithinParentScopeIsCorrectlyParented() { - // Test: Simulates creating a span within the scope of a parent - - Span parentSpan = tracer.spanBuilder("invocation").startSpan(); - try (Scope scope = parentSpan.makeCurrent()) { - Span agentSpan = tracer.spanBuilder("invoke_agent").setParent(Context.current()).startSpan(); - agentSpan.end(); - } finally { - parentSpan.end(); - } - - SpanData parentSpanData = findSpanByName("invocation"); - SpanData agentSpanData = findSpanByName("invoke_agent"); - - assertParent(parentSpanData, agentSpanData); - } - @Test public void testTraceFlowable() throws InterruptedException { Span parentSpan = tracer.spanBuilder("parent").startSpan(); @@ -475,8 +238,14 @@ public void testTraceAgentInvocation() { public void testTraceToolCall() { Span span = tracer.spanBuilder("test").startSpan(); try (Scope scope = span.makeCurrent()) { - Tracing.traceToolCall( - "tool-name", "tool-description", "tool-type", ImmutableMap.of("arg1", "value1")); + Tracing.traceToolExecution( + span, + "tool-name", + "tool-description", + "tool-type", + ImmutableMap.of("arg1", "value1"), + null, + null); } finally { span.end(); } @@ -513,7 +282,14 @@ public void testTraceToolResponse() { .build()) .build())) .build(); - Tracing.traceToolResponse("event-1", functionResponseEvent); + Tracing.traceToolExecution( + span, + "tool-name", + "tool-description", + "tool-type", + ImmutableMap.of(), + functionResponseEvent, + null); } finally { span.end(); } @@ -524,6 +300,10 @@ public void testTraceToolResponse() { assertEquals("execute_tool", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); assertEquals("event-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.event_id"))); assertEquals("tool-call-id", attrs.get(AttributeKey.stringKey("gen_ai.tool_call.id"))); + assertEquals("tool-name", attrs.get(AttributeKey.stringKey("gen_ai.tool.name"))); + assertEquals("tool-description", attrs.get(AttributeKey.stringKey("gen_ai.tool.description"))); + assertEquals("tool-type", attrs.get(AttributeKey.stringKey("gen_ai.tool.type"))); + assertEquals("{}", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.tool_call_args"))); assertEquals( "{\"result\":\"tool-result\"}", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.tool_response"))); @@ -550,7 +330,8 @@ public void testTraceCallLlm() { .totalTokenCount(30) .build()) .build(); - Tracing.traceCallLlm(span, buildInvocationContext(), "event-1", llmRequest, llmResponse); + Tracing.traceCallLlm( + span, buildInvocationContext(), "event-1", llmRequest, llmResponse, null); } finally { span.end(); } @@ -559,6 +340,7 @@ public void testTraceCallLlm() { SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals("gcp.vertex.agent", attrs.get(AttributeKey.stringKey("gen_ai.system"))); + assertEquals("call_llm", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); assertEquals("gemini-pro", attrs.get(AttributeKey.stringKey("gen_ai.request.model"))); assertEquals( "test-invocation-id", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.invocation_id"))); @@ -581,6 +363,7 @@ public void testTraceSendData() { Span span = tracer.spanBuilder("test").startSpan(); try (Scope scope = span.makeCurrent()) { Tracing.traceSendData( + span, buildInvocationContext(), "event-1", ImmutableList.of(Content.fromParts(Part.fromText("hello")))); @@ -591,6 +374,7 @@ public void testTraceSendData() { assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); + assertEquals("send_data", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); assertEquals( "test-invocation-id", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.invocation_id"))); assertEquals("event-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.event_id"))); @@ -687,8 +471,7 @@ public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { // invocation // └── invoke_agent test_agent // ├── call_llm - // │ ├── tool_call [search_flights] - // │ └── tool_response [search_flights] + // │ └── execute_tool [search_flights] // └── call_llm SearchFlightsTool searchFlightsTool = new SearchFlightsTool(); @@ -716,8 +499,7 @@ public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { SpanData invocation = findSpanByName("invocation"); SpanData invokeAgent = findSpanByName("invoke_agent test_agent"); - SpanData toolCall = findSpanByName("tool_call [search_flights]"); - SpanData toolResponse = findSpanByName("tool_response [search_flights]"); + SpanData toolResponse = findSpanByName("execute_tool [search_flights]"); List callLlmSpans = openTelemetryRule.getSpans().stream() .filter(s -> s.getName().equals("call_llm")) @@ -733,12 +515,28 @@ public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { assertParent(invocation, invokeAgent); // ├── call_llm 1 assertParent(invokeAgent, callLlm1); - // │ ├── tool_call [search_flights] - assertParent(callLlm1, toolCall); - // │ └── tool_response [search_flights] + // │ └── execute_tool [search_flights] assertParent(callLlm1, toolResponse); // └── call_llm 2 assertParent(invokeAgent, callLlm2); + + // Assert attributes + assertEquals( + "invoke_agent", + invokeAgent.getAttributes().get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals( + "call_llm", callLlm1.getAttributes().get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals( + "execute_tool", + toolResponse.getAttributes().get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals( + "search_flights", + toolResponse.getAttributes().get(AttributeKey.stringKey("gen_ai.tool.name"))); + assertEquals( + "execute_tool", + toolResponse.getAttributes().get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals( + "call_llm", callLlm2.getAttributes().get(AttributeKey.stringKey("gen_ai.operation.name"))); } @Test @@ -748,8 +546,7 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { // invocation // └── invoke_agent AgentA // ├── call_llm - // │ ├── tool_call [transfer_to_agent] - // │ └── tool_response [transfer_to_agent] + // │ └── execute_tool [transfer_to_agent] // └── invoke_agent AgentB // └── call_llm TestLlm llm = @@ -776,9 +573,8 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { SpanData invocation = findSpanByName("invocation"); SpanData agentASpan = findSpanByName("invoke_agent AgentA"); - SpanData toolCall = findSpanByName("tool_call [transfer_to_agent]"); + SpanData executeTool = findSpanByName("execute_tool [transfer_to_agent]"); SpanData agentBSpan = findSpanByName("invoke_agent AgentB"); - SpanData toolResponse = findSpanByName("tool_response [transfer_to_agent]"); List callLlmSpans = openTelemetryRule.getSpans().stream() @@ -792,8 +588,7 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { assertParent(invocation, agentASpan); assertParent(agentASpan, agentACallLlm1); - assertParent(agentACallLlm1, toolCall); - assertParent(agentACallLlm1, toolResponse); + assertParent(agentACallLlm1, executeTool); assertParent(agentASpan, agentBSpan); assertParent(agentBSpan, agentBCallLlm); }