Skip to content

Commit a8033c1

Browse files
authored
feat: Allow servers to express supported endpoints with ConfigResponse (#2848)
closes to #2847 # Rationale for this change This PR adds the server endpoint capabilities support, aligning with the Java [implementation](https://github.com/apache/iceberg/blob/main/core/src/main/java/org/apache/iceberg/rest/RESTSessionCatalog.java). While working on the REST scanning support, we need to know if a server supports specific capabilities before making any calls. So this PR also adds some extra support for the current implementation of PI iceberg REST catalog. The REST catalog will now parse the endpoints field from the config call to determine server capabilities. When a server doesn't respond, we have fallback logic that matches the behavior of Java's rest catalog. The View endpoints are conditionally added to the default with the config property as well. ## Are these changes tested? Added unit tests and tested with the iceberg rest fixture. ## Are there any user-facing changes? Yes added config and alignment with java impl. cc: @kevinjqliu @Fokko
1 parent fa03e08 commit a8033c1

2 files changed

Lines changed: 287 additions & 18 deletions

File tree

pyiceberg/catalog/rest/__init__.py

Lines changed: 160 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
Union,
2222
)
2323

24-
from pydantic import Field, field_validator
24+
from pydantic import ConfigDict, Field, field_validator
2525
from requests import HTTPError, Session
2626
from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt
2727

@@ -76,6 +76,39 @@
7676
import pyarrow as pa
7777

7878

79+
class HttpMethod(str, Enum):
80+
GET = "GET"
81+
HEAD = "HEAD"
82+
POST = "POST"
83+
DELETE = "DELETE"
84+
85+
86+
class Endpoint(IcebergBaseModel):
87+
model_config = ConfigDict(frozen=True)
88+
89+
http_method: HttpMethod = Field()
90+
path: str = Field()
91+
92+
@field_validator("path", mode="before")
93+
@classmethod
94+
def _validate_path(cls, raw_path: str) -> str:
95+
raw_path = raw_path.strip()
96+
if not raw_path:
97+
raise ValueError("Invalid path: empty")
98+
return raw_path
99+
100+
def __str__(self) -> str:
101+
"""Return the string representation of the Endpoint class."""
102+
return f"{self.http_method.value} {self.path}"
103+
104+
@classmethod
105+
def from_string(cls, endpoint: str) -> "Endpoint":
106+
elements = endpoint.strip().split(None, 1)
107+
if len(elements) != 2:
108+
raise ValueError(f"Invalid endpoint (must consist of two elements separated by a single space): {endpoint}")
109+
return cls(http_method=HttpMethod(elements[0].upper()), path=elements[1])
110+
111+
79112
class Endpoints:
80113
get_config: str = "config"
81114
list_namespaces: str = "namespaces"
@@ -86,7 +119,7 @@ class Endpoints:
86119
namespace_exists: str = "namespaces/{namespace}"
87120
list_tables: str = "namespaces/{namespace}/tables"
88121
create_table: str = "namespaces/{namespace}/tables"
89-
register_table = "namespaces/{namespace}/register"
122+
register_table: str = "namespaces/{namespace}/register"
90123
load_table: str = "namespaces/{namespace}/tables/{table}"
91124
update_table: str = "namespaces/{namespace}/tables/{table}"
92125
drop_table: str = "namespaces/{namespace}/tables/{table}"
@@ -100,6 +133,61 @@ class Endpoints:
100133
fetch_scan_tasks: str = "namespaces/{namespace}/tables/{table}/tasks"
101134

102135

