Skip to content

Commit c07972d

Browse files
committed
fix: Fix streaming materialization for exotic sources with lazy UDF pipelines
Signed-off-by: ntkathole <nikhilkathole2683@gmail.com>
1 parent 312eea3 commit c07972d

6 files changed

Lines changed: 271 additions & 51 deletions

File tree

sdk/python/feast/infra/common/serde.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ def unserialize(self):
3030
# unserialize
3131
proto = FeatureViewProto()
3232
proto.ParseFromString(self.feature_view_proto)
33-
feature_view = FeatureView.from_proto(proto)
33+
# skip_udf=True: the write node only needs schema / entity metadata.
34+
feature_view = FeatureView.from_proto(proto, skip_udf=True)
3435

3536
# load
3637
repo_config = dill.loads(self.repo_config_byte)

sdk/python/feast/infra/compute_engines/ray/config.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,20 @@ class RayComputeEngineConfig(FeastConfigBaseModel):
4141

4242
# Additional configuration options
4343
max_workers: Optional[int] = None
44-
"""Maximum number of Ray workers. If None, uses all available cores."""
44+
"""Maximum number of Ray workers for transformation and join nodes.
45+
If None, Ray uses all available cores."""
46+
47+
write_concurrency: Optional[int] = None
48+
"""Concurrency for the RayWriteNode's map_batches call (online-store writes).
49+
If None, falls back to max_workers, then 1 (safe default
50+
for single-file stores).
51+
52+
Example - SQLite online store (default for local deployments):
53+
write_concurrency: 1
54+
55+
Example - Redis / DynamoDB online store (supports parallel writes):
56+
write_concurrency: 8
57+
"""
4558

4659
enable_optimization: bool = True
4760
"""Enable automatic performance optimizations."""

sdk/python/feast/infra/compute_engines/ray/feature_builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def build_dedup_node(self, view, input_node):
136136
name="dedup",
137137
column_info=column_info,
138138
config=self.config,
139+
is_materialization=self.is_materialization,
139140
)
140141
node.add_input(input_node)
141142

sdk/python/feast/infra/compute_engines/ray/nodes.py

Lines changed: 72 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -559,17 +559,33 @@ def _fallback_pandas_aggregation(self, dataset: Dataset, agg_dict: dict) -> Data
559559
class RayDedupNode(DAGNode):
560560
"""
561561
Ray node for deduplicating records.
562+
563+
Two dedup strategies are provided:
564+
565+
* **Materialization** (``is_materialization=True``): per-block
566+
``drop_duplicates``. This is streaming-friendly because it never needs
567+
to see all blocks at once. Any cross-block duplicates are resolved by
568+
the online store, which does an UPSERT and therefore naturally keeps the
569+
last-written value. This avoids the ``groupby().map_groups()`` full
570+
shuffle that would otherwise block until every single block was produced.
571+
572+
* **Historical retrieval** (``is_materialization=False``): global
573+
``groupby().map_groups()``. Correctness is required here because the
574+
entity-timestamp join must return exactly one feature row per
575+
(entity, query-timestamp) pair.
562576
"""
563577

564578
def __init__(
565579
self,
566580
name: str,
567581
column_info,
568582
config: RayComputeEngineConfig,
583+
is_materialization: bool = False,
569584
):
570585
super().__init__(name)
571586
self.column_info = column_info
572587
self.config = config
588+
self.is_materialization = is_materialization
573589

574590
def execute(self, context: ExecutionContext) -> DAGValue:
575591
"""Execute the deduplication operation."""
@@ -581,26 +597,54 @@ def execute(self, context: ExecutionContext) -> DAGValue:
581597
timestamp_col = self.column_info.timestamp_column
582598

