Skip to content

Commit 396c278

Browse files
authored
Add max_count to paginated methods (#357)
1 parent 2ce69ac commit 396c278

7 files changed

Lines changed: 208 additions & 46 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ for channel in rocket.channels_list():
5757
rocket.chat_post_message('good news everyone!', channel='GENERAL', alias='Farnsworth')
5858

5959
# Get channel history
60-
rocket.channels_history('GENERAL', count=5)
60+
rocket.channels_history('GENERAL', max_count=5)
6161
```
6262

6363
### Token-Based Authentication

rocketchat_API/APISections/base.py

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1+
import itertools
12
import re
2-
33
from functools import wraps
4-
54
from json import JSONDecodeError
6-
from typing import Any
7-
5+
from typing import Any, Callable, Generator
86

97
import requests
108

@@ -16,7 +14,40 @@
1614
)
1715

1816

19-
def paginated(data_key):
17+
def _paginated_generator(
18+
self,
19+
func: Callable[..., dict[str, Any]],
20+
data_key: str,
21+
first_data: dict[str, Any],
22+
offset: int,
23+
count: int,
24+
args: tuple[Any, ...],
25+
kwargs: dict[str, Any],
26+
) -> Generator[dict[str, Any], None, None]:
27+
"""Inner generator that yields items from paginated API responses."""
28+
data = first_data
29+
while True:
30+
items = data.get(data_key, [])
31+
if not items:
32+
break
33+
34+
yield from items
35+
36+
# If we got fewer items than requested, we've reached the end
37+
if len(items) < count:
38+
break
39+
40+
offset += count
41+
# Call the original function with pagination parameters
42+
data = func(self, *args, offset=offset, count=count, **kwargs)
43+
44+
45+
def paginated(
46+
data_key: str,
47+
) -> Callable[
48+
[Callable[..., dict[str, Any]]],
49+
Callable[..., Generator[dict[str, Any], None, None]],
50+
]:
2051
"""
2152
Decorator that converts a paginated API method into an iterator.
2253
@@ -28,41 +59,41 @@ def paginated(data_key):
2859
A decorator that wraps the original method to yield items one by one,
2960
automatically handling pagination with offset and count parameters.
3061
62+
Kwargs (handled by the wrapper):
63+
offset: Starting offset for pagination (default: 0)
64+
count: Number of items per page (default: 50)
65+
max_count: Maximum total number of items to return (default: None, returns all)
66+
3167
Example:
3268
@paginated('groups')
3369
def groups_list_all(self, **kwargs):
3470
return self.call_api_get("groups.listAll", kwargs=kwargs)
35-
"""
36-
37-
def decorator(func):
38-
def _generator(self, first_data, offset, count, args, kwargs):
39-
"""Inner generator that yields items from paginated API responses."""
40-
data = first_data
41-
while True:
42-
items = data.get(data_key, [])
43-
if not items:
44-
break
4571
46-
for item in items:
47-
yield item
72+
# Get all groups
73+
list(rocket.groups_list_all())
4874
49-
# If we got fewer items than requested, we've reached the end
50-
if len(items) < count:
51-
break
52-
53-
offset += count
54-
# Call the original function with pagination parameters
55-
data = func(self, *args, offset=offset, count=count, **kwargs)
75+
# Get at most 100 groups
76+
list(rocket.groups_list_all(max_count=100))
77+
"""
5678

79+
def decorator(func):
5780
@wraps(func)
5881
def wrapper(self, *args, **kwargs):
5982
offset = kwargs.pop("offset", 0)
6083
count = kwargs.pop("count", 50)
84+
max_count = kwargs.pop("max_count", None)
6185

6286
# Call the original function eagerly to propagate any exceptions
6387
first_data = func(self, *args, offset=offset, count=count, **kwargs)
6488

65-
return _generator(self, first_data, offset, count, args, kwargs)
89+
items_gen = _paginated_generator(
90+
self, func, data_key, first_data, offset, count, args, kwargs
91+
)
92+
93+
if max_count is not None:
94+
return itertools.islice(items_gen, max_count)
95+
96+
return items_gen
6697

6798
return wrapper
6899

tests/test_channels.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def test_channels_list(logged_rocket):
5252
assert "_id" in channel
5353
assert "name" in channel
5454

55-
iterated_channels_custom = list(logged_rocket.channels_list(count=1))
56-
assert len(iterated_channels_custom) > 0
55+
iterated_channels_custom = list(logged_rocket.channels_list(max_count=1))
56+
assert len(iterated_channels_custom) == 1
5757

5858
for channel in logged_rocket.channels_list():
5959
assert "_id" in channel
@@ -67,8 +67,8 @@ def test_channels_list_joined(logged_rocket):
6767
assert "_id" in channel
6868
assert "name" in channel
6969

70-
iterated_channels_custom = list(logged_rocket.channels_list_joined(count=1))
71-
assert len(iterated_channels_custom) > 0
70+
iterated_channels_custom = list(logged_rocket.channels_list_joined(max_count=1))
71+
assert len(iterated_channels_custom) == 1
7272

7373
for channel in logged_rocket.channels_list_joined():
7474
assert "_id" in channel
@@ -96,8 +96,11 @@ def test_channels_history(logged_rocket):
9696

9797
# Test with custom count parameter
9898
iterated_messages_custom = list(
99-
logged_rocket.channels_history(room_id="GENERAL", count=1)
99+
logged_rocket.channels_history(room_id="GENERAL", max_count=1)
100100
)
101+
102+
assert len(iterated_messages_custom) == 1
103+
101104
for message in iterated_messages_custom:
102105
assert "_id" in message
103106

@@ -384,9 +387,9 @@ def test_channels_members(logged_rocket):
384387

385388
# Test with custom count parameter
386389
iterated_members_custom = list(
387-
logged_rocket.channels_members(room_id="GENERAL", count=1)
390+
logged_rocket.channels_members(room_id="GENERAL", max_count=1)
388391
)
389-
assert len(iterated_members_custom) > 0
392+
assert len(iterated_members_custom) == 1
390393

391394
for member in logged_rocket.channels_members(room_id="GENERAL"):
392395
assert "_id" in member

tests/test_groups.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def test_groups_list_all(logged_rocket):
5656
assert "_id" in group
5757
assert "name" in group
5858

59-
iterated_groups_custom = list(logged_rocket.groups_list_all(count=1))
60-
assert len(iterated_groups_custom) > 0
59+
iterated_groups_custom = list(logged_rocket.groups_list_all(max_count=1))
60+
assert len(iterated_groups_custom) == 1
6161

6262
for group in logged_rocket.groups_list_all():
6363
assert "_id" in group
@@ -363,9 +363,9 @@ def test_groups_members(logged_rocket, test_group_name, test_group_id):
363363

364364
# Test with custom count parameter
365365
iterated_members_custom = list(
366-
logged_rocket.groups_members(room_id=test_group_id, count=1)
366+
logged_rocket.groups_members(room_id=test_group_id, max_count=1)
367367
)
368-
assert len(iterated_members_custom) > 0
368+
assert len(iterated_members_custom) == 1
369369

370370
with pytest.raises(RocketMissingParamException):
371371
logged_rocket.groups_members()

tests/test_paginated.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
from typing import Any
2+
3+
from rocketchat_API.APISections.base import paginated
4+
5+
6+
class MockAPI:
7+
def __init__(self, total_items: int) -> None:
8+
self.total_items = total_items
9+
self.call_count = 0
10+
11+
@paginated("items")
12+
def get_items(self, **kwargs: Any) -> Any:
13+
self.call_count += 1
14+
offset = kwargs.get("offset", 0)
15+
count = kwargs.get("count", 50)
16+
start = offset
17+
end = min(offset + count, self.total_items)
18+
items = [{"id": i, "name": f"item_{i}"} for i in range(start, end)]
19+
return {"items": items, "success": True}
20+
21+
22+
def test_basic_pagination():
23+
api = MockAPI(total_items=150)
24+
result = list(api.get_items(count=50))
25+
26+
assert len(result) == 150
27+
assert api.call_count == 4
28+
29+
30+
def test_max_count_less_than_total():
31+
api = MockAPI(total_items=150)
32+
result: list[dict[str, Any]] = list(api.get_items(max_count=100))
33+
34+
assert len(result) == 100
35+
first_item: dict[str, Any] = result[0]
36+
last_item: dict[str, Any] = result[99]
37+
assert first_item.get("id") == 0
38+
assert last_item.get("id") == 99
39+
40+
41+
def test_max_count_greater_than_total():
42+
api = MockAPI(total_items=50)
43+
result = list(api.get_items(max_count=100))
44+
45+
assert len(result) == 50
46+
47+
48+
def test_max_count_exact_page_boundary():
49+
api = MockAPI(total_items=150)
50+
result = list(api.get_items(count=50, max_count=100))
51+
52+
assert len(result) == 100
53+
assert api.call_count == 2
54+
55+
56+
def test_max_count_mid_page():
57+
api = MockAPI(total_items=150)
58+
result = list(api.get_items(count=50, max_count=75))
59+
60+
assert len(result) == 75
61+
assert api.call_count == 2
62+
63+
64+
def test_max_count_one():
65+
api = MockAPI(total_items=150)
66+
result: list[dict[str, Any]] = list(api.get_items(max_count=1))
67+
68+
assert len(result) == 1
69+
first_item: dict[str, Any] = result[0]
70+
assert first_item.get("id") == 0
71+
assert api.call_count == 1
72+
73+
74+
def test_max_count_zero_returns_empty():
75+
api = MockAPI(total_items=150)
76+
result = list(api.get_items(max_count=0))
77+
78+
assert len(result) == 0
79+
80+
81+
def test_max_count_none_returns_all():
82+
api = MockAPI(total_items=75)
83+
result = list(api.get_items(count=50))
84+
85+
assert len(result) == 75
86+
assert api.call_count == 2
87+
88+
89+
def test_offset_with_max_count():
90+
api = MockAPI(total_items=150)
91+
result: list[dict[str, Any]] = list(api.get_items(offset=50, max_count=50))
92+
93+
assert len(result) == 50
94+
first_item: dict[str, Any] = result[0]
95+
last_item: dict[str, Any] = result[49]
96+
assert first_item.get("id") == 50
97+
assert last_item.get("id") == 99
98+
99+
100+
def test_custom_count_with_max_count():
101+
api = MockAPI(total_items=100)
102+
result = list(api.get_items(count=10, max_count=25))
103+
104+
assert len(result) == 25
105+
assert api.call_count == 3
106+
107+
108+
def test_generator_behavior():
109+
api = MockAPI(total_items=100)
110+
gen = api.get_items(max_count=10)
111+
112+
assert hasattr(gen, "__iter__")
113+
assert hasattr(gen, "__next__")
114+
115+
items = []
116+
for item in gen:
117+
items.append(item)
118+
119+
assert len(items) == 10
120+
121+
122+
def test_empty_response():
123+
api = MockAPI(total_items=0)
124+
result = list(api.get_items())
125+
126+
assert len(result) == 0
127+
assert api.call_count == 1
128+
129+
130+
def test_empty_response_with_max_count():
131+
api = MockAPI(total_items=0)
132+
result = list(api.get_items(max_count=100))
133+
134+
assert len(result) == 0
135+
assert api.call_count == 1

tests/test_rooms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def test_rooms_admin_rooms(logged_rocket):
7171
assert "t" in room
7272

7373
# Test with custom count parameter
74-
iterated_rooms_custom = list(logged_rocket.rooms_admin_rooms(count=1))
75-
assert len(iterated_rooms_custom) > 0
74+
iterated_rooms_custom = list(logged_rocket.rooms_admin_rooms(max_count=1))
75+
assert len(iterated_rooms_custom) == 1
7676

7777
rooms_with_filter = list(logged_rocket.rooms_admin_rooms(filter="general"))
7878
assert len(rooms_with_filter) == 1

tests/test_user.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,6 @@ def test_users_list(logged_rocket):
174174
assert "_id" in user
175175
assert "username" in user
176176

177-
iterated_users_custom = list(logged_rocket.users_list(count=1))
178-
assert len(iterated_users_custom) > 0
179-
assert len(iterated_users_custom) == len(iterated_users)
180-
181-
for user in logged_rocket.users_list():
182-
assert "_id" in user
183-
184177

185178
def test_users_set_status(logged_rocket):
186179
logged_rocket.users_set_status(message="working on it", status="online")

0 commit comments

Comments
 (0)