Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/models/hstu_match.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ data_config {
negative_sampler {
input_path: "odps://{PROJECT}/tables/taobao_ad_feature_gl_bucketized_v1"
num_sample: 128
attr_fields: "cand_seq__video_id"
attr_fields: "video_id"
item_id_field: "cand_seq__video_id"
attr_delimiter: "\t"
}
Expand Down Expand Up @@ -211,6 +211,8 @@ model_config {

- data_config: 数据配置,其中需要配置负采样 Sampler,负采样 Sampler 的配置详见 [DSSM](dssm.md) 文档中的**负采样配置**章节

- HSTUMatch 的候选侧是 `sequence_feature` 的子特征。在 `negative_sampler` 中,`item_id_field` 写为带 `sequence_name` 前缀的名(例如 `cand_seq__video_id`),`attr_fields` 写为不带前缀的子特征名(例如 `video_id`)。

- feature_groups: 特征组

- uih: 用户历史行为序列,可增加 side info;类型为 JAGGED_SEQUENCE,**必填**
Expand Down
76 changes: 36 additions & 40 deletions tzrec/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import os
import random
from collections import OrderedDict
Expand Down Expand Up @@ -165,27 +166,23 @@ def __init__(
):
self._selected_input_names = None

# Map input_name -> sequence_delim for true sequence inputs only
# (via feature.sequence_input_names). Excludes non-sequence
# sub-inputs of grouped sequence_feature.
self._seq_field_delims: Dict[str, str] = {}
for feature in features:
if not feature.sequence_delim:
continue
seq_inputs = set(feature.sequence_input_names)
for input_name in feature.inputs:
if input_name not in seq_inputs:
# Sequence state when item_id_field is a grouped sequence sub-feature.
self._sampler_seq_delim: str = ""
self._sampler_seq_prefix: str = ""
if self._sampler_item_id_field is not None:
for feature in features:
if self._sampler_item_id_field not in feature.sequence_input_names:
continue
existing = self._seq_field_delims.get(input_name)
if existing is not None and existing != feature.sequence_delim:
logger.warning(
"Conflicting sequence_delim for input '%s': %r vs %r; "
"latter wins.",
input_name,
existing,
feature.sequence_delim,
if not feature.is_grouped_sequence:
raise ValueError(
f"item_id_field '{self._sampler_item_id_field}' is a "
"sequence input but its matching feature is not a "
"grouped sequence; only grouped sequence sub-features "
"are supported as item_id_field."
)
self._seq_field_delims[input_name] = feature.sequence_delim
self._sampler_seq_delim = feature.sequence_delim
self._sampler_seq_prefix = feature.grouped_sequence_prefix
break

self._fg_mode = data_config.fg_mode
self._fg_encoded_multival_sep = data_config.fg_encoded_multival_sep
Expand All @@ -211,22 +208,23 @@ def launch_sampler_cluster(
sampler_type = self._data_config.WhichOneof("sampler")
sampler_config = getattr(self._data_config, sampler_type)

# Multi-positive sampling: when the sampler's item_id_field is
# itself a sequence-positive train column, the per-row outer list
# on every item-side attr is the positive-grouping container, not
# a multi-value field. Strip the outer list so the sampler sees
# the pool's native scalar storage and _to_arrow_array emits
# scalar negs directly (avoiding the multival_sep split that
# would wrap each scalar in a 1-elem list).
# Rewrite bare attr_fields to flattened (`video_id` ->
# `cand_seq__video_id`); deep-copy so data_config isn't mutated.
if self._sampler_seq_prefix:
sampler_config = copy.deepcopy(sampler_config)
sampler_config.attr_fields[:] = [
self._sampler_seq_prefix + a for a in sampler_config.attr_fields
]
Comment thread
tiankongdeguiji marked this conversation as resolved.

# Strip the per-row positive-grouping outer list on attr_fields
# columns so the sampler emits scalar negs.
sampler_fields = self.input_fields
if (
self._sampler_item_id_field is not None
and self._sampler_item_id_field in self._seq_field_delims
):
if self._sampler_seq_delim:
sampler_attrs = set(sampler_config.attr_fields)
sampler_fields = [
pa.field(f.name, f.type.value_type)
if (
f.name in self._seq_field_delims
f.name in sampler_attrs
and (pa.types.is_list(f.type) or pa.types.is_large_list(f.type))
)
else f
Expand Down Expand Up @@ -391,7 +389,7 @@ def _apply_negative_sampler(
input_data,
self._sampler_item_id_field,
self._sampler_user_id_field,
self._seq_field_delims,
self._sampler_seq_delim,
)
sampled = self._sampler.get(sampler_input)

Expand Down Expand Up @@ -459,11 +457,10 @@ def _merge_sampled_features(
) -> Optional[np.ndarray]:
"""Merge sampler outputs into input_data in place; return per-row pos lengths.

Per sampled key: new keys are assigned as-is, keys with a
`seq_delim` use the block-suffix combine, others fall back to
`pa.concat_arrays`. `pos_lengths` is sourced from the configured
item_id_field combine; returns None if no sequence field was
merged.
Per sampled key: new keys assigned as-is; otherwise routed via
`pa.concat_arrays` (scalar item_id_field) or the block-suffix
combine (grouped-sequence item_id_field). `pos_lengths` returns
None when not in sequence mode.
"""
# Prefer item_id_field; fall back to first-seen seq-field if absent.
prefer_key = self._sampler_item_id_field
Expand All @@ -473,12 +470,11 @@ def _merge_sampled_features(
if k not in input_data:
input_data[k] = v
continue
seq_delim = self._seq_field_delims.get(k)
if seq_delim is None:
if not self._sampler_seq_delim:
input_data[k] = pa.concat_arrays([input_data[k], v])
continue
combined, pl = combine_negs_to_candidate_sequence(
input_data[k], v, seq_delim
input_data[k], v, self._sampler_seq_delim
)
input_data[k] = combined
if k == prefer_key:
Expand Down
133 changes: 77 additions & 56 deletions tzrec/datasets/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,10 +479,11 @@ def test_dataset_with_sampler(self, force_base_data_group, mode, input_tile):
def test_dataset_with_sampler_list_item_id(self, mode):
"""E2E: list-typed item_id positives through a real NegativeSampler.

Schema declares `item_id` as `pa.list_(pa.int64())` (Parquet-style
multi-value column). Exercises `build_sampler_input`'s list-pass-
through + flatten, the dynamic-`expand_factor` path in the
sampler, and `combine_negs_to_candidate_sequence`'s list-typed-negs
Schema declares `cand_seq__item_id` (grouped sequence sub-feature
flattened name) as `pa.list_(pa.int64())` (multi-positive column).
Exercises `build_sampler_input`'s list-pass-through + flatten,
the dynamic-`expand_factor` path in the sampler, and
`combine_negs_to_candidate_sequence`'s list-typed-negs
normalization. Parameterized over Mode.TRAIN / Mode.EVAL so the
eval-mode `num_eval_sample` path is also covered.
"""
Expand All @@ -494,18 +495,25 @@ def test_dataset_with_sampler_list_item_id(self, mode):
f.flush()

input_fields = [
pa.field(name="item_id", type=pa.list_(pa.int64())),
pa.field(name="cand_seq__item_id", type=pa.list_(pa.int64())),
pa.field(name="label", type=pa.int32()),
]
feature_cfgs = [
feature_pb2.FeatureConfig(
sequence_id_feature=feature_pb2.IdFeature(
feature_name="item_id",
expression="item:item_id",
sequence_feature=feature_pb2.SequenceFeature(
sequence_name="cand_seq",
sequence_length=10,
sequence_delim=";",
num_buckets=200,
embedding_dim=8,
features=[
feature_pb2.SeqFeatureConfig(
id_feature=feature_pb2.IdFeature(
feature_name="item_id",
expression="item:item_id",
num_buckets=200,
embedding_dim=8,
)
),
],
)
),
]
Expand All @@ -525,8 +533,8 @@ def test_dataset_with_sampler_list_item_id(self, mode):
input_path=f.name,
num_sample=8,
num_eval_sample=4,
attr_fields=["item_id"],
item_id_field="item_id",
attr_fields=["item_id"], # bare; gets rewritten
item_id_field="cand_seq__item_id", # qualified
),
force_base_data_group=True,
),
Expand All @@ -540,7 +548,7 @@ def test_dataset_with_sampler_list_item_id(self, mode):
# Multi-positive mode (item_id is sequence-positive in train):
# launch_sampler_cluster strips the outer list so the sampler emits
# scalar item_id negs, not 1-elem lists via the multival_sep round-trip.
item_id_idx = dataset._sampler._attr_names.index("item_id")
item_id_idx = dataset._sampler._attr_names.index("cand_seq__item_id")
self.assertEqual(dataset._sampler._attr_types[item_id_idx], pa.int64())

dataloader = DataLoader(
Expand Down Expand Up @@ -599,7 +607,7 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode):

input_fields = [
pa.field(name="user_id", type=pa.int64()),
pa.field(name="item_id", type=pa.list_(pa.int64())),
pa.field(name="cand_seq__item_id", type=pa.list_(pa.int64())),
pa.field(name="label", type=pa.int32()),
]
feature_cfgs = [
Expand All @@ -612,13 +620,20 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode):
)
),
feature_pb2.FeatureConfig(
sequence_id_feature=feature_pb2.IdFeature(
feature_name="item_id",
expression="item:item_id",
sequence_feature=feature_pb2.SequenceFeature(
sequence_name="cand_seq",
sequence_length=10,
sequence_delim=";",
num_buckets=200,
embedding_dim=8,
features=[
feature_pb2.SeqFeatureConfig(
id_feature=feature_pb2.IdFeature(
feature_name="item_id",
expression="item:item_id",
num_buckets=200,
embedding_dim=8,
)
),
],
)
),
]
Expand All @@ -641,8 +656,8 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode):
num_sample=8,
num_eval_sample=4,
num_hard_sample=2,
attr_fields=["item_id"],
item_id_field="item_id",
attr_fields=["item_id"], # bare; gets rewritten
item_id_field="cand_seq__item_id", # qualified
user_id_field="user_id",
),
force_base_data_group=True,
Expand All @@ -655,7 +670,7 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode):
dataset.launch_sampler_cluster(2)

