Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ async def _run_node_async(
user_id: str,
session_id: str,
new_message: Optional[types.Content] = None,
state_delta: Optional[dict[str, Any]] = None,
run_config: Optional[RunConfig] = None,
yield_user_message: bool = False,
node: Optional['BaseNode'] = None,
Expand Down Expand Up @@ -512,7 +513,9 @@ async def _run_node_async(

# Append user message to session for history
if new_message:
user_event = await self._append_user_event(ic, new_message)
user_event = await self._append_user_event(
ic, new_message, state_delta=state_delta
)
if yield_user_message and user_event:
yield user_event

Expand Down Expand Up @@ -706,14 +709,26 @@ def _resolve_invocation_id_from_fr(
return invocation_ids.pop()

async def _append_user_event(
self, ic: InvocationContext, content: types.Content
self,
ic: InvocationContext,
content: types.Content,
*,
state_delta: Optional[dict[str, Any]] = None,
) -> Event:
"""Append a user message event to the session and return it."""
event = Event(
invocation_id=ic.invocation_id,
author='user',
content=content,
)
if state_delta:
event = Event(
invocation_id=ic.invocation_id,
author='user',
actions=EventActions(state_delta=state_delta),
content=content,
)
Comment thread
trongthanht3 marked this conversation as resolved.
else:
event = Event(
invocation_id=ic.invocation_id,
author='user',
content=content,
)
# when a paused task delegation is in flight, stamp
# the new user message with that task's isolation_scope so the
# task agent's content-build (scoped to <fc_id>) sees it.
Expand Down Expand Up @@ -989,6 +1004,7 @@ async def run_async(
user_id=user_id,
session_id=session_id,
new_message=new_message,
state_delta=state_delta,
run_config=run_config,
yield_user_message=yield_user_message,
node=agent_to_run,
Expand All @@ -1008,6 +1024,7 @@ async def run_async(
user_id=user_id,
session_id=session_id,
new_message=new_message,
state_delta=state_delta,
run_config=run_config,
yield_user_message=yield_user_message,
)
Expand Down
129 changes: 129 additions & 0 deletions tests/unittests/runners/test_runner_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
from typing import Any
from typing import AsyncGenerator

from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.context import Context
from google.adk.agents.llm_agent import LlmAgent
from google.adk.events.event import Event
from google.adk.runners import Runner
from google.adk.sessions.in_memory_session_service import InMemorySessionService
Expand All @@ -49,6 +51,10 @@ async def _run_impl(
yield f'Echo: {text}'


def _user_message(text: str = 'hello') -> types.Content:
return types.Content(parts=[types.Part(text=text)], role='user')


async def _run_node(node, message='hello'):
"""Run a BaseNode via Runner(node=...) and return (events, ss, session)."""
ss = InMemorySessionService()
Expand Down Expand Up @@ -288,6 +294,129 @@ async def test_yield_user_message_false_by_default():
assert user_events == []


@pytest.mark.asyncio
async def test_node_runner_applies_state_delta_before_base_node_runs():
"""A BaseNode sees run_async state_delta as session state."""

class _StateReaderNode(BaseNode):

async def _run_impl(
self, *, ctx: Context, node_input: Any
) -> AsyncGenerator[Any, None]:
yield f'state:{ctx.state["test_state"]}'

session_service = InMemorySessionService()
runner = Runner(
app_name='test',
node=_StateReaderNode(name='reader'),
session_service=session_service,
)
session = await session_service.create_session(app_name='test', user_id='u')

events: list[Event] = []
async for event in runner.run_async(
user_id='u',
session_id=session.id,
new_message=_user_message(),
state_delta={'test_state': 'must_change'},
):
events.append(event)

updated = await session_service.get_session(
app_name='test', user_id='u', session_id=session.id
)
user_events = [event for event in updated.events if event.author == 'user']

assert [event.output for event in events if event.output is not None] == [
'state:must_change'
]
assert updated.state['test_state'] == 'must_change'
assert user_events[0].actions.state_delta == {'test_state': 'must_change'}


@pytest.mark.asyncio
async def test_node_runner_yields_user_event_with_state_delta():
"""yield_user_message=True yields the user event with state_delta."""

class _NoopNode(BaseNode):

async def _run_impl(
self, *, ctx: Context, node_input: Any
) -> AsyncGenerator[Any, None]:
yield 'done'

session_service = InMemorySessionService()
runner = Runner(
app_name='test',
node=_NoopNode(name='noop'),
session_service=session_service,
)
session = await session_service.create_session(app_name='test', user_id='u')

events: list[Event] = []
async for event in runner.run_async(
user_id='u',
session_id=session.id,
new_message=_user_message(),
state_delta={'test_state': 'must_change'},
yield_user_message=True,
):
events.append(event)

assert events[0].author == 'user'
assert events[0].actions.state_delta == {'test_state': 'must_change'}


@pytest.mark.asyncio
async def test_node_runner_applies_state_delta_before_llm_agent_runs():
"""An LlmAgent callback sees run_async state_delta before model execution."""

captured_state_value = None

def _before_agent_callback(
callback_context: CallbackContext,
) -> types.Content:
nonlocal captured_state_value
captured_state_value = callback_context.state['test_state']
return types.Content(
role='model',
parts=[types.Part(text=f'state:{captured_state_value}')],
)

session_service = InMemorySessionService()
agent = LlmAgent(
name='state_agent',
before_agent_callback=_before_agent_callback,
)
runner = Runner(app_name='test', agent=agent, session_service=session_service)
session = await session_service.create_session(app_name='test', user_id='u')

events: list[Event] = []
async for event in runner.run_async(
user_id='u',
session_id=session.id,
new_message=_user_message(),
state_delta={'test_state': 'must_change'},
):
events.append(event)

updated = await session_service.get_session(
app_name='test', user_id='u', session_id=session.id
)
user_events = [event for event in updated.events if event.author == 'user']
response_texts = [
part.text
for event in events
if event.content
for part in event.content.parts
if part.text
]

assert captured_state_value == 'must_change'
assert 'state:must_change' in response_texts
assert user_events[0].actions.state_delta == {'test_state': 'must_change'}


# ---------------------------------------------------------------------------
# Resume (HITL)
# ---------------------------------------------------------------------------
Expand Down
Loading