Skip to content
57 changes: 56 additions & 1 deletion google/genai/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,44 @@ def _load_json_from_response(cls, response: Any) -> Any:
)


def _extract_retry_info_delay_seconds(
api_error: errors.APIError,
) -> Optional[float]:
if api_error.code != 429 or api_error.status != 'RESOURCE_EXHAUSTED':
return None

if not isinstance(api_error.details, dict):
return None

for path in (['error', 'details'], ['details']):
details = _common.get_value_by_path(api_error.details, path)
if not isinstance(details, list):
continue

for detail in details:
if not isinstance(detail, dict):
continue
detail_type = detail.get('@type')
if (
not isinstance(detail_type, str)
or not detail_type.endswith('google.rpc.RetryInfo')
):
continue
retry_delay = _common.get_value_by_path(detail, ['retryDelay'])
if not isinstance(retry_delay, str):
continue
retry_delay = retry_delay.strip()
if not retry_delay.endswith('s'):
continue
try:
retry_delay_seconds = float(retry_delay[:-1])
except ValueError:
continue
if retry_delay_seconds >= 0:
return retry_delay_seconds
return None


def retry_args(options: Optional[HttpRetryOptions]) -> _common.StringDict:
"""Returns the retry args for the given http retry options.

Expand All @@ -498,11 +536,28 @@ def retry_args(options: Optional[HttpRetryOptions]) -> _common.StringDict:
exp_base=options.exp_base or _RETRY_EXP_BASE,
jitter=options.jitter or _RETRY_JITTER,
)
fallback_wait = wait

def wait_with_retry_info(retry_state: tenacity.RetryCallState) -> float:
if retry_state.outcome is not None and retry_state.outcome.failed:
exception = retry_state.outcome.exception()
if isinstance(exception, errors.APIError):
retry_delay_seconds = _extract_retry_info_delay_seconds(exception)
if retry_delay_seconds is not None:
# Add one second because RetryInfo delay can be truncated.
return retry_delay_seconds + 1
return fallback_wait(retry_state)

# Preserve standard attributes.
wait_with_retry_info.initial = wait.initial # type: ignore
wait_with_retry_info.max = wait.max # type: ignore
wait_with_retry_info.exp_base = wait.exp_base # type: ignore
wait_with_retry_info.jitter = wait.jitter # type: ignore
return {
'stop': stop,
'retry': retry,
'reraise': True,
'wait': wait,
'wait': wait_with_retry_info,
'before_sleep': tenacity.before_sleep_log(logger, logging.INFO),
}

Expand Down
101 changes: 99 additions & 2 deletions google/genai/tests/client/test_retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import asyncio
from collections.abc import Sequence
import datetime
import json
from unittest import mock
import pytest
try:
Expand Down Expand Up @@ -61,11 +62,14 @@ def _final_codes(retried_codes: Sequence[int] = _RETRIED_CODES):
return [code for code in range(100, 600) if code not in retried_codes]


def _httpx_response(code: int):
def _httpx_response(code: int, response_json=None):
content = b''
if response_json is not None:
content = json.dumps(response_json).encode('utf-8')
return httpx.Response(
status_code=code,
headers={'status-code': str(code)},
content=b'',
content=content,
)


Expand Down Expand Up @@ -144,6 +148,99 @@ def fn():
assert timestamps[4] - timestamps[3] >= datetime.timedelta(seconds=8)


_RETRY_OPTIONS_NO_JITTER = types.HttpRetryOptions(
attempts=2,
initial_delay=0.25,
max_delay=10,
exp_base=2,
jitter=0,
)


def _resource_exhausted_error_payload(
retry_delay: str,
*,
status: str = 'RESOURCE_EXHAUSTED',
wrapped: bool = True,
):
details = {
'code': 429,
'message': 'Resource exhausted.',
'status': status,
'details': [
{
'@type': 'type.googleapis.com/google.rpc.RetryInfo',
'retryDelay': retry_delay,
}
],
}
if wrapped:
return {'error': details}
return details


def _retry_and_capture_sleep(status_code: int, error_payload: dict[str, object]):
def fn():
errors.APIError.raise_for_response(_httpx_response(status_code, error_payload))

retrying = tenacity.Retrying(
**api_client.retry_args(_RETRY_OPTIONS_NO_JITTER)
)
with mock.patch('tenacity.wait.random.uniform', return_value=0.0):
with mock.patch('tenacity.nap.time.sleep') as mock_sleep:
with pytest.raises(errors.APIError):
retrying(fn)
assert mock_sleep.call_count == 1
return mock_sleep.call_args.args[0]


def test_retry_wait_uses_retry_info_for_429_resource_exhausted():
retry_delay_seconds = _retry_and_capture_sleep(
429,
_resource_exhausted_error_payload('21.943984799s'),
)
assert retry_delay_seconds == pytest.approx(22.943984799)


def test_retry_wait_ignores_retry_info_when_status_not_resource_exhausted():
retry_delay_seconds = _retry_and_capture_sleep(
429,
_resource_exhausted_error_payload(
'9s', status='UNAVAILABLE'
),
)
assert retry_delay_seconds == 0.25


def test_retry_wait_ignores_retry_info_when_code_not_429():
retry_delay_seconds = _retry_and_capture_sleep(
500,
_resource_exhausted_error_payload('9s'),
)
assert retry_delay_seconds == 0.25


def test_retry_wait_falls_back_on_malformed_retry_delay():
retry_delay_seconds = _retry_and_capture_sleep(
429,
_resource_exhausted_error_payload('invalid-delay'),
)
assert retry_delay_seconds == 0.25


def test_retry_wait_supports_error_details_with_or_without_error_wrapper():
wrapped_retry_delay = _retry_and_capture_sleep(
429,
_resource_exhausted_error_payload('3.5s', wrapped=True),
)
unwrapped_retry_delay = _retry_and_capture_sleep(
429,
_resource_exhausted_error_payload('3.5s', wrapped=False),
)
assert wrapped_retry_delay == pytest.approx(4.5)
assert unwrapped_retry_delay == pytest.approx(4.5)


def test_retry_args_enabled_with_custom_values_are_not_overridden():
options = types.HttpRetryOptions(
attempts=10,
Expand Down