Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
## Release History

### 4.14.3 (Unreleased)

#### Bugs Fixed
* Fixed bug where client timeout/read_timeout values were not properly enforced. See [PR 42652](https://github.com/Azure/azure-sdk-for-python/pull/42652).
* Fixed bug when passing in None for some option in `query_items` would cause unexpected errors. See [PR 44098](https://github.com/Azure/azure-sdk-for-python/pull/44098)

### 4.14.2 (2025-11-14)

#### Features Added
Expand Down
8 changes: 8 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"""

import base64
import time
from email.utils import formatdate
import json
import uuid
Expand Down Expand Up @@ -109,6 +110,13 @@ def build_options(kwargs: dict[str, Any]) -> dict[str, Any]:
for key, value in _COMMON_OPTIONS.items():
if key in kwargs:
options[value] = kwargs.pop(key)
if 'read_timeout' in kwargs:
options['read_timeout'] = kwargs['read_timeout']
if 'timeout' in kwargs:
options['timeout'] = kwargs['timeout']


options[Constants.OperationStartTime] = time.time()
if_match, if_none_match = _get_match_headers(kwargs)
if if_match:
options['accessCondition'] = {'type': 'IfMatch', 'condition': if_match}
Expand Down
9 changes: 8 additions & 1 deletion sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,21 @@

from typing_extensions import Literal

# cspell:ignore reranker
class TimeoutScope:
"""Defines the scope of timeout application"""
OPERATION: Literal["operation"] = "operation" # Apply timeout to entire logical operation
PAGE: Literal["page"] = "page" # Apply timeout to individual page requests

# cspell:ignore reranker

class _Constants:
"""Constants used in the azure-cosmos package"""

UserConsistencyPolicy: Literal["userConsistencyPolicy"] = "userConsistencyPolicy"
DefaultConsistencyLevel: Literal["defaultConsistencyLevel"] = "defaultConsistencyLevel"
OperationStartTime: Literal["operationStartTime"] = "operationStartTime"
# whether to apply timeout to the whole logical operation or just a page request
TimeoutScope: Literal["timeoutScope"] = "timeoutScope"

# GlobalDB related constants
WritableLocations: Literal["writableLocations"] = "writableLocations"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3169,6 +3169,18 @@ def __QueryFeed( # pylint: disable=too-many-locals, too-many-statements, too-ma
"""
if options is None:
options = {}
read_timeout = options.get("read_timeout")
if read_timeout is not None:
# we currently have a gap where kwargs are not getting passed correctly down the pipeline. In order to make
# absolute time out work, we are passing read_timeout via kwargs as a temporary fix
kwargs.setdefault("read_timeout", read_timeout)

operation_start_time = options.get(Constants.OperationStartTime)
if operation_start_time is not None:
kwargs.setdefault(Constants.OperationStartTime, operation_start_time)
timeout = options.get("timeout")
if timeout is not None:
kwargs.setdefault("timeout", timeout)

if query:
__GetBodiesFromQueryResult = result_fn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,17 @@ async def _fetch_items_helper_no_retries(self, fetch_function):
return fetched_items

async def _fetch_items_helper_with_retries(self, fetch_function):
async def callback():
# TODO: Properly propagate kwargs from retry utility to fetch function
# the callback keep the **kwargs parameter to maintain compatibility with the retry utility's execution pattern.
# ExecuteAsync passes retry context parameters (timeout, operation start time, logger, etc.)
# The callback need to accept these parameters even if unused
# Removing **kwargs results in a TypeError when ExecuteAsync tries to pass these parameters
async def callback(**kwargs): # pylint: disable=unused-argument
return await self._fetch_items_helper_no_retries(fetch_function)

return await _retry_utility_async.ExecuteAsync(self._client, self._client._global_endpoint_manager, callback)
return await _retry_utility_async.ExecuteAsync(
self._client, self._client._global_endpoint_manager, callback, **self._options
)


class _DefaultQueryExecutionContext(_QueryExecutionContextBase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,19 @@ def __init__(self, client, resource_link, query, options, fetch_function,
async def _create_execution_context_with_query_plan(self):
self._fetched_query_plan = True
query_to_use = self._query if self._query is not None else "Select * from root r"
query_execution_info = _PartitionedQueryExecutionInfo(await self._client._GetQueryPlanThroughGateway
(query_to_use, self._resource_link, self._options.get('excludedLocations')))
query_plan = await self._client._GetQueryPlanThroughGateway(
query_to_use,
self._resource_link,
self._options.get('excludedLocations'),
read_timeout=self._options.get('read_timeout')
)
query_execution_info = _PartitionedQueryExecutionInfo(query_plan)
qe_info = getattr(query_execution_info, "_query_execution_info", None)
if isinstance(qe_info, dict) and isinstance(query_to_use, dict):
params = query_to_use.get("parameters")
if params is not None:
query_execution_info._query_execution_info['parameters'] = params

self._execution_context = await self._create_pipelined_execution_context(query_execution_info)

async def __anext__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,15 @@ def _fetch_items_helper_no_retries(self, fetch_function):
return fetched_items

def _fetch_items_helper_with_retries(self, fetch_function):
def callback():
# TODO: Properly propagate kwargs from retry utility to fetch function
# the callback keep the **kwargs parameter to maintain compatibility with the retry utility's execution pattern.
# ExecuteAsync passes retry context parameters (timeout, operation start time, logger, etc.)
# The callback need to accept these parameters even if unused
# Removing **kwargs results in a TypeError when ExecuteAsync tries to pass these parameters
def callback(**kwargs): # pylint: disable=unused-argument
return self._fetch_items_helper_no_retries(fetch_function)

return _retry_utility.Execute(self._client, self._client._global_endpoint_manager, callback)
return _retry_utility.Execute(self._client, self._client._global_endpoint_manager, callback, **self._options)

next = __next__ # Python 2 compatibility.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,19 @@ def __init__(self, client, resource_link, query, options, fetch_function, respon
def _create_execution_context_with_query_plan(self):
self._fetched_query_plan = True
query_to_use = self._query if self._query is not None else "Select * from root r"
query_execution_info = _PartitionedQueryExecutionInfo(self._client._GetQueryPlanThroughGateway
(query_to_use, self._resource_link, self._options.get('excludedLocations')))

query_plan = self._client._GetQueryPlanThroughGateway(
query_to_use,
self._resource_link,
self._options.get('excludedLocations'),
read_timeout=self._options.get('read_timeout')
)
query_execution_info = _PartitionedQueryExecutionInfo(query_plan)
qe_info = getattr(query_execution_info, "_query_execution_info", None)
if isinstance(qe_info, dict) and isinstance(query_to_use, dict):
params = query_to_use.get("parameters")
if params is not None:
query_execution_info._query_execution_info['parameters'] = params

self._execution_context = self._create_pipelined_execution_context(query_execution_info)

def __next__(self):
Expand Down
14 changes: 14 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@

"""Iterable query results in the Azure Cosmos database service.
"""
import time
from azure.core.paging import PageIterator # type: ignore
from azure.cosmos._constants import _Constants, TimeoutScope
from azure.cosmos._execution_context import execution_dispatcher
from azure.cosmos import exceptions

# pylint: disable=protected-access

Expand Down Expand Up @@ -99,6 +102,17 @@ def _fetch_next(self, *args): # pylint: disable=unused-argument
:return: List of results.
:rtype: list
"""
timeout = self._options.get('timeout')
# reset the operation start time if it's a paged request
if timeout and self._options.get(_Constants.TimeoutScope) != TimeoutScope.OPERATION:
self._options[_Constants.OperationStartTime] = time.time()

# Check timeout before fetching next block
if timeout:
elapsed = time.time() - self._options.get(_Constants.OperationStartTime)
if elapsed >= timeout:
raise exceptions.CosmosClientTimeoutError()

block = self._ex_context.fetch_next_block()
if not block:
raise StopIteration
Expand Down
35 changes: 29 additions & 6 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@
from . import _session_retry_policy
from . import _timeout_failover_retry_policy
from . import exceptions
from ._constants import _Constants
from .documents import _OperationType
from .exceptions import CosmosHttpResponseError
from .http_constants import HttpHeaders, StatusCodes, SubStatusCodes, ResourceType
from ._cosmos_http_logging_policy import _log_diagnostics_error


# pylint: disable=protected-access, disable=too-many-lines, disable=too-many-statements, disable=too-many-branches

# args [0] is the request object
Expand All @@ -64,6 +64,13 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin
:returns: the result of running the passed in function as a (result, headers) tuple
:rtype: tuple of (dict, dict)
"""
# Capture the client timeout and start time at the beginning
timeout = kwargs.get('timeout')
operation_start_time = kwargs.get(_Constants.OperationStartTime, time.time())

# Track the last error for chaining
last_error = None

pk_range_wrapper = None
if args and global_endpoint_manager.is_circuit_breaker_applicable(args[0]):
pk_range_wrapper = global_endpoint_manager.create_pk_range_wrapper(args[0])
Expand Down Expand Up @@ -110,14 +117,25 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin
client, client._container_properties_cache, None, *args)

while True:
client_timeout = kwargs.get('timeout')
start_time = time.time()
# Check timeout before executing function
if timeout:
elapsed = time.time() - operation_start_time
if elapsed >= timeout:
raise exceptions.CosmosClientTimeoutError(error=last_error)

try:
if args:
result = ExecuteFunction(function, global_endpoint_manager, *args, **kwargs)
global_endpoint_manager.record_success(args[0])
else:
result = ExecuteFunction(function, *args, **kwargs)
# Check timeout after successful execution
if timeout:
elapsed = time.time() - operation_start_time
if elapsed >= timeout:
raise exceptions.CosmosClientTimeoutError(error=last_error)

if not client.last_response_headers:
client.last_response_headers = {}

Expand Down Expand Up @@ -158,6 +176,7 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin

return result
except exceptions.CosmosHttpResponseError as e:
last_error = e
if request:
# update session token for relevant operations
client._UpdateSessionIfRequired(request.headers, {}, e.headers)
Expand Down Expand Up @@ -226,12 +245,13 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin
client.session.clear_session_token(client.last_response_headers)
raise

# Now check timeout before retrying
if timeout:
elapsed = time.time() - operation_start_time
if elapsed >= timeout:
raise exceptions.CosmosClientTimeoutError(error=last_error)
# Wait for retry_after_in_milliseconds time before the next retry
time.sleep(retry_policy.retry_after_in_milliseconds / 1000.0)
if client_timeout:
kwargs['timeout'] = client_timeout - (time.time() - start_time)
if kwargs['timeout'] <= 0:
raise exceptions.CosmosClientTimeoutError()

except ServiceRequestError as e:
if request and _has_database_account_header(request.headers):
Expand All @@ -258,6 +278,7 @@ def ExecuteFunction(function, *args, **kwargs):
"""
return function(*args, **kwargs)


def _has_read_retryable_headers(request_headers):
if _OperationType.IsReadOnlyOperation(request_headers.get(HttpHeaders.ThinClientProxyOperationType)):
return True
Expand Down Expand Up @@ -332,6 +353,7 @@ def send(self, request):
:raises ~azure.cosmos.exceptions.CosmosClientTimeoutError: Specified timeout exceeded.
:raises ~azure.core.exceptions.ClientAuthenticationError: Authentication failed.
"""

absolute_timeout = request.context.options.pop('timeout', None)
per_request_timeout = request.context.options.pop('connection_timeout', 0)
request_params = request.context.options.pop('request_params', None)
Expand Down Expand Up @@ -384,6 +406,7 @@ def send(self, request):
if retry_active:
self.sleep(retry_settings, request.context.transport)
continue

raise err
except CosmosHttpResponseError as err:
raise err
Expand Down
10 changes: 4 additions & 6 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@

from urllib.parse import urlparse
from azure.core.exceptions import DecodeError # type: ignore

from . import exceptions
from . import http_constants
from . import _retry_utility
from ._constants import _Constants
from . import exceptions, http_constants, _retry_utility


def _is_readable_stream(obj):
Expand Down Expand Up @@ -80,8 +78,8 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin
:rtype: tuple of (dict, dict)

"""
# pylint: disable=protected-access

# pylint: disable=protected-access, too-many-branches
kwargs.pop(_Constants.OperationStartTime, None)
connection_timeout = connection_policy.RequestTimeout
connection_timeout = kwargs.pop("connection_timeout", connection_timeout)
read_timeout = connection_policy.ReadTimeout
Expand Down
7 changes: 6 additions & 1 deletion sdk/cosmos/azure-cosmos/azure/cosmos/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,17 @@ def valid_key_value_exist(
kwargs: dict[str, Any],
key: str,
invalid_value: Any = None) -> bool:
"""Check if a valid key and value exists in kwargs. By default, it checks if the value is not None.
"""Check if a valid key and value exists in kwargs. It always checks if the value is not None and it will remove
from the kwargs the None value.

:param dict[str, Any] kwargs: The dictionary of keyword arguments.
:param str key: The key to check.
:param Any invalid_value: The value that is considered invalid. Default is None.
:return: True if the key exists and its value is not None, False otherwise.
:rtype: bool
"""
if key in kwargs and kwargs[key] is None:
kwargs.pop(key)
return False

return key in kwargs and kwargs[key] is not invalid_value
2 changes: 1 addition & 1 deletion sdk/cosmos/azure-cosmos/azure/cosmos/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

VERSION = "4.14.2"
VERSION = "4.14.3"
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .. import exceptions
from .. import http_constants
from . import _retry_utility_async
from .._constants import _Constants
from .._synchronized_request import _request_body_from_data, _replace_url_prefix


Expand All @@ -49,8 +50,8 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p
:rtype: tuple of (dict, dict)

"""
# pylint: disable=protected-access

# pylint: disable=protected-access, too-many-branches
kwargs.pop(_Constants.OperationStartTime, None)
connection_timeout = connection_policy.RequestTimeout
read_timeout = connection_policy.ReadTimeout
connection_timeout = kwargs.pop("connection_timeout", connection_timeout)
Expand Down
14 changes: 11 additions & 3 deletions sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
from .._base import (_build_properties_cache, _deserialize_throughput, _replace_throughput,
build_options as _build_options, GenerateGuidId, validate_cache_staleness_value)
from .._change_feed.feed_range_internal import FeedRangeInternalEpk
from .._constants import _Constants as Constants

from .._cosmos_responses import CosmosDict, CosmosList
from .._constants import _Constants as Constants, TimeoutScope
from .._routing.routing_range import Range
from .._session_token_helpers import get_latest_session_token
from ..exceptions import CosmosHttpResponseError
Expand Down Expand Up @@ -96,8 +97,14 @@ def __repr__(self) -> str:

async def _get_properties_with_options(self, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
kwargs = {}
if options and "excludedLocations" in options:
kwargs['excluded_locations'] = options['excludedLocations']
if options:
if "excludedLocations" in options:
kwargs['excluded_locations'] = options['excludedLocations']
if Constants.OperationStartTime in options:
kwargs[Constants.OperationStartTime] = options[Constants.OperationStartTime]
if "timeout" in options:
kwargs['timeout'] = options['timeout']

return await self._get_properties(**kwargs)

async def _get_properties(self, **kwargs: Any) -> dict[str, Any]:
Expand Down Expand Up @@ -483,6 +490,7 @@ async def read_items(
query_options = _build_options(kwargs)
await self._get_properties_with_options(query_options)
query_options["enableCrossPartitionQuery"] = True
query_options[Constants.TimeoutScope] = TimeoutScope.OPERATION

item_tuples = [(item_id, await self._set_partition_key(pk)) for item_id, pk in items]
return await self.client_connection.read_items(
Expand Down
Loading
Loading