Skip to content

Commit 1dbcecc

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: update Spanner query tools to async functions
PiperOrigin-RevId: 874318392
1 parent 37d52b4 commit 1dbcecc

4 files changed

Lines changed: 77 additions & 67 deletions

File tree

src/google/adk/tools/spanner/query_tool.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
from __future__ import annotations
1616

17+
import asyncio
1718
import functools
1819
import textwrap
19-
import types
2020
from typing import Callable
2121

2222
from google.auth.credentials import Credentials
@@ -27,7 +27,7 @@
2727
from .settings import SpannerToolSettings
2828

2929

30-
def execute_sql(
30+
async def execute_sql(
3131
project_id: str,
3232
instance_id: str,
3333
database_id: str,
@@ -82,7 +82,8 @@ def execute_sql(
8282
Note:
8383
This is running with Read-Only Transaction for query that only read data.
8484
"""
85-
return utils.execute_sql(
85+
return await asyncio.to_thread(
86+
utils.execute_sql,
8687
project_id,
8788
instance_id,
8889
database_id,
@@ -179,15 +180,10 @@ def get_execute_sql(settings: SpannerToolSettings) -> Callable[..., dict]:
179180

180181
if settings and settings.query_result_mode is QueryResultMode.DICT_LIST:
181182

182-
execute_sql_wrapper = types.FunctionType(
183-
execute_sql.__code__,
184-
execute_sql.__globals__,
185-
execute_sql.__name__,
186-
execute_sql.__defaults__,
187-
execute_sql.__closure__,
188-
)
189-
functools.update_wrapper(execute_sql_wrapper, execute_sql)
190-
# Update with the new docstring
183+
@functools.wraps(execute_sql)
184+
async def execute_sql_wrapper(*args, **kwargs) -> dict:
185+
return await execute_sql(*args, **kwargs)
186+
191187
execute_sql_wrapper.__doc__ = _EXECUTE_SQL_DICT_LIST_MODE_DOCSTRING
192188
return execute_sql_wrapper
193189

src/google/adk/tools/spanner/search_tool.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import asyncio
1718
import json
1819
from typing import Any
1920
from typing import Dict
@@ -230,7 +231,7 @@ def _generate_sql_for_ann(
230231
"""
231232

232233

233-
def similarity_search(
234+
async def similarity_search(
234235
project_id: str,
235236
instance_id: str,
236237
database_id: str,
@@ -462,13 +463,16 @@ def similarity_search(
462463

463464
# Generate embedding for the query according to the embedding options.
464465
if vertex_ai_embedding_model_name:
465-
embedding = utils.embed_contents(
466-
vertex_ai_embedding_model_name,
467-
[query],
468-
output_dimensionality,
466+
embedding = (
467+
await utils.embed_contents_async(
468+
vertex_ai_embedding_model_name,
469+
[query],
470+
output_dimensionality,
471+
)
469472
)[0]
470473
else:
471-
embedding = _get_embedding_for_query(
474+
embedding = await asyncio.to_thread(
475+
_get_embedding_for_query,
472476
database,
473477
database.database_dialect,
474478
spanner_gsql_embedding_model_name,
@@ -507,30 +511,28 @@ def similarity_search(
507511
else:
508512
params = {_GOOGLESQL_PARAMETER_QUERY_EMBEDDING: embedding}
509513

510-
with database.snapshot() as snapshot:
511-
result_set = snapshot.execute_sql(sql, params=params)
512-
rows = []
513-
result = {}
514-
for row in result_set:
515-
try:
516-
# if the json serialization of the row succeeds, use it as is
517-
json.dumps(row)
518-
except (TypeError, ValueError, OverflowError):
519-
row = str(row)
520-
521-
rows.append(row)
522-
523-
result["status"] = "SUCCESS"
524-
result["rows"] = rows
525-
return result
514+
def _execute_sql():
515+
with database.snapshot() as snapshot:
516+
result_set = snapshot.execute_sql(sql, params=params)
517+
rows = []
518+
for row in result_set:
519+
try:
520+
# If the json serialization of the row succeeds, use it as is
521+
json.dumps(row)
522+
except (TypeError, ValueError, OverflowError):
523+
row = str(row)
524+
rows.append(row)
525+
return {"status": "SUCCESS", "rows": rows}
526+
527+
return await asyncio.to_thread(_execute_sql)
526528
except Exception as ex:
527529
return {
528530
"status": "ERROR",
529531
"error_details": repr(ex),
530532
}
531533

532534

533-
def vector_store_similarity_search(
535+
async def vector_store_similarity_search(
534536
query: str,
535537
credentials: Credentials,
536538
settings: SpannerToolSettings,
@@ -605,7 +607,7 @@ def vector_store_similarity_search(
605607
settings.vector_store_settings.num_leaves_to_search
606608
)
607609

608-
return similarity_search(
610+
return await similarity_search(
609611
project_id=settings.vector_store_settings.project_id,
610612
instance_id=settings.vector_store_settings.instance_id,
611613
database_id=settings.vector_store_settings.database_id,

tests/unittests/tools/spanner/test_search_tool.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,12 @@ def mock_spanner_ids():
5454
),
5555
],
5656
)
57-
@mock.patch.object(utils, "embed_contents")
57+
@pytest.mark.asyncio
58+
@mock.patch.object(utils, "embed_contents_async", autospec=True)
5859
@mock.patch.object(client, "get_spanner_client")
59-
def test_similarity_search_knn_success(
60+
async def test_similarity_search_knn_success(
6061
mock_get_spanner_client,
61-
mock_embed_contents,
62+
mock_embed_contents_async,
6263
mock_spanner_ids,
6364
mock_credentials,
6465
embedding_option_key,
@@ -77,7 +78,7 @@ def test_similarity_search_knn_success(
7778
mock_get_spanner_client.return_value = mock_spanner_client
7879

7980
if embedding_option_key == "vertex_ai_embedding_model_name":
80-
mock_embed_contents.return_value = [expected_embedding]
81+
mock_embed_contents_async.return_value = [expected_embedding]
8182
# execute_sql is called once for the kNN search
8283
mock_snapshot.execute_sql.return_value = iter([("result1",), ("result2",)])
8384
else:
@@ -90,7 +91,7 @@ def test_similarity_search_knn_success(
9091
iter([("result1",), ("result2",)]),
9192
]
9293

93-
result = search_tool.similarity_search(
94+
result = await search_tool.similarity_search(
9495
project_id=mock_spanner_ids["project_id"],
9596
instance_id=mock_spanner_ids["instance_id"],
9697
database_id=mock_spanner_ids["database_id"],
@@ -111,13 +112,14 @@ def test_similarity_search_knn_success(
111112
assert "@embedding" in sql
112113
assert call_args.kwargs == {"params": {"embedding": expected_embedding}}
113114
if embedding_option_key == "vertex_ai_embedding_model_name":
114-
mock_embed_contents.assert_called_once_with(
115+
mock_embed_contents_async.assert_called_once_with(
115116
embedding_option_value, ["test query"], None
116117
)
117118

118119

120+
@pytest.mark.asyncio
119121
@mock.patch.object(client, "get_spanner_client")
120-
def test_similarity_search_ann_success(
122+
async def test_similarity_search_ann_success(
121123
mock_get_spanner_client, mock_spanner_ids, mock_credentials
122124
):
123125
"""Test similarity_search function with ANN success."""
@@ -139,7 +141,7 @@ def test_similarity_search_ann_success(
139141
mock_spanner_client.instance.return_value = mock_instance
140142
mock_get_spanner_client.return_value = mock_spanner_client
141143

142-
result = search_tool.similarity_search(
144+
result = await search_tool.similarity_search(
143145
project_id=mock_spanner_ids["project_id"],
144146
instance_id=mock_spanner_ids["instance_id"],
145147
database_id=mock_spanner_ids["database_id"],
@@ -164,13 +166,14 @@ def test_similarity_search_ann_success(
164166
assert call_args.kwargs == {"params": {"embedding": [0.1, 0.2, 0.3]}}
165167

166168

169+
@pytest.mark.asyncio
167170
@mock.patch.object(client, "get_spanner_client")
168-
def test_similarity_search_error(
171+
async def test_similarity_search_error(
169172
mock_get_spanner_client, mock_spanner_ids, mock_credentials
170173
):
171174
"""Test similarity_search function with a generic error."""
172175
mock_get_spanner_client.side_effect = Exception("Test Exception")
173-
result = search_tool.similarity_search(
176+
result = await search_tool.similarity_search(
174177
project_id=mock_spanner_ids["project_id"],
175178
instance_id=mock_spanner_ids["instance_id"],
176179
database_id=mock_spanner_ids["database_id"],
@@ -187,11 +190,12 @@ def test_similarity_search_error(
187190
assert "Test Exception" in result["error_details"]
188191

189192

190-
@mock.patch.object(utils, "embed_contents")
193+
@pytest.mark.asyncio
194+
@mock.patch.object(utils, "embed_contents_async")
191195
@mock.patch.object(client, "get_spanner_client")
192-
def test_similarity_search_circular_row_fallback_to_string(
196+
async def test_similarity_search_circular_row_fallback_to_string(
193197
mock_get_spanner_client,
194-
mock_embed_contents,
198+
mock_embed_contents_async,
195199
mock_spanner_ids,
196200
mock_credentials,
197201
):
@@ -202,15 +206,15 @@ def test_similarity_search_circular_row_fallback_to_string(
202206
mock_snapshot = MagicMock()
203207
circular_row = []
204208
circular_row.append(circular_row)
205-
mock_embed_contents.return_value = [[0.1, 0.2, 0.3]]
209+
mock_embed_contents_async.return_value = [[0.1, 0.2, 0.3]]
206210
mock_snapshot.execute_sql.return_value = iter([circular_row])
207211
mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
208212
mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
209213
mock_instance.database.return_value = mock_database
210214
mock_spanner_client.instance.return_value = mock_instance
211215
mock_get_spanner_client.return_value = mock_spanner_client
212216

213-
result = search_tool.similarity_search(
217+
result = await search_tool.similarity_search(
214218
project_id=mock_spanner_ids["project_id"],
215219
instance_id=mock_spanner_ids["instance_id"],
216220
database_id=mock_spanner_ids["database_id"],
@@ -228,8 +232,9 @@ def test_similarity_search_circular_row_fallback_to_string(
228232
assert result["rows"] == [str(circular_row)]
229233

230234

235+
@pytest.mark.asyncio
231236
@mock.patch.object(client, "get_spanner_client")
232-
def test_similarity_search_postgresql_knn_success(
237+
async def test_similarity_search_postgresql_knn_success(
233238
mock_get_spanner_client, mock_spanner_ids, mock_credentials
234239
):
235240
"""Test similarity_search with PostgreSQL dialect for kNN."""
@@ -249,7 +254,7 @@ def test_similarity_search_postgresql_knn_success(
249254
mock_spanner_client.instance.return_value = mock_instance
250255
mock_get_spanner_client.return_value = mock_spanner_client
251256

252-
result = search_tool.similarity_search(
257+
result = await search_tool.similarity_search(
253258
project_id=mock_spanner_ids["project_id"],
254259
instance_id=mock_spanner_ids["instance_id"],
255260
database_id=mock_spanner_ids["database_id"],
@@ -273,8 +278,9 @@ def test_similarity_search_postgresql_knn_success(
273278
assert call_args.kwargs == {"params": {"p1": [0.1, 0.2, 0.3]}}
274279

275280

281+
@pytest.mark.asyncio
276282
@mock.patch.object(client, "get_spanner_client")
277-
def test_similarity_search_postgresql_ann_unsupported(
283+
async def test_similarity_search_postgresql_ann_unsupported(
278284
mock_get_spanner_client, mock_spanner_ids, mock_credentials
279285
):
280286
"""Test similarity_search with unsupported ANN for PostgreSQL dialect."""
@@ -286,7 +292,7 @@ def test_similarity_search_postgresql_ann_unsupported(
286292
mock_spanner_client.instance.return_value = mock_instance
287293
mock_get_spanner_client.return_value = mock_spanner_client
288294

289-
result = search_tool.similarity_search(
295+
result = await search_tool.similarity_search(
290296
project_id=mock_spanner_ids["project_id"],
291297
instance_id=mock_spanner_ids["instance_id"],
292298
database_id=mock_spanner_ids["database_id"],
@@ -311,8 +317,9 @@ def test_similarity_search_postgresql_ann_unsupported(
311317
)
312318

313319

320+
@pytest.mark.asyncio
314321
@mock.patch.object(client, "get_spanner_client")
315-
def test_similarity_search_gsql_missing_embedding_model_error(
322+
async def test_similarity_search_gsql_missing_embedding_model_error(
316323
mock_get_spanner_client, mock_spanner_ids, mock_credentials
317324
):
318325
"""Test similarity_search with missing embedding_options for GoogleSQL dialect."""
@@ -324,7 +331,7 @@ def test_similarity_search_gsql_missing_embedding_model_error(
324331
mock_spanner_client.instance.return_value = mock_instance
325332
mock_get_spanner_client.return_value = mock_spanner_client
326333

327-
result = search_tool.similarity_search(
334+
result = await search_tool.similarity_search(
328335
project_id=mock_spanner_ids["project_id"],
329336
instance_id=mock_spanner_ids["instance_id"],
330337
database_id=mock_spanner_ids["database_id"],
@@ -348,8 +355,9 @@ def test_similarity_search_gsql_missing_embedding_model_error(
348355
)
349356

350357

358+
@pytest.mark.asyncio
351359
@mock.patch.object(client, "get_spanner_client")
352-
def test_similarity_search_pg_missing_embedding_model_error(
360+
async def test_similarity_search_pg_missing_embedding_model_error(
353361
mock_get_spanner_client, mock_spanner_ids, mock_credentials
354362
):
355363
"""Test similarity_search with missing embedding_options for PostgreSQL dialect."""
@@ -361,7 +369,7 @@ def test_similarity_search_pg_missing_embedding_model_error(
361369
mock_spanner_client.instance.return_value = mock_instance
362370
mock_get_spanner_client.return_value = mock_spanner_client
363371

364-
result = search_tool.similarity_search(
372+
result = await search_tool.similarity_search(
365373
project_id=mock_spanner_ids["project_id"],
366374
instance_id=mock_spanner_ids["instance_id"],
367375
database_id=mock_spanner_ids["database_id"],
@@ -427,8 +435,9 @@ def test_similarity_search_pg_missing_embedding_model_error(
427435
),
428436
],
429437
)
438+
@pytest.mark.asyncio
430439
@mock.patch.object(client, "get_spanner_client")
431-
def test_similarity_search_multiple_embedding_options_error(
440+
async def test_similarity_search_multiple_embedding_options_error(
432441
mock_get_spanner_client,
433442
mock_spanner_ids,
434443
mock_credentials,
@@ -443,7 +452,7 @@ def test_similarity_search_multiple_embedding_options_error(
443452
mock_spanner_client.instance.return_value = mock_instance
444453
mock_get_spanner_client.return_value = mock_spanner_client
445454

446-
result = search_tool.similarity_search(
455+
result = await search_tool.similarity_search(
447456
project_id=mock_spanner_ids["project_id"],
448457
instance_id=mock_spanner_ids["instance_id"],
449458
database_id=mock_spanner_ids["database_id"],
@@ -461,8 +470,9 @@ def test_similarity_search_multiple_embedding_options_error(
461470
)
462471

463472

473+
@pytest.mark.asyncio
464474
@mock.patch.object(client, "get_spanner_client")
465-
def test_similarity_search_output_dimensionality_gsql_error(
475+
async def test_similarity_search_output_dimensionality_gsql_error(
466476
mock_get_spanner_client, mock_spanner_ids, mock_credentials
467477
):
468478
"""Test similarity_search with output_dimensionality and spanner_googlesql_embedding_model_name."""
@@ -474,7 +484,7 @@ def test_similarity_search_output_dimensionality_gsql_error(
474484
mock_spanner_client.instance.return_value = mock_instance
475485
mock_get_spanner_client.return_value = mock_spanner_client
476486

477-
result = search_tool.similarity_search(
487+
result = await search_tool.similarity_search(
478488
project_id=mock_spanner_ids["project_id"],
479489
instance_id=mock_spanner_ids["instance_id"],
480490
database_id=mock_spanner_ids["database_id"],
@@ -492,8 +502,9 @@ def test_similarity_search_output_dimensionality_gsql_error(
492502
assert "is not supported when" in result["error_details"]
493503

494504

505+
@pytest.mark.asyncio
495506
@mock.patch.object(client, "get_spanner_client")
496-
def test_similarity_search_unsupported_algorithm_error(
507+
async def test_similarity_search_unsupported_algorithm_error(
497508
mock_get_spanner_client, mock_spanner_ids, mock_credentials
498509
):
499510
"""Test similarity_search with an unsupported nearest neighbors algorithm."""
@@ -505,7 +516,7 @@ def test_similarity_search_unsupported_algorithm_error(
505516
mock_spanner_client.instance.return_value = mock_instance
506517
mock_get_spanner_client.return_value = mock_spanner_client
507518

508-
result = search_tool.similarity_search(
519+
result = await search_tool.similarity_search(
509520
project_id=mock_spanner_ids["project_id"],
510521
instance_id=mock_spanner_ids["instance_id"],
511522
database_id=mock_spanner_ids["database_id"],

0 commit comments

Comments
 (0)