diff --git a/src/anthropic/lib/bedrock/_client.py b/src/anthropic/lib/bedrock/_client.py index cda0690df..e9ea0c112 100644 --- a/src/anthropic/lib/bedrock/_client.py +++ b/src/anthropic/lib/bedrock/_client.py @@ -67,17 +67,17 @@ def _prepare_options(input_options: FinalRequestOptions) -> FinalRequestOptions: return options -def _infer_region() -> str: +def _infer_region(aws_profile: str | None = None) -> str: """ Infer the AWS region from the environment variables or from the boto3 session if available. """ - aws_region = os.environ.get("AWS_REGION") + aws_region = os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION") if aws_region is None: try: import boto3 - session = boto3.Session() + session = boto3.Session(profile_name=aws_profile) if session.region_name: aws_region = session.region_name except ImportError: @@ -178,7 +178,7 @@ def __init__( self.aws_access_key = aws_access_key - self.aws_region = _infer_region() if aws_region is None else aws_region + self.aws_region = _infer_region(aws_profile) if aws_region is None else aws_region self.aws_profile = aws_profile self.aws_session_token = aws_session_token @@ -343,7 +343,7 @@ def __init__( self.aws_access_key = aws_access_key - self.aws_region = _infer_region() if aws_region is None else aws_region + self.aws_region = _infer_region(aws_profile) if aws_region is None else aws_region self.aws_profile = aws_profile self.aws_session_token = aws_session_token diff --git a/tests/lib/test_bedrock.py b/tests/lib/test_bedrock.py index 6e45c27f7..2d33fa932 100644 --- a/tests/lib/test_bedrock.py +++ b/tests/lib/test_bedrock.py @@ -246,6 +246,13 @@ def test_api_key_env_mutual_exclusion(monkeypatch: t.Any) -> None: ) +def test_region_infer_from_aws_default_region_env(monkeypatch: t.Any) -> None: + monkeypatch.setenv("AWS_DEFAULT_REGION", "ap-southeast-1") + monkeypatch.delenv("AWS_REGION", raising=False) + client = AnthropicBedrock() + assert client.aws_region == "ap-southeast-1" + + def test_region_infer_from_profile( mock_aws_config: None, # noqa: ARG001 profiles: t.List[AwsConfigProfile], @@ -275,3 +282,23 @@ def test_region_infer_from_specified_profile( client = AnthropicBedrock() assert client.aws_region == next(profile for profile in profiles if profile["name"] == aws_profile)["region"] + + +@pytest.mark.parametrize( + "profiles, aws_profile", + [ + pytest.param( + [{"name": "default", "region": "us-east-2"}, {"name": "custom", "region": "us-west-1"}], + "custom", + id="custom profile via aws_profile arg", + ), + ], +) +def test_region_infer_from_aws_profile_arg( + mock_aws_config: None, # noqa: ARG001 + profiles: t.List[AwsConfigProfile], + aws_profile: str, +) -> None: + client = AnthropicBedrock(aws_profile=aws_profile) + + assert client.aws_region == next(profile for profile in profiles if profile["name"] == aws_profile)["region"]