Skip to content

Commit df46c8d

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: support mTLS in sync default CUJs through google-auth migration (except custom args and custom httpx.Client)
PiperOrigin-RevId: 884659552
1 parent 6c3379f commit df46c8d

4 files changed

Lines changed: 125 additions & 44 deletions

File tree

google/genai/_api_client.py

Lines changed: 86 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Google LLC
1+
# Copyright 2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -34,7 +34,7 @@
3434
import sys
3535
import threading
3636
import time
37-
from typing import Any, AsyncIterator, Iterator, Optional, Tuple, TYPE_CHECKING, Union
37+
from typing import Any, AsyncIterator, Iterator, Optional, TYPE_CHECKING, Tuple, Union
3838
from urllib.parse import urlparse
3939
from urllib.parse import urlunparse
4040
import warnings
@@ -44,9 +44,13 @@
4444
import google.auth
4545
import google.auth.credentials
4646
from google.auth.credentials import Credentials
47+
from google.auth.transport import mtls
48+
from google.auth.transport.requests import AuthorizedSession
4749
import httpx
4850
from pydantic import BaseModel
4951
from pydantic import ValidationError
52+
import requests
53+
from requests.structures import CaseInsensitiveDict
5054
import tenacity
5155

5256
from . import _common
@@ -182,12 +186,6 @@ def join_url_path(base_url: str, path: str) -> str:
182186
def load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
183187
"""Loads google auth credentials and project id."""
184188

185-
## Set GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES to false
186-
## to disable bound token sharing. Tracking on
187-
## https://github.com/googleapis/python-genai/issues/1956
188-
os.environ['GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES'] = (
189-
'false'
190-
)
191189
credentials, loaded_project_id = google.auth.default( # type: ignore[no-untyped-call]
192190
scopes=['https://www.googleapis.com/auth/cloud-platform'],
193191
)
@@ -235,7 +233,12 @@ class HttpResponse:
235233

