@@ -559,17 +559,33 @@ def _fallback_pandas_aggregation(self, dataset: Dataset, agg_dict: dict) -> Data
559559class RayDedupNode (DAGNode ):
560560 """
561561 Ray node for deduplicating records.
562+
563+ Two dedup strategies are provided:
564+
565+ * **Materialization** (``is_materialization=True``): per-block
566+ ``drop_duplicates``. This is streaming-friendly because it never needs
567+ to see all blocks at once. Any cross-block duplicates are resolved by
568+ the online store, which does an UPSERT and therefore naturally keeps the
569+ last-written value. This avoids the ``groupby().map_groups()`` full
570+ shuffle that would otherwise block until every single block was produced.
571+
572+ * **Historical retrieval** (``is_materialization=False``): global
573+ ``groupby().map_groups()``. Correctness is required here because the
574+ entity-timestamp join must return exactly one feature row per
575+ (entity, query-timestamp) pair.
562576 """
563577
564578 def __init__ (
565579 self ,
566580 name : str ,
567581 column_info ,
568582 config : RayComputeEngineConfig ,
583+ is_materialization : bool = False ,
569584 ):
570585 super ().__init__ (name )
571586 self .column_info = column_info
572587 self .config = config
588+ self .is_materialization = is_materialization
573589
574590 def execute (self , context : ExecutionContext ) -> DAGValue :
575591 """Execute the deduplication operation."""
@@ -581,26 +597,54 @@ def execute(self, context: ExecutionContext) -> DAGValue:
581597 timestamp_col = self .column_info .timestamp_column
582598
583599 if join_keys :
584- available_join_keys = [k for k in join_keys if k in dataset .schema ().names ]
585- available_ts_col = (
586- timestamp_col if timestamp_col in dataset .schema ().names else None
587- )
588-
589- if available_join_keys :
590- # groupby().map_groups() co-locates ALL rows for the same entity
591- # in a single call, so deduplication is always correct regardless
592- # of how Ray splits the dataset into partitions. sort + map_batches
593- # is NOT safe: Ray can place the same entity's rows in different
594- # partitions after a sort, causing surviving duplicates.
595- def _keep_latest_in_group (group : pd .DataFrame ) -> pd .DataFrame :
596- if available_ts_col and available_ts_col in group .columns :
597- group = group .sort_values (available_ts_col , ascending = False )
598- return group .head (1 )
599-
600- dataset = dataset .groupby (available_join_keys ).map_groups (
601- _keep_latest_in_group , batch_format = "pandas"
600+ if self .is_materialization :
601+ # Per-block dedup: streaming-safe, no full shuffle required.
602+ # Cross-block duplicates are handled by the online-store UPSERT.
603+ #
604+ # IMPORTANT: do NOT call dataset.schema() here. For streaming
605+ # datasets backed by slow map_batches actors, .schema() triggers
606+ # eager block execution to
607+ # infer the output type. Those blocks are consumed and LOST —
608+ # they never reach the write stage. We therefore defer the
609+ # column-existence check to inside _dedup_block, which runs in
610+ # a worker per block without interfering with streaming.
611+ _join_keys = list (join_keys )
612+ _ts_col = timestamp_col
613+
614+ def _dedup_block (block : pd .DataFrame ) -> pd .DataFrame :
615+ available = [k for k in _join_keys if k in block .columns ]
616+ if not available :
617+ return block
618+ if _ts_col and _ts_col in block .columns :
619+ block = block .sort_values (_ts_col , ascending = False )
620+ return block .drop_duplicates (subset = available )
621+
622+ dataset = dataset .map_batches (_dedup_block , batch_format = "pandas" )
623+ else :
624+ # Global dedup via groupby: required for historical retrieval
625+ # where the entity–timestamp join must return exactly one row
626+ # per (entity, query-timestamp) pair.
627+ # NOTE: groupby().map_groups() is a full shuffle and blocks
628+ # until ALL upstream blocks are produced. Use only when
629+ # correctness across partition boundaries is mandatory.
630+ available_join_keys = [
631+ k for k in join_keys if k in dataset .schema ().names
632+ ]
633+ available_ts_col = (
634+ timestamp_col if timestamp_col in dataset .schema ().names else None
602635 )
603636
637+ if available_join_keys :
638+
639+ def _keep_latest_in_group (group : pd .DataFrame ) -> pd .DataFrame :
640+ if available_ts_col and available_ts_col in group .columns :
641+ group = group .sort_values (available_ts_col , ascending = False )
642+ return group .head (1 )
643+
644+ dataset = dataset .groupby (available_join_keys ).map_groups (
645+ _keep_latest_in_group , batch_format = "pandas"
646+ )
647+
604648 deduped_dataset = dataset
605649
606650 return DAGValue (
@@ -848,10 +892,19 @@ def write_batch_with_serialized_artifacts(batch: pd.DataFrame) -> pd.DataFrame:
848892
849893 return batch
850894
895+ # Resolve write concurrency from config.
896+ # write_concurrency takes precedence; falls back to max_workers, then 1.
897+ if self .config is not None and self .config .write_concurrency is not None :
898+ _write_concurrency = self .config .write_concurrency
899+ elif self .config is not None and self .config .max_workers is not None :
900+ _write_concurrency = self .config .max_workers
901+ else :
902+ _write_concurrency = 1
903+
851904 written_dataset = dataset .map_batches (
852905 write_batch_with_serialized_artifacts ,
853906 batch_format = "pandas" ,
854- concurrency = self . config . max_workers if self . config else 12 ,
907+ concurrency = _write_concurrency ,
855908 )
856909 written_dataset = written_dataset .materialize ()
857910
0 commit comments