Skip to content

Commit a21d26f

Browse files
[refactor] sampler: derive sequence state from feature configs + accept bare sub-feature attr_fields (#520)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent dc236b3 commit a21d26f

7 files changed

Lines changed: 143 additions & 122 deletions

File tree

docs/source/models/hstu_match.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ data_config {
1616
negative_sampler {
1717
input_path: "odps://{PROJECT}/tables/taobao_ad_feature_gl_bucketized_v1"
1818
num_sample: 128
19-
attr_fields: "cand_seq__video_id"
19+
attr_fields: "video_id"
2020
item_id_field: "cand_seq__video_id"
2121
attr_delimiter: "\t"
2222
}
@@ -211,6 +211,8 @@ model_config {
211211
212212
- data_config: 数据配置,其中需要配置负采样 Sampler,负采样 Sampler 的配置详见 [DSSM](dssm.md) 文档中的**负采样配置**章节
213213

214+
- HSTUMatch 的候选侧是 `sequence_feature` 的子特征。在 `negative_sampler` 中,`item_id_field` 写为带 `sequence_name` 前缀的名(例如 `cand_seq__video_id`),`attr_fields` 写为不带前缀的子特征名(例如 `video_id`)。
215+
214216
- feature_groups: 特征组
215217

216218
- uih: 用户历史行为序列,可增加 side info;类型为 JAGGED_SEQUENCE,**必填**

tzrec/datasets/dataset.py

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
import copy
1213
import os
1314
import random
1415
from collections import OrderedDict
@@ -165,27 +166,23 @@ def __init__(
165166
):
166167
self._selected_input_names = None
167168

168-
# Map input_name -> sequence_delim for true sequence inputs only
169-
# (via feature.sequence_input_names). Excludes non-sequence
170-
# sub-inputs of grouped sequence_feature.
171-
self._seq_field_delims: Dict[str, str] = {}
172-
for feature in features:
173-
if not feature.sequence_delim:
174-
continue
175-
seq_inputs = set(feature.sequence_input_names)
176-
for input_name in feature.inputs:
177-
if input_name not in seq_inputs:
169+
# Sequence state when item_id_field is a grouped sequence sub-feature.
170+
self._sampler_seq_delim: str = ""
171+
self._sampler_seq_prefix: str = ""
172+
if self._sampler_item_id_field is not None:
173+
for feature in features:
174+
if self._sampler_item_id_field not in feature.sequence_input_names:
178175
continue
179-
existing = self._seq_field_delims.get(input_name)
180-
if existing is not None and existing != feature.sequence_delim:
181-
logger.warning(
182-
"Conflicting sequence_delim for input '%s': %r vs %r; "
183-
"latter wins.",
184-
input_name,
185-
existing,
186-
feature.sequence_delim,
176+
if not feature.is_grouped_sequence:
177+
raise ValueError(
178+
f"item_id_field '{self._sampler_item_id_field}' is a "
179+
"sequence input but its matching feature is not a "
180+
"grouped sequence; only grouped sequence sub-features "
181+
"are supported as item_id_field."
187182
)
188-
self._seq_field_delims[input_name] = feature.sequence_delim
183+
self._sampler_seq_delim = feature.sequence_delim
184+
self._sampler_seq_prefix = feature.grouped_sequence_prefix
185+
break
189186

190187
self._fg_mode = data_config.fg_mode
191188
self._fg_encoded_multival_sep = data_config.fg_encoded_multival_sep
@@ -211,22 +208,23 @@ def launch_sampler_cluster(
211208
sampler_type = self._data_config.WhichOneof("sampler")
212209
sampler_config = getattr(self._data_config, sampler_type)
213210

214-
# Multi-positive sampling: when the sampler's item_id_field is
215-
# itself a sequence-positive train column, the per-row outer list
216-
# on every item-side attr is the positive-grouping container, not
217-
# a multi-value field. Strip the outer list so the sampler sees
218-
# the pool's native scalar storage and _to_arrow_array emits
219-
# scalar negs directly (avoiding the multival_sep split that
220-
# would wrap each scalar in a 1-elem list).
211+
# Rewrite bare attr_fields to flattened (`video_id` ->
212+
# `cand_seq__video_id`); deep-copy so data_config isn't mutated.
213+
if self._sampler_seq_prefix:
214+
sampler_config = copy.deepcopy(sampler_config)
215+
sampler_config.attr_fields[:] = [
216+
self._sampler_seq_prefix + a for a in sampler_config.attr_fields
217+
]
218+
219+
# Strip the per-row positive-grouping outer list on attr_fields
220+
# columns so the sampler emits scalar negs.
221221
sampler_fields = self.input_fields
222-
if (
223-
self._sampler_item_id_field is not None
224-
and self._sampler_item_id_field in self._seq_field_delims
225-
):
222+
if self._sampler_seq_delim:
223+
sampler_attrs = set(sampler_config.attr_fields)
226224
sampler_fields = [
227225
pa.field(f.name, f.type.value_type)
228226
if (
229-
f.name in self._seq_field_delims
227+
f.name in sampler_attrs
230228
and (pa.types.is_list(f.type) or pa.types.is_large_list(f.type))
231229
)
232230
else f
@@ -391,7 +389,7 @@ def _apply_negative_sampler(
391389
input_data,
392390
self._sampler_item_id_field,
393391
self._sampler_user_id_field,
394-
self._seq_field_delims,
392+
self._sampler_seq_delim,
395393
)
396394
sampled = self._sampler.get(sampler_input)
397395

@@ -459,11 +457,10 @@ def _merge_sampled_features(
459457
) -> Optional[np.ndarray]:
460458
"""Merge sampler outputs into input_data in place; return per-row pos lengths.
461459
462-
Per sampled key: new keys are assigned as-is, keys with a
463-
`seq_delim` use the block-suffix combine, others fall back to
464-
`pa.concat_arrays`. `pos_lengths` is sourced from the configured
465-
item_id_field combine; returns None if no sequence field was
466-
merged.
460+
Per sampled key: new keys assigned as-is; otherwise routed via
461+
`pa.concat_arrays` (scalar item_id_field) or the block-suffix
462+
combine (grouped-sequence item_id_field). `pos_lengths` returns
463+
None when not in sequence mode.
467464
"""
468465
# Prefer item_id_field; fall back to first-seen seq-field if absent.
469466
prefer_key = self._sampler_item_id_field
@@ -473,12 +470,11 @@ def _merge_sampled_features(
473470
if k not in input_data:
474471
input_data[k] = v
475472
continue
476-
seq_delim = self._seq_field_delims.get(k)
477-
if seq_delim is None:
473+
if not self._sampler_seq_delim:
478474
input_data[k] = pa.concat_arrays([input_data[k], v])
479475
continue
480476
combined, pl = combine_negs_to_candidate_sequence(
481-
input_data[k], v, seq_delim
477+
input_data[k], v, self._sampler_seq_delim
482478
)
483479
input_data[k] = combined
484480
if k == prefer_key:

tzrec/datasets/dataset_test.py

Lines changed: 77 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -479,10 +479,11 @@ def test_dataset_with_sampler(self, force_base_data_group, mode, input_tile):
479479
def test_dataset_with_sampler_list_item_id(self, mode):
480480
"""E2E: list-typed item_id positives through a real NegativeSampler.
481481
482-
Schema declares `item_id` as `pa.list_(pa.int64())` (Parquet-style
483-
multi-value column). Exercises `build_sampler_input`'s list-pass-
484-
through + flatten, the dynamic-`expand_factor` path in the
485-
sampler, and `combine_negs_to_candidate_sequence`'s list-typed-negs
482+
Schema declares `cand_seq__item_id` (grouped sequence sub-feature
483+
flattened name) as `pa.list_(pa.int64())` (multi-positive column).
484+
Exercises `build_sampler_input`'s list-pass-through + flatten,
485+
the dynamic-`expand_factor` path in the sampler, and
486+
`combine_negs_to_candidate_sequence`'s list-typed-negs
486487
normalization. Parameterized over Mode.TRAIN / Mode.EVAL so the
487488
eval-mode `num_eval_sample` path is also covered.
488489
"""
@@ -494,18 +495,25 @@ def test_dataset_with_sampler_list_item_id(self, mode):
494495
f.flush()
495496

496497
input_fields = [
497-
pa.field(name="item_id", type=pa.list_(pa.int64())),
498+
pa.field(name="cand_seq__item_id", type=pa.list_(pa.int64())),
498499
pa.field(name="label", type=pa.int32()),
499500
]
500501
feature_cfgs = [
501502
feature_pb2.FeatureConfig(
502-
sequence_id_feature=feature_pb2.IdFeature(
503-
feature_name="item_id",
504-
expression="item:item_id",
503+
sequence_feature=feature_pb2.SequenceFeature(
504+
sequence_name="cand_seq",
505505
sequence_length=10,
506506
sequence_delim=";",
507-
num_buckets=200,
508-
embedding_dim=8,
507+
features=[
508+
feature_pb2.SeqFeatureConfig(
509+
id_feature=feature_pb2.IdFeature(
510+
feature_name="item_id",
511+
expression="item:item_id",
512+
num_buckets=200,
513+
embedding_dim=8,
514+
)
515+
),
516+
],
509517
)
510518
),
511519
]
@@ -525,8 +533,8 @@ def test_dataset_with_sampler_list_item_id(self, mode):
525533
input_path=f.name,
526534
num_sample=8,
527535
num_eval_sample=4,
528-
attr_fields=["item_id"],
529-
item_id_field="item_id",
536+
attr_fields=["item_id"], # bare; gets rewritten
537+
item_id_field="cand_seq__item_id", # qualified
530538
),
531539
force_base_data_group=True,
532540
),
@@ -540,7 +548,7 @@ def test_dataset_with_sampler_list_item_id(self, mode):
540548
# Multi-positive mode (item_id is sequence-positive in train):
541549
# launch_sampler_cluster strips the outer list so the sampler emits
542550
# scalar item_id negs, not 1-elem lists via the multival_sep round-trip.
543-
item_id_idx = dataset._sampler._attr_names.index("item_id")
551+
item_id_idx = dataset._sampler._attr_names.index("cand_seq__item_id")
544552
self.assertEqual(dataset._sampler._attr_types[item_id_idx], pa.int64())
545553

546554
dataloader = DataLoader(
@@ -599,7 +607,7 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode):
599607

600608
input_fields = [
601609
pa.field(name="user_id", type=pa.int64()),
602-
pa.field(name="item_id", type=pa.list_(pa.int64())),
610+
pa.field(name="cand_seq__item_id", type=pa.list_(pa.int64())),
603611
pa.field(name="label", type=pa.int32()),
604612
]
605613
feature_cfgs = [
@@ -612,13 +620,20 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode):
612620
)
613621
),
614622
feature_pb2.FeatureConfig(
615-
sequence_id_feature=feature_pb2.IdFeature(
616-
feature_name="item_id",
617-
expression="item:item_id",
623+
sequence_feature=feature_pb2.SequenceFeature(
624+
sequence_name="cand_seq",
618625
sequence_length=10,
619626
sequence_delim=";",
620-
num_buckets=200,
621-
embedding_dim=8,
627+
features=[
628+
feature_pb2.SeqFeatureConfig(
629+
id_feature=feature_pb2.IdFeature(
630+
feature_name="item_id",
631+
expression="item:item_id",
632+
num_buckets=200,
633+
embedding_dim=8,
634+
)
635+
),
636+
],
622637
)
623638
),
624639
]
@@ -641,8 +656,8 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode):
641656
num_sample=8,
642657
num_eval_sample=4,
643658
num_hard_sample=2,
644-
attr_fields=["item_id"],
645-
item_id_field="item_id",
659+
attr_fields=["item_id"], # bare; gets rewritten
660+
item_id_field="cand_seq__item_id", # qualified
646661
user_id_field="user_id",
647662
),
648663
force_base_data_group=True,
@@ -655,7 +670,7 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode):
655670
dataset.launch_sampler_cluster(2)
656671

657672
# Multi-positive mode: outer list stripped so sampler emits scalar negs.
658-
item_id_idx = dataset._sampler._attr_names.index("item_id")
673+
item_id_idx = dataset._sampler._attr_names.index("cand_seq__item_id")
659674
self.assertEqual(dataset._sampler._attr_types[item_id_idx], pa.int64())
660675

661676
dataloader = DataLoader(
@@ -677,27 +692,19 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode):
677692
hard_neg_indices = batch.additional_infos[HARD_NEG_INDICES]
678693
self.assertEqual(set(hard_neg_indices[:, 0].tolist()), {0, 1, 2, 3, 4, 5, 6, 7})
679694

680-
def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self):
681-
"""End-to-end strip decisions across the per-attr filter.
682-
683-
Grouped LookupFeature sub with two item-side inputs; only ``cat_key``
684-
is in ``sequence_fields`` so only it enters ``_seq_field_delims``.
685-
``cat_key`` is typed ``list<list<int64>>`` (multi-value attr layered
686-
under multi-positive grouping); after strip it becomes
687-
``list<int64>`` (ONE level stripped, not bare-stripped to ``int64``).
688-
``cat_map`` is item-side but excluded from ``_seq_field_delims``,
689-
so it stays ``list<string>`` unchanged.
690-
"""
695+
def test_launch_sampler_cluster_grouped_sequence_strip_and_rewrite(self):
696+
"""Prefix-rewrite + outer-list strip across multiple sub-feature types."""
691697
f = tempfile.NamedTemporaryFile("w")
692698
self._temp_files.append(f)
693699
f.write("id:int64\tweight:float\tattrs:string\n")
694700
for i in range(100):
695-
f.write(f"{i}\t1.0\t{i}:{i + 1000}\n")
701+
f.write(f"{i}\t1.0\t{i}:{i + 1000}:0.5\n")
696702
f.flush()
697703