236234
def __init__(
237235
self,
238-
headers: Union[dict[str, str], httpx.Headers, 'CIMultiDictProxy[str]'],
236+
headers: Union[
237+
dict[str, str],
238+
httpx.Headers,
239+
'CIMultiDictProxy[str]',
240+
CaseInsensitiveDict,
241+
],
239242
response_stream: Union[Any, str] = None,
240243
byte_stream: Union[Any, bytes] = None,
241244
):
@@ -245,6 +248,8 @@ def __init__(
245248
self.headers = {
246249
key: ', '.join(headers.get_list(key)) for key in headers.keys()
247250
}
251+
elif isinstance(headers, CaseInsensitiveDict):
252+
self.headers = {key: value for key, value in headers.items()}
248253
elif type(headers).__name__ == 'CIMultiDictProxy':
249254
self.headers = {
250255
key: ', '.join(headers.getall(key)) for key in headers.keys()
@@ -321,15 +326,22 @@ def _copy_to_dict(self, response_payload: dict[str, object]) -> None:
321326

322327
def _iter_response_stream(self) -> Iterator[str]:
323328
"""Iterates over chunks retrieved from the API."""
324-
if not isinstance(self.response_stream, httpx.Response):
329+
if not (
330+
isinstance(self.response_stream, httpx.Response)
331+
or isinstance(self.response_stream, requests.Response)
332+
):
325333
raise TypeError(
326334
'Expected self.response_stream to be an httpx.Response object, '
327335
f'but got {type(self.response_stream).__name__}.'
328336
)
329337

330338
chunk = ''
331339
balance = 0
332-
for line in self.response_stream.iter_lines():
340+
if isinstance(self.response_stream, httpx.Response):
341+
response_stream = self.response_stream.iter_lines()
342+
else:
343+
response_stream = self.response_stream.iter_lines(decode_unicode=True)
344+
for line in response_stream:
333345
if not line:
334346
continue
335347

@@ -593,7 +605,10 @@ def __init__(
593605
elif http_options and _common.is_duck_type_of(http_options, HttpOptions):
594606
validated_http_options = http_options
595607

596-
if validated_http_options.base_url_resource_scope and not validated_http_options.base_url:
608+
if (
609+
validated_http_options.base_url_resource_scope
610+
and not validated_http_options.base_url
611+
):
597612
# base_url_resource_scope is only valid when base_url is set.
598613
raise ValueError(
599614
'base_url must be set when base_url_resource_scope is set.'
@@ -729,8 +744,11 @@ def __init__(
729744
self._http_options
730745
)
731746
self._async_httpx_client_args = async_client_args
747+
self.authorized_session: Optional[AuthorizedSession] = None
732748

733-
if self._http_options.httpx_client:
749+
if self._use_google_auth_sync():
750+
self._httpx_client = None
751+
elif self._http_options.httpx_client:
734752
self._httpx_client = self._http_options.httpx_client
735753
else:
736754
self._httpx_client = SyncHttpxClient(**client_args)
@@ -759,7 +777,14 @@ def __init__(
759777
self._retry = tenacity.Retrying(**retry_kwargs)
760778
self._async_retry = tenacity.AsyncRetrying(**retry_kwargs)
761779

762-
async def _get_aiohttp_session(self) -> 'aiohttp.ClientSession':
780+
def _use_google_auth_sync(self) -> Union[bool, None]:
781+
return self.vertexai and not (
782+
self._http_options.httpx_client or self._http_options.client_args
783+
)
784+
785+
async def _get_aiohttp_session(
786+
self,
787+
) -> 'aiohttp.ClientSession':
763788
"""Returns the aiohttp client session."""
764789
if (
765790
self._aiohttp_session is None
@@ -1003,6 +1028,11 @@ def _websocket_base_url(self) -> str:
10031028

10041029
def _access_token(self) -> str:
10051030
"""Retrieves the access token for the credentials."""
1031+
# Set GOOGLE_API_USE_CLIENT_CERTIFICATE to true to enable bound token sharing.
1032+
os.environ['GOOGLE_API_USE_CLIENT_CERTIFICATE'] = 'true'
1033+
os.environ['GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES'] = (
1034+
'true'
1035+
)
10061036
with self._sync_auth_lock:
10071037
if not self._credentials:
10081038
self._credentials, project = load_auth(project=self.project)
@@ -1041,6 +1071,12 @@ async def _get_async_auth_lock(self) -> asyncio.Lock:
10411071

10421072
async def _async_access_token(self) -> Union[str, Any]:
10431073
"""Retrieves the access token for the credentials asynchronously."""
1074+
# Set GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES to false
1075+
# to disable bound token sharing. Tracking on
1076+
# https://github.com/googleapis/python-genai/issues/1956
1077+
os.environ['GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES'] = (
1078+
'false'
1079+
)
10441080
if not self._credentials:
10451081
async_auth_lock = await self._get_async_auth_lock()
10461082
async with async_auth_lock:
@@ -1190,31 +1226,44 @@ def _request_once(
11901226
else:
11911227
data = http_request.data
11921228

1193-
if stream:
1194-
httpx_request = self._httpx_client.build_request(
1195-
method=http_request.method,
1196-
url=http_request.url,
1197-
content=data,
1229+
if self._use_google_auth_sync():
1230+
url = str(http_request.url)
1231+
if self.authorized_session is None:
1232+
self.authorized_session = AuthorizedSession( # type: ignore[no-untyped-call]
1233+
self._credentials,
1234+
max_refresh_attempts=1,
1235+
)
1236+
# Application default SSL credentials will be used to configure mtls
1237+
# channel.
1238+
self.authorized_session.configure_mtls_channel() # type: ignore[no-untyped-call]
1239+
if self.authorized_session._is_mtls and 'googleapis.com' in url:
1240+
if 'sandbox' in url:
1241+
url = url.replace(
1242+
'sandbox.googleapis.com', 'mtls.sandbox.googleapis.com'
1243+
)
1244+
else:
1245+
url = url.replace('googleapis.com', 'mtls.googleapis.com')
1246+
response = self.authorized_session.request( # type: ignore[no-untyped-call]
1247+
method=http_request.method.upper(),
1248+
url=url,
1249+
data=data,
11981250
headers=http_request.headers,
11991251
timeout=http_request.timeout,
1200-
)
1201-
response = self._httpx_client.send(httpx_request, stream=stream)
1202-
errors.APIError.raise_for_response(response)
1203-
return HttpResponse(
1204-
response.headers, response if stream else [response.text]
1252+
stream=stream,
12051253
)
12061254
else:
1207-
response = self._httpx_client.request(
1255+
httpx_request = self._httpx_client.build_request( # type: ignore[union-attr]
12081256
method=http_request.method,
12091257
url=http_request.url,
1210-
headers=http_request.headers,
12111258
content=data,
1259+
headers=http_request.headers,
12121260
timeout=http_request.timeout,
12131261
)
1214-
errors.APIError.raise_for_response(response)
1215-
return HttpResponse(
1216-
response.headers, response if stream else [response.text]
1217-
)
1262+
response = self._httpx_client.send(httpx_request, stream=stream) # type: ignore[union-attr]
1263+
errors.APIError.raise_for_response(response)
1264+
return HttpResponse(
1265+
response.headers, response if stream else [response.text]
1266+
)
12181267

12191268
def _request(
12201269
self,
@@ -1590,7 +1639,7 @@ def _upload_fd(
15901639
populate_server_timeout_header(upload_headers, timeout_in_seconds)
15911640
retry_count = 0
15921641
while retry_count < MAX_RETRY_COUNT:
1593-
response = self._httpx_client.request(
1642+
response = self._httpx_client.request( # type: ignore[union-attr]
15941643
method='POST',
15951644
url=upload_url,
15961645
headers=upload_headers,
@@ -1642,7 +1691,7 @@ def download_file(
16421691
else:
16431692
data = http_request.data
16441693

1645-
response = self._httpx_client.request(
1694+
response = self._httpx_client.request( # type: ignore[union-attr]
16461695
method=http_request.method,
16471696
url=http_request.url,
16481697
headers=http_request.headers,
@@ -1956,8 +2005,10 @@ def close(self) -> None:
19562005
"""Closes the API client."""
19572006
# Let users close the custom client explicitly by themselves. Otherwise,
19582007
# close the client when the object is garbage collected.
1959-
if not self._http_options.httpx_client:
2008+
if not self._http_options.httpx_client and self._httpx_client:
19602009
self._httpx_client.close()
2010+
if self.authorized_session:
2011+
self.authorized_session.close() # type: ignore[no-untyped-call]
19612012

19622013
async def aclose(self) -> None:
19632014
"""Closes the API async client."""
@@ -1986,6 +2037,7 @@ def __del__(self) -> None:
19862037
except Exception: # pylint: disable=broad-except
19872038
pass
19882039

2040+
19892041
def get_token_from_credentials(
19902042
client: 'BaseApiClient',
19912043
credentials: google.auth.credentials.Credentials
@@ -1998,6 +2050,7 @@ def get_token_from_credentials(
19982050
raise RuntimeError('Could not resolve API token from the environment')
19992051
return credentials.token # type: ignore[no-any-return]
20002052

2053+
20012054
async def async_get_token_from_credentials(
20022055
client: 'BaseApiClient',
20032056
credentials: google.auth.credentials.Credentials

google/genai/errors.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
1919
import httpx
2020
import json
21-
import websockets
21+
import requests
2222
from . import _common
2323

2424

@@ -30,7 +30,11 @@
3030
class APIError(Exception):
3131
"""General errors raised by the GenAI API."""
3232
code: int
33-
response: Union['ReplayResponse', httpx.Response, 'aiohttp.ClientResponse']
33+
response: Union[
34+
requests.Response,
35+
'ReplayResponse',
36+
httpx.Response,
37+
]
3438

3539
status: Optional[str] = None
3640
message: Optional[str] = None
@@ -40,7 +44,11 @@ def __init__(
4044
code: int,
4145
response_json: Any,
4246
response: Optional[
43-
Union['ReplayResponse', httpx.Response, 'aiohttp.ClientResponse']
47+
Union[
48+
requests.Response,
49+
'ReplayResponse',
50+
httpx.Response,
51+
]
4452
] = None,
4553
):
4654
if isinstance(response_json, list) and len(response_json) == 1:
@@ -112,7 +120,7 @@ def _to_replay_record(self) -> _common.StringDict:
112120

113121
@classmethod
114122
def raise_for_response(
115-
cls, response: Union['ReplayResponse', httpx.Response]
123+
cls, response: Union['ReplayResponse', httpx.Response, requests.Response]
116124
) -> None:
117125
"""Raises an error with detailed error message if the response has an error status."""
118126
if response.status_code == 200:
@@ -128,6 +136,16 @@ def raise_for_response(
128136
'message': message,
129137
'status': response.reason_phrase,
130138
}
139+
elif isinstance(response, requests.Response):
140+
try:
141+
# do not do any extra muanipulation on the response.
142+
# return the raw response json as is.
143+
response_json = response.json()
144+
except requests.exceptions.JSONDecodeError:
145+
response_json = {
146+
'message': response.text,
147+
'status': response.reason,
148+
}
131149
else:
132150
response_json = response.body_segments[0].get('error', {})
133151

@@ -139,7 +157,11 @@ def raise_error(
139157
status_code: int,
140158
response_json: Any,
141159
response: Optional[
142-
Union['ReplayResponse', httpx.Response, 'aiohttp.ClientResponse']
160+
Union[
161+
'ReplayResponse',
162+
httpx.Response,
163+
requests.Response,
164+
]
143165
],
144166
) -> None:
145167
"""Raises an appropriate APIError subclass based on the status code.

google/genai/tests/client/test_client_close.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def test_close_httpx_client():
4343
vertexai=True,
4444
project='test_project',
4545
location='global',
46+
http_options=api_client.HttpOptions(client_args={'max_redirects': 10}),
4647
)
4748
client.close()
4849
assert client._api_client._httpx_client.is_closed
@@ -55,6 +56,7 @@ def test_httpx_client_context_manager():
5556
vertexai=True,
5657
project='test_project',
5758
location='global',
59+
http_options=api_client.HttpOptions(client_args={'max_redirects': 10}),
5860
) as client:
5961
pass
6062
assert not client._api_client._httpx_client.is_closed

google/genai/tests/client/test_client_initialization.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import concurrent.futures
2121
import logging
2222
import os
23+
import requests
2324
import ssl
2425
from unittest import mock
2526

@@ -1332,13 +1333,16 @@ def refresh_side_effect(request):
13321333
mock_creds.refresh = mock_refresh
13331334

13341335
# Mock the actual request to avoid network calls
1335-
mock_httpx_response = httpx.Response(
1336-
status_code=200,
1337-
headers={},
1338-
text='{"candidates": [{"content": {"parts": [{"text": "response"}]}}]}',
1336+
mock_http_response = requests.Response()
1337+
mock_http_response.status_code = 200
1338+
mock_http_response.headers = {}
1339+
mock_http_response._content = (
1340+
b'{"candidates": [{"content": {"parts": [{"text": "response"}]}}]}'
1341+
)
1342+
mock_request = mock.Mock(return_value=mock_http_response)
1343+
monkeypatch.setattr(
1344+
google.auth.transport.requests.AuthorizedSession, "request", mock_request
13391345
)
1340-
mock_request = mock.Mock(return_value=mock_httpx_response)
1341-
monkeypatch.setattr(api_client.SyncHttpxClient, "request", mock_request)
13421346

13431347
client = Client(
13441348
vertexai=True, project="fake_project_id", location="fake-location"

0 commit comments

Comments
 (0)