-
Notifications
You must be signed in to change notification settings - Fork 709
Expand file tree
/
Copy path_dataset_client.py
More file actions
335 lines (272 loc) · 11.9 KB
/
_dataset_client.py
File metadata and controls
335 lines (272 loc) · 11.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
from __future__ import annotations
from logging import getLogger
from typing import TYPE_CHECKING, Any, cast
from redis.exceptions import RedisError
from typing_extensions import NotRequired, override
from crawlee.errors import StorageWriteError
from crawlee.storage_clients._base import DatasetClient
from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata
from ._client_mixin import MetadataUpdateParams, RedisClientMixin
from ._utils import await_redis_response
if TYPE_CHECKING:
from collections.abc import AsyncIterator
from redis.asyncio import Redis
from redis.asyncio.client import Pipeline
logger = getLogger(__name__)
class _DatasetMetadataUpdateParams(MetadataUpdateParams):
"""Parameters for updating dataset metadata."""
new_item_count: NotRequired[int]
delta_item_count: NotRequired[int]
class RedisDatasetClient(DatasetClient, RedisClientMixin):
"""Redis implementation of the dataset client.
This client persists dataset items to Redis using JSON arrays for efficient storage and retrieval.
Items are stored as JSON objects with automatic ordering preservation through Redis list operations.
The dataset data is stored in Redis using the following key pattern:
- `datasets:{name}:items` - Redis JSON array containing all dataset items.
- `datasets:{name}:metadata` - Redis JSON object containing dataset metadata.
Items must be JSON-serializable dictionaries. Single items or lists of items can be pushed to the dataset.
The item ordering is preserved through Redis JSON array operations. All operations provide atomic consistency
through Redis transactions and pipeline operations.
"""
_DEFAULT_NAME = 'default'
"""Default Dataset name key prefix when none provided."""
_MAIN_KEY = 'datasets'
"""Main Redis key prefix for Dataset."""
_CLIENT_TYPE = 'Dataset'
"""Human-readable client type for error messages."""
def __init__(self, storage_name: str, storage_id: str, redis: Redis) -> None:
"""Initialize a new instance.
Preferably use the `RedisDatasetClient.open` class method to create a new instance.
Args:
storage_name: Internal storage name used for Redis keys.
storage_id: Unique identifier for the dataset.
redis: Redis client instance.
"""
super().__init__(storage_name=storage_name, storage_id=storage_id, redis=redis)
@property
def _items_key(self) -> str:
"""Return the Redis key for the items of this dataset."""
return f'{self._MAIN_KEY}:{self._storage_name}:items'
@classmethod
async def open(
cls,
*,
id: str | None,
name: str | None,
alias: str | None,
redis: Redis,
) -> RedisDatasetClient:
"""Open or create a new Redis dataset client.
This method attempts to open an existing dataset from the Redis database. If a dataset with the specified
ID or name exists, it loads the metadata from the database. If no existing store is found, a new one
is created.
Args:
id: The ID of the dataset. If not provided, a random ID will be generated.
name: The name of the dataset for named (global scope) storages.
alias: The alias of the dataset for unnamed (run scope) storages.
redis: Redis client instance.
Returns:
An instance for the opened or created storage client.
"""
return await cls._open(
id=id,
name=name,
alias=alias,
redis=redis,
metadata_model=DatasetMetadata,
extra_metadata_fields={'item_count': 0},
instance_kwargs={},
)
@override
async def get_metadata(self) -> DatasetMetadata:
return await self._get_metadata(DatasetMetadata)
@override
async def drop(self) -> None:
await self._drop(extra_keys=[self._items_key])
@override
async def purge(self) -> None:
await self._purge(
extra_keys=[self._items_key],
metadata_kwargs=_DatasetMetadataUpdateParams(
new_item_count=0, update_accessed_at=True, update_modified_at=True
),
)
@override
async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None:
if isinstance(data, dict):
data = [data]
try:
async with self._get_pipeline() as pipe:
pipe.json().arrappend(self._items_key, '$', *data)
await self._update_metadata(
pipe,
**_DatasetMetadataUpdateParams(
update_accessed_at=True, update_modified_at=True, delta_item_count=len(data)
),
)
except RedisError as e:
raise StorageWriteError(e) from e
@override
async def get_data(
self,
*,
offset: int = 0,
limit: int | None = 999_999_999_999,
clean: bool = False,
desc: bool = False,
fields: list[str] | None = None,
omit: list[str] | None = None,
unwind: list[str] | None = None,
skip_empty: bool = False,
skip_hidden: bool = False,
flatten: list[str] | None = None,
view: str | None = None,
) -> DatasetItemsListPage:
# Check for unsupported arguments and log a warning if found
unsupported_args: dict[str, Any] = {
'clean': clean,
'fields': fields,
'omit': omit,
'unwind': unwind,
'skip_hidden': skip_hidden,
'flatten': flatten,
'view': view,
}
unsupported = {k: v for k, v in unsupported_args.items() if v not in (False, None)}
if unsupported:
logger.warning(
f'The arguments {list(unsupported.keys())} of get_data are not supported '
f'by the {self.__class__.__name__} client.'
)
metadata = await self.get_metadata()
total = metadata.item_count
json_path = '$'
# Apply sorting and pagination
match (desc, offset, limit):
case (True, 0, int()):
json_path += f'[-{limit}:]'
case (True, int(), None):
json_path += f'[:-{offset}]'
case (True, int(), int()):
# ty lacks support for advanced pattern matching, see https://github.com/astral-sh/ty/issues/887.
json_path += f'[-{offset + limit}:-{offset}]' # ty: ignore[unsupported-operator]
case (False, 0, int()):
json_path += f'[:{limit}]'
case (False, int(), None):
json_path += f'[{offset}:]'
case (False, int(), int()):
# ty lacks support for advanced pattern matching, see https://github.com/astral-sh/ty/issues/887.
json_path += f'[{offset}:{offset + limit}]' # ty: ignore[unsupported-operator]
if json_path == '$':
json_path = '$[*]'
data = await await_redis_response(self._redis.json().get(self._items_key, json_path))
if data is None:
data = []
data = [item for item in data if isinstance(item, dict)]
if skip_empty:
data = [item for item in data if item]
if desc:
data = list(reversed(data))
async with self._get_pipeline() as pipe:
await self._update_metadata(pipe, **_DatasetMetadataUpdateParams(update_accessed_at=True))
return DatasetItemsListPage(
count=len(data),
offset=offset,
limit=limit or (total - offset),
total=total,
desc=desc,
items=data,
)
@override
async def iterate_items(
self,
*,
offset: int = 0,
limit: int | None = None,
clean: bool = False,
desc: bool = False,
fields: list[str] | None = None,
omit: list[str] | None = None,
unwind: list[str] | None = None,
skip_empty: bool = False,
skip_hidden: bool = False,
) -> AsyncIterator[dict[str, Any]]:
"""Iterate over dataset items one by one.
This method yields items individually instead of loading all items at once,
which is more memory efficient for large datasets.
"""
# Log warnings for unsupported arguments
unsupported_args: dict[str, Any] = {
'clean': clean,
'fields': fields,
'omit': omit,
'unwind': unwind,
'skip_hidden': skip_hidden,
}
unsupported = {k: v for k, v in unsupported_args.items() if v not in (False, None)}
if unsupported:
logger.warning(
f'The arguments {list(unsupported.keys())} of iterate_items are not supported '
f'by the {self.__class__.__name__} client.'
)
metadata = await self.get_metadata()
total_items = metadata.item_count
# Calculate actual range based on parameters
start_idx = offset
end_idx = min(total_items, offset + limit) if limit is not None else total_items
# Update accessed_at timestamp
async with self._get_pipeline() as pipe:
await self._update_metadata(pipe, **_DatasetMetadataUpdateParams(update_accessed_at=True))
# Process items in batches for better network efficiency
batch_size = 100
for batch_start in range(start_idx, end_idx, batch_size):
batch_end = min(batch_start + batch_size, end_idx)
# Build JsonPath for batch slice
if desc:
# For descending order, we need to reverse the slice calculation
desc_batch_start = total_items - batch_end
desc_batch_end = total_items - batch_start
json_path = f'$[{desc_batch_start}:{desc_batch_end}]'
else:
json_path = f'$[{batch_start}:{batch_end}]'
# Get batch of items
batch_items = await await_redis_response(self._redis.json().get(self._items_key, json_path))
# Handle case where batch_items might be None or not a list
if batch_items is None:
continue
# Reverse batch if desc order (since we got items in normal order but need desc)
items_iter = reversed(batch_items) if desc else iter(batch_items)
# Yield items from batch
for item in items_iter:
# Apply skip_empty filter
if skip_empty and not item:
continue
yield cast('dict[str, Any]', item)
async with self._get_pipeline() as pipe:
await self._update_metadata(pipe, **_DatasetMetadataUpdateParams(update_accessed_at=True))
@override
async def _create_storage(self, pipeline: Pipeline) -> None:
"""Create the main dataset keys in Redis."""
# Create an empty JSON array for items
await await_redis_response(pipeline.json().set(self._items_key, '$', []))
@override
async def _specific_update_metadata(
self,
pipeline: Pipeline,
*,
new_item_count: int | None = None,
delta_item_count: int | None = None,
**_kwargs: Any,
) -> None:
"""Update the dataset metadata in the database.
Args:
pipeline: The Redis pipeline to use for the update.
new_item_count: If provided, update the item count to this value.
delta_item_count: If provided, increment the item count by this value.
"""
if new_item_count is not None:
await await_redis_response(
pipeline.json().set(self.metadata_key, '$.item_count', new_item_count, nx=False, xx=True)
)
elif delta_item_count is not None:
await await_redis_response(pipeline.json().numincrby(self.metadata_key, '$.item_count', delta_item_count))