diff --git a/eng/apiview_reqs.txt b/eng/apiview_reqs.txt index 0d936e63e2c5..90bcd47a1c4a 100644 --- a/eng/apiview_reqs.txt +++ b/eng/apiview_reqs.txt @@ -14,4 +14,4 @@ tomli==2.2.1 tomlkit==0.13.2 typing_extensions==4.15.0 wrapt==1.17.2 -apiview-stub-generator==0.3.24 \ No newline at end of file +apiview-stub-generator==0.3.25 \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 9166aed2c1a5..4e598bf68a54 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -1,5 +1,21 @@ ## Release History +### 4.14.4 (Unreleased) + +#### Features Added + +#### Breaking Changes + +#### Bugs Fixed + +#### Other Changes + +### 4.14.3 (2025-12-08) + +#### Bugs Fixed +* Fixed bug where client timeout/read_timeout values were not properly enforced. See [PR 42652](https://github.com/Azure/azure-sdk-for-python/pull/42652). +* Fixed bug when passing in None for some options in `query_items` would cause unexpected errors. See [PR 44098](https://github.com/Azure/azure-sdk-for-python/pull/44098) + ### 4.14.2 (2025-11-14) #### Features Added diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index d066135500d1..b581d3a0c09c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -23,6 +23,7 @@ """ import base64 +import time from email.utils import formatdate import json import uuid @@ -109,6 +110,13 @@ def build_options(kwargs: dict[str, Any]) -> dict[str, Any]: for key, value in _COMMON_OPTIONS.items(): if key in kwargs: options[value] = kwargs.pop(key) + if 'read_timeout' in kwargs: + options['read_timeout'] = kwargs['read_timeout'] + if 'timeout' in kwargs: + options['timeout'] = kwargs['timeout'] + + + options[Constants.OperationStartTime] = time.time() if_match, if_none_match = _get_match_headers(kwargs) if if_match: options['accessCondition'] = {'type': 'IfMatch', 'condition': if_match} diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py index d6a23050c226..f49798c8f8b8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py @@ -25,14 +25,21 @@ from typing_extensions import Literal -# cspell:ignore reranker +class TimeoutScope: + """Defines the scope of timeout application""" + OPERATION: Literal["operation"] = "operation" # Apply timeout to entire logical operation + PAGE: Literal["page"] = "page" # Apply timeout to individual page requests +# cspell:ignore reranker class _Constants: """Constants used in the azure-cosmos package""" UserConsistencyPolicy: Literal["userConsistencyPolicy"] = "userConsistencyPolicy" DefaultConsistencyLevel: Literal["defaultConsistencyLevel"] = "defaultConsistencyLevel" + OperationStartTime: Literal["operationStartTime"] = "operationStartTime" + # whether to apply timeout to the whole logical operation or just a page request + TimeoutScope: Literal["timeoutScope"] = "timeoutScope" # GlobalDB related constants WritableLocations: Literal["writableLocations"] = "writableLocations" diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 4e08365e2c47..d3ffadf8b82f 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -3169,6 +3169,18 @@ def __QueryFeed( # pylint: disable=too-many-locals, too-many-statements, too-ma """ if options is None: options = {} + read_timeout = options.get("read_timeout") + if read_timeout is not None: + # we currently have a gap where kwargs are not getting passed correctly down the pipeline. In order to make + # absolute time out work, we are passing read_timeout via kwargs as a temporary fix + kwargs.setdefault("read_timeout", read_timeout) + + operation_start_time = options.get(Constants.OperationStartTime) + if operation_start_time is not None: + kwargs.setdefault(Constants.OperationStartTime, operation_start_time) + timeout = options.get("timeout") + if timeout is not None: + kwargs.setdefault("timeout", timeout) if query: __GetBodiesFromQueryResult = result_fn diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/base_execution_context.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/base_execution_context.py index 560ca6c05389..3687a9179e19 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/base_execution_context.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/base_execution_context.py @@ -131,10 +131,17 @@ async def _fetch_items_helper_no_retries(self, fetch_function): return fetched_items async def _fetch_items_helper_with_retries(self, fetch_function): - async def callback(): + # TODO: Properly propagate kwargs from retry utility to fetch function + # the callback keep the **kwargs parameter to maintain compatibility with the retry utility's execution pattern. + # ExecuteAsync passes retry context parameters (timeout, operation start time, logger, etc.) + # The callback need to accept these parameters even if unused + # Removing **kwargs results in a TypeError when ExecuteAsync tries to pass these parameters + async def callback(**kwargs): # pylint: disable=unused-argument return await self._fetch_items_helper_no_retries(fetch_function) - return await _retry_utility_async.ExecuteAsync(self._client, self._client._global_endpoint_manager, callback) + return await _retry_utility_async.ExecuteAsync( + self._client, self._client._global_endpoint_manager, callback, **self._options + ) class _DefaultQueryExecutionContext(_QueryExecutionContextBase): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py index 845284f766ad..afa5a564715d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py @@ -66,13 +66,19 @@ def __init__(self, client, resource_link, query, options, fetch_function, async def _create_execution_context_with_query_plan(self): self._fetched_query_plan = True query_to_use = self._query if self._query is not None else "Select * from root r" - query_execution_info = _PartitionedQueryExecutionInfo(await self._client._GetQueryPlanThroughGateway - (query_to_use, self._resource_link, self._options.get('excludedLocations'))) + query_plan = await self._client._GetQueryPlanThroughGateway( + query_to_use, + self._resource_link, + self._options.get('excludedLocations'), + read_timeout=self._options.get('read_timeout') + ) + query_execution_info = _PartitionedQueryExecutionInfo(query_plan) qe_info = getattr(query_execution_info, "_query_execution_info", None) if isinstance(qe_info, dict) and isinstance(query_to_use, dict): params = query_to_use.get("parameters") if params is not None: query_execution_info._query_execution_info['parameters'] = params + self._execution_context = await self._create_pipelined_execution_context(query_execution_info) async def __anext__(self): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/base_execution_context.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/base_execution_context.py index 1ee4d067f3b0..4cdf8dddca9b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/base_execution_context.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/base_execution_context.py @@ -129,10 +129,15 @@ def _fetch_items_helper_no_retries(self, fetch_function): return fetched_items def _fetch_items_helper_with_retries(self, fetch_function): - def callback(): + # TODO: Properly propagate kwargs from retry utility to fetch function + # the callback keep the **kwargs parameter to maintain compatibility with the retry utility's execution pattern. + # ExecuteAsync passes retry context parameters (timeout, operation start time, logger, etc.) + # The callback need to accept these parameters even if unused + # Removing **kwargs results in a TypeError when ExecuteAsync tries to pass these parameters + def callback(**kwargs): # pylint: disable=unused-argument return self._fetch_items_helper_no_retries(fetch_function) - return _retry_utility.Execute(self._client, self._client._global_endpoint_manager, callback) + return _retry_utility.Execute(self._client, self._client._global_endpoint_manager, callback, **self._options) next = __next__ # Python 2 compatibility. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py index 0567bd5d6cf7..a5d807021acc 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py @@ -96,14 +96,19 @@ def __init__(self, client, resource_link, query, options, fetch_function, respon def _create_execution_context_with_query_plan(self): self._fetched_query_plan = True query_to_use = self._query if self._query is not None else "Select * from root r" - query_execution_info = _PartitionedQueryExecutionInfo(self._client._GetQueryPlanThroughGateway - (query_to_use, self._resource_link, self._options.get('excludedLocations'))) - + query_plan = self._client._GetQueryPlanThroughGateway( + query_to_use, + self._resource_link, + self._options.get('excludedLocations'), + read_timeout=self._options.get('read_timeout') + ) + query_execution_info = _PartitionedQueryExecutionInfo(query_plan) qe_info = getattr(query_execution_info, "_query_execution_info", None) if isinstance(qe_info, dict) and isinstance(query_to_use, dict): params = query_to_use.get("parameters") if params is not None: query_execution_info._query_execution_info['parameters'] = params + self._execution_context = self._create_pipelined_execution_context(query_execution_info) def __next__(self): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py index be06b24478a8..599d884c9797 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py @@ -21,8 +21,11 @@ """Iterable query results in the Azure Cosmos database service. """ +import time from azure.core.paging import PageIterator # type: ignore +from azure.cosmos._constants import _Constants, TimeoutScope from azure.cosmos._execution_context import execution_dispatcher +from azure.cosmos import exceptions # pylint: disable=protected-access @@ -99,6 +102,17 @@ def _fetch_next(self, *args): # pylint: disable=unused-argument :return: List of results. :rtype: list """ + timeout = self._options.get('timeout') + # reset the operation start time if it's a paged request + if timeout and self._options.get(_Constants.TimeoutScope) != TimeoutScope.OPERATION: + self._options[_Constants.OperationStartTime] = time.time() + + # Check timeout before fetching next block + if timeout: + elapsed = time.time() - self._options.get(_Constants.OperationStartTime) + if elapsed >= timeout: + raise exceptions.CosmosClientTimeoutError() + block = self._ex_context.fetch_next_block() if not block: raise StopIteration diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py index 9b9153308db2..ef406f5f613a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py @@ -39,12 +39,12 @@ from . import _session_retry_policy from . import _timeout_failover_retry_policy from . import exceptions +from ._constants import _Constants from .documents import _OperationType from .exceptions import CosmosHttpResponseError from .http_constants import HttpHeaders, StatusCodes, SubStatusCodes, ResourceType from ._cosmos_http_logging_policy import _log_diagnostics_error - # pylint: disable=protected-access, disable=too-many-lines, disable=too-many-statements, disable=too-many-branches # args [0] is the request object @@ -64,6 +64,13 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin :returns: the result of running the passed in function as a (result, headers) tuple :rtype: tuple of (dict, dict) """ + # Capture the client timeout and start time at the beginning + timeout = kwargs.get('timeout') + operation_start_time = kwargs.get(_Constants.OperationStartTime, time.time()) + + # Track the last error for chaining + last_error = None + pk_range_wrapper = None if args and global_endpoint_manager.is_circuit_breaker_applicable(args[0]): pk_range_wrapper = global_endpoint_manager.create_pk_range_wrapper(args[0]) @@ -110,14 +117,25 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin client, client._container_properties_cache, None, *args) while True: - client_timeout = kwargs.get('timeout') start_time = time.time() + # Check timeout before executing function + if timeout: + elapsed = time.time() - operation_start_time + if elapsed >= timeout: + raise exceptions.CosmosClientTimeoutError(error=last_error) + try: if args: result = ExecuteFunction(function, global_endpoint_manager, *args, **kwargs) global_endpoint_manager.record_success(args[0]) else: result = ExecuteFunction(function, *args, **kwargs) + # Check timeout after successful execution + if timeout: + elapsed = time.time() - operation_start_time + if elapsed >= timeout: + raise exceptions.CosmosClientTimeoutError(error=last_error) + if not client.last_response_headers: client.last_response_headers = {} @@ -158,6 +176,7 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin return result except exceptions.CosmosHttpResponseError as e: + last_error = e if request: # update session token for relevant operations client._UpdateSessionIfRequired(request.headers, {}, e.headers) @@ -226,12 +245,13 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin client.session.clear_session_token(client.last_response_headers) raise + # Now check timeout before retrying + if timeout: + elapsed = time.time() - operation_start_time + if elapsed >= timeout: + raise exceptions.CosmosClientTimeoutError(error=last_error) # Wait for retry_after_in_milliseconds time before the next retry time.sleep(retry_policy.retry_after_in_milliseconds / 1000.0) - if client_timeout: - kwargs['timeout'] = client_timeout - (time.time() - start_time) - if kwargs['timeout'] <= 0: - raise exceptions.CosmosClientTimeoutError() except ServiceRequestError as e: if request and _has_database_account_header(request.headers): @@ -258,6 +278,7 @@ def ExecuteFunction(function, *args, **kwargs): """ return function(*args, **kwargs) + def _has_read_retryable_headers(request_headers): if _OperationType.IsReadOnlyOperation(request_headers.get(HttpHeaders.ThinClientProxyOperationType)): return True @@ -332,6 +353,7 @@ def send(self, request): :raises ~azure.cosmos.exceptions.CosmosClientTimeoutError: Specified timeout exceeded. :raises ~azure.core.exceptions.ClientAuthenticationError: Authentication failed. """ + absolute_timeout = request.context.options.pop('timeout', None) per_request_timeout = request.context.options.pop('connection_timeout', 0) request_params = request.context.options.pop('request_params', None) @@ -384,6 +406,7 @@ def send(self, request): if retry_active: self.sleep(retry_settings, request.context.transport) continue + raise err except CosmosHttpResponseError as err: raise err diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py index 1b2290981307..d4a56e1545a3 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py @@ -27,10 +27,8 @@ from urllib.parse import urlparse from azure.core.exceptions import DecodeError # type: ignore - -from . import exceptions -from . import http_constants -from . import _retry_utility +from ._constants import _Constants +from . import exceptions, http_constants, _retry_utility def _is_readable_stream(obj): @@ -80,8 +78,8 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin :rtype: tuple of (dict, dict) """ - # pylint: disable=protected-access - + # pylint: disable=protected-access, too-many-branches + kwargs.pop(_Constants.OperationStartTime, None) connection_timeout = connection_policy.RequestTimeout connection_timeout = kwargs.pop("connection_timeout", connection_timeout) read_timeout = connection_policy.ReadTimeout diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_utils.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_utils.py index 9144afca613d..792c33b1efae 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_utils.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_utils.py @@ -137,7 +137,8 @@ def valid_key_value_exist( kwargs: dict[str, Any], key: str, invalid_value: Any = None) -> bool: - """Check if a valid key and value exists in kwargs. By default, it checks if the value is not None. + """Check if a valid key and value exists in kwargs. It always checks if the value is not None and it will remove + from the kwargs the None value. :param dict[str, Any] kwargs: The dictionary of keyword arguments. :param str key: The key to check. @@ -145,4 +146,8 @@ def valid_key_value_exist( :return: True if the key exists and its value is not None, False otherwise. :rtype: bool """ + if key in kwargs and kwargs[key] is None: + kwargs.pop(key) + return False + return key in kwargs and kwargs[key] is not invalid_value diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py index 0b6faf7457ff..e6b2758537f7 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py @@ -19,4 +19,4 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -VERSION = "4.14.2" +VERSION = "4.14.4" diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index c19d9b494abb..f9477b0d3da5 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -31,6 +31,7 @@ from .. import exceptions from .. import http_constants from . import _retry_utility_async +from .._constants import _Constants from .._synchronized_request import _request_body_from_data, _replace_url_prefix @@ -49,8 +50,8 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p :rtype: tuple of (dict, dict) """ - # pylint: disable=protected-access - + # pylint: disable=protected-access, too-many-branches + kwargs.pop(_Constants.OperationStartTime, None) connection_timeout = connection_policy.RequestTimeout read_timeout = connection_policy.ReadTimeout connection_timeout = kwargs.pop("connection_timeout", connection_timeout) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py index ea635bc34b1a..2076149f6a1d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py @@ -39,8 +39,9 @@ from .._base import (_build_properties_cache, _deserialize_throughput, _replace_throughput, build_options as _build_options, GenerateGuidId, validate_cache_staleness_value) from .._change_feed.feed_range_internal import FeedRangeInternalEpk -from .._constants import _Constants as Constants + from .._cosmos_responses import CosmosDict, CosmosList +from .._constants import _Constants as Constants, TimeoutScope from .._routing.routing_range import Range from .._session_token_helpers import get_latest_session_token from ..exceptions import CosmosHttpResponseError @@ -96,8 +97,14 @@ def __repr__(self) -> str: async def _get_properties_with_options(self, options: Optional[dict[str, Any]] = None) -> dict[str, Any]: kwargs = {} - if options and "excludedLocations" in options: - kwargs['excluded_locations'] = options['excludedLocations'] + if options: + if "excludedLocations" in options: + kwargs['excluded_locations'] = options['excludedLocations'] + if Constants.OperationStartTime in options: + kwargs[Constants.OperationStartTime] = options[Constants.OperationStartTime] + if "timeout" in options: + kwargs['timeout'] = options['timeout'] + return await self._get_properties(**kwargs) async def _get_properties(self, **kwargs: Any) -> dict[str, Any]: @@ -483,6 +490,7 @@ async def read_items( query_options = _build_options(kwargs) await self._get_properties_with_options(query_options) query_options["enableCrossPartitionQuery"] = True + query_options[Constants.TimeoutScope] = TimeoutScope.OPERATION item_tuples = [(item_id, await self._set_partition_key(pk)) for item_id, pk in items] return await self.client_connection.read_items( diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 67d5d4efa3e9..59f7c8f6085e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -2976,6 +2976,21 @@ async def __QueryFeed( # pylint: disable=too-many-branches,too-many-statements, if options is None: options = {} + read_timeout = options.get("read_timeout") + if read_timeout is not None: + # we currently have a gap where kwargs are not getting passed correctly down the pipeline. In order to make + # absolute time out work, we are passing read_timeout via kwargs as a temporary fix + kwargs.setdefault("read_timeout", read_timeout) + + operation_start_time = options.get(Constants.OperationStartTime) + if operation_start_time is not None: + # we need to set operation_state in kwargs as thats where it is looked at while sending the request + kwargs.setdefault(Constants.OperationStartTime, operation_start_time) + timeout = options.get("timeout") + if timeout is not None: + # we need to set operation_state in kwargs as that's where it is looked at while sending the request + kwargs.setdefault("timeout", timeout) + if query: __GetBodiesFromQueryResult = result_fn else: @@ -3380,7 +3395,7 @@ async def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, "contentType": runtime_constants.MediaTypes.Json, "isQueryPlanRequest": True, "supportedQueryFeatures": supported_query_features, - "queryVersion": http_constants.Versions.QueryVersion + "queryVersion": http_constants.Versions.QueryVersion, } if excluded_locations is not None: options["excludedLocations"] = excluded_locations diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_query_iterable_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_query_iterable_async.py index d0304ccfde60..bd0f6537dced 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_query_iterable_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_query_iterable_async.py @@ -22,8 +22,13 @@ """Iterable query results in the Azure Cosmos database service. """ import asyncio # pylint: disable=do-not-import-asyncio +import time + from azure.core.async_paging import AsyncPageIterator + +from azure.cosmos._constants import _Constants, TimeoutScope from azure.cosmos._execution_context.aio import execution_dispatcher +from azure.cosmos import exceptions # pylint: disable=protected-access @@ -100,9 +105,23 @@ async def _fetch_next(self, *args): # pylint: disable=unused-argument :return: List of results. :rtype: list """ + timeout = self._options.get('timeout') if 'partitionKey' in self._options and asyncio.iscoroutine(self._options['partitionKey']): self._options['partitionKey'] = await self._options['partitionKey'] + + # Check timeout before fetching next block + + if timeout and self._options.get(_Constants.TimeoutScope) != TimeoutScope.OPERATION: + self._options[_Constants.OperationStartTime] = time.time() + + # Check timeout before fetching next block + if timeout: + elapsed = time.time() - self._options.get(_Constants.OperationStartTime) + if elapsed >= timeout: + raise exceptions.CosmosClientTimeoutError() + block = await self._ex_context.fetch_next_block() + if not block: raise StopAsyncIteration return block diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index be19eedc36b7..ddf9eebb8cde 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -26,7 +26,8 @@ import time import logging -from azure.core.exceptions import AzureError, ClientAuthenticationError, ServiceRequestError, ServiceResponseError +from azure.core.exceptions import (AzureError, ClientAuthenticationError, ServiceRequestError, + ServiceResponseError) from azure.core.pipeline.policies import AsyncRetryPolicy from .. import _default_retry_policy, _health_check_retry_policy @@ -37,6 +38,7 @@ from .. import _session_retry_policy from .. import _timeout_failover_retry_policy from .. import exceptions +from .._constants import _Constants from .._container_recreate_retry_policy import ContainerRecreateRetryPolicy from .._retry_utility import (_configure_timeout, _has_read_retryable_headers, _handle_service_response_retries, _handle_service_request_retries, @@ -65,6 +67,12 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg :returns: the result of running the passed in function as a (result, headers) tuple :rtype: tuple of (dict, dict) """ + timeout = kwargs.get('timeout') + operation_start_time = kwargs.get(_Constants.OperationStartTime, time.time()) + + # Track the last error for chaining + last_error = None + pk_range_wrapper = None if args and global_endpoint_manager.is_circuit_breaker_applicable(args[0]): pk_range_wrapper = await global_endpoint_manager.create_pk_range_wrapper(args[0]) @@ -110,14 +118,23 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg client, client._container_properties_cache, None, *args) while True: - client_timeout = kwargs.get('timeout') start_time = time.time() + # Check timeout before executing function + if timeout: + elapsed = time.time() - operation_start_time + if elapsed >= timeout: + raise exceptions.CosmosClientTimeoutError(error=last_error) try: if args: result = await ExecuteFunctionAsync(function, global_endpoint_manager, *args, **kwargs) await global_endpoint_manager.record_success(args[0]) else: result = await ExecuteFunctionAsync(function, *args, **kwargs) + # Check timeout after successful execution + if timeout: + elapsed = time.time() - operation_start_time + if elapsed >= timeout: + raise exceptions.CosmosClientTimeoutError(error=last_error) if not client.last_response_headers: client.last_response_headers = {} @@ -158,6 +175,7 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg return result except exceptions.CosmosHttpResponseError as e: + last_error = e if request: # update session token for relevant operations client._UpdateSessionIfRequired(request.headers, {}, e.headers) @@ -225,12 +243,13 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg client.session.clear_session_token(client.last_response_headers) raise + # Check timeout only before retrying + if timeout: + elapsed = time.time() - operation_start_time + if elapsed >= timeout: + raise exceptions.CosmosClientTimeoutError(error=last_error) # Wait for retry_after_in_milliseconds time before the next retry await asyncio.sleep(retry_policy.retry_after_in_milliseconds / 1000.0) - if client_timeout: - kwargs['timeout'] = client_timeout - (time.time() - start_time) - if kwargs['timeout'] <= 0: - raise exceptions.CosmosClientTimeoutError() except ServiceRequestError as e: if request and _has_database_account_header(request.headers): @@ -345,6 +364,7 @@ async def send(self, request): if retry_active: await self.sleep(retry_settings, request.context.transport) continue + except ImportError: raise err # pylint: disable=raise-missing-from raise err diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py index f17e39515fbf..9cb5578a8e78 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py @@ -36,7 +36,7 @@ from ._base import (_build_properties_cache, _deserialize_throughput, _replace_throughput, build_options, GenerateGuidId, validate_cache_staleness_value) from ._change_feed.feed_range_internal import FeedRangeInternalEpk -from ._constants import _Constants as Constants +from ._constants import _Constants as Constants, TimeoutScope from ._cosmos_client_connection import CosmosClientConnection from ._cosmos_responses import CosmosDict, CosmosList from ._routing.routing_range import Range @@ -99,8 +99,13 @@ def __repr__(self) -> str: def _get_properties_with_options(self, options: Optional[dict[str, Any]] = None) -> dict[str, Any]: kwargs = {} - if options and "excludedLocations" in options: - kwargs['excluded_locations'] = options['excludedLocations'] + if options: + if "excludedLocations" in options: + kwargs['excluded_locations'] = options['excludedLocations'] + if Constants.OperationStartTime in options: + kwargs[Constants.OperationStartTime] = options[Constants.OperationStartTime] + if "timeout" in options: + kwargs['timeout'] = options['timeout'] return self._get_properties(**kwargs) def _get_properties(self, **kwargs: Any) -> dict[str, Any]: @@ -343,6 +348,7 @@ def read_items( query_options = build_options(kwargs) self._get_properties_with_options(query_options) query_options["enableCrossPartitionQuery"] = True + query_options[Constants.TimeoutScope] = TimeoutScope.OPERATION item_tuples = [(item_id, self._set_partition_key(pk)) for item_id, pk in items] diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/exceptions.py b/sdk/cosmos/azure-cosmos/azure/cosmos/exceptions.py index c07c0c903bda..468bb820a1f1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/exceptions.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/exceptions.py @@ -118,7 +118,9 @@ class CosmosClientTimeoutError(AzureError): """An operation failed to complete within the specified timeout.""" def __init__(self, **kwargs): - message = "Client operation failed to complete within specified timeout." + message = kwargs.pop('message', None) + if message is None: + message = "The request failed to complete within the given timeout." self.response = None self.history = None super(CosmosClientTimeoutError, self).__init__(message, **kwargs) diff --git a/sdk/cosmos/azure-cosmos/docs/TimeoutAndRetriesConfig.md b/sdk/cosmos/azure-cosmos/docs/TimeoutAndRetriesConfig.md index cc0711b2b979..e1c7cd3828d3 100644 --- a/sdk/cosmos/azure-cosmos/docs/TimeoutAndRetriesConfig.md +++ b/sdk/cosmos/azure-cosmos/docs/TimeoutAndRetriesConfig.md @@ -4,7 +4,7 @@ The timeout options for the client can be changed from the default configurations with the options below. These options can be passed in at the client constructor or on a per-request basis. These are: -- `Client Timeout`: can be changed by passing the `timeout` option. Changes the value of the per-request client timeout. If not present, +- `Client Timeout`: can be changed by passing the `timeout` option. Changes the value of the per-operation client timeout (operations like cross-partition queries can make multiple requests for instance - it would be considered as one query operation). If not present, the 'Connection Timeout' connectivity timeouts below will be used. `connection_timeout` must be smaller than your `timeout` to be used. - `Connection Timeout`: can be changed through `connection_timeout` option. Changes the value on the client's http transport timeout when connecting to the socket. Default value is 5s. diff --git a/sdk/cosmos/azure-cosmos/tests/test_changefeed_partition_key_variation.py b/sdk/cosmos/azure-cosmos/tests/test_changefeed_partition_key_variation.py index e7f4d088c21e..2e403bd38df0 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_changefeed_partition_key_variation.py +++ b/sdk/cosmos/azure-cosmos/tests/test_changefeed_partition_key_variation.py @@ -232,7 +232,7 @@ def test_partition_key_version_1_properties(self): # Simulate the version key not being in the definition - def _get_properties_override(): + def _get_properties_override(**kwargs): properties = original_get_properties() partition_key = properties["partitionKey"] partition_key.pop("version", None) # Remove version key for validation diff --git a/sdk/cosmos/azure-cosmos/tests/test_changefeed_partition_key_variation_async.py b/sdk/cosmos/azure-cosmos/tests/test_changefeed_partition_key_variation_async.py index 87f74f4af7c1..725ee56d1f2e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_changefeed_partition_key_variation_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_changefeed_partition_key_variation_async.py @@ -231,7 +231,7 @@ async def test_partition_key_version_1_properties_async(self): # Simulate the version key not being in the definition - async def _get_properties_override(): + async def _get_properties_override(**kwargs): properties = await original_get_properties() partition_key = properties["partitionKey"] partition_key.pop("version", None) # Remove version key for validation diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud.py b/sdk/cosmos/azure-cosmos/tests/test_crud.py index ec28054b7051..dc9cb5882747 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud.py @@ -30,12 +30,13 @@ class TimeoutTransport(RequestsTransport): - def __init__(self, response): + def __init__(self, response, passthrough=False): self._response = response + self.passthrough = passthrough super(TimeoutTransport, self).__init__() def send(self, *args, **kwargs): - if kwargs.pop("passthrough", False): + if self.passthrough: return super(TimeoutTransport, self).send(*args, **kwargs) time.sleep(5) @@ -1229,46 +1230,340 @@ def initialize_client_with_connection_core_retry_config(self, retries): end_time = time.time() return end_time - start_time - # TODO: Skipping this test to debug later - @unittest.skip - def test_absolute_client_timeout(self): - with self.assertRaises(exceptions.CosmosClientTimeoutError): + def test_timeout_on_connection_error(self): + # Connection Refused: This is an active rejection from the target machine's operating system. It receives your + # connection request but immediately sends back a response indicating that no process is listening on that port. + # This is a fast failure. + # Connection Timeout Setting: This occurs when your connection request receives no response at all within a + # specified period. The client gives up waiting. This typically happens if the target machine is down, + # unreachable due to network configuration, or a firewall is silently dropping the packets. + # so in the below test connection_timeout setting has no bearing on the test outcome + # catching both exceptions to make the test the test reliable in different environment as + # the underlying operating system and network stack handle the connection attempt to a non-existent port in different ways + + with self.assertRaises((exceptions.CosmosClientTimeoutError, ServiceRequestError)): cosmos_client.CosmosClient( "https://localhost:9999", TestCRUDOperations.masterKey, - "Session", - retry_total=3, - timeout=1) + retry_total=50, + connection_timeout=100, + timeout=10) + def test_timeout_on_read_operation(self): error_response = ServiceResponseError("Read timeout") - timeout_transport = TimeoutTransport(error_response) - client = cosmos_client.CosmosClient( - self.host, self.masterKey, "Session", transport=timeout_transport, passthrough=True) + # Initialize transport with passthrough enabled for client setup + timeout_transport = TimeoutTransport(error_response, passthrough=True) + client = cosmos_client.CosmosClient( + self.host, self.masterKey, "Session", transport=timeout_transport) + timeout_transport.passthrough = False with self.assertRaises(exceptions.CosmosClientTimeoutError): client.create_database_if_not_exists("test", timeout=2) - status_response = 500 # Users connection level retry - timeout_transport = TimeoutTransport(status_response) + def test_timeout_on_throttling_error(self): + # Throttling(429): Keeps retrying -> Eventually times out -> CosmosClientTimeoutError + status_response = 429 # Uses Cosmos custom retry + timeout_transport = TimeoutTransport(status_response, passthrough=True) client = cosmos_client.CosmosClient( - self.host, self.masterKey, "Session", transport=timeout_transport, passthrough=True) + self.host, self.masterKey, "Session", transport=timeout_transport) + + timeout_transport.passthrough = False with self.assertRaises(exceptions.CosmosClientTimeoutError): - client.create_database("test", timeout=2) + client.create_database_if_not_exists("test", timeout=30) - databases = client.list_databases(timeout=2) + databases = client.list_databases(timeout=29) with self.assertRaises(exceptions.CosmosClientTimeoutError): list(databases) + def test_inner_exceptions_on_timeout(self): + # Throttling(429): Keeps retrying -> Eventually times out -> CosmosClientTimeoutError status_response = 429 # Uses Cosmos custom retry - timeout_transport = TimeoutTransport(status_response) + timeout_transport = TimeoutTransport(status_response, passthrough=True) client = cosmos_client.CosmosClient( - self.host, self.masterKey, "Session", transport=timeout_transport, passthrough=True) + self.host, self.masterKey, "Session", transport=timeout_transport) + + timeout_transport.passthrough = False + with self.assertRaises(exceptions.CosmosClientTimeoutError) as cm: + client.create_database_if_not_exists("test", timeout=30) + + # Verify the inner_exception is set and is a 429 error + self.assertIsNotNone(cm.exception.inner_exception) + self.assertIsInstance(cm.exception.inner_exception, exceptions.CosmosHttpResponseError) + self.assertEqual(cm.exception.inner_exception.status_code, 429) + + def test_timeout_for_read_items(self): + """Test that timeout is properly maintained across multiple partition requests for a single logical operation + read_items is different as the results of this api are not paginated and we present the complete result set + """ + + # Create a container with multiple partitions + created_container = self.databaseForTest.create_container( + id='multi_partition_container_' + str(uuid.uuid4()), + partition_key=PartitionKey(path="/pk"), + offer_throughput=11000 + ) + pk_ranges = list(created_container.client_connection._ReadPartitionKeyRanges( + created_container.container_link)) + self.assertGreater(len(pk_ranges), 1, "Container should have multiple physical partitions.") + + # 2. Create items across different logical partitions + items_to_read = [] + all_item_ids = set() + for i in range(200): + doc_id = f"item_{i}_{uuid.uuid4()}" + pk = i % 10 + all_item_ids.add(doc_id) + created_container.create_item({'id': doc_id, 'pk': pk, 'data': i}) + items_to_read.append((doc_id, pk)) + + # Create a custom transport that introduces delays + class DelayedTransport(RequestsTransport): + def __init__(self, delay_per_request=2): + self.delay_per_request = delay_per_request + self.request_count = 0 + super().__init__() + + def send(self, request, **kwargs): + self.request_count += 1 + # Delay each request to simulate slow network + time.sleep(self.delay_per_request) + return super().send(request, **kwargs) + + # Verify timeout fails when cumulative time exceeds limit + delayed_transport = DelayedTransport(delay_per_request=2) + client_with_delay = cosmos_client.CosmosClient( + self.host, + self.masterKey, + transport=delayed_transport + ) + container_with_delay = client_with_delay.get_database_client( + self.databaseForTest.id + ).get_container_client(created_container.id) + + start_time = time.time() with self.assertRaises(exceptions.CosmosClientTimeoutError): - client.create_database_if_not_exists("test", timeout=2) + # This should timeout because multiple partition requests * 2s delay > 5s timeout + list(container_with_delay.read_items( + items = items_to_read, + timeout = 5 # 5 second total timeout + )) + + elapsed_time = time.time() - start_time + + # Should fail close to 5 seconds (not wait for all requests) + self.assertLess(elapsed_time, 7) # Allow some overhead + self.assertGreater(elapsed_time, 5) # Should wait at least close to timeout + + # Verify operation succeeds when no timeout is passed(default is close to 7 days) + start_time = time.time() + # add few more items + for i in range(500): + doc_id = f"item_{i}_{uuid.uuid4()}" + pk = i % 10 + all_item_ids.add(doc_id) + created_container.create_item({'id': doc_id, 'pk': pk, 'data': i}) + items_to_read.append((doc_id, pk)) + + items = list(container_with_delay.read_items( + items=items_to_read, + )) + + elapsed_time = time.time() - start_time + + + def test_timeout_for_paged_request(self): + """Test that timeout applies to each individual page request, not cumulatively""" + + # Create container and add items + created_container = self.databaseForTest.create_container( + id='paged_timeout_container_' + str(uuid.uuid4()), + partition_key=PartitionKey(path="/pk") + ) + + # Create enough items to ensure multiple pages + for i in range(100): + created_container.create_item({'id': f'item_{i}', 'pk': i % 10, 'data': 'x' * 1000}) + + # Create a transport that delays each request + class DelayedTransport(RequestsTransport): + def __init__(self, delay_seconds=3): + self.delay_seconds = delay_seconds + super().__init__() + + def send(self, request, **kwargs): + time.sleep(self.delay_seconds) + return super().send(request, **kwargs) + + # Test with delayed transport + delayed_transport = DelayedTransport(delay_seconds=3) + client_with_delay = cosmos_client.CosmosClient( + self.host, self.masterKey, transport=delayed_transport + ) + container_with_delay = client_with_delay.get_database_client( + self.databaseForTest.id + ).get_container_client(created_container.id) + + # Test 1: Timeout should apply per page + item_pages = container_with_delay.query_items( + query="SELECT * FROM c", + enable_cross_partition_query=True, + max_item_count=10, # Small page size + timeout=5 # Pass timeout here + ).by_page() + + # First page should succeed with 5s timeout (3s delay < 5s timeout) + first_page = list(next(item_pages)) + self.assertGreater(len(first_page), 0) + + # Second page should also succeed (timeout resets per page) + second_page = list(next(item_pages)) + self.assertGreater(len(second_page), 0) + + # Test 2: Timeout too short should fail + item_pages_short_timeout = container_with_delay.query_items( + query="SELECT * FROM c", + enable_cross_partition_query=True, + max_item_count=10, + timeout=2 # 2s timeout < 3s delay, should fail + ).by_page() - databases = client.list_databases(timeout=2) with self.assertRaises(exceptions.CosmosClientTimeoutError): - list(databases) + list(next(item_pages_short_timeout)) + + # Cleanup + self.databaseForTest.delete_container(created_container.id) + + def test_timeout_for_point_operation(self): + """Test that point operations respect client timeout""" + + # Create a container for testing + created_container = self.databaseForTest.create_container( + id='point_op_timeout_container_' + str(uuid.uuid4()), + partition_key=PartitionKey(path="/pk") + ) + + # Create a test item + test_item = { + 'id': 'test_item_1', + 'pk': 'partition1', + 'data': 'test_data' + } + created_container.create_item(test_item) + + # Test 1: Short timeout should fail + with self.assertRaises(exceptions.CosmosClientTimeoutError): + created_container.read_item( + item='test_item_1', + partition_key='partition1', + timeout=0.00000002 # very small timeout to force failure + ) + + # Test 2: Long timeout should succeed + result = created_container.read_item( + item='test_item_1', + partition_key='partition1', + timeout=3.0 + ) + self.assertEqual(result['id'], 'test_item_1') + + def test_point_operation_read_timeout(self): + """Test that point operations respect client provided read timeout""" + + # Create a container for testing + container = self.databaseForTest.create_container( + id='point_op_timeout_container_' + str(uuid.uuid4()), + partition_key=PartitionKey(path="/pk") + ) + + # Create a test item + test_item = { + 'id': 'test_item_1', + 'pk': 'partition1', + 'data': 'test_data' + } + container.create_item(test_item) + try: + container.read_item( + item='test_item_1', + partition_key='partition1', + read_timeout=0.000003 + ) + except Exception as e: + print(f"Exception is {e}") + + # TODO: for read timeouts azure-core returns a ServiceResponseError, needs to be fixed in azure-core and then this test can be enabled + @unittest.skip + def test_query_operation_single_partition_read_timeout(self): + """Test that timeout is properly maintained across multiple network requests for a single logical operation + """ + # Create a container with multiple partitions + container = self.databaseForTest.create_container( + id='single_partition_container_' + str(uuid.uuid4()), + partition_key=PartitionKey(path="/pk"), + ) + single_partition_key = 0 + + large_string = 'a' * 1000 # 1KB string + for i in range(500): # Insert 500 documents + container.create_item({ + 'id': f'item_{i}', + 'pk': single_partition_key, + 'data': large_string, + 'order_field': i + }) + + with self.assertRaises(exceptions.CosmosClientTimeoutError): + items = list(container.query_items( + query="SELECT * FROM c ORDER BY c.order_field ASC", + max_item_count=100, + read_timeout=0.00005, + partition_key=single_partition_key + )) + self.assertEqual(len(items), 500) + + # TODO: for read timeouts azure-core returns a ServiceResponseError, needs to be fixed in azure-core and then this test can be enabled + @unittest.skip + def test_query_operation_cross_partition_read_timeout(self): + """Test that timeout is properly maintained across multiple partition requests for a single logical operation + """ + # Create a container with multiple partitions + container = self.databaseForTest.create_container( + id='multi_partition_container_' + str(uuid.uuid4()), + partition_key=PartitionKey(path="/pk"), + offer_throughput=11000 + ) + + # 2. Create large documents to increase payload size + large_string = 'a' * 1000 # 1KB string + for i in range(500): # Insert 500 documents + container.create_item({ + 'id': f'item_{i}', + 'pk': i % 2, + 'data': large_string, + 'order_field': i + }) + + pk_ranges = list(container.client_connection._ReadPartitionKeyRanges( + container.container_link)) + + self.assertGreater(len(pk_ranges), 1, "Container should have multiple physical partitions.") + with self.assertRaises(exceptions.CosmosClientTimeoutError): + # This should timeout because of multiple partition requests + list(container.query_items( + query="SELECT * FROM c ORDER BY c.order_field ASC", + enable_cross_partition_query=True, + max_item_count=100, + read_timeout=0.00005, + )) + # This shouldn't result in any error because the default 65seconds is respected + + items = list(container.query_items( + query="SELECT * FROM c ORDER BY c.order_field ASC", + enable_cross_partition_query=True, + max_item_count=100, + )) + self.assertEqual(len(items), 500) + def test_query_iterable_functionality(self): collection = self.databaseForTest.create_container("query-iterable-container", diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py index debda632b717..8e6a3529ff86 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py @@ -12,12 +12,12 @@ import urllib.parse as urllib import uuid from asyncio import sleep - +from unittest import mock import pytest -import requests +from aiohttp import web from azure.core import MatchConditions from azure.core.exceptions import AzureError, ServiceResponseError, ServiceRequestError -from azure.core.pipeline.transport import AsyncioRequestsTransport, AsyncioRequestsTransportResponse +from azure.core.pipeline.transport import AioHttpTransport, AioHttpTransportResponse import azure.cosmos.documents as documents import azure.cosmos.exceptions as exceptions @@ -27,23 +27,28 @@ from azure.cosmos.partition_key import PartitionKey -class TimeoutTransport(AsyncioRequestsTransport): +class TimeoutTransport(AioHttpTransport): - def __init__(self, response): + def __init__(self, response, passthrough=False): self._response = response - super(TimeoutTransport, self).__init__() + self.passthrough = passthrough + super().__init__() - async def send(self, *args, **kwargs): - if kwargs.pop("passthrough", False): - return super(TimeoutTransport, self).send(*args, **kwargs) + async def send(self, request, **kwargs): + if self.passthrough: + return await super().send(request, **kwargs) - time.sleep(5) + await asyncio.sleep(5) if isinstance(self._response, Exception): + # The SDK's retry logic wraps the original error in a ServiceRequestError. + # We raise it this way to properly simulate a transport-level timeout. raise self._response - current_response = await self._response - output = requests.Response() - output.status_code = current_response - response = AsyncioRequestsTransportResponse(None, output) + + # This part of the mock is for simulating successful responses, not used in this specific timeout test. + raw_response = web.Response(status=self._response, reason="mock") + raw_response.read = mock.AsyncMock(return_value=b"") + response = AioHttpTransportResponse(request, raw_response) + await response.load_body() return response @pytest.mark.cosmosCircuitBreaker @@ -973,52 +978,306 @@ async def initialize_client_with_connection_urllib_retry_config(self, retries): end_time = time.time() return end_time - start_time - # TODO: Skipping this test to debug later - @unittest.skip - async def test_absolute_client_timeout_async(self): - with self.assertRaises(exceptions.CosmosClientTimeoutError): - async with CosmosClient( + async def test_timeout_on_connection_error_async(self): + # Connection Refused: This is an active rejection from the target machine's operating system. It receives your + # connection request but immediately sends back a response indicating that no process is listening on that port. + # This is a fast failure. + # Connection Timeout Setting: This occurs when your connection request receives no response at all within a + # specified period. The client gives up waiting. This typically happens if the target machine is down, + # unreachable due to network configuration, or a firewall is silently dropping the packets. + # so in the below test connection_timeout setting has no bearing on the test outcome + client = None + try: + with self.assertRaises((exceptions.CosmosClientTimeoutError, ServiceRequestError)): + client = CosmosClient( "https://localhost:9999", TestCRUDOperationsAsync.masterKey, - retry_total=3, - timeout=1) as client: - print('Async initialization') - + retry_total=10, + connection_timeout=100, + timeout=2) + # The __aenter__ call is what triggers the connection attempt. + await client.__aenter__() + finally: + if client: + await client.close() + + async def test_timeout_on_read_operation_async(self): error_response = ServiceResponseError("Read timeout") - timeout_transport = TimeoutTransport(error_response) - async with CosmosClient( - self.host, self.masterKey, transport=timeout_transport, - passthrough=True) as client: - print('Async initialization') + # Initialize transport with passthrough enabled for client setup + timeout_transport = TimeoutTransport(error_response, passthrough=True) + async with CosmosClient( + self.host, self.masterKey, transport=timeout_transport) as client: + # Disable passthrough to test the timeout on the next call + timeout_transport.passthrough = False with self.assertRaises(exceptions.CosmosClientTimeoutError): await client.create_database_if_not_exists("test", timeout=2) - status_response = 500 # Users connection level retry - timeout_transport = TimeoutTransport(status_response) + + async def test_timeout_on_throttling_error_async(self): + # Throttling(429): Keeps retrying -> Eventually times out -> CosmosClientTimeoutError + status_response = 429 # Uses Cosmos custom retry + timeout_transport = TimeoutTransport(status_response, passthrough=True) async with CosmosClient( - self.host, self.masterKey, transport=timeout_transport, - passthrough=True) as client: + self.host, self.masterKey, transport=timeout_transport + ) as client: print('Async initialization') + timeout_transport.passthrough = False with self.assertRaises(exceptions.CosmosClientTimeoutError): - await client.create_database("test", timeout=2) + await client.create_database_if_not_exists("test", timeout=20) - databases = client.list_databases(timeout=2) + databases = client.list_databases(timeout=12) with self.assertRaises(exceptions.CosmosClientTimeoutError): databases = [database async for database in databases] + async def test_inner_exceptions_on_timeout_async(self): + # Throttling(429): Keeps retrying -> Eventually times out -> CosmosClientTimeoutError status_response = 429 # Uses Cosmos custom retry - timeout_transport = TimeoutTransport(status_response) + timeout_transport = TimeoutTransport(status_response, passthrough=True) async with CosmosClient( - self.host, self.masterKey, transport=timeout_transport, - passthrough=True) as client: + self.host, self.masterKey, transport=timeout_transport + ) as client: print('Async initialization') + timeout_transport.passthrough = False + with self.assertRaises(exceptions.CosmosClientTimeoutError) as cm : + await client.create_database_if_not_exists("test", timeout=20) + + # Verify the inner_exception is set and is a 429 error + self.assertIsNotNone(cm.exception.inner_exception) + self.assertIsInstance(cm.exception.inner_exception, exceptions.CosmosHttpResponseError) + self.assertEqual(cm.exception.inner_exception.status_code, 429) + + async def test_timeout_for_read_items_async(self): + """Test that timeout is properly maintained across multiple partition requests for a single logical operation + read_items is different as the results of this api are not paginated and we present the complete result set + """ + + # Create a container with multiple partitions + created_container = await self.database_for_test.create_container( + id='multi_partition_container_' + str(uuid.uuid4()), + partition_key=PartitionKey(path="/pk"), + offer_throughput=11000 + ) + pk_ranges = [ + pk async for pk in + created_container.client_connection._ReadPartitionKeyRanges(created_container.container_link) + ] + self.assertGreater(len(pk_ranges), 1, "Container should have multiple physical partitions.") + + # 2. Create items across different logical partitions + items_to_read = [] + all_item_ids = set() + for i in range(200): + doc_id = f"item_{i}_{uuid.uuid4()}" + pk = i % 10 + all_item_ids.add(doc_id) + await created_container.create_item({'id': doc_id, 'pk': pk, 'data': i}) + items_to_read.append((doc_id, pk)) + + # Create a custom transport that introduces delays + class DelayedTransport(AioHttpTransport): + def __init__(self, delay_per_request=2): + self.delay_per_request = delay_per_request + self.request_count = 0 + super().__init__() + + async def send(self, request, **kwargs): + self.request_count += 1 + # Delay each request to simulate slow network + await asyncio.sleep(self.delay_per_request) # 2 second delaytime.sleep(self.delay_per_request) + return await super().send(request, **kwargs) + + # Verify timeout fails when cumulative time exceeds limit + delayed_transport = DelayedTransport(delay_per_request=2) + + async with CosmosClient( + self.host, self.masterKey, transport=delayed_transport + ) as client_with_delay: + + container_with_delay = client_with_delay.get_database_client( + self.database_for_test.id + ).get_container_client(created_container.id) + + start_time = time.time() + with self.assertRaises(exceptions.CosmosClientTimeoutError): - await client.create_database_if_not_exists("test", timeout=2) + # This should timeout because multiple partition requests * 2s delay > 5s timeout + await container_with_delay.read_items( + items=items_to_read, + timeout=5 # 5 second total timeout + ) + + elapsed_time = time.time() - start_time + # Should fail close to 5 seconds (not wait for all requests) + self.assertLess(elapsed_time, 7) # Allow some overhead + self.assertGreater(elapsed_time, 5) # Should wait at least close to timeout + + + async def test_timeout_for_point_operation_async(self): + """Test that point operations respect client timeout""" + + # Create a container for testing + created_container = await self.database_for_test.create_container( + id='point_op_timeout_container_' + str(uuid.uuid4()), + partition_key=PartitionKey(path="/pk") + ) + + # Create a test item + test_item = { + 'id': 'test_item_1', + 'pk': 'partition1', + 'data': 'test_data' + } + await created_container.create_item(test_item) + + # Long timeout should succeed + result = await created_container.read_item( + item='test_item_1', + partition_key='partition1', + timeout=1.0 # 1 second timeout + ) + self.assertEqual(result['id'], 'test_item_1') + + async def test_timeout_for_paged_request_async(self): + """Test that timeout applies to each individual page request, not cumulatively""" + + # Create container and add items + created_container = await self.database_for_test.create_container( + id='paged_timeout_container_' + str(uuid.uuid4()), + partition_key=PartitionKey(path="/pk") + ) + + # Create enough items to ensure multiple pages + for i in range(100): + await created_container.create_item({'id': f'item_{i}', 'pk': i % 10, 'data': 'x' * 1000}) + + # Create a transport that delays each request + class DelayedTransport(AioHttpTransport): + def __init__(self, delay_seconds=1): + self.delay_seconds = delay_seconds + super().__init__() + + async def send(self, request, **kwargs): + await asyncio.sleep(self.delay_seconds) + return await super().send(request, **kwargs) + + # Test with delayed transport - reduced delay to 1 second for stability + delayed_transport = DelayedTransport(delay_seconds=1) + async with CosmosClient( + self.host, self.masterKey, transport=delayed_transport + ) as client_with_delay: + container_with_delay = client_with_delay.get_database_client( + self.database_for_test.id + ).get_container_client(created_container.id) + + # Test 1: Timeout should apply per page + item_pages = container_with_delay.query_items( + query="SELECT * FROM c", + #enable_cross_partition_query=True, + max_item_count=10, # Small page size + timeout=3 # 3s timeout > 1s delay, should succeed + ).by_page() + + # First page should succeed with 3s timeout (1s delay < 3s timeout) + first_page = [item async for item in await item_pages.__anext__()] + self.assertGreater(len(first_page), 0) + + # Second page should also succeed (timeout resets per page) + second_page = [item async for item in await item_pages.__anext__()] + self.assertGreater(len(second_page), 0) + + # Test 2: Timeout too short should fail + item_pages_short_timeout = container_with_delay.query_items( + query="SELECT * FROM c", + #enable_cross_partition_query=True, + max_item_count=10, + timeout=0.5 # 0.5s timeout < 1s delay, should fail + ).by_page() - databases = client.list_databases(timeout=2) with self.assertRaises(exceptions.CosmosClientTimeoutError): - databases = [database async for database in databases] + first_page = [item async for item in await item_pages_short_timeout.__anext__()] + + # Cleanup + await self.database_for_test.delete_container(created_container.id) + + # TODO: for read timeouts azure-core returns a ServiceResponseError, needs to be fixed in azure-core and then this test can be enabled + @unittest.skip + async def test_query_operation_single_partition_read_timeout_async(self): + """Test that timeout is properly maintained across multiple network requests for a single logical operation + """ + # Create a container with multiple partitions + container = await self.database_for_test.create_container( + id='single_partition_container_' + str(uuid.uuid4()), + partition_key=PartitionKey(path="/pk"), + ) + single_partition_key = 0 + + large_string = 'a' * 1000 # 1KB string + for i in range(200): # Insert 500 documents + await container.create_item({ + 'id': f'item_{i}', + 'pk': single_partition_key, + 'data': large_string, + 'order_field': i + }) + + start_time = time.time() + with self.assertRaises(exceptions.CosmosClientTimeoutError): + async for item in container.query_items( + query="SELECT * FROM c ORDER BY c.order_field ASC", + max_item_count=200, + read_timeout=0.00005, + partition_key=single_partition_key + ): + pass + + elapsed_time = time.time() - start_time + print(f"elapsed time is {elapsed_time}") + + # TODO: for read timeouts azure-core returns a ServiceResponseError, needs to be fixed in azure-core and then this test can be enabled + @unittest.skip + async def test_query_operation_cross_partition_read_timeout_async(self): + """Test that timeout is properly maintained across multiple partition requests for a single logical operation + """ + # Create a container with multiple partitions + container = await self.database_for_test.create_container( + id='multi_partition_container_' + str(uuid.uuid4()), + partition_key=PartitionKey(path="/pk"), + offer_throughput=11000 + ) + + # 2. Create large documents to increase payload size + large_string = 'a' * 1000 # 1KB string + for i in range(1000): # Insert 500 documents + await container.create_item({ + 'id': f'item_{i}', + 'pk': i % 2, + 'data': large_string, + 'order_field': i + }) + + pk_ranges = [ + pk async for pk in container.client_connection._ReadPartitionKeyRanges(container.container_link) + ] + + self.assertGreater(len(pk_ranges), 1, "Container should have multiple physical partitions.") + + with self.assertRaises(exceptions.CosmosClientTimeoutError): + # This should timeout because of multiple partition requests + items = [doc async for doc in container.query_items( + query="SELECT * FROM c ORDER BY c.order_field ASC", + max_item_count=100, + read_timeout=0.00005, + )] + # This shouldn't result in any error because the default 65seconds is respected + + items = [doc async for doc in container.query_items( + query="SELECT * FROM c ORDER BY c.order_field ASC", + max_item_count=100, + )] + self.assertEqual(len(items), 1000) + + async def test_query_iterable_functionality_async(self): diff --git a/sdk/cosmos/azure-cosmos/tests/test_none_options.py b/sdk/cosmos/azure-cosmos/tests/test_none_options.py new file mode 100644 index 000000000000..bcd08a11d799 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_none_options.py @@ -0,0 +1,172 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +import unittest +import uuid + +import pytest + +from azure.cosmos import CosmosClient +import test_config +from azure.cosmos.exceptions import CosmosHttpResponseError + + +@pytest.mark.cosmosEmulator +class TestNoneOptions(unittest.TestCase): + configs = test_config.TestConfig + host = configs.host + masterKey = configs.masterKey + connectionPolicy = configs.connectionPolicy + + def setUp(self) -> None: + self.client = CosmosClient(self.host, self.masterKey) + self.database = self.client.get_database_client(self.configs.TEST_DATABASE_ID) + self.container = self.database.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + + def _create_sample_item(self): + item = {"id": str(uuid.uuid4()), "pk": "pk-value", "value": 42} + self.container.create_item(item, pre_trigger_include=None, post_trigger_include=None, indexing_directive=None, + enable_automatic_id_generation=False, session_token=None, initial_headers=None, + priority=None, no_response=None, retry_write=None, throughput_bucket=None) + return item + + def test_container_read_none_options(self): + result = self.container.read(populate_partition_key_range_statistics=None, populate_quota_info=None, + priority=None, initial_headers=None) + assert result + + def test_container_create_item_none_options(self): + item = {"id": str(uuid.uuid4()), "pk": "pk-value", "value": 1} + created = self.container.create_item(item, pre_trigger_include=None, post_trigger_include=None, + indexing_directive=None, enable_automatic_id_generation=False, + session_token=None, initial_headers=None, priority=None, no_response=None, + retry_write=None, throughput_bucket=None) + assert created["id"] == item["id"] + + def test_container_read_item_none_options(self): + item = self._create_sample_item() + read_back = self.container.read_item(item["id"], partition_key=item["pk"], post_trigger_include=None, + session_token=None, initial_headers=None, + max_integrated_cache_staleness_in_ms=None, priority=None, + throughput_bucket=None) + assert read_back["id"] == item["id"] + + def test_container_read_all_items_none_options(self): + self._create_sample_item() + pager = self.container.read_all_items(max_item_count=None, session_token=None, initial_headers=None, + max_integrated_cache_staleness_in_ms=None, priority=None, + throughput_bucket=None) + items = list(pager) + assert len(items) >= 1 + + def test_container_read_items_none_options(self): + item = self._create_sample_item() + results = self.container.read_items([(item["id"], item["pk"])], max_concurrency=None, consistency_level=None, + session_token=None, initial_headers=None, excluded_locations=None, + priority=None, throughput_bucket=None) + assert any(r["id"] == item["id"] for r in results) + + def test_container_query_items_none_options_partition(self): + self._create_sample_item() + pager = self.container.query_items("SELECT * FROM c", continuation_token_limit=None, enable_scan_in_query=None, + initial_headers=None, max_integrated_cache_staleness_in_ms=None, max_item_count=None, + parameters=None, partition_key=None, populate_index_metrics=None, + populate_query_metrics=None, priority=None, response_hook=None, session_token=None, + throughput_bucket=None, enable_cross_partition_query=True) + items = list(pager) + assert len(items) >= 1 + + def test_upsert_item_none_options(self): + item = {"id": str(uuid.uuid4()), "pk": "pk-value", "value": 5} + upserted = self.container.upsert_item(item, pre_trigger_include=None, post_trigger_include=None, + session_token=None, initial_headers=None, etag=None, + match_condition=None, priority=None, no_response=None, + retry_write=None, throughput_bucket=None) + assert upserted["id"] == item["id"] + + def test_replace_item_none_options(self): + item = self._create_sample_item() + new_body = {"id": item["id"], "pk": item["pk"], "value": 999} + replaced = self.container.replace_item(item["id"], new_body, pre_trigger_include=None, + post_trigger_include=None, session_token=None, + initial_headers=None, etag=None, match_condition=None, + priority=None, no_response=None, retry_write=None, + throughput_bucket=None) + assert replaced["value"] == 999 + + def test_patch_item_none_options(self): + item = self._create_sample_item() + operations = [{"op": "add", "path": "/patched", "value": True}] + patched = self.container.patch_item(item["id"], partition_key=item["pk"], patch_operations=operations, + filter_predicate=None, pre_trigger_include=None, post_trigger_include=None, + session_token=None, etag=None, match_condition=None, priority=None, + no_response=None, retry_write=None, throughput_bucket=None) + assert patched["patched"] is True + + def test_delete_item_none_options(self): + item = self._create_sample_item() + self.container.delete_item(item["id"], partition_key=item["pk"], pre_trigger_include=None, + post_trigger_include=None, session_token=None, initial_headers=None, + etag=None, match_condition=None, priority=None, retry_write=None, + throughput_bucket=None) + with self.assertRaises(CosmosHttpResponseError): + self.container.read_item(item["id"], partition_key=item["pk"], post_trigger_include=None, + session_token=None, initial_headers=None, + max_integrated_cache_staleness_in_ms=None, priority=None, + throughput_bucket=None) + + def test_get_throughput_none_options(self): + tp = self.container.get_throughput(response_hook=None) + assert tp.offer_throughput > 0 + + def test_list_conflicts_none_options(self): + pager = self.container.list_conflicts(max_item_count=None, response_hook=None) + conflicts = list(pager) + assert conflicts == conflicts # may be empty + + def test_query_conflicts_none_options(self): + pager = self.container.query_conflicts("SELECT * FROM c", parameters=None, partition_key=None, + max_item_count=None, response_hook=None, enable_cross_partition_query=True) + conflicts = list(pager) + assert conflicts == conflicts + + def test_delete_all_items_by_partition_key_none_options(self): + pk_value = "delete-pk" + for _ in range(2): + item = {"id": str(uuid.uuid4()), "pk": pk_value, "value": 1} + self.container.create_item(item, pre_trigger_include=None, post_trigger_include=None, indexing_directive=None, + enable_automatic_id_generation=False, session_token=None, initial_headers=None, + priority=None, no_response=None, retry_write=None, throughput_bucket=None) + self.container.delete_all_items_by_partition_key(pk_value, pre_trigger_include=None, post_trigger_include=None, + session_token=None, throughput_bucket=None) + pager = self.container.query_items("SELECT * FROM c WHERE c.pk = @pk", parameters=[{"name": "@pk", "value": pk_value}], + partition_key=None, continuation_token_limit=None, enable_scan_in_query=None, + initial_headers=None, max_integrated_cache_staleness_in_ms=None, max_item_count=None, + populate_index_metrics=None, populate_query_metrics=None, priority=None, + response_hook=None, session_token=None, throughput_bucket=None) + _items = list(pager) + assert _items == _items + + def test_execute_item_batch_none_options(self): + pk_value = "batch-pk" + id1 = str(uuid.uuid4()) + id2 = str(uuid.uuid4()) + ops = [ + ("create", ({"id": id1, "pk": pk_value},)), + ("create", ({"id": id2, "pk": pk_value},)), + ] + batch_result = self.container.execute_item_batch(ops, partition_key=pk_value, + pre_trigger_include=None, post_trigger_include=None, + session_token=None, priority=None, + throughput_bucket=None) + assert any(r.get("resourceBody").get("id") == id1 for r in batch_result) or any(r.get("resourceBody").get("id") == id2 for r in batch_result) + + def test_query_items_change_feed_none_options(self): + #Create an item, then acquire the change feed pager to verify the item appears in the feed. + self.container.create_item({"id": str(uuid.uuid4()), "pk": "cf-pk", "value": 100}, + pre_trigger_include=None, post_trigger_include=None, indexing_directive=None, + enable_automatic_id_generation=False, session_token=None, initial_headers=None, + priority=None, no_response=None, retry_write=None, throughput_bucket=None) + pager = self.container.query_items_change_feed(max_item_count=None, start_time="Beginning", partition_key=None, + priority=None, mode=None, response_hook=None) + changes = list(pager) + assert len(changes) >= 1 diff --git a/sdk/cosmos/azure-cosmos/tests/test_none_options_async.py b/sdk/cosmos/azure-cosmos/tests/test_none_options_async.py new file mode 100644 index 000000000000..fd66b4c6e454 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_none_options_async.py @@ -0,0 +1,182 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +import unittest +import uuid + +import pytest + +from azure.cosmos.aio import CosmosClient +import test_config +from azure.cosmos.exceptions import CosmosHttpResponseError + + +@pytest.mark.cosmosEmulator +class TestNoneOptionsAsync(unittest.IsolatedAsyncioTestCase): + configs = test_config.TestConfig + host = configs.host + masterKey = configs.masterKey + connectionPolicy = configs.connectionPolicy + + async def asyncSetUp(self): + self.client = CosmosClient(self.host, self.masterKey) + await self.client.__aenter__() + self.database = self.client.get_database_client(self.configs.TEST_DATABASE_ID) + self.container = self.database.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + + async def asyncTearDown(self): + await self.client.close() + + async def _create_sample_item(self): + item = {"id": str(uuid.uuid4()), "pk": "pk-value", "value": 42} + await self.container.create_item(item, pre_trigger_include=None, post_trigger_include=None, indexing_directive=None, + enable_automatic_id_generation=False, session_token=None, initial_headers=None, + priority=None, no_response=None, retry_write=None, throughput_bucket=None) + return item + + async def test_container_read_none_options_async(self): + result = await self.container.read(populate_partition_key_range_statistics=None, populate_quota_info=None, + priority=None, initial_headers=None) + assert result + + async def test_container_create_item_none_options_async(self): + item = {"id": str(uuid.uuid4()), "pk": "pk-value", "value": 1} + created = await self.container.create_item(item, pre_trigger_include=None, post_trigger_include=None, + indexing_directive=None, enable_automatic_id_generation=False, + session_token=None, initial_headers=None, priority=None, no_response=None, + retry_write=None, throughput_bucket=None) + assert created["id"] == item["id"] + + async def test_container_read_item_none_options_async(self): + item = await self._create_sample_item() + read_back = await self.container.read_item(item["id"], partition_key=item["pk"], post_trigger_include=None, + session_token=None, initial_headers=None, + max_integrated_cache_staleness_in_ms=None, priority=None, + throughput_bucket=None) + assert read_back["id"] == item["id"] + + async def test_container_read_all_items_none_options_async(self): + await self._create_sample_item() + pager = self.container.read_all_items(max_item_count=None, session_token=None, initial_headers=None, + max_integrated_cache_staleness_in_ms=None, priority=None, + throughput_bucket=None) + items = [item async for item in pager] + assert len(items) >= 1 + + async def test_container_read_items_none_options_async(self): + item = await self._create_sample_item() + results = await self.container.read_items([(item["id"], item["pk"])], max_concurrency=None, consistency_level=None, + session_token=None, initial_headers=None, excluded_locations=None, + priority=None, throughput_bucket=None) + assert any(r["id"] == item["id"] for r in results) + + async def test_container_query_items_none_options_partition_async(self): + await self._create_sample_item() + pager = self.container.query_items("SELECT * FROM c", continuation_token_limit=None, enable_scan_in_query=None, + initial_headers=None, max_integrated_cache_staleness_in_ms=None, max_item_count=None, + parameters=None, partition_key=None, populate_index_metrics=None, + populate_query_metrics=None, priority=None, response_hook=None, session_token=None, + throughput_bucket=None) + items = [doc async for doc in pager] + assert len(items) >= 1 + + async def test_upsert_item_none_options_async(self): + item = {"id": str(uuid.uuid4()), "pk": "pk-value", "value": 5} + upserted = await self.container.upsert_item(item, pre_trigger_include=None, post_trigger_include=None, + session_token=None, initial_headers=None, etag=None, + match_condition=None, priority=None, no_response=None, + retry_write=None, throughput_bucket=None) + assert upserted["id"] == item["id"] + + async def test_replace_item_none_options_async(self): + item = await self._create_sample_item() + new_body = {"id": item["id"], "pk": item["pk"], "value": 999} + replaced = await self.container.replace_item(item["id"], new_body, pre_trigger_include=None, + post_trigger_include=None, session_token=None, + initial_headers=None, etag=None, match_condition=None, + priority=None, no_response=None, retry_write=None, + throughput_bucket=None) + assert replaced["value"] == 999 + + async def test_patch_item_none_options_async(self): + item = await self._create_sample_item() + operations = [{"op": "add", "path": "/patched", "value": True}] + patched = await self.container.patch_item(item["id"], partition_key=item["pk"], patch_operations=operations, + filter_predicate=None, pre_trigger_include=None, post_trigger_include=None, + session_token=None, etag=None, match_condition=None, priority=None, + no_response=None, retry_write=None, throughput_bucket=None) + assert patched["patched"] is True + + async def test_delete_item_none_options_async(self): + item = await self._create_sample_item() + await self.container.delete_item(item["id"], partition_key=item["pk"], pre_trigger_include=None, + post_trigger_include=None, session_token=None, initial_headers=None, + etag=None, match_condition=None, priority=None, retry_write=None, + throughput_bucket=None) + with self.assertRaises(CosmosHttpResponseError): + await self.container.read_item(item["id"], partition_key=item["pk"], post_trigger_include=None, + session_token=None, initial_headers=None, + max_integrated_cache_staleness_in_ms=None, priority=None, + throughput_bucket=None) + + async def test_get_throughput_none_options_async(self): + tp = await self.container.get_throughput(response_hook=None) + assert tp.offer_throughput > 0 + + async def test_list_conflicts_none_options_async(self): + pager = self.container.list_conflicts(max_item_count=None, response_hook=None) + conflicts = [c async for c in pager] + assert conflicts == conflicts # simple sanity (may be empty) + + async def test_query_conflicts_none_options_async(self): + pager = self.container.query_conflicts("SELECT * FROM c", parameters=None, partition_key=None, + max_item_count=None, response_hook=None) + conflicts = [c async for c in pager] + assert conflicts == conflicts + + async def test_delete_all_items_by_partition_key_none_options_async(self): + pk_value = "delete-pk" + for _ in range(2): + item = {"id": str(uuid.uuid4()), "pk": pk_value, "value": 1} + await self.container.create_item(item, pre_trigger_include=None, post_trigger_include=None, + indexing_directive=None, enable_automatic_id_generation=False, + session_token=None, initial_headers=None, priority=None, no_response=None, + retry_write=None, throughput_bucket=None) + await self.container.delete_all_items_by_partition_key(pk_value, pre_trigger_include=None, + post_trigger_include=None, session_token=None, + throughput_bucket=None) + # Just ensure query still works with None options + pager = self.container.query_items("SELECT * FROM c WHERE c.pk = @pk", parameters=[{"name": "@pk", "value": pk_value}], + partition_key=None, continuation_token_limit=None, enable_scan_in_query=None, + initial_headers=None, max_integrated_cache_staleness_in_ms=None, max_item_count=None, + populate_index_metrics=None, populate_query_metrics=None, priority=None, + response_hook=None, session_token=None, throughput_bucket=None) + _items = [doc async for doc in pager] + assert _items == _items + + async def test_execute_item_batch_none_options_async(self): + pk_value = "batch-pk" + id1 = str(uuid.uuid4()) + id2 = str(uuid.uuid4()) + ops = [ + ("create", ({"id": id1, "pk": pk_value},)), + ("create", ({"id": id2, "pk": pk_value},)), + ] + batch_result = await self.container.execute_item_batch(ops, partition_key=pk_value, + pre_trigger_include=None, post_trigger_include=None, + session_token=None, priority=None, + throughput_bucket=None) + assert any(r.get("resourceBody").get("id") == id1 for r in batch_result) or any(r.get("resourceBody").get("id") == id2 for r in batch_result) + + + async def test_query_items_change_feed_none_options_async(self): + for _ in range(15): + await self.container.create_item({"id": str(uuid.uuid4()), "pk": "cf-pk", "value": 100}, + pre_trigger_include=None, post_trigger_include=None, indexing_directive=None, + enable_automatic_id_generation=False, session_token=None, initial_headers=None, + priority=None, no_response=None, retry_write=None, throughput_bucket=None) + # Obtain the change feed pager with all optional parameters set to None (including partition_key) + pager = self.container.query_items_change_feed(max_item_count=None, start_time="Beginning", partition_key=None, + priority=None, mode=None, response_hook=None) + + changes = [doc async for doc in pager] + assert len(changes) >= 1 diff --git a/sdk/cosmos/azure-cosmos/tests/test_read_items.py b/sdk/cosmos/azure-cosmos/tests/test_read_items.py index 71196dc9e637..db770c88d363 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_read_items.py +++ b/sdk/cosmos/azure-cosmos/tests/test_read_items.py @@ -8,12 +8,11 @@ import azure.cosmos.cosmos_client as cosmos_client import azure.cosmos.exceptions as exceptions import test_config -from azure.cosmos import PartitionKey, CosmosDict +from azure.cosmos import PartitionKey from _fault_injection_transport import FaultInjectionTransport from azure.cosmos._resource_throttle_retry_policy import ResourceThrottleRetryPolicy from azure.cosmos.exceptions import CosmosHttpResponseError from azure.cosmos.documents import _OperationType -from azure.core.utils import CaseInsensitiveDict @pytest.mark.cosmosEmulator class TestReadItems(unittest.TestCase): @@ -484,13 +483,13 @@ def test_read_items_concurrency_internals(self): call_args = mock_query.call_args_list # Extract the number of parameters from each call. chunk_sizes = [len(call[0][1]['parameters']) for call in call_args] - # Sort the chunk sizes to make the assertion deterministic. chunk_sizes.sort(reverse=True) self.assertEqual(chunk_sizes[0], 1000) self.assertEqual(chunk_sizes[1], 1000) self.assertEqual(chunk_sizes[2], 500) + def test_read_items_multiple_physical_partitions_and_hook(self): """Tests read_items on a container with multiple physical partitions and verifies response_hook.""" # Create a container with high throughput to force multiple physical partitions diff --git a/sdk/cosmos/azure-cosmos/tests/workloads/run_workloads.sh b/sdk/cosmos/azure-cosmos/tests/workloads/run_workloads.sh index 2576bc8b2563..9d3f6ecbf62f 100755 --- a/sdk/cosmos/azure-cosmos/tests/workloads/run_workloads.sh +++ b/sdk/cosmos/azure-cosmos/tests/workloads/run_workloads.sh @@ -23,5 +23,4 @@ for file in ./*_workload.py; do done done -wait echo "[Info] All workloads started successfully."