Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/crawlee/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,9 @@ class BasicCrawlingContext:
log: logging.Logger
"""Logger instance."""

register_deferred_cleanup: Callable[[Callable[[], Coroutine[None, None, None]]], None]
Comment thread
vdusek marked this conversation as resolved.
Outdated
"""Register an async callback to be called after request processing completes (including error handlers)."""

async def get_snapshot(self) -> PageSnapshot:
"""Get snapshot of crawled page."""
return PageSnapshot()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ async def get_input_state(
get_key_value_store=result.get_key_value_store,
use_state=use_state_function,
log=context.log,
register_deferred_cleanup=context.register_deferred_cleanup,
)

try:
Expand Down
10 changes: 10 additions & 0 deletions src/crawlee/crawlers/_basic/_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,6 +1413,8 @@ async def __run_task_function(self) -> None:
proxy_info = await self._get_proxy_info(request, session)
result = RequestHandlerRunResult(key_value_store_getter=self.get_key_value_store, request=request)

deferred_cleanup: list[Callable[[], Awaitable[None]]] = []

context = BasicCrawlingContext(
request=result.request,
session=session,
Expand All @@ -1423,6 +1425,7 @@ async def __run_task_function(self) -> None:
get_key_value_store=result.get_key_value_store,
use_state=self._use_state,
log=self._logger,
register_deferred_cleanup=deferred_cleanup.append,
)
self._context_result_map[context] = result

Expand Down Expand Up @@ -1509,6 +1512,13 @@ async def __run_task_function(self) -> None:
)
raise

finally:
for cleanup in deferred_cleanup:
try:
await cleanup()
except Exception: # noqa: PERF203
Comment thread
vdusek marked this conversation as resolved.
self._logger.exception('Error in deferred cleanup')

