Skip to content

Commit f57e06d

Browse files
[polish] dataset_test: expand grouped-sequence test to cover multiple sub-feature types
`test_launch_sampler_cluster_grouped_sequence_strip_and_rewrite` used a single LookupFeature sub with `sequence_fields=["cat_key"]` to construct one `list<list<int64>>` candidate input. The synthetic shape (multi-value layered under multi-positive) doesn't reflect HSTUMatch's actual usage where the candidate sequence carries multiple sub-features of varied types (id, raw, ...). Replace with three sub-features under a single `sequence_feature`: * `id_feature item_id` -> `click_seq__item_id: list<int64>` * `id_feature cat_id` -> `click_seq__cat_id: list<int64>` * `raw_feature watch_time` -> `click_seq__watch_time: list<float32>` `attr_fields=["item_id", "cat_id", "watch_time"]` exercises the prefix-prepend rewrite across all three; the strip filter scopes correctly per type (int64 / int64 / float32). Drops the `sequence_fields` complexity (LookupFeature-only knob) and the `cat_map` parquet column (the `assertNotIn("cat_map", ...)` assertion was checking sampler-side state unrelated to the strip filter's behavior). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 2515693 commit f57e06d

1 file changed

Lines changed: 36 additions & 26 deletions

File tree

tzrec/datasets/dataset_test.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -693,23 +693,18 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode):
693693
self.assertEqual(set(hard_neg_indices[:, 0].tolist()), {0, 1, 2, 3, 4, 5, 6, 7})
694694

695695
def test_launch_sampler_cluster_grouped_sequence_strip_and_rewrite(self):
696-
"""Prefix-rewrite + outer-list strip for HSTUMatch-shaped configs.
697-
698-
Bare ``attr_fields=["cat_key"]`` is rewritten to
699-
``["click_seq__cat_key"]``; strip drops the multi-positive outer
700-
list (``list<list<int64>>`` -> ``list<int64>``). ``cat_map``, a
701-
parquet column not in ``attr_fields``, is left untouched.
702-
"""
696+
"""Prefix-rewrite + outer-list strip across multiple sub-feature types."""
703697
f = tempfile.NamedTemporaryFile("w")
704698
self._temp_files.append(f)
705699
f.write("id:int64\tweight:float\tattrs:string\n")
706700
for i in range(100):
707-
f.write(f"{i}\t1.0\t{i}\n")
701+
f.write(f"{i}\t1.0\t{i}:{i + 1000}:0.5\n")
708702
f.flush()
709703

710704
input_fields = [
711-
pa.field(name="cat_map", type=pa.list_(pa.string())),
712-
pa.field(name="click_seq__cat_key", type=pa.list_(pa.list_(pa.int64()))),
705+
pa.field(name="click_seq__item_id", type=pa.list_(pa.int64())),
706+
pa.field(name="click_seq__cat_id", type=pa.list_(pa.int64())),
707+
pa.field(name="click_seq__watch_time", type=pa.list_(pa.float32())),
713708
pa.field(name="label", type=pa.int32()),
714709
]
715710
feature_cfgs = [
@@ -720,23 +715,35 @@ def test_launch_sampler_cluster_grouped_sequence_strip_and_rewrite(self):
720715
sequence_delim=";",
721716
features=[
722717
feature_pb2.SeqFeatureConfig(
723-
lookup_feature=feature_pb2.LookupFeature(
724-
feature_name="lookup_c",
725-
map="item:cat_map",
726-
key="item:cat_key",
727-
sequence_fields=["cat_key"],
718+
id_feature=feature_pb2.IdFeature(
719+
feature_name="item_id",
720+
expression="item:item_id",
721+
num_buckets=200,
722+
embedding_dim=8,
723+
)
724+
),
725+
feature_pb2.SeqFeatureConfig(
726+
id_feature=feature_pb2.IdFeature(
727+
feature_name="cat_id",
728+
expression="item:cat_id",
728729
num_buckets=10,
729730
embedding_dim=8,
730731
)
731732
),
733+
feature_pb2.SeqFeatureConfig(
734+
raw_feature=feature_pb2.RawFeature(
735+
feature_name="watch_time",
736+
expression="item:watch_time",
737+
)
738+
),
732739
],
733740
)
734741
),
735742
]
736743
features = create_features(
737744
feature_cfgs,
738745
fg_mode=data_pb2.FgMode.FG_NORMAL,
739-
neg_fields=["cat_key"],
746+
neg_fields=["item_id", "cat_id", "watch_time"],
740747
force_base_data_group=True,
741748
)
742749
dataset = _TestDataset(
@@ -748,8 +755,9 @@ def test_launch_sampler_cluster_grouped_sequence_strip_and_rewrite(self):
748755
negative_sampler=sampler_pb2.NegativeSampler(
749756
input_path=f.name,
750757
num_sample=4,
751-
attr_fields=["cat_key"], # bare; prefix-rewritten at launch
752-
item_id_field="click_seq__cat_key",
758+
# bare names; prefix-rewritten at launch
759+
attr_fields=["item_id", "cat_id", "watch_time"],
760+
item_id_field="click_seq__item_id",
753761
),
754762
force_base_data_group=True,
755763
),
@@ -764,15 +772,17 @@ def test_launch_sampler_cluster_grouped_sequence_strip_and_rewrite(self):
764772
dataset.launch_sampler_cluster(2)
765773
# Deep-copy guard: data_config not mutated by the rewrite.
766774
self.assertEqual(
767-
list(dataset._data_config.negative_sampler.attr_fields), ["cat_key"]
768-
)
769-
self.assertIn("click_seq__cat_key", dataset._sampler._attr_names)
770-
self.assertNotIn("cat_key", dataset._sampler._attr_names)
771-
cat_key_idx = dataset._sampler._attr_names.index("click_seq__cat_key")
772-
self.assertEqual(
773-
dataset._sampler._attr_types[cat_key_idx], pa.list_(pa.int64())
775+
list(dataset._data_config.negative_sampler.attr_fields),
776+
["item_id", "cat_id", "watch_time"],
774777
)
775-
self.assertNotIn("cat_map", dataset._sampler._attr_names)
778+
# Each bare entry prefix-rewritten and outer-list stripped.
779+
for qualified, expected_type in [
780+
("click_seq__item_id", pa.int64()),
781+
("click_seq__cat_id", pa.int64()),
782+
("click_seq__watch_time", pa.float32()),
783+
]:
784+
idx = dataset._sampler._attr_names.index(qualified)
785+
self.assertEqual(dataset._sampler._attr_types[idx], expected_type)
776786

777787
def test_dataset_with_sample_mask(self):
778788
input_fields = [

0 commit comments

Comments
 (0)