Skip to content

Commit 002d704

Browse files
committed
implement better encapsulation to decouple IO from session
1 parent 5da8186 commit 002d704

8 files changed

Lines changed: 552 additions & 6 deletions

File tree

pyiceberg/catalog/rest/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from pyiceberg import __version__
3333
from pyiceberg.catalog import BOTOCORE_SESSION, TOKEN, URI, WAREHOUSE_LOCATION, Catalog, PropertiesUpdateSummary
3434
from pyiceberg.catalog.rest.auth import AUTH_MANAGER, AuthManager, AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager
35+
from pyiceberg.catalog.rest.credentials_provider import REFRESH_CREDENTIALS_ENABLED, VendedCredentialsProvider
3536
from pyiceberg.catalog.rest.response import _handle_non_200_response
3637
from pyiceberg.catalog.rest.scan_planning import (
3738
FetchScanTasksRequest,
@@ -484,7 +485,10 @@ def _load_file_io(self, properties: Properties = EMPTY_DICT, location: str | Non
484485
merged_properties = {**self.properties, **properties}
485486
if self._auth_manager:
486487
merged_properties[AUTH_MANAGER] = self._auth_manager
487-
return load_file_io(merged_properties, location)
488+
file_io = load_file_io(merged_properties, location)
489+
if property_as_bool(merged_properties, REFRESH_CREDENTIALS_ENABLED, False):
490+
file_io.set_credentials_provider(VendedCredentialsProvider(self._session, merged_properties))
491+
return file_io
488492

489493
@override
490494
def supports_server_side_planning(self) -> bool:
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
#
17+
from datetime import datetime
18+
19+
from pydantic import Field
20+
from requests import HTTPError, Session
21+
22+
from pyiceberg.catalog import URI
23+
from pyiceberg.catalog.rest.response import _handle_non_200_response
24+
from pyiceberg.catalog.rest.scan_planning import StorageCredential
25+
from pyiceberg.exceptions import ValidationError, ValidationException
26+
from pyiceberg.io import (
27+
AWS_ACCESS_KEY_ID,
28+
AWS_SECRET_ACCESS_KEY,
29+
AWS_SESSION_TOKEN,
30+
S3_ACCESS_KEY_ID,
31+
S3_SECRET_ACCESS_KEY,
32+
S3_SESSION_TOKEN,
33+
)
34+
from pyiceberg.typedef import IcebergBaseModel, Properties
35+
from pyiceberg.utils.properties import get_first_property_value
36+
37+
S3_SESSION_TOKEN_EXPIRES_AT_MS = "s3.session-token-expires-at-ms"
38+
CREDENTIALS_ENDPOINT = "client.refresh-credentials-endpoint"
39+
REFRESH_CREDENTIALS_ENABLED = "client.refresh-credentials-enabled"
40+
41+
42+
class LoadCredentialsResponse(IcebergBaseModel):
43+
credentials: list[StorageCredential] = Field(alias="storage-credentials")
44+
45+
46+
class VendedCredentialsProvider:
47+
_session: Session
48+
_properties: Properties
49+
50+
def __init__(self, session: Session, properties: Properties):
51+
self._session = session
52+
self._properties = properties
53+
54+
def _extract_s3_credentials_from(self, props: Properties) -> tuple[str | None, str | None, str | None, str | None]:
55+
"""Extract only S3 credentials from properties."""
56+
access_key = get_first_property_value(props, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID)
57+
secret_key = get_first_property_value(props, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY)
58+
session_token = get_first_property_value(props, S3_SESSION_TOKEN, AWS_SESSION_TOKEN)
59+
expiry = get_first_property_value(props, S3_SESSION_TOKEN_EXPIRES_AT_MS)
60+
61+
return access_key, secret_key, session_token, expiry
62+
63+
def _to_credentials_property_map(
64+
self, access_key: str | None, secret_key: str | None, session_token: str | None, expiry: str | None
65+
) -> Properties:
66+
return {
67+
S3_ACCESS_KEY_ID: access_key,
68+
S3_SECRET_ACCESS_KEY: secret_key,
69+
S3_SESSION_TOKEN: session_token,
70+
S3_SESSION_TOKEN_EXPIRES_AT_MS: expiry,
71+
}
72+
73+
def needs_refresh(self) -> bool:
74+
"""Return True if the S3 session token expires within 300s."""
75+
expiry = get_first_property_value(self._properties, S3_SESSION_TOKEN_EXPIRES_AT_MS)
76+
if expiry is None:
77+
return False
78+
expires_at = datetime.fromtimestamp(int(expiry) / 1000)
79+
seconds_remaining = (expires_at - datetime.now()).total_seconds()
80+
return seconds_remaining < 300
81+
82+
def _build_refresh_endpoint(self) -> str:
83+
"""Build credential refresh endpoint from properties."""
84+
catalog_uri = get_first_property_value(self._properties, URI)
85+
credentials_path = get_first_property_value(self._properties, CREDENTIALS_ENDPOINT)
86+
87+
if catalog_uri is None:
88+
raise ValidationException("Invalid catalog endpoint: None")
89+
90+
if credentials_path is None:
91+
raise ValidationException("Invalid credentials endpoint: None")
92+
93+
return str(catalog_uri).rstrip("/") + "/" + str(credentials_path).lstrip("/")
94+
95+
def _get_new_credentials(self) -> LoadCredentialsResponse | None:
96+
try:
97+
http_response = self._session.get(self._build_refresh_endpoint())
98+
http_response.raise_for_status()
99+
return LoadCredentialsResponse.model_validate_json(http_response.text)
100+
except HTTPError as exc:
101+
_handle_non_200_response(exc, {})
102+
return None
103+
104+
def get_credentials(self) -> Properties:
105+
"""Retrieve current S3 credentials, refreshing from the endpoint if near expiry."""
106+
access_key, secret_key, session_token, expiry = self._extract_s3_credentials_from(self._properties)
107+
108+
if not self.needs_refresh():
109+
return self._to_credentials_property_map(access_key, secret_key, session_token, expiry)
110+
111+
creds = self._get_new_credentials()
112+
113+
if creds is None:
114+
raise ValidationError("Load credential response is None")
115+
if not creds.credentials:
116+
raise ValueError("Invalid S3 Credentials: empty")
117+
if len(creds.credentials) > 1:
118+
raise ValueError("Invalid S3 Credentials: only one S3 credential should exists")
119+
120+
updated_creds = self._extract_s3_credentials_from(creds.credentials[0].config)
121+
updated_map = self._to_credentials_property_map(*updated_creds)
122+
123+
# Update internal properties with new credentials
124+
self._properties = {**self._properties, **updated_map}
125+
126+
return updated_map

pyiceberg/io/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,13 @@
3232
from io import SEEK_SET
3333
from types import TracebackType
3434
from typing import (
35+
TYPE_CHECKING,
3536
Protocol,
3637
runtime_checkable,
3738
)
39+
40+
if TYPE_CHECKING:
41+
from pyiceberg.catalog.rest.credentials_provider import VendedCredentialsProvider
3842
from urllib.parse import urlparse
3943

4044
from pyiceberg.typedef import EMPTY_DICT, Properties
@@ -291,6 +295,13 @@ def delete(self, location: str | InputFile | OutputFile) -> None:
291295
FileNotFoundError: When the file at the provided location does not exist.
292296
"""
293297

298+
def set_credentials_provider(self, provider: VendedCredentialsProvider) -> None: # noqa: B027
299+
"""Inject a credentials provider for refreshing vended storage credentials.
300+
301+
Args:
302+
provider (VendedCredentialsProvider): A concrete type of VendedCredentialsProvider (e.g S3VendedCredentialsProvider)
303+
"""
304+
294305

295306
LOCATION = "location"
296307
WAREHOUSE = "warehouse"

pyiceberg/io/fsspec.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
from pyiceberg.catalog import TOKEN, URI
4141
from pyiceberg.catalog.rest.auth import AUTH_MANAGER
42+
from pyiceberg.catalog.rest.credentials_provider import VendedCredentialsProvider
4243
from pyiceberg.exceptions import SignError
4344
from pyiceberg.io import (
4445
ADLS_ACCOUNT_HOST,
@@ -166,9 +167,12 @@ def _file(_: Properties) -> LocalFileSystem:
166167
return LocalFileSystem(auto_mkdir=True)
167168

168169

169-
def _s3(properties: Properties) -> AbstractFileSystem:
170+
def _s3(properties: Properties, cred_provider: VendedCredentialsProvider | None) -> AbstractFileSystem:
170171
from s3fs import S3FileSystem
171172

173+
if cred_provider is not None and cred_provider.needs_refresh():
174+
properties = {**properties, **cred_provider.get_credentials()}
175+
172176
client_kwargs = {
173177
"endpoint_url": properties.get(S3_ENDPOINT),
174178
"aws_access_key_id": get_first_property_value(properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID),
@@ -319,6 +323,7 @@ def _hf(properties: Properties) -> AbstractFileSystem:
319323
}
320324

321325
_ADLS_SCHEMES = frozenset({"abfs", "abfss", "wasb", "wasbs"})
326+
_S3_SCHEMES = frozenset({"s3", "s3a", "s3n"})
322327

323328

324329
class FsspecInputFile(InputFile):
@@ -430,8 +435,12 @@ class FsspecFileIO(FileIO):
430435
def __init__(self, properties: Properties):
431436
self._scheme_to_fs: dict[str, Callable[..., AbstractFileSystem]] = dict(SCHEME_TO_FS)
432437
self._thread_locals = threading.local()
438+
self._credentials_provider: VendedCredentialsProvider | None = None
433439
super().__init__(properties=properties)
434440

441+
def set_credentials_provider(self, provider: VendedCredentialsProvider) -> None:
442+
self._credentials_provider = provider
443+
435444
@override
436445
def new_input(self, location: str) -> FsspecInputFile:
437446
"""Get an FsspecInputFile instance to read bytes from the file at the given location.
@@ -486,9 +495,12 @@ def _get_fs_from_uri(self, uri: "ParseResult") -> AbstractFileSystem:
486495

487496
def get_fs(self, scheme: str, hostname: str | None = None) -> AbstractFileSystem:
488497
"""Get a filesystem for a specific scheme, cached per thread."""
489-
if not hasattr(self._thread_locals, "get_fs_cached"):
490-
self._thread_locals.get_fs_cached = lru_cache(self._get_fs)
498+
# If we have available a CredentialProvider and we detect that the tokens need to be refreshed
499+
# then invalidate the cached fileio in order to get a new fileio with the fresh credentials
500+
needs_refresh = self._credentials_provider and self._credentials_provider.needs_refresh()
491501

502+
if not hasattr(self._thread_locals, "get_fs_cached") or needs_refresh:
503+
self._thread_locals.get_fs_cached = lru_cache(self._get_fs)
492504
return self._thread_locals.get_fs_cached(scheme, hostname)
493505

494506
def _get_fs(self, scheme: str, hostname: str | None = None) -> AbstractFileSystem:
@@ -499,6 +511,9 @@ def _get_fs(self, scheme: str, hostname: str | None = None) -> AbstractFileSyste
499511
if scheme in _ADLS_SCHEMES:
500512
return _adls(self.properties, hostname)
501513

514+
if scheme in _S3_SCHEMES:
515+
return _s3(self.properties, self._credentials_provider)
516+
502517
return self._scheme_to_fs[scheme](self.properties)
503518

504519
def __getstate__(self) -> dict[str, Any]:

pyiceberg/io/pyarrow.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@
187187
from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string
188188

189189
if TYPE_CHECKING:
190+
from pyiceberg.catalog.rest.credentials_provider import VendedCredentialsProvider
190191
from pyiceberg.table import FileScanTask, WriteTask
191192

192193
logger = logging.getLogger(__name__)
@@ -394,8 +395,20 @@ class PyArrowFileIO(FileIO):
394395

395396
def __init__(self, properties: Properties = EMPTY_DICT):
396397
self.fs_by_scheme: Callable[[str, str | None], FileSystem] = lru_cache(self._initialize_fs)
398+
self._credentials_provider: VendedCredentialsProvider | None = None
397399
super().__init__(properties=properties)
398400

401+
def set_credentials_provider(self, provider: VendedCredentialsProvider) -> None:
402+
self._credentials_provider = provider
403+
404+
def _get_fs(self, scheme: str, netloc: str | None) -> FileSystem:
405+
# If we have available a CredentialProvider and we detect that the tokens need to be refreshed
406+
# then invalidate the cached fileio in order to get a new fileio with the fresh credentials
407+
if self._credentials_provider and self._credentials_provider.needs_refresh():
408+
self.properties = {**self.properties, **self._credentials_provider.get_credentials()}
409+
self.fs_by_scheme = lru_cache(self._initialize_fs)
410+
return self.fs_by_scheme(scheme, netloc)
411+
399412
@staticmethod
400413
def parse_location(location: str, properties: Properties = EMPTY_DICT) -> tuple[str, str, str]:
401414
"""Return (scheme, netloc, path) for the given location.
@@ -628,7 +641,7 @@ def new_input(self, location: str) -> PyArrowFile:
628641
"""
629642
scheme, netloc, path = self.parse_location(location, self.properties)
630643
return PyArrowFile(
631-
fs=self.fs_by_scheme(scheme, netloc),
644+
fs=self._get_fs(scheme, netloc),
632645
location=location,
633646
path=path,
634647
buffer_size=int(self.properties.get(BUFFER_SIZE, ONE_MEGABYTE)),
@@ -646,7 +659,7 @@ def new_output(self, location: str) -> PyArrowFile:
646659
"""
647660
scheme, netloc, path = self.parse_location(location, self.properties)
648661
return PyArrowFile(
649-
fs=self.fs_by_scheme(scheme, netloc),
662+
fs=self._get_fs(scheme, netloc),
650663
location=location,
651664
path=path,
652665
buffer_size=int(self.properties.get(BUFFER_SIZE, ONE_MEGABYTE)),

0 commit comments

Comments
 (0)