Skip to content

Commit 91508f8

Browse files
POC: Encryption read support for REST catalog
1 parent 39e08a1 commit 91508f8

File tree

17 files changed

+1480
-23
lines changed

17 files changed

+1480
-23
lines changed

pyiceberg/encryption/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Iceberg table encryption support."""

pyiceberg/encryption/ciphers.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""AES-GCM encryption/decryption primitives and AGS1 stream decryption."""
18+
19+
from __future__ import annotations
20+
21+
import os
22+
import struct
23+
24+
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
25+
26+
NONCE_LENGTH = 12
27+
GCM_TAG_LENGTH = 16
28+
29+
30+
def aes_gcm_encrypt(key: bytes, plaintext: bytes, aad: bytes | None = None) -> bytes:
31+
"""Encrypt using AES-GCM. Returns nonce || ciphertext || tag."""
32+
nonce = os.urandom(NONCE_LENGTH)
33+
aesgcm = AESGCM(key)
34+
ciphertext_with_tag = aesgcm.encrypt(nonce, plaintext, aad)
35+
return nonce + ciphertext_with_tag
36+
37+
38+
def aes_gcm_decrypt(key: bytes, ciphertext: bytes, aad: bytes | None = None) -> bytes:
39+
"""Decrypt AES-GCM data in format: nonce || ciphertext || tag."""
40+
if len(ciphertext) < NONCE_LENGTH + GCM_TAG_LENGTH:
41+
raise ValueError(f"Ciphertext too short: {len(ciphertext)} bytes")
42+
nonce = ciphertext[:NONCE_LENGTH]
43+
encrypted_data = ciphertext[NONCE_LENGTH:]
44+
aesgcm = AESGCM(key)
45+
return aesgcm.decrypt(nonce, encrypted_data, aad)
46+
47+
48+
# AGS1 stream constants
49+
GCM_STREAM_MAGIC = b"AGS1"
50+
GCM_STREAM_HEADER_LENGTH = 8 # 4 magic + 4 block size
51+
52+
53+
def stream_block_aad(aad_prefix: bytes, block_index: int) -> bytes:
54+
"""Construct per-block AAD for AGS1 stream encryption.
55+
56+
Format: aad_prefix || block_index (4 bytes, little-endian).
57+
"""
58+
index_bytes = struct.pack("<I", block_index)
59+
if not aad_prefix:
60+
return index_bytes
61+
return aad_prefix + index_bytes
62+
63+
64+
def decrypt_ags1_stream(key: bytes, encrypted_data: bytes, aad_prefix: bytes) -> bytes:
65+
"""Decrypt an entire AGS1 stream and return the plaintext.
66+
67+
AGS1 format:
68+
- Header: "AGS1" (4 bytes) + plain_block_size (4 bytes LE)
69+
- Blocks: each block is nonce(12) + ciphertext(up to 1MB) + tag(16)
70+
- Each block's AAD = aad_prefix + block_index (4 bytes LE)
71+
72+
"""
73+
if len(encrypted_data) < GCM_STREAM_HEADER_LENGTH:
74+
raise ValueError(f"AGS1 stream too short: {len(encrypted_data)} bytes")
75+
76+
magic = encrypted_data[:4]
77+
if magic != GCM_STREAM_MAGIC:
78+
raise ValueError(f"Invalid AGS1 magic: {magic!r}, expected {GCM_STREAM_MAGIC!r}")
79+
80+
plain_block_size = struct.unpack_from("<I", encrypted_data, 4)[0]
81+
cipher_block_size = plain_block_size + NONCE_LENGTH + GCM_TAG_LENGTH
82+
83+
stream_data = encrypted_data[GCM_STREAM_HEADER_LENGTH:]
84+
if not stream_data:
85+
return b""
86+
87+
aesgcm = AESGCM(key)
88+
result = bytearray()
89+
offset = 0
90+
block_index = 0
91+
92+
while offset < len(stream_data):
93+
# Determine this block's cipher size
94+
remaining = len(stream_data) - offset
95+
if remaining >= cipher_block_size:
96+
block_cipher_size = cipher_block_size
97+
else:
98+
block_cipher_size = remaining
99+
100+
if block_cipher_size < NONCE_LENGTH + GCM_TAG_LENGTH:
101+
raise ValueError(
102+
f"Truncated AGS1 block at offset {offset}: {block_cipher_size} bytes (minimum {NONCE_LENGTH + GCM_TAG_LENGTH})"
103+
)
104+
105+
block_data = stream_data[offset : offset + block_cipher_size]
106+
nonce = block_data[:NONCE_LENGTH]
107+
ciphertext_with_tag = block_data[NONCE_LENGTH:]
108+
109+
aad = stream_block_aad(aad_prefix, block_index)
110+
plaintext = aesgcm.decrypt(nonce, ciphertext_with_tag, aad)
111+
result.extend(plaintext)
112+
113+
offset += block_cipher_size
114+
block_index += 1
115+
116+
return bytes(result)

