Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,53 @@ if __name__ == "__main__":

</details>

<details>
<summary>☁️ <strong>AWS Bedrock LLM Support with Flexible Authentication</strong></summary>

Crawl4AI supports AWS Bedrock models with multiple authentication methods. Credentials are resolved with priority: **explicit code > environment variables > AWS profile > IAM role**.

```python
# Method 1: IAM Role (recommended for EC2/ECS/Lambda)
LLMConfig(
provider="bedrock/us.anthropic.claude-sonnet-4-6",
provider_config={
"aws_region_name": "us-east-1",
}
)

# Method 2: AWS Profile
LLMConfig(
provider="bedrock/us.anthropic.claude-sonnet-4-6",
provider_config={
"aws_region_name": "us-east-1",
"aws_profile_name": "my-profile",
}
)

# Method 3: Explicit Credentials (highest priority)
LLMConfig(
provider="bedrock/us.anthropic.claude-sonnet-4-6",
provider_config={
"aws_region_name": "us-east-1",
"aws_access_key_id": "AKIA...",
"aws_secret_access_key": "...",
}
)

# Method 4: Environment Variables (automatic fallback)
# Set: export AWS_REGION=us-east-1
# export AWS_ACCESS_KEY_ID=AKIA...
# export AWS_SECRET_ACCESS_KEY=...
# or Set: export AWS_REGION=us-east-1
# export AWS_PROFILE=...
LLMConfig(
provider="bedrock/us.anthropic.claude-sonnet-4-6",
# Automatically uses AWS environment variables
)
```

</details>

<details>
<summary>🤖 <strong>Using Your own Browser with Custom User Profile</strong></summary>