136+
API_PREFIX = "/v1/{prefix}"
137+
138+
139+
class Capability:
140+
V1_LIST_NAMESPACES = Endpoint(http_method=HttpMethod.GET, path=f"{API_PREFIX}/{Endpoints.list_namespaces}")
141+
V1_LOAD_NAMESPACE = Endpoint(http_method=HttpMethod.GET, path=f"{API_PREFIX}/{Endpoints.load_namespace_metadata}")
142+
V1_NAMESPACE_EXISTS = Endpoint(http_method=HttpMethod.HEAD, path=f"{API_PREFIX}/{Endpoints.namespace_exists}")
143+
V1_UPDATE_NAMESPACE = Endpoint(http_method=HttpMethod.POST, path=f"{API_PREFIX}/{Endpoints.update_namespace_properties}")
144+
V1_CREATE_NAMESPACE = Endpoint(http_method=HttpMethod.POST, path=f"{API_PREFIX}/{Endpoints.create_namespace}")
145+
V1_DELETE_NAMESPACE = Endpoint(http_method=HttpMethod.DELETE, path=f"{API_PREFIX}/{Endpoints.drop_namespace}")
146+
147+
V1_LIST_TABLES = Endpoint(http_method=HttpMethod.GET, path=f"{API_PREFIX}/{Endpoints.list_tables}")
148+
V1_LOAD_TABLE = Endpoint(http_method=HttpMethod.GET, path=f"{API_PREFIX}/{Endpoints.load_table}")
149+
V1_TABLE_EXISTS = Endpoint(http_method=HttpMethod.HEAD, path=f"{API_PREFIX}/{Endpoints.table_exists}")
150+
V1_CREATE_TABLE = Endpoint(http_method=HttpMethod.POST, path=f"{API_PREFIX}/{Endpoints.create_table}")
151+
V1_UPDATE_TABLE = Endpoint(http_method=HttpMethod.POST, path=f"{API_PREFIX}/{Endpoints.update_table}")
152+
V1_DELETE_TABLE = Endpoint(http_method=HttpMethod.DELETE, path=f"{API_PREFIX}/{Endpoints.drop_table}")
153+
V1_RENAME_TABLE = Endpoint(http_method=HttpMethod.POST, path=f"{API_PREFIX}/{Endpoints.rename_table}")
154+
V1_REGISTER_TABLE = Endpoint(http_method=HttpMethod.POST, path=f"{API_PREFIX}/{Endpoints.register_table}")
155+
156+
V1_LIST_VIEWS = Endpoint(http_method=HttpMethod.GET, path=f"{API_PREFIX}/{Endpoints.list_views}")
157+
V1_VIEW_EXISTS = Endpoint(http_method=HttpMethod.HEAD, path=f"{API_PREFIX}/{Endpoints.view_exists}")
158+
V1_DELETE_VIEW = Endpoint(http_method=HttpMethod.DELETE, path=f"{API_PREFIX}/{Endpoints.drop_view}")
159+
V1_SUBMIT_TABLE_SCAN_PLAN = Endpoint(http_method=HttpMethod.POST, path=f"{API_PREFIX}/{Endpoints.plan_table_scan}")
160+
V1_TABLE_SCAN_PLAN_TASKS = Endpoint(http_method=HttpMethod.POST, path=f"{API_PREFIX}/{Endpoints.fetch_scan_tasks}")
161+
162+
163+
# Default endpoints for backwards compatibility with legacy servers that don't return endpoints
164+
# in ConfigResponse. Only includes namespace and table endpoints.
165+
DEFAULT_ENDPOINTS: frozenset[Endpoint] = frozenset(
166+
(
167+
Capability.V1_LIST_NAMESPACES,
168+
Capability.V1_LOAD_NAMESPACE,
169+
Capability.V1_CREATE_NAMESPACE,
170+
Capability.V1_UPDATE_NAMESPACE,
171+
Capability.V1_DELETE_NAMESPACE,
172+
Capability.V1_LIST_TABLES,
173+
Capability.V1_LOAD_TABLE,
174+
Capability.V1_CREATE_TABLE,
175+
Capability.V1_UPDATE_TABLE,
176+
Capability.V1_DELETE_TABLE,
177+
Capability.V1_RENAME_TABLE,
178+
Capability.V1_REGISTER_TABLE,
179+
)
180+
)
181+
182+
# View endpoints conditionally added based on VIEW_ENDPOINTS_SUPPORTED property.
183+
VIEW_ENDPOINTS: frozenset[Endpoint] = frozenset(
184+
(
185+
Capability.V1_LIST_VIEWS,
186+
Capability.V1_DELETE_VIEW,
187+
)
188+
)
189+
190+
103191
class IdentifierKind(Enum):
104192
TABLE = "table"
105193
VIEW = "view"
@@ -134,6 +222,10 @@ class IdentifierKind(Enum):
134222
CUSTOM = "custom"
135223
REST_SCAN_PLANNING_ENABLED = "rest-scan-planning-enabled"
136224
REST_SCAN_PLANNING_ENABLED_DEFAULT = False
225+
# for backwards compatibility with older REST servers where it can be assumed that a particular
226+
# server supports view endpoints but doesn't send the "endpoints" field in the ConfigResponse
227+
VIEW_ENDPOINTS_SUPPORTED = "view-endpoints-supported"
228+
VIEW_ENDPOINTS_SUPPORTED_DEFAULT = False
137229

