diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 2bfbca881..82f342a81 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -477,13 +477,8 @@ protected Flowable runAsyncImpl( session.appName(), session.userId(), session.id(), Optional.empty()) .flatMapPublisher( updatedSession -> - runAgentWithFreshSession( - session, - updatedSession, - event, - invocationId, - runConfig, - rootAgent)) + runAgentWithUpdatedSession( + initialContext, updatedSession, event, rootAgent)) .compose(Tracing.withContext(capturedContext)); }); }) @@ -495,19 +490,27 @@ protected Flowable runAsyncImpl( }); } - private Flowable runAgentWithFreshSession( - Session session, - Session updatedSession, - Event event, - String invocationId, - RunConfig runConfig, - BaseAgent rootAgent) { + /** + * Runs the agent with the updated session state. + * + *

This method is called after the user message has been persistent in the session. It creates + * a final {@link InvocationContext} that inherits state from the {@code initialContext} but uses + * the {@code updatedSession} to ensure the agent can access the latest conversation history. + * + * @param initialContext the context from the start of the invocation, used to preserve metadata + * and callback data. + * @param updatedSession the session object containing the latest message. + * @param event the event representing the user message that was just appended. + * @param rootAgent the agent to be executed. + * @return a stream of events from the agent execution and subsequent plugin callbacks. + */ + private Flowable runAgentWithUpdatedSession( + InvocationContext initialContext, Session updatedSession, Event event, BaseAgent rootAgent) { // Create context with updated session for beforeRunCallback InvocationContext contextWithUpdatedSession = - newInvocationContextBuilder(updatedSession) - .invocationId(invocationId) + initialContext.toBuilder() + .session(updatedSession) .agent(this.findAgentToRun(updatedSession, rootAgent)) - .runConfig(runConfig) .userContent(event.content().orElseGet(Content::fromParts)) .build(); @@ -536,7 +539,7 @@ private Flowable runAgentWithFreshSession( .flatMap( registeredEvent -> { // TODO: remove this hack after deprecating runAsync with Session. - copySessionStates(updatedSession, session); + copySessionStates(updatedSession, initialContext.session()); return contextWithUpdatedSession .pluginManager() .onEventCallback(contextWithUpdatedSession, registeredEvent) diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index b68b6ff5f..36530faf2 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -591,6 +591,29 @@ public void onEventCallback_success() { verify(plugin).onEventCallback(any(), any()); } + @Test + public void callbackContextData_preservedAcrossInvocation() { + String testKey = "testKey"; + String testValue = "testValue"; + + when(plugin.onUserMessageCallback(any(), any())) + .thenAnswer( + invocation -> { + InvocationContext context = invocation.getArgument(0); + context.callbackContextData().put(testKey, testValue); + return Maybe.empty(); + }); + + ArgumentCaptor contextCaptor = + ArgumentCaptor.forClass(InvocationContext.class); + when(plugin.afterRunCallback(contextCaptor.capture())).thenReturn(Completable.complete()); + + var unused = + runner.runAsync("user", session.id(), createContent("test")).toList().blockingGet(); + + assertThat(contextCaptor.getValue().callbackContextData()).containsEntry(testKey, testValue); + } + @Test public void runAsync_withSessionKey_success() { var events =