|
| 1 | +"""AWS Bedrock adapter for Helion's LLM transport. |
| 2 | +
|
| 3 | +Uses IMDSv2 to fetch instance-role credentials, signs requests with SigV4, |
| 4 | +and wraps the Anthropic Messages payload in the Bedrock invoke format. No |
| 5 | +boto3 / anthropic SDK dependency; uses only the standard library so it |
| 6 | +matches the rest of ``transport.py``. |
| 7 | +
|
| 8 | +Env vars: |
| 9 | +
|
| 10 | +- ``AWS_REGION`` / ``AWS_DEFAULT_REGION``: region for the Bedrock endpoint. |
| 11 | + Falls back to the IMDSv2 region metadata. |
| 12 | +- ``HELION_LLM_ANTHROPIC_THINKING_BUDGET``: integer token budget for |
| 13 | + Claude extended thinking. Opus 4.7 uses ``thinking.type="adaptive"`` |
| 14 | + plus ``output_config.effort="high"``; older Opus models use |
| 15 | + ``thinking.type="enabled"`` with ``budget_tokens=N`` and force |
| 16 | + ``temperature=1.0``. ``max_tokens`` is raised to at least |
| 17 | + ``budget + 4096`` so the response has room after thinking. |
| 18 | +- ``HELION_LLM_ANTHROPIC_REASONING_EFFORT``: optional override for the |
| 19 | + Opus 4.7 ``output_config.effort`` value; default ``"high"``. |
| 20 | +""" |
| 21 | + |
| 22 | +from __future__ import annotations |
| 23 | + |
| 24 | +from dataclasses import dataclass |
| 25 | +from datetime import datetime |
| 26 | +from datetime import timezone |
| 27 | +import hashlib |
| 28 | +import hmac |
| 29 | +import json |
| 30 | +import os |
| 31 | +import threading |
| 32 | +from typing import Any |
| 33 | +from urllib import request as urllib_request |
| 34 | + |
| 35 | +_IMDS_BASE = "http://169.254.169.254/latest" |
| 36 | +_IMDS_TOKEN_TTL_S = 21600 # 6h, max allowed |
| 37 | + |
| 38 | +_cred_lock = threading.Lock() |
| 39 | +_cached_creds: BedrockCredentials | None = None |
| 40 | + |
| 41 | + |
| 42 | +@dataclass(frozen=True) |
| 43 | +class BedrockCredentials: |
| 44 | + access_key_id: str |
| 45 | + secret_access_key: str |
| 46 | + session_token: str | None |
| 47 | + expiration_epoch: float # seconds since epoch; 0 for static creds |
| 48 | + |
| 49 | + |
| 50 | +def _imds_token() -> str: |
| 51 | + """Fetch an IMDSv2 session token.""" |
| 52 | + req = urllib_request.Request( |
| 53 | + f"{_IMDS_BASE}/api/token", |
| 54 | + method="PUT", |
| 55 | + headers={"X-aws-ec2-metadata-token-ttl-seconds": str(_IMDS_TOKEN_TTL_S)}, |
| 56 | + ) |
| 57 | + with urllib_request.urlopen(req, timeout=2) as resp: |
| 58 | + return resp.read().decode() |
| 59 | + |
| 60 | + |
| 61 | +def _imds_get(path: str, token: str) -> str: |
| 62 | + req = urllib_request.Request( |
| 63 | + f"{_IMDS_BASE}/{path}", |
| 64 | + headers={"X-aws-ec2-metadata-token": token}, |
| 65 | + ) |
| 66 | + with urllib_request.urlopen(req, timeout=2) as resp: |
| 67 | + return resp.read().decode() |
| 68 | + |
| 69 | + |
| 70 | +def _parse_expiration(expiration: str) -> float: |
| 71 | + """Parse ISO-8601 `Z` timestamp into epoch seconds.""" |
| 72 | + if expiration.endswith("Z"): |
| 73 | + expiration = expiration[:-1] + "+00:00" |
| 74 | + return datetime.fromisoformat(expiration).timestamp() |
| 75 | + |
| 76 | + |
| 77 | +def _fetch_imds_credentials() -> BedrockCredentials: |
| 78 | + token = _imds_token() |
| 79 | + role = _imds_get("meta-data/iam/security-credentials/", token).strip() |
| 80 | + if not role: |
| 81 | + raise RuntimeError( |
| 82 | + "IMDSv2 returned no IAM role; Bedrock requires instance credentials" |
| 83 | + ) |
| 84 | + raw = _imds_get(f"meta-data/iam/security-credentials/{role}", token) |
| 85 | + data = json.loads(raw) |
| 86 | + if data.get("Code") != "Success": |
| 87 | + raise RuntimeError(f"IMDSv2 credential fetch failed: {data.get('Code')}") |
| 88 | + return BedrockCredentials( |
| 89 | + access_key_id=data["AccessKeyId"], |
| 90 | + secret_access_key=data["SecretAccessKey"], |
| 91 | + session_token=data.get("Token"), |
| 92 | + expiration_epoch=_parse_expiration(data["Expiration"]), |
| 93 | + ) |
| 94 | + |
| 95 | + |
| 96 | +def _load_env_credentials() -> BedrockCredentials | None: |
| 97 | + """Allow AWS_* env vars to override IMDS for local testing.""" |
| 98 | + ak = os.environ.get("AWS_ACCESS_KEY_ID") |
| 99 | + sk = os.environ.get("AWS_SECRET_ACCESS_KEY") |
| 100 | + if not ak or not sk: |
| 101 | + return None |
| 102 | + return BedrockCredentials( |
| 103 | + access_key_id=ak, |
| 104 | + secret_access_key=sk, |
| 105 | + session_token=os.environ.get("AWS_SESSION_TOKEN"), |
| 106 | + expiration_epoch=0.0, |
| 107 | + ) |
| 108 | + |
| 109 | + |
| 110 | +def get_credentials() -> BedrockCredentials: |
| 111 | + """Return cached credentials, refreshing before expiry.""" |
| 112 | + global _cached_creds |
| 113 | + with _cred_lock: |
| 114 | + now = datetime.now(timezone.utc).timestamp() |
| 115 | + if _cached_creds is not None: |
| 116 | + if ( |
| 117 | + _cached_creds.expiration_epoch == 0.0 |
| 118 | + or _cached_creds.expiration_epoch - now > 300 |
| 119 | + ): |
| 120 | + return _cached_creds |
| 121 | + creds = _load_env_credentials() or _fetch_imds_credentials() |
| 122 | + _cached_creds = creds |
| 123 | + return creds |
| 124 | + |
| 125 | + |
| 126 | +def resolve_region(explicit: str | None = None) -> str: |
| 127 | + if explicit: |
| 128 | + return explicit |
| 129 | + for name in ("AWS_REGION", "AWS_DEFAULT_REGION"): |
| 130 | + if (val := os.environ.get(name)) is not None: |
| 131 | + return val |
| 132 | + # Fall back to IMDS if available. |
| 133 | + try: |
| 134 | + token = _imds_token() |
| 135 | + return _imds_get("meta-data/placement/region", token).strip() |
| 136 | + except Exception as e: # noqa: BLE001 |
| 137 | + raise RuntimeError( |
| 138 | + "Could not resolve AWS region; set AWS_REGION or run on an EC2 " |
| 139 | + "instance reachable from IMDSv2." |
| 140 | + ) from e |
| 141 | + |
| 142 | + |
| 143 | +def bedrock_endpoint(model_id: str, region: str) -> str: |
| 144 | + """Build the Bedrock-runtime invoke URL for a given model. |
| 145 | +
|
| 146 | + Bedrock model IDs contain ``:`` (e.g. ``...-v1:0``), which must be |
| 147 | + URL-encoded in both the request URL and the SigV4 canonical path. |
| 148 | + """ |
| 149 | + from urllib.parse import quote |
| 150 | + |
| 151 | + encoded_id = quote(model_id, safe="") |
| 152 | + return ( |
| 153 | + f"https://bedrock-runtime.{region}.amazonaws.com" |
| 154 | + f"/model/{encoded_id}/invoke" |
| 155 | + ) |
| 156 | + |
| 157 | + |
| 158 | +def _thinking_budget() -> int | None: |
| 159 | + raw = os.environ.get("HELION_LLM_ANTHROPIC_THINKING_BUDGET") |
| 160 | + if raw is None: |
| 161 | + return None |
| 162 | + try: |
| 163 | + val = int(raw) |
| 164 | + except ValueError as e: |
| 165 | + raise RuntimeError( |
| 166 | + f"HELION_LLM_ANTHROPIC_THINKING_BUDGET must be an int, got {raw!r}" |
| 167 | + ) from e |
| 168 | + if val <= 0: |
| 169 | + return None |
| 170 | + return val |
| 171 | + |
| 172 | + |
| 173 | +_ADAPTIVE_THINKING_MODELS = ("claude-opus-4-7",) |
| 174 | + |
| 175 | + |
| 176 | +def _model_uses_adaptive_thinking(model: str) -> bool: |
| 177 | + m = model.lower() |
| 178 | + return any(key in m for key in _ADAPTIVE_THINKING_MODELS) |
| 179 | + |
| 180 | + |
| 181 | +def build_bedrock_payload( |
| 182 | + *, |
| 183 | + messages: list[dict[str, str]], |
| 184 | + max_output_tokens: int, |
| 185 | + system_prompt: str, |
| 186 | + model: str = "", |
| 187 | +) -> dict[str, Any]: |
| 188 | + """Build the Anthropic-Messages-over-Bedrock invoke body.""" |
| 189 | + from .transport import anthropic_messages_from_history |
| 190 | + |
| 191 | + payload: dict[str, Any] = { |
| 192 | + "anthropic_version": "bedrock-2023-05-31", |
| 193 | + "max_tokens": max_output_tokens, |
| 194 | + "messages": anthropic_messages_from_history(messages), |
| 195 | + } |
| 196 | + if system_prompt: |
| 197 | + payload["system"] = system_prompt |
| 198 | + |
| 199 | + budget = _thinking_budget() |
| 200 | + if budget is None: |
| 201 | + return payload |
| 202 | + |
| 203 | + response_headroom = 4096 |
| 204 | + payload["max_tokens"] = max(payload["max_tokens"], budget + response_headroom) |
| 205 | + # Extended thinking requires temperature=1.0 per Anthropic contract. |
| 206 | + payload["temperature"] = 1.0 |
| 207 | + |
| 208 | + if _model_uses_adaptive_thinking(model): |
| 209 | + effort = os.environ.get("HELION_LLM_ANTHROPIC_REASONING_EFFORT", "high") |
| 210 | + payload["thinking"] = {"type": "adaptive"} |
| 211 | + payload["output_config"] = {"effort": effort} |
| 212 | + else: |
| 213 | + payload["thinking"] = {"type": "enabled", "budget_tokens": budget} |
| 214 | + return payload |
| 215 | + |
| 216 | + |
| 217 | +# ---------------------------------------------------------------------------- |
| 218 | +# SigV4 signing (Bedrock runtime) |
| 219 | +# ---------------------------------------------------------------------------- |
| 220 | + |
| 221 | + |
| 222 | +def _hmac_sha256(key: bytes, data: bytes) -> bytes: |
| 223 | + return hmac.new(key, data, hashlib.sha256).digest() |
| 224 | + |
| 225 | + |
| 226 | +def _sig_key(secret: str, date: str, region: str, service: str) -> bytes: |
| 227 | + k_date = _hmac_sha256(("AWS4" + secret).encode(), date.encode()) |
| 228 | + k_region = _hmac_sha256(k_date, region.encode()) |
| 229 | + k_service = _hmac_sha256(k_region, service.encode()) |
| 230 | + return _hmac_sha256(k_service, b"aws4_request") |
| 231 | + |
| 232 | + |
| 233 | +def sigv4_headers( |
| 234 | + *, |
| 235 | + method: str, |
| 236 | + url: str, |
| 237 | + body: bytes, |
| 238 | + region: str, |
| 239 | + creds: BedrockCredentials, |
| 240 | + service: str = "bedrock", |
| 241 | +) -> dict[str, str]: |
| 242 | + """Return the HTTP headers for a SigV4-signed Bedrock request.""" |
| 243 | + from urllib.parse import quote, urlparse |
| 244 | + |
| 245 | + parsed = urlparse(url) |
| 246 | + host = parsed.netloc |
| 247 | + # SigV4 canonical path: URI-encoded absolute path. Because the URL we |
| 248 | + # hand to urlopen already contains a percent-encoded model ID (e.g. |
| 249 | + # %3A), that percent must be re-encoded for the canonical string or |
| 250 | + # AWS rejects the signature. quote(..., safe="/") leaves slashes alone |
| 251 | + # and percent-encodes everything else, so `%3A` becomes `%253A`. |
| 252 | + path = quote(parsed.path, safe="/") if parsed.path else "/" |
| 253 | + query = parsed.query # Bedrock invoke uses empty query |
| 254 | + |
| 255 | + now = datetime.now(timezone.utc) |
| 256 | + amz_date = now.strftime("%Y%m%dT%H%M%SZ") |
| 257 | + date_stamp = now.strftime("%Y%m%d") |
| 258 | + |
| 259 | + payload_hash = hashlib.sha256(body).hexdigest() |
| 260 | + canonical_headers_list = [ |
| 261 | + ("content-type", "application/json"), |
| 262 | + ("host", host), |
| 263 | + ("x-amz-date", amz_date), |
| 264 | + ] |
| 265 | + if creds.session_token: |
| 266 | + canonical_headers_list.append(("x-amz-security-token", creds.session_token)) |
| 267 | + canonical_headers_list.sort(key=lambda kv: kv[0]) |
| 268 | + canonical_headers = "".join(f"{k}:{v}\n" for k, v in canonical_headers_list) |
| 269 | + signed_headers = ";".join(k for k, _ in canonical_headers_list) |
| 270 | + |
| 271 | + canonical_request = ( |
| 272 | + f"{method}\n{path}\n{query}\n{canonical_headers}\n" |
| 273 | + f"{signed_headers}\n{payload_hash}" |
| 274 | + ) |
| 275 | + |
| 276 | + credential_scope = f"{date_stamp}/{region}/{service}/aws4_request" |
| 277 | + string_to_sign = ( |
| 278 | + "AWS4-HMAC-SHA256\n" |
| 279 | + f"{amz_date}\n" |
| 280 | + f"{credential_scope}\n" |
| 281 | + f"{hashlib.sha256(canonical_request.encode()).hexdigest()}" |
| 282 | + ) |
| 283 | + |
| 284 | + signing_key = _sig_key(creds.secret_access_key, date_stamp, region, service) |
| 285 | + signature = hmac.new( |
| 286 | + signing_key, string_to_sign.encode(), hashlib.sha256 |
| 287 | + ).hexdigest() |
| 288 | + |
| 289 | + auth_header = ( |
| 290 | + f"AWS4-HMAC-SHA256 " |
| 291 | + f"Credential={creds.access_key_id}/{credential_scope}, " |
| 292 | + f"SignedHeaders={signed_headers}, " |
| 293 | + f"Signature={signature}" |
| 294 | + ) |
| 295 | + |
| 296 | + headers = { |
| 297 | + "content-type": "application/json", |
| 298 | + "host": host, |
| 299 | + "x-amz-date": amz_date, |
| 300 | + "authorization": auth_header, |
| 301 | + } |
| 302 | + if creds.session_token: |
| 303 | + headers["x-amz-security-token"] = creds.session_token |
| 304 | + return headers |
0 commit comments