# Multi-positive mode: outer list stripped so sampler emits scalar negs.
item_id_idx = dataset._sampler._attr_names.index("item_id")
item_id_idx = dataset._sampler._attr_names.index("cand_seq__item_id")
self.assertEqual(dataset._sampler._attr_types[item_id_idx], pa.int64())

dataloader = DataLoader(
Expand All @@ -677,27 +692,19 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode):
hard_neg_indices = batch.additional_infos[HARD_NEG_INDICES]
self.assertEqual(set(hard_neg_indices[:, 0].tolist()), {0, 1, 2, 3, 4, 5, 6, 7})

def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self):
"""End-to-end strip decisions across the per-attr filter.

Grouped LookupFeature sub with two item-side inputs; only ``cat_key``
is in ``sequence_fields`` so only it enters ``_seq_field_delims``.
``cat_key`` is typed ``list<list<int64>>`` (multi-value attr layered
under multi-positive grouping); after strip it becomes
``list<int64>`` (ONE level stripped, not bare-stripped to ``int64``).
``cat_map`` is item-side but excluded from ``_seq_field_delims``,
so it stays ``list<string>`` unchanged.
"""
def test_launch_sampler_cluster_grouped_sequence_strip_and_rewrite(self):
"""Prefix-rewrite + outer-list strip across multiple sub-feature types."""
f = tempfile.NamedTemporaryFile("w")
self._temp_files.append(f)
f.write("id:int64\tweight:float\tattrs:string\n")
for i in range(100):
f.write(f"{i}\t1.0\t{i}:{i + 1000}\n")
f.write(f"{i}\t1.0\t{i}:{i + 1000}:0.5\n")
f.flush()

