diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index b2032c5325..7243412066 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -14,6 +14,7 @@ from __future__ import annotations +import asyncio from typing import Any from typing import cast from typing import Optional @@ -178,6 +179,14 @@ class InvocationContext(BaseModel): Set to True in callbacks or tools to terminate this invocation.""" + stop_event: Optional[asyncio.Event] = None + """An optional event that consumers can set to stop generation mid-stream. + + When set (``stop_event.set()``), the SSE streaming flow will stop yielding + new chunks and return cleanly. This is useful for implementing a "stop + generating" button in chat UIs. + """ + live_request_queue: Optional[LiveRequestQueue] = None """The queue to receive live requests.""" diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index bd0037bdcb..0ffef98f46 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -749,6 +749,8 @@ async def run_async( ) -> AsyncGenerator[Event, None]: """Runs the flow.""" while True: + if invocation_context.stop_event and invocation_context.stop_event.is_set(): + break last_event = None async with Aclosing(self._run_one_step_async(invocation_context)) as agen: async for event in agen: @@ -829,6 +831,8 @@ async def _run_one_step_async( ) ) as agen: async for llm_response in agen: + if invocation_context.stop_event and invocation_context.stop_event.is_set(): + return # Postprocess after calling the LLM. async with Aclosing( self._postprocess_async( diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 8e352794a4..7ef13ba0e4 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -509,6 +509,7 @@ async def run_async( new_message: Optional[types.Content] = None, state_delta: Optional[dict[str, Any]] = None, run_config: Optional[RunConfig] = None, + stop_event: Optional[asyncio.Event] = None, ) -> AsyncGenerator[Event, None]: """Main entry method to run the agent in this runner. @@ -526,6 +527,9 @@ async def run_async( new_message: A new message to append to the session. state_delta: Optional state changes to apply to the session. run_config: The run config for the agent. + stop_event: An optional ``asyncio.Event`` that, when set, causes the + streaming flow to stop yielding new chunks and return cleanly. This + is useful for implementing a "stop generating" button in chat UIs. Yields: The events generated by the agent. @@ -601,6 +605,9 @@ async def _run_with_trace( # already final. return + if stop_event is not None: + invocation_context.stop_event = stop_event + async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: async with Aclosing(ctx.agent.run_async(ctx)) as agen: async for event in agen: diff --git a/tests/unittests/flows/llm_flows/test_stop_event.py b/tests/unittests/flows/llm_flows/test_stop_event.py new file mode 100644 index 0000000000..cfd258416e --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_stop_event.py @@ -0,0 +1,223 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the stop_event cancellation mechanism in SSE streaming.""" + +import asyncio +from typing import AsyncGenerator +from typing import override + +from google.adk.agents.llm_agent import Agent +from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.genai import types +from pydantic import Field +import pytest + +from ... import testing_utils + + +class BaseLlmFlowForTesting(BaseLlmFlow): + """Test implementation of BaseLlmFlow for testing purposes.""" + + pass + + +class StreamingMockModel(testing_utils.MockModel): + """MockModel that yields multiple chunks per generate_content_async call.""" + + chunks_per_call: list[list[LlmResponse]] = Field(default_factory=list) + call_index: int = -1 + + @override + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + self.call_index += 1 + self.requests.append(llm_request) + for chunk in self.chunks_per_call[self.call_index]: + yield chunk + + +@pytest.mark.asyncio +async def test_stop_event_stops_streaming_mid_chunks(): + """Setting stop_event mid-stream should prevent further chunks from being yielded.""" + chunks = [ + LlmResponse( + content=types.Content( + role='model', parts=[types.Part.from_text(text='Hello')] + ), + partial=True, + ), + LlmResponse( + content=types.Content( + role='model', parts=[types.Part.from_text(text=' world')] + ), + partial=True, + ), + LlmResponse( + content=types.Content( + role='model', parts=[types.Part.from_text(text=' foo')] + ), + partial=True, + ), + LlmResponse( + content=types.Content( + role='model', parts=[types.Part.from_text(text=' bar')] + ), + partial=True, + ), + ] + + mock_model = StreamingMockModel(responses=[], chunks_per_call=[chunks]) + + agent = Agent(name='test_agent', model=mock_model) + stop_event = asyncio.Event() + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test message' + ) + invocation_context.stop_event = stop_event + + flow = BaseLlmFlowForTesting() + events = [] + async for event in flow.run_async(invocation_context): + events.append(event) + if len(events) == 2: + # Signal stop after receiving 2 chunks + stop_event.set() + + # Should have received exactly 2 chunks (stop was signalled after the 2nd) + assert len(events) == 2 + + +@pytest.mark.asyncio +async def test_stop_event_not_set_yields_all_chunks(): + """When stop_event is provided but never set, all chunks should be yielded.""" + chunks = [ + LlmResponse( + content=types.Content( + role='model', parts=[types.Part.from_text(text='Hello')] + ), + partial=True, + ), + LlmResponse( + content=types.Content( + role='model', parts=[types.Part.from_text(text=' world')] + ), + partial=True, + ), + ] + + mock_model = StreamingMockModel(responses=[], chunks_per_call=[chunks]) + + agent = Agent(name='test_agent', model=mock_model) + stop_event = asyncio.Event() + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test message' + ) + invocation_context.stop_event = stop_event + + flow = BaseLlmFlowForTesting() + events = [] + async for event in flow.run_async(invocation_context): + events.append(event) + + assert len(events) == 2 + + +@pytest.mark.asyncio +async def test_stop_event_prevents_next_llm_call(): + """Setting stop_event between LLM calls should prevent the next call.""" + # First LLM call: returns a function call + fc_response = LlmResponse( + content=types.Content( + role='model', + parts=[ + types.Part.from_function_call( + name='my_tool', args={'x': '1'} + ) + ], + ), + partial=False, + ) + # Second LLM call: should NOT be reached + text_response = LlmResponse( + content=types.Content( + role='model', parts=[types.Part.from_text(text='Result')] + ), + partial=False, + error_code=types.FinishReason.STOP, + ) + + mock_model = testing_utils.MockModel.create( + responses=[fc_response, text_response] + ) + + def my_tool(x: str) -> str: + return f'result_{x}' + + agent = Agent(name='test_agent', model=mock_model, tools=[my_tool]) + stop_event = asyncio.Event() + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test message' + ) + invocation_context.stop_event = stop_event + + flow = BaseLlmFlowForTesting() + events = [] + async for event in flow.run_async(invocation_context): + events.append(event) + # Stop after first LLM call yields its events + if event.get_function_calls(): + stop_event.set() + + # Should have events from the first LLM call only + # The second LLM call (text_response) should NOT have happened + all_texts = [ + part.text + for e in events + if e.content and e.content.parts + for part in e.content.parts + if part.text + ] + assert 'Result' not in all_texts + + +@pytest.mark.asyncio +async def test_no_stop_event_works_normally(): + """When no stop_event is provided, everything works as before.""" + response = LlmResponse( + content=types.Content( + role='model', parts=[types.Part.from_text(text='Done')] + ), + partial=False, + error_code=types.FinishReason.STOP, + ) + + mock_model = testing_utils.MockModel.create(responses=[response]) + + agent = Agent(name='test_agent', model=mock_model) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test message' + ) + # No stop_event set (default None) + + flow = BaseLlmFlowForTesting() + events = [] + async for event in flow.run_async(invocation_context): + events.append(event) + + assert len(events) == 1 + assert events[0].content.parts[0].text == 'Done'