Skip to content

Commit fec2a10

Browse files
Manisha4Manisha4
andauthored
fix: Add retry with exponential backoff for online store writes (streaming) (#347)
* fix(streaming): add retry with exponential backoff for online store writes * fix linting error * fix(build): pin setuptools_scm<10 to avoid vcs_versioning dependency * addressing PR comments, making timeout check more explicit and adding retry unit tests * fixing linting * fixing linting --------- Co-authored-by: Manisha4 <Manisha4@github.com>
1 parent a0596d8 commit fec2a10

5 files changed

Lines changed: 390 additions & 6 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ requires = [
246246
"pybindgen==0.22.0",
247247
# https://amitylearning.vercel.app/?question=git-1742419854670&update=1742342400027
248248
"setuptools>=60,<81",
249-
"setuptools_scm>=6.2",
249+
"setuptools_scm>=6.2,<10",
250250
"sphinx!=4.0.0",
251251
"wheel",
252252
]

sdk/python/feast/infra/contrib/spark_kafka_processor.py

Lines changed: 118 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import logging
12
import os
3+
import random
4+
import time
25
from types import MethodType
3-
from typing import List, Optional, Set, Union, no_type_check
6+
from typing import Callable, List, Optional, Set, Union, no_type_check
47

58
import pandas as pd
69
from pyspark import SparkContext
@@ -23,6 +26,89 @@
2326
from feast.sorted_feature_view import SortedFeatureView
2427
from feast.stream_feature_view import StreamFeatureView
2528

29+
logger = logging.getLogger(__name__)
30+
31+
# Patterns that indicate transient errors which should be retried
32+
TRANSIENT_ERROR_PATTERNS = [
33+
"writetimeout",
34+
"readtimeout",
35+
"unavailable",
36+
"operationtimedout",
37+
"nohostsavailable",
38+
"connection refused",
39+
"connection reset",
40+
"overloaded",
41+
]
42+
43+
44+
def _is_transient_error(exc: Exception) -> bool:
45+
"""Check if an exception is a transient error that should be retried."""
46+
exc_str = str(exc).lower()
47+
exc_type = type(exc).__name__.lower()
48+
49+
for pattern in TRANSIENT_ERROR_PATTERNS:
50+
if pattern in exc_str or pattern in exc_type:
51+
return True
52+
return False
53+
54+
55+
def _write_with_retry(
56+
write_fn: Callable[[], None],
57+
operation_name: str,
58+
max_retries: int = 3,
59+
base_delay: float = 1.0,
60+
max_delay: float = 30.0,
61+
) -> None:
62+
"""
63+
Execute a write function with exponential backoff retry for transient errors.
64+
65+
Args:
66+
write_fn: The write function to execute
67+
operation_name: Name of the operation for logging
68+
max_retries: Maximum number of retry attempts
69+
base_delay: Base delay in seconds for exponential backoff
70+
max_delay: Maximum delay in seconds between retries
71+
72+
Raises:
73+
Exception: The last exception if all retries are exhausted or if a
74+
non-transient error occurs
75+
"""
76+
for attempt in range(max_retries + 1):
77+
try:
78+
write_fn()
79+
if attempt > 0:
80+
logger.info(
81+
f"[{operation_name}] Succeeded after {attempt} retry attempt(s)"
82+
)
83+
return # Success
84+
except Exception as e:
85+
if not _is_transient_error(e):
86+
# Permanent error - don't retry, bubble up immediately
87+
logger.error(
88+
f"[{operation_name}] Permanent error (not retrying): "
89+
f"{type(e).__name__}: {e}"
90+
)
91+
raise
92+
93+
if attempt < max_retries:
94+
# Calculate delay with exponential backoff + jitter
95+
delay = min(base_delay * (2**attempt), max_delay)
96+
jitter = random.uniform(0, delay * 0.1)
97+
total_delay = delay + jitter
98+
99+
logger.warning(
100+
f"[{operation_name}] Transient error, retry {attempt + 1}/{max_retries} "
101+
f"after {total_delay:.2f}s: {type(e).__name__}: {e}"
102+
)
103+
time.sleep(total_delay)
104+
else:
105+
# Max retries exceeded - bubble up the exception
106+
logger.error(
107+
f"[{operation_name}] Max retries ({max_retries}) exceeded: "
108+
f"{type(e).__name__}: {e}"
109+
)
110+
raise
111+
26112

27113
class SparkProcessorConfig(ProcessorConfig):
28114
"""spark_kafka_options, schema_registry_config and checkpoint_location are only used for ConfluentAvroFormat"""
@@ -279,11 +365,28 @@ def batch_write(row: DataFrame, batch_id: int):
279365
rows = self.preprocess_fn(rows)
280366

281367
# Finally persist the data to the online store and/or offline store.
368+
# Use retry with exponential backoff for transient errors.
282369
if rows.size > 0:
283370
if to == PushMode.ONLINE or to == PushMode.ONLINE_AND_OFFLINE:
284-
self.fs.write_to_online_store(self.sfv.name, rows)
371+
_write_with_retry(
372+
write_fn=lambda: self.fs.write_to_online_store(
373+
self.sfv.name, rows
374+
),
375+
operation_name=f"write_to_online_store[{self.sfv.name}][batch_id={batch_id}]",
376+
max_retries=3,
377+
base_delay=1.0,
378+
max_delay=30.0,
379+
)
285380
if to == PushMode.OFFLINE or to == PushMode.ONLINE_AND_OFFLINE:
286-
self.fs.write_to_offline_store(self.sfv.name, rows)
381+
_write_with_retry(
382+
write_fn=lambda: self.fs.write_to_offline_store(
383+
self.sfv.name, rows
384+
),
385+
operation_name=f"write_to_offline_store[{self.sfv.name}][batch_id={batch_id}]",
386+
max_retries=3,
387+
base_delay=1.0,
388+
max_delay=30.0,
389+
)
287390

288391
query = (
289392
df.writeStream.outputMode("update")
@@ -293,5 +396,16 @@ def batch_write(row: DataFrame, batch_id: int):
293396
.start()
294397
)
295398

296-
query.awaitTermination(timeout=self.query_timeout)
399+
terminated = query.awaitTermination(timeout=self.query_timeout)
400+
401+
if terminated:
402+
# Query terminated before timeout - check if it was an error
403+
# This ensures exceptions from batch_write() bubble up to the caller
404+
query_exception = query.exception()
405+
if query_exception is not None:
406+
logger.error(
407+
f"Streaming query terminated with exception: {query_exception}"
408+
)
409+
raise query_exception
410+
297411
return query

sdk/python/tests/unit/infra/contrib/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)