Skip to content
Open
Binary file added main_mypy.txt
Binary file not shown.
Binary file added pr_mypy.txt
Binary file not shown.
12 changes: 9 additions & 3 deletions src/google/adk/agents/remote_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def _is_remote_response(self, event: Event) -> bool:

def _construct_message_parts_from_session(
self, ctx: InvocationContext
) -> tuple[list[A2APart], Optional[str]]:
) -> tuple[list[A2APart], Optional[str], Optional[str]]:
"""Construct A2A message parts from session events.

Args:
Expand All @@ -391,6 +391,7 @@ def _construct_message_parts_from_session(
"""
message_parts: list[A2APart] = []
context_id = None
task_id = None

events_to_process = []
for event in reversed(ctx.session.events):
Expand All @@ -400,6 +401,10 @@ def _construct_message_parts_from_session(
if event.custom_metadata:
metadata = event.custom_metadata
context_id = metadata.get(A2A_METADATA_PREFIX + "context_id")
# Always forward task_id if present. The remote agent (server) owns task
# lifecycle and will reject or ignore a stale task_id if the task is no
# longer open. Filtering by state client-side is error-prone.
task_id = metadata.get(A2A_METADATA_PREFIX + "task_id")
# Historical note: this behavior originally always applied, regardless
# of whether the agent was stateful or stateless. However, only stateful
# agents can be expected to have previous events in the remote session.
Expand Down Expand Up @@ -427,7 +432,7 @@ def _construct_message_parts_from_session(
else:
logger.warning("Failed to convert part to A2A format: %s", part)

return message_parts, context_id
return message_parts, context_id, task_id

async def _handle_a2a_response(
self, a2a_response: A2AClientEvent | A2AMessage, ctx: InvocationContext
Expand Down Expand Up @@ -624,7 +629,7 @@ async def _run_async_impl(
# Create A2A request for function response or regular message
a2a_request = self._create_a2a_request_for_user_function_response(ctx)
if not a2a_request:
message_parts, context_id = self._construct_message_parts_from_session(
message_parts, context_id, task_id = self._construct_message_parts_from_session(
ctx
)

Expand All @@ -645,6 +650,7 @@ async def _run_async_impl(
parts=message_parts,
role="user",
context_id=context_id,
task_id=task_id,
)

logger.debug(build_a2a_request_log(a2a_request))
Expand Down
38 changes: 19 additions & 19 deletions tests/unittests/agents/test_remote_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def test_construct_message_parts_from_session_success(self):
mock_a2a_part = Mock()
self.mock_genai_part_converter.return_value = mock_a2a_part

parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, task_id = self.agent._construct_message_parts_from_session(
self.mock_context
)

Expand Down Expand Up @@ -649,7 +649,7 @@ def test_construct_message_parts_from_session_success_multiple_parts(self):
mock_a2a_part2,
]

parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, task_id = self.agent._construct_message_parts_from_session(
self.mock_context
)

Expand All @@ -660,7 +660,7 @@ def test_construct_message_parts_from_session_empty_events(self):
"""Test message parts construction with empty events."""
self.mock_session.events = []

parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, task_id = self.agent._construct_message_parts_from_session(
self.mock_context
)

Expand Down Expand Up @@ -718,7 +718,7 @@ def mock_converter(part):
"google.adk.agents.remote_a2a_agent._present_other_agent_message"
) as mock_present:
mock_present.side_effect = lambda event: event
parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, task_id = self.agent._construct_message_parts_from_session(
self.mock_context
)
assert len(parts) == 1
Expand Down Expand Up @@ -768,7 +768,7 @@ def mock_converter(part):
"google.adk.agents.remote_a2a_agent._present_other_agent_message"
) as mock_present:
mock_present.side_effect = lambda event: event
parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, task_id = self.agent._construct_message_parts_from_session(
self.mock_context
)
assert len(parts) == 3
Expand Down Expand Up @@ -823,7 +823,7 @@ def mock_converter(part):
"google.adk.agents.remote_a2a_agent._present_other_agent_message"
) as mock_present:
mock_present.side_effect = lambda event: event
parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, task_id = self.agent._construct_message_parts_from_session(
self.mock_context
)
assert len(parts) == 1
Expand Down Expand Up @@ -954,7 +954,7 @@ def mock_converter(part):

self.mock_genai_part_converter.side_effect = mock_converter

parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, task_id = self.agent._construct_message_parts_from_session(
self.mock_context
)

Expand Down Expand Up @@ -1373,12 +1373,12 @@ def test_construct_message_parts_from_session_success(self):
mock_convert.return_value = mock_event

with patch.object(
self.agent, "_genai_part_converter"
self.agent, "_genai_part_converter"
) as mock_convert_part:
mock_a2a_part = Mock()
mock_convert_part.return_value = mock_a2a_part

parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, task_id = self.agent._construct_message_parts_from_session(
self.mock_context
)

Expand All @@ -1390,7 +1390,7 @@ def test_construct_message_parts_from_session_empty_events(self):
"""Test message parts construction with empty events."""
self.mock_session.events = []

parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, task_id = self.agent._construct_message_parts_from_session(
self.mock_context
)

Expand Down Expand Up @@ -1966,10 +1966,7 @@ async def test_run_async_impl_no_message_parts(self):
with patch.object(
self.agent, "_construct_message_parts_from_session"
) as mock_construct:
mock_construct.return_value = (
[],
None,
) # Tuple with empty parts and no context_id
mock_construct.return_value =([], None, None) # Tuple with empty parts and no context_id

events = []
async for event in self.agent._run_async_impl(self.mock_context):
Expand Down Expand Up @@ -1999,7 +1996,8 @@ async def test_run_async_impl_successful_request(self):
mock_construct.return_value = (
[mock_a2a_part],
"context-123",
) # Tuple with parts and context_id
None,
) # Tuple with parts and context_id , no task_id

# Mock A2A client
mock_a2a_client = create_autospec(spec=A2AClient, instance=True)
Expand Down Expand Up @@ -2071,6 +2069,7 @@ async def test_run_async_impl_a2a_client_error(self):
mock_construct.return_value = (
[mock_a2a_part],
"context-123",
None,
) # Tuple with parts and context_id

# Mock A2A client that throws an exception
Expand Down Expand Up @@ -2138,6 +2137,7 @@ async def test_run_async_impl_with_meta_provider(self):
mock_construct.return_value = (
[mock_a2a_part],
"context-123",
None,
) # Tuple with parts and context_id

# Mock A2A client
Expand Down Expand Up @@ -2242,10 +2242,8 @@ async def test_run_async_impl_no_message_parts(self):
with patch.object(
self.agent, "_construct_message_parts_from_session"
) as mock_construct:
mock_construct.return_value = (
[],
None,
) # Tuple with empty parts and no context_id
mock_construct.return_value = ([], None, None)
# Tuple with empty parts and no context_id

events = []
async for event in self.agent._run_async_impl(self.mock_context):
Expand Down Expand Up @@ -2275,6 +2273,7 @@ async def test_run_async_impl_successful_request(self):
mock_construct.return_value = (
[mock_a2a_part],
"context-123",
None,
) # Tuple with parts and context_id

# Mock A2A client
Expand Down Expand Up @@ -2349,6 +2348,7 @@ async def test_run_async_impl_a2a_client_error(self):
mock_construct.return_value = (
[mock_a2a_part],
"context-123",
None,
) # Tuple with parts and context_id

# Mock A2A client that throws an exception
Expand Down