698704
input_fields = [
699-
pa.field(name="cat_map", type=pa.list_(pa.string())),
700-
pa.field(name="click_seq__cat_key", type=pa.list_(pa.list_(pa.int64()))),
705+
pa.field(name="click_seq__item_id", type=pa.list_(pa.int64())),
706+
pa.field(name="click_seq__cat_id", type=pa.list_(pa.int64())),
707+
pa.field(name="click_seq__duration", type=pa.list_(pa.float32())),
701708
pa.field(name="label", type=pa.int32()),
702709
]
703710
feature_cfgs = [
@@ -708,23 +715,35 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self):
708715
sequence_delim=";",
709716
features=[
710717
feature_pb2.SeqFeatureConfig(
711-
lookup_feature=feature_pb2.LookupFeature(
712-
feature_name="lookup_c",
713-
map="item:cat_map",
714-
key="item:cat_key",
715-
sequence_fields=["cat_key"],
718+
id_feature=feature_pb2.IdFeature(
719+
feature_name="item_id",
720+
expression="item:item_id",
721+
num_buckets=200,
722+
embedding_dim=8,
723+
)
724+
),
725+
feature_pb2.SeqFeatureConfig(
726+
id_feature=feature_pb2.IdFeature(
727+
feature_name="cat_id",
728+
expression="item:cat_id",
716729
num_buckets=10,
717730
embedding_dim=8,
718731
)
719732
),
733+
feature_pb2.SeqFeatureConfig(
734+
raw_feature=feature_pb2.RawFeature(
735+
feature_name="duration",
736+
expression="item:duration",
737+
)
738+
),
720739
],
721740
)
722741
),
723742
]
724743
features = create_features(
725744
feature_cfgs,
726745
fg_mode=data_pb2.FgMode.FG_NORMAL,
727-
neg_fields=["cat_map", "cat_key"],
746+
neg_fields=["item_id", "cat_id", "duration"],
728747
force_base_data_group=True,
729748
)
730749
dataset = _TestDataset(
@@ -736,8 +755,9 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self):
736755
negative_sampler=sampler_pb2.NegativeSampler(
737756
input_path=f.name,
738757
num_sample=4,
739-
attr_fields=["cat_map", "click_seq__cat_key"],
740-
item_id_field="click_seq__cat_key",
758+
# bare names; prefix-rewritten at launch
759+
attr_fields=["item_id", "cat_id", "duration"],
760+
item_id_field="click_seq__item_id",
741761
),
742762
force_base_data_group=True,
743763
),
@@ -746,22 +766,23 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self):
746766
input_fields=input_fields,
747767
mode=Mode.TRAIN,
748768
)
749-
# Narrowed _seq_field_delims excludes the non-sequence item-side input.
750-
self.assertIn("click_seq__cat_key", dataset._seq_field_delims)
751-
self.assertNotIn("cat_map", dataset._seq_field_delims)
769+
self.assertEqual(dataset._sampler_seq_delim, ";")
770+
self.assertEqual(dataset._sampler_seq_prefix, "click_seq__")
752771

