@@ -54,11 +54,12 @@ def mock_spanner_ids():
5454 ),
5555 ],
5656)
57- @mock .patch .object (utils , "embed_contents" )
57+ @pytest .mark .asyncio
58+ @mock .patch .object (utils , "embed_contents_async" , autospec = True )
5859@mock .patch .object (client , "get_spanner_client" )
59- def test_similarity_search_knn_success (
60+ async def test_similarity_search_knn_success (
6061 mock_get_spanner_client ,
61- mock_embed_contents ,
62+ mock_embed_contents_async ,
6263 mock_spanner_ids ,
6364 mock_credentials ,
6465 embedding_option_key ,
@@ -77,7 +78,7 @@ def test_similarity_search_knn_success(
7778 mock_get_spanner_client .return_value = mock_spanner_client
7879
7980 if embedding_option_key == "vertex_ai_embedding_model_name" :
80- mock_embed_contents .return_value = [expected_embedding ]
81+ mock_embed_contents_async .return_value = [expected_embedding ]
8182 # execute_sql is called once for the kNN search
8283 mock_snapshot .execute_sql .return_value = iter ([("result1" ,), ("result2" ,)])
8384 else :
@@ -90,7 +91,7 @@ def test_similarity_search_knn_success(
9091 iter ([("result1" ,), ("result2" ,)]),
9192 ]
9293
93- result = search_tool .similarity_search (
94+ result = await search_tool .similarity_search (
9495 project_id = mock_spanner_ids ["project_id" ],
9596 instance_id = mock_spanner_ids ["instance_id" ],
9697 database_id = mock_spanner_ids ["database_id" ],
@@ -111,13 +112,14 @@ def test_similarity_search_knn_success(
111112 assert "@embedding" in sql
112113 assert call_args .kwargs == {"params" : {"embedding" : expected_embedding }}
113114 if embedding_option_key == "vertex_ai_embedding_model_name" :
114- mock_embed_contents .assert_called_once_with (
115+ mock_embed_contents_async .assert_called_once_with (
115116 embedding_option_value , ["test query" ], None
116117 )
117118
118119
120+ @pytest .mark .asyncio
119121@mock .patch .object (client , "get_spanner_client" )
120- def test_similarity_search_ann_success (
122+ async def test_similarity_search_ann_success (
121123 mock_get_spanner_client , mock_spanner_ids , mock_credentials
122124):
123125 """Test similarity_search function with ANN success."""
@@ -139,7 +141,7 @@ def test_similarity_search_ann_success(
139141 mock_spanner_client .instance .return_value = mock_instance
140142 mock_get_spanner_client .return_value = mock_spanner_client
141143
142- result = search_tool .similarity_search (
144+ result = await search_tool .similarity_search (
143145 project_id = mock_spanner_ids ["project_id" ],
144146 instance_id = mock_spanner_ids ["instance_id" ],
145147 database_id = mock_spanner_ids ["database_id" ],
@@ -164,13 +166,14 @@ def test_similarity_search_ann_success(
164166 assert call_args .kwargs == {"params" : {"embedding" : [0.1 , 0.2 , 0.3 ]}}
165167
166168
169+ @pytest .mark .asyncio
167170@mock .patch .object (client , "get_spanner_client" )
168- def test_similarity_search_error (
171+ async def test_similarity_search_error (
169172 mock_get_spanner_client , mock_spanner_ids , mock_credentials
170173):
171174 """Test similarity_search function with a generic error."""
172175 mock_get_spanner_client .side_effect = Exception ("Test Exception" )
173- result = search_tool .similarity_search (
176+ result = await search_tool .similarity_search (
174177 project_id = mock_spanner_ids ["project_id" ],
175178 instance_id = mock_spanner_ids ["instance_id" ],
176179 database_id = mock_spanner_ids ["database_id" ],
@@ -187,11 +190,12 @@ def test_similarity_search_error(
187190 assert "Test Exception" in result ["error_details" ]
188191
189192
190- @mock .patch .object (utils , "embed_contents" )
193+ @pytest .mark .asyncio
194+ @mock .patch .object (utils , "embed_contents_async" )
191195@mock .patch .object (client , "get_spanner_client" )
192- def test_similarity_search_circular_row_fallback_to_string (
196+ async def test_similarity_search_circular_row_fallback_to_string (
193197 mock_get_spanner_client ,
194- mock_embed_contents ,
198+ mock_embed_contents_async ,
195199 mock_spanner_ids ,
196200 mock_credentials ,
197201):
@@ -202,15 +206,15 @@ def test_similarity_search_circular_row_fallback_to_string(
202206 mock_snapshot = MagicMock ()
203207 circular_row = []
204208 circular_row .append (circular_row )
205- mock_embed_contents .return_value = [[0.1 , 0.2 , 0.3 ]]
209+ mock_embed_contents_async .return_value = [[0.1 , 0.2 , 0.3 ]]
206210 mock_snapshot .execute_sql .return_value = iter ([circular_row ])
207211 mock_database .snapshot .return_value .__enter__ .return_value = mock_snapshot
208212 mock_database .database_dialect = DatabaseDialect .GOOGLE_STANDARD_SQL
209213 mock_instance .database .return_value = mock_database
210214 mock_spanner_client .instance .return_value = mock_instance
211215 mock_get_spanner_client .return_value = mock_spanner_client
212216
213- result = search_tool .similarity_search (
217+ result = await search_tool .similarity_search (
214218 project_id = mock_spanner_ids ["project_id" ],
215219 instance_id = mock_spanner_ids ["instance_id" ],
216220 database_id = mock_spanner_ids ["database_id" ],
@@ -228,8 +232,9 @@ def test_similarity_search_circular_row_fallback_to_string(
228232 assert result ["rows" ] == [str (circular_row )]
229233
230234
235+ @pytest .mark .asyncio
231236@mock .patch .object (client , "get_spanner_client" )
232- def test_similarity_search_postgresql_knn_success (
237+ async def test_similarity_search_postgresql_knn_success (
233238 mock_get_spanner_client , mock_spanner_ids , mock_credentials
234239):
235240 """Test similarity_search with PostgreSQL dialect for kNN."""
@@ -249,7 +254,7 @@ def test_similarity_search_postgresql_knn_success(
249254 mock_spanner_client .instance .return_value = mock_instance
250255 mock_get_spanner_client .return_value = mock_spanner_client
251256
252- result = search_tool .similarity_search (
257+ result = await search_tool .similarity_search (
253258 project_id = mock_spanner_ids ["project_id" ],
254259 instance_id = mock_spanner_ids ["instance_id" ],
255260 database_id = mock_spanner_ids ["database_id" ],
@@ -273,8 +278,9 @@ def test_similarity_search_postgresql_knn_success(
273278 assert call_args .kwargs == {"params" : {"p1" : [0.1 , 0.2 , 0.3 ]}}
274279
275280
281+ @pytest .mark .asyncio
276282@mock .patch .object (client , "get_spanner_client" )
277- def test_similarity_search_postgresql_ann_unsupported (
283+ async def test_similarity_search_postgresql_ann_unsupported (
278284 mock_get_spanner_client , mock_spanner_ids , mock_credentials
279285):
280286 """Test similarity_search with unsupported ANN for PostgreSQL dialect."""
@@ -286,7 +292,7 @@ def test_similarity_search_postgresql_ann_unsupported(
286292 mock_spanner_client .instance .return_value = mock_instance
287293 mock_get_spanner_client .return_value = mock_spanner_client
288294
289- result = search_tool .similarity_search (
295+ result = await search_tool .similarity_search (
290296 project_id = mock_spanner_ids ["project_id" ],
291297 instance_id = mock_spanner_ids ["instance_id" ],
292298 database_id = mock_spanner_ids ["database_id" ],
@@ -311,8 +317,9 @@ def test_similarity_search_postgresql_ann_unsupported(
311317 )
312318
313319
320+ @pytest .mark .asyncio
314321@mock .patch .object (client , "get_spanner_client" )
315- def test_similarity_search_gsql_missing_embedding_model_error (
322+ async def test_similarity_search_gsql_missing_embedding_model_error (
316323 mock_get_spanner_client , mock_spanner_ids , mock_credentials
317324):
318325 """Test similarity_search with missing embedding_options for GoogleSQL dialect."""
@@ -324,7 +331,7 @@ def test_similarity_search_gsql_missing_embedding_model_error(
324331 mock_spanner_client .instance .return_value = mock_instance
325332 mock_get_spanner_client .return_value = mock_spanner_client
326333
327- result = search_tool .similarity_search (
334+ result = await search_tool .similarity_search (
328335 project_id = mock_spanner_ids ["project_id" ],
329336 instance_id = mock_spanner_ids ["instance_id" ],
330337 database_id = mock_spanner_ids ["database_id" ],
@@ -348,8 +355,9 @@ def test_similarity_search_gsql_missing_embedding_model_error(
348355 )
349356
350357
358+ @pytest .mark .asyncio
351359@mock .patch .object (client , "get_spanner_client" )
352- def test_similarity_search_pg_missing_embedding_model_error (
360+ async def test_similarity_search_pg_missing_embedding_model_error (
353361 mock_get_spanner_client , mock_spanner_ids , mock_credentials
354362):
355363 """Test similarity_search with missing embedding_options for PostgreSQL dialect."""
@@ -361,7 +369,7 @@ def test_similarity_search_pg_missing_embedding_model_error(
361369 mock_spanner_client .instance .return_value = mock_instance
362370 mock_get_spanner_client .return_value = mock_spanner_client
363371
364- result = search_tool .similarity_search (
372+ result = await search_tool .similarity_search (
365373 project_id = mock_spanner_ids ["project_id" ],
366374 instance_id = mock_spanner_ids ["instance_id" ],
367375 database_id = mock_spanner_ids ["database_id" ],
@@ -427,8 +435,9 @@ def test_similarity_search_pg_missing_embedding_model_error(
427435 ),
428436 ],
429437)
438+ @pytest .mark .asyncio
430439@mock .patch .object (client , "get_spanner_client" )
431- def test_similarity_search_multiple_embedding_options_error (
440+ async def test_similarity_search_multiple_embedding_options_error (
432441 mock_get_spanner_client ,
433442 mock_spanner_ids ,
434443 mock_credentials ,
@@ -443,7 +452,7 @@ def test_similarity_search_multiple_embedding_options_error(
443452 mock_spanner_client .instance .return_value = mock_instance
444453 mock_get_spanner_client .return_value = mock_spanner_client
445454
446- result = search_tool .similarity_search (
455+ result = await search_tool .similarity_search (
447456 project_id = mock_spanner_ids ["project_id" ],
448457 instance_id = mock_spanner_ids ["instance_id" ],
449458 database_id = mock_spanner_ids ["database_id" ],
@@ -461,8 +470,9 @@ def test_similarity_search_multiple_embedding_options_error(
461470 )
462471
463472
473+ @pytest .mark .asyncio
464474@mock .patch .object (client , "get_spanner_client" )
465- def test_similarity_search_output_dimensionality_gsql_error (
475+ async def test_similarity_search_output_dimensionality_gsql_error (
466476 mock_get_spanner_client , mock_spanner_ids , mock_credentials
467477):
468478 """Test similarity_search with output_dimensionality and spanner_googlesql_embedding_model_name."""
@@ -474,7 +484,7 @@ def test_similarity_search_output_dimensionality_gsql_error(
474484 mock_spanner_client .instance .return_value = mock_instance
475485 mock_get_spanner_client .return_value = mock_spanner_client
476486
477- result = search_tool .similarity_search (
487+ result = await search_tool .similarity_search (
478488 project_id = mock_spanner_ids ["project_id" ],
479489 instance_id = mock_spanner_ids ["instance_id" ],
480490 database_id = mock_spanner_ids ["database_id" ],
@@ -492,8 +502,9 @@ def test_similarity_search_output_dimensionality_gsql_error(
492502 assert "is not supported when" in result ["error_details" ]
493503
494504
505+ @pytest .mark .asyncio
495506@mock .patch .object (client , "get_spanner_client" )
496- def test_similarity_search_unsupported_algorithm_error (
507+ async def test_similarity_search_unsupported_algorithm_error (
497508 mock_get_spanner_client , mock_spanner_ids , mock_credentials
498509):
499510 """Test similarity_search with an unsupported nearest neighbors algorithm."""
@@ -505,7 +516,7 @@ def test_similarity_search_unsupported_algorithm_error(
505516 mock_spanner_client .instance .return_value = mock_instance
506517 mock_get_spanner_client .return_value = mock_spanner_client
507518
508- result = search_tool .similarity_search (
519+ result = await search_tool .similarity_search (
509520 project_id = mock_spanner_ids ["project_id" ],
510521 instance_id = mock_spanner_ids ["instance_id" ],
511522 database_id = mock_spanner_ids ["database_id" ],
0 commit comments