@@ -479,10 +479,11 @@ def test_dataset_with_sampler(self, force_base_data_group, mode, input_tile):
479479 def test_dataset_with_sampler_list_item_id (self , mode ):
480480 """E2E: list-typed item_id positives through a real NegativeSampler.
481481
482- Schema declares `item_id` as `pa.list_(pa.int64())` (Parquet-style
483- multi-value column). Exercises `build_sampler_input`'s list-pass-
484- through + flatten, the dynamic-`expand_factor` path in the
485- sampler, and `combine_negs_to_candidate_sequence`'s list-typed-negs
482+ Schema declares `cand_seq__item_id` (grouped sequence sub-feature
483+ flattened name) as `pa.list_(pa.int64())` (multi-positive column).
484+ Exercises `build_sampler_input`'s list-pass-through + flatten,
485+ the dynamic-`expand_factor` path in the sampler, and
486+ `combine_negs_to_candidate_sequence`'s list-typed-negs
486487 normalization. Parameterized over Mode.TRAIN / Mode.EVAL so the
487488 eval-mode `num_eval_sample` path is also covered.
488489 """
@@ -494,18 +495,25 @@ def test_dataset_with_sampler_list_item_id(self, mode):
494495 f .flush ()
495496
496497 input_fields = [
497- pa .field (name = "item_id " , type = pa .list_ (pa .int64 ())),
498+ pa .field (name = "cand_seq__item_id " , type = pa .list_ (pa .int64 ())),
498499 pa .field (name = "label" , type = pa .int32 ()),
499500 ]
500501 feature_cfgs = [
501502 feature_pb2 .FeatureConfig (
502- sequence_id_feature = feature_pb2 .IdFeature (
503- feature_name = "item_id" ,
504- expression = "item:item_id" ,
503+ sequence_feature = feature_pb2 .SequenceFeature (
504+ sequence_name = "cand_seq" ,
505505 sequence_length = 10 ,
506506 sequence_delim = ";" ,
507- num_buckets = 200 ,
508- embedding_dim = 8 ,
507+ features = [
508+ feature_pb2 .SeqFeatureConfig (
509+ id_feature = feature_pb2 .IdFeature (
510+ feature_name = "item_id" ,
511+ expression = "item:item_id" ,
512+ num_buckets = 200 ,
513+ embedding_dim = 8 ,
514+ )
515+ ),
516+ ],
509517 )
510518 ),
511519 ]
@@ -525,8 +533,8 @@ def test_dataset_with_sampler_list_item_id(self, mode):
525533 input_path = f .name ,
526534 num_sample = 8 ,
527535 num_eval_sample = 4 ,
528- attr_fields = ["item_id" ],
529- item_id_field = "item_id" ,
536+ attr_fields = ["item_id" ], # bare; gets rewritten
537+ item_id_field = "cand_seq__item_id" , # qualified
530538 ),
531539 force_base_data_group = True ,
532540 ),
@@ -540,7 +548,7 @@ def test_dataset_with_sampler_list_item_id(self, mode):
540548 # Multi-positive mode (item_id is sequence-positive in train):
541549 # launch_sampler_cluster strips the outer list so the sampler emits
542550 # scalar item_id negs, not 1-elem lists via the multival_sep round-trip.
543- item_id_idx = dataset ._sampler ._attr_names .index ("item_id " )
551+ item_id_idx = dataset ._sampler ._attr_names .index ("cand_seq__item_id " )
544552 self .assertEqual (dataset ._sampler ._attr_types [item_id_idx ], pa .int64 ())
545553
546554 dataloader = DataLoader (
@@ -599,7 +607,7 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode):
599607
600608 input_fields = [
601609 pa .field (name = "user_id" , type = pa .int64 ()),
602- pa .field (name = "item_id " , type = pa .list_ (pa .int64 ())),
610+ pa .field (name = "cand_seq__item_id " , type = pa .list_ (pa .int64 ())),
603611 pa .field (name = "label" , type = pa .int32 ()),
604612 ]
605613 feature_cfgs = [
@@ -612,13 +620,20 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode):
612620 )
613621 ),
614622 feature_pb2 .FeatureConfig (
615- sequence_id_feature = feature_pb2 .IdFeature (
616- feature_name = "item_id" ,
617- expression = "item:item_id" ,
623+ sequence_feature = feature_pb2 .SequenceFeature (
624+ sequence_name = "cand_seq" ,
618625 sequence_length = 10 ,
619626 sequence_delim = ";" ,
620- num_buckets = 200 ,
621- embedding_dim = 8 ,
627+ features = [
628+ feature_pb2 .SeqFeatureConfig (
629+ id_feature = feature_pb2 .IdFeature (
630+ feature_name = "item_id" ,
631+ expression = "item:item_id" ,
632+ num_buckets = 200 ,
633+ embedding_dim = 8 ,
634+ )
635+ ),
636+ ],
622637 )
623638 ),
624639 ]
@@ -641,8 +656,8 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode):
641656 num_sample = 8 ,
642657 num_eval_sample = 4 ,
643658 num_hard_sample = 2 ,
644- attr_fields = ["item_id" ],
645- item_id_field = "item_id" ,
659+ attr_fields = ["item_id" ], # bare; gets rewritten
660+ item_id_field = "cand_seq__item_id" , # qualified
646661 user_id_field = "user_id" ,
647662 ),
648663 force_base_data_group = True ,
@@ -655,7 +670,7 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode):
655670 dataset .launch_sampler_cluster (2 )
656671
657672 # Multi-positive mode: outer list stripped so sampler emits scalar negs.
658- item_id_idx = dataset ._sampler ._attr_names .index ("item_id " )
673+ item_id_idx = dataset ._sampler ._attr_names .index ("cand_seq__item_id " )
659674 self .assertEqual (dataset ._sampler ._attr_types [item_id_idx ], pa .int64 ())
660675
661676 dataloader = DataLoader (
@@ -677,27 +692,19 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode):
677692 hard_neg_indices = batch .additional_infos [HARD_NEG_INDICES ]
678693 self .assertEqual (set (hard_neg_indices [:, 0 ].tolist ()), {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 })
679694
680- def test_launch_sampler_cluster_multi_attr_strip_decision_matrix (self ):
681- """End-to-end strip decisions across the per-attr filter.
682-
683- Grouped LookupFeature sub with two item-side inputs; only ``cat_key``
684- is in ``sequence_fields`` so only it enters ``_seq_field_delims``.
685- ``cat_key`` is typed ``list<list<int64>>`` (multi-value attr layered
686- under multi-positive grouping); after strip it becomes
687- ``list<int64>`` (ONE level stripped, not bare-stripped to ``int64``).
688- ``cat_map`` is item-side but excluded from ``_seq_field_delims``,
689- so it stays ``list<string>`` unchanged.
690- """
695+ def test_launch_sampler_cluster_grouped_sequence_strip_and_rewrite (self ):
696+ """Prefix-rewrite + outer-list strip across multiple sub-feature types."""
691697 f = tempfile .NamedTemporaryFile ("w" )
692698 self ._temp_files .append (f )
693699 f .write ("id:int64\t weight:float\t attrs:string\n " )
694700 for i in range (100 ):
695- f .write (f"{ i } \t 1.0\t { i } :{ i + 1000 } \n " )
701+ f .write (f"{ i } \t 1.0\t { i } :{ i + 1000 } :0.5 \n " )
696702 f .flush ()
697703
698704 input_fields = [
699- pa .field (name = "cat_map" , type = pa .list_ (pa .string ())),
700- pa .field (name = "click_seq__cat_key" , type = pa .list_ (pa .list_ (pa .int64 ()))),
705+ pa .field (name = "click_seq__item_id" , type = pa .list_ (pa .int64 ())),
706+ pa .field (name = "click_seq__cat_id" , type = pa .list_ (pa .int64 ())),
707+ pa .field (name = "click_seq__duration" , type = pa .list_ (pa .float32 ())),
701708 pa .field (name = "label" , type = pa .int32 ()),
702709 ]
703710 feature_cfgs = [
@@ -708,23 +715,35 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self):
708715 sequence_delim = ";" ,
709716 features = [
710717 feature_pb2 .SeqFeatureConfig (
711- lookup_feature = feature_pb2 .LookupFeature (
712- feature_name = "lookup_c" ,
713- map = "item:cat_map" ,
714- key = "item:cat_key" ,
715- sequence_fields = ["cat_key" ],
718+ id_feature = feature_pb2 .IdFeature (
719+ feature_name = "item_id" ,
720+ expression = "item:item_id" ,
721+ num_buckets = 200 ,
722+ embedding_dim = 8 ,
723+ )
724+ ),
725+ feature_pb2 .SeqFeatureConfig (
726+ id_feature = feature_pb2 .IdFeature (
727+ feature_name = "cat_id" ,
728+ expression = "item:cat_id" ,
716729 num_buckets = 10 ,
717730 embedding_dim = 8 ,
718731 )
719732 ),
733+ feature_pb2 .SeqFeatureConfig (
734+ raw_feature = feature_pb2 .RawFeature (
735+ feature_name = "duration" ,
736+ expression = "item:duration" ,
737+ )
738+ ),
720739 ],
721740 )
722741 ),
723742 ]
724743 features = create_features (
725744 feature_cfgs ,
726745 fg_mode = data_pb2 .FgMode .FG_NORMAL ,
727- neg_fields = ["cat_map " , "cat_key " ],
746+ neg_fields = ["item_id " , "cat_id" , "duration " ],
728747 force_base_data_group = True ,
729748 )
730749 dataset = _TestDataset (
@@ -736,8 +755,9 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self):
736755 negative_sampler = sampler_pb2 .NegativeSampler (
737756 input_path = f .name ,
738757 num_sample = 4 ,
739- attr_fields = ["cat_map" , "click_seq__cat_key" ],
740- item_id_field = "click_seq__cat_key" ,
758+ # bare names; prefix-rewritten at launch
759+ attr_fields = ["item_id" , "cat_id" , "duration" ],
760+ item_id_field = "click_seq__item_id" ,
741761 ),
742762 force_base_data_group = True ,
743763 ),
@@ -746,22 +766,23 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self):
746766 input_fields = input_fields ,
747767 mode = Mode .TRAIN ,
748768 )
749- # Narrowed _seq_field_delims excludes the non-sequence item-side input.
750- self .assertIn ("click_seq__cat_key" , dataset ._seq_field_delims )
751- self .assertNotIn ("cat_map" , dataset ._seq_field_delims )
769+ self .assertEqual (dataset ._sampler_seq_delim , ";" )
770+ self .assertEqual (dataset ._sampler_seq_prefix , "click_seq__" )
752771
753772 dataset .launch_sampler_cluster (2 )
754- # outer guard True (item_id_field is sequence-positive):
755- # - cat_key: list<list<int64>> -> list<int64> (one strip).
756- # - cat_map: list<string>, not in _seq_field_delims -> unstripped.
757- cat_key_idx = dataset ._sampler ._attr_names .index ("click_seq__cat_key" )
758- cat_map_idx = dataset ._sampler ._attr_names .index ("cat_map" )
759- self .assertEqual (
760- dataset ._sampler ._attr_types [cat_key_idx ], pa .list_ (pa .int64 ())
761- )
773+ # Deep-copy guard: data_config not mutated by the rewrite.
762774 self .assertEqual (
763- dataset ._sampler ._attr_types [cat_map_idx ], pa .list_ (pa .string ())
775+ list (dataset ._data_config .negative_sampler .attr_fields ),
776+ ["item_id" , "cat_id" , "duration" ],
764777 )
778+ # Each bare entry prefix-rewritten and outer-list stripped.
779+ for qualified , expected_type in [
780+ ("click_seq__item_id" , pa .int64 ()),
781+ ("click_seq__cat_id" , pa .int64 ()),
782+ ("click_seq__duration" , pa .float32 ()),
783+ ]:
784+ idx = dataset ._sampler ._attr_names .index (qualified )
785+ self .assertEqual (dataset ._sampler ._attr_types [idx ], expected_type )
765786
766787 def test_dataset_with_sample_mask (self ):
767788 input_fields = [
0 commit comments