Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion contributing/samples/gepa/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from tau_bench.types import EnvRunResult
from tau_bench.types import RunConfig
import tau_bench_agent as tau_bench_agent_lib

import utils


Expand Down
1 change: 0 additions & 1 deletion contributing/samples/gepa/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from absl import flags
import experiment
from google.genai import types

import utils

_OUTPUT_DIR = flags.DEFINE_string(
Expand Down
51 changes: 34 additions & 17 deletions src/google/adk/flows/llm_flows/contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,39 @@ def _is_timestamp_compacted(ts: float) -> bool:
return [event for _, _, event in processed_items]


def filter_rewound_events(events: list[Event]) -> list[Event]:
"""Returns events with those annulled by a rewind removed.

Iterates backward; when a rewind marker is found, skips all events
back to the rewind_before_invocation_id.

Args:
events: The full event list from the session.

Returns:
A new list with rewound events removed, in the original order.
"""
# Pre-compute the first occurrence index of each invocation_id for O(1) lookup.
first_occurrence: dict[str, int] = {}
for idx, event in enumerate(events):
if event.invocation_id not in first_occurrence:
first_occurrence[event.invocation_id] = idx

filtered = []
i = len(events) - 1
while i >= 0:
event = events[i]
if event.actions and event.actions.rewind_before_invocation_id:
rewind_id = event.actions.rewind_before_invocation_id
if rewind_id in first_occurrence and first_occurrence[rewind_id] < i:
i = first_occurrence[rewind_id]
else:
filtered.append(event)
i -= 1
filtered.reverse()
return filtered


def _get_contents(
current_branch: Optional[str],
events: list[Event],
Expand All @@ -430,23 +463,7 @@ def _get_contents(
accumulated_output_transcription = ''

# Filter out events that are annulled by a rewind.
# By iterating backward, when a rewind event is found, we skip all events
# from that point back to the `rewind_before_invocation_id`, thus removing
# them from the history used for the LLM request.
rewind_filtered_events = []
i = len(events) - 1
while i >= 0:
event = events[i]
if event.actions and event.actions.rewind_before_invocation_id:
rewind_invocation_id = event.actions.rewind_before_invocation_id
for j in range(0, i, 1):
if events[j].invocation_id == rewind_invocation_id:
i = j
break
else:
rewind_filtered_events.append(event)
i -= 1
rewind_filtered_events.reverse()
rewind_filtered_events = filter_rewound_events(events)

# Parse the events, leaving the contents and the function calls and
# responses from the current agent.
Expand Down
5 changes: 3 additions & 2 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,8 @@ def _find_agent_to_run(
# the agent that returned the corresponding function call regardless the
# type of the agent. e.g. a remote a2a agent may surface a credential
# request as a special long-running function tool call.
event = find_matching_function_call(session.events)
filtered_events = contents.filter_rewound_events(session.events)
event = find_matching_function_call(filtered_events)
if event and event.author:
return root_agent.find_agent(event.author)

Expand All @@ -1142,7 +1143,7 @@ def _event_filter(event: Event) -> bool:
return False
return True

for event in filter(_event_filter, reversed(session.events)):
for event in filter(_event_filter, reversed(filtered_events)):
if event.author == root_agent.name:
# Found root agent.
return root_agent
Expand Down
92 changes: 92 additions & 0 deletions tests/unittests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from google.adk.cli.utils.agent_loader import AgentLoader
from google.adk.errors.session_not_found_error import SessionNotFoundError
from google.adk.events.event import Event
from google.adk.events.event import EventActions
from google.adk.flows.llm_flows.contents import filter_rewound_events
from google.adk.plugins.base_plugin import BasePlugin
from google.adk.runners import Runner
from google.adk.sessions.in_memory_session_service import InMemorySessionService
Expand Down Expand Up @@ -642,6 +644,96 @@ def test_is_transferable_across_agent_tree_with_non_llm_agent(self):
assert result is False


def test_find_agent_to_run_ignores_rewound_sub_agent_event():
"""After a rewind, events from the rewound invocation are ignored."""
root_agent = MockLlmAgent("root_agent")
sub_agent1 = MockLlmAgent("sub_agent1", parent_agent=root_agent)
root_agent.sub_agents = [sub_agent1]

runner = Runner(
app_name="test_app",
agent=root_agent,
session_service=InMemorySessionService(),
artifact_service=InMemoryArtifactService(),
)

# sub_agent1 was the last active agent during inv1
sub_agent_event = Event(
invocation_id="inv1",
author="sub_agent1",
content=types.Content(
role="model", parts=[types.Part(text="Sub agent response")]
),
)
# Rewind event that annuls inv1 and everything after it
rewind_event = Event(
invocation_id="inv2",
author="user",
actions=EventActions(rewind_before_invocation_id="inv1"),
)
session = Session(
id="test_session",
user_id="test_user",
app_name="test_app",
events=[sub_agent_event, rewind_event],
)

result = runner._find_agent_to_run(session, root_agent)
assert result == root_agent


def test_find_agent_to_run_ignores_rewound_function_call():
"""After a rewind, a function call from the rewound invocation is not matched."""
root_agent = MockLlmAgent("root_agent")
sub_agent2 = MockLlmAgent("sub_agent2", parent_agent=root_agent)
root_agent.sub_agents = [sub_agent2]

runner = Runner(
app_name="test_app",
agent=root_agent,
session_service=InMemorySessionService(),
artifact_service=InMemoryArtifactService(),
)

function_call = types.FunctionCall(id="func_789", name="test_func", args={})
function_response = types.FunctionResponse(
id="func_789", name="test_func", response={}
)

# sub_agent2 issued a function call in inv1
call_event = Event(
invocation_id="inv1",
author="sub_agent2",
content=types.Content(
role="model", parts=[types.Part(function_call=function_call)]
),
)
# User provides the function response, also in inv1
response_event = Event(
invocation_id="inv1",
author="user",
content=types.Content(
role="user", parts=[types.Part(function_response=function_response)]
),
)
# Rewind event that annuls inv1
rewind_event = Event(
invocation_id="inv2",
author="user",
actions=EventActions(rewind_before_invocation_id="inv1"),
)
session = Session(
id="test_session",
user_id="test_user",
app_name="test_app",
events=[call_event, response_event, rewind_event],
)

# The rewound function call should not be matched; root_agent is returned
result = runner._find_agent_to_run(session, root_agent)
assert result == root_agent


@pytest.mark.asyncio
async def test_run_config_custom_metadata_propagates_to_events():
session_service = InMemorySessionService()
Expand Down