Skip to content

Commit c68b9b1

Browse files
smaheshwar-pltrSreesh Maheshwarkevinjqliu
authored
Support Location Providers (#1452)
* Skeletal implementation * First attempt at hashing locations * Relocate to table submodule; code and comment improvements * Add unit tests * Remove entropy check * Nit: Prefer `self.table_properties` * Remove special character testing * Add integration tests for writes * Move all `LocationProviders`-related code into locations.py * Nit: tiny for loop refactor * Fix typo * Object storage as default location provider * Update tests/integration/test_writes/test_partitioned_writes.py Co-authored-by: Kevin Liu <kevinjqliu@users.noreply.github.com> * Test entropy in test_object_storage_injects_entropy * Refactor integration tests to use properties and omit when default once * Use a different table property for custom location provision * write.location-provider.py-impl -> write.py-location-provider.impl * Make lint * Move location provider loading into `write_file` for back-compat * Make object storage no longer the default * Add test case for partitioned paths disabled but with no partition special case * Moved constants within ObjectStoreLocationProvider --------- Co-authored-by: Sreesh Maheshwar <smaheshwar@palantir.com> Co-authored-by: Kevin Liu <kevinjqliu@users.noreply.github.com>
1 parent 691740d commit c68b9b1

6 files changed

Lines changed: 355 additions & 8 deletions

File tree

pyiceberg/io/pyarrow.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@
136136
visit,
137137
visit_with_partner,
138138
)
139+
from pyiceberg.table.locations import load_location_provider
139140
from pyiceberg.table.metadata import TableMetadata
140141
from pyiceberg.table.name_mapping import NameMapping, apply_name_mapping
141142
from pyiceberg.transforms import TruncateTransform
@@ -2305,6 +2306,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT
23052306
property_name=TableProperties.PARQUET_ROW_GROUP_LIMIT,
23062307
default=TableProperties.PARQUET_ROW_GROUP_LIMIT_DEFAULT,
23072308
)
2309+
location_provider = load_location_provider(table_location=table_metadata.location, table_properties=table_metadata.properties)
23082310

23092311
def write_parquet(task: WriteTask) -> DataFile:
23102312
table_schema = table_metadata.schema()
@@ -2327,7 +2329,10 @@ def write_parquet(task: WriteTask) -> DataFile:
23272329
for batch in task.record_batches
23282330
]
23292331
arrow_table = pa.Table.from_batches(batches)
2330-
file_path = f"{table_metadata.location}/data/{task.generate_data_file_path('parquet')}"
2332+
file_path = location_provider.new_data_location(
2333+
data_file_name=task.generate_data_file_filename("parquet"),
2334+
partition_key=task.partition_key,
2335+
)
23312336
fo = io.new_output(file_path)
23322337
with fo.create(overwrite=True) as fos:
23332338
with pq.ParquetWriter(fos, schema=arrow_table.schema, **parquet_writer_kwargs) as writer:

pyiceberg/table/__init__.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,14 @@ class TableProperties:
187187
WRITE_PARTITION_SUMMARY_LIMIT = "write.summary.partition-limit"
188188
WRITE_PARTITION_SUMMARY_LIMIT_DEFAULT = 0
189189

190+
WRITE_PY_LOCATION_PROVIDER_IMPL = "write.py-location-provider.impl"
191+
192+
OBJECT_STORE_ENABLED = "write.object-storage.enabled"
193+
OBJECT_STORE_ENABLED_DEFAULT = False
194+
195+
WRITE_OBJECT_STORE_PARTITIONED_PATHS = "write.object-storage.partitioned-paths"
196+
WRITE_OBJECT_STORE_PARTITIONED_PATHS_DEFAULT = True
197+
190198
DELETE_MODE = "write.delete.mode"
191199
DELETE_MODE_COPY_ON_WRITE = "copy-on-write"
192200
DELETE_MODE_MERGE_ON_READ = "merge-on-read"
@@ -1613,13 +1621,6 @@ def generate_data_file_filename(self, extension: str) -> str:
16131621
# https://github.com/apache/iceberg/blob/a582968975dd30ff4917fbbe999f1be903efac02/core/src/main/java/org/apache/iceberg/io/OutputFileFactory.java#L92-L101
16141622
return f"00000-{self.task_id}-{self.write_uuid}.{extension}"
16151623

1616-
def generate_data_file_path(self, extension: str) -> str:
1617-
if self.partition_key:
1618-
file_path = f"{self.partition_key.to_path()}/{self.generate_data_file_filename(extension)}"
1619-
return file_path
1620-
else:
1621-
return self.generate_data_file_filename(extension)
1622-
16231624

16241625
@dataclass(frozen=True)
16251626
class AddFileTask:

