Skip to content

Commit aa14bee

Browse files
committed
Implement vended credential refresh for s3
1 parent a9ad3a3 commit aa14bee

5 files changed

Lines changed: 370 additions & 27 deletions

File tree

pyiceberg/catalog/rest/__init__.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,15 @@ class ListViewsResponse(IcebergBaseModel):
382382
identifiers: list[ListViewResponseEntry] = Field()
383383

384384

385+
class Credential(IcebergBaseModel):
386+
prefix: str = Field()
387+
config: dict[str, str] = Field()
388+
389+
390+
class LoadCredentialsResponse(IcebergBaseModel):
391+
credentials: list[Credential] = Field(alias="storage-credentials")
392+
393+
385394
_PLANNING_RESPONSE_ADAPTER = TypeAdapter(PlanningResponse)
386395

387396

@@ -469,11 +478,13 @@ def _resolve_storage_credentials(storage_credentials: list[StorageCredential], l
469478

470479
return best_match.config if best_match else {}
471480

472-
def _load_file_io(self, properties: Properties = EMPTY_DICT, location: str | None = None) -> FileIO:
481+
def _load_file_io(
482+
self, properties: Properties = EMPTY_DICT, location: str | None = None, session: Session | None = None
483+
) -> FileIO:
473484
merged_properties = {**self.properties, **properties}
474485
if self._auth_manager:
475486
merged_properties[AUTH_MANAGER] = self._auth_manager
476-
return load_file_io(merged_properties, location)
487+
return load_file_io(merged_properties, location, session)
477488

478489
def supports_server_side_planning(self) -> bool:
479490
"""Check if the catalog supports server-side scan planning."""
@@ -820,6 +831,7 @@ def _response_to_table(self, identifier_tuple: tuple[str, ...], table_response:
820831
io=self._load_file_io(
821832
{**table_response.metadata.properties, **table_response.config, **credential_config},
822833
table_response.metadata_location,
834+
self._session,
823835
),
824836
catalog=self,
825837
config=table_response.config,
@@ -837,6 +849,7 @@ def _response_to_staged_table(self, identifier_tuple: tuple[str, ...], table_res
837849
io=self._load_file_io(
838850
{**table_response.metadata.properties, **table_response.config, **credential_config},
839851
table_response.metadata_location,
852+
self._session,
840853
),
841854
catalog=self,
842855
)

pyiceberg/io/__init__.py

Lines changed: 110 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import logging
3030
import warnings
3131
from abc import ABC, abstractmethod
32+
from datetime import datetime
3233
from io import SEEK_SET
3334
from types import TracebackType
3435
from typing import (
@@ -37,7 +38,11 @@
3738
)
3839
from urllib.parse import urlparse
3940

41+
from requests import HTTPError, Session
42+
43+
from pyiceberg.exceptions import ValidationException
4044
from pyiceberg.typedef import EMPTY_DICT, Properties
45+
from pyiceberg.utils.properties import get_first_property_value, property_as_bool, property_as_int
4146

4247
logger = logging.getLogger(__name__)
4348

@@ -67,6 +72,7 @@
6772
S3_ROLE_SESSION_NAME = "s3.role-session-name"
6873
S3_FORCE_VIRTUAL_ADDRESSING = "s3.force-virtual-addressing"
6974
S3_RETRY_STRATEGY_IMPL = "s3.retry-strategy-impl"
75+
S3_SESSION_TOKEN_EXPIRES_AT_MS = "s3.session-token-expires-at-ms"
7076
HDFS_HOST = "hdfs.host"
7177
HDFS_PORT = "hdfs.port"
7278
HDFS_USER = "hdfs.user"
@@ -99,6 +105,9 @@
99105
GCS_VERSION_AWARE = "gcs.version-aware"
100106
HF_ENDPOINT = "hf.endpoint"
101107
HF_TOKEN = "hf.token"
108+
CREDENTIALS_ENDPOINT = "client.refresh-credentials-endpoint"
109+
REFRESH_CREDENTIALS_ENABLED = "client.refresh-credentials-enabled"
110+
CATALOG_URI = "uri"
102111

103112

104113
@runtime_checkable
@@ -258,9 +267,11 @@ class FileIO(ABC):
258267
"""A base class for FileIO implementations."""
259268

260269
properties: Properties
270+
session: Session | None
261271

262-
def __init__(self, properties: Properties = EMPTY_DICT):
272+
def __init__(self, properties: Properties = EMPTY_DICT, session: Session | None = None):
263273
self.properties = properties
274+
self.session = session
264275

265276
@abstractmethod
266277
def new_input(self, location: str) -> InputFile:
@@ -317,15 +328,15 @@ def delete(self, location: str | InputFile | OutputFile) -> None:
317328
}
318329

319330

320-
def _import_file_io(io_impl: str, properties: Properties) -> FileIO | None:
331+
def _import_file_io(io_impl: str, properties: Properties, session: Session | None = None) -> FileIO | None:
321332
try:
322333
path_parts = io_impl.split(".")
323334
if len(path_parts) < 2:
324335
raise ValueError(f"py-io-impl should be full path (module.CustomFileIO), got: {io_impl}")
325336
module_name, class_name = ".".join(path_parts[:-1]), path_parts[-1]
326337
module = importlib.import_module(module_name)
327338
class_ = getattr(module, class_name)
328-
return class_(properties)
339+
return class_(properties, session)
329340
except ModuleNotFoundError:
330341
logger.warning(f"Could not initialize FileIO: {io_impl}", exc_info=logger.isEnabledFor(logging.DEBUG))
331342
return None
@@ -334,45 +345,134 @@ def _import_file_io(io_impl: str, properties: Properties) -> FileIO | None:
334345
PY_IO_IMPL = "py-io-impl"
335346

336347

337-
def _infer_file_io_from_scheme(path: str, properties: Properties) -> FileIO | None:
348+
def _infer_file_io_from_scheme(path: str, properties: Properties, session: Session | None = None) -> FileIO | None:
338349
parsed_url = urlparse(path)
339350
if parsed_url.scheme:
340351
if file_ios := SCHEMA_TO_FILE_IO.get(parsed_url.scheme):
341352
for file_io_path in file_ios:
342-
if file_io := _import_file_io(file_io_path, properties):
353+
if file_io := _import_file_io(file_io_path, properties, session):
343354
return file_io
344355
else:
345356
warnings.warn(f"No preferred file implementation for scheme: {parsed_url.scheme}", stacklevel=2)
346357
return None
347358

348359

349-
def load_file_io(properties: Properties = EMPTY_DICT, location: str | None = None) -> FileIO:
360+
def load_file_io(properties: Properties = EMPTY_DICT, location: str | None = None, session: Session | None = None) -> FileIO:
350361
# First look for the py-io-impl property to directly load the class
351362
if io_impl := properties.get(PY_IO_IMPL):
352-
if file_io := _import_file_io(io_impl, properties):
363+
if file_io := _import_file_io(io_impl, properties, session):
353364
logger.info("Loaded FileIO: %s", io_impl)
354365
return file_io
355366
else:
356367
raise ValueError(f"Could not initialize FileIO: {io_impl}")
357368

358369
# Check the table location
359370
if location:
360-
if file_io := _infer_file_io_from_scheme(location, properties):
371+
if file_io := _infer_file_io_from_scheme(location, properties, session):
361372
return file_io
362373

363374
# Look at the schema of the warehouse
364375
if warehouse_location := properties.get(WAREHOUSE):
365-
if file_io := _infer_file_io_from_scheme(warehouse_location, properties):
376+
if file_io := _infer_file_io_from_scheme(warehouse_location, properties, session):
366377
return file_io
367378

368379
try:
369380
# Default to PyArrow
370381
logger.info("Defaulting to PyArrow FileIO")
371382
from pyiceberg.io.pyarrow import PyArrowFileIO
372383

373-
return PyArrowFileIO(properties)
384+
return PyArrowFileIO(properties, session)
374385
except ModuleNotFoundError as e:
375386
raise ModuleNotFoundError(
376387
"Could not load a FileIO, please consider installing one: "
377388
'pip3 install "pyiceberg[pyarrow]", for more options refer to the docs.'
378389
) from e
390+
391+
392+
def _extract_s3_credentials(properties: Properties) -> Properties:
393+
"""Extract only S3 credential keys from properties, normalizing AWS_ prefixes to S3_."""
394+
creds: Properties = {}
395+
if access_key := get_first_property_value(properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID):
396+
creds[S3_ACCESS_KEY_ID] = access_key
397+
if secret_key := get_first_property_value(properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY):
398+
creds[S3_SECRET_ACCESS_KEY] = secret_key
399+
if session_token := get_first_property_value(properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN):
400+
creds[S3_SESSION_TOKEN] = session_token
401+
if expiry := get_first_property_value(properties, S3_SESSION_TOKEN_EXPIRES_AT_MS):
402+
creds[S3_SESSION_TOKEN_EXPIRES_AT_MS] = expiry
403+
return creds
404+
405+
406+
def _credential_from_properties(properties: Properties) -> Properties:
407+
"""Retrieve current S3 credentials from properties returns empty if expired."""
408+
access_key = get_first_property_value(properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID)
409+
secret_access_key = get_first_property_value(properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY)
410+
session_token = get_first_property_value(properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN)
411+
expiration_ms = property_as_int(properties, S3_SESSION_TOKEN_EXPIRES_AT_MS)
412+
413+
if not access_key or not secret_access_key or not session_token or not expiration_ms:
414+
return EMPTY_DICT
415+
416+
expiresAt = datetime.fromtimestamp(expiration_ms / 1000)
417+
prefetchAt = (expiresAt - datetime.now()).total_seconds()
418+
419+
if prefetchAt > 300:
420+
return EMPTY_DICT
421+
422+
return {
423+
S3_ACCESS_KEY_ID: access_key,
424+
S3_SECRET_ACCESS_KEY: secret_access_key,
425+
S3_SESSION_TOKEN: session_token,
426+
S3_SESSION_TOKEN_EXPIRES_AT_MS: expiration_ms,
427+
}
428+
429+
430+
def _credential_refresh_endpoint(properties: Properties) -> str:
431+
"""Build credential refresh endpoint from properties."""
432+
catalog_uri = get_first_property_value(properties, CATALOG_URI)
433+
credentials_path = get_first_property_value(properties, CREDENTIALS_ENDPOINT)
434+
435+
if catalog_uri is None:
436+
raise ValidationException("Invalid catalog endpoint: None")
437+
438+
if credentials_path is None:
439+
raise ValidationException("Invalid credentials endpoint: None")
440+
441+
return str(catalog_uri).rstrip("/") + "/" + str(credentials_path).lstrip("/")
442+
443+
444+
def _get_or_refresh_credentials(properties: Properties, session: Session | None) -> Properties:
445+
"""Retrieve current S3 credentials from properties, refreshing them if they are close to expiration."""
446+
refresh_enabled = property_as_bool(properties, REFRESH_CREDENTIALS_ENABLED, False)
447+
if not refresh_enabled or session is None:
448+
return _extract_s3_credentials(properties)
449+
450+
# Returns empty if credentials missing or not yet expiring
451+
creds = _credential_from_properties(properties)
452+
453+
if not creds:
454+
return _extract_s3_credentials(properties)
455+
456+
from pyiceberg.catalog.rest import LoadCredentialsResponse
457+
from pyiceberg.catalog.rest.response import _handle_non_200_response
458+
459+
load_response: LoadCredentialsResponse | None = None
460+
461+
try:
462+
http_response = session.get(_credential_refresh_endpoint(properties))
463+
http_response.raise_for_status()
464+
load_response = LoadCredentialsResponse.model_validate_json(http_response.text)
465+
except HTTPError as exc:
466+
_handle_non_200_response(exc, {})
467+
468+
if load_response is None:
469+
raise ValidationException("Load credential response is None")
470+
471+
if not load_response.credentials:
472+
raise ValueError("Invalid S3 Credentials: empty")
473+
474+
if len(load_response.credentials) > 1:
475+
raise ValueError("Invalid S3 Credentials: only one S3 credential should exist")
476+
477+
credentials = load_response.credentials[0].config
478+
return _extract_s3_credentials(credentials)

pyiceberg/io/fsspec.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
import requests
3535
from fsspec import AbstractFileSystem
3636
from fsspec.implementations.local import LocalFileSystem
37-
from requests import HTTPError
37+
from requests import HTTPError, Session
3838

3939
from pyiceberg.catalog import TOKEN, URI
4040
from pyiceberg.catalog.rest.auth import AUTH_MANAGER
@@ -88,6 +88,7 @@
8888
InputStream,
8989
OutputFile,
9090
OutputStream,
91+
_get_or_refresh_credentials,
9192
)
9293
from pyiceberg.typedef import Properties
9394
from pyiceberg.types import strtobool
@@ -165,14 +166,16 @@ def _file(_: Properties) -> LocalFileSystem:
165166
return LocalFileSystem(auto_mkdir=True)
166167

