-
Notifications
You must be signed in to change notification settings - Fork 77
[refactor] sampler: derive sequence state from feature configs + accept bare sub-feature attr_fields #520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[refactor] sampler: derive sequence state from feature configs + accept bare sub-feature attr_fields #520
Changes from all commits
ab1d6da
ffae57b
703690f
6834b2f
a80f431
2515693
b0213f8
8f6f4fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -479,10 +479,11 @@ def test_dataset_with_sampler(self, force_base_data_group, mode, input_tile): | |
| def test_dataset_with_sampler_list_item_id(self, mode): | ||
| """E2E: list-typed item_id positives through a real NegativeSampler. | ||
|
|
||
| Schema declares `item_id` as `pa.list_(pa.int64())` (Parquet-style | ||
| multi-value column). Exercises `build_sampler_input`'s list-pass- | ||
| through + flatten, the dynamic-`expand_factor` path in the | ||
| sampler, and `combine_negs_to_candidate_sequence`'s list-typed-negs | ||
| Schema declares `cand_seq__item_id` (grouped sequence sub-feature | ||
| flattened name) as `pa.list_(pa.int64())` (multi-positive column). | ||
| Exercises `build_sampler_input`'s list-pass-through + flatten, | ||
| the dynamic-`expand_factor` path in the sampler, and | ||
| `combine_negs_to_candidate_sequence`'s list-typed-negs | ||
| normalization. Parameterized over Mode.TRAIN / Mode.EVAL so the | ||
| eval-mode `num_eval_sample` path is also covered. | ||
| """ | ||
|
|
@@ -494,18 +495,25 @@ def test_dataset_with_sampler_list_item_id(self, mode): | |
| f.flush() | ||
|
|
||
| input_fields = [ | ||
| pa.field(name="item_id", type=pa.list_(pa.int64())), | ||
| pa.field(name="cand_seq__item_id", type=pa.list_(pa.int64())), | ||
| pa.field(name="label", type=pa.int32()), | ||
| ] | ||
| feature_cfgs = [ | ||
| feature_pb2.FeatureConfig( | ||
| sequence_id_feature=feature_pb2.IdFeature( | ||
| feature_name="item_id", | ||
| expression="item:item_id", | ||
| sequence_feature=feature_pb2.SequenceFeature( | ||
| sequence_name="cand_seq", | ||
| sequence_length=10, | ||
| sequence_delim=";", | ||
| num_buckets=200, | ||
| embedding_dim=8, | ||
| features=[ | ||
| feature_pb2.SeqFeatureConfig( | ||
| id_feature=feature_pb2.IdFeature( | ||
| feature_name="item_id", | ||
| expression="item:item_id", | ||
| num_buckets=200, | ||
| embedding_dim=8, | ||
| ) | ||
| ), | ||
| ], | ||
| ) | ||
| ), | ||
| ] | ||
|
|
@@ -525,8 +533,8 @@ def test_dataset_with_sampler_list_item_id(self, mode): | |
| input_path=f.name, | ||
| num_sample=8, | ||
| num_eval_sample=4, | ||
| attr_fields=["item_id"], | ||
| item_id_field="item_id", | ||
| attr_fields=["item_id"], # bare; gets rewritten | ||
| item_id_field="cand_seq__item_id", # qualified | ||
| ), | ||
| force_base_data_group=True, | ||
| ), | ||
|
|
@@ -540,7 +548,7 @@ def test_dataset_with_sampler_list_item_id(self, mode): | |
| # Multi-positive mode (item_id is sequence-positive in train): | ||
| # launch_sampler_cluster strips the outer list so the sampler emits | ||
| # scalar item_id negs, not 1-elem lists via the multival_sep round-trip. | ||
| item_id_idx = dataset._sampler._attr_names.index("item_id") | ||
| item_id_idx = dataset._sampler._attr_names.index("cand_seq__item_id") | ||
| self.assertEqual(dataset._sampler._attr_types[item_id_idx], pa.int64()) | ||
|
|
||
| dataloader = DataLoader( | ||
|
|
@@ -599,7 +607,7 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode): | |
|
|
||
| input_fields = [ | ||
| pa.field(name="user_id", type=pa.int64()), | ||
| pa.field(name="item_id", type=pa.list_(pa.int64())), | ||
| pa.field(name="cand_seq__item_id", type=pa.list_(pa.int64())), | ||
| pa.field(name="label", type=pa.int32()), | ||
| ] | ||
| feature_cfgs = [ | ||
|
|
@@ -612,13 +620,20 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode): | |
| ) | ||
| ), | ||
| feature_pb2.FeatureConfig( | ||
| sequence_id_feature=feature_pb2.IdFeature( | ||
| feature_name="item_id", | ||
| expression="item:item_id", | ||
| sequence_feature=feature_pb2.SequenceFeature( | ||
| sequence_name="cand_seq", | ||
| sequence_length=10, | ||
| sequence_delim=";", | ||
| num_buckets=200, | ||
| embedding_dim=8, | ||
| features=[ | ||
| feature_pb2.SeqFeatureConfig( | ||
| id_feature=feature_pb2.IdFeature( | ||
| feature_name="item_id", | ||
| expression="item:item_id", | ||
| num_buckets=200, | ||
| embedding_dim=8, | ||
| ) | ||
| ), | ||
| ], | ||
| ) | ||
| ), | ||
| ] | ||
|
|
@@ -641,8 +656,8 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode): | |
| num_sample=8, | ||
| num_eval_sample=4, | ||
| num_hard_sample=2, | ||
| attr_fields=["item_id"], | ||
| item_id_field="item_id", | ||
| attr_fields=["item_id"], # bare; gets rewritten | ||
| item_id_field="cand_seq__item_id", # qualified | ||
| user_id_field="user_id", | ||
| ), | ||
| force_base_data_group=True, | ||
|
|
@@ -655,7 +670,7 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode): | |
| dataset.launch_sampler_cluster(2) | ||
|
|
||
| # Multi-positive mode: outer list stripped so sampler emits scalar negs. | ||
| item_id_idx = dataset._sampler._attr_names.index("item_id") | ||
| item_id_idx = dataset._sampler._attr_names.index("cand_seq__item_id") | ||
| self.assertEqual(dataset._sampler._attr_types[item_id_idx], pa.int64()) | ||
|
|
||
| dataloader = DataLoader( | ||
|
|
@@ -677,27 +692,19 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode): | |
| hard_neg_indices = batch.additional_infos[HARD_NEG_INDICES] | ||
| self.assertEqual(set(hard_neg_indices[:, 0].tolist()), {0, 1, 2, 3, 4, 5, 6, 7}) | ||
|
|
||
| def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): | ||
| """End-to-end strip decisions across the per-attr filter. | ||
|
|
||
| Grouped LookupFeature sub with two item-side inputs; only ``cat_key`` | ||
| is in ``sequence_fields`` so only it enters ``_seq_field_delims``. | ||
| ``cat_key`` is typed ``list<list<int64>>`` (multi-value attr layered | ||
| under multi-positive grouping); after strip it becomes | ||
| ``list<int64>`` (ONE level stripped, not bare-stripped to ``int64``). | ||
| ``cat_map`` is item-side but excluded from ``_seq_field_delims``, | ||
| so it stays ``list<string>`` unchanged. | ||
| """ | ||
| def test_launch_sampler_cluster_grouped_sequence_strip_and_rewrite(self): | ||
| """Prefix-rewrite + outer-list strip across multiple sub-feature types.""" | ||
| f = tempfile.NamedTemporaryFile("w") | ||
| self._temp_files.append(f) | ||
| f.write("id:int64\tweight:float\tattrs:string\n") | ||
| for i in range(100): | ||
| f.write(f"{i}\t1.0\t{i}:{i + 1000}\n") | ||
| f.write(f"{i}\t1.0\t{i}:{i + 1000}:0.5\n") | ||
| f.flush() | ||
|
|
||
| input_fields = [ | ||
| pa.field(name="cat_map", type=pa.list_(pa.string())), | ||
| pa.field(name="click_seq__cat_key", type=pa.list_(pa.list_(pa.int64()))), | ||
| pa.field(name="click_seq__item_id", type=pa.list_(pa.int64())), | ||
| pa.field(name="click_seq__cat_id", type=pa.list_(pa.int64())), | ||
| pa.field(name="click_seq__duration", type=pa.list_(pa.float32())), | ||
| pa.field(name="label", type=pa.int32()), | ||
| ] | ||
| feature_cfgs = [ | ||
|
|
@@ -708,23 +715,35 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): | |
| sequence_delim=";", | ||
| features=[ | ||
| feature_pb2.SeqFeatureConfig( | ||
| lookup_feature=feature_pb2.LookupFeature( | ||
| feature_name="lookup_c", | ||
| map="item:cat_map", | ||
| key="item:cat_key", | ||
| sequence_fields=["cat_key"], | ||
| id_feature=feature_pb2.IdFeature( | ||
| feature_name="item_id", | ||
| expression="item:item_id", | ||
| num_buckets=200, | ||
| embedding_dim=8, | ||
| ) | ||
| ), | ||
| feature_pb2.SeqFeatureConfig( | ||
| id_feature=feature_pb2.IdFeature( | ||
| feature_name="cat_id", | ||
| expression="item:cat_id", | ||
| num_buckets=10, | ||
| embedding_dim=8, | ||
| ) | ||
| ), | ||
| feature_pb2.SeqFeatureConfig( | ||
| raw_feature=feature_pb2.RawFeature( | ||
| feature_name="duration", | ||
| expression="item:duration", | ||
| ) | ||
| ), | ||
| ], | ||
| ) | ||
| ), | ||
| ] | ||
| features = create_features( | ||
| feature_cfgs, | ||
| fg_mode=data_pb2.FgMode.FG_NORMAL, | ||
| neg_fields=["cat_map", "cat_key"], | ||
| neg_fields=["item_id", "cat_id", "duration"], | ||
| force_base_data_group=True, | ||
| ) | ||
| dataset = _TestDataset( | ||
|
|
@@ -736,8 +755,9 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): | |
| negative_sampler=sampler_pb2.NegativeSampler( | ||
| input_path=f.name, | ||
| num_sample=4, | ||
| attr_fields=["cat_map", "click_seq__cat_key"], | ||
| item_id_field="click_seq__cat_key", | ||
| # bare names; prefix-rewritten at launch | ||
| attr_fields=["item_id", "cat_id", "duration"], | ||
| item_id_field="click_seq__item_id", | ||
| ), | ||
| force_base_data_group=True, | ||
| ), | ||
|
|
@@ -746,22 +766,23 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): | |
| input_fields=input_fields, | ||
| mode=Mode.TRAIN, | ||
| ) | ||
| # Narrowed _seq_field_delims excludes the non-sequence item-side input. | ||
| self.assertIn("click_seq__cat_key", dataset._seq_field_delims) | ||
| self.assertNotIn("cat_map", dataset._seq_field_delims) | ||
| self.assertEqual(dataset._sampler_seq_delim, ";") | ||
| self.assertEqual(dataset._sampler_seq_prefix, "click_seq__") | ||
|
|
||
| dataset.launch_sampler_cluster(2) | ||
| # outer guard True (item_id_field is sequence-positive): | ||
| # - cat_key: list<list<int64>> -> list<int64> (one strip). | ||
| # - cat_map: list<string>, not in _seq_field_delims -> unstripped. | ||
| cat_key_idx = dataset._sampler._attr_names.index("click_seq__cat_key") | ||
| cat_map_idx = dataset._sampler._attr_names.index("cat_map") | ||
| self.assertEqual( | ||
| dataset._sampler._attr_types[cat_key_idx], pa.list_(pa.int64()) | ||
| ) | ||
| # Deep-copy guard: data_config not mutated by the rewrite. | ||
| self.assertEqual( | ||
| dataset._sampler._attr_types[cat_map_idx], pa.list_(pa.string()) | ||
| list(dataset._data_config.negative_sampler.attr_fields), | ||
| ["item_id", "cat_id", "duration"], | ||
| ) | ||
| # Each bare entry prefix-rewritten and outer-list stripped. | ||
| for qualified, expected_type in [ | ||
| ("click_seq__item_id", pa.int64()), | ||
| ("click_seq__cat_id", pa.int64()), | ||
| ("click_seq__duration", pa.float32()), | ||
| ]: | ||
| idx = dataset._sampler._attr_names.index(qualified) | ||
| self.assertEqual(dataset._sampler._attr_types[idx], expected_type) | ||
|
Comment on lines
695
to
+785
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two coverage gaps worth a follow-up commit:
|
||
|
|
||
| def test_dataset_with_sample_mask(self): | ||
| input_fields = [ | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.