Skip to content

Commit ea0777c

Browse files
committed
credential vending impl
1 parent 5da8186 commit ea0777c

3 files changed

Lines changed: 322 additions & 0 deletions

File tree

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
#
17+
from datetime import datetime
18+
19+
from pydantic import Field
20+
from requests import HTTPError, Session
21+
22+
from pyiceberg.catalog import URI
23+
from pyiceberg.catalog.rest.response import _handle_non_200_response
24+
from pyiceberg.catalog.rest.scan_planning import StorageCredential
25+
from pyiceberg.exceptions import ValidationException
26+
from pyiceberg.io import (
27+
AWS_ACCESS_KEY_ID,
28+
AWS_SECRET_ACCESS_KEY,
29+
AWS_SESSION_TOKEN,
30+
S3_ACCESS_KEY_ID,
31+
S3_SECRET_ACCESS_KEY,
32+
S3_SESSION_TOKEN,
33+
)
34+
from pyiceberg.typedef import IcebergBaseModel, Properties
35+
from pyiceberg.utils.properties import get_first_property_value
36+
37+
S3_SESSION_TOKEN_EXPIRES_AT_MS = "s3.session-token-expires-at-ms"
38+
CREDENTIALS_ENDPOINT = "client.refresh-credentials-endpoint"
39+
REFRESH_CREDENTIALS_ENABLED = "client.refresh-credentials-enabled"
40+
41+
42+
class LoadCredentialsResponse(IcebergBaseModel):
43+
credentials: list[StorageCredential] = Field(alias="storage-credentials")
44+
45+
46+
class VendedCredentialsProvider:
47+
_session: Session
48+
_properties: Properties
49+
50+
def __init__(self, session: Session, properties: Properties):
51+
self._session = session
52+
self._properties = properties
53+
54+
def _extract_s3_credentials_from(self, props: Properties) -> tuple[str | None, str | None, str | None, str | None]:
55+
"""Extract only S3 credentials from properties."""
56+
access_key = get_first_property_value(props, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID)
57+
secret_key = get_first_property_value(props, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY)
58+
session_token = get_first_property_value(props, S3_SESSION_TOKEN, AWS_SESSION_TOKEN)
59+
expiry = get_first_property_value(props, S3_SESSION_TOKEN_EXPIRES_AT_MS)
60+
61+
return access_key, secret_key, session_token, expiry
62+
63+
def _to_credentials_property_map(
64+
self, access_key: str | None, secret_key: str | None, session_token: str | None, expiry: str | None
65+
) -> Properties:
66+
return {
67+
S3_ACCESS_KEY_ID: access_key,
68+
S3_SECRET_ACCESS_KEY: secret_key,
69+
S3_SESSION_TOKEN: session_token,
70+
S3_SESSION_TOKEN_EXPIRES_AT_MS: expiry,
71+
}
72+
73+
def needs_refresh(self) -> bool:
74+
"""Return True if the S3 session token expires within 300s."""
75+
expiry = get_first_property_value(self._properties, S3_SESSION_TOKEN_EXPIRES_AT_MS)
76+
if expiry is None:
77+
return False
78+
expires_at = datetime.fromtimestamp(int(expiry) / 1000)
79+
seconds_remaining = (expires_at - datetime.now()).total_seconds()
80+
return seconds_remaining < 300
81+
82+
def _build_refresh_endpoint(self) -> str:
83+
"""Build credential refresh endpoint from properties."""
84+
catalog_uri = get_first_property_value(self._properties, URI)
85+
credentials_path = get_first_property_value(self._properties, CREDENTIALS_ENDPOINT)
86+
87+
if catalog_uri is None:
88+
raise ValidationException("Invalid catalog endpoint: None")
89+
90+
if credentials_path is None:
91+
raise ValidationException("Invalid credentials endpoint: None")
92+
93+
return str(catalog_uri).rstrip("/") + "/" + str(credentials_path).lstrip("/")
94+
95+
def _get_new_credentials(self) -> LoadCredentialsResponse | None:
96+
try:
97+
http_response = self._session.get(self._build_refresh_endpoint())
98+
http_response.raise_for_status()
99+
return LoadCredentialsResponse.model_validate_json(http_response.text)
100+
except HTTPError as exc:
101+
_handle_non_200_response(exc, {})
102+
return None
103+
104+
def get_credentials(self) -> Properties:
105+
"""Retrieve current S3 credentials, refreshing from the endpoint if near expiry."""
106+
access_key, secret_key, session_token, expiry = self._extract_s3_credentials_from(self._properties)
107+
108+
if not self.needs_refresh():
109+
return self._to_credentials_property_map(access_key, secret_key, session_token, expiry)
110+
111+
creds = self._get_new_credentials()
112+
113+
if creds is None:
114+
raise ValidationException("Load credential response is None")
115+
if not creds.credentials:
116+
raise ValueError("Invalid S3 Credentials: empty")
117+
if len(creds.credentials) > 1:
118+
raise ValueError("Invalid S3 Credentials: only one S3 credential should exists")
119+
120+
updated_creds = self._extract_s3_credentials_from(creds.credentials[0].config)
121+
updated_map = self._to_credentials_property_map(*updated_creds)
122+
123+
# Update internal properties with new credentials
124+
self._properties = {**self._properties, **updated_map}
125+
126+
return updated_map