753772
dataset.launch_sampler_cluster(2)
754-
# outer guard True (item_id_field is sequence-positive):
755-
# - cat_key: list<list<int64>> -> list<int64> (one strip).
756-
# - cat_map: list<string>, not in _seq_field_delims -> unstripped.
757-
cat_key_idx = dataset._sampler._attr_names.index("click_seq__cat_key")
758-
cat_map_idx = dataset._sampler._attr_names.index("cat_map")
759-
self.assertEqual(
760-
dataset._sampler._attr_types[cat_key_idx], pa.list_(pa.int64())
761-
)
773+
# Deep-copy guard: data_config not mutated by the rewrite.
762774
self.assertEqual(
763-
dataset._sampler._attr_types[cat_map_idx], pa.list_(pa.string())
775+
list(dataset._data_config.negative_sampler.attr_fields),
776+
["item_id", "cat_id", "duration"],
764777
)
778+
# Each bare entry prefix-rewritten and outer-list stripped.
779+
for qualified, expected_type in [
780+
("click_seq__item_id", pa.int64()),
781+
("click_seq__cat_id", pa.int64()),
782+
("click_seq__duration", pa.float32()),
783+
]:
784+
idx = dataset._sampler._attr_names.index(qualified)
785+
self.assertEqual(dataset._sampler._attr_types[idx], expected_type)
765786

766787
def test_dataset_with_sample_mask(self):
767788
input_fields = [

0 commit comments

Comments
 (0)