Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def failed_handler(context: BasicCrawlingContext, error: Exception) -> Non
)

# Add the new `Request` to the `Queue`
rq = await crawler.get_request_manager()
rq = await crawler.open_request_manager()
await rq.add_request(new_request)

await crawler.run(['https://crawlee.dev/'])
Expand Down
5 changes: 3 additions & 2 deletions docs/upgrading/upgrading_to_v1.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ service_locator.set_storage_client(MemoryStorageClient()) # Raises an error
Explicitly passed services to the crawler can be different the global ones accessible in `crawlee.service_locator`. `BasicCrawler` no longer causes the global services in `service_locator` to be set to the crawler's explicitly passed services.

**Before (v0.6):**

```python
from crawlee import service_locator
from crawlee.crawlers import BasicCrawler
Expand All @@ -228,7 +229,7 @@ async def main() -> None:
crawler = BasicCrawler(storage_client=custom_storage_client)

assert service_locator.get_storage_client() is custom_storage_client
assert await crawler.get_dataset() is await Dataset.open()
assert await crawler.open_dataset() is await Dataset.open()
```
**Now (v1.0):**

Expand All @@ -244,7 +245,7 @@ async def main() -> None:
crawler = BasicCrawler(storage_client=custom_storage_client)

assert service_locator.get_storage_client() is not custom_storage_client
assert await crawler.get_dataset() is not await Dataset.open()
assert await crawler.open_dataset() is not await Dataset.open()
```

This allows two crawlers with different services at the same time.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ async def get_input_state(
use_state_function = context.use_state

# New result is created and injected to newly created context. This is done to ensure isolation of sub crawlers.
result = RequestHandlerRunResult(key_value_store_getter=self.get_key_value_store)
result = RequestHandlerRunResult(key_value_store_getter=self.open_key_value_store)
context_linked_to_result = BasicCrawlingContext(
request=deepcopy(context.request),
session=deepcopy(context.session),
Expand Down
62 changes: 37 additions & 25 deletions src/crawlee/crawlers/_basic/_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,24 +560,36 @@ async def _get_proxy_info(self, request: Request, session: Session | None) -> Pr
proxy_tier=None,
)

async def get_request_manager(self) -> RequestManager:
async def open_request_manager(self) -> RequestManager:
"""Return the configured request manager. If none is configured, open and return the default request queue."""
if not self._request_manager:
self._request_manager = await RequestQueue.open(
storage_client=self._service_locator.get_storage_client(),
configuration=self._service_locator.get_configuration(),
)

self._request_manager = await self.open_request_queue()
return self._request_manager

async def get_dataset(
async def open_request_queue(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has some funk to it - if the crawler uses a non-default request manager, this will still return the default request queue. If somebody does that, they will probably be surprised that adding requests to this queue does nothing 😁

Perhaps the method could throw if there is a non-default request manager in place?

self,
*,
id: str | None = None,
name: str | None = None,
alias: str | None = None,
) -> RequestQueue:
"""Return `RequestQueue` with the given ID or name or alias. If none is provided, return the default one."""
return await RequestQueue.open(
id=id,
name=name,
alias=alias,
storage_client=self._service_locator.get_storage_client(),
configuration=self._service_locator.get_configuration(),
)

async def open_dataset(
self,
*,
id: str | None = None,
name: str | None = None,
alias: str | None = None,
) -> Dataset:
"""Return the `Dataset` with the given ID or name. If none is provided, return the default one."""
"""Return `Dataset` with the given ID or name or alias. If none is provided, return the default one."""
return await Dataset.open(
id=id,
name=name,
Expand All @@ -586,14 +598,14 @@ async def get_dataset(
configuration=self._service_locator.get_configuration(),
)

async def get_key_value_store(
async def open_key_value_store(
self,
*,
id: str | None = None,
name: str | None = None,
alias: str | None = None,
) -> KeyValueStore:
"""Return the `KeyValueStore` with the given ID or name. If none is provided, return the default KVS."""
"""Return `KeyValueStore` with the given ID or name or alias. If none is provided, return the default KVS."""
return await KeyValueStore.open(
id=id,
name=name,
Expand Down Expand Up @@ -656,10 +668,10 @@ async def run(
if self._use_session_pool:
await self._session_pool.reset_store()

request_manager = await self.get_request_manager()
request_manager = await self.open_request_manager()
if purge_request_queue and isinstance(request_manager, RequestQueue):
await request_manager.drop()
self._request_manager = await RequestQueue.open()
self._request_manager = await self.open_request_queue()

if requests is not None:
await self.add_requests(requests)
Expand Down Expand Up @@ -778,7 +790,7 @@ async def add_requests(
await asyncio.gather(*skipped_tasks)
self._logger.warning('Some requests were skipped because they were disallowed based on the robots.txt file')

request_manager = await self.get_request_manager()
request_manager = await self.open_request_manager()

await request_manager.add_requests(
requests=allowed_requests,
Expand All @@ -793,11 +805,11 @@ async def _use_state(
self,
default_value: dict[str, JsonSerializable] | None = None,
) -> dict[str, JsonSerializable]:
kvs = await self.get_key_value_store()
kvs = await self.open_key_value_store()
return await kvs.get_auto_saved_value(self._CRAWLEE_STATE_KEY, default_value)

async def _save_crawler_state(self) -> None:
store = await self.get_key_value_store()
store = await self.open_key_value_store()
await store.persist_autosaved_values()

async def get_data(
Expand Down Expand Up @@ -887,7 +899,7 @@ async def _push_data(
dataset_alias: The alias of the `Dataset` (run scope, unnamed storage).
kwargs: Keyword arguments to be passed to the `Dataset.push_data()` method.
"""
dataset = await self.get_dataset(id=dataset_id, name=dataset_name, alias=dataset_alias)
dataset = await self.open_dataset(id=dataset_id, name=dataset_name, alias=dataset_alias)
await dataset.push_data(data, **kwargs)

def _should_retry_request(self, context: BasicCrawlingContext, error: Exception) -> bool:
Expand Down Expand Up @@ -1072,7 +1084,7 @@ async def _handle_request_retries(
context: TCrawlingContext | BasicCrawlingContext,
error: Exception,
) -> None:
request_manager = await self.get_request_manager()
request_manager = await self.open_request_manager()
request = context.request

if self._abort_on_error:
Expand Down Expand Up @@ -1155,7 +1167,7 @@ async def _handle_skipped_request(
self, request: Request | str, reason: SkippedReason, *, need_mark: bool = False
) -> None:
if need_mark and isinstance(request, Request):
request_manager = await self.get_request_manager()
request_manager = await self.open_request_manager()

await wait_for(
lambda: request_manager.mark_request_as_handled(request),
Expand Down Expand Up @@ -1241,7 +1253,7 @@ async def _commit_request_handler_result(self, context: BasicCrawlingContext) ->
"""Commit request handler result for the input `context`. Result is taken from `_context_result_map`."""
result = self._context_result_map[context]

request_manager = await self.get_request_manager()
request_manager = await self.open_request_manager()
origin = context.request.loaded_url or context.request.url

for add_requests_call in result.add_requests_calls:
Expand Down Expand Up @@ -1269,7 +1281,7 @@ async def _commit_request_handler_result(self, context: BasicCrawlingContext) ->
for push_data_call in result.push_data_calls:
await self._push_data(**push_data_call)

await self._commit_key_value_store_changes(result, get_kvs=self.get_key_value_store)
await self._commit_key_value_store_changes(result, get_kvs=self.open_key_value_store)

@staticmethod
async def _commit_key_value_store_changes(
Expand All @@ -1294,7 +1306,7 @@ async def __is_finished_function(self) -> bool:
if self._keep_alive:
return False

request_manager = await self.get_request_manager()
request_manager = await self.open_request_manager()
return await request_manager.is_finished()

async def __is_task_ready_function(self) -> bool:
Expand All @@ -1306,11 +1318,11 @@ async def __is_task_ready_function(self) -> bool:
)
return False

request_manager = await self.get_request_manager()
request_manager = await self.open_request_manager()
return not await request_manager.is_empty()

async def __run_task_function(self) -> None:
request_manager = await self.get_request_manager()
request_manager = await self.open_request_manager()

request = await wait_for(
lambda: request_manager.fetch_next_request(),
Expand All @@ -1336,7 +1348,7 @@ async def __run_task_function(self) -> None:
else:
session = await self._get_session()
proxy_info = await self._get_proxy_info(request, session)
result = RequestHandlerRunResult(key_value_store_getter=self.get_key_value_store)
result = RequestHandlerRunResult(key_value_store_getter=self.open_key_value_store)

context = BasicCrawlingContext(
request=request,
Expand Down Expand Up @@ -1582,7 +1594,7 @@ async def _crawler_state_task(self) -> None:
):
message = f'Experiencing problems, {failed_requests} failed requests since last status update.'
else:
request_manager = await self.get_request_manager()
request_manager = await self.open_request_manager()
total_count = await request_manager.get_total_count()
if total_count is not None and total_count > 0:
pages_info = f'{self._statistics.state.requests_finished}/{total_count}'
Expand Down
28 changes: 20 additions & 8 deletions src/crawlee/storages/_storage_instance_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,21 +187,33 @@ def remove_from_cache(self, storage_instance: Storage) -> None:
"""
storage_type = type(storage_instance)

for storage_client_cache in self._cache_by_storage_client.values():
for storage_client_type in self._cache_by_storage_client:
# Remove from ID cache
for additional_key in storage_client_cache.by_id[storage_type][storage_instance.id]:
del storage_client_cache.by_id[storage_type][storage_instance.id][additional_key]
for additional_key in self._cache_by_storage_client[storage_client_type].by_id[storage_type][
storage_instance.id
]:
del self._cache_by_storage_client[storage_client_type].by_id[storage_type][storage_instance.id][
additional_key
]
break

# Remove from name cache or alias cache. It can never be in both.
if storage_instance.name is not None:
for additional_key in storage_client_cache.by_name[storage_type][storage_instance.name]:
del storage_client_cache.by_name[storage_type][storage_instance.name][additional_key]
for additional_key in self._cache_by_storage_client[storage_client_type].by_name[storage_type][
storage_instance.name
]:
del self._cache_by_storage_client[storage_client_type].by_name[storage_type][storage_instance.name][
additional_key
]
break
else:
for alias_key in storage_client_cache.by_alias[storage_type]:
for additional_key in storage_client_cache.by_alias[storage_type][alias_key]:
del storage_client_cache.by_alias[storage_type][alias_key][additional_key]
for alias_key in self._cache_by_storage_client[storage_client_type].by_alias[storage_type]:
for additional_key in self._cache_by_storage_client[storage_client_type].by_alias[storage_type][
alias_key
]:
del self._cache_by_storage_client[storage_client_type].by_alias[storage_type][alias_key][
additional_key
]
break

def clear_cache(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ async def request_handler(context: AdaptivePlaywrightCrawlingContext) -> None:

await crawler.run(test_urls[:1])

dataset = await crawler.get_dataset()
dataset = await crawler.open_dataset()
stored_results = [item async for item in dataset.iterate_items()]

if error_in_pw_crawler:
Expand Down
Loading
Loading