138230
NAMESPACE_SEPARATOR = b"\x1f".decode(UTF8)
139231

@@ -180,6 +272,14 @@ class RegisterTableRequest(IcebergBaseModel):
180272
class ConfigResponse(IcebergBaseModel):
181273
defaults: Properties | None = Field(default_factory=dict)
182274
overrides: Properties | None = Field(default_factory=dict)
275+
endpoints: set[Endpoint] | None = Field(default=None)
276+
277+
@field_validator("endpoints", mode="before")
278+
@classmethod
279+
def _parse_endpoints(cls, v: list[str] | None) -> set[Endpoint] | None:
280+
if v is None:
281+
return None
282+
return {Endpoint.from_string(s) for s in v}
183283

184284

185285
class ListNamespaceResponse(IcebergBaseModel):
@@ -218,6 +318,7 @@ class ListViewsResponse(IcebergBaseModel):
218318
class RestCatalog(Catalog):
219319
uri: str
220320
_session: Session
321+
_supported_endpoints: set[Endpoint]
221322

222323
def __init__(self, name: str, **properties: str):
223324
"""Rest Catalog.
@@ -279,7 +380,9 @@ def is_rest_scan_planning_enabled(self) -> bool:
279380
Returns:
280381
True if enabled, False otherwise.
281382
"""
282-
return property_as_bool(self.properties, REST_SCAN_PLANNING_ENABLED, REST_SCAN_PLANNING_ENABLED_DEFAULT)
383+
return Capability.V1_SUBMIT_TABLE_SCAN_PLAN in self._supported_endpoints and property_as_bool(
384+
self.properties, REST_SCAN_PLANNING_ENABLED, REST_SCAN_PLANNING_ENABLED_DEFAULT
385+
)
283386

