Skip to content

Commit 8cb46fc

Browse files
committed
Add generic tests
1 parent d6a19c6 commit 8cb46fc

File tree

4 files changed

+276
-28
lines changed

4 files changed

+276
-28
lines changed

src/apify_client/clients/base/resource_collection_client.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from collections.abc import AsyncIterator, Awaitable
3+
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Generator
44
from typing import Any, Generic, Protocol, TypeVar
55

66
from apify_client._utils import parse_date_fields, pluck_data
@@ -85,8 +85,10 @@ async def _list(self, **kwargs: Any) -> ListPage:
8585

8686
return ListPage(parse_date_fields(pluck_data(response.json())))
8787

88-
def _list_paginated(self, **kwargs: Any) -> ListPageProtocol:
89-
def min_for_limit_param(a: int | None, b : int| None) -> int | None:
88+
def _list_iterable(self, **kwargs: Any) -> ListPageProtocol[T]:
89+
"""Return object can be awaited or iterated over."""
90+
91+
def min_for_limit_param(a: int | None, b: int | None) -> int | None:
9092
# API treats 0 as None for limit parameter, in this context API understands 0 as infinity.
9193
if a == 0:
9294
a = None
@@ -97,32 +99,34 @@ def min_for_limit_param(a: int | None, b : int| None) -> int | None:
9799
if b is None:
98100
return a
99101
return min(a, b)
102+
100103
chunk_size = kwargs.pop('chunk_size', None)
101104

102-
list_page_getter = self._list(**{**kwargs, 'limit':min_for_limit_param(kwargs.get('limit'), chunk_size)})
105+
list_page_awaitable = self._list(**{**kwargs, 'limit': min_for_limit_param(kwargs.get('limit'), chunk_size)})
103106

104-
async def async_iterator():
105-
current_page = await list_page_getter
107+
async def async_iterator() -> AsyncIterator[T]:
108+
current_page = await list_page_awaitable
106109
for item in current_page.items:
107110
yield item
108111

109112
offset = kwargs.get('offset') or 0
110113
limit = min(kwargs.get('limit') or current_page.total, current_page.total)
111114

112115
current_offset = offset + len(current_page.items)
113-
remaining_items = min(current_page.total-offset, limit) - len(current_page.items)
114-
while (current_page.items and remaining_items > 0):
115-
new_kwargs = {**kwargs,
116-
'offset': current_offset,
117-
'limit': min_for_limit_param(remaining_items, chunk_size)}
116+
remaining_items = min(current_page.total - offset, limit) - len(current_page.items)
117+
while current_page.items and remaining_items > 0:
118+
new_kwargs = {
119+
**kwargs,
120+
'offset': current_offset,
121+
'limit': min_for_limit_param(remaining_items, chunk_size),
122+
}
118123
current_page = await self._list(**new_kwargs)
119124
for item in current_page.items:
120125
yield item
121126
current_offset += len(current_page.items)
122127
remaining_items -= len(current_page.items)
123128

124-
return ListPageIterable(list_page_getter, async_iterator())
125-
129+
return IterableListPage[T](list_page_awaitable, async_iterator())
126130

