Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions matter_server/server/vendor_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
from typing import TYPE_CHECKING

from aiohttp import ClientError, ClientSession
from aiohttp import ClientError, ClientSession, ClientTimeout

from ..common.helpers.api import api_command
from ..common.helpers.util import dataclass_from_dict, dataclass_to_dict
Expand All @@ -17,6 +17,7 @@
LOGGER = logging.getLogger(__name__)
PRODUCTION_URL = "https://on.dcl.csa-iot.org"
DATA_KEY_VENDOR_INFO = "vendor_info"
DCL_REQUEST_TIMEOUT = ClientTimeout(total=30)


TEST_VENDOR = VendorInfoModel(
Expand Down Expand Up @@ -71,7 +72,9 @@ async def _fetch_vendors(self) -> None:
LOGGER.info("Fetching the latest vendor info from DCL.")
vendors: dict[int, VendorInfoModel] = {}
try:
async with ClientSession(raise_for_status=True) as session:
async with ClientSession(
raise_for_status=True, timeout=DCL_REQUEST_TIMEOUT
) as session:
page_token: str | None = ""
while page_token is not None:
async with session.get(
Expand All @@ -93,8 +96,10 @@ async def _fetch_vendors(self) -> None:
creator=vendorinfo["creator"],
)
page_token = data.get("pagination", {}).get("next_key", None)
except ClientError as err:
LOGGER.error("Unable to fetch vendor info from DCL: %s", err)
except (ClientError, TimeoutError) as err:
LOGGER.warning(
"Unable to fetch vendor info from DCL: %s", err, exc_info=err
)
else:
LOGGER.info("Fetched %s vendors from DCL.", len(vendors))

Expand Down
73 changes: 73 additions & 0 deletions tests/server/test_vendor_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Test vendor info handling."""

from __future__ import annotations

from typing import Any, Self
from unittest.mock import MagicMock, patch

from matter_server.server.vendor_info import (
DATA_KEY_VENDOR_INFO,
DCL_REQUEST_TIMEOUT,
NABUCASA_VENDOR,
TEST_VENDOR,
VendorInfo,
)


class _TimeoutResponse:
"""Response context manager that times out when JSON is read."""

async def __aenter__(self) -> Self:
"""Enter the response context."""
return self

async def __aexit__(self, *args: Any) -> None:
"""Exit the response context."""

async def json(self) -> dict[str, Any]:
"""Raise a timeout to simulate a stalled DCL response."""
raise TimeoutError


class _FakeClientSession:
"""Client session context manager that records initialization kwargs."""

def __init__(self, call_kwargs: dict[str, Any]) -> None:
"""Initialize the fake session."""
self.call_kwargs = call_kwargs

async def __aenter__(self) -> Self:
"""Enter the session context."""
return self

async def __aexit__(self, *args: Any) -> None:
"""Exit the session context."""

def get(self, *args: Any, **kwargs: Any) -> _TimeoutResponse:
"""Return a response that times out."""
return _TimeoutResponse()


async def test_vendor_info_start_handles_dcl_timeout() -> None:
"""Test vendor info startup continues when the DCL request times out."""
server = MagicMock()
server.storage.get.return_value = {}
vendor_info = VendorInfo(server)
client_session_kwargs: dict[str, Any] = {}

def _client_session(**kwargs: Any) -> _FakeClientSession:
client_session_kwargs.update(kwargs)
return _FakeClientSession(kwargs)

with patch("matter_server.server.vendor_info.ClientSession", _client_session):
await vendor_info.start()

assert client_session_kwargs == {
"raise_for_status": True,
"timeout": DCL_REQUEST_TIMEOUT,
}
server.storage.set.assert_called_once()
storage_key, vendor_data = server.storage.set.call_args.args
assert storage_key == DATA_KEY_VENDOR_INFO
assert TEST_VENDOR.vendor_id in vendor_data
assert NABUCASA_VENDOR.vendor_id in vendor_data