Skip to content

Commit 1bbb651

Browse files
committed
Draft of use_state as input argument
Only for discussion, types ignored for now.
1 parent 31e16d2 commit 1bbb651

5 files changed

Lines changed: 66 additions & 49 deletions

File tree

src/crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@
4343

4444
from typing_extensions import Unpack
4545

46-
from crawlee.crawlers._basic._basic_crawler import _BasicCrawlerOptions
47-
46+
from crawlee.crawlers._basic._basic_crawler import _BasicCrawlerOptions, _DefaultUseState
4847

4948
TStaticParseResult = TypeVar('TStaticParseResult')
5049
TStaticSelectResult = TypeVar('TStaticSelectResult')
@@ -389,7 +388,8 @@ async def _run_request_handler(self, context: BasicCrawlingContext) -> None:
389388
# (This static crawl is performed only to evaluate rendering type detection.)
390389
kvs = await context.get_key_value_store()
391390
default_value = dict[str, JsonSerializable]()
392-
old_state: dict[str, JsonSerializable] = await kvs.get_value(self._CRAWLEE_STATE_KEY, default_value)
391+
# This was fragile even before. Out of scope for draft
392+
old_state: dict[str, JsonSerializable] = await kvs.get_value(_DefaultUseState._CRAWLEE_STATE_KEY, default_value)
393393
old_state_copy = deepcopy(old_state)
394394

395395
pw_run = await self._crawl_one('client only', context=context)