167168

168-
def _s3(properties: Properties) -> AbstractFileSystem:
169+
def _s3(properties: Properties, session: Session | None = None) -> AbstractFileSystem:
169170
from s3fs import S3FileSystem
170171

172+
creds = _get_or_refresh_credentials(properties, session)
173+
171174
client_kwargs = {
172175
"endpoint_url": properties.get(S3_ENDPOINT),
173-
"aws_access_key_id": get_first_property_value(properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID),
174-
"aws_secret_access_key": get_first_property_value(properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY),
175-
"aws_session_token": get_first_property_value(properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN),
176+
"aws_access_key_id": get_first_property_value(creds, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID),
177+
"aws_secret_access_key": get_first_property_value(creds, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY),
178+
"aws_session_token": get_first_property_value(creds, S3_SESSION_TOKEN, AWS_SESSION_TOKEN),
176179
"region_name": get_first_property_value(properties, S3_REGION, AWS_REGION),
177180
}
178181
config_kwargs = {}
@@ -318,6 +321,7 @@ def _hf(properties: Properties) -> AbstractFileSystem:
318321
}
319322

320323
_ADLS_SCHEMES = frozenset({"abfs", "abfss", "wasb", "wasbs"})
324+
_S3_SCHEMES = frozenset({"s3", "s3a", "s3n"})
321325

