Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@

from typing_extensions import Unpack

from crawlee.crawlers._basic._basic_crawler import _BasicCrawlerOptions

from crawlee.crawlers._basic._basic_crawler import _BasicCrawlerOptions, _DefaultUseState

TStaticParseResult = TypeVar('TStaticParseResult')
TStaticSelectResult = TypeVar('TStaticSelectResult')
Expand Down Expand Up @@ -389,7 +388,8 @@ async def _run_request_handler(self, context: BasicCrawlingContext) -> None:
# (This static crawl is performed only to evaluate rendering type detection.)
kvs = await context.get_key_value_store()
default_value = dict[str, JsonSerializable]()
old_state: dict[str, JsonSerializable] = await kvs.get_value(self._CRAWLEE_STATE_KEY, default_value)
# This was fragile even before. Out of scope for draft
old_state: dict[str, JsonSerializable] = await kvs.get_value(_DefaultUseState._CRAWLEE_STATE_KEY, default_value)
old_state_copy = deepcopy(old_state)

pw_run = await self._crawl_one('client only', context=context)
Expand Down
44 changes: 35 additions & 9 deletions src/crawlee/crawlers/_basic/_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import threading
import traceback
from asyncio import CancelledError
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Sequence
from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Iterable, Sequence
from contextlib import AsyncExitStack, suppress
from datetime import timedelta
from functools import partial
Expand Down Expand Up @@ -42,6 +42,7 @@
RequestHandlerRunResult,
SendRequestFunction,
SkippedReason,
UseStateFunction,
)
from crawlee._utils.docs import docs_group
from crawlee._utils.file import atomic_write, export_csv_to_stream, export_json_to_stream
Expand Down Expand Up @@ -239,6 +240,29 @@ class BasicCrawlerOptions(
"""


class _DefaultUseState:
_next_state_id = 0
_CRAWLEE_STATE_KEY = 'CRAWLEE_STATE'

def __init__(self, get_key_value_store: Awaitable[KeyValueStore]) -> None:
self._get_key_value_store = get_key_value_store
self._id = self._next_state_id
_DefaultUseState._next_state_id += 1

async def _use_state(
self,
default_value: dict[str, JsonSerializable] | None = None,
) -> dict[str, JsonSerializable]:
kvs = await self._get_key_value_store()
return await kvs.get_auto_saved_value(f'{self._CRAWLEE_STATE_KEY}_{self._id}', default_value)

def __call__(
self,
default_value: dict[str, JsonSerializable] | None = None,
) -> Coroutine[None, None, dict[str, JsonSerializable]]:
return self._use_state(default_value)


@docs_group('Crawlers')
class BasicCrawler(Generic[TCrawlingContext, TStatisticsState]):
"""A basic web crawler providing a framework for crawling websites.
Expand All @@ -264,8 +288,8 @@ class BasicCrawler(Generic[TCrawlingContext, TStatisticsState]):
- and more.
"""

_CRAWLEE_STATE_KEY = 'CRAWLEE_STATE'
_request_handler_timeout_text = 'Request handler timed out after'
__next_id = 0

def __init__(
self,
Expand Down Expand Up @@ -297,6 +321,7 @@ def __init__(
status_message_logging_interval: timedelta = timedelta(seconds=10),
status_message_callback: Callable[[StatisticsState, StatisticsState | None, str], Awaitable[str | None]]
| None = None,
use_state: UseStateFunction | None = None,
_context_pipeline: ContextPipeline[TCrawlingContext] | None = None,
_additional_context_managers: Sequence[AbstractAsyncContextManager] | None = None,
_logger: logging.Logger | None = None,
Expand Down Expand Up @@ -349,6 +374,8 @@ def __init__(
status_message_logging_interval: Interval for logging the crawler status messages.
status_message_callback: Allows overriding the default status message. The default status message is
provided in the parameters. Returning `None` suppresses the status message.
use_state: Callback used to access shared state. Use only for custom state implementation, for example when
you want to share state between two different crawlers.
_context_pipeline: Enables extending the request lifecycle and modifying the crawling context.
Intended for use by subclasses rather than direct instantiation of `BasicCrawler`.
_additional_context_managers: Additional context managers used throughout the crawler lifecycle.
Expand Down Expand Up @@ -431,6 +458,12 @@ def __init__(
self._use_session_pool = use_session_pool
self._retry_on_blocked = retry_on_blocked

# Set use state
if use_state:
self._use_state = use_state
else:
self._use_state = _DefaultUseState(get_key_value_store=self.get_key_value_store)

# Logging setup
if configure_logging:
root_logger = logging.getLogger()
Expand Down Expand Up @@ -826,13 +859,6 @@ async def add_requests(
wait_for_all_requests_to_be_added_timeout=wait_for_all_requests_to_be_added_timeout,
)

async def _use_state(
self,
default_value: dict[str, JsonSerializable] | None = None,
) -> dict[str, JsonSerializable]:
kvs = await self.get_key_value_store()
return await kvs.get_auto_saved_value(self._CRAWLEE_STATE_KEY, default_value)

async def _save_crawler_state(self) -> None:
store = await self.get_key_value_store()
await store.persist_autosaved_values()
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from uvicorn.config import Config

from crawlee import service_locator
from crawlee.crawlers._basic._basic_crawler import _DefaultUseState
from crawlee.fingerprint_suite._browserforge_adapter import get_available_header_network
from crawlee.http_clients import CurlImpersonateHttpClient, HttpxHttpClient, ImpitHttpClient
from crawlee.proxy_configuration import ProxyInfo
Expand Down Expand Up @@ -74,6 +75,7 @@ def _prepare_test_env() -> None:
# Reset global class variables to ensure test isolation.
KeyValueStore._autosaved_values = {}
Statistics._Statistics__next_id = 0 # type:ignore[attr-defined] # Mangled attribute
_DefaultUseState._next_state_id = 0 # type:ignore[attr-defined] # Mangled attribute

return _prepare_test_env

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
AdaptivePlaywrightCrawler,
AdaptivePlaywrightCrawlingContext,
AdaptivePlaywrightPreNavCrawlingContext,
BasicCrawler,
RenderingType,
RenderingTypePrediction,
RenderingTypePredictor,
Expand Down Expand Up @@ -381,7 +380,7 @@ async def test_adaptive_crawling_result_use_state_isolation(
crawler = AdaptivePlaywrightCrawler.with_beautifulsoup_static_parser(
rendering_type_predictor=static_only_predictor_enforce_detection,
)
await key_value_store.set_value(BasicCrawler._CRAWLEE_STATE_KEY, {'counter': 0})
await key_value_store.set_value('CRAWLEE_STATE_0', {'counter': 0})
request_handler_calls = 0

@crawler.router.default_handler
Expand All @@ -398,7 +397,7 @@ async def request_handler(context: AdaptivePlaywrightCrawlingContext) -> None:
# Request handler was called twice
assert request_handler_calls == 2
# Increment of global state happened only once
assert (await key_value_store.get_value(BasicCrawler._CRAWLEE_STATE_KEY))['counter'] == 1
assert (await key_value_store.get_value('CRAWLEE_STATE_0'))['counter'] == 1


async def test_adaptive_crawling_statistics(test_urls: list[str]) -> None:
Expand Down
Loading
Loading