127131
async def _create(self, resource: dict) -> dict:
128132
response = await self.http_client.call(
@@ -149,25 +153,21 @@ async def _get_or_create(
149153
return parse_date_fields(pluck_data(response.json()))
150154

151155

152-
class ListPageProtocol(Protocol[T]):
153-
def __aiter__(self) -> AsyncIterator[T]: ...
154-
def __await__(self) -> ListPage[T]: ...
156+
class ListPageProtocol(Protocol[T], AsyncIterable[T], Awaitable[ListPage[T]]):
157+
"""Protocol for an object that can be both awaited and asynchronously iterated over."""
155158

156159

157-
class ListPageIterable(Generic[T]):
160+
class IterableListPage(Generic[T]):
161+
"""Can be awaited to get ListPage with items or asynchronously iterated over to get individual items."""
162+
158163
def __init__(self, awaitable: Awaitable[ListPage[T]], async_iterator: AsyncIterator[T]) -> None:
159164
self._awaitable = awaitable
160165
self._async_iterator = async_iterator
161166

162-
def __aiter__(self):
167+
def __aiter__(self) -> AsyncIterator[T]:
168+
"""Return an asynchronous iterator over the items from API, possibly doing multiple API calls."""
163169
return self._async_iterator
164170

165-
def __await__(self):
171+
def __await__(self) -> Generator[Any, Any, ListPage[T]]:
172+
"""Return an awaitable that resolves to the ListPage doing exactly one API call."""
166173
return self._awaitable.__await__()
167-
168-
169-
"""
170-
async def __anext__(self) -> T:
171-
async for item in self._async_iterator:
172-
print(item)
173-
"""

src/apify_client/clients/resource_clients/actor_collection.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ def list(
150150
offset: int | None = None,
151151
desc: bool | None = None,
152152
sort_by: Literal['createdAt', 'stats.lastRunStartedAt'] | None = 'createdAt',
153-
chunk_size: int | None = None,
154153
) -> ListPageProtocol[dict]:
155154
"""List the Actors the user has created or used.
156155
@@ -166,7 +165,7 @@ def list(
166165
Returns:
167166
The list of available Actors matching the specified filters.
168167
"""
169-
return self._list_paginated(my=my, limit=limit, offset=offset, desc=desc, sortBy=sort_by, chunk_size=chunk_size)
168+
return self._list_iterable(my=my, limit=limit, offset=offset, desc=desc, sortBy=sort_by)
170169

171170
async def create(
172171
self,
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import pytest
2+
3+
from apify_client import ApifyClientAsync
4+
5+
6+
@pytest.mark.parametrize(
7+
'factory_name',
8+
[
9+
'actors',
10+
'datasets',
11+
],
12+
)
13+
async def test_client_list_iterable_total_count(apify_client_async: ApifyClientAsync, factory_name: str) -> None:
14+
"""Basic test of client list methods on real API.
15+
16+
More detailed tests are in unit tets.
17+
"""
18+
client = getattr(apify_client_async, factory_name)()
19+
list_response = await client.list()
20+
all_items = [item async for item in client.list()]
21+
assert len(all_items) == list_response.total
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
import dataclasses
2+
from typing import Any, Literal
3+
from unittest import mock
4+
from unittest.mock import Mock
5+
6+
import pytest
7+
from _pytest.mark import ParameterSet
8+
9+
from apify_client import ApifyClient, ApifyClientAsync
10+
from apify_client.clients import (
11+
ActorCollectionClient,
12+
BaseClient,
13+
BaseClientAsync,
14+
BuildCollectionClient,
15+
DatasetCollectionClient,
16+
KeyValueStoreCollectionClient,
17+
RequestQueueCollectionClient,
18+
ScheduleCollectionClient,
19+
StoreCollectionClient,
20+
TaskCollectionClient,
21+
WebhookCollectionClient,
22+
WebhookDispatchCollectionClient,
23+
)
24+
25+
CollectionClient = (
26+
ActorCollectionClient
27+
| BuildCollectionClient
28+
| ScheduleCollectionClient
29+
| TaskCollectionClient
30+
| WebhookCollectionClient
31+
| WebhookDispatchCollectionClient
32+
| DatasetCollectionClient
33+
| KeyValueStoreCollectionClient
34+
| RequestQueueCollectionClient
35+
| StoreCollectionClient
36+
)
37+
38+
39+
def create_items(start: int, end: int) -> list[dict[str, int]]:
40+
step = -1 if end < start else 1
41+
return [{'id': i, 'key': i} for i in range(start, end, step)]
42+
43+
44+
def mocked_api_pagination_logic(*_: Any, **kwargs: Any) -> dict:
45+
"""This function is a placeholder representing the mocked API pagination logic.
46+
47+
It simulates paginated responses from an API only to a limited extend to test iteration logic in client.
48+
Returned items are only placeholders that enable keeping track of their index on platform.
49+
50+
There are 2500 normal items in the collection and additional 100 extra items.
51+
Items are simple objects with incrementing attributes for easy verification.
52+
"""
53+
params = kwargs.get('params', {})
54+
normal_items = 2500
55+
extra_items = 100 # additional items, for example unnamed
56+
max_items_per_page = 1000
57+
58+
total_items = (normal_items + extra_items) if params.get('unnamed') else normal_items
59+
offset = params.get('offset') or 0
60+
limit = params.get('limit') or 0
61+
assert offset >= 0, 'Invalid offset send to API'
62+
assert limit >= 0, 'Invalid limit send to API'
63+
64+
# Ordered all items in the mocked platform.
65+
items = create_items(total_items, 0) if params.get('desc', False) else create_items(0, total_items)
66+
lower_index = min(offset, total_items)
67+
upper_index = min(offset + (limit or total_items), total_items)
68+
count = min(upper_index - lower_index, max_items_per_page)
69+
70+
response = Mock()
71+
response.json = lambda: {
72+
'data': {
73+
'total': total_items,
74+
'count': count,
75+
'offset': offset,
76+
'limit': limit or count,
77+
'desc': params.get('desc', False),
78+
'items': items[lower_index : min(upper_index, lower_index + max_items_per_page)],
79+
}
80+
}
81+
82+
return response
83+
84+
85+
@dataclasses.dataclass
86+
class TestCase:
87+
id: str
88+
inputs: dict
89+
expected_items: list[dict[str, int]]
90+
supported_clients: set[str]
91+
92+
def __hash__(self) -> int:
93+
return hash(self.id)
94+
95+
def supports(self, client: BaseClient | BaseClientAsync) -> bool:
96+
return client.__class__.__name__.replace('Async', '') in self.supported_clients
97+
98+
99+
# Prepare supported testcases for different clients
100+
COLLECTION_CLIENTS = {
101+
'ActorCollectionClient',
102+
'BuildCollectionClient',
103+
'ScheduleCollectionClient',
104+
'TaskCollectionClient',
105+
'WebhookCollectionClient',
106+
'WebhookDispatchCollectionClient',
107+
'DatasetCollectionClient',
108+
'KeyValueStoreCollectionClient',
109+
'RequestQueueCollectionClient',
110+
'StoreCollectionClient',
111+
}
112+
113+
NO_OPTIONS_CLIENTS = {
114+
'ActorEnvVarCollectionClient',
115+
'ActorVersionClient',
116+
}
117+
118+
STORAGE_CLIENTS = {
119+
'DatasetClient',
120+
'KeyValueStoreClient',
121+
'RequestQueueClient',
122+
}
123+
124+
ALL_CLIENTS = COLLECTION_CLIENTS | NO_OPTIONS_CLIENTS | STORAGE_CLIENTS
125+
126+
TEST_CASES = {
127+
TestCase('No options', {}, create_items(0, 2500), ALL_CLIENTS),
128+
TestCase('Limit', {'limit': 1100}, create_items(0, 1100), ALL_CLIENTS - NO_OPTIONS_CLIENTS),
129+
TestCase('Out of range limit', {'limit': 3000}, create_items(0, 2500), ALL_CLIENTS - NO_OPTIONS_CLIENTS),
130+
TestCase('Offset', {'offset': 1000}, create_items(1000, 2500), ALL_CLIENTS - NO_OPTIONS_CLIENTS),
131+
TestCase(
132+
'Offset and limit', {'offset': 1000, 'limit': 1100}, create_items(1000, 2100), ALL_CLIENTS - NO_OPTIONS_CLIENTS
133+
),
134+
TestCase('Out of range offset', {'offset': 3000}, [], ALL_CLIENTS - NO_OPTIONS_CLIENTS),
135+
TestCase(
136+
'Offset, limit, descending',
137+
{'offset': 1000, 'limit': 1100, 'desc': True},
138+
create_items(1500, 400),
139+
ALL_CLIENTS - NO_OPTIONS_CLIENTS - {'StoreCollectionClient'},
140+
),
141+
TestCase(
142+
'Offset, limit, descending, unnamed',
143+
{'offset': 50, 'limit': 1100, 'desc': True, 'unnamed': True},
144+
create_items(2550, 1450),
145+
{'DatasetCollectionClient', 'KeyValueStoreCollectionClient', 'RequestQueueCollectionClient'},
146+
),
147+
TestCase(
148+
'Offset, limit, descending, chunkSize',
149+
{'offset': 50, 'limit': 1100, 'desc': True, 'chunk_size': 100},
150+
create_items(1500, 400),
151+
{'DatasetClient'},
152+
),
153+
TestCase('Exclusive start key', {'exclusive_start_key': 1000}, create_items(1001, 2500), {'KeyValueStoreClient'}),
154+
TestCase('Exclusive start id', {'exclusive_start_id': 1000}, create_items(1001, 2500), {'RequestQueueClient'}),
155+
}
156+
157+
158+
def generate_test_params(
159+
client_set: Literal['collection', 'kvs', 'rq', 'dataset'], *, async_clients: bool = False
160+
) -> list[ParameterSet]:
161+
# Different clients support different options and thus different scenarios
162+
client = ApifyClientAsync(token='') if async_clients else ApifyClient(token='')
163+
164+
clients: set[BaseClient | BaseClientAsync]
165+
166+
match client_set:
167+
case 'collection':
168+
clients = {
169+
client.actors(),
170+
client.schedules(),
171+
client.tasks(),
172+
client.webhooks(),
173+
client.webhook_dispatches(),
174+
client.store(),
175+
client.datasets(),
176+
client.key_value_stores(),
177+
client.request_queues(),
178+
client.actor('some-id').builds(),
179+
client.actor('some-id').versions(),
180+
client.actor('some-id').version('some-version').env_vars(),
181+
}
182+
case 'kvs':
183+
clients = {client.key_value_store('some-id')}
184+
case 'rq':
185+
clients = {client.request_queue('some-id')}
186+
case 'dataset':
187+
clients = {client.dataset('some-id')}
188+
case _:
189+
raise ValueError(f'Unknown client set: {client_set}')
190+
191+
return [
192+
pytest.param(
193+
test_case.inputs, test_case.expected_items, client, id=f'{client.__class__.__name__}:{test_case.id}'
194+
)
195+
for test_case in TEST_CASES
196+
for client in clients
197+
if test_case.supports(client)
198+
]
199+
200+
201+
@pytest.mark.parametrize(
202+
('inputs', 'expected_items', 'client'), generate_test_params(client_set='collection', async_clients=True)
203+
)
204+
async def test_client_list_iterable_async(
205+
client: CollectionClient, inputs: dict, expected_items: list[dict[str, int]]
206+
) -> None:
207+
with mock.patch.object(client.http_client, 'call', side_effect=mocked_api_pagination_logic):
208+
returned_items = [item async for item in client.list(**inputs)]
209+
210+
if inputs == {}:
211+
list_response = await client.list(**inputs)
212+
assert len(returned_items) == list_response.total
213+
214+
assert returned_items == expected_items
215+
216+
217+
@pytest.mark.parametrize(
218+
('inputs', 'expected_items', 'client'), generate_test_params(client_set='collection', async_clients=False)
219+
)
220+
def test_client_list_iterable(client: BaseClientAsync, inputs: dict, expected_items: list[dict[str, int]]) -> None:
221+
with mock.patch.object(client.http_client, 'call', side_effect=mocked_api_pagination_logic):
222+
returned_items = [item for item in client.list(**inputs)] # noqa: C416 list needed for assertion
223+
224+
if inputs == {}:
225+
list_response = client.list(**inputs)
226+
assert len(returned_items) == list_response.total
227+
228+
assert returned_items == expected_items

0 commit comments

Comments
 (0)