322326

323327
class FsspecInputFile(InputFile):
@@ -419,10 +423,10 @@ def to_input_file(self) -> FsspecInputFile:
419423
class FsspecFileIO(FileIO):
420424
"""A FileIO implementation that uses fsspec."""
421425

422-
def __init__(self, properties: Properties):
426+
def __init__(self, properties: Properties, session: Session | None = None):
423427
self._scheme_to_fs: dict[str, Callable[..., AbstractFileSystem]] = dict(SCHEME_TO_FS)
424428
self._thread_locals = threading.local()
425-
super().__init__(properties=properties)
429+
super().__init__(properties=properties, session=session)
426430

427431
def new_input(self, location: str) -> FsspecInputFile:
428432
"""Get an FsspecInputFile instance to read bytes from the file at the given location.
@@ -488,6 +492,9 @@ def _get_fs(self, scheme: str, hostname: str | None = None) -> AbstractFileSyste
488492
if scheme in _ADLS_SCHEMES:
489493
return _adls(self.properties, hostname)
490494

495+
if scheme in _S3_SCHEMES:
496+
return _s3(self.properties, self.session)
497+
491498
return self._scheme_to_fs[scheme](self.properties)
492499

493500
def __getstate__(self) -> dict[str, Any]:

pyiceberg/io/pyarrow.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
FileSystem,
6464
FileType,
6565
)
66+
from requests import Session
6667

6768
from pyiceberg.conversions import to_bytes
6869
from pyiceberg.exceptions import ResolveError
@@ -119,6 +120,7 @@
119120
InputStream,
120121
OutputFile,
121122
OutputStream,
123+
_get_or_refresh_credentials,
122124
)
123125
from pyiceberg.io.fileformat import DataFileStatistics as DataFileStatistics
124126
from pyiceberg.manifest import (
@@ -386,9 +388,9 @@ def to_input_file(self) -> PyArrowFile:
386388
class PyArrowFileIO(FileIO):
387389
fs_by_scheme: Callable[[str, str | None], FileSystem]
388390

389-
def __init__(self, properties: Properties = EMPTY_DICT):
391+
def __init__(self, properties: Properties = EMPTY_DICT, session: Session | None = None):
390392
self.fs_by_scheme: Callable[[str, str | None], FileSystem] = lru_cache(self._initialize_fs)
391-
super().__init__(properties=properties)
393+
super().__init__(properties=properties, session=session)
392394

393395
@staticmethod
394396
def parse_location(location: str, properties: Properties = EMPTY_DICT) -> tuple[str, str, str]:
@@ -433,11 +435,13 @@ def _initialize_fs(self, scheme: str, netloc: str | None = None) -> FileSystem:
433435
def _initialize_oss_fs(self) -> FileSystem:
434436
from pyarrow.fs import S3FileSystem
435437

438+
creds = _get_or_refresh_credentials(self.properties, self.session)
439+
436440
client_kwargs: dict[str, Any] = {
437441
"endpoint_override": self.properties.get(S3_ENDPOINT),
438-
"access_key": get_first_property_value(self.properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID),
439-
"secret_key": get_first_property_value(self.properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY),
440-
"session_token": get_first_property_value(self.properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN),
442+
"access_key": get_first_property_value(creds, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID),
443+
"secret_key": get_first_property_value(creds, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY),
444+
"session_token": get_first_property_value(creds, S3_SESSION_TOKEN, AWS_SESSION_TOKEN),
441445
"region": get_first_property_value(self.properties, S3_REGION, AWS_REGION),
442446
"force_virtual_addressing": property_as_bool(self.properties, S3_FORCE_VIRTUAL_ADDRESSING, True),
443447
}
@@ -480,11 +484,13 @@ def _initialize_s3_fs(self, netloc: str | None) -> FileSystem:
480484
else:
481485
bucket_region = provided_region
482486

487+
creds = _get_or_refresh_credentials(self.properties, self.session)
488+
483489
client_kwargs: dict[str, Any] = {
484490
"endpoint_override": self.properties.get(S3_ENDPOINT),
485-
"access_key": get_first_property_value(self.properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID),
486-
"secret_key": get_first_property_value(self.properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY),
487-
"session_token": get_first_property_value(self.properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN),
491+
"access_key": get_first_property_value(creds, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID),
492+
"secret_key": get_first_property_value(creds, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY),
493+
"session_token": get_first_property_value(creds, S3_SESSION_TOKEN, AWS_SESSION_TOKEN),
488494
"region": bucket_region,
489495
}
490496

0 commit comments

Comments
 (0)