Skip to content

Commit a95d08a

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Autoenable mTLS in environment with bound token (Agent Engine with AgentAuthority) through google-auth migration (except custom client args, custom client or custom ClientSession)
PiperOrigin-RevId: 887052636
1 parent ce86f2b commit a95d08a

8 files changed

Lines changed: 374 additions & 104 deletions

File tree

google/genai/_api_client.py

Lines changed: 196 additions & 66 deletions
Large diffs are not rendered by default.

google/genai/client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,9 @@ def _nextgen_client(self) -> AsyncGeminiNextGenAPIClient:
148148
stacklevel=5,
149149
)
150150

151-
http_client: httpx.AsyncClient = self._api_client._async_httpx_client
151+
http_client: Optional[httpx.AsyncClient] = (
152+
self._api_client._async_httpx_client
153+
)
152154

153155
async_client_args = self._api_client._http_options.async_client_args or {}
154156
has_custom_transport = 'transport' in async_client_args
@@ -308,7 +310,6 @@ class DebugConfig(pydantic.BaseModel):
308310
)
309311

310312

311-
312313
class Client:
313314
"""Client for making synchronous requests.
314315

google/genai/errors.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,25 @@
1818
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
1919
import httpx
2020
import json
21-
import websockets
21+
import requests
2222
from . import _common
2323

2424

2525
if TYPE_CHECKING:
2626
from .replay_api_client import ReplayResponse
2727
import aiohttp
28+
from google.auth.aio.transport.aiohttp import Response as AsyncAuthorizedSessionResponse
2829

2930

3031
class APIError(Exception):
3132
"""General errors raised by the GenAI API."""
3233
code: int
33-
response: Union['ReplayResponse', httpx.Response, 'aiohttp.ClientResponse']
34+
response: Union[
35+
requests.Response,
36+
'ReplayResponse',
37+
httpx.Response,
38+
'AsyncAuthorizedSessionResponse',
39+
]
3440

3541
status: Optional[str] = None
3642
message: Optional[str] = None
@@ -40,7 +46,12 @@ def __init__(
4046
code: int,
4147
response_json: Any,
4248
response: Optional[
43-
Union['ReplayResponse', httpx.Response, 'aiohttp.ClientResponse']
49+
Union[
50+
requests.Response,
51+
'ReplayResponse',
52+
httpx.Response,
53+
'AsyncAuthorizedSessionResponse',
54+
]
4455
] = None,
4556
):
4657
if isinstance(response_json, list) and len(response_json) == 1:
@@ -112,7 +123,7 @@ def _to_replay_record(self) -> _common.StringDict:
112123

113124
@classmethod
114125
def raise_for_response(
115-
cls, response: Union['ReplayResponse', httpx.Response]
126+
cls, response: Union['ReplayResponse', httpx.Response, requests.Response]
116127
) -> None:
117128
"""Raises an error with detailed error message if the response has an error status."""
118129
if response.status_code == 200:
@@ -128,6 +139,16 @@ def raise_for_response(
128139
'message': message,
129140
'status': response.reason_phrase,
130141
}
142+
elif isinstance(response, requests.Response):
143+
try:
144+
# do not do any extra muanipulation on the response.
145+
# return the raw response json as is.
146+
response_json = response.json()
147+
except requests.exceptions.JSONDecodeError:
148+
response_json = {
149+
'message': response.text,
150+
'status': response.reason,
151+
}
131152
else:
132153
response_json = response.body_segments[0].get('error', {})
133154

@@ -139,7 +160,11 @@ def raise_error(
139160
status_code: int,
140161
response_json: Any,
141162
response: Optional[
142-
Union['ReplayResponse', httpx.Response, 'aiohttp.ClientResponse']
163+
Union[
164+
'ReplayResponse',
165+
httpx.Response,
166+
requests.Response,
167+
]
143168
],
144169
) -> None:
145170
"""Raises an appropriate APIError subclass based on the status code.
@@ -166,12 +191,13 @@ def raise_error(
166191
async def raise_for_async_response(
167192
cls,
168193
response: Union[
169-
'ReplayResponse', httpx.Response, 'aiohttp.ClientResponse'
194+
'ReplayResponse',
195+
httpx.Response,
196+
'aiohttp.ClientResponse',
197+
'AsyncAuthorizedSessionResponse',
170198
],
171199
) -> None:
172200
"""Raises an error with detailed error message if the response has an error status."""
173-
status_code = 0
174-
response_json = None
175201
if isinstance(response, httpx.Response):
176202
if response.status_code == 200:
177203
return
@@ -196,18 +222,23 @@ async def raise_for_async_response(
196222
try:
197223
import aiohttp # pylint: disable=g-import-not-at-top
198224

199-
if isinstance(response, aiohttp.ClientResponse):
200-
if response.status == 200:
225+
# Use a local variable to help Mypy handle the unwrapped response
226+
unwrapped_response: Any = response
227+
if hasattr(unwrapped_response, '_response'):
228+
unwrapped_response = unwrapped_response._response
229+
230+
if isinstance(unwrapped_response, aiohttp.ClientResponse):
231+
if unwrapped_response.status == 200:
201232
return
202233
try:
203-
response_json = await response.json()
234+
response_json = await unwrapped_response.json()
204235
except aiohttp.client_exceptions.ContentTypeError:
205-
message = await response.text()
236+
message = await unwrapped_response.text()
206237
response_json = {
207238
'message': message,
208-
'status': response.reason,
239+
'status': unwrapped_response.reason,
209240
}
210-
status_code = response.status
241+
status_code = unwrapped_response.status
211242
else:
212243
raise ValueError(f'Unsupported response type: {type(response)}')
213244
except ImportError:

google/genai/tests/client/test_client_close.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def test_close_httpx_client():
4343
vertexai=True,
4444
project='test_project',
4545
location='global',
46+
http_options=api_client.HttpOptions(client_args={'max_redirects': 10}),
4647
)
4748
client.close()
4849
assert client._api_client._httpx_client.is_closed
@@ -55,6 +56,7 @@ def test_httpx_client_context_manager():
5556
vertexai=True,
5657
project='test_project',
5758
location='global',
59+
http_options=api_client.HttpOptions(client_args={'max_redirects': 10}),
5860
) as client:
5961
pass
6062
assert not client._api_client._httpx_client.is_closed
@@ -135,6 +137,9 @@ async def run():
135137
vertexai=True,
136138
project='test_project',
137139
location='global',
140+
http_options=api_client.HttpOptions(
141+
async_client_args={'trust_env': False}
142+
),
138143
).aio
139144
# aiohttp session is created in the first request instead of client
140145
# initialization.
@@ -176,6 +181,9 @@ async def run():
176181
vertexai=True,
177182
project='test_project',
178183
location='global',
184+
http_options=api_client.HttpOptions(
185+
async_client_args={'trust_env': False}
186+
),
179187
).aio as async_client:
180188
# aiohttp session is created in the first request instead of client
181189
# initialization.