pyiceberg/encryption/io.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""InputFile implementation backed by in-memory bytes."""
18+
19+
from __future__ import annotations
20+
21+
import io
22+
from types import TracebackType
23+
24+
from pyiceberg.io import InputFile, InputStream
25+
26+
27+
class BytesInputStream(InputStream):
28+
"""InputStream implementation backed by a bytes buffer."""
29+
30+
def __init__(self, data: bytes) -> None:
31+
self._buffer = io.BytesIO(data)
32+
33+
def read(self, size: int = 0) -> bytes:
34+
if size <= 0:
35+
return self._buffer.read()
36+
return self._buffer.read(size)
37+
38+
def seek(self, offset: int, whence: int = 0) -> int:
39+
return self._buffer.seek(offset, whence)
40+
41+
def tell(self) -> int:
42+
return self._buffer.tell()
43+
44+
def close(self) -> None:
45+
self._buffer.close()
46+
47+
def __enter__(self) -> BytesInputStream:
48+
"""Enter the context manager."""
49+
return self
50+
51+
def __exit__(
52+
self,
53+
exctype: type[BaseException] | None,
54+
excinst: BaseException | None,
55+
exctb: TracebackType | None,
56+
) -> None:
57+
"""Exit the context manager and close the stream."""
58+
self.close()
59+
60+
61+
class BytesInputFile(InputFile):
62+
"""InputFile implementation backed by in-memory bytes.
63+
64+
Used to wrap decrypted data so that it can be read by
65+
AvroFile and other readers that expect an InputFile.
66+
"""
67+
68+
def __init__(self, location: str, data: bytes) -> None:
69+
super().__init__(location)
70+
self._data = data
71+
72+
def __len__(self) -> int:
73+
"""Return the length of the underlying data."""
74+
return len(self._data)
75+
76+
def exists(self) -> bool:
77+
return True
78+
79+
def open(self, seekable: bool = True) -> InputStream:
80+
return BytesInputStream(self._data)
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""StandardKeyMetadata Avro serialization.
18+
19+
Wire format: ``0x01 version byte || Avro-encoded fields``
20+
21+
Avro schema:
22+
- encryption_key: bytes (required)
23+
- aad_prefix: union[null, bytes] (optional)
24+
- file_length: union[null, long] (optional)
25+
"""
26+
27+
from __future__ import annotations
28+
29+
from dataclasses import dataclass
30+
31+
V1 = 0x01
32+
33+
34+
def _read_avro_long(data: bytes, offset: int) -> tuple[int, int]:
35+
"""Read a zigzag-encoded Avro long from data at offset. Returns (value, new_offset)."""
36+
result = 0
37+
shift = 0
38+
while True:
39+
if offset >= len(data):
40+
raise ValueError("Unexpected end of Avro data reading long")
41+
b = data[offset]
42+
offset += 1
43+
result |= (b & 0x7F) << shift
44+
if (b & 0x80) == 0:
45+
break
46+
shift += 7
47+
# Zigzag decode
48+
return (result >> 1) ^ -(result & 1), offset
49+
50+
51+
def _read_avro_bytes(data: bytes, offset: int) -> tuple[bytes, int]:
52+
"""Read Avro bytes (length-prefixed). Returns (bytes_value, new_offset)."""
53+
length, offset = _read_avro_long(data, offset)
54+
if length < 0:
55+
raise ValueError(f"Negative Avro bytes length: {length}")
56+
end = offset + length
57+
if end > len(data):
58+
raise ValueError("Unexpected end of Avro data reading bytes")
59+
return data[offset:end], end
60+
61+
62+
@dataclass(frozen=True)
63+
class StandardKeyMetadata:
64+
"""Standard key metadata for Iceberg table encryption.
65+
66+
Contains the plaintext encryption key (DEK), AAD prefix, and optional file length.
67+
"""
68+
69+
encryption_key: bytes
70+
aad_prefix: bytes = b""
71+
file_length: int | None = None
72+
73+
@staticmethod
74+
def deserialize(data: bytes) -> StandardKeyMetadata:
75+
"""Deserialize from wire format: ``0x01 version || Avro-encoded fields``."""
76+
if not data:
77+
raise ValueError("Empty key metadata buffer")
78+
79+
version = data[0]
80+
if version != V1:
81+
raise ValueError(f"Unsupported key metadata version: {version}")
82+
83+
offset = 1
84+
85+
# Read encryption_key (required bytes)
86+
encryption_key, offset = _read_avro_bytes(data, offset)
87+
88+
# Read aad_prefix (optional: union[null, bytes])
89+
union_index, offset = _read_avro_long(data, offset)
90+
if union_index == 0:
91+
aad_prefix = b""
92+
elif union_index == 1:
93+
aad_prefix, offset = _read_avro_bytes(data, offset)
94+
else:
95+
raise ValueError(f"Invalid union index for aad_prefix: {union_index}")
96+
97+
# Read file_length (optional: union[null, long])
98+
file_length = None
99+
if offset < len(data):
100+
union_index, offset = _read_avro_long(data, offset)
101+
if union_index == 0:
102+
file_length = None
103+
elif union_index == 1:
104+
file_length, offset = _read_avro_long(data, offset)
105+
else:
106+
raise ValueError(f"Invalid union index for file_length: {union_index}")
107+
108+
return StandardKeyMetadata(
109+
encryption_key=encryption_key,
110+
aad_prefix=aad_prefix,
111+
file_length=file_length,
112+
)
113+
114+
def serialize(self) -> bytes:
115+
"""Serialize to wire format: ``0x01 version || Avro-encoded fields``."""
116+
parts = [bytes([V1])]
117+
118+
# encryption_key (required bytes)
119+
parts.append(_encode_avro_bytes(self.encryption_key))
120+
121+
# aad_prefix (union[null, bytes])
122+
if self.aad_prefix:
123+
parts.append(_encode_avro_long(1)) # union index 1 = bytes
124+
parts.append(_encode_avro_bytes(self.aad_prefix))
125+
else:
126+
parts.append(_encode_avro_long(0)) # union index 0 = null
127+
128+
# file_length (union[null, long])
129+
if self.file_length is not None:
130+
parts.append(_encode_avro_long(1)) # union index 1 = long
131+
parts.append(_encode_avro_long(self.file_length))
132+
else:
133+
parts.append(_encode_avro_long(0)) # union index 0 = null
134+
135+
return b"".join(parts)
136+
137+
138+
def _encode_avro_long(value: int) -> bytes:
139+
"""Encode a long as zigzag-encoded Avro varint."""
140+
# Zigzag encode
141+
n = (value << 1) ^ (value >> 63)
142+
result = bytearray()
143+
while n & ~0x7F:
144+
result.append((n & 0x7F) | 0x80)
145+
n >>= 7
146+
result.append(n & 0x7F)
147+
return bytes(result)
148+
149+
150+
def _encode_avro_bytes(data: bytes) -> bytes:
151+
"""Encode bytes with Avro length prefix."""
152+
return _encode_avro_long(len(data)) + data

0 commit comments

Comments
 (0)