Skip to content

Commit b87a20a

Browse files
Fix code review issues in DIGEST-MD5 delegation token auth
Address all findings from code review: Critical: - Rewrite VInt decoder to match Java WritableUtils.readVLong exactly, using signed-byte interpretation and correct prefix/length semantics High: - Catch OSError (not just FileNotFoundError) when reading token file - Reject unknown auth mechanisms with HiveAuthError instead of silently falling back to unauthenticated TBufferedTransport - Replace monkey-patching sasl.process in _DigestMD5SaslTransport with a clean send_sasl_msg override (thread-safe, no shared state mutation) Medium: - Fix kerberos_service_name default from config key to actual value - Wrap UnicodeDecodeError in HiveAuthError for invalid UTF-8 in tokens - Rewrite VInt test encoder to match real Hadoop encoding format - Fix dead kerberos backward-compat tests to actually exercise __init__ Low: - Add upper bound to pure-sasl dependency (<1.0.0) - Fix tmp_path typing from object to pathlib.Path - Fix docs to say pure-sasl (pip package name) not puresasl Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent af7565a commit b87a20a

File tree

7 files changed

+173
-121
lines changed

7 files changed

+173
-121
lines changed

mkdocs/docs/configuration.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ catalog:
687687
| ugi | t-1234:secret | Hadoop UGI for Hive client. |
688688
| hive.metastore.authentication | DIGEST-MD5 | Auth mechanism: `NONE` (default), `KERBEROS`, or `DIGEST-MD5` |
689689

690-
When using DIGEST-MD5 authentication, PyIceberg reads a Hive delegation token from the file pointed to by the `$HADOOP_TOKEN_FILE_LOCATION` environment variable. This is the standard mechanism used in secure Hadoop environments where delegation tokens are distributed to jobs. Install PyIceberg with `pip install "pyiceberg[hive]"` to get the required `puresasl` dependency.
690+
When using DIGEST-MD5 authentication, PyIceberg reads a Hive delegation token from the file pointed to by the `$HADOOP_TOKEN_FILE_LOCATION` environment variable. This is the standard mechanism used in secure Hadoop environments where delegation tokens are distributed to jobs. Install PyIceberg with `pip install "pyiceberg[hive]"` to get the required `pure-sasl` dependency.
691691

692692
When using Hive 2.x, make sure to set the compatibility flag:
693693