input_fields = [
pa.field(name="cat_map", type=pa.list_(pa.string())),
pa.field(name="click_seq__cat_key", type=pa.list_(pa.list_(pa.int64()))),
pa.field(name="click_seq__item_id", type=pa.list_(pa.int64())),
pa.field(name="click_seq__cat_id", type=pa.list_(pa.int64())),
pa.field(name="click_seq__duration", type=pa.list_(pa.float32())),
pa.field(name="label", type=pa.int32()),
]
feature_cfgs = [
Expand All @@ -708,23 +715,35 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self):
sequence_delim=";",
features=[
feature_pb2.SeqFeatureConfig(
lookup_feature=feature_pb2.LookupFeature(
feature_name="lookup_c",
map="item:cat_map",
key="item:cat_key",
sequence_fields=["cat_key"],
id_feature=feature_pb2.IdFeature(
feature_name="item_id",
expression="item:item_id",
num_buckets=200,
embedding_dim=8,
)
),
feature_pb2.SeqFeatureConfig(
id_feature=feature_pb2.IdFeature(
feature_name="cat_id",
expression="item:cat_id",
num_buckets=10,
embedding_dim=8,
)
),
feature_pb2.SeqFeatureConfig(
raw_feature=feature_pb2.RawFeature(
feature_name="duration",
expression="item:duration",
)
),
],
)
),
]
features = create_features(
feature_cfgs,
fg_mode=data_pb2.FgMode.FG_NORMAL,
neg_fields=["cat_map", "cat_key"],
neg_fields=["item_id", "cat_id", "duration"],
force_base_data_group=True,
)
dataset = _TestDataset(
Expand All @@ -736,8 +755,9 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self):
negative_sampler=sampler_pb2.NegativeSampler(
input_path=f.name,
num_sample=4,
attr_fields=["cat_map", "click_seq__cat_key"],
item_id_field="click_seq__cat_key",
# bare names; prefix-rewritten at launch
attr_fields=["item_id", "cat_id", "duration"],
item_id_field="click_seq__item_id",
),
force_base_data_group=True,
),
Expand All @@ -746,22 +766,23 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self):
input_fields=input_fields,
mode=Mode.TRAIN,
)
# Narrowed _seq_field_delims excludes the non-sequence item-side input.
self.assertIn("click_seq__cat_key", dataset._seq_field_delims)
self.assertNotIn("cat_map", dataset._seq_field_delims)
self.assertEqual(dataset._sampler_seq_delim, ";")
self.assertEqual(dataset._sampler_seq_prefix, "click_seq__")