pyiceberg/io/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,13 @@
3232
from io import SEEK_SET
3333
from types import TracebackType
3434
from typing import (
35+
TYPE_CHECKING,
3536
Protocol,
3637
runtime_checkable,
3738
)
39+
40+
if TYPE_CHECKING:
41+
from pyiceberg.catalog.rest.credentials_provider import VendedCredentialsProvider
3842
from urllib.parse import urlparse
3943

4044
from pyiceberg.typedef import EMPTY_DICT, Properties
@@ -291,6 +295,13 @@ def delete(self, location: str | InputFile | OutputFile) -> None:
291295
FileNotFoundError: When the file at the provided location does not exist.
292296
"""
293297

298+
def set_credentials_provider(self, provider: VendedCredentialsProvider) -> None: # noqa: B027
299+
"""Inject a credentials provider for refreshing vended storage credentials.
300+
301+
Args:
302+
provider (VendedCredentialsProvider): A concrete type of VendedCredentialsProvider (e.g S3VendedCredentialsProvider)
303+
"""
304+
294305

295306
LOCATION = "location"
296307
WAREHOUSE = "warehouse"
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import time
19+
from unittest.mock import MagicMock
20+
21+
import pytest
22+
23+
from pyiceberg.catalog.rest.credentials_provider import (
24+
CREDENTIALS_ENDPOINT,
25+
LoadCredentialsResponse,
26+
VendedCredentialsProvider,
27+
)
28+
from pyiceberg.catalog.rest.scan_planning import StorageCredential
29+
30+
CATALOG_URI = "http://localhost:8181"
31+
CREDENTIALS_PATH = "v1/credentials"
32+
33+
BASE_PROPS = {
34+
"uri": CATALOG_URI,
35+
CREDENTIALS_ENDPOINT: CREDENTIALS_PATH,
36+
"s3.access-key-id": "initial-key",
37+
"s3.secret-access-key": "initial-secret",
38+
"s3.session-token": "initial-token",
39+
}
40+
41+
REFRESH_RESPONSE = LoadCredentialsResponse(
42+
credentials=[
43+
StorageCredential(
44+
prefix="s3://",
45+
config={
46+
"s3.access-key-id": "refreshed-key",
47+
"s3.secret-access-key": "refreshed-secret",
48+
"s3.session-token": "refreshed-token",
49+
},
50+
)
51+
]
52+
)
53+
54+
55+
def _make_session(response: LoadCredentialsResponse = REFRESH_RESPONSE) -> MagicMock:
56+
session = MagicMock()
57+
mock_response = MagicMock()
58+
mock_response.text = response.model_dump_json(by_alias=True)
59+
mock_response.raise_for_status.return_value = None
60+
session.get.return_value = mock_response
61+
return session
62+
63+
64+
def test_get_credentials_no_expiry_returns_static_creds() -> None:
65+
"""When no expiry is set, credentials are returned from properties without an HTTP call."""
66+
session = _make_session()
67+
provider = VendedCredentialsProvider(session, BASE_PROPS)
68+
creds = provider.get_credentials()
69+
70+
session.get.assert_not_called()
71+
assert creds["s3.access-key-id"] == "initial-key"
72+
assert creds["s3.secret-access-key"] == "initial-secret"
73+
assert creds["s3.session-token"] == "initial-token"
74+
75+
76+
def test_get_credentials_far_expiry_returns_static_creds() -> None:
77+
"""When expiry is far in the future (>300s), no refresh is triggered."""
78+
far_future_ms = str(int((time.time() + 3600) * 1000)) # expires in 1 hour
79+
props = {**BASE_PROPS, "s3.session-token-expires-at-ms": far_future_ms}
80+
session = _make_session()
81+
provider = VendedCredentialsProvider(session, props)
82+
creds = provider.get_credentials()
83+
84+
session.get.assert_not_called()
85+
assert creds["s3.access-key-id"] == "initial-key"
86+
87+
88+
def test_get_credentials_near_expiry_calls_refresh_endpoint() -> None:
89+
"""When expiry is within 300s, the refresh endpoint is called and new creds returned."""
90+
near_expiry_ms = str(int((time.time() + 60) * 1000)) # expires in 60s
91+
props = {**BASE_PROPS, "s3.session-token-expires-at-ms": near_expiry_ms}
92+
session = _make_session()
93+
provider = VendedCredentialsProvider(session, props)
94+
creds = provider.get_credentials()
95+
96+
session.get.assert_called_once_with(f"{CATALOG_URI}/{CREDENTIALS_PATH}")
97+
assert creds["s3.access-key-id"] == "refreshed-key"
98+
assert creds["s3.secret-access-key"] == "refreshed-secret"
99+
assert creds["s3.session-token"] == "refreshed-token"
100+
101+
102+
def test_get_credentials_raises_on_empty_credentials() -> None:
103+
"""An empty credentials list in the refresh response raises ValueError."""
104+
near_expiry_ms = str(int((time.time() + 60) * 1000))
105+
props = {**BASE_PROPS, "s3.session-token-expires-at-ms": near_expiry_ms}
106+
empty_response = LoadCredentialsResponse(credentials=[])
107+
provider = VendedCredentialsProvider(_make_session(empty_response), props)
108+
109+
with pytest.raises(ValueError, match="empty"):
110+
provider.get_credentials()
111+
112+
113+
def test_get_credentials_raises_on_multiple_credentials() -> None:
114+
"""More than one credential in the refresh response raises ValueError."""
115+
near_expiry_ms = str(int((time.time() + 60) * 1000))
116+
props = {**BASE_PROPS, "s3.session-token-expires-at-ms": near_expiry_ms}
117+
multi_response = LoadCredentialsResponse(
118+
credentials=[
119+
StorageCredential(prefix="s3://", config={}),
120+
StorageCredential(prefix="s3://b", config={}),
121+
]
122+
)
123+
provider = VendedCredentialsProvider(_make_session(multi_response), props)
124+
125+
with pytest.raises(ValueError, match="only one"):
126+
provider.get_credentials()
127+
128+
129+
def test_build_refresh_endpoint_strips_trailing_slash() -> None:
130+
props = {**BASE_PROPS, "uri": "http://localhost:8181/"}
131+
provider = VendedCredentialsProvider(MagicMock(), props)
132+
assert provider._build_refresh_endpoint() == f"http://localhost:8181/{CREDENTIALS_PATH}"
133+
134+
135+
def test_build_refresh_endpoint_raises_without_uri() -> None:
136+
props = {CREDENTIALS_ENDPOINT: CREDENTIALS_PATH}
137+
provider = VendedCredentialsProvider(MagicMock(), props)
138+
139+
from pyiceberg.exceptions import ValidationException
140+
141+
with pytest.raises(ValidationException):
142+
provider._build_refresh_endpoint()
143+
144+
145+
def test_needs_refresh_true_when_near_expiry() -> None:
146+
near_expiry_ms = str(int((time.time() + 60) * 1000))
147+
provider = VendedCredentialsProvider(MagicMock(), {**BASE_PROPS, "s3.session-token-expires-at-ms": near_expiry_ms})
148+
assert provider.needs_refresh() is True
149+
150+
151+
def test_needs_refresh_false_when_far_expiry() -> None:
152+
far_expiry_ms = str(int((time.time() + 3600) * 1000))
153+
provider = VendedCredentialsProvider(MagicMock(), {**BASE_PROPS, "s3.session-token-expires-at-ms": far_expiry_ms})
154+
assert provider.needs_refresh() is False
155+
156+
157+
def test_needs_refresh_false_when_no_expiry() -> None:
158+
provider = VendedCredentialsProvider(MagicMock(), BASE_PROPS)
159+
assert provider.needs_refresh() is False
160+
161+
162+
def test_get_credentials_updates_internal_properties_after_refresh() -> None:
163+
"""After a refresh, _properties holds the new expiry so needs_refresh() sees the updated state."""
164+
far_future_ms = str(int((time.time() + 3600) * 1000))
165+
refreshed_response = LoadCredentialsResponse(
166+
credentials=[
167+
StorageCredential(
168+
prefix="s3://",
169+
config={
170+
"s3.access-key-id": "new-key",
171+
"s3.secret-access-key": "new-secret",
172+
"s3.session-token": "new-token",
173+
"s3.session-token-expires-at-ms": far_future_ms,
174+
},
175+
)
176+
]
177+
)
178+
near_expiry_ms = str(int((time.time() + 60) * 1000))
179+
props = {**BASE_PROPS, "s3.session-token-expires-at-ms": near_expiry_ms}
180+
provider = VendedCredentialsProvider(_make_session(refreshed_response), props)
181+
182+
assert provider.needs_refresh() is True
183+
provider.get_credentials()
184+
assert provider.needs_refresh() is False
185+
assert provider._properties["s3.session-token-expires-at-ms"] == far_future_ms

0 commit comments

Comments
 (0)