pyiceberg/table/locations.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import importlib
18+
import logging
19+
from abc import ABC, abstractmethod
20+
from typing import Optional
21+
22+
import mmh3
23+
24+
from pyiceberg.partitioning import PartitionKey
25+
from pyiceberg.table import TableProperties
26+
from pyiceberg.typedef import Properties
27+
from pyiceberg.utils.properties import property_as_bool
28+
29+
logger = logging.getLogger(__name__)
30+
31+
32+
class LocationProvider(ABC):
33+
"""A base class for location providers, that provide data file locations for write tasks."""
34+
35+
table_location: str
36+
table_properties: Properties
37+
38+
def __init__(self, table_location: str, table_properties: Properties):
39+
self.table_location = table_location
40+
self.table_properties = table_properties
41+
42+
@abstractmethod
43+
def new_data_location(self, data_file_name: str, partition_key: Optional[PartitionKey] = None) -> str:
44+
"""Return a fully-qualified data file location for the given filename.
45+
46+
Args:
47+
data_file_name (str): The name of the data file.
48+
partition_key (Optional[PartitionKey]): The data file's partition key. If None, the data is not partitioned.
49+
50+
Returns:
51+
str: A fully-qualified location URI for the data file.
52+
"""
53+
54+
55+
class SimpleLocationProvider(LocationProvider):
56+
def __init__(self, table_location: str, table_properties: Properties):
57+
super().__init__(table_location, table_properties)
58+
59+
def new_data_location(self, data_file_name: str, partition_key: Optional[PartitionKey] = None) -> str:
60+
prefix = f"{self.table_location}/data"
61+
return f"{prefix}/{partition_key.to_path()}/{data_file_name}" if partition_key else f"{prefix}/{data_file_name}"
62+
63+
64+
class ObjectStoreLocationProvider(LocationProvider):
65+
HASH_BINARY_STRING_BITS = 20
66+
ENTROPY_DIR_LENGTH = 4
67+
ENTROPY_DIR_DEPTH = 3
68+
69+
_include_partition_paths: bool
70+
71+
def __init__(self, table_location: str, table_properties: Properties):
72+
super().__init__(table_location, table_properties)
73+
self._include_partition_paths = property_as_bool(
74+
self.table_properties,
75+
TableProperties.WRITE_OBJECT_STORE_PARTITIONED_PATHS,
76+
TableProperties.WRITE_OBJECT_STORE_PARTITIONED_PATHS_DEFAULT,
77+
)
78+
79+
def new_data_location(self, data_file_name: str, partition_key: Optional[PartitionKey] = None) -> str:
80+
if self._include_partition_paths and partition_key:
81+
return self.new_data_location(f"{partition_key.to_path()}/{data_file_name}")
82+
83+
prefix = f"{self.table_location}/data"
84+
hashed_path = self._compute_hash(data_file_name)
85+
86+
return (
87+
f"{prefix}/{hashed_path}/{data_file_name}"
88+
if self._include_partition_paths
89+
else f"{prefix}/{hashed_path}-{data_file_name}"
90+
)
91+
92+
@staticmethod
93+
def _compute_hash(data_file_name: str) -> str:
94+
# Bitwise AND to combat sign-extension; bitwise OR to preserve leading zeroes that `bin` would otherwise strip.
95+
top_mask = 1 << ObjectStoreLocationProvider.HASH_BINARY_STRING_BITS
96+
hash_code = mmh3.hash(data_file_name) & (top_mask - 1) | top_mask
97+
return ObjectStoreLocationProvider._dirs_from_hash(bin(hash_code)[-ObjectStoreLocationProvider.HASH_BINARY_STRING_BITS :])
98+
99+
@staticmethod
100+
def _dirs_from_hash(file_hash: str) -> str:
101+
"""Divides hash into directories for optimized orphan removal operation using ENTROPY_DIR_DEPTH and ENTROPY_DIR_LENGTH."""
102+
total_entropy_length = ObjectStoreLocationProvider.ENTROPY_DIR_DEPTH * ObjectStoreLocationProvider.ENTROPY_DIR_LENGTH
103+
104+
hash_with_dirs = []
105+
for i in range(0, total_entropy_length, ObjectStoreLocationProvider.ENTROPY_DIR_LENGTH):
106+
hash_with_dirs.append(file_hash[i : i + ObjectStoreLocationProvider.ENTROPY_DIR_LENGTH])
107+
108+
if len(file_hash) > total_entropy_length:
109+
hash_with_dirs.append(file_hash[total_entropy_length:])
110+
111+
return "/".join(hash_with_dirs)
112+
113+
114+
def _import_location_provider(
115+
location_provider_impl: str, table_location: str, table_properties: Properties
116+
) -> Optional[LocationProvider]:
117+
try:
118+
path_parts = location_provider_impl.split(".")
119+
if len(path_parts) < 2:
120+
raise ValueError(
121+
f"{TableProperties.WRITE_PY_LOCATION_PROVIDER_IMPL} should be full path (module.CustomLocationProvider), got: {location_provider_impl}"
122+
)
123+
module_name, class_name = ".".join(path_parts[:-1]), path_parts[-1]
124+
module = importlib.import_module(module_name)
125+
class_ = getattr(module, class_name)
126+
return class_(table_location, table_properties)
127+
except ModuleNotFoundError:
128+
logger.warning("Could not initialize LocationProvider: %s", location_provider_impl)
129+
return None
130+
131+
132+
def load_location_provider(table_location: str, table_properties: Properties) -> LocationProvider:
133+
table_location = table_location.rstrip("/")
134+
135+
if location_provider_impl := table_properties.get(TableProperties.WRITE_PY_LOCATION_PROVIDER_IMPL):
136+
if location_provider := _import_location_provider(location_provider_impl, table_location, table_properties):
137+
logger.info("Loaded LocationProvider: %s", location_provider_impl)
138+
return location_provider
139+
else:
140+
raise ValueError(f"Could not initialize LocationProvider: {location_provider_impl}")
141+
142+
if property_as_bool(table_properties, TableProperties.OBJECT_STORE_ENABLED, TableProperties.OBJECT_STORE_ENABLED_DEFAULT):
143+
return ObjectStoreLocationProvider(table_location, table_properties)
144+
else:
145+
return SimpleLocationProvider(table_location, table_properties)