dataset.launch_sampler_cluster(2)
# outer guard True (item_id_field is sequence-positive):
# - cat_key: list<list<int64>> -> list<int64> (one strip).
# - cat_map: list<string>, not in _seq_field_delims -> unstripped.
cat_key_idx = dataset._sampler._attr_names.index("click_seq__cat_key")
cat_map_idx = dataset._sampler._attr_names.index("cat_map")
self.assertEqual(
dataset._sampler._attr_types[cat_key_idx], pa.list_(pa.int64())
)
# Deep-copy guard: data_config not mutated by the rewrite.
self.assertEqual(
dataset._sampler._attr_types[cat_map_idx], pa.list_(pa.string())
list(dataset._data_config.negative_sampler.attr_fields),
["item_id", "cat_id", "duration"],
)
# Each bare entry prefix-rewritten and outer-list stripped.
for qualified, expected_type in [
("click_seq__item_id", pa.int64()),
("click_seq__cat_id", pa.int64()),
("click_seq__duration", pa.float32()),
]:
idx = dataset._sampler._attr_names.index(qualified)
self.assertEqual(dataset._sampler._attr_types[idx], expected_type)
Comment on lines 695 to +785

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two coverage gaps worth a follow-up commit:

  1. No negative test for the assert feature.is_grouped_sequence path at dataset.py:176. A 5-line assertRaisesRegex construction with a top-level sequence_id_feature + matching item_id_field would lock the rejection — without it, the assertion can be silently weakened in a future refactor (and would be invisible under python -O; see the related comment on the assert itself).

  2. No large_list coverage in this test. The strip predicate at dataset.py:227 handles both is_list and is_large_list, but every input field here is pa.list_(...). Adding a single pa.large_list(pa.int64()) sub-feature row would cover the second arm.


def test_dataset_with_sample_mask(self):
input_fields = [
Expand Down
Loading
Loading