1+ import logging
12import os
3+ import random
4+ import time
25from types import MethodType
3- from typing import List , Optional , Set , Union , no_type_check
6+ from typing import Callable , List , Optional , Set , Union , no_type_check
47
58import pandas as pd
69from pyspark import SparkContext
2326from feast .sorted_feature_view import SortedFeatureView
2427from feast .stream_feature_view import StreamFeatureView
2528
29+ logger = logging .getLogger (__name__ )
30+
31+ # Patterns that indicate transient errors which should be retried
32+ TRANSIENT_ERROR_PATTERNS = [
33+ "writetimeout" ,
34+ "readtimeout" ,
35+ "unavailable" ,
36+ "operationtimedout" ,
37+ "nohostsavailable" ,
38+ "connection refused" ,
39+ "connection reset" ,
40+ "overloaded" ,
41+ ]
42+
43+
44+ def _is_transient_error (exc : Exception ) -> bool :
45+ """Check if an exception is a transient error that should be retried."""
46+ exc_str = str (exc ).lower ()
47+ exc_type = type (exc ).__name__ .lower ()
48+
49+ for pattern in TRANSIENT_ERROR_PATTERNS :
50+ if pattern in exc_str or pattern in exc_type :
51+ return True
52+ return False
53+
54+
55+ def _write_with_retry (
56+ write_fn : Callable [[], None ],
57+ operation_name : str ,
58+ max_retries : int = 3 ,
59+ base_delay : float = 1.0 ,
60+ max_delay : float = 30.0 ,
61+ ) -> None :
62+ """
63+ Execute a write function with exponential backoff retry for transient errors.
64+
65+ Args:
66+ write_fn: The write function to execute
67+ operation_name: Name of the operation for logging
68+ max_retries: Maximum number of retry attempts
69+ base_delay: Base delay in seconds for exponential backoff
70+ max_delay: Maximum delay in seconds between retries
71+
72+ Raises:
73+ Exception: The last exception if all retries are exhausted or if a
74+ non-transient error occurs
75+ """
76+ for attempt in range (max_retries + 1 ):
77+ try :
78+ write_fn ()
79+ if attempt > 0 :
80+ logger .info (
81+ f"[{ operation_name } ] Succeeded after { attempt } retry attempt(s)"
82+ )
83+ return # Success
84+ except Exception as e :
85+ if not _is_transient_error (e ):
86+ # Permanent error - don't retry, bubble up immediately
87+ logger .error (
88+ f"[{ operation_name } ] Permanent error (not retrying): "
89+ f"{ type (e ).__name__ } : { e } "
90+ )
91+ raise
92+
93+ if attempt < max_retries :
94+ # Calculate delay with exponential backoff + jitter
95+ delay = min (base_delay * (2 ** attempt ), max_delay )
96+ jitter = random .uniform (0 , delay * 0.1 )
97+ total_delay = delay + jitter
98+
99+ logger .warning (
100+ f"[{ operation_name } ] Transient error, retry { attempt + 1 } /{ max_retries } "
101+ f"after { total_delay :.2f} s: { type (e ).__name__ } : { e } "
102+ )
103+ time .sleep (total_delay )
104+ else :
105+ # Max retries exceeded - bubble up the exception
106+ logger .error (
107+ f"[{ operation_name } ] Max retries ({ max_retries } ) exceeded: "
108+ f"{ type (e ).__name__ } : { e } "
109+ )
110+ raise
111+
26112
27113class SparkProcessorConfig (ProcessorConfig ):
28114 """spark_kafka_options, schema_registry_config and checkpoint_location are only used for ConfluentAvroFormat"""
@@ -279,11 +365,28 @@ def batch_write(row: DataFrame, batch_id: int):
279365 rows = self .preprocess_fn (rows )
280366
281367 # Finally persist the data to the online store and/or offline store.
368+ # Use retry with exponential backoff for transient errors.
282369 if rows .size > 0 :
283370 if to == PushMode .ONLINE or to == PushMode .ONLINE_AND_OFFLINE :
284- self .fs .write_to_online_store (self .sfv .name , rows )
371+ _write_with_retry (
372+ write_fn = lambda : self .fs .write_to_online_store (
373+ self .sfv .name , rows
374+ ),
375+ operation_name = f"write_to_online_store[{ self .sfv .name } ][batch_id={ batch_id } ]" ,
376+ max_retries = 3 ,
377+ base_delay = 1.0 ,
378+ max_delay = 30.0 ,
379+ )
285380 if to == PushMode .OFFLINE or to == PushMode .ONLINE_AND_OFFLINE :
286- self .fs .write_to_offline_store (self .sfv .name , rows )
381+ _write_with_retry (
382+ write_fn = lambda : self .fs .write_to_offline_store (
383+ self .sfv .name , rows
384+ ),
385+ operation_name = f"write_to_offline_store[{ self .sfv .name } ][batch_id={ batch_id } ]" ,
386+ max_retries = 3 ,
387+ base_delay = 1.0 ,
388+ max_delay = 30.0 ,
389+ )
287390
288391 query = (
289392 df .writeStream .outputMode ("update" )
@@ -293,5 +396,16 @@ def batch_write(row: DataFrame, batch_id: int):
293396 .start ()
294397 )
295398
296- query .awaitTermination (timeout = self .query_timeout )
399+ terminated = query .awaitTermination (timeout = self .query_timeout )
400+
401+ if terminated :
402+ # Query terminated before timeout - check if it was an error
403+ # This ensures exceptions from batch_write() bubble up to the caller
404+ query_exception = query .exception ()
405+ if query_exception is not None :
406+ logger .error (
407+ f"Streaming query terminated with exception: { query_exception } "
408+ )
409+ raise query_exception
410+
297411 return query
0 commit comments