async def _run_request_handler(self, context: BasicCrawlingContext) -> None:
context.request.state = RequestState.BEFORE_NAV
await self._context_pipeline(
Expand Down
117 changes: 65 additions & 52 deletions src/crawlee/crawlers/_playwright/_playwright_crawler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import contextlib
import logging
import warnings
from datetime import timedelta
Expand Down Expand Up @@ -236,6 +237,7 @@ async def _open_page(
proxy_info=context.proxy_info,
get_key_value_store=context.get_key_value_store,
log=context.log,
register_deferred_cleanup=context.register_deferred_cleanup,
page=crawlee_page.page,
block_requests=partial(block_requests, page=crawlee_page.page),
goto_options=GotoOptions(**self._goto_options),
Expand Down Expand Up @@ -296,63 +298,73 @@ async def _navigate(
The enhanced crawling context with the Playwright-specific features (page, response, enqueue_links,
infinite_scroll and block_requests).
"""
async with context.page:
if context.session:
session_cookies = context.session.cookies.get_cookies_as_playwright_format()
await self._update_cookies(context.page, session_cookies)

if context.request.headers:
await context.page.set_extra_http_headers(context.request.headers.model_dump())
# Navigate to the URL and get response.
if context.request.method != 'GET':
# Call the notification only once
warnings.warn(
'Using other request methods than GET or adding payloads has a high impact on performance'
' in recent versions of Playwright. Use only when necessary.',
category=UserWarning,
stacklevel=2,
)
# Enter the page context manager, but defer its cleanup (page.close()) so the page stays open
# during error handler execution.
await context.page.__aenter__()

route_handler = self._prepare_request_interceptor(
method=context.request.method,
headers=context.request.headers,
payload=context.request.payload,
)
async def _close_page() -> None:
with contextlib.suppress(Exception):
await context.page.__aexit__(None, None, None)
Comment thread
vdusek marked this conversation as resolved.
Outdated

# Set route_handler only for current request
await context.page.route(context.request.url, route_handler)
context.register_deferred_cleanup(_close_page)

try:
async with self._shared_navigation_timeouts[id(context)] as remaining_timeout:
response = await context.page.goto(
context.request.url, timeout=remaining_timeout.total_seconds() * 1000, **context.goto_options
)
context.request.state = RequestState.AFTER_NAV
except playwright.async_api.TimeoutError as exc:
raise asyncio.TimeoutError from exc

if response is None:
raise SessionError(f'Failed to load the URL: {context.request.url}')

# Set the loaded URL to the actual URL after redirection.
context.request.loaded_url = context.page.url

yield PlaywrightPostNavCrawlingContext(
request=context.request,
session=context.session,
add_requests=context.add_requests,
send_request=context.send_request,
push_data=context.push_data,
use_state=context.use_state,
proxy_info=context.proxy_info,
get_key_value_store=context.get_key_value_store,
log=context.log,
page=context.page,
block_requests=context.block_requests,
goto_options=context.goto_options,
response=response,
if context.session:
session_cookies = context.session.cookies.get_cookies_as_playwright_format()
await self._update_cookies(context.page, session_cookies)

if context.request.headers:
await context.page.set_extra_http_headers(context.request.headers.model_dump())
# Navigate to the URL and get response.
if context.request.method != 'GET':
# Call the notification only once
warnings.warn(
'Using other request methods than GET or adding payloads has a high impact on performance'
' in recent versions of Playwright. Use only when necessary.',
category=UserWarning,
stacklevel=2,
)

route_handler = self._prepare_request_interceptor(
method=context.request.method,
headers=context.request.headers,
payload=context.request.payload,
)

# Set route_handler only for current request
await context.page.route(context.request.url, route_handler)

try:
async with self._shared_navigation_timeouts[id(context)] as remaining_timeout:
response = await context.page.goto(
context.request.url, timeout=remaining_timeout.total_seconds() * 1000, **context.goto_options
)
context.request.state = RequestState.AFTER_NAV
except playwright.async_api.TimeoutError as exc:
raise asyncio.TimeoutError from exc

if response is None:
raise SessionError(f'Failed to load the URL: {context.request.url}')

# Set the loaded URL to the actual URL after redirection.
context.request.loaded_url = context.page.url

yield PlaywrightPostNavCrawlingContext(
request=context.request,
session=context.session,
add_requests=context.add_requests,
send_request=context.send_request,
push_data=context.push_data,
use_state=context.use_state,
proxy_info=context.proxy_info,
get_key_value_store=context.get_key_value_store,
log=context.log,
register_deferred_cleanup=context.register_deferred_cleanup,
page=context.page,
block_requests=context.block_requests,
goto_options=context.goto_options,
response=response,
)

def _create_extract_links_function(self, context: PlaywrightPreNavCrawlingContext) -> ExtractLinksFunction:
"""Create a callback function for extracting links from context.

Expand Down Expand Up @@ -508,6 +520,7 @@ async def _create_crawling_context(
proxy_info=context.proxy_info,
get_key_value_store=context.get_key_value_store,
log=context.log,
register_deferred_cleanup=context.register_deferred_cleanup,
page=context.page,
goto_options=context.goto_options,
response=context.response,
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/crawlers/_basic/test_context_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ async def test_calls_consumer_without_middleware() -> None:
use_state=AsyncMock(),
get_key_value_store=AsyncMock(),
log=logging.getLogger(),
register_deferred_cleanup=lambda _: None,
)

await pipeline(context, consumer)
Expand Down Expand Up @@ -68,6 +69,7 @@ async def middleware_a(context: BasicCrawlingContext) -> AsyncGenerator[Enhanced
use_state=AsyncMock(),
get_key_value_store=AsyncMock(),
log=logging.getLogger(),
register_deferred_cleanup=context.register_deferred_cleanup,
)
events.append('middleware_a_out')

Expand All @@ -85,6 +87,7 @@ async def middleware_b(context: EnhancedCrawlingContext) -> AsyncGenerator[MoreE
use_state=AsyncMock(),
get_key_value_store=AsyncMock(),
log=logging.getLogger(),
register_deferred_cleanup=context.register_deferred_cleanup,
)
events.append('middleware_b_out')

Expand All @@ -100,6 +103,7 @@ async def middleware_b(context: EnhancedCrawlingContext) -> AsyncGenerator[MoreE
use_state=AsyncMock(),
get_key_value_store=AsyncMock(),
log=logging.getLogger(),
register_deferred_cleanup=lambda _: None,
)
await pipeline(context, consumer)

Expand All @@ -126,6 +130,7 @@ async def test_wraps_consumer_errors() -> None:
use_state=AsyncMock(),
get_key_value_store=AsyncMock(),
log=logging.getLogger(),
register_deferred_cleanup=lambda _: None,
)

with pytest.raises(RequestHandlerError):
Expand Down Expand Up @@ -155,6 +160,7 @@ async def step_2(context: BasicCrawlingContext) -> AsyncGenerator[BasicCrawlingC
use_state=AsyncMock(),
get_key_value_store=AsyncMock(),
log=logging.getLogger(),
register_deferred_cleanup=lambda _: None,
)

with pytest.raises(ContextPipelineInitializationError):
Expand Down Expand Up @@ -187,6 +193,7 @@ async def step_2(context: BasicCrawlingContext) -> AsyncGenerator[BasicCrawlingC
use_state=AsyncMock(),
get_key_value_store=AsyncMock(),
log=logging.getLogger(),
register_deferred_cleanup=lambda _: None,
)

with pytest.raises(ContextPipelineFinalizationError):
Expand Down
36 changes: 34 additions & 2 deletions tests/unit/crawlers/_playwright/test_playwright_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
service_locator,
)
from crawlee.configuration import Configuration
from crawlee.crawlers import PlaywrightCrawler
from crawlee.crawlers import (
PlaywrightCrawler,
PlaywrightCrawlingContext,
)
from crawlee.fingerprint_suite import (
DefaultFingerprintGenerator,
FingerprintGenerator,
Expand Down Expand Up @@ -49,7 +52,6 @@
from crawlee.browsers._types import BrowserType
from crawlee.crawlers import (
BasicCrawlingContext,
PlaywrightCrawlingContext,
PlaywrightPostNavCrawlingContext,
PlaywrightPreNavCrawlingContext,
)
Expand Down Expand Up @@ -1203,3 +1205,33 @@ async def post_nav_hook_2(_context: PlaywrightPostNavCrawlingContext) -> None:
'post-navigation-hook 2',
'final handler',
]


async def test_error_handler_can_access_page(server_url: URL) -> None:
"""Test that the error handler can access the Page object via PlaywrightCrawlingContext."""

crawler = PlaywrightCrawler(max_request_retries=2)

request_handler = mock.AsyncMock(side_effect=RuntimeError('Intentional crash'))
crawler.router.default_handler(request_handler)

error_handler_calls: list[str | None] = []

@crawler.error_handler
async def error_handler(context: BasicCrawlingContext | PlaywrightCrawlingContext, _error: Exception) -> None:
error_handler_calls.append(
await context.page.content() if isinstance(context, PlaywrightCrawlingContext) else None
)

failed_handler_calls: list[str | None] = []

@crawler.failed_request_handler
async def failed_handler(context: BasicCrawlingContext | PlaywrightCrawlingContext, _error: Exception) -> None:
failed_handler_calls.append(
await context.page.content() if isinstance(context, PlaywrightCrawlingContext) else None
)

await crawler.run([str(server_url / 'hello-world')])

assert error_handler_calls == [HELLO_WORLD.decode(), HELLO_WORLD.decode()]
assert failed_handler_calls == [HELLO_WORLD.decode()]
1 change: 1 addition & 0 deletions tests/unit/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self, *, label: str | None) -> None:
use_state=AsyncMock(),
get_key_value_store=AsyncMock(),
log=logging.getLogger(),
register_deferred_cleanup=lambda _: None,
)


Expand Down