src/crawlee/crawlers/_basic/_basic_crawler.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import threading
1111
import traceback
1212
from asyncio import CancelledError
13-
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Sequence
13+
from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Iterable, Sequence
1414
from contextlib import AsyncExitStack, suppress
1515
from datetime import timedelta
1616
from functools import partial
@@ -42,6 +42,7 @@
4242
RequestHandlerRunResult,
4343
SendRequestFunction,
4444
SkippedReason,
45+
UseStateFunction,
4546
)
4647
from crawlee._utils.docs import docs_group
4748
from crawlee._utils.file import atomic_write, export_csv_to_stream, export_json_to_stream
@@ -239,6 +240,29 @@ class BasicCrawlerOptions(
239240
"""
240241

241242

243+
class _DefaultUseState:
244+
_next_state_id = 0
245+
_CRAWLEE_STATE_KEY = 'CRAWLEE_STATE'
246+
247+
def __init__(self, get_key_value_store: Awaitable[KeyValueStore]) -> None:
248+
self._get_key_value_store = get_key_value_store
249+
self._id = self._next_state_id
250+
_DefaultUseState._next_state_id += 1
251+
252+
async def _use_state(
253+
self,
254+
default_value: dict[str, JsonSerializable] | None = None,
255+
) -> dict[str, JsonSerializable]:
256+
kvs = await self._get_key_value_store()
257+
return await kvs.get_auto_saved_value(f'{self._CRAWLEE_STATE_KEY}_{self._id}', default_value)
258+
259+
def __call__(
260+
self,
261+
default_value: dict[str, JsonSerializable] | None = None,
262+
) -> Coroutine[None, None, dict[str, JsonSerializable]]:
263+
return self._use_state(default_value)
264+
265+
242266
@docs_group('Crawlers')
243267
class BasicCrawler(Generic[TCrawlingContext, TStatisticsState]):
244268
"""A basic web crawler providing a framework for crawling websites.
@@ -264,7 +288,6 @@ class BasicCrawler(Generic[TCrawlingContext, TStatisticsState]):
264288
- and more.
265289
"""
266290

267-
_CRAWLEE_STATE_KEY = 'CRAWLEE_STATE'
268291
_request_handler_timeout_text = 'Request handler timed out after'
269292
__next_id = 0
270293

@@ -298,7 +321,7 @@ def __init__(
298321
status_message_logging_interval: timedelta = timedelta(seconds=10),
299322
status_message_callback: Callable[[StatisticsState, StatisticsState | None, str], Awaitable[str | None]]
300323
| None = None,
301-
crawler_id: int | None = None,
324+
use_state: UseStateFunction | None = None,
302325
_context_pipeline: ContextPipeline[TCrawlingContext] | None = None,
303326
_additional_context_managers: Sequence[AbstractAsyncContextManager] | None = None,
304327
_logger: logging.Logger | None = None,
@@ -351,22 +374,15 @@ def __init__(
351374
status_message_logging_interval: Interval for logging the crawler status messages.
352375
status_message_callback: Allows overriding the default status message. The default status message is
353376
provided in the parameters. Returning `None` suppresses the status message.
354-
crawler_id: Id of the crawler used for state and statistics tracking. You can use same explicit id to share
355-
state and statistics between two crawlers. By default, each crawler will use own state and statistics.
377+
use_state: Callback used to access shared state. Use only for custom state implementation, for example when
378+
you want to share state between two different crawlers.
356379
_context_pipeline: Enables extending the request lifecycle and modifying the crawling context.
357380
Intended for use by subclasses rather than direct instantiation of `BasicCrawler`.
358381
_additional_context_managers: Additional context managers used throughout the crawler lifecycle.
359382
Intended for use by subclasses rather than direct instantiation of `BasicCrawler`.
360383
_logger: A logger instance, typically provided by a subclass, for consistent logging labels.
361384
Intended for use by subclasses rather than direct instantiation of `BasicCrawler`.
362385
"""
363-
if crawler_id is None:
364-
# This could look into set of already used ids, but lets not overengineer this.
365-
self.id = BasicCrawler.__next_id
366-
BasicCrawler.__next_id += 1
367-
else:
368-
self.id = crawler_id
369-
370386
implicit_event_manager_with_explicit_config = False
371387
if not configuration:
372388
configuration = service_locator.get_configuration()
@@ -442,6 +458,12 @@ def __init__(
442458
self._use_session_pool = use_session_pool
443459
self._retry_on_blocked = retry_on_blocked
444460

461+
# Set use state
462+
if use_state:
463+
self._use_state = use_state
464+
else:
465+
self._use_state = _DefaultUseState(get_key_value_store=self.get_key_value_store)
466+
445467
# Logging setup
446468
if configure_logging:
447469
root_logger = logging.getLogger()
@@ -837,13 +859,6 @@ async def add_requests(
837859
wait_for_all_requests_to_be_added_timeout=wait_for_all_requests_to_be_added_timeout,
838860
)
839861

840-
async def _use_state(
841-
self,
842-
default_value: dict[str, JsonSerializable] | None = None,
843-
) -> dict[str, JsonSerializable]:
844-
kvs = await self.get_key_value_store()
845-
return await kvs.get_auto_saved_value(f'{self._CRAWLEE_STATE_KEY}_{self.id}', default_value)
846-
847862
async def _save_crawler_state(self) -> None:
848863
store = await self.get_key_value_store()
849864
await store.persist_autosaved_values()

tests/unit/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from uvicorn.config import Config
1313

1414
from crawlee import service_locator
15-
from crawlee.crawlers import BasicCrawler
15+
from crawlee.crawlers._basic._basic_crawler import _DefaultUseState
1616
from crawlee.fingerprint_suite._browserforge_adapter import get_available_header_network
1717
from crawlee.http_clients import CurlImpersonateHttpClient, HttpxHttpClient, ImpitHttpClient
1818
from crawlee.proxy_configuration import ProxyInfo
@@ -75,7 +75,7 @@ def _prepare_test_env() -> None:
7575
# Reset global class variables to ensure test isolation.
7676
KeyValueStore._autosaved_values = {}
7777
Statistics._Statistics__next_id = 0 # type:ignore[attr-defined] # Mangled attribute
78-
BasicCrawler._BasicCrawler__next_id = 0 # type:ignore[attr-defined] # Mangled attribute
78+
_DefaultUseState._next_state_id = 0 # type:ignore[attr-defined] # Mangled attribute
7979

8080
return _prepare_test_env
8181

tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
AdaptivePlaywrightCrawler,
2020
AdaptivePlaywrightCrawlingContext,
2121
AdaptivePlaywrightPreNavCrawlingContext,
22-
BasicCrawler,
2322
RenderingType,
2423
RenderingTypePrediction,
2524
RenderingTypePredictor,
@@ -381,7 +380,7 @@ async def test_adaptive_crawling_result_use_state_isolation(
381380
crawler = AdaptivePlaywrightCrawler.with_beautifulsoup_static_parser(
382381
rendering_type_predictor=static_only_predictor_enforce_detection,
383382
)
384-
await key_value_store.set_value(f'{BasicCrawler._CRAWLEE_STATE_KEY}_0', {'counter': 0})
383+
await key_value_store.set_value('CRAWLEE_STATE_0', {'counter': 0})
385384
request_handler_calls = 0
386385

387386
@crawler.router.default_handler
@@ -398,7 +397,7 @@ async def request_handler(context: AdaptivePlaywrightCrawlingContext) -> None:
398397
# Request handler was called twice
399398
assert request_handler_calls == 2
400399
# Increment of global state happened only once
401-
assert (await key_value_store.get_value(f'{BasicCrawler._CRAWLEE_STATE_KEY}_0'))['counter'] == 1
400+
assert (await key_value_store.get_value('CRAWLEE_STATE_0'))['counter'] == 1
402401

403402

404403
async def test_adaptive_crawling_statistics(test_urls: list[str]) -> None:

tests/unit/crawlers/_basic/test_basic_crawler.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -810,36 +810,39 @@ async def handler(context: BasicCrawlingContext) -> None:
810810
await crawler.run(['https://hello.world'])
811811

812812
kvs = await crawler.get_key_value_store()
813-
value = await kvs.get_value(f'{BasicCrawler._CRAWLEE_STATE_KEY}_0')
813+
value = await kvs.get_value('CRAWLEE_STATE_0')
814814

815815
assert value == {'hello': 'world'}
816816

817817

818-
async def test_context_use_state_crawlers_share_state() -> None:
818+
async def test_context_use_state_crawlers_share_custom_state() -> None:
819+
custom_state_dict = {}
820+
821+
async def custom_use_state(default_state: dict[str, JsonSerializable]) -> dict[str, JsonSerializable]:
822+
if not custom_state_dict:
823+
custom_state_dict.update(default_state)
824+
return custom_state_dict
825+
819826
async def handler(context: BasicCrawlingContext) -> None:
820827
state = await context.use_state({'urls': []})
821828
assert isinstance(state['urls'], list)
822829
state['urls'].append(context.request.url)
823830

824-
crawler_1 = BasicCrawler(crawler_id=0, request_handler=handler)
825-
crawler_2 = BasicCrawler(crawler_id=0, request_handler=handler)
831+
crawler_1 = BasicCrawler(use_state=custom_use_state, request_handler=handler)
832+
crawler_2 = BasicCrawler(use_state=custom_use_state, request_handler=handler)
826833

827834
await crawler_1.run(['https://a.com'])
828835
await crawler_2.run(['https://b.com'])
829836

830-
kvs = await KeyValueStore.open()
831-
assert crawler_1.id == crawler_2.id == 0
832-
assert await kvs.get_value(f'{BasicCrawler._CRAWLEE_STATE_KEY}_{crawler_1.id}') == {
833-
'urls': ['https://a.com', 'https://b.com']
834-
}
837+
assert custom_state_dict == {'urls': ['https://a.com', 'https://b.com']}
835838

836839

837840
async def test_crawlers_share_stats() -> None:
838841
async def handler(context: BasicCrawlingContext) -> None:
839842
await context.use_state({'urls': []})
840843

841-
crawler_1 = BasicCrawler(crawler_id=0, request_handler=handler)
842-
crawler_2 = BasicCrawler(crawler_id=0, request_handler=handler, statistics=crawler_1.statistics)
844+
crawler_1 = BasicCrawler(request_handler=handler)
845+
crawler_2 = BasicCrawler(request_handler=handler, statistics=crawler_1.statistics)
843846

844847
result1 = await crawler_1.run(['https://a.com'])
845848
result2 = await crawler_2.run(['https://b.com'])
@@ -862,8 +865,8 @@ async def handler(context: BasicCrawlingContext) -> None:
862865
await crawler_2.run(['https://b.com'])
863866

864867
kvs = await KeyValueStore.open()
865-
assert await kvs.get_value(f'{BasicCrawler._CRAWLEE_STATE_KEY}_0') == {'urls': ['https://a.com']}
866-
assert await kvs.get_value(f'{BasicCrawler._CRAWLEE_STATE_KEY}_1') == {'urls': ['https://b.com']}
868+
assert await kvs.get_value('CRAWLEE_STATE_0') == {'urls': ['https://a.com']}
869+
assert await kvs.get_value('CRAWLEE_STATE_1') == {'urls': ['https://b.com']}
867870

868871

869872
async def test_context_handlers_use_state(key_value_store: KeyValueStore) -> None:
@@ -906,7 +909,7 @@ async def handler_three(context: BasicCrawlingContext) -> None:
906909
store = await crawler.get_key_value_store()
907910

908911
# The state in the KVS must match with the last set state
909-
assert (await store.get_value(f'{BasicCrawler._CRAWLEE_STATE_KEY}_0')) == {'hello': 'last_world'}
912+
assert (await store.get_value('CRAWLEE_STATE_0')) == {'hello': 'last_world'}
910913

911914

912915
async def test_max_requests_per_crawl() -> None:
@@ -1334,7 +1337,7 @@ async def test_context_use_state_race_condition_in_handlers(key_value_store: Key
13341337

13351338
crawler = BasicCrawler()
13361339
store = await crawler.get_key_value_store()
1337-
await store.set_value(f'{BasicCrawler._CRAWLEE_STATE_KEY}_0', {'counter': 0})
1340+
await store.set_value('CRAWLEE_STATE_0', {'counter': 0})
13381341
handler_barrier = Barrier(2)
13391342

13401343
@crawler.router.default_handler
@@ -1349,7 +1352,7 @@ async def handler(context: BasicCrawlingContext) -> None:
13491352
store = await crawler.get_key_value_store()
13501353
# Ensure that local state is pushed back to kvs.
13511354
await store.persist_autosaved_values()
1352-
assert (await store.get_value(f'{BasicCrawler._CRAWLEE_STATE_KEY}_0'))['counter'] == 2
1355+
assert (await store.get_value('CRAWLEE_STATE_0'))['counter'] == 2
13531356

13541357

13551358
@pytest.mark.run_alone
@@ -1859,7 +1862,7 @@ async def test_crawler_state_persistence(tmp_path: Path) -> None:
18591862
).result()[0]
18601863
# Expected state after first crawler run
18611864
assert first_run_state.requests_finished == 2
1862-
state = await state_kvs.get_value(f'{BasicCrawler._CRAWLEE_STATE_KEY}_0')
1865+
state = await state_kvs.get_value('CRAWLEE_STATE_0')
18631866
assert state.get('urls') == ['https://a.placeholder.com', 'https://b.placeholder.com']
18641867

18651868
# Do not reuse the executor to simulate a fresh process to avoid modified class attributes.
@@ -1875,7 +1878,7 @@ async def test_crawler_state_persistence(tmp_path: Path) -> None:
18751878
# 2 requests from first run and 1 request from second run.
18761879
assert second_run_state.requests_finished == 3
18771880

1878-
state = await state_kvs.get_value(f'{BasicCrawler._CRAWLEE_STATE_KEY}_0')
1881+
state = await state_kvs.get_value('CRAWLEE_STATE_0')
18791882
assert state.get('urls') == [
18801883
'https://a.placeholder.com',
18811884
'https://b.placeholder.com',
@@ -1912,9 +1915,9 @@ async def test_crawler_state_persistence_2_crawlers_with_migration(tmp_path: Pat
19121915
# Expected state after first crawler run
19131916
assert first_run_states[0].requests_finished == 1
19141917
assert first_run_states[1].requests_finished == 1
1915-
state_0 = await state_kvs.get_value(f'{BasicCrawler._CRAWLEE_STATE_KEY}_0')
1918+
state_0 = await state_kvs.get_value('CRAWLEE_STATE_0')
19161919
assert state_0.get('urls') == ['https://a.placeholder.com']
1917-
state_1 = await state_kvs.get_value(f'{BasicCrawler._CRAWLEE_STATE_KEY}_1')
1920+
state_1 = await state_kvs.get_value('CRAWLEE_STATE_1')
19181921
assert state_1.get('urls') == ['https://c.placeholder.com']
19191922

19201923
with ProcessPoolExecutor() as executor:
@@ -1930,9 +1933,9 @@ async def test_crawler_state_persistence_2_crawlers_with_migration(tmp_path: Pat
19301933
# Expected state after first crawler run
19311934
assert second_run_states[0].requests_finished == 2
19321935
assert second_run_states[1].requests_finished == 2
1933-
state_0 = await state_kvs.get_value(f'{BasicCrawler._CRAWLEE_STATE_KEY}_0')
1936+
state_0 = await state_kvs.get_value('CRAWLEE_STATE_0')
19341937
assert state_0.get('urls') == ['https://a.placeholder.com', 'https://b.placeholder.com']
1935-
state_1 = await state_kvs.get_value(f'{BasicCrawler._CRAWLEE_STATE_KEY}_1')
1938+
state_1 = await state_kvs.get_value('CRAWLEE_STATE_1')
19361939
assert state_1.get('urls') == ['https://c.placeholder.com', 'https://d.placeholder.com']
19371940

19381941

0 commit comments

Comments
 (0)