Skip to content

Commit 6c3628f

Browse files
Move SigV4 request signing into the AuthManager abstraction
1 parent 006a7fc commit 6c3628f

4 files changed

Lines changed: 618 additions & 381 deletions

File tree

pyiceberg/catalog/rest/__init__.py

Lines changed: 69 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,15 @@
3131

3232
from pyiceberg import __version__
3333
from pyiceberg.catalog import BOTOCORE_SESSION, TOKEN, URI, WAREHOUSE_LOCATION, Catalog, PropertiesUpdateSummary
34-
from pyiceberg.catalog.rest.auth import AUTH_MANAGER, AuthManager, AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager
34+
from pyiceberg.catalog.rest.auth import (
35+
AUTH_MANAGER,
36+
AuthManager,
37+
AuthManagerAdapter,
38+
AuthManagerFactory,
39+
LegacyOAuth2AuthManager,
40+
NoopAuthManager,
41+
SigV4AuthManager,
42+
)
3543
from pyiceberg.catalog.rest.response import _handle_non_200_response
3644
from pyiceberg.catalog.rest.scan_planning import (
3745
FetchScanTasksRequest,
@@ -251,11 +259,11 @@ class ScanPlanningMode(Enum):
251259
CA_BUNDLE = "cabundle"
252260
SSL = "ssl"
253261
SIGV4 = "rest.sigv4-enabled"
262+
SIGV4_AUTH_TYPE = "sigv4"
254263
SIGV4_REGION = "rest.signing-region"
255264
SIGV4_SERVICE = "rest.signing-name"
256265
SIGV4_MAX_RETRIES = "rest.sigv4.max-retries"
257266
SIGV4_MAX_RETRIES_DEFAULT = 10
258-
EMPTY_BODY_SHA256: str = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
259267
OAUTH2_SERVER_URI = "oauth2-server-uri"
260268
SNAPSHOT_LOADING_MODE = "snapshot-loading-mode"
261269
AUTH = "auth"
@@ -435,10 +443,49 @@ def _create_session(self) -> Session:
435443
elif ssl_client_cert := ssl_client.get(CERT):
436444
session.cert = ssl_client_cert
437445

446+
self._auth_manager = self._build_auth_manager(session)
447+
session.auth = AuthManagerAdapter(self._auth_manager)
448+
449+
# SigV4 retry is decoupled from signing: mount a plain retry adapter.
450+
if self._is_sigv4_enabled():
451+
from requests.adapters import HTTPAdapter
452+
453+
max_retries = property_as_int(self.properties, SIGV4_MAX_RETRIES, SIGV4_MAX_RETRIES_DEFAULT)
454+
session.mount(self.uri, HTTPAdapter(max_retries=max_retries))
455+
456+
return session
457+
458+
def _is_sigv4_enabled(self) -> bool:
459+
"""Return True if SigV4 signing is requested via either config path."""
460+
if property_as_bool(self.properties, SIGV4, False):
461+
return True
462+
auth_config = self.properties.get(AUTH)
463+
return auth_config is not None and auth_config.get("type") == SIGV4_AUTH_TYPE
464+
465+
def _build_auth_manager(self, session: Session) -> AuthManager:
466+
"""Build the AuthManager, wrapping the delegate in SigV4 when enabled."""
467+
delegate = self._build_delegate_auth_manager(session)
468+
if self._is_sigv4_enabled():
469+
return self._build_sigv4_auth_manager(delegate)
470+
return delegate
471+
472+
def _build_delegate_auth_manager(self, session: Session) -> AuthManager:
473+
"""Build the header-based AuthManager (the SigV4 delegate, or the manager used directly)."""
438474
if auth_config := self.properties.get(AUTH):
439475
auth_type = auth_config.get("type")
440476
if auth_type is None:
441477
raise ValueError("auth.type must be defined")
478+
479+
if auth_type == SIGV4_AUTH_TYPE:
480+
# The delegate is configured under auth.sigv4.delegate.*
481+
sigv4_config = auth_config.get(SIGV4_AUTH_TYPE, {})
482+
delegate_config = sigv4_config.get("delegate")
483+
if not delegate_config or "type" not in delegate_config:
484+
# No delegate configured: SigV4-only auth, with no header-based delegate.
485+
return NoopAuthManager()
486+
delegate_type = delegate_config["type"]
487+
return AuthManagerFactory.create(delegate_type, delegate_config.get(delegate_type, {}))
488+
442489
auth_type_config = auth_config.get(auth_type, {})
443490
auth_impl = auth_config.get("impl")
444491

@@ -448,17 +495,28 @@ def _create_session(self) -> Session:
448495
if auth_type != CUSTOM and auth_impl:
449496
raise ValueError("auth.impl can only be specified when using custom auth.type")
450497

451-
self._auth_manager = AuthManagerFactory.create(auth_impl or auth_type, auth_type_config)
452-
session.auth = AuthManagerAdapter(self._auth_manager)
453-
else:
454-
self._auth_manager = self._create_legacy_oauth2_auth_manager(session)
455-
session.auth = AuthManagerAdapter(self._auth_manager)
498+
return AuthManagerFactory.create(auth_impl or auth_type, auth_type_config)
456499

457-
# Configure SigV4 Request Signing
458-
if property_as_bool(self.properties, SIGV4, False):
459-
self._init_sigv4(session)
500+
return self._create_legacy_oauth2_auth_manager(session)
460501

461-
return session
502+
def _build_sigv4_auth_manager(self, delegate: AuthManager) -> AuthManager:
503+
"""Wrap the delegate AuthManager in a SigV4AuthManager."""
504+
import boto3
505+
506+
boto_session = boto3.Session(
507+
profile_name=get_first_property_value(self.properties, AWS_PROFILE_NAME),
508+
region_name=get_first_property_value(self.properties, AWS_REGION),
509+
botocore_session=self.properties.get(BOTOCORE_SESSION),
510+
aws_access_key_id=get_first_property_value(self.properties, AWS_ACCESS_KEY_ID),
511+
aws_secret_access_key=get_first_property_value(self.properties, AWS_SECRET_ACCESS_KEY),
512+
aws_session_token=get_first_property_value(self.properties, AWS_SESSION_TOKEN),
513+
)
514+
return SigV4AuthManager(
515+
delegate=delegate,
516+
boto_session=boto_session,
517+
region=self.properties.get(SIGV4_REGION),
518+
service=self.properties.get(SIGV4_SERVICE, "execute-api"),
519+
)
462520

463521
@staticmethod
464522
def _resolve_storage_credentials(storage_credentials: list[StorageCredential], location: str | None) -> Properties:
@@ -761,101 +819,6 @@ def _split_identifier_for_json(self, identifier: str | Identifier) -> dict[str,
761819
identifier_tuple = self._identifier_to_validated_tuple(identifier)
762820
return {"namespace": identifier_tuple[:-1], "name": identifier_tuple[-1]}
763821

764-
def _init_sigv4(self, session: Session) -> None:
765-
import base64
766-
import hashlib
767-
from urllib import parse
768-
769-
import boto3
770-
from botocore.auth import SigV4Auth
771-
from botocore.awsrequest import AWSRequest
772-
from requests import PreparedRequest
773-
from requests.adapters import HTTPAdapter
774-
775-
class _IcebergSigV4Auth(SigV4Auth):
776-
def canonical_request(self, request: AWSRequest) -> str:
777-
# Override forces hex payload hash in the canonical request even when
778-
# x-amz-content-sha256 header is base64 (see body-hash block below).
779-
# Mirrors botocore <=1.42.x SigV4Auth.canonical_request layout:
780-
# https://github.com/boto/botocore/blob/1.42.85/botocore/auth.py#L622-L637
781-
cr = [request.method.upper()]
782-
path = self._normalize_url_path(parse.urlsplit(request.url).path)
783-
cr.append(path)
784-
cr.append(self.canonical_query_string(request))
785-
headers_to_sign = self.headers_to_sign(request)
786-
cr.append(self.canonical_headers(headers_to_sign) + "\n")
787-
cr.append(self.signed_headers(headers_to_sign))
788-
cr.append(self.payload(request))
789-
return "\n".join(cr)
790-
791-
class SigV4Adapter(HTTPAdapter):
792-
def __init__(self, **properties: str):
793-
self._properties = properties
794-
max_retries = property_as_int(self._properties, SIGV4_MAX_RETRIES, SIGV4_MAX_RETRIES_DEFAULT)
795-
super().__init__(max_retries=max_retries)
796-
self._boto_session = boto3.Session(
797-
profile_name=get_first_property_value(self._properties, AWS_PROFILE_NAME),
798-
region_name=get_first_property_value(self._properties, AWS_REGION),
799-
botocore_session=self._properties.get(BOTOCORE_SESSION),
800-
aws_access_key_id=get_first_property_value(self._properties, AWS_ACCESS_KEY_ID),
801-
aws_secret_access_key=get_first_property_value(self._properties, AWS_SECRET_ACCESS_KEY),
802-
aws_session_token=get_first_property_value(self._properties, AWS_SESSION_TOKEN),
803-
)
804-
805-
def add_headers(self, request: PreparedRequest, **kwargs: Any) -> None: # pylint: disable=W0613
806-
credentials = self._boto_session.get_credentials().get_frozen_credentials()
807-
region = self._properties.get(SIGV4_REGION, self._boto_session.region_name)
808-
service = self._properties.get(SIGV4_SERVICE, "execute-api")
809-
810-
url = str(request.url).split("?")[0]
811-
query = str(parse.urlsplit(request.url).query)
812-
params = dict(parse.parse_qsl(query))
813-
814-
# remove the connection header as it will be updated after signing
815-
if "connection" in request.headers:
816-
del request.headers["connection"]
817-
818-
# Match Iceberg Java's AWS SDK v2 flexible-checksum signing:
819-
# x-amz-content-sha256 header is base64 for non-empty bodies, hex for empty.
820-
# The SigV4 canonical request still uses hex (enforced in _IcebergSigV4Auth above).
821-
# Ref: https://github.com/apache/iceberg/blob/main/aws/src/main/java/org/apache/iceberg/aws/RESTSigV4AuthSession.java
822-
if request.body:
823-
if isinstance(request.body, str):
824-
body_bytes = request.body.encode("utf-8")
825-
elif isinstance(request.body, (bytes, bytearray)):
826-
body_bytes = request.body
827-
else:
828-
raise TypeError(
829-
f"Unsupported request body type for SigV4 signing: "
830-
f"{type(request.body).__name__}; expected str or bytes."
831-
)
832-
content_sha256_header = base64.b64encode(hashlib.sha256(body_bytes).digest()).decode()
833-
else:
834-
content_sha256_header = EMPTY_BODY_SHA256
835-
836-
signing_headers = dict(request.headers)
837-
signing_headers["x-amz-content-sha256"] = content_sha256_header
838-
839-
aws_request = AWSRequest(
840-
method=request.method, url=url, params=params, data=request.body, headers=signing_headers
841-
)
842-
843-
_IcebergSigV4Auth(credentials, service, region).add_auth(aws_request)
844-
845-
original_header = dict(request.headers)
846-
signed_headers = dict(aws_request.headers)
847-
relocated_headers = {}
848-
849-
# relocate headers if there is a conflict with signed headers
850-
for header, value in original_header.items():
851-
if header in signed_headers and signed_headers[header] != value:
852-
relocated_headers[f"Original-{header}"] = value
853-
854-
request.headers.update(relocated_headers)
855-
request.headers.update(signed_headers)
856-
857-
session.mount(self.uri, SigV4Adapter(**self.properties))
858-
859822
def _response_to_table(self, identifier_tuple: tuple[str, ...], table_response: TableResponse) -> Table:
860823
# Per Iceberg spec: storage-credentials take precedence over config
861824
credential_config = self._resolve_storage_credentials(

pyiceberg/catalog/rest/auth.py

Lines changed: 130 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import threading
2222
import time
2323
from abc import ABC, abstractmethod
24-
from functools import cached_property
24+
from functools import cache, cached_property
2525
from typing import Any
2626

2727
import requests
@@ -36,6 +36,37 @@
3636
COLON = ":"
3737
logger = logging.getLogger(__name__)
3838

39+
# SHA-256 of an empty payload. Used as the x-amz-content-sha256 header value for
40+
# empty-body requests, matching Iceberg Java's RESTSigV4AuthSession workaround.
41+
EMPTY_BODY_SHA256 = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
42+
43+
44+
@cache
45+
def _iceberg_sigv4_auth_class() -> type:
46+
"""Lazily build the botocore SigV4Auth subclass (botocore is an optional dependency)."""
47+
from urllib import parse
48+
49+
from botocore.auth import SigV4Auth
50+
from botocore.awsrequest import AWSRequest
51+
52+
class _IcebergSigV4Auth(SigV4Auth):
53+
def canonical_request(self, request: AWSRequest) -> str:
54+
# Override forces the hex payload hash in the canonical request even when
55+
# the x-amz-content-sha256 header is base64 (see SigV4AuthManager.sign_request).
56+
# Mirrors botocore <=1.42.x SigV4Auth.canonical_request layout:
57+
# https://github.com/boto/botocore/blob/1.42.85/botocore/auth.py#L622-L637
58+
cr = [request.method.upper()]
59+
path = self._normalize_url_path(parse.urlsplit(request.url).path)
60+
cr.append(path)
61+
cr.append(self.canonical_query_string(request))
62+
headers_to_sign = self.headers_to_sign(request)
63+
cr.append(self.canonical_headers(headers_to_sign) + "\n")
64+
cr.append(self.signed_headers(headers_to_sign))
65+
cr.append(self.payload(request))
66+
return "\n".join(cr)
67+
68+
return _IcebergSigV4Auth
69+
3970

4071
class AuthManager(ABC):
4172
"""
@@ -48,6 +79,14 @@ class AuthManager(ABC):
4879
def auth_header(self) -> str | None:
4980
"""Return the Authorization header value, or None if not applicable."""
5081

82+
def sign_request(self, request: PreparedRequest) -> PreparedRequest:
83+
"""Optionally sign or otherwise modify the prepared request.
84+
85+
The default implementation is a no-op. Override for request-signing
86+
schemes such as SigV4 that must inspect the full request.
87+
"""
88+
return request
89+
5190

5291
class NoopAuthManager(AuthManager):
5392
"""Auth Manager implementation with no auth."""
@@ -311,6 +350,91 @@ def auth_header(self) -> str:
311350
return f"Bearer {self._get_token()}"
312351

313352

353+
class SigV4AuthManager(AuthManager):
354+
"""AuthManager that signs requests with AWS SigV4, wrapping a delegate AuthManager.
355+
356+
Mirrors Iceberg Java's RESTSigV4AuthManager: the delegate AuthManager handles
357+
header-based auth (e.g. OAuth2), then SigV4 signs the resulting request.
358+
"""
359+
360+
def __init__(
361+
self,
362+
delegate: AuthManager,
363+
boto_session: Any,
364+
region: str | None,
365+
service: str = "execute-api",
366+
):
367+
"""Initialize SigV4AuthManager.
368+
369+
Args:
370+
delegate: AuthManager that supplies header-based auth before signing.
371+
boto_session: A boto3.Session used to resolve AWS credentials.
372+
region: SigV4 signing region; falls back to the boto session's region.
373+
service: SigV4 signing service name.
374+
"""
375+
self._delegate = delegate
376+
self._boto_session = boto_session
377+
self._region = region
378+
self._service = service
379+
380+
def auth_header(self) -> str | None:
381+
return self._delegate.auth_header()
382+
383+
def sign_request(self, request: PreparedRequest) -> PreparedRequest:
384+
import hashlib
385+
from urllib import parse
386+
387+
from botocore.awsrequest import AWSRequest
388+
389+
credentials = self._boto_session.get_credentials().get_frozen_credentials()
390+
region = self._region or self._boto_session.region_name
391+
392+
url = str(request.url).split("?")[0]
393+
query = str(parse.urlsplit(request.url).query)
394+
params = dict(parse.parse_qsl(query))
395+
396+
# remove the connection header as it will be updated after signing
397+
if "connection" in request.headers:
398+
del request.headers["connection"]
399+
400+
# Match Iceberg Java's AWS SDK v2 flexible-checksum signing:
401+
# x-amz-content-sha256 header is base64 for non-empty bodies, hex for empty.
402+
# The SigV4 canonical request still uses hex (enforced in _iceberg_sigv4_auth_class).
403+
# Ref: https://github.com/apache/iceberg/blob/main/aws/src/main/java/org/apache/iceberg/aws/RESTSigV4AuthSession.java
404+
if request.body:
405+
if isinstance(request.body, str):
406+
body_bytes = request.body.encode("utf-8")
407+
elif isinstance(request.body, (bytes, bytearray)):
408+
body_bytes = bytes(request.body)
409+
else:
410+
raise TypeError(
411+
f"Unsupported request body type for SigV4 signing: {type(request.body).__name__}; expected str or bytes."
412+
)
413+
content_sha256_header = base64.b64encode(hashlib.sha256(body_bytes).digest()).decode()
414+
else:
415+
content_sha256_header = EMPTY_BODY_SHA256
416+
417+
signing_headers = dict(request.headers)
418+
signing_headers["x-amz-content-sha256"] = content_sha256_header
419+
420+
aws_request = AWSRequest(method=request.method, url=url, params=params, data=request.body, headers=signing_headers)
421+
422+
_iceberg_sigv4_auth_class()(credentials, self._service, region).add_auth(aws_request)
423+
424+
original_header = dict(request.headers)
425+
signed_headers = dict(aws_request.headers)
426+
relocated_headers = {}
427+
428+
# relocate headers if there is a conflict with signed headers
429+
for header, value in original_header.items():
430+
if header in signed_headers and signed_headers[header] != value:
431+
relocated_headers[f"Original-{header}"] = value
432+
433+
request.headers.update(relocated_headers)
434+
request.headers.update(signed_headers)
435+
return request
436+
437+
314438
class AuthManagerAdapter(AuthBase):
315439
"""A `requests.auth.AuthBase` adapter for integrating an `AuthManager` into a `requests.Session`.
316440
@@ -332,17 +456,19 @@ def __init__(self, auth_manager: AuthManager):
332456

333457
def __call__(self, request: PreparedRequest) -> PreparedRequest:
334458
"""
335-
Modify the outgoing request to include the Authorization header.
459+
Modify the outgoing request to include the Authorization header and any signature.
336460
337461
Args:
338462
request (requests.PreparedRequest): The HTTP request being prepared.
339463
340464
Returns:
341-
requests.PreparedRequest: The modified request with Authorization header.
465+
requests.PreparedRequest: The modified request.
342466
"""
343467
if auth_header := self.auth_manager.auth_header():
344468
request.headers["Authorization"] = auth_header
345-
return request
469+
# Header first, then sign: a request-signing AuthManager (e.g. SigV4) must
470+
# see the Authorization header so it can relocate it before signing.
471+
return self.auth_manager.sign_request(request)
346472

347473

348474
class AuthManagerFactory:

0 commit comments

Comments
 (0)