Skip to content

Commit ae565be

Browse files
GWealecopybara-github
authored andcommitted
fix: Preserve thought_signature in LiteLLM tool calls
This change adds logic to extract and re-embed the `thought_signature` field associated with function calls in Gemini models when converting between LiteLLM's ChatCompletionMessageToolCall and ADK's types.Part Close #4650 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 882212223
1 parent 22799c0 commit ae565be

File tree

2 files changed

+345
-35
lines changed

2 files changed

+345
-35
lines changed

src/google/adk/models/lite_llm.py

Lines changed: 113 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import base64
18+
import binascii
1819
import copy
1920
import importlib.util
2021
import json
@@ -143,6 +144,11 @@
143144
"before a response was recorded)."
144145
)
145146

147+
# Separator LiteLLM uses to embed thought_signature in tool call IDs.
148+
# Gemini's thoughtSignature requirement is documented here:
149+
# https://ai.google.dev/gemini-api/docs/thought-signatures
150+
_THOUGHT_SIGNATURE_SEPARATOR = "__thought__"
151+
146152
_LITELLM_IMPORTED = False
147153
_LITELLM_GLOBAL_SYMBOLS = (
148154
"ChatCompletionAssistantMessage",
@@ -602,6 +608,27 @@ def _extract_cached_prompt_tokens(usage: Any) -> int:
602608
return 0
603609

604610

611+
def _decode_thought_signature(value: Any) -> Optional[bytes]:
612+
"""Safely decodes a thought_signature value to bytes.
613+
614+
Args:
615+
value: A base64 string or raw bytes thought_signature.
616+
617+
Returns:
618+
The decoded bytes, or None if decoding fails.
619+
"""
620+
if isinstance(value, bytes):
621+
return value
622+
try:
623+
return base64.b64decode(value, validate=True)
624+
except (binascii.Error, TypeError, ValueError):
625+
logger.debug(
626+
"Failed to decode thought_signature of type %s.",
627+
type(value).__name__,
628+
)
629+
return None
630+
631+
605632
def _extract_reasoning_tokens(usage: Any) -> int:
606633
"""Extracts reasoning tokens from LiteLLM usage.
607634
@@ -637,6 +664,64 @@ def _extract_reasoning_tokens(usage: Any) -> int:
637664
return 0
638665

639666

667+
def _extract_thought_signature_from_tool_call(
668+
tool_call: ChatCompletionMessageToolCall,
669+
) -> Optional[bytes]:
670+
"""Extracts thought_signature from a litellm tool call if present.
671+
672+
Gemini thinking models attach a thought_signature to function call parts.
673+
See https://ai.google.dev/gemini-api/docs/thought-signatures.
674+
This signature may appear in several locations depending on the
675+
provider path:
676+
1. extra_content.google.thought_signature (OpenAI-compatible API).
677+
2. provider_specific_fields on the tool call or function (Vertex).
678+
3. Embedded in the tool call ID via __thought__ separator.
679+
680+
Args:
681+
tool_call: A litellm tool call object.
682+
683+
Returns:
684+
The thought_signature as bytes, or None if not present.
685+
"""
686+
# Check extra_content.google.thought_signature (OpenAI format)
687+
extra_content = tool_call.get("extra_content")
688+
if isinstance(extra_content, dict):
689+
google_fields = extra_content.get("google")
690+
if isinstance(google_fields, dict):
691+
signature = google_fields.get("thought_signature")
692+
if signature:
693+
return _decode_thought_signature(signature)
694+
695+
# Check provider_specific_fields on the tool call
696+
provider_fields = tool_call.get("provider_specific_fields")
697+
if isinstance(provider_fields, dict):
698+
signature = provider_fields.get("thought_signature")
699+
if signature:
700+
return _decode_thought_signature(signature)
701+
702+
# Check provider_specific_fields on the function
703+
function = tool_call.get("function")
704+
if function:
705+
func_provider_fields = None
706+
if isinstance(function, dict):
707+
func_provider_fields = function.get("provider_specific_fields")
708+
elif hasattr(function, "provider_specific_fields"):
709+
func_provider_fields = function.provider_specific_fields
710+
if isinstance(func_provider_fields, dict):
711+
signature = func_provider_fields.get("thought_signature")
712+
if signature:
713+
return _decode_thought_signature(signature)
714+
715+
# Check if thought signature is embedded in the tool call ID
716+
tool_call_id = tool_call.get("id") or ""
717+
if _THOUGHT_SIGNATURE_SEPARATOR in tool_call_id:
718+
parts = tool_call_id.split(_THOUGHT_SIGNATURE_SEPARATOR, 1)
719+
if len(parts) == 2:
720+
return _decode_thought_signature(parts[1])
721+
722+
return None
723+
724+
640725
async def _content_to_message_param(
641726
content: types.Content,
642727
*,
@@ -706,16 +791,31 @@ async def _content_to_message_param(
706791
reasoning_parts: list[types.Part] = []
707792
for part in content.parts:
708793
if part.function_call:
709-
tool_calls.append(
710-
ChatCompletionAssistantToolCall(
711-
type="function",
712-
id=part.function_call.id,
713-
function=Function(
714-
name=part.function_call.name,
715-
arguments=_safe_json_serialize(part.function_call.args),
716-
),
717-
)
718-
)
794+
tool_call_id = part.function_call.id or ""
795+
tool_call_dict: ChatCompletionAssistantToolCall = {
796+
"type": "function",
797+
"id": tool_call_id,
798+
"function": {
799+
"name": part.function_call.name,
800+
"arguments": _safe_json_serialize(part.function_call.args),
801+
},
802+
}
803+
# Preserve thought_signature for Gemini thinking models.
804+
# LiteLLM's Gemini prompt conversion reads provider_specific_fields,
805+
# while the OpenAI-compatible Gemini endpoint path expects the
806+
# extra_content.google.thought_signature payload to survive.
807+
# See https://ai.google.dev/gemini-api/docs/thought-signatures.
808+
if part.thought_signature:
809+
sig = part.thought_signature
810+
if isinstance(sig, bytes):
811+
sig = base64.b64encode(sig).decode("utf-8")
812+
tool_call_dict["provider_specific_fields"] = {
813+
"thought_signature": sig
814+
}
815+
tool_call_dict["extra_content"] = {
816+
"google": {"thought_signature": sig}
817+
}
818+
tool_calls.append(tool_call_dict)
719819
elif part.thought:
720820
reasoning_parts.append(part)
721821
else:
@@ -1524,11 +1624,14 @@ def _message_to_generate_content_response(
15241624
if tool_calls:
15251625
for tool_call in tool_calls:
15261626
if tool_call.type == "function":
1627+
thought_signature = _extract_thought_signature_from_tool_call(tool_call)
15271628
part = types.Part.from_function_call(
15281629
name=tool_call.function.name,
15291630
args=json.loads(tool_call.function.arguments or "{}"),
15301631
)
15311632
part.function_call.id = tool_call.id
1633+
if thought_signature:
1634+
part.thought_signature = thought_signature
15321635
parts.append(part)
15331636

15341637
return LlmResponse(

0 commit comments

Comments
 (0)