tests/integration/test_writes/test_partitioned_writes.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pyiceberg.exceptions import NoSuchTableError
2929
from pyiceberg.partitioning import PartitionField, PartitionSpec
3030
from pyiceberg.schema import Schema
31+
from pyiceberg.table import TableProperties
3132
from pyiceberg.transforms import (
3233
BucketTransform,
3334
DayTransform,
@@ -280,6 +281,44 @@ def test_query_filter_v1_v2_append_null(
280281
assert df.where(f"{col} is null").count() == 2, f"Expected 2 null rows for {col}"
281282

282283

284+
@pytest.mark.integration
285+
@pytest.mark.parametrize(
286+
"part_col", ["int", "bool", "string", "string_long", "long", "float", "double", "date", "timestamp", "timestamptz", "binary"]
287+
)
288+
@pytest.mark.parametrize("format_version", [1, 2])
289+
def test_object_storage_location_provider_excludes_partition_path(
290+
session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, part_col: str, format_version: int
291+
) -> None:
292+
nested_field = TABLE_SCHEMA.find_field(part_col)
293+
partition_spec = PartitionSpec(
294+
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=IdentityTransform(), name=part_col)
295+
)
296+
297+
tbl = _create_table(
298+
session_catalog=session_catalog,
299+
identifier=f"default.arrow_table_v{format_version}_with_null_partitioned_on_col_{part_col}",
300+
# write.object-storage.partitioned-paths defaults to True
301+
properties={"format-version": str(format_version), TableProperties.OBJECT_STORE_ENABLED: True},
302+
data=[arrow_table_with_null],
303+
partition_spec=partition_spec,
304+
)
305+
306+
original_paths = tbl.inspect.data_files().to_pydict()["file_path"]
307+
assert len(original_paths) == 3
308+
309+
# Update props to exclude partitioned paths and append data
310+
with tbl.transaction() as tx:
311+
tx.set_properties({TableProperties.WRITE_OBJECT_STORE_PARTITIONED_PATHS: False})
312+
tbl.append(arrow_table_with_null)
313+
314+
added_paths = set(tbl.inspect.data_files().to_pydict()["file_path"]) - set(original_paths)
315+
assert len(added_paths) == 3
316+
317+
# All paths before the props update should contain the partition, while all paths after should not
318+
assert all(f"{part_col}=" in path for path in original_paths)
319+
assert all(f"{part_col}=" not in path for path in added_paths)
320+
321+
283322
@pytest.mark.integration
284323
@pytest.mark.parametrize(
285324
"spec",

tests/integration/test_writes/test_writes.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,33 @@ def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_w
285285
assert [row.deleted_data_files_count for row in rows] == [0, 1, 0, 0, 0]
286286

287287

288+
@pytest.mark.integration
289+
@pytest.mark.parametrize("format_version", [1, 2])
290+
def test_object_storage_data_files(
291+
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
292+
) -> None:
293+
tbl = _create_table(
294+
session_catalog=session_catalog,
295+
identifier="default.object_stored",
296+
properties={"format-version": format_version, TableProperties.OBJECT_STORE_ENABLED: True},
297+
data=[arrow_table_with_null],
298+
)
299+
tbl.append(arrow_table_with_null)
300+
301+
paths = tbl.inspect.data_files().to_pydict()["file_path"]
302+
assert len(paths) == 2
303+
304+
for location in paths:
305+
assert location.startswith("s3://warehouse/default/object_stored/data/")
306+
parts = location.split("/")
307+
assert len(parts) == 11
308+
309+
# Entropy binary directories should have been injected
310+
for dir_name in parts[6:10]:
311+
assert dir_name
312+
assert all(c in "01" for c in dir_name)
313+
314+
288315
@pytest.mark.integration
289316
def test_python_writes_with_spark_snapshot_reads(
290317
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table

0 commit comments

Comments
 (0)