diff --git a/matter_server/server/vendor_info.py b/matter_server/server/vendor_info.py index a966039d..51898172 100644 --- a/matter_server/server/vendor_info.py +++ b/matter_server/server/vendor_info.py @@ -5,7 +5,7 @@ import logging from typing import TYPE_CHECKING -from aiohttp import ClientError, ClientSession +from aiohttp import ClientError, ClientSession, ClientTimeout from ..common.helpers.api import api_command from ..common.helpers.util import dataclass_from_dict, dataclass_to_dict @@ -17,6 +17,7 @@ LOGGER = logging.getLogger(__name__) PRODUCTION_URL = "https://on.dcl.csa-iot.org" DATA_KEY_VENDOR_INFO = "vendor_info" +DCL_REQUEST_TIMEOUT = ClientTimeout(total=30) TEST_VENDOR = VendorInfoModel( @@ -71,7 +72,9 @@ async def _fetch_vendors(self) -> None: LOGGER.info("Fetching the latest vendor info from DCL.") vendors: dict[int, VendorInfoModel] = {} try: - async with ClientSession(raise_for_status=True) as session: + async with ClientSession( + raise_for_status=True, timeout=DCL_REQUEST_TIMEOUT + ) as session: page_token: str | None = "" while page_token is not None: async with session.get( @@ -93,8 +96,10 @@ async def _fetch_vendors(self) -> None: creator=vendorinfo["creator"], ) page_token = data.get("pagination", {}).get("next_key", None) - except ClientError as err: - LOGGER.error("Unable to fetch vendor info from DCL: %s", err) + except (ClientError, TimeoutError) as err: + LOGGER.warning( + "Unable to fetch vendor info from DCL: %s", err, exc_info=err + ) else: LOGGER.info("Fetched %s vendors from DCL.", len(vendors)) diff --git a/tests/server/test_vendor_info.py b/tests/server/test_vendor_info.py new file mode 100644 index 00000000..e4d56917 --- /dev/null +++ b/tests/server/test_vendor_info.py @@ -0,0 +1,73 @@ +"""Test vendor info handling.""" + +from __future__ import annotations + +from typing import Any, Self +from unittest.mock import MagicMock, patch + +from matter_server.server.vendor_info import ( + DATA_KEY_VENDOR_INFO, + DCL_REQUEST_TIMEOUT, + NABUCASA_VENDOR, + TEST_VENDOR, + VendorInfo, +) + + +class _TimeoutResponse: + """Response context manager that times out when JSON is read.""" + + async def __aenter__(self) -> Self: + """Enter the response context.""" + return self + + async def __aexit__(self, *args: Any) -> None: + """Exit the response context.""" + + async def json(self) -> dict[str, Any]: + """Raise a timeout to simulate a stalled DCL response.""" + raise TimeoutError + + +class _FakeClientSession: + """Client session context manager that records initialization kwargs.""" + + def __init__(self, call_kwargs: dict[str, Any]) -> None: + """Initialize the fake session.""" + self.call_kwargs = call_kwargs + + async def __aenter__(self) -> Self: + """Enter the session context.""" + return self + + async def __aexit__(self, *args: Any) -> None: + """Exit the session context.""" + + def get(self, *args: Any, **kwargs: Any) -> _TimeoutResponse: + """Return a response that times out.""" + return _TimeoutResponse() + + +async def test_vendor_info_start_handles_dcl_timeout() -> None: + """Test vendor info startup continues when the DCL request times out.""" + server = MagicMock() + server.storage.get.return_value = {} + vendor_info = VendorInfo(server) + client_session_kwargs: dict[str, Any] = {} + + def _client_session(**kwargs: Any) -> _FakeClientSession: + client_session_kwargs.update(kwargs) + return _FakeClientSession(kwargs) + + with patch("matter_server.server.vendor_info.ClientSession", _client_session): + await vendor_info.start() + + assert client_session_kwargs == { + "raise_for_status": True, + "timeout": DCL_REQUEST_TIMEOUT, + } + server.storage.set.assert_called_once() + storage_key, vendor_data = server.storage.set.call_args.args + assert storage_key == DATA_KEY_VENDOR_INFO + assert TEST_VENDOR.vendor_id in vendor_data + assert NABUCASA_VENDOR.vendor_id in vendor_data