diff --git a/docs/source/models/hstu_match.md b/docs/source/models/hstu_match.md index 1ec1c6cf..3f013d2a 100644 --- a/docs/source/models/hstu_match.md +++ b/docs/source/models/hstu_match.md @@ -16,7 +16,7 @@ data_config { negative_sampler { input_path: "odps://{PROJECT}/tables/taobao_ad_feature_gl_bucketized_v1" num_sample: 128 - attr_fields: "cand_seq__video_id" + attr_fields: "video_id" item_id_field: "cand_seq__video_id" attr_delimiter: "\t" } @@ -211,6 +211,8 @@ model_config { - data_config: 数据配置,其中需要配置负采样 Sampler,负采样 Sampler 的配置详见 [DSSM](dssm.md) 文档中的**负采样配置**章节 + - HSTUMatch 的候选侧是 `sequence_feature` 的子特征。在 `negative_sampler` 中,`item_id_field` 写为带 `sequence_name` 前缀的名(例如 `cand_seq__video_id`),`attr_fields` 写为不带前缀的子特征名(例如 `video_id`)。 + - feature_groups: 特征组 - uih: 用户历史行为序列,可增加 side info;类型为 JAGGED_SEQUENCE,**必填** diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index 21459fea..e589cfee 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import os import random from collections import OrderedDict @@ -165,27 +166,23 @@ def __init__( ): self._selected_input_names = None - # Map input_name -> sequence_delim for true sequence inputs only - # (via feature.sequence_input_names). Excludes non-sequence - # sub-inputs of grouped sequence_feature. - self._seq_field_delims: Dict[str, str] = {} - for feature in features: - if not feature.sequence_delim: - continue - seq_inputs = set(feature.sequence_input_names) - for input_name in feature.inputs: - if input_name not in seq_inputs: + # Sequence state when item_id_field is a grouped sequence sub-feature. + self._sampler_seq_delim: str = "" + self._sampler_seq_prefix: str = "" + if self._sampler_item_id_field is not None: + for feature in features: + if self._sampler_item_id_field not in feature.sequence_input_names: continue - existing = self._seq_field_delims.get(input_name) - if existing is not None and existing != feature.sequence_delim: - logger.warning( - "Conflicting sequence_delim for input '%s': %r vs %r; " - "latter wins.", - input_name, - existing, - feature.sequence_delim, + if not feature.is_grouped_sequence: + raise ValueError( + f"item_id_field '{self._sampler_item_id_field}' is a " + "sequence input but its matching feature is not a " + "grouped sequence; only grouped sequence sub-features " + "are supported as item_id_field." ) - self._seq_field_delims[input_name] = feature.sequence_delim + self._sampler_seq_delim = feature.sequence_delim + self._sampler_seq_prefix = feature.grouped_sequence_prefix + break self._fg_mode = data_config.fg_mode self._fg_encoded_multival_sep = data_config.fg_encoded_multival_sep @@ -211,22 +208,23 @@ def launch_sampler_cluster( sampler_type = self._data_config.WhichOneof("sampler") sampler_config = getattr(self._data_config, sampler_type) - # Multi-positive sampling: when the sampler's item_id_field is - # itself a sequence-positive train column, the per-row outer list - # on every item-side attr is the positive-grouping container, not - # a multi-value field. Strip the outer list so the sampler sees - # the pool's native scalar storage and _to_arrow_array emits - # scalar negs directly (avoiding the multival_sep split that - # would wrap each scalar in a 1-elem list). + # Rewrite bare attr_fields to flattened (`video_id` -> + # `cand_seq__video_id`); deep-copy so data_config isn't mutated. + if self._sampler_seq_prefix: + sampler_config = copy.deepcopy(sampler_config) + sampler_config.attr_fields[:] = [ + self._sampler_seq_prefix + a for a in sampler_config.attr_fields + ] + + # Strip the per-row positive-grouping outer list on attr_fields + # columns so the sampler emits scalar negs. sampler_fields = self.input_fields - if ( - self._sampler_item_id_field is not None - and self._sampler_item_id_field in self._seq_field_delims - ): + if self._sampler_seq_delim: + sampler_attrs = set(sampler_config.attr_fields) sampler_fields = [ pa.field(f.name, f.type.value_type) if ( - f.name in self._seq_field_delims + f.name in sampler_attrs and (pa.types.is_list(f.type) or pa.types.is_large_list(f.type)) ) else f @@ -391,7 +389,7 @@ def _apply_negative_sampler( input_data, self._sampler_item_id_field, self._sampler_user_id_field, - self._seq_field_delims, + self._sampler_seq_delim, ) sampled = self._sampler.get(sampler_input) @@ -459,11 +457,10 @@ def _merge_sampled_features( ) -> Optional[np.ndarray]: """Merge sampler outputs into input_data in place; return per-row pos lengths. - Per sampled key: new keys are assigned as-is, keys with a - `seq_delim` use the block-suffix combine, others fall back to - `pa.concat_arrays`. `pos_lengths` is sourced from the configured - item_id_field combine; returns None if no sequence field was - merged. + Per sampled key: new keys assigned as-is; otherwise routed via + `pa.concat_arrays` (scalar item_id_field) or the block-suffix + combine (grouped-sequence item_id_field). `pos_lengths` returns + None when not in sequence mode. """ # Prefer item_id_field; fall back to first-seen seq-field if absent. prefer_key = self._sampler_item_id_field @@ -473,12 +470,11 @@ def _merge_sampled_features( if k not in input_data: input_data[k] = v continue - seq_delim = self._seq_field_delims.get(k) - if seq_delim is None: + if not self._sampler_seq_delim: input_data[k] = pa.concat_arrays([input_data[k], v]) continue combined, pl = combine_negs_to_candidate_sequence( - input_data[k], v, seq_delim + input_data[k], v, self._sampler_seq_delim ) input_data[k] = combined if k == prefer_key: diff --git a/tzrec/datasets/dataset_test.py b/tzrec/datasets/dataset_test.py index a8928ef1..5c0276bf 100644 --- a/tzrec/datasets/dataset_test.py +++ b/tzrec/datasets/dataset_test.py @@ -479,10 +479,11 @@ def test_dataset_with_sampler(self, force_base_data_group, mode, input_tile): def test_dataset_with_sampler_list_item_id(self, mode): """E2E: list-typed item_id positives through a real NegativeSampler. - Schema declares `item_id` as `pa.list_(pa.int64())` (Parquet-style - multi-value column). Exercises `build_sampler_input`'s list-pass- - through + flatten, the dynamic-`expand_factor` path in the - sampler, and `combine_negs_to_candidate_sequence`'s list-typed-negs + Schema declares `cand_seq__item_id` (grouped sequence sub-feature + flattened name) as `pa.list_(pa.int64())` (multi-positive column). + Exercises `build_sampler_input`'s list-pass-through + flatten, + the dynamic-`expand_factor` path in the sampler, and + `combine_negs_to_candidate_sequence`'s list-typed-negs normalization. Parameterized over Mode.TRAIN / Mode.EVAL so the eval-mode `num_eval_sample` path is also covered. """ @@ -494,18 +495,25 @@ def test_dataset_with_sampler_list_item_id(self, mode): f.flush() input_fields = [ - pa.field(name="item_id", type=pa.list_(pa.int64())), + pa.field(name="cand_seq__item_id", type=pa.list_(pa.int64())), pa.field(name="label", type=pa.int32()), ] feature_cfgs = [ feature_pb2.FeatureConfig( - sequence_id_feature=feature_pb2.IdFeature( - feature_name="item_id", - expression="item:item_id", + sequence_feature=feature_pb2.SequenceFeature( + sequence_name="cand_seq", sequence_length=10, sequence_delim=";", - num_buckets=200, - embedding_dim=8, + features=[ + feature_pb2.SeqFeatureConfig( + id_feature=feature_pb2.IdFeature( + feature_name="item_id", + expression="item:item_id", + num_buckets=200, + embedding_dim=8, + ) + ), + ], ) ), ] @@ -525,8 +533,8 @@ def test_dataset_with_sampler_list_item_id(self, mode): input_path=f.name, num_sample=8, num_eval_sample=4, - attr_fields=["item_id"], - item_id_field="item_id", + attr_fields=["item_id"], # bare; gets rewritten + item_id_field="cand_seq__item_id", # qualified ), force_base_data_group=True, ), @@ -540,7 +548,7 @@ def test_dataset_with_sampler_list_item_id(self, mode): # Multi-positive mode (item_id is sequence-positive in train): # launch_sampler_cluster strips the outer list so the sampler emits # scalar item_id negs, not 1-elem lists via the multival_sep round-trip. - item_id_idx = dataset._sampler._attr_names.index("item_id") + item_id_idx = dataset._sampler._attr_names.index("cand_seq__item_id") self.assertEqual(dataset._sampler._attr_types[item_id_idx], pa.int64()) dataloader = DataLoader( @@ -599,7 +607,7 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode): input_fields = [ pa.field(name="user_id", type=pa.int64()), - pa.field(name="item_id", type=pa.list_(pa.int64())), + pa.field(name="cand_seq__item_id", type=pa.list_(pa.int64())), pa.field(name="label", type=pa.int32()), ] feature_cfgs = [ @@ -612,13 +620,20 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode): ) ), feature_pb2.FeatureConfig( - sequence_id_feature=feature_pb2.IdFeature( - feature_name="item_id", - expression="item:item_id", + sequence_feature=feature_pb2.SequenceFeature( + sequence_name="cand_seq", sequence_length=10, sequence_delim=";", - num_buckets=200, - embedding_dim=8, + features=[ + feature_pb2.SeqFeatureConfig( + id_feature=feature_pb2.IdFeature( + feature_name="item_id", + expression="item:item_id", + num_buckets=200, + embedding_dim=8, + ) + ), + ], ) ), ] @@ -641,8 +656,8 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode): num_sample=8, num_eval_sample=4, num_hard_sample=2, - attr_fields=["item_id"], - item_id_field="item_id", + attr_fields=["item_id"], # bare; gets rewritten + item_id_field="cand_seq__item_id", # qualified user_id_field="user_id", ), force_base_data_group=True, @@ -655,7 +670,7 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode): dataset.launch_sampler_cluster(2) # Multi-positive mode: outer list stripped so sampler emits scalar negs. - item_id_idx = dataset._sampler._attr_names.index("item_id") + item_id_idx = dataset._sampler._attr_names.index("cand_seq__item_id") self.assertEqual(dataset._sampler._attr_types[item_id_idx], pa.int64()) dataloader = DataLoader( @@ -677,27 +692,19 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode): hard_neg_indices = batch.additional_infos[HARD_NEG_INDICES] self.assertEqual(set(hard_neg_indices[:, 0].tolist()), {0, 1, 2, 3, 4, 5, 6, 7}) - def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): - """End-to-end strip decisions across the per-attr filter. - - Grouped LookupFeature sub with two item-side inputs; only ``cat_key`` - is in ``sequence_fields`` so only it enters ``_seq_field_delims``. - ``cat_key`` is typed ``list>`` (multi-value attr layered - under multi-positive grouping); after strip it becomes - ``list`` (ONE level stripped, not bare-stripped to ``int64``). - ``cat_map`` is item-side but excluded from ``_seq_field_delims``, - so it stays ``list`` unchanged. - """ + def test_launch_sampler_cluster_grouped_sequence_strip_and_rewrite(self): + """Prefix-rewrite + outer-list strip across multiple sub-feature types.""" f = tempfile.NamedTemporaryFile("w") self._temp_files.append(f) f.write("id:int64\tweight:float\tattrs:string\n") for i in range(100): - f.write(f"{i}\t1.0\t{i}:{i + 1000}\n") + f.write(f"{i}\t1.0\t{i}:{i + 1000}:0.5\n") f.flush() input_fields = [ - pa.field(name="cat_map", type=pa.list_(pa.string())), - pa.field(name="click_seq__cat_key", type=pa.list_(pa.list_(pa.int64()))), + pa.field(name="click_seq__item_id", type=pa.list_(pa.int64())), + pa.field(name="click_seq__cat_id", type=pa.list_(pa.int64())), + pa.field(name="click_seq__duration", type=pa.list_(pa.float32())), pa.field(name="label", type=pa.int32()), ] feature_cfgs = [ @@ -708,15 +715,27 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): sequence_delim=";", features=[ feature_pb2.SeqFeatureConfig( - lookup_feature=feature_pb2.LookupFeature( - feature_name="lookup_c", - map="item:cat_map", - key="item:cat_key", - sequence_fields=["cat_key"], + id_feature=feature_pb2.IdFeature( + feature_name="item_id", + expression="item:item_id", + num_buckets=200, + embedding_dim=8, + ) + ), + feature_pb2.SeqFeatureConfig( + id_feature=feature_pb2.IdFeature( + feature_name="cat_id", + expression="item:cat_id", num_buckets=10, embedding_dim=8, ) ), + feature_pb2.SeqFeatureConfig( + raw_feature=feature_pb2.RawFeature( + feature_name="duration", + expression="item:duration", + ) + ), ], ) ), @@ -724,7 +743,7 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): features = create_features( feature_cfgs, fg_mode=data_pb2.FgMode.FG_NORMAL, - neg_fields=["cat_map", "cat_key"], + neg_fields=["item_id", "cat_id", "duration"], force_base_data_group=True, ) dataset = _TestDataset( @@ -736,8 +755,9 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): negative_sampler=sampler_pb2.NegativeSampler( input_path=f.name, num_sample=4, - attr_fields=["cat_map", "click_seq__cat_key"], - item_id_field="click_seq__cat_key", + # bare names; prefix-rewritten at launch + attr_fields=["item_id", "cat_id", "duration"], + item_id_field="click_seq__item_id", ), force_base_data_group=True, ), @@ -746,22 +766,23 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): input_fields=input_fields, mode=Mode.TRAIN, ) - # Narrowed _seq_field_delims excludes the non-sequence item-side input. - self.assertIn("click_seq__cat_key", dataset._seq_field_delims) - self.assertNotIn("cat_map", dataset._seq_field_delims) + self.assertEqual(dataset._sampler_seq_delim, ";") + self.assertEqual(dataset._sampler_seq_prefix, "click_seq__") dataset.launch_sampler_cluster(2) - # outer guard True (item_id_field is sequence-positive): - # - cat_key: list> -> list (one strip). - # - cat_map: list, not in _seq_field_delims -> unstripped. - cat_key_idx = dataset._sampler._attr_names.index("click_seq__cat_key") - cat_map_idx = dataset._sampler._attr_names.index("cat_map") - self.assertEqual( - dataset._sampler._attr_types[cat_key_idx], pa.list_(pa.int64()) - ) + # Deep-copy guard: data_config not mutated by the rewrite. self.assertEqual( - dataset._sampler._attr_types[cat_map_idx], pa.list_(pa.string()) + list(dataset._data_config.negative_sampler.attr_fields), + ["item_id", "cat_id", "duration"], ) + # Each bare entry prefix-rewritten and outer-list stripped. + for qualified, expected_type in [ + ("click_seq__item_id", pa.int64()), + ("click_seq__cat_id", pa.int64()), + ("click_seq__duration", pa.float32()), + ]: + idx = dataset._sampler._attr_names.index(qualified) + self.assertEqual(dataset._sampler._attr_types[idx], expected_type) def test_dataset_with_sample_mask(self): input_fields = [ diff --git a/tzrec/datasets/utils.py b/tzrec/datasets/utils.py index 98f3727f..d006c0cc 100644 --- a/tzrec/datasets/utils.py +++ b/tzrec/datasets/utils.py @@ -582,31 +582,31 @@ def build_sampler_input( input_data: Dict[str, pa.Array], item_id_field: Optional[str], user_id_field: Optional[str], - seq_field_delims: Dict[str, str], + seq_delim: str, ) -> Dict[str, pa.Array]: """Shallow-copy input_data with item_id (and user_id) flattened for the sampler. - When `item_id_field` is a sequence_id_feature, per-row positives - (delimited string or list array) are flattened to 1D and + When `item_id_field` is a grouped sequence sub-feature, per-row + positives (delimited string or list array) are flattened to 1D and `user_id_field` (if any) is expanded by per-row positive count. - Scalar item_id or unconfigured seq_delim falls through unchanged. - The caller's `input_data` is not mutated. + Scalar item_id (`seq_delim=""`) falls through unchanged. The + caller's `input_data` is not mutated. Args: input_data: per-row input column dict. item_id_field: sampler config's `item_id_field`, or None. user_id_field: sampler config's `user_id_field`, or None. - seq_field_delims: input_name -> sequence_delim mapping. + seq_delim: candidate sequence's `sequence_delim`, or "" when + `item_id_field` is a top-level scalar feature. Returns: A new shallow-copy dict with item_id flattened and user_id expanded when both apply. """ sampler_input = dict(input_data) - if item_id_field is None or item_id_field not in seq_field_delims: + if item_id_field is None or not seq_delim: return sampler_input - seq_delim = seq_field_delims[item_id_field] raw = input_data[item_id_field] if pa.types.is_string(raw.type) or pa.types.is_large_string(raw.type): pos_lists = pc.split_pattern(raw, seq_delim) diff --git a/tzrec/datasets/utils_test.py b/tzrec/datasets/utils_test.py index c7405b29..3e95a52a 100644 --- a/tzrec/datasets/utils_test.py +++ b/tzrec/datasets/utils_test.py @@ -210,7 +210,7 @@ def test_calc_slice_intervals_topology_change(self): @parameterized.expand( [ # (name, input_data, item_id_field, user_id_field, - # seq_field_delims, expected_output) + # seq_delim, expected_output) ( # NegativeSampler-style: no user_id_field; item_id is # delimited string; gets flattened. @@ -218,7 +218,7 @@ def test_calc_slice_intervals_topology_change(self): {"item_id": pa.array(["1;2", "3"]), "label": pa.array([1, 0])}, "item_id", None, - {"item_id": ";"}, + ";", {"item_id": ["1", "2", "3"], "label": [1, 0]}, ), ( @@ -231,7 +231,7 @@ def test_calc_slice_intervals_topology_change(self): }, "item_id", "user_id", - {"item_id": ";"}, + ";", {"item_id": [1, 2, 3], "user_id": ["u0", "u0", "u1"]}, ), ( @@ -244,16 +244,16 @@ def test_calc_slice_intervals_topology_change(self): }, "item_id", "user_id", - {"item_id": ";"}, + ";", {"item_id": [1, 2, 3], "user_id": ["u0", "u1", "u2"]}, ), ( - # item_id_field has no seq_delim entry -> pass through. - "item_id_not_in_seq_field_delims", + # item_id_field is a top-level scalar -> seq_delim="" -> pass through. + "empty_seq_delim_passthrough", {"item_id": pa.array(["1", "2"])}, "item_id", None, - {}, + "", {"item_id": ["1", "2"]}, ), ( @@ -263,7 +263,7 @@ def test_calc_slice_intervals_topology_change(self): {"a": pa.array([1, 2])}, None, None, - {}, + "", {"a": [1, 2]}, ), ] @@ -274,7 +274,7 @@ def test_build_sampler_input( input_data, item_id_field, user_id_field, - seq_field_delims, + seq_delim, expected_output, ): # Snapshot input_data so we can verify the function didn't mutate it. @@ -284,7 +284,7 @@ def test_build_sampler_input( input_data, item_id_field=item_id_field, user_id_field=user_id_field, - seq_field_delims=seq_field_delims, + seq_delim=seq_delim, ) # Contract 1: output equals expected (per-column pylist compare). diff --git a/tzrec/features/feature.py b/tzrec/features/feature.py index 55cc587f..bd5bb416 100644 --- a/tzrec/features/feature.py +++ b/tzrec/features/feature.py @@ -469,10 +469,7 @@ def __init__( @property def name(self) -> str: """Feature name.""" - if self._is_grouped_seq: - return f"{self.sequence_name}{self._underline}{self.config.feature_name}" - else: - return self.config.feature_name + return f"{self.grouped_sequence_prefix}{self.config.feature_name}" @property def is_neg(self) -> bool: @@ -568,6 +565,11 @@ def is_grouped_sequence(self) -> bool: """Feature is grouped sequence or not.""" return self._is_grouped_seq + @property + def grouped_sequence_prefix(self) -> str: + """Flatten prefix ``<_underline>``; empty if not grouped.""" + return f"{self.sequence_name}{self._underline}" if self._is_grouped_seq else "" + @property def is_weighted(self) -> bool: """Feature is weighted id feature or not.""" @@ -746,7 +748,7 @@ def _is_sequence_input(self, side: str, name: str) -> bool: if not self.is_sequence: return False if self._is_grouped_seq and self.sequence_name: - prefix = f"{self.sequence_name}{self._underline}" + prefix = self.grouped_sequence_prefix if name.startswith(prefix): name = name[len(prefix) :] if ( @@ -784,7 +786,7 @@ def side_inputs(self) -> List[Tuple[str, str]]: ) side, name = x[0], x[1] seq_prefix = ( - f"{self.sequence_name}{self._underline}" + self.grouped_sequence_prefix if self._need_seq_prefix(side, name) else "" ) diff --git a/tzrec/tests/configs/hstu_kuairand_1k.config b/tzrec/tests/configs/hstu_kuairand_1k.config index 31df5d47..89d99987 100644 --- a/tzrec/tests/configs/hstu_kuairand_1k.config +++ b/tzrec/tests/configs/hstu_kuairand_1k.config @@ -33,7 +33,7 @@ data_config { negative_sampler { input_path: "data/test/kuairand-1k-match-item-gl.txt" num_sample: 128 - attr_fields: "cand_seq__video_id" + attr_fields: "video_id" item_id_field: "cand_seq__video_id" attr_delimiter: "\t" }