Expand Down
4 changes: 4 additions & 0 deletions crawl4ai/async_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2028,6 +2028,7 @@ def __init__(
provider: str = DEFAULT_PROVIDER,
api_token: Optional[str] = None,
base_url: Optional[str] = None,
provider_config: Optional[Dict] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
Expand All @@ -2041,6 +2042,7 @@ def __init__(
):
"""Configuaration class for LLM provider and API token."""
self.provider = provider
self.provider_config = provider_config or {}
if api_token and not api_token.startswith("env:"):
self.api_token = api_token
elif api_token and api_token.startswith("env:"):
Expand Down Expand Up @@ -2076,6 +2078,7 @@ def from_kwargs(kwargs: dict) -> "LLMConfig":
provider=kwargs.get("provider", DEFAULT_PROVIDER),
api_token=kwargs.get("api_token"),
base_url=kwargs.get("base_url"),
provider_config=kwargs.get("provider_config"),
temperature=kwargs.get("temperature"),
max_tokens=kwargs.get("max_tokens"),
top_p=kwargs.get("top_p"),
Expand All @@ -2093,6 +2096,7 @@ def to_dict(self):
"provider": self.provider,
"api_token": self.api_token,
"base_url": self.base_url,
"provider_config": self.provider_config,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
Expand Down
4 changes: 4 additions & 0 deletions crawl4ai/extraction_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,7 @@ def extract(self, url: str, ix: int, html: str) -> List[Dict[str, Any]]:
self.llm_config.api_token,
base_url=self.llm_config.base_url,
json_response=self.force_json_response,
provider_config=self.llm_config.provider_config,
extra_args=self.extra_args,
base_delay=self.llm_config.backoff_base_delay,
max_attempts=self.llm_config.backoff_max_attempts,
Expand Down Expand Up @@ -891,6 +892,7 @@ async def aextract(self, url: str, ix: int, html: str) -> List[Dict[str, Any]]:
self.llm_config.api_token,
base_url=self.llm_config.base_url,
json_response=self.force_json_response,
provider_config=self.llm_config.provider_config,
extra_args=self.extra_args,
base_delay=self.llm_config.backoff_base_delay,
max_attempts=self.llm_config.backoff_max_attempts,
Expand Down Expand Up @@ -1602,6 +1604,7 @@ async def _infer_target_json(query: str, html_snippet: str, llm_config, url: str
json_response=True,
api_token=llm_config.api_token,
base_url=llm_config.base_url,
provider_config=llm_config.provider_config,
)
if usage is not None:
usage.completion_tokens += response.usage.completion_tokens
Expand Down Expand Up @@ -1919,6 +1922,7 @@ async def agenerate_schema(
json_response=True,
api_token=llm_config.api_token,
base_url=llm_config.base_url,
provider_config=llm_config.provider_config,
messages=messages,
extra_args=kwargs,
)
Expand Down
48 changes: 48 additions & 0 deletions crawl4ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1749,6 +1749,7 @@ def perform_completion_with_backoff(
max_attempts=3,
exponential_factor=2,
messages=None,
provider_config=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -1783,6 +1784,29 @@ def perform_completion_with_backoff(
if json_response:
extra_args["response_format"] = {"type": "json_object"}

if provider.startswith("bedrock/"):
pc = provider_config or {}

region = pc.get("aws_region_name") or os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION")
if region:
extra_args["aws_region_name"] = region

profile = pc.get("aws_profile_name") or os.getenv("AWS_PROFILE")
if profile:
extra_args["aws_profile_name"] = profile

access_key = pc.get("aws_access_key_id") or os.getenv("AWS_ACCESS_KEY_ID")
if access_key:
extra_args["aws_access_key_id"] = access_key

secret_key = pc.get("aws_secret_access_key") or os.getenv("AWS_SECRET_ACCESS_KEY")
if secret_key:
extra_args["aws_secret_access_key"] = secret_key

session_token = pc.get("aws_session_token") or os.getenv("AWS_SESSION_TOKEN")
if session_token:
extra_args["aws_session_token"] = session_token

if kwargs.get("extra_args"):
extra_args.update(kwargs["extra_args"])

Expand Down Expand Up @@ -1841,6 +1865,7 @@ async def aperform_completion_with_backoff(
max_attempts=3,
exponential_factor=2,
messages=None,
provider_config=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -1876,6 +1901,29 @@ async def aperform_completion_with_backoff(
if json_response:
extra_args["response_format"] = {"type": "json_object"}

if provider.startswith("bedrock/"):
pc = provider_config or {}

region = pc.get("aws_region_name") or os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION")
if region:
extra_args["aws_region_name"] = region

profile = pc.get("aws_profile_name") or os.getenv("AWS_PROFILE")
if profile:
extra_args["aws_profile_name"] = profile

access_key = pc.get("aws_access_key_id") or os.getenv("AWS_ACCESS_KEY_ID")
if access_key:
extra_args["aws_access_key_id"] = access_key

secret_key = pc.get("aws_secret_access_key") or os.getenv("AWS_SECRET_ACCESS_KEY")
if secret_key:
extra_args["aws_secret_access_key"] = secret_key

session_token = pc.get("aws_session_token") or os.getenv("AWS_SESSION_TOKEN")
if session_token:
extra_args["aws_session_token"] = session_token

if kwargs.get("extra_args"):
extra_args.update(kwargs["extra_args"])

Expand Down
191 changes: 191 additions & 0 deletions tests/test_completion_provider_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import os
from unittest.mock import AsyncMock, MagicMock

import pytest

from crawl4ai.utils import aperform_completion_with_backoff, perform_completion_with_backoff


@pytest.fixture
def mock_completion(monkeypatch):
mock = MagicMock()
mock_response = MagicMock()
mock_response.usage.completion_tokens = 10
mock_response.usage.prompt_tokens = 20
mock_response.usage.total_tokens = 30
mock_response.usage.completion_tokens_details = None
mock_response.usage.prompt_tokens_details = None
mock.return_value = mock_response
monkeypatch.setattr("litellm.completion", mock)
return mock


@pytest.fixture
def mock_acompletion(monkeypatch):
mock = AsyncMock()
mock_response = MagicMock()
mock_response.usage.completion_tokens = 10
mock_response.usage.prompt_tokens = 20
mock_response.usage.total_tokens = 30
mock_response.usage.completion_tokens_details = None
mock_response.usage.prompt_tokens_details = None
mock.return_value = mock_response
monkeypatch.setattr("litellm.acompletion", mock)
return mock


class TestPerformCompletionWithProviderConfig:
def setup_method(self):
env_vars = [
"AWS_REGION",
"AWS_DEFAULT_REGION",
"AWS_PROFILE",
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_SESSION_TOKEN",
]
for var in env_vars:
os.environ.pop(var, None)

def test_bedrock_provider_config_passed_to_litellm(self, mock_completion):
provider_config = {
"aws_region_name": "us-west-2",
"aws_profile_name": "dev",
"aws_access_key_id": "AKIA123",
"aws_secret_access_key": "secret123",
"aws_session_token": "token123",
}

perform_completion_with_backoff(
provider="bedrock/anthropic.claude-v2",
prompt_with_variables="test prompt",
api_token=None,
provider_config=provider_config,
)

mock_completion.assert_called_once()
call_kwargs = mock_completion.call_args.kwargs

assert {k: v for k, v in call_kwargs.items() if k.startswith("aws_")} == {
"aws_region_name": "us-west-2",
"aws_profile_name": "dev",
"aws_access_key_id": "AKIA123",
"aws_secret_access_key": "secret123",
"aws_session_token": "token123",
}

def test_bedrock_env_var_fallback(self, mock_completion):
os.environ["AWS_REGION"] = "ap-southeast-1"
os.environ["AWS_ACCESS_KEY_ID"] = "AKIA_ENV"
os.environ["AWS_SECRET_ACCESS_KEY"] = "SECRET_ENV"

perform_completion_with_backoff(
provider="bedrock/anthropic.claude-v2",
prompt_with_variables="test prompt",
api_token=None,
provider_config=None,
)

mock_completion.assert_called_once()
call_kwargs = mock_completion.call_args.kwargs

assert {k: v for k, v in call_kwargs.items() if k.startswith("aws_")} == {
"aws_region_name": "ap-southeast-1",
"aws_access_key_id": "AKIA_ENV",
"aws_secret_access_key": "SECRET_ENV",
}

def test_bedrock_explicit_overrides_env(self, mock_completion):
os.environ["AWS_REGION"] = "us-east-1"
os.environ["AWS_ACCESS_KEY_ID"] = "AKIA_ENV"

provider_config = {"aws_region_name": "us-west-2", "aws_access_key_id": "AKIA_EXPLICIT"}

perform_completion_with_backoff(
provider="bedrock/anthropic.claude-v2",
prompt_with_variables="test prompt",
api_token=None,
provider_config=provider_config,
)

mock_completion.assert_called_once()
call_kwargs = mock_completion.call_args.kwargs

assert {k: v for k, v in call_kwargs.items() if k.startswith("aws_")} == {
"aws_region_name": "us-west-2",
"aws_access_key_id": "AKIA_EXPLICIT",
}

def test_non_bedrock_provider_unaffected(self, mock_completion):
provider_config = {"aws_region_name": "us-west-2", "aws_access_key_id": "AKIA123"}

perform_completion_with_backoff(
provider="openai/gpt-4",
prompt_with_variables="test prompt",
api_token="sk-test",
provider_config=provider_config,
)

mock_completion.assert_called_once()
call_kwargs = mock_completion.call_args.kwargs

assert "aws_region_name" not in call_kwargs
assert "aws_access_key_id" not in call_kwargs
assert call_kwargs["api_key"] == "sk-test"


class TestAPerformCompletionWithProviderConfig:
def setup_method(self):
env_vars = [
"AWS_REGION",
"AWS_DEFAULT_REGION",
"AWS_PROFILE",
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_SESSION_TOKEN",
]
for var in env_vars:
os.environ.pop(var, None)

@pytest.mark.asyncio
async def test_async_bedrock_provider_config_passed(self, mock_acompletion):
provider_config = {"aws_region_name": "us-west-2", "aws_access_key_id": "AKIA123"}

await aperform_completion_with_backoff(
provider="bedrock/anthropic.claude-v2",
prompt_with_variables="test prompt",
api_token=None,
provider_config=provider_config,
)

mock_acompletion.assert_called_once()
call_kwargs = mock_acompletion.call_args.kwargs

assert {k: v for k, v in call_kwargs.items() if k.startswith("aws_")} == {
"aws_region_name": "us-west-2",
"aws_access_key_id": "AKIA123",
}

@pytest.mark.asyncio
async def test_async_bedrock_env_fallback(self, mock_acompletion):
os.environ["AWS_REGION"] = "eu-west-1"
os.environ["AWS_ACCESS_KEY_ID"] = "AKIA_ENV"

await aperform_completion_with_backoff(
provider="bedrock/anthropic.claude-v2",
prompt_with_variables="test prompt",
api_token=None,
provider_config=None,
)

mock_acompletion.assert_called_once()
call_kwargs = mock_acompletion.call_args.kwargs

assert {k: v for k, v in call_kwargs.items() if k.startswith("aws_")} == {
"aws_region_name": "eu-west-1",
"aws_access_key_id": "AKIA_ENV",
}


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading