Skip to content
9 changes: 9 additions & 0 deletions src/google/adk/agents/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,15 @@ class LlmAgent(BaseAgent):
"""
# Callbacks - End

@override
async def _handle_before_agent_callback(
self, ctx: InvocationContext
) -> Optional[Event]:
event = await super()._handle_before_agent_callback(ctx)
if event is not None:
self.__maybe_save_output_to_state(event)
return event

@override
async def _run_async_impl(
self, ctx: InvocationContext
Expand Down
24 changes: 24 additions & 0 deletions tests/unittests/agents/test_llm_agent_output_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@
import logging
from unittest.mock import patch

from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import LlmAgent
from google.adk.events.event import Event
from google.adk.events.event_actions import EventActions
from google.genai import types
from pydantic import BaseModel
import pytest

from .. import testing_utils


class MockOutputSchema(BaseModel):
message: str
Expand Down Expand Up @@ -276,3 +279,24 @@ def test_maybe_save_output_to_state_handles_empty_final_chunk_with_schema(
# ASSERT: Because the method should return early, the state_delta
# should remain empty.
assert len(event.actions.state_delta) == 0

@pytest.mark.asyncio
async def test_output_key_saved_when_before_agent_callback_short_circuits(
self,
):
"""Test that output_key is written to session state when
before_agent_callback short-circuits the agent."""

def cache_callback(callback_context: CallbackContext) -> types.Content:
return types.Content(parts=[types.Part.from_text(text="cached answer")])

agent = LlmAgent(
name="test_agent",
output_key="result",
before_agent_callback=cache_callback,
)

runner = testing_utils.InMemoryRunner(agent)
await runner.run_async("hello")

assert runner.session.state.get("result") == "cached answer"