google/genai/tests/client/test_client_initialization.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import concurrent.futures
2121
import logging
2222
import os
23+
import requests
2324
import ssl
2425
from unittest import mock
2526

@@ -1331,18 +1332,32 @@ def refresh_side_effect(request):
13311332
mock_refresh = mock.Mock(side_effect=refresh_side_effect)
13321333
mock_creds.refresh = mock_refresh
13331334

1334-
# Mock the actual request to avoid network calls
1335-
mock_httpx_response = httpx.Response(
1336-
status_code=200,
1337-
headers={},
1338-
text='{"candidates": [{"content": {"parts": [{"text": "response"}]}}]}',
1339-
)
1340-
mock_request = mock.Mock(return_value=mock_httpx_response)
1341-
monkeypatch.setattr(api_client.SyncHttpxClient, "request", mock_request)
1342-
13431335
client = Client(
13441336
vertexai=True, project="fake_project_id", location="fake-location"
13451337
)
1338+
# Mock the actual request to avoid network calls
1339+
if client._api_client._use_google_auth_sync():
1340+
# Cloud environment enables mTLS and uses requests.Response
1341+
mock_http_response = requests.Response()
1342+
mock_http_response.status_code = 200
1343+
mock_http_response.headers = {}
1344+
mock_http_response._content = (
1345+
b'{"candidates": [{"content": {"parts": [{"text": "response"}]}}]}'
1346+
)
1347+
mock_request = mock.Mock(return_value=mock_http_response)
1348+
monkeypatch.setattr(
1349+
google.auth.transport.requests.AuthorizedSession, "request", mock_request
1350+
)
1351+
else:
1352+
# Non-cloud environment w/o certificates uses httpx.Response
1353+
mock_httpx_response = httpx.Response(
1354+
status_code=200,
1355+
headers={},
1356+
text='{"candidates": [{"content": {"parts": [{"text": "response"}]}}]}',
1357+
)
1358+
mock_request = mock.Mock(return_value=mock_httpx_response)
1359+
monkeypatch.setattr(api_client.SyncHttpxClient, "send", mock_request)
1360+
13461361
# Reset credentials to test initialization to ensure the sync lock is tested.
13471362
client._api_client._credentials = None
13481363

0 commit comments

Comments
 (0)