583599
if join_keys:
584-
available_join_keys = [k for k in join_keys if k in dataset.schema().names]
585-
available_ts_col = (
586-
timestamp_col if timestamp_col in dataset.schema().names else None
587-
)
588-
589-
if available_join_keys:
590-
# groupby().map_groups() co-locates ALL rows for the same entity
591-
# in a single call, so deduplication is always correct regardless
592-
# of how Ray splits the dataset into partitions. sort + map_batches
593-
# is NOT safe: Ray can place the same entity's rows in different
594-
# partitions after a sort, causing surviving duplicates.
595-
def _keep_latest_in_group(group: pd.DataFrame) -> pd.DataFrame:
596-
if available_ts_col and available_ts_col in group.columns:
597-
group = group.sort_values(available_ts_col, ascending=False)
598-
return group.head(1)
599-
600-
dataset = dataset.groupby(available_join_keys).map_groups(
601-
_keep_latest_in_group, batch_format="pandas"
600+
if self.is_materialization:
601+
# Per-block dedup: streaming-safe, no full shuffle required.
602+
# Cross-block duplicates are handled by the online-store UPSERT.
603+
#
604+
# IMPORTANT: do NOT call dataset.schema() here. For streaming
605+
# datasets backed by slow map_batches actors, .schema() triggers
606+
# eager block execution to
607+
# infer the output type. Those blocks are consumed and LOST —
608+
# they never reach the write stage. We therefore defer the
609+
# column-existence check to inside _dedup_block, which runs in
610+
# a worker per block without interfering with streaming.
611+
_join_keys = list(join_keys)
612+
_ts_col = timestamp_col
613+
614+
def _dedup_block(block: pd.DataFrame) -> pd.DataFrame:
615+
available = [k for k in _join_keys if k in block.columns]
616+
if not available:
617+
return block
618+
if _ts_col and _ts_col in block.columns:
619+
block = block.sort_values(_ts_col, ascending=False)
620+
return block.drop_duplicates(subset=available)
621+
622+
dataset = dataset.map_batches(_dedup_block, batch_format="pandas")
623+
else:
624+
# Global dedup via groupby: required for historical retrieval
625+
# where the entity–timestamp join must return exactly one row
626+
# per (entity, query-timestamp) pair.
627+
# NOTE: groupby().map_groups() is a full shuffle and blocks
628+
# until ALL upstream blocks are produced. Use only when
629+
# correctness across partition boundaries is mandatory.
630+
available_join_keys = [
631+
k for k in join_keys if k in dataset.schema().names
632+
]
633+
available_ts_col = (
634+
timestamp_col if timestamp_col in dataset.schema().names else None
602635
)
603636

637+
if available_join_keys:
638+
639+
def _keep_latest_in_group(group: pd.DataFrame) -> pd.DataFrame:
640+
if available_ts_col and available_ts_col in group.columns:
641+
group = group.sort_values(available_ts_col, ascending=False)
642+
return group.head(1)
643+
644+
dataset = dataset.groupby(available_join_keys).map_groups(
645+
_keep_latest_in_group, batch_format="pandas"
646+
)
647+
604648
deduped_dataset = dataset
605649

606650
return DAGValue(
@@ -848,10 +892,19 @@ def write_batch_with_serialized_artifacts(batch: pd.DataFrame) -> pd.DataFrame:
848892

849893
return batch
850894

895+
# Resolve write concurrency from config.
896+
# write_concurrency takes precedence; falls back to max_workers, then 1.
897+
if self.config is not None and self.config.write_concurrency is not None:
898+
_write_concurrency = self.config.write_concurrency
899+
elif self.config is not None and self.config.max_workers is not None:
900+
_write_concurrency = self.config.max_workers
901+
else:
902+
_write_concurrency = 1
903+
851904
written_dataset = dataset.map_batches(
852905
write_batch_with_serialized_artifacts,
853906
batch_format="pandas",
854-
concurrency=self.config.max_workers if self.config else 12,
907+
concurrency=_write_concurrency,
855908
)
856909
written_dataset = written_dataset.materialize()
857910

sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1793,37 +1793,49 @@ def _load_and_filter_dataset_ray(
17931793
if pre_loaded_ds is not None:
17941794
ds = pre_loaded_ds
17951795

1796-
# Normalize the timestamp column BEFORE the filter so that
1797-
# non-Parquet sources (CSV, JSON, SQL) whose raw dataset may
1798-
# contain strings or tz-naive datetimes can be compared against
1799-
# the tz-aware datetime bounds below without raising TypeError.
1800-
# This mirrors what _create_filtered_dataset does for file-based
1801-
# sources as part of its read pipeline.
1802-
if timestamp_field:
1803-
ts_cols_to_norm = [timestamp_field]
1804-
if created_timestamp_column:
1805-
ts_cols_to_norm.append(created_timestamp_column)
1806-
ds = ensure_timestamp_compatibility(ds, ts_cols_to_norm)
1807-
1808-
# Apply time-range filter inline (done by _create_filtered_dataset
1809-
# for path-based sources).
1810-
def _normalize(dt: Optional[datetime]) -> Optional[datetime]:
1811-
return make_tzaware(dt) if dt and dt.tzinfo is None else dt
1812-
1813-
s_date = _normalize(start_date)
1814-
e_date = _normalize(end_date)
1815-
ts_col = timestamp_field
1816-
1817-
if s_date and e_date:
1818-
ds = ds.filter(
1819-
lambda batch, s=s_date, e=e_date, col=ts_col: (
1820-
(batch[col] >= s) & (batch[col] <= e)
1796+
# Normalise timestamps and apply time-range filter inside
1797+
# map_batches so that ds.schema() is NEVER called eagerly.
1798+
# Column-existence checks are deferred to each batch so that
1799+
# exotic sources whose timestamp column is synthesised inside a
1800+
# downstream UDF (e.g. HuggingFace image datasets) are handled
1801+
# gracefully: normalization and filtering are simply skipped for
1802+
# batches that do not yet contain the column.
1803+
_ts_field = timestamp_field
1804+
_created_ts = created_timestamp_column
1805+
_s_date = (
1806+
make_tzaware(start_date)
1807+
if start_date and start_date.tzinfo is None
1808+
else start_date
1809+
)
1810+
_e_date = (
1811+
make_tzaware(end_date)
1812+
if end_date and end_date.tzinfo is None
1813+
else end_date
1814+
)
1815+
1816+
def _norm_and_filter(batch: pd.DataFrame) -> pd.DataFrame:
1817+
batch = make_df_tzaware(batch)
1818+
for col in [
1819+
c for c in [_ts_field, _created_ts] if c and c in batch.columns
1820+
]:
1821+
batch[col] = (
1822+
pd.to_datetime(batch[col], utc=True, errors="coerce")
1823+
.dt.floor("s")
1824+
.astype("datetime64[ns, UTC]")
18211825
)
1822-
)
1823-
elif s_date:
1824-
ds = ds.filter(lambda batch, s=s_date, col=ts_col: batch[col] >= s)
1825-
elif e_date:
1826-
ds = ds.filter(lambda batch, e=e_date, col=ts_col: batch[col] <= e)
1826+
if _ts_field and _ts_field in batch.columns:
1827+
if _s_date and _e_date:
1828+
batch = batch[
1829+
(batch[_ts_field] >= _s_date)
1830+
& (batch[_ts_field] <= _e_date)
1831+
]
1832+
elif _s_date:
1833+
batch = batch[batch[_ts_field] >= _s_date]
1834+
elif _e_date:
1835+
batch = batch[batch[_ts_field] <= _e_date]
1836+
return batch
1837+
1838+
ds = ds.map_batches(_norm_and_filter, batch_format="pandas")
18271839
else:
18281840
if not feature_name_columns:
18291841
columns_to_read = None

sdk/python/tests/component/ray/test_nodes.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,146 @@ def test_ray_dedup_node(
299299
assert "driver_id" in result_df.columns
300300

301301

302+
def test_ray_dedup_node_materialization_within_block(
303+
ray_session, ray_config, mock_context, column_info
304+
):
305+
"""Materialization path: within-block duplicates are removed and the row
306+
with the latest event_timestamp is kept.
307+
308+
is_materialization=True uses per-block map_batches (streaming-safe).
309+
No ds.schema() call should be triggered.
310+
"""
311+
now = datetime.now()
312+
older_ts = now - timedelta(hours=3)
313+
newer_ts = now - timedelta(hours=1)
314+
315+
block = pd.DataFrame(
316+
[
317+
{
318+
"driver_id": 1001,
319+
"event_timestamp": older_ts,
320+
"conv_rate": 0.5,
321+
},
322+
{
323+
"driver_id": 1001,
324+
"event_timestamp": newer_ts,
325+
"conv_rate": 0.8,
326+
},
327+
{
328+
"driver_id": 1002,
329+
"event_timestamp": now - timedelta(hours=2),
330+
"conv_rate": 0.7,
331+
},
332+
]
333+
)
334+
335+
ray_dataset = ray.data.from_pandas(block)
336+
input_value = DAGValue(data=ray_dataset, format=DAGFormat.RAY)
337+
dummy_node = DummyInputNode("input_node", input_value)
338+
node = RayDedupNode(
339+
name="dedup",
340+
column_info=column_info,
341+
config=ray_config,
342+
is_materialization=True,
343+
)
344+
node.add_input(dummy_node)
345+
mock_context.node_outputs = {"input_node": input_value}
346+
347+
result = node.execute(mock_context)
348+
result_df = result.data.to_pandas().sort_values("driver_id").reset_index(drop=True)
349+
350+
assert len(result_df) == 2, "One row per entity should survive within the block"
351+
driver_1001 = result_df[result_df["driver_id"] == 1001].iloc[0]
352+
assert driver_1001["event_timestamp"] == newer_ts, (
353+
"Latest timestamp should be kept for driver 1001"
354+
)
355+
356+
357+
def test_ray_dedup_node_materialization_cross_block_duplicates_survive(
358+
ray_session, ray_config, mock_context, column_info
359+
):
360+
"""Materialization path: the same entity in two *different* blocks both
361+
survive — cross-block dedup is delegated to the online-store UPSERT.
362+
363+
This validates the per-block (streaming-safe) semantics: a global shuffle
364+
is intentionally avoided so that slow upstream actors (EasyOCR, CLIP, etc.)
365+
do not need to finish all blocks before writes begin.
366+
"""
367+
now = datetime.now()
368+
block_a = pd.DataFrame(
369+
[
370+
{
371+
"driver_id": 1001,
372+
"event_timestamp": now - timedelta(hours=3),
373+
"conv_rate": 0.5,
374+
}
375+
]
376+
)
377+
block_b = pd.DataFrame(
378+
[
379+
{
380+
"driver_id": 1001,
381+
"event_timestamp": now - timedelta(hours=1),
382+
"conv_rate": 0.8,
383+
}
384+
]
385+
)
386+
387+
# Force two separate Ray blocks by passing a list of DataFrames.
388+
ray_dataset = ray.data.from_pandas([block_a, block_b])
389+
input_value = DAGValue(data=ray_dataset, format=DAGFormat.RAY)
390+
dummy_node = DummyInputNode("input_node", input_value)
391+
node = RayDedupNode(
392+
name="dedup",
393+
column_info=column_info,
394+
config=ray_config,
395+
is_materialization=True,
396+
)
397+
node.add_input(dummy_node)
398+
mock_context.node_outputs = {"input_node": input_value}
399+
400+
result = node.execute(mock_context)
401+
result_df = result.data.to_pandas()
402+
403+
assert len(result_df) == 2, (
404+
"Both blocks should each contribute one row; "
405+
"cross-block dedup is the online store's responsibility"
406+
)
407+
408+
409+
def test_ray_dedup_node_materialization_no_join_keys(
410+
ray_session, ray_config, mock_context, sample_data
411+
):
412+
"""Materialization path: when no join keys are present all rows pass through
413+
unchanged (there is nothing to deduplicate on).
414+
"""
415+
empty_column_info = ColumnInfo(
416+
join_keys=[],
417+
feature_cols=["conv_rate", "acc_rate", "avg_daily_trips"],
418+
ts_col="event_timestamp",
419+
created_ts_col="created",
420+
field_mapping=None,
421+
)
422+
ray_dataset = ray.data.from_pandas(sample_data)
423+
input_value = DAGValue(data=ray_dataset, format=DAGFormat.RAY)
424+
dummy_node = DummyInputNode("input_node", input_value)
425+
node = RayDedupNode(
426+
name="dedup",
427+
column_info=empty_column_info,
428+
config=ray_config,
429+
is_materialization=True,
430+
)
431+
node.add_input(dummy_node)
432+
mock_context.node_outputs = {"input_node": input_value}
433+
434+
result = node.execute(mock_context)
435+
result_df = result.data.to_pandas()
436+
437+
assert len(result_df) == len(sample_data), (
438+
"All rows should survive when there are no join keys to deduplicate on"
439+
)
440+
441+
302442
def test_ray_config_validation():
303443
"""Test Ray configuration validation."""
304444
# Test valid configuration

0 commit comments

Comments
 (0)