Skip to content

Commit 6ae3f04

Browse files
zabarnZach Barnett
andauthored
fix: Wrap exception in cassandra on_failure for pickle issue (#345)
* fix: Wrap exception in cassandra on_failure for pickle issue * fix: test issue * fix: update test to actually call cassandra code * fix: change test location * noop * revert last commit * ci: trigger workflow * ci: allow manual workflow dispatch for unit tests * remove manual workflow dispatch --------- Co-authored-by: Zach Barnett <zbarnett@expediagroup.com>
1 parent 472c557 commit 6ae3f04

2 files changed

Lines changed: 71 additions & 2 deletions

File tree

sdk/python/feast/infra/online_stores/cassandra_online_store/cassandra_online_store.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from queue import Queue
3030
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
3131

32+
from cassandra import Timeout
3233
from cassandra.auth import PlainTextAuthProvider
3334
from cassandra.cluster import (
3435
EXEC_PROFILE_DEFAULT,
@@ -406,7 +407,19 @@ def on_success(result, concurrent_queue):
406407

407408
def on_failure(exc, concurrent_queue):
408409
nonlocal ex
409-
ex = exc
410+
# The cassandra-driver's Timeout subclasses (WriteTimeout,
411+
# ReadTimeout) fail to unpickle because __init__ re-translates
412+
# write_type via WriteType.value_to_name, but pickle only
413+
# preserves the formatted message string — so write_type
414+
# defaults to None on reconstruction, causing KeyError.
415+
# Wrap them in a plain Exception so they survive pickle
416+
# round-trip across Spark's multiprocessing boundaries.
417+
if isinstance(exc, Timeout):
418+
ex = Exception(
419+
f"Error writing batch to Cassandra: {type(exc).__name__}: {exc}"
420+
)
421+
else:
422+
ex = exc
410423
concurrent_queue.get_nowait()
411424
logger.exception(f"Error writing a batch: {exc}")
412425

sdk/python/tests/expediagroup/test_cassandra_online_store.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import pickle
12
import textwrap
23
import time
34
from datetime import datetime, timedelta
45

56
import pytest
6-
from cassandra import InvalidRequest
7+
from cassandra import InvalidRequest, WriteTimeout, WriteType
78
from cassandra.cluster import Cluster
89

910
from feast import Entity, Field, FileSource, RepoConfig, ValueType, utils
@@ -326,6 +327,61 @@ def test_validate_invalid_request_error_when_sort_keys_are_null(
326327
== 'Error from server: code=2200 [Invalid query] message="Invalid null value in condition for column int"'
327328
)
328329

330+
def test_on_failure_wraps_timeout_for_pickle_safety(
331+
self,
332+
repo_config: RepoConfig,
333+
online_store: CassandraOnlineStore,
334+
mocker,
335+
):
336+
"""
337+
When a WriteTimeout occurs during online_write_batch, the on_failure
338+
callback must wrap it in a plain Exception so it survives pickle
339+
round-trip across Spark's multiprocessing boundaries.
340+
"""
341+
write_timeout = WriteTimeout("Operation timed out", write_type=WriteType.SIMPLE)
342+
mock_future = mocker.MagicMock()
343+
mock_future.add_callbacks = lambda ok, err: err(write_timeout)
344+
345+
mock_session = mocker.MagicMock()
346+
mock_session.execute_async.return_value = mock_future
347+
mock_session.is_shutdown = False
348+
mock_session.prepare.return_value = mocker.MagicMock()
349+
350+
store = CassandraOnlineStore()
351+
mock_table = mocker.MagicMock()
352+
mock_table.name = "test_fv"
353+
mock_table.ttl = None
354+
355+
entity_key = mocker.MagicMock()
356+
feature_val = mocker.MagicMock()
357+
data = [(entity_key, {"feature1": feature_val}, datetime.utcnow(), None)]
358+
359+
mocker.patch.object(store, "_get_session", return_value=mock_session)
360+
mocker.patch.object(
361+
store, "_get_cql_statement", return_value=mocker.MagicMock()
362+
)
363+
mocker.patch(
364+
"feast.infra.online_stores.cassandra_online_store.cassandra_online_store.serialize_entity_key",
365+
return_value=b"\x00",
366+
)
367+
368+
with pytest.raises(Exception) as exc_info:
369+
store.online_write_batch(
370+
config=repo_config,
371+
table=mock_table,
372+
data=data,
373+
progress=None,
374+
)
375+
376+
raised = exc_info.value
377+
assert type(raised) is Exception
378+
assert "WriteTimeout" in str(raised)
379+
assert "Operation timed out" in str(raised)
380+
381+
# Must survive pickle round-trip
382+
restored = pickle.loads(pickle.dumps(raised))
383+
assert str(restored) == str(raised)
384+
329385
def test_cassandra_online_write_batch_ttl(
330386
self,
331387
cassandra_session,

0 commit comments

Comments
 (0)