pyiceberg/catalog/hive.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -154,21 +154,8 @@ class _DigestMD5SaslTransport(TTransport.TSaslClientTransport):
154154
coerces ``None`` to ``b""`` so the SASL handshake proceeds normally.
155155
"""
156156

157-
def open(self) -> None:
158-
# Intercept sasl.process to coerce the initial None response
159-
original_process = self.sasl.process
160-
161-
def _patched_process(challenge: bytes | None = None) -> bytes | None:
162-
result = original_process(challenge)
163-
if result is None:
164-
return b""
165-
return result
166-
167-
self.sasl.process = _patched_process
168-
try:
169-
super().open()
170-
finally:
171-
self.sasl.process = original_process
157+
def send_sasl_msg(self, status: int, body: bytes | None) -> None: # type: ignore[override]
158+
super().send_sasl_msg(status, body if body is not None else b"")
172159

173160

174161
class _HiveClient:
@@ -182,7 +169,7 @@ def __init__(
182169
uri: str,
183170
ugi: str | None = None,
184171
kerberos_auth: bool | None = HIVE_KERBEROS_AUTH_DEFAULT,
185-
kerberos_service_name: str | None = HIVE_KERBEROS_SERVICE_NAME,
172+
kerberos_service_name: str | None = HIVE_KERBEROS_SERVICE_NAME_DEFAULT,
186173
auth_mechanism: str | None = None,
187174
):
188175
self._uri = uri
@@ -204,7 +191,9 @@ def _init_thrift_transport(self) -> TTransport:
204191
url_parts = urlparse(self._uri)
205192
socket = TSocket.TSocket(url_parts.hostname, url_parts.port)
206193

207-
if self._auth_mechanism == "KERBEROS":
194+
if self._auth_mechanism == "NONE":
195+
return TTransport.TBufferedTransport(socket)
196+
elif self._auth_mechanism == "KERBEROS":
208197
return TTransport.TSaslClientTransport(socket, host=url_parts.hostname, service=self._kerberos_service_name)
209198
elif self._auth_mechanism == "DIGEST-MD5":
210199
identifier, password = read_hive_delegation_token()
@@ -217,7 +206,10 @@ def _init_thrift_transport(self) -> TTransport:
217206
password=password,
218207
)
219208
else:
220-
return TTransport.TBufferedTransport(socket)
209+
raise HiveAuthError(
210+
f"Unknown auth mechanism: {self._auth_mechanism!r}. "
211+
f"Valid values: NONE, KERBEROS, DIGEST-MD5"
212+
)
221213

222214
def _client(self) -> Client:
223215
protocol = TBinaryProtocol.TBinaryProtocol(self._transport)

pyiceberg/utils/hadoop_credentials.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,29 +36,32 @@
3636

3737

3838
def _read_hadoop_vint(stream: BytesIO) -> int:
39-
"""Decode a Hadoop WritableUtils VInt/VLong from a byte stream."""
39+
"""Decode a Hadoop WritableUtils VInt/VLong from a byte stream.
40+
41+
Matches the encoding in Java's ``WritableUtils.readVInt``/``readVLong``:
42+
- If the first byte (interpreted as signed) is >= -112, it *is* the value.
43+
- Otherwise the first byte encodes both a negativity flag and the number
44+
of additional big-endian payload bytes that carry the actual value.
45+
"""
4046
first = stream.read(1)
4147
if not first:
4248
raise HiveAuthError("Unexpected end of token file while reading VInt")
49+
# Reinterpret as signed byte to match Java's signed-byte semantics
4350
b = first[0]
44-
if b <= 0x7F:
51+
if b > 127:
52+
b -= 256
53+
if b >= -112:
4554
return b
46-
# Number of additional bytes is encoded in leading 1-bits
47-
num_extra = 0
48-
mask = 0x80
49-
while b & mask:
50-
num_extra += 1
51-
mask >>= 1
52-
# First byte contributes the remaining bits
53-
result = b & (mask - 1)
54-
extra = stream.read(num_extra)
55-
if len(extra) != num_extra:
55+
negative = b < -120
56+
length = (-119 - b) if negative else (-111 - b)
57+
extra = stream.read(length)
58+
if len(extra) != length:
5659
raise HiveAuthError("Unexpected end of token file while reading VInt")
57-
for byte in extra:
58-
result = (result << 8) | byte
59-
# Sign-extend if negative (high bit of decoded value is set)
60-
if result >= (1 << (8 * num_extra + (8 - num_extra - 1) - 1)):
61-
result -= 1 << (8 * num_extra + (8 - num_extra - 1))
60+
result = 0
61+
for byte_val in extra:
62+
result = (result << 8) | byte_val
63+
if negative:
64+
result = ~result
6265
return result
6366

6467

@@ -75,7 +78,11 @@ def _read_hadoop_bytes(stream: BytesIO) -> bytes:
7578

7679
def _read_hadoop_text(stream: BytesIO) -> str:
7780
"""Read a VInt-prefixed UTF-8 string from a Hadoop token stream."""
78-
return _read_hadoop_bytes(stream).decode("utf-8")
81+
raw = _read_hadoop_bytes(stream)
82+
try:
83+
return raw.decode("utf-8")
84+
except UnicodeDecodeError as e:
85+
raise HiveAuthError(f"Token file contains invalid UTF-8 in text field: {e}") from e
7986

8087

8188
def read_hive_delegation_token() -> tuple[str, str]:
@@ -99,8 +106,8 @@ def read_hive_delegation_token() -> tuple[str, str]:
99106
try:
100107
with open(token_file, "rb") as f:
101108
data = f.read()
102-
except FileNotFoundError:
103-
raise HiveAuthError(f"Hadoop token file not found: {token_file}")
109+
except OSError as e:
110+
raise HiveAuthError(f"Cannot read Hadoop token file {token_file}: {e}") from e
104111

105112
stream = BytesIO(data)
106113

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ polars = ["polars>=1.21.0,<2"]
7676
snappy = ["python-snappy>=0.6.0,<1.0.0"]
7777
hive = [
7878
"thrift>=0.13.0,<1.0.0",
79-
"pure-sasl>=0.6.0",
79+
"pure-sasl>=0.6.0,<1.0.0",
8080
]
8181
hive-kerberos = [
8282
"thrift>=0.13.0,<1.0.0",

tests/catalog/test_hive.py

Lines changed: 31 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,15 +1429,13 @@ def test_auth_mechanism_none_creates_buffered_transport_explicit() -> None:
14291429
assert client._auth_mechanism == "NONE"
14301430

14311431

1432-
def test_auth_mechanism_kerberos_resolved() -> None:
1433-
"""When auth_mechanism is KERBEROS, _auth_mechanism is set correctly.
1434-
1435-
We don't fully instantiate because TSaslClientTransport with GSSAPI
1436-
requires the kerberos C module which may not be installed.
1437-
"""
1438-
client = _HiveClient.__new__(_HiveClient)
1439-
client._auth_mechanism = "KERBEROS"
1432+
def test_auth_mechanism_kerberos_resolved(monkeypatch: pytest.MonkeyPatch) -> None:
1433+
"""When auth_mechanism is KERBEROS, _auth_mechanism is resolved correctly."""
1434+
# Stub TSaslClientTransport.__init__ to avoid requiring the kerberos C module
1435+
monkeypatch.setattr(TTransport.TSaslClientTransport, "__init__", lambda *a, **kw: None)
1436+
client = _HiveClient(uri="thrift://localhost:9083", auth_mechanism="KERBEROS")
14401437
assert client._auth_mechanism == "KERBEROS"
1438+
assert isinstance(client._transport, TTransport.TSaslClientTransport)
14411439

14421440

14431441
def test_auth_mechanism_digest_md5_creates_digest_transport(monkeypatch: pytest.MonkeyPatch) -> None:
@@ -1447,12 +1445,12 @@ def test_auth_mechanism_digest_md5_creates_digest_transport(monkeypatch: pytest.
14471445
assert isinstance(client._transport, _DigestMD5SaslTransport)
14481446

14491447

1450-
def test_legacy_kerberos_auth_backward_compat() -> None:
1448+
def test_legacy_kerberos_auth_backward_compat(monkeypatch: pytest.MonkeyPatch) -> None:
14511449
"""Legacy kerberos_auth=True resolves to KERBEROS auth_mechanism."""
1452-
client = _HiveClient.__new__(_HiveClient)
1453-
# Replicate the constructor's mechanism resolution logic
1454-
client._auth_mechanism = "KERBEROS" # what kerberos_auth=True produces
1450+
monkeypatch.setattr(TTransport.TSaslClientTransport, "__init__", lambda *a, **kw: None)
1451+
client = _HiveClient(uri="thrift://localhost:9083", kerberos_auth=True)
14551452
assert client._auth_mechanism == "KERBEROS"
1453+
assert isinstance(client._transport, TTransport.TSaslClientTransport)
14561454

14571455

14581456
def test_auth_mechanism_overrides_kerberos_auth() -> None:
@@ -1462,6 +1460,14 @@ def test_auth_mechanism_overrides_kerberos_auth() -> None:
14621460
assert client._auth_mechanism == "NONE"
14631461

14641462

1463+
def test_auth_mechanism_unknown_raises() -> None:
1464+
"""Unknown auth mechanism should raise HiveAuthError, not silently fall back."""
1465+
from pyiceberg.exceptions import HiveAuthError
1466+
1467+
with pytest.raises(HiveAuthError, match="Unknown auth mechanism.*PLAIN"):
1468+
_HiveClient(uri="thrift://localhost:9083", auth_mechanism="PLAIN")
1469+
1470+
14651471
def test_auth_mechanism_case_insensitive(monkeypatch: pytest.MonkeyPatch) -> None:
14661472
"""Auth mechanism should be case-insensitive."""
14671473
monkeypatch.setattr("pyiceberg.catalog.hive.read_hive_delegation_token", _fake_read_token)
@@ -1480,8 +1486,8 @@ def test_create_hive_client_passes_auth_mechanism(monkeypatch: pytest.MonkeyPatc
14801486
assert client._auth_mechanism == "DIGEST-MD5"
14811487

14821488

1483-
def test_digest_md5_transport_coerces_none_to_empty_bytes(monkeypatch: pytest.MonkeyPatch) -> None:
1484-
"""_DigestMD5SaslTransport.open() coerces None initial sasl.process() to b''."""
1489+
def test_digest_md5_transport_send_sasl_msg_coerces_none(monkeypatch: pytest.MonkeyPatch) -> None:
1490+
"""_DigestMD5SaslTransport.send_sasl_msg coerces None body to b''."""
14851491
monkeypatch.setattr("pyiceberg.catalog.hive.read_hive_delegation_token", _fake_read_token)
14861492

14871493
transport = _DigestMD5SaslTransport(
@@ -1493,36 +1499,18 @@ def test_digest_md5_transport_coerces_none_to_empty_bytes(monkeypatch: pytest.Mo
14931499
password="dGVzdC1wdw==",
14941500
)
14951501

1496-
# Build a simple sasl stand-in that returns None on first call
1497-
class FakeSasl:
1498-
def __init__(self) -> None:
1499-
self._call_count = 0
1500-
self.complete = True
1501-
1502-
def process(self, challenge: bytes | None = None) -> bytes | None:
1503-
self._call_count += 1
1504-
if self._call_count == 1:
1505-
return None # DIGEST-MD5 initial response is None
1506-
return b"response-data"
1507-
1508-
original_sasl = transport.sasl
1509-
transport.sasl = FakeSasl() # type: ignore[assignment]
1510-
1511-
# Capture what the patched process returns during open()
1512-
captured_results: list[bytes | None] = []
1502+
# Capture what the parent send_sasl_msg receives
1503+
captured_calls: list[tuple[int, bytes | None]] = []
15131504

1514-
def fake_super_open(self: TTransport.TSaslClientTransport) -> None:
1515-
captured_results.append(self.sasl.process(None))
1516-
captured_results.append(self.sasl.process(b"challenge"))
1505+
def capture_send(self: TTransport.TSaslClientTransport, status: int, body: bytes | None) -> None:
1506+
captured_calls.append((status, body))
15171507

1518-
monkeypatch.setattr(TTransport.TSaslClientTransport, "open", fake_super_open)
1508+
monkeypatch.setattr(TTransport.TSaslClientTransport, "send_sasl_msg", capture_send)
15191509

1520-
try:
1521-
transport.open()
1522-
finally:
1523-
transport.sasl = original_sasl
1510+
# Send with None body — should be coerced to b""
1511+
transport.send_sasl_msg(1, None)
1512+
# Send with real body — should pass through unchanged
1513+
transport.send_sasl_msg(2, b"real-data")
15241514

1525-
# None from first process() should have been coerced to b""
1526-
assert captured_results[0] == b""
1527-
# Non-None result passes through unchanged
1528-
assert captured_results[1] == b"response-data"
1515+
assert captured_calls[0] == (1, b""), "None body should be coerced to b''"
1516+
assert captured_calls[1] == (2, b"real-data"), "Non-None body should pass through"

0 commit comments

Comments
 (0)