Skip to content

Commit b2db501

Browse files
drop REST mentions and trim verbose comments/docstrings
1 parent 53c9bf9 commit b2db501

14 files changed

Lines changed: 80 additions & 310 deletions

File tree

dev/provision.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,7 @@
396396
spark.sql(f"ALTER TABLE {catalog_name}.default.test_empty_scan_ordered_str WRITE ORDERED BY id")
397397
spark.sql(f"INSERT INTO {catalog_name}.default.test_empty_scan_ordered_str VALUES 'a', 'c'")
398398

399-
# Encrypted Iceberg table written via Spark, read back via PyIceberg in tests/integration/test_encryption.py.
400-
# Only the Hive catalog is configured with a Java-side KMS (encryption.kms-impl=UnitestKMS); the REST catalog
401-
# image does not ship UnitestKMS so we limit this fixture to Hive.
399+
# Encrypted Hive-cataloged table; read back via PyIceberg in tests/integration/test_encryption.py.
402400
spark.sql("""
403401
CREATE OR REPLACE TABLE hive.default.test_encrypted (id bigint, data string, value float)
404402
USING iceberg

pyiceberg/encryption/ciphers.py

Lines changed: 16 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
"""AES-GCM encryption/decryption primitives and AGS1 stream decryption."""
17+
"""AES-GCM primitives and Iceberg AGS1 stream decryption."""
1818

1919
from __future__ import annotations
2020

@@ -28,58 +28,37 @@
2828

2929

3030
def aes_gcm_encrypt(key: bytes, plaintext: bytes, aad: bytes | None = None) -> bytes:
31-
"""Encrypt using AES-GCM. Returns nonce || ciphertext || tag."""
3231
nonce = os.urandom(NONCE_LENGTH)
33-
aesgcm = AESGCM(key)
34-
ciphertext_with_tag = aesgcm.encrypt(nonce, plaintext, aad)
35-
return nonce + ciphertext_with_tag
32+
return nonce + AESGCM(key).encrypt(nonce, plaintext, aad)
3633

3734

3835
def aes_gcm_decrypt(key: bytes, ciphertext: bytes, aad: bytes | None = None) -> bytes:
39-
"""Decrypt AES-GCM data in format: nonce || ciphertext || tag."""
4036
if len(ciphertext) < NONCE_LENGTH + GCM_TAG_LENGTH:
4137
raise ValueError(f"Ciphertext too short: {len(ciphertext)} bytes")
42-
nonce = ciphertext[:NONCE_LENGTH]
43-
encrypted_data = ciphertext[NONCE_LENGTH:]
44-
aesgcm = AESGCM(key)
45-
return aesgcm.decrypt(nonce, encrypted_data, aad)
38+
return AESGCM(key).decrypt(ciphertext[:NONCE_LENGTH], ciphertext[NONCE_LENGTH:], aad)
4639

4740

48-
# AGS1 stream constants
4941
GCM_STREAM_MAGIC = b"AGS1"
50-
GCM_STREAM_HEADER_LENGTH = 8 # 4 magic + 4 block size
42+
GCM_STREAM_HEADER_LENGTH = 8 # 4 magic + 4 little-endian block size
5143

5244

5345
def stream_block_aad(aad_prefix: bytes, block_index: int) -> bytes:
54-
"""Construct per-block AAD for AGS1 stream encryption.
55-
56-
Format: aad_prefix || block_index (4 bytes, little-endian).
57-
"""
58-
index_bytes = struct.pack("<I", block_index)
59-
if not aad_prefix:
60-
return index_bytes
61-
return aad_prefix + index_bytes
46+
return aad_prefix + struct.pack("<I", block_index)
6247

6348

6449
def decrypt_ags1_stream(key: bytes, encrypted_data: bytes, aad_prefix: bytes) -> bytes:
65-
"""Decrypt an entire AGS1 stream and return the plaintext.
66-
67-
AGS1 format:
68-
- Header: "AGS1" (4 bytes) + plain_block_size (4 bytes LE)
69-
- Blocks: each block is nonce(12) + ciphertext(up to 1MB) + tag(16)
70-
- Each block's AAD = aad_prefix + block_index (4 bytes LE)
50+
"""Decrypt an Iceberg AGS1 stream.
7151
52+
Layout: "AGS1" (4) | plain_block_size LE (4) | one or more {nonce(12) | cipher | tag(16)} blocks.
53+
Each block's AAD is `aad_prefix || block_index_le32`.
7254
"""
7355
if len(encrypted_data) < GCM_STREAM_HEADER_LENGTH:
7456
raise ValueError(f"AGS1 stream too short: {len(encrypted_data)} bytes")
75-
76-
magic = encrypted_data[:4]
77-
if magic != GCM_STREAM_MAGIC:
78-
raise ValueError(f"Invalid AGS1 magic: {magic!r}, expected {GCM_STREAM_MAGIC!r}")
57+
if encrypted_data[:4] != GCM_STREAM_MAGIC:
58+
raise ValueError(f"Invalid AGS1 magic: {encrypted_data[:4]!r}")
7959

8060
plain_block_size = struct.unpack_from("<I", encrypted_data, 4)[0]
8161
cipher_block_size = plain_block_size + NONCE_LENGTH + GCM_TAG_LENGTH
82-
8362
stream_data = encrypted_data[GCM_STREAM_HEADER_LENGTH:]
8463
if not stream_data:
8564
return b""
@@ -88,28 +67,15 @@ def decrypt_ags1_stream(key: bytes, encrypted_data: bytes, aad_prefix: bytes) ->
8867
result = bytearray()
8968
offset = 0
9069
block_index = 0
91-
9270
while offset < len(stream_data):
93-
# Determine this block's cipher size
94-
remaining = len(stream_data) - offset
95-
if remaining >= cipher_block_size:
96-
block_cipher_size = cipher_block_size
97-
else:
98-
block_cipher_size = remaining
99-
71+
block_cipher_size = min(cipher_block_size, len(stream_data) - offset)
10072
if block_cipher_size < NONCE_LENGTH + GCM_TAG_LENGTH:
101-
raise ValueError(
102-
f"Truncated AGS1 block at offset {offset}: {block_cipher_size} bytes (minimum {NONCE_LENGTH + GCM_TAG_LENGTH})"
103-
)
104-
105-
block_data = stream_data[offset : offset + block_cipher_size]
106-
nonce = block_data[:NONCE_LENGTH]
107-
ciphertext_with_tag = block_data[NONCE_LENGTH:]
108-
109-
aad = stream_block_aad(aad_prefix, block_index)
110-
plaintext = aesgcm.decrypt(nonce, ciphertext_with_tag, aad)
111-
result.extend(plaintext)
73+
raise ValueError(f"Truncated AGS1 block at offset {offset}: {block_cipher_size} bytes")
11274

75+
block = stream_data[offset : offset + block_cipher_size]
76+
result.extend(
77+
aesgcm.decrypt(block[:NONCE_LENGTH], block[NONCE_LENGTH:], stream_block_aad(aad_prefix, block_index))
78+
)
11379
offset += block_cipher_size
11480
block_index += 1
11581

pyiceberg/encryption/io.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
"""InputFile implementation backed by in-memory bytes."""
17+
"""In-memory InputFile/InputStream used to wrap decrypted Avro buffers for AvroFile."""
1818

1919
from __future__ import annotations
2020

@@ -25,8 +25,6 @@
2525

2626

2727
class BytesInputStream(InputStream):
28-
"""InputStream implementation backed by a bytes buffer."""
29-
3028
def __init__(self, data: bytes) -> None:
3129
self._buffer = io.BytesIO(data)
3230

@@ -45,7 +43,6 @@ def close(self) -> None:
4543
self._buffer.close()
4644

4745
def __enter__(self) -> BytesInputStream:
48-
"""Enter the context manager."""
4946
return self
5047

5148
def __exit__(
@@ -54,23 +51,15 @@ def __exit__(
5451
excinst: BaseException | None,
5552
exctb: TracebackType | None,
5653
) -> None:
57-
"""Exit the context manager and close the stream."""
5854
self.close()
5955

6056

6157
class BytesInputFile(InputFile):
62-
"""InputFile implementation backed by in-memory bytes.
63-
64-
Used to wrap decrypted data so that it can be read by
65-
AvroFile and other readers that expect an InputFile.
66-
"""
67-
6858
def __init__(self, location: str, data: bytes) -> None:
6959
super().__init__(location)
7060
self._data = data
7161

7262
def __len__(self) -> int:
73-
"""Return the length of the underlying data."""
7463
return len(self._data)
7564

7665
def exists(self) -> bool:

pyiceberg/encryption/key_metadata.py

Lines changed: 14 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,10 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
"""StandardKeyMetadata Avro serialization.
17+
"""StandardKeyMetadata Avro codec.
1818
19-
Wire format: ``0x01 version byte || Avro-encoded fields``
20-
21-
Avro schema:
22-
- encryption_key: bytes (required)
23-
- aad_prefix: union[null, bytes] (optional)
24-
- file_length: union[null, long] (optional)
19+
Wire: ``0x01 version`` || encryption_key (bytes) || aad_prefix (union[null,bytes])
20+
|| file_length (union[null,long]).
2521
"""
2622

2723
from __future__ import annotations
@@ -32,7 +28,6 @@
3228

3329

3430
def _read_avro_long(data: bytes, offset: int) -> tuple[int, int]:
35-
"""Read a zigzag-encoded Avro long from data at offset. Returns (value, new_offset)."""
3631
result = 0
3732
shift = 0
3833
while True:
@@ -44,12 +39,10 @@ def _read_avro_long(data: bytes, offset: int) -> tuple[int, int]:
4439
if (b & 0x80) == 0:
4540
break
4641
shift += 7
47-
# Zigzag decode
4842
return (result >> 1) ^ -(result & 1), offset
4943

5044

5145
def _read_avro_bytes(data: bytes, offset: int) -> tuple[bytes, int]:
52-
"""Read Avro bytes (length-prefixed). Returns (bytes_value, new_offset)."""
5346
length, offset = _read_avro_long(data, offset)
5447
if length < 0:
5548
raise ValueError(f"Negative Avro bytes length: {length}")
@@ -61,31 +54,20 @@ def _read_avro_bytes(data: bytes, offset: int) -> tuple[bytes, int]:
6154

6255
@dataclass(frozen=True)
6356
class StandardKeyMetadata:
64-
"""Standard key metadata for Iceberg table encryption.
65-
66-
Contains the plaintext encryption key (DEK), AAD prefix, and optional file length.
67-
"""
68-
6957
encryption_key: bytes
7058
aad_prefix: bytes = b""
7159
file_length: int | None = None
7260

7361
@staticmethod
7462
def deserialize(data: bytes) -> StandardKeyMetadata:
75-
"""Deserialize from wire format: ``0x01 version || Avro-encoded fields``."""
7663
if not data:
7764
raise ValueError("Empty key metadata buffer")
78-
79-
version = data[0]
80-
if version != V1:
81-
raise ValueError(f"Unsupported key metadata version: {version}")
82-
65+
if data[0] != V1:
66+
raise ValueError(f"Unsupported key metadata version: {data[0]}")
8367
offset = 1
8468

85-
# Read encryption_key (required bytes)
8669
encryption_key, offset = _read_avro_bytes(data, offset)
8770

88-
# Read aad_prefix (optional: union[null, bytes])
8971
union_index, offset = _read_avro_long(data, offset)
9072
if union_index == 0:
9173
aad_prefix = b""
@@ -94,50 +76,30 @@ def deserialize(data: bytes) -> StandardKeyMetadata:
9476
else:
9577
raise ValueError(f"Invalid union index for aad_prefix: {union_index}")
9678

97-
# Read file_length (optional: union[null, long])
98-
file_length = None
79+
file_length: int | None = None
9980
if offset < len(data):
10081
union_index, offset = _read_avro_long(data, offset)
101-
if union_index == 0:
102-
file_length = None
103-
elif union_index == 1:
82+
if union_index == 1:
10483
file_length, offset = _read_avro_long(data, offset)
105-
else:
84+
elif union_index != 0:
10685
raise ValueError(f"Invalid union index for file_length: {union_index}")
10786

108-
return StandardKeyMetadata(
109-
encryption_key=encryption_key,
110-
aad_prefix=aad_prefix,
111-
file_length=file_length,
112-
)
87+
return StandardKeyMetadata(encryption_key=encryption_key, aad_prefix=aad_prefix, file_length=file_length)
11388

11489
def serialize(self) -> bytes:
115-
"""Serialize to wire format: ``0x01 version || Avro-encoded fields``."""
116-
parts = [bytes([V1])]
117-
118-
# encryption_key (required bytes)
119-
parts.append(_encode_avro_bytes(self.encryption_key))
120-
121-
# aad_prefix (union[null, bytes])
90+
parts = [bytes([V1]), _encode_avro_bytes(self.encryption_key)]
12291
if self.aad_prefix:
123-
parts.append(_encode_avro_long(1)) # union index 1 = bytes
124-
parts.append(_encode_avro_bytes(self.aad_prefix))
92+
parts += [_encode_avro_long(1), _encode_avro_bytes(self.aad_prefix)]
12593
else:
126-
parts.append(_encode_avro_long(0)) # union index 0 = null
127-
128-
# file_length (union[null, long])
94+
parts.append(_encode_avro_long(0))
12995
if self.file_length is not None:
130-
parts.append(_encode_avro_long(1)) # union index 1 = long
131-
parts.append(_encode_avro_long(self.file_length))
96+
parts += [_encode_avro_long(1), _encode_avro_long(self.file_length)]
13297
else:
133-
parts.append(_encode_avro_long(0)) # union index 0 = null
134-
98+
parts.append(_encode_avro_long(0))
13599
return b"".join(parts)
136100

137101

138102
def _encode_avro_long(value: int) -> bytes:
139-
"""Encode a long as zigzag-encoded Avro varint."""
140-
# Zigzag encode
141103
n = (value << 1) ^ (value >> 63)
142104
result = bytearray()
143105
while n & ~0x7F:
@@ -148,5 +110,4 @@ def _encode_avro_long(value: int) -> bytes:
148110

149111

150112
def _encode_avro_bytes(data: bytes) -> bytes:
151-
"""Encode bytes with Avro length prefix."""
152113
return _encode_avro_long(len(data)) + data

pyiceberg/encryption/kms.py

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
"""Key Management Service interfaces and implementations."""
18-
1917
from __future__ import annotations
2018

2119
import importlib
@@ -34,15 +32,11 @@
3432

3533

3634
class KeyManagementClient(ABC):
37-
"""Abstract base class for key management operations."""
38-
3935
@abstractmethod
40-
def wrap_key(self, key: bytes, wrapping_key_id: str) -> bytes:
41-
"""Wrap (encrypt) a key using the master key identified by wrapping_key_id."""
36+
def wrap_key(self, key: bytes, wrapping_key_id: str) -> bytes: ...
4237

4338
@abstractmethod
44-
def unwrap_key(self, wrapped_key: bytes, wrapping_key_id: str) -> bytes:
45-
"""Unwrap (decrypt) a wrapped key using the master key identified by wrapping_key_id."""
39+
def unwrap_key(self, wrapped_key: bytes, wrapping_key_id: str) -> bytes: ...
4640

4741
def initialize(self, properties: dict[str, str]) -> None: # noqa: B027
4842
"""Initialize the KMS client from catalog/table properties."""
@@ -55,39 +49,26 @@ def __init__(self, master_keys: dict[str, bytes] | None = None) -> None:
5549
self._master_keys: dict[str, bytes] = dict(master_keys) if master_keys else {}
5650

5751
def initialize(self, properties: dict[str, str]) -> None:
52+
prefix = "encryption.kms.key."
5853
for key, value in properties.items():
59-
if key.startswith("encryption.kms.key."):
60-
key_id = key[len("encryption.kms.key.") :]
61-
self._master_keys[key_id] = bytes.fromhex(value)
54+
if key.startswith(prefix):
55+
self._master_keys[key[len(prefix) :]] = bytes.fromhex(value)
6256

6357
def wrap_key(self, key: bytes, wrapping_key_id: str) -> bytes:
64-
master_key = self._master_keys.get(wrapping_key_id)
65-
if master_key is None:
66-
raise ValueError(f"Wrapping key not found: {wrapping_key_id}")
67-
return aes_gcm_encrypt(master_key, key, aad=None)
58+
return aes_gcm_encrypt(self._master(wrapping_key_id), key, aad=None)
6859

6960
def unwrap_key(self, wrapped_key: bytes, wrapping_key_id: str) -> bytes:
61+
return aes_gcm_decrypt(self._master(wrapping_key_id), wrapped_key, aad=None)
62+
63+
def _master(self, wrapping_key_id: str) -> bytes:
7064
master_key = self._master_keys.get(wrapping_key_id)
7165
if master_key is None:
7266
raise ValueError(f"Wrapping key not found: {wrapping_key_id}")
73-
return aes_gcm_decrypt(master_key, wrapped_key, aad=None)
67+
return master_key
7468

7569

7670
def load_kms_client(properties: Properties) -> KeyManagementClient | None:
77-
"""Load a KMS client from properties using py-kms-impl.
78-
79-
Follows the same pattern as py-io-impl for FileIO.
80-
81-
The property 'py-kms-impl' should be a fully qualified Python class name
82-
(e.g., 'pyiceberg.encryption.kms.InMemoryKms'). The class must be a
83-
subclass of KeyManagementClient.
84-
85-
Args:
86-
properties: Catalog and/or table properties.
87-
88-
Returns:
89-
An initialized KeyManagementClient, or None if py-kms-impl is not set.
90-
"""
71+
"""Instantiate a KeyManagementClient from a fully-qualified `py-kms-impl` (or return None)."""
9172
kms_impl = properties.get(PY_KMS_IMPL)
9273
if kms_impl is None:
9374
return None

0 commit comments

Comments
 (0)