|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
17 | 17 | import base64 |
| 18 | +import binascii |
18 | 19 | import copy |
19 | 20 | import importlib.util |
20 | 21 | import json |
|
143 | 144 | "before a response was recorded)." |
144 | 145 | ) |
145 | 146 |
|
| 147 | +# Separator LiteLLM uses to embed thought_signature in tool call IDs. |
| 148 | +# Gemini's thoughtSignature requirement is documented here: |
| 149 | +# https://ai.google.dev/gemini-api/docs/thought-signatures |
| 150 | +_THOUGHT_SIGNATURE_SEPARATOR = "__thought__" |
| 151 | + |
146 | 152 | _LITELLM_IMPORTED = False |
147 | 153 | _LITELLM_GLOBAL_SYMBOLS = ( |
148 | 154 | "ChatCompletionAssistantMessage", |
@@ -602,6 +608,27 @@ def _extract_cached_prompt_tokens(usage: Any) -> int: |
602 | 608 | return 0 |
603 | 609 |
|
604 | 610 |
|
| 611 | +def _decode_thought_signature(value: Any) -> Optional[bytes]: |
| 612 | + """Safely decodes a thought_signature value to bytes. |
| 613 | +
|
| 614 | + Args: |
| 615 | + value: A base64 string or raw bytes thought_signature. |
| 616 | +
|
| 617 | + Returns: |
| 618 | + The decoded bytes, or None if decoding fails. |
| 619 | + """ |
| 620 | + if isinstance(value, bytes): |
| 621 | + return value |
| 622 | + try: |
| 623 | + return base64.b64decode(value, validate=True) |
| 624 | + except (binascii.Error, TypeError, ValueError): |
| 625 | + logger.debug( |
| 626 | + "Failed to decode thought_signature of type %s.", |
| 627 | + type(value).__name__, |
| 628 | + ) |
| 629 | + return None |
| 630 | + |
| 631 | + |
605 | 632 | def _extract_reasoning_tokens(usage: Any) -> int: |
606 | 633 | """Extracts reasoning tokens from LiteLLM usage. |
607 | 634 |
|
@@ -637,6 +664,64 @@ def _extract_reasoning_tokens(usage: Any) -> int: |
637 | 664 | return 0 |
638 | 665 |
|
639 | 666 |
|
| 667 | +def _extract_thought_signature_from_tool_call( |
| 668 | + tool_call: ChatCompletionMessageToolCall, |
| 669 | +) -> Optional[bytes]: |
| 670 | + """Extracts thought_signature from a litellm tool call if present. |
| 671 | +
|
| 672 | + Gemini thinking models attach a thought_signature to function call parts. |
| 673 | + See https://ai.google.dev/gemini-api/docs/thought-signatures. |
| 674 | + This signature may appear in several locations depending on the |
| 675 | + provider path: |
| 676 | + 1. extra_content.google.thought_signature (OpenAI-compatible API). |
| 677 | + 2. provider_specific_fields on the tool call or function (Vertex). |
| 678 | + 3. Embedded in the tool call ID via __thought__ separator. |
| 679 | +
|
| 680 | + Args: |
| 681 | + tool_call: A litellm tool call object. |
| 682 | +
|
| 683 | + Returns: |
| 684 | + The thought_signature as bytes, or None if not present. |
| 685 | + """ |
| 686 | + # Check extra_content.google.thought_signature (OpenAI format) |
| 687 | + extra_content = tool_call.get("extra_content") |
| 688 | + if isinstance(extra_content, dict): |
| 689 | + google_fields = extra_content.get("google") |
| 690 | + if isinstance(google_fields, dict): |
| 691 | + signature = google_fields.get("thought_signature") |
| 692 | + if signature: |
| 693 | + return _decode_thought_signature(signature) |
| 694 | + |
| 695 | + # Check provider_specific_fields on the tool call |
| 696 | + provider_fields = tool_call.get("provider_specific_fields") |
| 697 | + if isinstance(provider_fields, dict): |
| 698 | + signature = provider_fields.get("thought_signature") |
| 699 | + if signature: |
| 700 | + return _decode_thought_signature(signature) |
| 701 | + |
| 702 | + # Check provider_specific_fields on the function |
| 703 | + function = tool_call.get("function") |
| 704 | + if function: |
| 705 | + func_provider_fields = None |
| 706 | + if isinstance(function, dict): |
| 707 | + func_provider_fields = function.get("provider_specific_fields") |
| 708 | + elif hasattr(function, "provider_specific_fields"): |
| 709 | + func_provider_fields = function.provider_specific_fields |
| 710 | + if isinstance(func_provider_fields, dict): |
| 711 | + signature = func_provider_fields.get("thought_signature") |
| 712 | + if signature: |
| 713 | + return _decode_thought_signature(signature) |
| 714 | + |
| 715 | + # Check if thought signature is embedded in the tool call ID |
| 716 | + tool_call_id = tool_call.get("id") or "" |
| 717 | + if _THOUGHT_SIGNATURE_SEPARATOR in tool_call_id: |
| 718 | + parts = tool_call_id.split(_THOUGHT_SIGNATURE_SEPARATOR, 1) |
| 719 | + if len(parts) == 2: |
| 720 | + return _decode_thought_signature(parts[1]) |
| 721 | + |
| 722 | + return None |
| 723 | + |
| 724 | + |
640 | 725 | async def _content_to_message_param( |
641 | 726 | content: types.Content, |
642 | 727 | *, |
@@ -706,16 +791,31 @@ async def _content_to_message_param( |
706 | 791 | reasoning_parts: list[types.Part] = [] |
707 | 792 | for part in content.parts: |
708 | 793 | if part.function_call: |
709 | | - tool_calls.append( |
710 | | - ChatCompletionAssistantToolCall( |
711 | | - type="function", |
712 | | - id=part.function_call.id, |
713 | | - function=Function( |
714 | | - name=part.function_call.name, |
715 | | - arguments=_safe_json_serialize(part.function_call.args), |
716 | | - ), |
717 | | - ) |
718 | | - ) |
| 794 | + tool_call_id = part.function_call.id or "" |
| 795 | + tool_call_dict: ChatCompletionAssistantToolCall = { |
| 796 | + "type": "function", |
| 797 | + "id": tool_call_id, |
| 798 | + "function": { |
| 799 | + "name": part.function_call.name, |
| 800 | + "arguments": _safe_json_serialize(part.function_call.args), |
| 801 | + }, |
| 802 | + } |
| 803 | + # Preserve thought_signature for Gemini thinking models. |
| 804 | + # LiteLLM's Gemini prompt conversion reads provider_specific_fields, |
| 805 | + # while the OpenAI-compatible Gemini endpoint path expects the |
| 806 | + # extra_content.google.thought_signature payload to survive. |
| 807 | + # See https://ai.google.dev/gemini-api/docs/thought-signatures. |
| 808 | + if part.thought_signature: |
| 809 | + sig = part.thought_signature |
| 810 | + if isinstance(sig, bytes): |
| 811 | + sig = base64.b64encode(sig).decode("utf-8") |
| 812 | + tool_call_dict["provider_specific_fields"] = { |
| 813 | + "thought_signature": sig |
| 814 | + } |
| 815 | + tool_call_dict["extra_content"] = { |
| 816 | + "google": {"thought_signature": sig} |
| 817 | + } |
| 818 | + tool_calls.append(tool_call_dict) |
719 | 819 | elif part.thought: |
720 | 820 | reasoning_parts.append(part) |
721 | 821 | else: |
@@ -1524,11 +1624,14 @@ def _message_to_generate_content_response( |
1524 | 1624 | if tool_calls: |
1525 | 1625 | for tool_call in tool_calls: |
1526 | 1626 | if tool_call.type == "function": |
| 1627 | + thought_signature = _extract_thought_signature_from_tool_call(tool_call) |
1527 | 1628 | part = types.Part.from_function_call( |
1528 | 1629 | name=tool_call.function.name, |
1529 | 1630 | args=json.loads(tool_call.function.arguments or "{}"), |
1530 | 1631 | ) |
1531 | 1632 | part.function_call.id = tool_call.id |
| 1633 | + if thought_signature: |
| 1634 | + part.thought_signature = thought_signature |
1532 | 1635 | parts.append(part) |
1533 | 1636 |
|
1534 | 1637 | return LlmResponse( |
|
0 commit comments