2929import logging
3030import warnings
3131from abc import ABC , abstractmethod
32+ from datetime import datetime
3233from io import SEEK_SET
3334from types import TracebackType
3435from typing import (
3738)
3839from urllib .parse import urlparse
3940
41+ from requests import HTTPError , Session
42+
43+ from pyiceberg .exceptions import ValidationException
4044from pyiceberg .typedef import EMPTY_DICT , Properties
45+ from pyiceberg .utils .properties import get_first_property_value , property_as_bool , property_as_int
4146
4247logger = logging .getLogger (__name__ )
4348
6772S3_ROLE_SESSION_NAME = "s3.role-session-name"
6873S3_FORCE_VIRTUAL_ADDRESSING = "s3.force-virtual-addressing"
6974S3_RETRY_STRATEGY_IMPL = "s3.retry-strategy-impl"
75+ S3_SESSION_TOKEN_EXPIRES_AT_MS = "s3.session-token-expires-at-ms"
7076HDFS_HOST = "hdfs.host"
7177HDFS_PORT = "hdfs.port"
7278HDFS_USER = "hdfs.user"
99105GCS_VERSION_AWARE = "gcs.version-aware"
100106HF_ENDPOINT = "hf.endpoint"
101107HF_TOKEN = "hf.token"
108+ CREDENTIALS_ENDPOINT = "client.refresh-credentials-endpoint"
109+ REFRESH_CREDENTIALS_ENABLED = "client.refresh-credentials-enabled"
110+ CATALOG_URI = "uri"
102111
103112
104113@runtime_checkable
@@ -258,9 +267,11 @@ class FileIO(ABC):
258267 """A base class for FileIO implementations."""
259268
260269 properties : Properties
270+ session : Session | None
261271
262- def __init__ (self , properties : Properties = EMPTY_DICT ):
272+ def __init__ (self , properties : Properties = EMPTY_DICT , session : Session | None = None ):
263273 self .properties = properties
274+ self .session = session
264275
265276 @abstractmethod
266277 def new_input (self , location : str ) -> InputFile :
@@ -317,15 +328,15 @@ def delete(self, location: str | InputFile | OutputFile) -> None:
317328}
318329
319330
320- def _import_file_io (io_impl : str , properties : Properties ) -> FileIO | None :
331+ def _import_file_io (io_impl : str , properties : Properties , session : Session | None = None ) -> FileIO | None :
321332 try :
322333 path_parts = io_impl .split ("." )
323334 if len (path_parts ) < 2 :
324335 raise ValueError (f"py-io-impl should be full path (module.CustomFileIO), got: { io_impl } " )
325336 module_name , class_name = "." .join (path_parts [:- 1 ]), path_parts [- 1 ]
326337 module = importlib .import_module (module_name )
327338 class_ = getattr (module , class_name )
328- return class_ (properties )
339+ return class_ (properties , session )
329340 except ModuleNotFoundError :
330341 logger .warning (f"Could not initialize FileIO: { io_impl } " , exc_info = logger .isEnabledFor (logging .DEBUG ))
331342 return None
@@ -334,45 +345,134 @@ def _import_file_io(io_impl: str, properties: Properties) -> FileIO | None:
334345PY_IO_IMPL = "py-io-impl"
335346
336347
337- def _infer_file_io_from_scheme (path : str , properties : Properties ) -> FileIO | None :
348+ def _infer_file_io_from_scheme (path : str , properties : Properties , session : Session | None = None ) -> FileIO | None :
338349 parsed_url = urlparse (path )
339350 if parsed_url .scheme :
340351 if file_ios := SCHEMA_TO_FILE_IO .get (parsed_url .scheme ):
341352 for file_io_path in file_ios :
342- if file_io := _import_file_io (file_io_path , properties ):
353+ if file_io := _import_file_io (file_io_path , properties , session ):
343354 return file_io
344355 else :
345356 warnings .warn (f"No preferred file implementation for scheme: { parsed_url .scheme } " , stacklevel = 2 )
346357 return None
347358
348359
349- def load_file_io (properties : Properties = EMPTY_DICT , location : str | None = None ) -> FileIO :
360+ def load_file_io (properties : Properties = EMPTY_DICT , location : str | None = None , session : Session | None = None ) -> FileIO :
350361 # First look for the py-io-impl property to directly load the class
351362 if io_impl := properties .get (PY_IO_IMPL ):
352- if file_io := _import_file_io (io_impl , properties ):
363+ if file_io := _import_file_io (io_impl , properties , session ):
353364 logger .info ("Loaded FileIO: %s" , io_impl )
354365 return file_io
355366 else :
356367 raise ValueError (f"Could not initialize FileIO: { io_impl } " )
357368
358369 # Check the table location
359370 if location :
360- if file_io := _infer_file_io_from_scheme (location , properties ):
371+ if file_io := _infer_file_io_from_scheme (location , properties , session ):
361372 return file_io
362373
363374 # Look at the schema of the warehouse
364375 if warehouse_location := properties .get (WAREHOUSE ):
365- if file_io := _infer_file_io_from_scheme (warehouse_location , properties ):
376+ if file_io := _infer_file_io_from_scheme (warehouse_location , properties , session ):
366377 return file_io
367378
368379 try :
369380 # Default to PyArrow
370381 logger .info ("Defaulting to PyArrow FileIO" )
371382 from pyiceberg .io .pyarrow import PyArrowFileIO
372383
373- return PyArrowFileIO (properties )
384+ return PyArrowFileIO (properties , session )
374385 except ModuleNotFoundError as e :
375386 raise ModuleNotFoundError (
376387 "Could not load a FileIO, please consider installing one: "
377388 'pip3 install "pyiceberg[pyarrow]", for more options refer to the docs.'
378389 ) from e
390+
391+
392+ def _extract_s3_credentials (properties : Properties ) -> Properties :
393+ """Extract only S3 credential keys from properties, normalizing AWS_ prefixes to S3_."""
394+ creds : Properties = {}
395+ if access_key := get_first_property_value (properties , S3_ACCESS_KEY_ID , AWS_ACCESS_KEY_ID ):
396+ creds [S3_ACCESS_KEY_ID ] = access_key
397+ if secret_key := get_first_property_value (properties , S3_SECRET_ACCESS_KEY , AWS_SECRET_ACCESS_KEY ):
398+ creds [S3_SECRET_ACCESS_KEY ] = secret_key
399+ if session_token := get_first_property_value (properties , S3_SESSION_TOKEN , AWS_SESSION_TOKEN ):
400+ creds [S3_SESSION_TOKEN ] = session_token
401+ if expiry := get_first_property_value (properties , S3_SESSION_TOKEN_EXPIRES_AT_MS ):
402+ creds [S3_SESSION_TOKEN_EXPIRES_AT_MS ] = expiry
403+ return creds
404+
405+
406+ def _credential_from_properties (properties : Properties ) -> Properties :
407+ """Retrieve current S3 credentials from properties returns empty if expired."""
408+ access_key = get_first_property_value (properties , S3_ACCESS_KEY_ID , AWS_ACCESS_KEY_ID )
409+ secret_access_key = get_first_property_value (properties , S3_SECRET_ACCESS_KEY , AWS_SECRET_ACCESS_KEY )
410+ session_token = get_first_property_value (properties , S3_SESSION_TOKEN , AWS_SESSION_TOKEN )
411+ expiration_ms = property_as_int (properties , S3_SESSION_TOKEN_EXPIRES_AT_MS )
412+
413+ if not access_key or not secret_access_key or not session_token or not expiration_ms :
414+ return EMPTY_DICT
415+
416+ expiresAt = datetime .fromtimestamp (expiration_ms / 1000 )
417+ prefetchAt = (expiresAt - datetime .now ()).total_seconds ()
418+
419+ if prefetchAt > 300 :
420+ return EMPTY_DICT
421+
422+ return {
423+ S3_ACCESS_KEY_ID : access_key ,
424+ S3_SECRET_ACCESS_KEY : secret_access_key ,
425+ S3_SESSION_TOKEN : session_token ,
426+ S3_SESSION_TOKEN_EXPIRES_AT_MS : expiration_ms ,
427+ }
428+
429+
430+ def _credential_refresh_endpoint (properties : Properties ) -> str :
431+ """Build credential refresh endpoint from properties."""
432+ catalog_uri = get_first_property_value (properties , CATALOG_URI )
433+ credentials_path = get_first_property_value (properties , CREDENTIALS_ENDPOINT )
434+
435+ if catalog_uri is None :
436+ raise ValidationException ("Invalid catalog endpoint: None" )
437+
438+ if credentials_path is None :
439+ raise ValidationException ("Invalid credentials endpoint: None" )
440+
441+ return str (catalog_uri ).rstrip ("/" ) + "/" + str (credentials_path ).lstrip ("/" )
442+
443+
444+ def _get_or_refresh_credentials (properties : Properties , session : Session | None ) -> Properties :
445+ """Retrieve current S3 credentials from properties, refreshing them if they are close to expiration."""
446+ refresh_enabled = property_as_bool (properties , REFRESH_CREDENTIALS_ENABLED , False )
447+ if not refresh_enabled or session is None :
448+ return _extract_s3_credentials (properties )
449+
450+ # Returns empty if credentials missing or not yet expiring
451+ creds = _credential_from_properties (properties )
452+
453+ if not creds :
454+ return _extract_s3_credentials (properties )
455+
456+ from pyiceberg .catalog .rest import LoadCredentialsResponse
457+ from pyiceberg .catalog .rest .response import _handle_non_200_response
458+
459+ load_response : LoadCredentialsResponse | None = None
460+
461+ try :
462+ http_response = session .get (_credential_refresh_endpoint (properties ))
463+ http_response .raise_for_status ()
464+ load_response = LoadCredentialsResponse .model_validate_json (http_response .text )
465+ except HTTPError as exc :
466+ _handle_non_200_response (exc , {})
467+
468+ if load_response is None :
469+ raise ValidationException ("Load credential response is None" )
470+
471+ if not load_response .credentials :
472+ raise ValueError ("Invalid S3 Credentials: empty" )
473+
474+ if len (load_response .credentials ) > 1 :
475+ raise ValueError ("Invalid S3 Credentials: only one S3 credential should exist" )
476+
477+ credentials = load_response .credentials [0 ].config
478+ return _extract_s3_credentials (credentials )
0 commit comments