1- # Copyright 2025 Google LLC
1+ # Copyright 2026 Google LLC
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
3434import sys
3535import threading
3636import time
37- from typing import Any , AsyncIterator , Iterator , Optional , Tuple , TYPE_CHECKING , Union
37+ from typing import Any , AsyncIterator , Iterator , Optional , TYPE_CHECKING , Tuple , Union
3838from urllib .parse import urlparse
3939from urllib .parse import urlunparse
4040import warnings
4444import google .auth
4545import google .auth .credentials
4646from google .auth .credentials import Credentials
47+ from google .auth .transport import mtls
48+ from google .auth .transport .requests import AuthorizedSession
4749import httpx
4850from pydantic import BaseModel
4951from pydantic import ValidationError
52+ import requests
53+ from requests .structures import CaseInsensitiveDict
5054import tenacity
5155
5256from . import _common
@@ -182,12 +186,6 @@ def join_url_path(base_url: str, path: str) -> str:
182186def load_auth (* , project : Union [str , None ]) -> Tuple [Credentials , str ]:
183187 """Loads google auth credentials and project id."""
184188
185- ## Set GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES to false
186- ## to disable bound token sharing. Tracking on
187- ## https://github.com/googleapis/python-genai/issues/1956
188- os .environ ['GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES' ] = (
189- 'false'
190- )
191189 credentials , loaded_project_id = google .auth .default ( # type: ignore[no-untyped-call]
192190 scopes = ['https://www.googleapis.com/auth/cloud-platform' ],
193191 )
@@ -235,7 +233,12 @@ class HttpResponse:
235233
236234 def __init__ (
237235 self ,
238- headers : Union [dict [str , str ], httpx .Headers , 'CIMultiDictProxy[str]' ],
236+ headers : Union [
237+ dict [str , str ],
238+ httpx .Headers ,
239+ 'CIMultiDictProxy[str]' ,
240+ CaseInsensitiveDict ,
241+ ],
239242 response_stream : Union [Any , str ] = None ,
240243 byte_stream : Union [Any , bytes ] = None ,
241244 ):
@@ -245,6 +248,8 @@ def __init__(
245248 self .headers = {
246249 key : ', ' .join (headers .get_list (key )) for key in headers .keys ()
247250 }
251+ elif isinstance (headers , CaseInsensitiveDict ):
252+ self .headers = {key : value for key , value in headers .items ()}
248253 elif type (headers ).__name__ == 'CIMultiDictProxy' :
249254 self .headers = {
250255 key : ', ' .join (headers .getall (key )) for key in headers .keys ()
@@ -321,15 +326,22 @@ def _copy_to_dict(self, response_payload: dict[str, object]) -> None:
321326
322327 def _iter_response_stream (self ) -> Iterator [str ]:
323328 """Iterates over chunks retrieved from the API."""
324- if not isinstance (self .response_stream , httpx .Response ):
329+ if not (
330+ isinstance (self .response_stream , httpx .Response )
331+ or isinstance (self .response_stream , requests .Response )
332+ ):
325333 raise TypeError (
326334 'Expected self.response_stream to be an httpx.Response object, '
327335 f'but got { type (self .response_stream ).__name__ } .'
328336 )
329337
330338 chunk = ''
331339 balance = 0
332- for line in self .response_stream .iter_lines ():
340+ if isinstance (self .response_stream , httpx .Response ):
341+ response_stream = self .response_stream .iter_lines ()
342+ else :
343+ response_stream = self .response_stream .iter_lines (decode_unicode = True )
344+ for line in response_stream :
333345 if not line :
334346 continue
335347
@@ -593,7 +605,10 @@ def __init__(
593605 elif http_options and _common .is_duck_type_of (http_options , HttpOptions ):
594606 validated_http_options = http_options
595607
596- if validated_http_options .base_url_resource_scope and not validated_http_options .base_url :
608+ if (
609+ validated_http_options .base_url_resource_scope
610+ and not validated_http_options .base_url
611+ ):
597612 # base_url_resource_scope is only valid when base_url is set.
598613 raise ValueError (
599614 'base_url must be set when base_url_resource_scope is set.'
@@ -729,8 +744,11 @@ def __init__(
729744 self ._http_options
730745 )
731746 self ._async_httpx_client_args = async_client_args
747+ self .authorized_session : Optional [AuthorizedSession ] = None
732748
733- if self ._http_options .httpx_client :
749+ if self ._use_google_auth_sync ():
750+ self ._httpx_client = None
751+ elif self ._http_options .httpx_client :
734752 self ._httpx_client = self ._http_options .httpx_client
735753 else :
736754 self ._httpx_client = SyncHttpxClient (** client_args )
@@ -759,7 +777,14 @@ def __init__(
759777 self ._retry = tenacity .Retrying (** retry_kwargs )
760778 self ._async_retry = tenacity .AsyncRetrying (** retry_kwargs )
761779
762- async def _get_aiohttp_session (self ) -> 'aiohttp.ClientSession' :
780+ def _use_google_auth_sync (self ) -> Union [bool , None ]:
781+ return self .vertexai and not (
782+ self ._http_options .httpx_client or self ._http_options .client_args
783+ )
784+
785+ async def _get_aiohttp_session (
786+ self ,
787+ ) -> 'aiohttp.ClientSession' :
763788 """Returns the aiohttp client session."""
764789 if (
765790 self ._aiohttp_session is None
@@ -1003,6 +1028,11 @@ def _websocket_base_url(self) -> str:
10031028
10041029 def _access_token (self ) -> str :
10051030 """Retrieves the access token for the credentials."""
1031+ # Set GOOGLE_API_USE_CLIENT_CERTIFICATE to true to enable bound token sharing.
1032+ os .environ ['GOOGLE_API_USE_CLIENT_CERTIFICATE' ] = 'true'
1033+ os .environ ['GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES' ] = (
1034+ 'true'
1035+ )
10061036 with self ._sync_auth_lock :
10071037 if not self ._credentials :
10081038 self ._credentials , project = load_auth (project = self .project )
@@ -1041,6 +1071,12 @@ async def _get_async_auth_lock(self) -> asyncio.Lock:
10411071
10421072 async def _async_access_token (self ) -> Union [str , Any ]:
10431073 """Retrieves the access token for the credentials asynchronously."""
1074+ # Set GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES to false
1075+ # to disable bound token sharing. Tracking on
1076+ # https://github.com/googleapis/python-genai/issues/1956
1077+ os .environ ['GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES' ] = (
1078+ 'false'
1079+ )
10441080 if not self ._credentials :
10451081 async_auth_lock = await self ._get_async_auth_lock ()
10461082 async with async_auth_lock :
@@ -1190,31 +1226,44 @@ def _request_once(
11901226 else :
11911227 data = http_request .data
11921228
1193- if stream :
1194- httpx_request = self ._httpx_client .build_request (
1195- method = http_request .method ,
1196- url = http_request .url ,
1197- content = data ,
1229+ if self ._use_google_auth_sync ():
1230+ url = str (http_request .url )
1231+ if self .authorized_session is None :
1232+ self .authorized_session = AuthorizedSession ( # type: ignore[no-untyped-call]
1233+ self ._credentials ,
1234+ max_refresh_attempts = 1 ,
1235+ )
1236+ # Application default SSL credentials will be used to configure mtls
1237+ # channel.
1238+ self .authorized_session .configure_mtls_channel () # type: ignore[no-untyped-call]
1239+ if self .authorized_session ._is_mtls and 'googleapis.com' in url :
1240+ if 'sandbox' in url :
1241+ url = url .replace (
1242+ 'sandbox.googleapis.com' , 'mtls.sandbox.googleapis.com'
1243+ )
1244+ else :
1245+ url = url .replace ('googleapis.com' , 'mtls.googleapis.com' )
1246+ response = self .authorized_session .request ( # type: ignore[no-untyped-call]
1247+ method = http_request .method .upper (),
1248+ url = url ,
1249+ data = data ,
11981250 headers = http_request .headers ,
11991251 timeout = http_request .timeout ,
1200- )
1201- response = self ._httpx_client .send (httpx_request , stream = stream )
1202- errors .APIError .raise_for_response (response )
1203- return HttpResponse (
1204- response .headers , response if stream else [response .text ]
1252+ stream = stream ,
12051253 )
12061254 else :
1207- response = self ._httpx_client .request (
1255+ httpx_request = self ._httpx_client .build_request ( # type: ignore[union-attr]
12081256 method = http_request .method ,
12091257 url = http_request .url ,
1210- headers = http_request .headers ,
12111258 content = data ,
1259+ headers = http_request .headers ,
12121260 timeout = http_request .timeout ,
12131261 )
1214- errors .APIError .raise_for_response (response )
1215- return HttpResponse (
1216- response .headers , response if stream else [response .text ]
1217- )
1262+ response = self ._httpx_client .send (httpx_request , stream = stream ) # type: ignore[union-attr]
1263+ errors .APIError .raise_for_response (response )
1264+ return HttpResponse (
1265+ response .headers , response if stream else [response .text ]
1266+ )
12181267
12191268 def _request (
12201269 self ,
@@ -1590,7 +1639,7 @@ def _upload_fd(
15901639 populate_server_timeout_header (upload_headers , timeout_in_seconds )
15911640 retry_count = 0
15921641 while retry_count < MAX_RETRY_COUNT :
1593- response = self ._httpx_client .request (
1642+ response = self ._httpx_client .request ( # type: ignore[union-attr]
15941643 method = 'POST' ,
15951644 url = upload_url ,
15961645 headers = upload_headers ,
@@ -1642,7 +1691,7 @@ def download_file(
16421691 else :
16431692 data = http_request .data
16441693
1645- response = self ._httpx_client .request (
1694+ response = self ._httpx_client .request ( # type: ignore[union-attr]
16461695 method = http_request .method ,
16471696 url = http_request .url ,
16481697 headers = http_request .headers ,
@@ -1956,8 +2005,10 @@ def close(self) -> None:
19562005 """Closes the API client."""
19572006 # Let users close the custom client explicitly by themselves. Otherwise,
19582007 # close the client when the object is garbage collected.
1959- if not self ._http_options .httpx_client :
2008+ if not self ._http_options .httpx_client and self . _httpx_client :
19602009 self ._httpx_client .close ()
2010+ if self .authorized_session :
2011+ self .authorized_session .close () # type: ignore[no-untyped-call]
19612012
19622013 async def aclose (self ) -> None :
19632014 """Closes the API async client."""
@@ -1986,6 +2037,7 @@ def __del__(self) -> None:
19862037 except Exception : # pylint: disable=broad-except
19872038 pass
19882039
2040+
19892041def get_token_from_credentials (
19902042 client : 'BaseApiClient' ,
19912043 credentials : google .auth .credentials .Credentials
@@ -1998,6 +2050,7 @@ def get_token_from_credentials(
19982050 raise RuntimeError ('Could not resolve API token from the environment' )
19992051 return credentials .token # type: ignore[no-any-return]
20002052
2053+
20012054async def async_get_token_from_credentials (
20022055 client : 'BaseApiClient' ,
20032056 credentials : google .auth .credentials .Credentials
0 commit comments