284387
def _create_legacy_oauth2_auth_manager(self, session: Session) -> AuthManager:
285388
"""Create the LegacyOAuth2AuthManager by fetching required properties.
@@ -327,6 +430,18 @@ def url(self, endpoint: str, prefixed: bool = True, **kwargs: Any) -> str:
327430

328431
return url + endpoint.format(**kwargs)
329432

433+
def _check_endpoint(self, endpoint: Endpoint) -> None:
434+
"""Check if an endpoint is supported by the server.
435+
436+
Args:
437+
endpoint: The endpoint to check against the set of supported endpoints
438+
439+
Raises:
440+
NotImplementedError: If the endpoint is not supported.
441+
"""
442+
if endpoint not in self._supported_endpoints:
443+
raise NotImplementedError(f"Server does not support endpoint: {endpoint}")
444+
330445
@property
331446
def auth_url(self) -> str:
332447
self._warn_oauth_tokens_deprecation()
@@ -384,6 +499,17 @@ def _fetch_config(self) -> None:
384499
# Update URI based on overrides
385500
self.uri = config[URI]
386501

502+
# Determine supported endpoints
503+
endpoints = config_response.endpoints
504+
if endpoints:
505+
self._supported_endpoints = set(endpoints)
506+
else:
507+
# Use default endpoints for legacy servers that don't return endpoints
508+
self._supported_endpoints = set(DEFAULT_ENDPOINTS)
509+
# Conditionally add view endpoints based on config
510+
if property_as_bool(self.properties, VIEW_ENDPOINTS_SUPPORTED, VIEW_ENDPOINTS_SUPPORTED_DEFAULT):
511+
self._supported_endpoints.update(VIEW_ENDPOINTS)
512+
387513
def _identifier_to_validated_tuple(self, identifier: str | Identifier) -> Identifier:
388514
identifier_tuple = self.identifier_to_tuple(identifier)
389515
if len(identifier_tuple) <= 1:
@@ -503,6 +629,7 @@ def _create_table(
503629
properties: Properties = EMPTY_DICT,
504630
stage_create: bool = False,
505631
) -> TableResponse:
632+
self._check_endpoint(Capability.V1_CREATE_TABLE)
506633
iceberg_schema = self._convert_schema_if_needed(
507634
schema,
508635
int(properties.get(TableProperties.FORMAT_VERSION, TableProperties.DEFAULT_FORMAT_VERSION)), # type: ignore
@@ -591,6 +718,7 @@ def register_table(self, identifier: str | Identifier, metadata_location: str) -
591718
Raises:
592719
TableAlreadyExistsError: If the table already exists
593720
"""
721+
self._check_endpoint(Capability.V1_REGISTER_TABLE)
594722
namespace_and_table = self._split_identifier_for_path(identifier)
595723
request = RegisterTableRequest(
596724
name=namespace_and_table["table"],
@@ -611,6 +739,7 @@ def register_table(self, identifier: str | Identifier, metadata_location: str) -
611739

612740
@retry(**_RETRY_ARGS)
613741
def list_tables(self, namespace: str | Identifier) -> list[Identifier]:
742+
self._check_endpoint(Capability.V1_LIST_TABLES)
614743
namespace_tuple = self._check_valid_namespace_identifier(namespace)
615744
namespace_concat = NAMESPACE_SEPARATOR.join(namespace_tuple)
616745
response = self._session.get(self.url(Endpoints.list_tables, namespace=namespace_concat))
@@ -622,6 +751,7 @@ def list_tables(self, namespace: str | Identifier) -> list[Identifier]:
622751

623752
@retry(**_RETRY_ARGS)
624753
def load_table(self, identifier: str | Identifier) -> Table:
754+
self._check_endpoint(Capability.V1_LOAD_TABLE)
625755
params = {}
626756
if mode := self.properties.get(SNAPSHOT_LOADING_MODE):
627757
if mode in {"all", "refs"}:
@@ -642,6 +772,7 @@ def load_table(self, identifier: str | Identifier) -> Table:
642772

643773
@retry(**_RETRY_ARGS)
644774
def drop_table(self, identifier: str | Identifier, purge_requested: bool = False) -> None:
775+
self._check_endpoint(Capability.V1_DELETE_TABLE)
645776
response = self._session.delete(
646777
self.url(Endpoints.drop_table, prefixed=True, **self._split_identifier_for_path(identifier)),
647778
params={"purgeRequested": purge_requested},
@@ -657,6 +788,7 @@ def purge_table(self, identifier: str | Identifier) -> None:
657788

658789
@retry(**_RETRY_ARGS)
659790
def rename_table(self, from_identifier: str | Identifier, to_identifier: str | Identifier) -> Table:
791+
self._check_endpoint(Capability.V1_RENAME_TABLE)
660792
payload = {
661793
"source": self._split_identifier_for_json(from_identifier),
662794
"destination": self._split_identifier_for_json(to_identifier),
@@ -692,6 +824,8 @@ def _remove_catalog_name_from_table_request_identifier(self, table_request: Comm
692824

693825
@retry(**_RETRY_ARGS)
694826
def list_views(self, namespace: str | Identifier) -> list[Identifier]:
827+
if Capability.V1_LIST_VIEWS not in self._supported_endpoints:
828+
return []
695829
namespace_tuple = self._check_valid_namespace_identifier(namespace)
696830
namespace_concat = NAMESPACE_SEPARATOR.join(namespace_tuple)
697831
response = self._session.get(self.url(Endpoints.list_views, namespace=namespace_concat))
@@ -720,6 +854,7 @@ def commit_table(
720854
CommitFailedException: Requirement not met, or a conflict with a concurrent commit.
721855
CommitStateUnknownException: Failed due to an internal exception on the side of the catalog.
722856
"""
857+
self._check_endpoint(Capability.V1_UPDATE_TABLE)
723858
identifier = table.name()
724859
table_identifier = TableIdentifier(namespace=identifier[:-1], name=identifier[-1])
725860
table_request = CommitTableRequest(identifier=table_identifier, requirements=requirements, updates=updates)
@@ -749,6 +884,7 @@ def commit_table(
749884

750885
@retry(**_RETRY_ARGS)
751886
def create_namespace(self, namespace: str | Identifier, properties: Properties = EMPTY_DICT) -> None:
887+
self._check_endpoint(Capability.V1_CREATE_NAMESPACE)
752888
namespace_tuple = self._check_valid_namespace_identifier(namespace)
753889
payload = {"namespace": namespace_tuple, "properties": properties}
754890
response = self._session.post(self.url(Endpoints.create_namespace), json=payload)
@@ -759,6 +895,7 @@ def create_namespace(self, namespace: str | Identifier, properties: Properties =
759895

760896
@retry(**_RETRY_ARGS)
761897
def drop_namespace(self, namespace: str | Identifier) -> None:
898+
self._check_endpoint(Capability.V1_DELETE_NAMESPACE)
762899
namespace_tuple = self._check_valid_namespace_identifier(namespace)
763900
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
764901
response = self._session.delete(self.url(Endpoints.drop_namespace, namespace=namespace))
@@ -769,6 +906,7 @@ def drop_namespace(self, namespace: str | Identifier) -> None:
769906

770907
@retry(**_RETRY_ARGS)
771908
def list_namespaces(self, namespace: str | Identifier = ()) -> list[Identifier]:
909+
self._check_endpoint(Capability.V1_LIST_NAMESPACES)
772910
namespace_tuple = self.identifier_to_tuple(namespace)
773911
response = self._session.get(
774912
self.url(
@@ -786,6 +924,7 @@ def list_namespaces(self, namespace: str | Identifier = ()) -> list[Identifier]:
786924

787925
@retry(**_RETRY_ARGS)
788926
def load_namespace_properties(self, namespace: str | Identifier) -> Properties:
927+
self._check_endpoint(Capability.V1_LOAD_NAMESPACE)
789928
namespace_tuple = self._check_valid_namespace_identifier(namespace)
790929
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
791930
response = self._session.get(self.url(Endpoints.load_namespace_metadata, namespace=namespace))
@@ -800,6 +939,7 @@ def load_namespace_properties(self, namespace: str | Identifier) -> Properties:
800939
def update_namespace_properties(
801940
self, namespace: str | Identifier, removals: set[str] | None = None, updates: Properties = EMPTY_DICT
802941
) -> PropertiesUpdateSummary:
942+
self._check_endpoint(Capability.V1_UPDATE_NAMESPACE)
803943
namespace_tuple = self._check_valid_namespace_identifier(namespace)
804944
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
805945
payload = {"removals": list(removals or []), "updates": updates}
@@ -819,6 +959,14 @@ def update_namespace_properties(
819959
def namespace_exists(self, namespace: str | Identifier) -> bool:
820960
namespace_tuple = self._check_valid_namespace_identifier(namespace)
821961
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
962+
# fallback in order to work with older rest catalog implementations
963+
if Capability.V1_NAMESPACE_EXISTS not in self._supported_endpoints:
964+
try:
965+
self.load_namespace_properties(namespace_tuple)
966+
return True
967+
except NoSuchNamespaceError:
968+
return False
969+
822970
response = self._session.head(self.url(Endpoints.namespace_exists, namespace=namespace))
823971

824972
if response.status_code == 404:
@@ -843,6 +991,14 @@ def table_exists(self, identifier: str | Identifier) -> bool:
843991
Returns:
844992
bool: True if the table exists, False otherwise.
845993
"""
994+
# fallback in order to work with older rest catalog implementations
995+
if Capability.V1_TABLE_EXISTS not in self._supported_endpoints:
996+
try:
997+
self.load_table(identifier)
998+
return True
999+
except NoSuchTableError:
1000+
return False
1001+
8461002
response = self._session.head(
8471003
self.url(Endpoints.load_table, prefixed=True, **self._split_identifier_for_path(identifier))
8481004
)
@@ -886,6 +1042,7 @@ def view_exists(self, identifier: str | Identifier) -> bool:
8861042

8871043
@retry(**_RETRY_ARGS)
8881044
def drop_view(self, identifier: str) -> None:
1045+
self._check_endpoint(Capability.V1_DELETE_VIEW)
8891046
response = self._session.delete(
8901047
self.url(Endpoints.drop_view, prefixed=True, **self._split_identifier_for_path(identifier, IdentifierKind.VIEW)),
8911048
)

0 commit comments

Comments
 (0)