Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion eng/apiview_reqs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ tomli==2.2.1
tomlkit==0.13.2
typing_extensions==4.15.0
wrapt==1.17.2
apiview-stub-generator==0.3.24
apiview-stub-generator==0.3.25
16 changes: 16 additions & 0 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
## Release History

### 4.14.4 (Unreleased)

#### Features Added

#### Breaking Changes

#### Bugs Fixed

#### Other Changes

### 4.14.3 (2025-12-08)

#### Bugs Fixed
* Fixed bug where client timeout/read_timeout values were not properly enforced. See [PR 42652](https://github.com/Azure/azure-sdk-for-python/pull/42652).
* Fixed bug when passing in None for some options in `query_items` would cause unexpected errors. See [PR 44098](https://github.com/Azure/azure-sdk-for-python/pull/44098)

### 4.14.2 (2025-11-14)

#### Features Added
Expand Down
8 changes: 8 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"""

import base64
import time
from email.utils import formatdate
import json
import uuid
Expand Down Expand Up @@ -109,6 +110,13 @@ def build_options(kwargs: dict[str, Any]) -> dict[str, Any]:
for key, value in _COMMON_OPTIONS.items():
if key in kwargs:
options[value] = kwargs.pop(key)
if 'read_timeout' in kwargs:
options['read_timeout'] = kwargs['read_timeout']
if 'timeout' in kwargs:
options['timeout'] = kwargs['timeout']


options[Constants.OperationStartTime] = time.time()
if_match, if_none_match = _get_match_headers(kwargs)
if if_match:
options['accessCondition'] = {'type': 'IfMatch', 'condition': if_match}
Expand Down
9 changes: 8 additions & 1 deletion sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,21 @@

from typing_extensions import Literal

# cspell:ignore reranker
class TimeoutScope:
"""Defines the scope of timeout application"""
OPERATION: Literal["operation"] = "operation" # Apply timeout to entire logical operation
PAGE: Literal["page"] = "page" # Apply timeout to individual page requests

# cspell:ignore reranker

class _Constants:
"""Constants used in the azure-cosmos package"""

UserConsistencyPolicy: Literal["userConsistencyPolicy"] = "userConsistencyPolicy"
DefaultConsistencyLevel: Literal["defaultConsistencyLevel"] = "defaultConsistencyLevel"
OperationStartTime: Literal["operationStartTime"] = "operationStartTime"
# whether to apply timeout to the whole logical operation or just a page request
TimeoutScope: Literal["timeoutScope"] = "timeoutScope"

# GlobalDB related constants
WritableLocations: Literal["writableLocations"] = "writableLocations"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3169,6 +3169,18 @@ def __QueryFeed( # pylint: disable=too-many-locals, too-many-statements, too-ma
"""
if options is None:
options = {}
read_timeout = options.get("read_timeout")
if read_timeout is not None:
# we currently have a gap where kwargs are not getting passed correctly down the pipeline. In order to make
# absolute time out work, we are passing read_timeout via kwargs as a temporary fix
kwargs.setdefault("read_timeout", read_timeout)

operation_start_time = options.get(Constants.OperationStartTime)
if operation_start_time is not None:
kwargs.setdefault(Constants.OperationStartTime, operation_start_time)
timeout = options.get("timeout")
if timeout is not None:
kwargs.setdefault("timeout", timeout)

if query:
__GetBodiesFromQueryResult = result_fn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,17 @@ async def _fetch_items_helper_no_retries(self, fetch_function):
return fetched_items

async def _fetch_items_helper_with_retries(self, fetch_function):
async def callback():
# TODO: Properly propagate kwargs from retry utility to fetch function
# the callback keep the **kwargs parameter to maintain compatibility with the retry utility's execution pattern.
# ExecuteAsync passes retry context parameters (timeout, operation start time, logger, etc.)
# The callback need to accept these parameters even if unused
# Removing **kwargs results in a TypeError when ExecuteAsync tries to pass these parameters
async def callback(**kwargs): # pylint: disable=unused-argument
return await self._fetch_items_helper_no_retries(fetch_function)

return await _retry_utility_async.ExecuteAsync(self._client, self._client._global_endpoint_manager, callback)
return await _retry_utility_async.ExecuteAsync(
self._client, self._client._global_endpoint_manager, callback, **self._options
)


class _DefaultQueryExecutionContext(_QueryExecutionContextBase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,19 @@ def __init__(self, client, resource_link, query, options, fetch_function,
async def _create_execution_context_with_query_plan(self):
self._fetched_query_plan = True
query_to_use = self._query if self._query is not None else "Select * from root r"
query_execution_info = _PartitionedQueryExecutionInfo(await self._client._GetQueryPlanThroughGateway
(query_to_use, self._resource_link, self._options.get('excludedLocations')))
query_plan = await self._client._GetQueryPlanThroughGateway(
query_to_use,
self._resource_link,
self._options.get('excludedLocations'),
read_timeout=self._options.get('read_timeout')
)
query_execution_info = _PartitionedQueryExecutionInfo(query_plan)
qe_info = getattr(query_execution_info, "_query_execution_info", None)
if isinstance(qe_info, dict) and isinstance(query_to_use, dict):
params = query_to_use.get("parameters")
if params is not None:
query_execution_info._query_execution_info['parameters'] = params

self._execution_context = await self._create_pipelined_execution_context(query_execution_info)

async def __anext__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,15 @@ def _fetch_items_helper_no_retries(self, fetch_function):
return fetched_items

def _fetch_items_helper_with_retries(self, fetch_function):
def callback():
# TODO: Properly propagate kwargs from retry utility to fetch function
# the callback keep the **kwargs parameter to maintain compatibility with the retry utility's execution pattern.
# ExecuteAsync passes retry context parameters (timeout, operation start time, logger, etc.)
# The callback need to accept these parameters even if unused
# Removing **kwargs results in a TypeError when ExecuteAsync tries to pass these parameters
def callback(**kwargs): # pylint: disable=unused-argument
return self._fetch_items_helper_no_retries(fetch_function)

return _retry_utility.Execute(self._client, self._client._global_endpoint_manager, callback)
return _retry_utility.Execute(self._client, self._client._global_endpoint_manager, callback, **self._options)

next = __next__ # Python 2 compatibility.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,19 @@ def __init__(self, client, resource_link, query, options, fetch_function, respon
def _create_execution_context_with_query_plan(self):
self._fetched_query_plan = True
query_to_use = self._query if self._query is not None else "Select * from root r"
query_execution_info = _PartitionedQueryExecutionInfo(self._client._GetQueryPlanThroughGateway
(query_to_use, self._resource_link, self._options.get('excludedLocations')))

query_plan = self._client._GetQueryPlanThroughGateway(
query_to_use,
self._resource_link,
self._options.get('excludedLocations'),
read_timeout=self._options.get('read_timeout')
)
query_execution_info = _PartitionedQueryExecutionInfo(query_plan)
qe_info = getattr(query_execution_info, "_query_execution_info", None)
if isinstance(qe_info, dict) and isinstance(query_to_use, dict):
params = query_to_use.get("parameters")
if params is not None:
query_execution_info._query_execution_info['parameters'] = params

self._execution_context = self._create_pipelined_execution_context(query_execution_info)

def __next__(self):
Expand Down
14 changes: 14 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@

"""Iterable query results in the Azure Cosmos database service.
"""
import time
from azure.core.paging import PageIterator # type: ignore
from azure.cosmos._constants import _Constants, TimeoutScope
from azure.cosmos._execution_context import execution_dispatcher
from azure.cosmos import exceptions

# pylint: disable=protected-access

Expand Down Expand Up @@ -99,6 +102,17 @@ def _fetch_next(self, *args): # pylint: disable=unused-argument
:return: List of results.
:rtype: list
"""
timeout = self._options.get('timeout')
# reset the operation start time if it's a paged request
if timeout and self._options.get(_Constants.TimeoutScope) != TimeoutScope.OPERATION:
self._options[_Constants.OperationStartTime] = time.time()

# Check timeout before fetching next block
if timeout:
elapsed = time.time() - self._options.get(_Constants.OperationStartTime)
if elapsed >= timeout:
raise exceptions.CosmosClientTimeoutError()

block = self._ex_context.fetch_next_block()
if not block:
raise StopIteration
Expand Down
35 changes: 29 additions & 6 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@
from . import _session_retry_policy
from . import _timeout_failover_retry_policy
from . import exceptions
from ._constants import _Constants
from .documents import _OperationType
from .exceptions import CosmosHttpResponseError
from .http_constants import HttpHeaders, StatusCodes, SubStatusCodes, ResourceType
from ._cosmos_http_logging_policy import _log_diagnostics_error


# pylint: disable=protected-access, disable=too-many-lines, disable=too-many-statements, disable=too-many-branches

# args [0] is the request object
Expand All @@ -64,6 +64,13 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin
:returns: the result of running the passed in function as a (result, headers) tuple
:rtype: tuple of (dict, dict)
"""
# Capture the client timeout and start time at the beginning
timeout = kwargs.get('timeout')
operation_start_time = kwargs.get(_Constants.OperationStartTime, time.time())

# Track the last error for chaining
last_error = None

pk_range_wrapper = None
if args and global_endpoint_manager.is_circuit_breaker_applicable(args[0]):
pk_range_wrapper = global_endpoint_manager.create_pk_range_wrapper(args[0])
Expand Down Expand Up @@ -110,14 +117,25 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin
client, client._container_properties_cache, None, *args)

while True:
client_timeout = kwargs.get('timeout')
start_time = time.time()
# Check timeout before executing function
if timeout:
elapsed = time.time() - operation_start_time
if elapsed >= timeout:
raise exceptions.CosmosClientTimeoutError(error=last_error)

try:
if args:
result = ExecuteFunction(function, global_endpoint_manager, *args, **kwargs)
global_endpoint_manager.record_success(args[0])
else:
result = ExecuteFunction(function, *args, **kwargs)
# Check timeout after successful execution
if timeout:
elapsed = time.time() - operation_start_time
if elapsed >= timeout:
raise exceptions.CosmosClientTimeoutError(error=last_error)

if not client.last_response_headers:
client.last_response_headers = {}

Expand Down Expand Up @@ -158,6 +176,7 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin

return result
except exceptions.CosmosHttpResponseError as e:
last_error = e
if request:
# update session token for relevant operations
client._UpdateSessionIfRequired(request.headers, {}, e.headers)
Expand Down Expand Up @@ -226,12 +245,13 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin
client.session.clear_session_token(client.last_response_headers)
raise

# Now check timeout before retrying
if timeout:
elapsed = time.time() - operation_start_time
if elapsed >= timeout:
raise exceptions.CosmosClientTimeoutError(error=last_error)
# Wait for retry_after_in_milliseconds time before the next retry
time.sleep(retry_policy.retry_after_in_milliseconds / 1000.0)
if client_timeout:
kwargs['timeout'] = client_timeout - (time.time() - start_time)
if kwargs['timeout'] <= 0:
raise exceptions.CosmosClientTimeoutError()

except ServiceRequestError as e:
if request and _has_database_account_header(request.headers):
Expand All @@ -258,6 +278,7 @@ def ExecuteFunction(function, *args, **kwargs):
"""
return function(*args, **kwargs)


def _has_read_retryable_headers(request_headers):
if _OperationType.IsReadOnlyOperation(request_headers.get(HttpHeaders.ThinClientProxyOperationType)):
return True
Expand Down Expand Up @@ -332,6 +353,7 @@ def send(self, request):
:raises ~azure.cosmos.exceptions.CosmosClientTimeoutError: Specified timeout exceeded.
:raises ~azure.core.exceptions.ClientAuthenticationError: Authentication failed.
"""

absolute_timeout = request.context.options.pop('timeout', None)
per_request_timeout = request.context.options.pop('connection_timeout', 0)
request_params = request.context.options.pop('request_params', None)
Expand Down Expand Up @@ -384,6 +406,7 @@ def send(self, request):
if retry_active:
self.sleep(retry_settings, request.context.transport)
continue

raise err
except CosmosHttpResponseError as err:
raise err
Expand Down
10 changes: 4 additions & 6 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@

from urllib.parse import urlparse
from azure.core.exceptions import DecodeError # type: ignore

from . import exceptions
from . import http_constants
from . import _retry_utility
from ._constants import _Constants
from . import exceptions, http_constants, _retry_utility


def _is_readable_stream(obj):
Expand Down Expand Up @@ -80,8 +78,8 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin
:rtype: tuple of (dict, dict)

"""
# pylint: disable=protected-access

# pylint: disable=protected-access, too-many-branches
kwargs.pop(_Constants.OperationStartTime, None)
connection_timeout = connection_policy.RequestTimeout
connection_timeout = kwargs.pop("connection_timeout", connection_timeout)
read_timeout = connection_policy.ReadTimeout
Expand Down
7 changes: 6 additions & 1 deletion sdk/cosmos/azure-cosmos/azure/cosmos/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,17 @@ def valid_key_value_exist(
kwargs: dict[str, Any],
key: str,
invalid_value: Any = None) -> bool:
"""Check if a valid key and value exists in kwargs. By default, it checks if the value is not None.
"""Check if a valid key and value exists in kwargs. It always checks if the value is not None and it will remove
from the kwargs the None value.

:param dict[str, Any] kwargs: The dictionary of keyword arguments.
:param str key: The key to check.
:param Any invalid_value: The value that is considered invalid. Default is None.
:return: True if the key exists and its value is not None, False otherwise.
:rtype: bool
"""
if key in kwargs and kwargs[key] is None:
kwargs.pop(key)
return False

return key in kwargs and kwargs[key] is not invalid_value
2 changes: 1 addition & 1 deletion sdk/cosmos/azure-cosmos/azure/cosmos/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

VERSION = "4.14.2"
VERSION = "4.14.4"
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .. import exceptions
from .. import http_constants
from . import _retry_utility_async
from .._constants import _Constants
from .._synchronized_request import _request_body_from_data, _replace_url_prefix


Expand All @@ -49,8 +50,8 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p
:rtype: tuple of (dict, dict)

"""
# pylint: disable=protected-access

# pylint: disable=protected-access, too-many-branches
kwargs.pop(_Constants.OperationStartTime, None)
connection_timeout = connection_policy.RequestTimeout
read_timeout = connection_policy.ReadTimeout
connection_timeout = kwargs.pop("connection_timeout", connection_timeout)
Expand Down
Loading
Loading