From ab1d6daa6b25534de18ea4488c082fbe6ac814b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 20 May 2026 20:25:09 +0800 Subject: [PATCH 1/8] [refactor] dataset: derive sampler sequence state from feature configs Replace the global `_seq_field_delims: Dict[str, str]` (every sequence input -> delim) with two candidate-side fields derived from the matching feature config: _sampler_seq_delim -- "" if `item_id_field` isn't a sequence input; the parent feature's `sequence_delim` otherwise (covers top-level `sequence_id_feature` AND grouped sequence sub-features). _sampler_seq_inputs -- candidate-side sequence input names: {feature.name} for top-level; union of `sequence_input_names` across all sub-features sharing `sequence_name` for grouped. Lookup is `item_id_field in feature.sequence_input_names`, which returns `[feature.name]` for top-level FG_NONE/FG_BUCKETIZE sequences and the flattened input names for grouped FG_DAG/FG_NORMAL sub-features -- single check covers both cases the old `_seq_field_delims` dict did. Three downstream simplifications in `tzrec/datasets/dataset.py`: * `launch_sampler_cluster`: outer-list-strip filter switches from `f.name in self._seq_field_delims` to `f.name in self._sampler_seq_inputs`, narrowing strip to the candidate-sequence inputs (excludes unrelated grouped sequences and non-sequence item-side attrs from the same lookup feature). * `_apply_negative_sampler`: passes the single `self._sampler_seq_delim` to `build_sampler_input` instead of the whole dict. * `_merge_sampled_features`: gates the per-key block-suffix combine on `k in self._sampler_seq_inputs` (a single set membership check replacing the per-key dict lookup). Correct for both top-level and grouped, and for mixed-attrs configs where a lookup feature exposes both candidate-sequence sub-features and unrelated item-side attrs. `tzrec/datasets/utils.py`: * `build_sampler_input(...)` signature change from `seq_field_delims: Dict[str, str]` to `seq_delim: str`. Docstring updated to name both top-level `sequence_id_feature` and grouped sub-feature cases (the master version mentioned only top-level). DSSM/MIND/TDM behaviour unchanged: when `item_id_field` is a top-level scalar, `_sampler_seq_delim` stays empty, `_sampler_seq_inputs` stays empty, and every branch degrades to the pre-refactor scalar path. Tests: * `dataset_test.py::test_launch_sampler_cluster_multi_attr_strip_decision_matrix` -- replace `_seq_field_delims` membership assertions with `_sampler_seq_delim` / `_sampler_seq_inputs` checks on the same mixed-attrs lookup_feature case (`cat_map` non-sequence item-side attr stays excluded; `click_seq__cat_key` grouped sub-feature stays included). * `utils_test.py::test_build_sampler_input` -- update `@parameterized.expand` rows to pass `seq_delim=` (str); rename the "no entry" case to "empty_seq_delim_passthrough". Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/datasets/dataset.py | 105 +++++++++++++++++---------------- tzrec/datasets/dataset_test.py | 21 ++++--- tzrec/datasets/utils.py | 17 +++--- tzrec/datasets/utils_test.py | 20 +++---- 4 files changed, 87 insertions(+), 76 deletions(-) diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index 21459fea..88f8c740 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -165,27 +165,36 @@ def __init__( ): self._selected_input_names = None - # Map input_name -> sequence_delim for true sequence inputs only - # (via feature.sequence_input_names). Excludes non-sequence - # sub-inputs of grouped sequence_feature. - self._seq_field_delims: Dict[str, str] = {} - for feature in features: - if not feature.sequence_delim: - continue - seq_inputs = set(feature.sequence_input_names) - for input_name in feature.inputs: - if input_name not in seq_inputs: - continue - existing = self._seq_field_delims.get(input_name) - if existing is not None and existing != feature.sequence_delim: - logger.warning( - "Conflicting sequence_delim for input '%s': %r vs %r; " - "latter wins.", - input_name, - existing, - feature.sequence_delim, - ) - self._seq_field_delims[input_name] = feature.sequence_delim + # Candidate-side sequence state derived from the matching feature + # config when `item_id_field` is a sequence input (top-level + # `sequence_id_feature` OR grouped sequence sub-feature). + # `_sampler_seq_delim` is the parent feature's `sequence_delim`; + # `_sampler_seq_inputs` is the set of sequence input names that + # share the candidate's sequence context -- {feature.name} for + # top-level, the union of `sequence_input_names` across all + # sub-features sharing `sequence_name` for grouped. Used by the + # outer-list strip in `launch_sampler_cluster` and by the per-key + # branch in `_merge_sampled_features`. + self._sampler_seq_delim: str = "" + self._sampler_seq_inputs: set = set() + if self._sampler_item_id_field is not None: + for feature in features: + if ( + feature.sequence_delim + and self._sampler_item_id_field in feature.sequence_input_names + ): + self._sampler_seq_delim = feature.sequence_delim + if feature.is_grouped_sequence: + seq_name = feature.sequence_name + self._sampler_seq_inputs = { + inp + for f in features + if f.is_grouped_sequence and f.sequence_name == seq_name + for inp in f.sequence_input_names + } + else: + self._sampler_seq_inputs = set(feature.sequence_input_names) + break self._fg_mode = data_config.fg_mode self._fg_encoded_multival_sep = data_config.fg_encoded_multival_sep @@ -212,26 +221,23 @@ def launch_sampler_cluster( sampler_config = getattr(self._data_config, sampler_type) # Multi-positive sampling: when the sampler's item_id_field is - # itself a sequence-positive train column, the per-row outer list - # on every item-side attr is the positive-grouping container, not - # a multi-value field. Strip the outer list so the sampler sees - # the pool's native scalar storage and _to_arrow_array emits - # scalar negs directly (avoiding the multival_sep split that - # would wrap each scalar in a 1-elem list). - sampler_fields = self.input_fields - if ( - self._sampler_item_id_field is not None - and self._sampler_item_id_field in self._seq_field_delims - ): - sampler_fields = [ - pa.field(f.name, f.type.value_type) - if ( - f.name in self._seq_field_delims - and (pa.types.is_list(f.type) or pa.types.is_large_list(f.type)) - ) - else f - for f in self.input_fields - ] + # itself a sequence-positive train column, the per-row outer + # list on every candidate-sequence attr is the positive- + # grouping container, not a multi-value field. Strip the outer + # list so the sampler sees the pool's native scalar storage and + # _to_arrow_array emits scalar negs directly. Filter is the + # candidate-sequence inputs set, so unrelated grouped sequences + # (e.g. uih_seq__*) and non-sequence item-side attrs from the + # same lookup feature (e.g. `cat_map`) are left untouched. + sampler_fields = [ + pa.field(f.name, f.type.value_type) + if ( + f.name in self._sampler_seq_inputs + and (pa.types.is_list(f.type) or pa.types.is_large_list(f.type)) + ) + else f + for f in self.input_fields + ] # pyre-ignore [16] self._sampler = BaseSampler.create_class(sampler_config.__class__.__name__)( @@ -391,7 +397,7 @@ def _apply_negative_sampler( input_data, self._sampler_item_id_field, self._sampler_user_id_field, - self._seq_field_delims, + self._sampler_seq_delim, ) sampled = self._sampler.get(sampler_input) @@ -459,11 +465,11 @@ def _merge_sampled_features( ) -> Optional[np.ndarray]: """Merge sampler outputs into input_data in place; return per-row pos lengths. - Per sampled key: new keys are assigned as-is, keys with a - `seq_delim` use the block-suffix combine, others fall back to - `pa.concat_arrays`. `pos_lengths` is sourced from the configured - item_id_field combine; returns None if no sequence field was - merged. + Per sampled key: new keys are assigned as-is; candidate-sequence + keys (those in `_sampler_seq_inputs`) use the block-suffix + combine; others fall back to `pa.concat_arrays`. `pos_lengths` is + sourced from the configured item_id_field combine; returns None + if no candidate-sequence field was merged. """ # Prefer item_id_field; fall back to first-seen seq-field if absent. prefer_key = self._sampler_item_id_field @@ -473,12 +479,11 @@ def _merge_sampled_features( if k not in input_data: input_data[k] = v continue - seq_delim = self._seq_field_delims.get(k) - if seq_delim is None: + if k not in self._sampler_seq_inputs: input_data[k] = pa.concat_arrays([input_data[k], v]) continue combined, pl = combine_negs_to_candidate_sequence( - input_data[k], v, seq_delim + input_data[k], v, self._sampler_seq_delim ) input_data[k] = combined if k == prefer_key: diff --git a/tzrec/datasets/dataset_test.py b/tzrec/datasets/dataset_test.py index a8928ef1..b1c40b46 100644 --- a/tzrec/datasets/dataset_test.py +++ b/tzrec/datasets/dataset_test.py @@ -681,12 +681,13 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): """End-to-end strip decisions across the per-attr filter. Grouped LookupFeature sub with two item-side inputs; only ``cat_key`` - is in ``sequence_fields`` so only it enters ``_seq_field_delims``. + is in ``sequence_fields`` so only it is a candidate-sequence input. ``cat_key`` is typed ``list>`` (multi-value attr layered under multi-positive grouping); after strip it becomes ``list`` (ONE level stripped, not bare-stripped to ``int64``). - ``cat_map`` is item-side but excluded from ``_seq_field_delims``, - so it stays ``list`` unchanged. + ``cat_map`` is item-side but NOT a candidate-sequence input + (excluded from ``_sampler_seq_inputs``), so it stays + ``list`` unchanged. """ f = tempfile.NamedTemporaryFile("w") self._temp_files.append(f) @@ -746,14 +747,18 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): input_fields=input_fields, mode=Mode.TRAIN, ) - # Narrowed _seq_field_delims excludes the non-sequence item-side input. - self.assertIn("click_seq__cat_key", dataset._seq_field_delims) - self.assertNotIn("cat_map", dataset._seq_field_delims) + # Candidate-side sequence state is derived from item_id_field's + # matching feature. `click_seq__cat_key` is a grouped sequence + # input; `cat_map` is a non-sequence item-side attr of the same + # lookup feature and is excluded. + self.assertEqual(dataset._sampler_seq_delim, ";") + self.assertIn("click_seq__cat_key", dataset._sampler_seq_inputs) + self.assertNotIn("cat_map", dataset._sampler_seq_inputs) dataset.launch_sampler_cluster(2) - # outer guard True (item_id_field is sequence-positive): + # item_id_field is a candidate-sequence sub-feature: # - cat_key: list> -> list (one strip). - # - cat_map: list, not in _seq_field_delims -> unstripped. + # - cat_map: list, not in _sampler_seq_inputs -> unstripped. cat_key_idx = dataset._sampler._attr_names.index("click_seq__cat_key") cat_map_idx = dataset._sampler._attr_names.index("cat_map") self.assertEqual( diff --git a/tzrec/datasets/utils.py b/tzrec/datasets/utils.py index 98f3727f..57646220 100644 --- a/tzrec/datasets/utils.py +++ b/tzrec/datasets/utils.py @@ -582,31 +582,32 @@ def build_sampler_input( input_data: Dict[str, pa.Array], item_id_field: Optional[str], user_id_field: Optional[str], - seq_field_delims: Dict[str, str], + seq_delim: str, ) -> Dict[str, pa.Array]: """Shallow-copy input_data with item_id (and user_id) flattened for the sampler. - When `item_id_field` is a sequence_id_feature, per-row positives - (delimited string or list array) are flattened to 1D and + When `item_id_field` is a sequence input (top-level + `sequence_id_feature` or grouped sequence sub-feature), per-row + positives (delimited string or list array) are flattened to 1D and `user_id_field` (if any) is expanded by per-row positive count. - Scalar item_id or unconfigured seq_delim falls through unchanged. - The caller's `input_data` is not mutated. + Scalar item_id (`seq_delim=""`) falls through unchanged. The + caller's `input_data` is not mutated. Args: input_data: per-row input column dict. item_id_field: sampler config's `item_id_field`, or None. user_id_field: sampler config's `user_id_field`, or None. - seq_field_delims: input_name -> sequence_delim mapping. + seq_delim: candidate sequence's `sequence_delim`, or "" when + `item_id_field` is a top-level scalar feature. Returns: A new shallow-copy dict with item_id flattened and user_id expanded when both apply. """ sampler_input = dict(input_data) - if item_id_field is None or item_id_field not in seq_field_delims: + if item_id_field is None or not seq_delim: return sampler_input - seq_delim = seq_field_delims[item_id_field] raw = input_data[item_id_field] if pa.types.is_string(raw.type) or pa.types.is_large_string(raw.type): pos_lists = pc.split_pattern(raw, seq_delim) diff --git a/tzrec/datasets/utils_test.py b/tzrec/datasets/utils_test.py index c7405b29..3e95a52a 100644 --- a/tzrec/datasets/utils_test.py +++ b/tzrec/datasets/utils_test.py @@ -210,7 +210,7 @@ def test_calc_slice_intervals_topology_change(self): @parameterized.expand( [ # (name, input_data, item_id_field, user_id_field, - # seq_field_delims, expected_output) + # seq_delim, expected_output) ( # NegativeSampler-style: no user_id_field; item_id is # delimited string; gets flattened. @@ -218,7 +218,7 @@ def test_calc_slice_intervals_topology_change(self): {"item_id": pa.array(["1;2", "3"]), "label": pa.array([1, 0])}, "item_id", None, - {"item_id": ";"}, + ";", {"item_id": ["1", "2", "3"], "label": [1, 0]}, ), ( @@ -231,7 +231,7 @@ def test_calc_slice_intervals_topology_change(self): }, "item_id", "user_id", - {"item_id": ";"}, + ";", {"item_id": [1, 2, 3], "user_id": ["u0", "u0", "u1"]}, ), ( @@ -244,16 +244,16 @@ def test_calc_slice_intervals_topology_change(self): }, "item_id", "user_id", - {"item_id": ";"}, + ";", {"item_id": [1, 2, 3], "user_id": ["u0", "u1", "u2"]}, ), ( - # item_id_field has no seq_delim entry -> pass through. - "item_id_not_in_seq_field_delims", + # item_id_field is a top-level scalar -> seq_delim="" -> pass through. + "empty_seq_delim_passthrough", {"item_id": pa.array(["1", "2"])}, "item_id", None, - {}, + "", {"item_id": ["1", "2"]}, ), ( @@ -263,7 +263,7 @@ def test_calc_slice_intervals_topology_change(self): {"a": pa.array([1, 2])}, None, None, - {}, + "", {"a": [1, 2]}, ), ] @@ -274,7 +274,7 @@ def test_build_sampler_input( input_data, item_id_field, user_id_field, - seq_field_delims, + seq_delim, expected_output, ): # Snapshot input_data so we can verify the function didn't mutate it. @@ -284,7 +284,7 @@ def test_build_sampler_input( input_data, item_id_field=item_id_field, user_id_field=user_id_field, - seq_field_delims=seq_field_delims, + seq_delim=seq_delim, ) # Contract 1: output equals expected (per-column pylist compare). From ffae57b7cf04b3b79de3e3b9fe0d572db579fd3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 20 May 2026 20:35:22 +0800 Subject: [PATCH 2/8] [refactor] sampler: accept bare sub-feature attr_fields for grouped-sequence candidate HSTUMatch is the only model whose sampler config carried qualified `attr_fields` (e.g. `attr_fields: "cand_seq__video_id"`), leaking DataParser's `{sequence_name}__{sub_feature}` flatten convention into user-facing config. Every other sampler config (DSSM/MIND/TDM) uses the bare sub-feature name because their candidate is a top-level feature where bare-name == flattened-name. Translate at the dataset boundary via an explicit alias map: _sampler_bare_attr_to_sequence_input: Dict[str, str] For grouped-sequence candidates (when item_id_field's matching feature is `is_grouped_sequence`), the map is built from the candidate- sequence input names by stripping `feature.sequence_name + feature._underline`: {"video_id": "cand_seq__video_id", ...} `_resolve_sampler_attr_field(name) -> str` does `map.get(name, name)`, and `launch_sampler_cluster` rewrites `sampler_config.attr_fields` in place over a deep-copy of the sampler sub-message (so `self._data_config` is not mutated). Three cases handled by the alias-map's `.get(a, a)` fallthrough: 1. Bare grouped-sequence sub-feature name -> rewritten to qualified. 2. Already-qualified name -> no map entry -> passed through. 3. Top-level item-side attr from the same lookup feature (e.g. `cat_map` excluded from `sequence_fields`) -> no map entry -> passed through. Flatten prefix uses `feature._underline` directly so RTP-safe ("_" vs "__" never enumerated). Deep-copy guard is `if alias map is non-empty`, so DSSM/MIND/TDM (alias map empty) skip the copy entirely. Test config + doc switch to the new convention: attr_fields: "video_id" (bare sub-feature name) item_id_field: "cand_seq__video_id" (qualified; doubles as the sequence_name source) Add `dataset_test.py::test_launch_sampler_cluster_bare_attr_resolves_against_seq_prefix` exercising the rewrite path end-to-end: asserts the alias map content, that `self._data_config` is not mutated, and that the sampler sees the qualified column name after `launch_sampler_cluster`. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/source/models/hstu_match.md | 4 +- tzrec/datasets/dataset.py | 44 ++++++++++++ tzrec/datasets/dataset_test.py | 80 +++++++++++++++++++++ tzrec/tests/configs/hstu_kuairand_1k.config | 2 +- 4 files changed, 128 insertions(+), 2 deletions(-) diff --git a/docs/source/models/hstu_match.md b/docs/source/models/hstu_match.md index 1ec1c6cf..3f013d2a 100644 --- a/docs/source/models/hstu_match.md +++ b/docs/source/models/hstu_match.md @@ -16,7 +16,7 @@ data_config { negative_sampler { input_path: "odps://{PROJECT}/tables/taobao_ad_feature_gl_bucketized_v1" num_sample: 128 - attr_fields: "cand_seq__video_id" + attr_fields: "video_id" item_id_field: "cand_seq__video_id" attr_delimiter: "\t" } @@ -211,6 +211,8 @@ model_config { - data_config: 数据配置,其中需要配置负采样 Sampler,负采样 Sampler 的配置详见 [DSSM](dssm.md) 文档中的**负采样配置**章节 + - HSTUMatch 的候选侧是 `sequence_feature` 的子特征。在 `negative_sampler` 中,`item_id_field` 写为带 `sequence_name` 前缀的名(例如 `cand_seq__video_id`),`attr_fields` 写为不带前缀的子特征名(例如 `video_id`)。 + - feature_groups: 特征组 - uih: 用户历史行为序列,可增加 side info;类型为 JAGGED_SEQUENCE,**必填** diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index 88f8c740..41c693f8 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import os import random from collections import OrderedDict @@ -175,8 +176,17 @@ def __init__( # sub-features sharing `sequence_name` for grouped. Used by the # outer-list strip in `launch_sampler_cluster` and by the per-key # branch in `_merge_sampled_features`. + # `_sampler_bare_attr_to_sequence_input` is the bare->qualified + # alias map used to rewrite `attr_fields` at sampler-launch time + # (only meaningful for the grouped case; empty otherwise -> the + # rewrite is a no-op). HSTUMatch's `attr_fields: "video_id"` + # writes the bare sub-feature name; the alias maps it to the + # flattened parquet column `cand_seq__video_id` the sampler pool + # actually carries. The flatten prefix uses `feature._underline` + # so RTP-safe (no "_" vs "__" guessing). self._sampler_seq_delim: str = "" self._sampler_seq_inputs: set = set() + self._sampler_bare_attr_to_sequence_input: Dict[str, str] = {} if self._sampler_item_id_field is not None: for feature in features: if ( @@ -186,12 +196,18 @@ def __init__( self._sampler_seq_delim = feature.sequence_delim if feature.is_grouped_sequence: seq_name = feature.sequence_name + # pyre-ignore [16]: _underline is the source of + # truth for the flatten separator. + prefix = seq_name + feature._underline self._sampler_seq_inputs = { inp for f in features if f.is_grouped_sequence and f.sequence_name == seq_name for inp in f.sequence_input_names } + self._sampler_bare_attr_to_sequence_input = { + inp[len(prefix) :]: inp for inp in self._sampler_seq_inputs + } else: self._sampler_seq_inputs = set(feature.sequence_input_names) break @@ -220,6 +236,20 @@ def launch_sampler_cluster( sampler_type = self._data_config.WhichOneof("sampler") sampler_config = getattr(self._data_config, sampler_type) + # Resolve bare sub-feature `attr_fields` against the grouped- + # sequence flatten prefix when applicable. Bare names (HSTUMatch + # convention, `attr_fields: "video_id"`) map to the qualified + # parquet column (`cand_seq__video_id`); already-qualified + # names and unrelated top-level item-side attrs pass through + # the alias map's `.get(a, a)` fallthrough. Deep-copy so the + # in-place mutation here doesn't leak back to `self._data_config`. + if self._sampler_bare_attr_to_sequence_input: + sampler_config = copy.deepcopy(sampler_config) + sampler_config.attr_fields[:] = [ + self._resolve_sampler_attr_field(a) + for a in sampler_config.attr_fields + ] + # Multi-positive sampling: when the sampler's item_id_field is # itself a sequence-positive train column, the per-row outer # list on every candidate-sequence attr is the positive- @@ -382,6 +412,20 @@ def _build_batch(self, input_data: Dict[str, pa.Array]) -> Batch: batch.checkpoint_info = checkpoint_info return batch + def _resolve_sampler_attr_field(self, attr_field: str) -> str: + """Translate a sampler `attr_field` from the user-facing namespace. + + For grouped-sequence candidate configs (e.g. HSTUMatch), the + user-facing `attr_fields` entry can be a bare sub-feature name + (e.g. ``"video_id"``); we resolve it to the flattened parquet + column the sampler pool actually carries (e.g. + ``"cand_seq__video_id"``). For already-qualified names and for + unrelated top-level item-side attrs (e.g. lookup_feature's + ``"cat_map"``), the alias map has no entry and the input is + returned unchanged. + """ + return self._sampler_bare_attr_to_sequence_input.get(attr_field, attr_field) + def _apply_negative_sampler( self, input_data: Dict[str, pa.Array], diff --git a/tzrec/datasets/dataset_test.py b/tzrec/datasets/dataset_test.py index b1c40b46..bfeae6d6 100644 --- a/tzrec/datasets/dataset_test.py +++ b/tzrec/datasets/dataset_test.py @@ -768,6 +768,86 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): dataset._sampler._attr_types[cat_map_idx], pa.list_(pa.string()) ) + def test_launch_sampler_cluster_bare_attr_resolves_against_seq_prefix(self): + """Bare `attr_fields` get rewritten to the qualified flatten name. + + HSTUMatch-style sampler config: `attr_fields: "cat_key"` (bare + sub-feature name) + `item_id_field: "click_seq__cat_key"` + (qualified). The dataset boundary maps the bare name to the + qualified parquet column via + ``_sampler_bare_attr_to_sequence_input``, so the sampler sees + the fully-flattened input name. + """ + f = tempfile.NamedTemporaryFile("w") + self._temp_files.append(f) + f.write("id:int64\tweight:float\tattrs:string\n") + for i in range(100): + f.write(f"{i}\t1.0\t{i}\n") + f.flush() + + input_fields = [ + pa.field(name="click_seq__cat_key", type=pa.list_(pa.int64())), + pa.field(name="label", type=pa.int32()), + ] + feature_cfgs = [ + feature_pb2.FeatureConfig( + sequence_feature=feature_pb2.SequenceFeature( + sequence_name="click_seq", + sequence_length=10, + sequence_delim=";", + features=[ + feature_pb2.SeqFeatureConfig( + id_feature=feature_pb2.IdFeature( + feature_name="cat_key", + expression="item:cat_key", + num_buckets=10, + embedding_dim=8, + ) + ), + ], + ) + ), + ] + features = create_features( + feature_cfgs, + fg_mode=data_pb2.FgMode.FG_NORMAL, + neg_fields=["cat_key"], + force_base_data_group=True, + ) + dataset = _TestDataset( + data_config=data_pb2.DataConfig( + batch_size=4, + dataset_type=data_pb2.DatasetType.OdpsDataset, + fg_mode=data_pb2.FgMode.FG_NORMAL, + label_fields=["label"], + negative_sampler=sampler_pb2.NegativeSampler( + input_path=f.name, + num_sample=4, + attr_fields=["cat_key"], # bare; gets rewritten + item_id_field="click_seq__cat_key", # qualified + ), + force_base_data_group=True, + ), + features=features, + input_path="", + input_fields=input_fields, + mode=Mode.TRAIN, + ) + # Alias map is built from the grouped sequence's flatten prefix. + self.assertEqual( + dataset._sampler_bare_attr_to_sequence_input, + {"cat_key": "click_seq__cat_key"}, + ) + # data_config.sampler is not mutated by the rewrite (deep-copied). + self.assertEqual( + list(dataset._data_config.negative_sampler.attr_fields), ["cat_key"] + ) + + dataset.launch_sampler_cluster(2) + # Sampler sees the QUALIFIED column name after rewrite. + self.assertIn("click_seq__cat_key", dataset._sampler._attr_names) + self.assertNotIn("cat_key", dataset._sampler._attr_names) + def test_dataset_with_sample_mask(self): input_fields = [ pa.field(name="int_a", type=pa.int64()), diff --git a/tzrec/tests/configs/hstu_kuairand_1k.config b/tzrec/tests/configs/hstu_kuairand_1k.config index 31df5d47..89d99987 100644 --- a/tzrec/tests/configs/hstu_kuairand_1k.config +++ b/tzrec/tests/configs/hstu_kuairand_1k.config @@ -33,7 +33,7 @@ data_config { negative_sampler { input_path: "data/test/kuairand-1k-match-item-gl.txt" num_sample: 128 - attr_fields: "cand_seq__video_id" + attr_fields: "video_id" item_id_field: "cand_seq__video_id" attr_delimiter: "\t" } From 703690f07ad99406ad3e89f236095a3a9c2d06d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 20 May 2026 20:44:27 +0800 Subject: [PATCH 3/8] [polish] dataset: trim sampler-seq init comments, restore scalar fast-path, flatten loop Addresses /simplify review on top of the prior two commits: * Trim the 17-line block comment above `_sampler_seq_*` init to 4 lines. Original block mostly paraphrased the code (WHAT); keep the load- bearing WHY (HSTUMatch bare-name convention + RTP-safe prefix). * Restore the scalar-config (DSSM/MIND/TDM) fast-path in `launch_sampler_cluster`: skip the `sampler_fields` list-comp when `_sampler_seq_inputs` is empty -- pass through `self.input_fields` unchanged like master did via the old guard. * Flatten the three-level init loop with `continue` guards. * Drop inaccurate `pyre-ignore [16]` on `feature._underline` -- the attribute is publicly defined on `BaseFeature.__init__` and other in-class call sites (e.g. `feature.py:473,749,787`) read it without any ignore. * Annotate `_sampler_seq_inputs` as `Set[str]` (was bare `set`), matching the `Dict[str, str]` style on the next line. No behavior change. Tests stay green (`tzrec.datasets.dataset_test`, `tzrec.datasets.utils_test`). Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/datasets/dataset.py | 89 +++++++++++++++++---------------------- 1 file changed, 38 insertions(+), 51 deletions(-) diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index 41c693f8..520006b7 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -13,7 +13,7 @@ import os import random from collections import OrderedDict -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union import numpy as np import pyarrow as pa @@ -166,51 +166,36 @@ def __init__( ): self._selected_input_names = None - # Candidate-side sequence state derived from the matching feature - # config when `item_id_field` is a sequence input (top-level - # `sequence_id_feature` OR grouped sequence sub-feature). - # `_sampler_seq_delim` is the parent feature's `sequence_delim`; - # `_sampler_seq_inputs` is the set of sequence input names that - # share the candidate's sequence context -- {feature.name} for - # top-level, the union of `sequence_input_names` across all - # sub-features sharing `sequence_name` for grouped. Used by the - # outer-list strip in `launch_sampler_cluster` and by the per-key - # branch in `_merge_sampled_features`. - # `_sampler_bare_attr_to_sequence_input` is the bare->qualified - # alias map used to rewrite `attr_fields` at sampler-launch time - # (only meaningful for the grouped case; empty otherwise -> the - # rewrite is a no-op). HSTUMatch's `attr_fields: "video_id"` - # writes the bare sub-feature name; the alias maps it to the - # flattened parquet column `cand_seq__video_id` the sampler pool - # actually carries. The flatten prefix uses `feature._underline` - # so RTP-safe (no "_" vs "__" guessing). + # Candidate-side sequence state for sampler I/O. HSTUMatch writes + # bare sub-feature names in `attr_fields`; the alias map rewrites + # them to the flattened parquet columns (`__`) the + # sampler pool actually carries. Prefix derived from + # `feature._underline` (not "__" literal) so RTP-safe. self._sampler_seq_delim: str = "" - self._sampler_seq_inputs: set = set() + self._sampler_seq_inputs: Set[str] = set() self._sampler_bare_attr_to_sequence_input: Dict[str, str] = {} if self._sampler_item_id_field is not None: for feature in features: - if ( - feature.sequence_delim - and self._sampler_item_id_field in feature.sequence_input_names - ): - self._sampler_seq_delim = feature.sequence_delim - if feature.is_grouped_sequence: - seq_name = feature.sequence_name - # pyre-ignore [16]: _underline is the source of - # truth for the flatten separator. - prefix = seq_name + feature._underline - self._sampler_seq_inputs = { - inp - for f in features - if f.is_grouped_sequence and f.sequence_name == seq_name - for inp in f.sequence_input_names - } - self._sampler_bare_attr_to_sequence_input = { - inp[len(prefix) :]: inp for inp in self._sampler_seq_inputs - } - else: - self._sampler_seq_inputs = set(feature.sequence_input_names) - break + if not feature.sequence_delim: + continue + if self._sampler_item_id_field not in feature.sequence_input_names: + continue + self._sampler_seq_delim = feature.sequence_delim + if feature.is_grouped_sequence: + seq_name = feature.sequence_name + prefix = seq_name + feature._underline + self._sampler_seq_inputs = { + inp + for f in features + if f.is_grouped_sequence and f.sequence_name == seq_name + for inp in f.sequence_input_names + } + self._sampler_bare_attr_to_sequence_input = { + inp[len(prefix) :]: inp for inp in self._sampler_seq_inputs + } + else: + self._sampler_seq_inputs = set(feature.sequence_input_names) + break self._fg_mode = data_config.fg_mode self._fg_encoded_multival_sep = data_config.fg_encoded_multival_sep @@ -259,15 +244,17 @@ def launch_sampler_cluster( # candidate-sequence inputs set, so unrelated grouped sequences # (e.g. uih_seq__*) and non-sequence item-side attrs from the # same lookup feature (e.g. `cat_map`) are left untouched. - sampler_fields = [ - pa.field(f.name, f.type.value_type) - if ( - f.name in self._sampler_seq_inputs - and (pa.types.is_list(f.type) or pa.types.is_large_list(f.type)) - ) - else f - for f in self.input_fields - ] + sampler_fields = self.input_fields + if self._sampler_seq_inputs: + sampler_fields = [ + pa.field(f.name, f.type.value_type) + if ( + f.name in self._sampler_seq_inputs + and (pa.types.is_list(f.type) or pa.types.is_large_list(f.type)) + ) + else f + for f in self.input_fields + ] # pyre-ignore [16] self._sampler = BaseSampler.create_class(sampler_config.__class__.__name__)( From 6834b2f4c16e19b3b3a0cfce69927ac2e890f2d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 20 May 2026 20:57:40 +0800 Subject: [PATCH 4/8] [polish] dataset_test: multi_attr test uses new bare convention; fold focused test The pre-existing `test_launch_sampler_cluster_multi_attr_strip_decision_matrix` still wrote `attr_fields=["cat_map", "click_seq__cat_key"]` -- the OLD qualified form for the sequence sub-feature, the very leak this PR is fixing. Update to `["cat_map", "cat_key"]` (bare for the sub-feature, top-level item-side `cat_map` unchanged), which both: * Demonstrates the canonical user-facing form under the new convention. * Exercises the alias-map rewrite (`cat_key` -> `click_seq__cat_key`) end-to-end alongside the fallthrough (`cat_map` has no alias entry). Now-redundant assertions from the focused `test_launch_sampler_cluster_bare_attr_resolves_against_seq_prefix` test fold into the multi_attr test: * `_sampler_bare_attr_to_sequence_input == {"cat_key": "click_seq__cat_key"}` (alias-map structural check). * `list(data_config.negative_sampler.attr_fields) == ["cat_map", "cat_key"]` after launch (deep-copy guard: `self._data_config` not mutated). * `assertNotIn("cat_key", _sampler._attr_names)` (rewrite replaces, doesn't append). Drop the focused test; the multi_attr test is now the comprehensive end-to-end check (rewrite + fallthrough + outer-list strip in one). Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/datasets/dataset_test.py | 143 ++++++++++----------------------- 1 file changed, 44 insertions(+), 99 deletions(-) diff --git a/tzrec/datasets/dataset_test.py b/tzrec/datasets/dataset_test.py index bfeae6d6..b3314c0c 100644 --- a/tzrec/datasets/dataset_test.py +++ b/tzrec/datasets/dataset_test.py @@ -678,16 +678,25 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode): self.assertEqual(set(hard_neg_indices[:, 0].tolist()), {0, 1, 2, 3, 4, 5, 6, 7}) def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): - """End-to-end strip decisions across the per-attr filter. - - Grouped LookupFeature sub with two item-side inputs; only ``cat_key`` - is in ``sequence_fields`` so only it is a candidate-sequence input. - ``cat_key`` is typed ``list>`` (multi-value attr layered - under multi-positive grouping); after strip it becomes - ``list`` (ONE level stripped, not bare-stripped to ``int64``). - ``cat_map`` is item-side but NOT a candidate-sequence input - (excluded from ``_sampler_seq_inputs``), so it stays - ``list`` unchanged. + """End-to-end attr-rewrite + strip decisions across the per-attr filter. + + Grouped LookupFeature sub with two item-side inputs; only + ``cat_key`` is in ``sequence_fields`` so only it is a + candidate-sequence input. ``attr_fields`` mixes the two + user-facing forms the new convention supports: + + - ``cat_key`` (bare grouped-sequence sub-feature name): alias- + map rewrites to ``click_seq__cat_key`` before sampler launch. + - ``cat_map`` (top-level item-side input from the same lookup + feature): not a sequence input, no alias entry, passes + through unchanged. + + Outer-list strip on ``input_fields`` then applies independently: + ``click_seq__cat_key`` is typed ``list>`` (multi- + value attr layered under multi-positive grouping); after strip + it becomes ``list`` (ONE level stripped, not bare- + stripped to ``int64``). ``cat_map`` is ``list``, not in + ``_sampler_seq_inputs``, stays unchanged. """ f = tempfile.NamedTemporaryFile("w") self._temp_files.append(f) @@ -737,7 +746,9 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): negative_sampler=sampler_pb2.NegativeSampler( input_path=f.name, num_sample=4, - attr_fields=["cat_map", "click_seq__cat_key"], + # Bare `cat_key` (gets rewritten) + top-level + # `cat_map` (no alias entry, passes through). + attr_fields=["cat_map", "cat_key"], item_id_field="click_seq__cat_key", ), force_base_data_group=True, @@ -747,17 +758,31 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): input_fields=input_fields, mode=Mode.TRAIN, ) - # Candidate-side sequence state is derived from item_id_field's - # matching feature. `click_seq__cat_key` is a grouped sequence - # input; `cat_map` is a non-sequence item-side attr of the same - # lookup feature and is excluded. + # Candidate-side sequence state derived from item_id_field's + # matching feature. Only sequence-input sub-features end up in + # `_sampler_seq_inputs`; the alias map's keys are the bare + # forms. self.assertEqual(dataset._sampler_seq_delim, ";") - self.assertIn("click_seq__cat_key", dataset._sampler_seq_inputs) - self.assertNotIn("cat_map", dataset._sampler_seq_inputs) + self.assertEqual(dataset._sampler_seq_inputs, {"click_seq__cat_key"}) + self.assertEqual( + dataset._sampler_bare_attr_to_sequence_input, + {"cat_key": "click_seq__cat_key"}, + ) dataset.launch_sampler_cluster(2) - # item_id_field is a candidate-sequence sub-feature: - # - cat_key: list> -> list (one strip). + # data_config.sampler not mutated by the rewrite (deep-copied). + self.assertEqual( + list(dataset._data_config.negative_sampler.attr_fields), + ["cat_map", "cat_key"], + ) + # Sampler sees the QUALIFIED column name after rewrite, plus + # the unchanged top-level `cat_map`. + self.assertIn("click_seq__cat_key", dataset._sampler._attr_names) + self.assertIn("cat_map", dataset._sampler._attr_names) + # Rewrite replaces, doesn't append: bare `cat_key` is gone. + self.assertNotIn("cat_key", dataset._sampler._attr_names) + # Outer-list strip: + # - click_seq__cat_key: list> -> list (one strip). # - cat_map: list, not in _sampler_seq_inputs -> unstripped. cat_key_idx = dataset._sampler._attr_names.index("click_seq__cat_key") cat_map_idx = dataset._sampler._attr_names.index("cat_map") @@ -768,86 +793,6 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): dataset._sampler._attr_types[cat_map_idx], pa.list_(pa.string()) ) - def test_launch_sampler_cluster_bare_attr_resolves_against_seq_prefix(self): - """Bare `attr_fields` get rewritten to the qualified flatten name. - - HSTUMatch-style sampler config: `attr_fields: "cat_key"` (bare - sub-feature name) + `item_id_field: "click_seq__cat_key"` - (qualified). The dataset boundary maps the bare name to the - qualified parquet column via - ``_sampler_bare_attr_to_sequence_input``, so the sampler sees - the fully-flattened input name. - """ - f = tempfile.NamedTemporaryFile("w") - self._temp_files.append(f) - f.write("id:int64\tweight:float\tattrs:string\n") - for i in range(100): - f.write(f"{i}\t1.0\t{i}\n") - f.flush() - - input_fields = [ - pa.field(name="click_seq__cat_key", type=pa.list_(pa.int64())), - pa.field(name="label", type=pa.int32()), - ] - feature_cfgs = [ - feature_pb2.FeatureConfig( - sequence_feature=feature_pb2.SequenceFeature( - sequence_name="click_seq", - sequence_length=10, - sequence_delim=";", - features=[ - feature_pb2.SeqFeatureConfig( - id_feature=feature_pb2.IdFeature( - feature_name="cat_key", - expression="item:cat_key", - num_buckets=10, - embedding_dim=8, - ) - ), - ], - ) - ), - ] - features = create_features( - feature_cfgs, - fg_mode=data_pb2.FgMode.FG_NORMAL, - neg_fields=["cat_key"], - force_base_data_group=True, - ) - dataset = _TestDataset( - data_config=data_pb2.DataConfig( - batch_size=4, - dataset_type=data_pb2.DatasetType.OdpsDataset, - fg_mode=data_pb2.FgMode.FG_NORMAL, - label_fields=["label"], - negative_sampler=sampler_pb2.NegativeSampler( - input_path=f.name, - num_sample=4, - attr_fields=["cat_key"], # bare; gets rewritten - item_id_field="click_seq__cat_key", # qualified - ), - force_base_data_group=True, - ), - features=features, - input_path="", - input_fields=input_fields, - mode=Mode.TRAIN, - ) - # Alias map is built from the grouped sequence's flatten prefix. - self.assertEqual( - dataset._sampler_bare_attr_to_sequence_input, - {"cat_key": "click_seq__cat_key"}, - ) - # data_config.sampler is not mutated by the rewrite (deep-copied). - self.assertEqual( - list(dataset._data_config.negative_sampler.attr_fields), ["cat_key"] - ) - - dataset.launch_sampler_cluster(2) - # Sampler sees the QUALIFIED column name after rewrite. - self.assertIn("click_seq__cat_key", dataset._sampler._attr_names) - self.assertNotIn("cat_key", dataset._sampler._attr_names) - def test_dataset_with_sample_mask(self): input_fields = [ pa.field(name="int_a", type=pa.int64()), From a80f4319f40c7384f04aab204ea509e38a2e8b50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Thu, 21 May 2026 10:19:00 +0800 Subject: [PATCH 5/8] [refactor] dataset: collapse sampler-seq state to a single prefix; only grouped item_id_field supported Domain invariant the prior rounds over-engineered against: when `item_id_field` is sequence-positive, `attr_fields` are uniformly sequence too -- they're all item-side per-positive, and the per-row positive layout makes any non-sequence item-side attr structurally incoherent. Real configs confirm: every NegativeSampler / HardNegativeSampler config in `tzrec/tests/configs/` and `examples/` is either all-scalar (DSSM/MIND/TDM) or uniformly grouped-sequence (HSTUMatch). Collapse to two fields driven by that invariant: _sampler_seq_delim -- "" if not in sequence mode; the matching grouped sequence's `sequence_delim` otherwise. _sampler_seq_prefix -- "" if not in sequence mode; the flatten prefix `f"{sequence_name}{_underline}"` otherwise. Replaces both `_sampler_seq_inputs` (set; dropped) and `_sampler_bare_attr_to_sequence_input` (dict; dropped) from the prior rounds. Rewrite rule shrinks to one line -- unconditional prepend: sampler_config.attr_fields[:] = [ self._sampler_seq_prefix + a for a in sampler_config.attr_fields ] No `startswith` guard for the legacy qualified form -- this PR establishes bare names as the single canonical convention, and the migration burden is small (HSTUMatch is the only model that ever wrote the qualified form; that config is flipped in commit ffae57b). Legacy qualified attr_fields now fail loud at sampler init (the double- prefixed column doesn't exist). `launch_sampler_cluster`'s outer-list strip narrows its filter from `_sampler_seq_inputs` to `set(sampler_config.attr_fields)` -- in every real config `item_id_field` is one of the `attr_fields` entries (post-rewrite), so no separate union is needed. `_merge_sampled_features` collapses to a single delim gate (`if not self._sampler_seq_delim`); the Round 1 per-key `_sampler_seq_inputs` membership check was defensive complexity for the would-be mixed-attrs case that doesn't exist in real configs. Top-level `sequence_id_feature` as `item_id_field` is now rejected via an `assert feature.is_grouped_sequence`: this shape is unused by any current config, and migrating the two `_TestReader` tests that posited it (`test_dataset_with_sampler_list_item_id`, `test_dataset_with_hard_negative_sampler_list_item_id`) to the grouped style matches how multi-positive sampling actually works in production. `_resolve_sampler_attr_field` helper dropped (one-line `.get` on a now-removed dict). `Set` import dropped (no other uses). `test_launch_sampler_cluster_multi_attr_strip_decision_matrix` renamed to `test_launch_sampler_cluster_grouped_sequence_strip_and_rewrite` and reframed: drops the synthetic `cat_map`-in-attr_fields scenario (mixed-attrs configs don't exist in production), keeps `cat_map` as a parquet column NOT in attr_fields to verify strip filter precision, adds the prefix assertion `_sampler_seq_prefix == "click_seq__"`. Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/datasets/dataset.py | 106 +++++++++++---------------- tzrec/datasets/dataset_test.py | 127 +++++++++++++++------------------ tzrec/datasets/utils.py | 3 +- 3 files changed, 101 insertions(+), 135 deletions(-) diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index 520006b7..fb74d5ed 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -13,7 +13,7 @@ import os import random from collections import OrderedDict -from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import numpy as np import pyarrow as pa @@ -166,35 +166,27 @@ def __init__( ): self._selected_input_names = None - # Candidate-side sequence state for sampler I/O. HSTUMatch writes - # bare sub-feature names in `attr_fields`; the alias map rewrites - # them to the flattened parquet columns (`__`) the - # sampler pool actually carries. Prefix derived from - # `feature._underline` (not "__" literal) so RTP-safe. + # When item_id_field is a grouped sequence sub-feature (HSTUMatch's + # candidate side), attr_fields are uniformly sequence too -- item- + # side per-positive, never mixed with top-level scalars. So a + # single delim + prefix capture the entire sequence path; no + # per-attr branching. Prefix uses feature._underline (not "__" + # literal) so RTP-safe. Top-level sequence_id_feature as + # item_id_field is rejected: only grouped is supported. self._sampler_seq_delim: str = "" - self._sampler_seq_inputs: Set[str] = set() - self._sampler_bare_attr_to_sequence_input: Dict[str, str] = {} + self._sampler_seq_prefix: str = "" if self._sampler_item_id_field is not None: for feature in features: - if not feature.sequence_delim: - continue if self._sampler_item_id_field not in feature.sequence_input_names: continue + assert feature.is_grouped_sequence, ( + f"item_id_field '{self._sampler_item_id_field}' is a " + f"sequence input but its matching feature is not a " + f"grouped sequence; only grouped sequence sub-features " + f"are supported as item_id_field." + ) self._sampler_seq_delim = feature.sequence_delim - if feature.is_grouped_sequence: - seq_name = feature.sequence_name - prefix = seq_name + feature._underline - self._sampler_seq_inputs = { - inp - for f in features - if f.is_grouped_sequence and f.sequence_name == seq_name - for inp in f.sequence_input_names - } - self._sampler_bare_attr_to_sequence_input = { - inp[len(prefix) :]: inp for inp in self._sampler_seq_inputs - } - else: - self._sampler_seq_inputs = set(feature.sequence_input_names) + self._sampler_seq_prefix = feature.sequence_name + feature._underline break self._fg_mode = data_config.fg_mode @@ -221,35 +213,31 @@ def launch_sampler_cluster( sampler_type = self._data_config.WhichOneof("sampler") sampler_config = getattr(self._data_config, sampler_type) - # Resolve bare sub-feature `attr_fields` against the grouped- - # sequence flatten prefix when applicable. Bare names (HSTUMatch - # convention, `attr_fields: "video_id"`) map to the qualified - # parquet column (`cand_seq__video_id`); already-qualified - # names and unrelated top-level item-side attrs pass through - # the alias map's `.get(a, a)` fallthrough. Deep-copy so the - # in-place mutation here doesn't leak back to `self._data_config`. - if self._sampler_bare_attr_to_sequence_input: + # Prepend the candidate sequence's flatten prefix to bare + # `attr_fields` entries (HSTUMatch's `"video_id"` becomes the + # qualified parquet column `"cand_seq__video_id"`). Deep-copy + # so the in-place mutation doesn't leak back to + # `self._data_config`. + if self._sampler_seq_prefix: sampler_config = copy.deepcopy(sampler_config) sampler_config.attr_fields[:] = [ - self._resolve_sampler_attr_field(a) - for a in sampler_config.attr_fields + self._sampler_seq_prefix + a for a in sampler_config.attr_fields ] - # Multi-positive sampling: when the sampler's item_id_field is - # itself a sequence-positive train column, the per-row outer - # list on every candidate-sequence attr is the positive- - # grouping container, not a multi-value field. Strip the outer - # list so the sampler sees the pool's native scalar storage and - # _to_arrow_array emits scalar negs directly. Filter is the - # candidate-sequence inputs set, so unrelated grouped sequences - # (e.g. uih_seq__*) and non-sequence item-side attrs from the - # same lookup feature (e.g. `cat_map`) are left untouched. + # Multi-positive sampling: when item_id_field is a candidate- + # sequence sub-feature, the per-row outer list on every + # attr_fields column is the positive-grouping container, not a + # multi-value field. Strip that outer list so the sampler sees + # the pool's native scalar storage and _to_arrow_array emits + # scalar negs directly. item_id_field is one of attr_fields in + # every real sampler config, so no separate union is needed. sampler_fields = self.input_fields - if self._sampler_seq_inputs: + if self._sampler_seq_delim: + sampler_attrs = set(sampler_config.attr_fields) sampler_fields = [ pa.field(f.name, f.type.value_type) if ( - f.name in self._sampler_seq_inputs + f.name in sampler_attrs and (pa.types.is_list(f.type) or pa.types.is_large_list(f.type)) ) else f @@ -399,20 +387,6 @@ def _build_batch(self, input_data: Dict[str, pa.Array]) -> Batch: batch.checkpoint_info = checkpoint_info return batch - def _resolve_sampler_attr_field(self, attr_field: str) -> str: - """Translate a sampler `attr_field` from the user-facing namespace. - - For grouped-sequence candidate configs (e.g. HSTUMatch), the - user-facing `attr_fields` entry can be a bare sub-feature name - (e.g. ``"video_id"``); we resolve it to the flattened parquet - column the sampler pool actually carries (e.g. - ``"cand_seq__video_id"``). For already-qualified names and for - unrelated top-level item-side attrs (e.g. lookup_feature's - ``"cat_map"``), the alias map has no entry and the input is - returned unchanged. - """ - return self._sampler_bare_attr_to_sequence_input.get(attr_field, attr_field) - def _apply_negative_sampler( self, input_data: Dict[str, pa.Array], @@ -496,11 +470,13 @@ def _merge_sampled_features( ) -> Optional[np.ndarray]: """Merge sampler outputs into input_data in place; return per-row pos lengths. - Per sampled key: new keys are assigned as-is; candidate-sequence - keys (those in `_sampler_seq_inputs`) use the block-suffix - combine; others fall back to `pa.concat_arrays`. `pos_lengths` is - sourced from the configured item_id_field combine; returns None - if no candidate-sequence field was merged. + Per sampled key: new keys are assigned as-is. When + `_sampler_seq_delim` is empty (scalar item_id_field), existing + keys are concatenated via `pa.concat_arrays`. When it's set + (candidate-sequence item_id_field), `attr_fields` are uniformly + sequence so every existing key goes through the block-suffix + combine. `pos_lengths` is sourced from the configured + item_id_field combine; returns None when not in sequence mode. """ # Prefer item_id_field; fall back to first-seen seq-field if absent. prefer_key = self._sampler_item_id_field @@ -510,7 +486,7 @@ def _merge_sampled_features( if k not in input_data: input_data[k] = v continue - if k not in self._sampler_seq_inputs: + if not self._sampler_seq_delim: input_data[k] = pa.concat_arrays([input_data[k], v]) continue combined, pl = combine_negs_to_candidate_sequence( diff --git a/tzrec/datasets/dataset_test.py b/tzrec/datasets/dataset_test.py index b3314c0c..7206d7e4 100644 --- a/tzrec/datasets/dataset_test.py +++ b/tzrec/datasets/dataset_test.py @@ -479,10 +479,11 @@ def test_dataset_with_sampler(self, force_base_data_group, mode, input_tile): def test_dataset_with_sampler_list_item_id(self, mode): """E2E: list-typed item_id positives through a real NegativeSampler. - Schema declares `item_id` as `pa.list_(pa.int64())` (Parquet-style - multi-value column). Exercises `build_sampler_input`'s list-pass- - through + flatten, the dynamic-`expand_factor` path in the - sampler, and `combine_negs_to_candidate_sequence`'s list-typed-negs + Schema declares `cand_seq__item_id` (grouped sequence sub-feature + flattened name) as `pa.list_(pa.int64())` (multi-positive column). + Exercises `build_sampler_input`'s list-pass-through + flatten, + the dynamic-`expand_factor` path in the sampler, and + `combine_negs_to_candidate_sequence`'s list-typed-negs normalization. Parameterized over Mode.TRAIN / Mode.EVAL so the eval-mode `num_eval_sample` path is also covered. """ @@ -494,18 +495,25 @@ def test_dataset_with_sampler_list_item_id(self, mode): f.flush() input_fields = [ - pa.field(name="item_id", type=pa.list_(pa.int64())), + pa.field(name="cand_seq__item_id", type=pa.list_(pa.int64())), pa.field(name="label", type=pa.int32()), ] feature_cfgs = [ feature_pb2.FeatureConfig( - sequence_id_feature=feature_pb2.IdFeature( - feature_name="item_id", - expression="item:item_id", + sequence_feature=feature_pb2.SequenceFeature( + sequence_name="cand_seq", sequence_length=10, sequence_delim=";", - num_buckets=200, - embedding_dim=8, + features=[ + feature_pb2.SeqFeatureConfig( + id_feature=feature_pb2.IdFeature( + feature_name="item_id", + expression="item:item_id", + num_buckets=200, + embedding_dim=8, + ) + ), + ], ) ), ] @@ -525,8 +533,8 @@ def test_dataset_with_sampler_list_item_id(self, mode): input_path=f.name, num_sample=8, num_eval_sample=4, - attr_fields=["item_id"], - item_id_field="item_id", + attr_fields=["item_id"], # bare; gets rewritten + item_id_field="cand_seq__item_id", # qualified ), force_base_data_group=True, ), @@ -540,7 +548,7 @@ def test_dataset_with_sampler_list_item_id(self, mode): # Multi-positive mode (item_id is sequence-positive in train): # launch_sampler_cluster strips the outer list so the sampler emits # scalar item_id negs, not 1-elem lists via the multival_sep round-trip. - item_id_idx = dataset._sampler._attr_names.index("item_id") + item_id_idx = dataset._sampler._attr_names.index("cand_seq__item_id") self.assertEqual(dataset._sampler._attr_types[item_id_idx], pa.int64()) dataloader = DataLoader( @@ -599,7 +607,7 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode): input_fields = [ pa.field(name="user_id", type=pa.int64()), - pa.field(name="item_id", type=pa.list_(pa.int64())), + pa.field(name="cand_seq__item_id", type=pa.list_(pa.int64())), pa.field(name="label", type=pa.int32()), ] feature_cfgs = [ @@ -612,13 +620,20 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode): ) ), feature_pb2.FeatureConfig( - sequence_id_feature=feature_pb2.IdFeature( - feature_name="item_id", - expression="item:item_id", + sequence_feature=feature_pb2.SequenceFeature( + sequence_name="cand_seq", sequence_length=10, sequence_delim=";", - num_buckets=200, - embedding_dim=8, + features=[ + feature_pb2.SeqFeatureConfig( + id_feature=feature_pb2.IdFeature( + feature_name="item_id", + expression="item:item_id", + num_buckets=200, + embedding_dim=8, + ) + ), + ], ) ), ] @@ -641,8 +656,8 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode): num_sample=8, num_eval_sample=4, num_hard_sample=2, - attr_fields=["item_id"], - item_id_field="item_id", + attr_fields=["item_id"], # bare; gets rewritten + item_id_field="cand_seq__item_id", # qualified user_id_field="user_id", ), force_base_data_group=True, @@ -655,7 +670,7 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode): dataset.launch_sampler_cluster(2) # Multi-positive mode: outer list stripped so sampler emits scalar negs. - item_id_idx = dataset._sampler._attr_names.index("item_id") + item_id_idx = dataset._sampler._attr_names.index("cand_seq__item_id") self.assertEqual(dataset._sampler._attr_types[item_id_idx], pa.int64()) dataloader = DataLoader( @@ -677,32 +692,25 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode): hard_neg_indices = batch.additional_infos[HARD_NEG_INDICES] self.assertEqual(set(hard_neg_indices[:, 0].tolist()), {0, 1, 2, 3, 4, 5, 6, 7}) - def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): - """End-to-end attr-rewrite + strip decisions across the per-attr filter. - - Grouped LookupFeature sub with two item-side inputs; only - ``cat_key`` is in ``sequence_fields`` so only it is a - candidate-sequence input. ``attr_fields`` mixes the two - user-facing forms the new convention supports: - - - ``cat_key`` (bare grouped-sequence sub-feature name): alias- - map rewrites to ``click_seq__cat_key`` before sampler launch. - - ``cat_map`` (top-level item-side input from the same lookup - feature): not a sequence input, no alias entry, passes - through unchanged. - - Outer-list strip on ``input_fields`` then applies independently: - ``click_seq__cat_key`` is typed ``list>`` (multi- - value attr layered under multi-positive grouping); after strip - it becomes ``list`` (ONE level stripped, not bare- - stripped to ``int64``). ``cat_map`` is ``list``, not in - ``_sampler_seq_inputs``, stays unchanged. + def test_launch_sampler_cluster_grouped_sequence_strip_and_rewrite(self): + """End-to-end prefix-rewrite + outer-list strip for HSTUMatch-shaped configs. + + Grouped LookupFeature sub with ``cat_key`` in ``sequence_fields``. + The sampler's ``attr_fields`` uses the new bare convention + (``"cat_key"``); the dataset boundary prepends the candidate + sequence's flatten prefix (``"click_seq__"``) before sampler + launch. Outer-list strip then applies to attr_fields columns + only: ``click_seq__cat_key`` (``list>``) loses ONE + level (multi-positive container) and stays ``list`` + (multi-value layer). ``cat_map`` (``list``) is an + unrelated parquet column not in ``attr_fields`` -- strip skips + it, so it stays ``list`` unchanged. """ f = tempfile.NamedTemporaryFile("w") self._temp_files.append(f) f.write("id:int64\tweight:float\tattrs:string\n") for i in range(100): - f.write(f"{i}\t1.0\t{i}:{i + 1000}\n") + f.write(f"{i}\t1.0\t{i}\n") f.flush() input_fields = [ @@ -734,7 +742,7 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): features = create_features( feature_cfgs, fg_mode=data_pb2.FgMode.FG_NORMAL, - neg_fields=["cat_map", "cat_key"], + neg_fields=["cat_key"], force_base_data_group=True, ) dataset = _TestDataset( @@ -746,9 +754,7 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): negative_sampler=sampler_pb2.NegativeSampler( input_path=f.name, num_sample=4, - # Bare `cat_key` (gets rewritten) + top-level - # `cat_map` (no alias entry, passes through). - attr_fields=["cat_map", "cat_key"], + attr_fields=["cat_key"], # bare; prefix-rewritten at launch item_id_field="click_seq__cat_key", ), force_base_data_group=True, @@ -758,40 +764,25 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): input_fields=input_fields, mode=Mode.TRAIN, ) - # Candidate-side sequence state derived from item_id_field's - # matching feature. Only sequence-input sub-features end up in - # `_sampler_seq_inputs`; the alias map's keys are the bare - # forms. self.assertEqual(dataset._sampler_seq_delim, ";") - self.assertEqual(dataset._sampler_seq_inputs, {"click_seq__cat_key"}) - self.assertEqual( - dataset._sampler_bare_attr_to_sequence_input, - {"cat_key": "click_seq__cat_key"}, - ) + self.assertEqual(dataset._sampler_seq_prefix, "click_seq__") dataset.launch_sampler_cluster(2) # data_config.sampler not mutated by the rewrite (deep-copied). self.assertEqual( - list(dataset._data_config.negative_sampler.attr_fields), - ["cat_map", "cat_key"], + list(dataset._data_config.negative_sampler.attr_fields), ["cat_key"] ) - # Sampler sees the QUALIFIED column name after rewrite, plus - # the unchanged top-level `cat_map`. + # Sampler sees the QUALIFIED column name; bare `cat_key` is gone. self.assertIn("click_seq__cat_key", dataset._sampler._attr_names) - self.assertIn("cat_map", dataset._sampler._attr_names) - # Rewrite replaces, doesn't append: bare `cat_key` is gone. self.assertNotIn("cat_key", dataset._sampler._attr_names) # Outer-list strip: - # - click_seq__cat_key: list> -> list (one strip). - # - cat_map: list, not in _sampler_seq_inputs -> unstripped. + # - click_seq__cat_key (in attr_fields): list> -> list. + # - cat_map (parquet column NOT in attr_fields): stays list. cat_key_idx = dataset._sampler._attr_names.index("click_seq__cat_key") - cat_map_idx = dataset._sampler._attr_names.index("cat_map") self.assertEqual( dataset._sampler._attr_types[cat_key_idx], pa.list_(pa.int64()) ) - self.assertEqual( - dataset._sampler._attr_types[cat_map_idx], pa.list_(pa.string()) - ) + self.assertNotIn("cat_map", dataset._sampler._attr_names) def test_dataset_with_sample_mask(self): input_fields = [ diff --git a/tzrec/datasets/utils.py b/tzrec/datasets/utils.py index 57646220..d006c0cc 100644 --- a/tzrec/datasets/utils.py +++ b/tzrec/datasets/utils.py @@ -586,8 +586,7 @@ def build_sampler_input( ) -> Dict[str, pa.Array]: """Shallow-copy input_data with item_id (and user_id) flattened for the sampler. - When `item_id_field` is a sequence input (top-level - `sequence_id_feature` or grouped sequence sub-feature), per-row + When `item_id_field` is a grouped sequence sub-feature, per-row positives (delimited string or list array) are flattened to 1D and `user_id_field` (if any) is expanded by per-row positive count. Scalar item_id (`seq_delim=""`) falls through unchanged. The From 25156935c320db97ccd43dfbe7f888b1b69fac1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Thu, 21 May 2026 10:27:06 +0800 Subject: [PATCH 6/8] [polish] dataset: trim verbose comments + docstrings on PR-touched code Comments + docstrings I added in earlier rounds had paragraph-style explanations. Trim to one-liners per `feedback_concise_inline_comments`: * `BaseDataset.__init__` sampler-seq state init: 7-line block -> 1 line. * `launch_sampler_cluster` rewrite block: 5-line comment -> 1 line. * `launch_sampler_cluster` strip block: 7-line comment -> 1 line. * `_merge_sampled_features` docstring: 7 lines -> 4 lines. * `test_launch_sampler_cluster_grouped_sequence_strip_and_rewrite` docstring: 12 lines -> 5; trim three WHAT-style assertion comments. No code-behavior change; tests stay green. Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/datasets/dataset.py | 35 +++++++++------------------------- tzrec/datasets/dataset_test.py | 24 +++++++---------------- 2 files changed, 16 insertions(+), 43 deletions(-) diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index fb74d5ed..a56fc1e7 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -166,13 +166,7 @@ def __init__( ): self._selected_input_names = None - # When item_id_field is a grouped sequence sub-feature (HSTUMatch's - # candidate side), attr_fields are uniformly sequence too -- item- - # side per-positive, never mixed with top-level scalars. So a - # single delim + prefix capture the entire sequence path; no - # per-attr branching. Prefix uses feature._underline (not "__" - # literal) so RTP-safe. Top-level sequence_id_feature as - # item_id_field is rejected: only grouped is supported. + # Sequence state when item_id_field is a grouped sequence sub-feature. self._sampler_seq_delim: str = "" self._sampler_seq_prefix: str = "" if self._sampler_item_id_field is not None: @@ -213,24 +207,16 @@ def launch_sampler_cluster( sampler_type = self._data_config.WhichOneof("sampler") sampler_config = getattr(self._data_config, sampler_type) - # Prepend the candidate sequence's flatten prefix to bare - # `attr_fields` entries (HSTUMatch's `"video_id"` becomes the - # qualified parquet column `"cand_seq__video_id"`). Deep-copy - # so the in-place mutation doesn't leak back to - # `self._data_config`. + # Rewrite bare attr_fields to flattened (`video_id` -> + # `cand_seq__video_id`); deep-copy so data_config isn't mutated. if self._sampler_seq_prefix: sampler_config = copy.deepcopy(sampler_config) sampler_config.attr_fields[:] = [ self._sampler_seq_prefix + a for a in sampler_config.attr_fields ] - # Multi-positive sampling: when item_id_field is a candidate- - # sequence sub-feature, the per-row outer list on every - # attr_fields column is the positive-grouping container, not a - # multi-value field. Strip that outer list so the sampler sees - # the pool's native scalar storage and _to_arrow_array emits - # scalar negs directly. item_id_field is one of attr_fields in - # every real sampler config, so no separate union is needed. + # Strip the per-row positive-grouping outer list on attr_fields + # columns so the sampler emits scalar negs. sampler_fields = self.input_fields if self._sampler_seq_delim: sampler_attrs = set(sampler_config.attr_fields) @@ -470,13 +456,10 @@ def _merge_sampled_features( ) -> Optional[np.ndarray]: """Merge sampler outputs into input_data in place; return per-row pos lengths. - Per sampled key: new keys are assigned as-is. When - `_sampler_seq_delim` is empty (scalar item_id_field), existing - keys are concatenated via `pa.concat_arrays`. When it's set - (candidate-sequence item_id_field), `attr_fields` are uniformly - sequence so every existing key goes through the block-suffix - combine. `pos_lengths` is sourced from the configured - item_id_field combine; returns None when not in sequence mode. + Per sampled key: new keys assigned as-is; otherwise routed via + `pa.concat_arrays` (scalar item_id_field) or the block-suffix + combine (grouped-sequence item_id_field). `pos_lengths` returns + None when not in sequence mode. """ # Prefer item_id_field; fall back to first-seen seq-field if absent. prefer_key = self._sampler_item_id_field diff --git a/tzrec/datasets/dataset_test.py b/tzrec/datasets/dataset_test.py index 7206d7e4..a03389d5 100644 --- a/tzrec/datasets/dataset_test.py +++ b/tzrec/datasets/dataset_test.py @@ -693,18 +693,12 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode): self.assertEqual(set(hard_neg_indices[:, 0].tolist()), {0, 1, 2, 3, 4, 5, 6, 7}) def test_launch_sampler_cluster_grouped_sequence_strip_and_rewrite(self): - """End-to-end prefix-rewrite + outer-list strip for HSTUMatch-shaped configs. - - Grouped LookupFeature sub with ``cat_key`` in ``sequence_fields``. - The sampler's ``attr_fields`` uses the new bare convention - (``"cat_key"``); the dataset boundary prepends the candidate - sequence's flatten prefix (``"click_seq__"``) before sampler - launch. Outer-list strip then applies to attr_fields columns - only: ``click_seq__cat_key`` (``list>``) loses ONE - level (multi-positive container) and stays ``list`` - (multi-value layer). ``cat_map`` (``list``) is an - unrelated parquet column not in ``attr_fields`` -- strip skips - it, so it stays ``list`` unchanged. + """Prefix-rewrite + outer-list strip for HSTUMatch-shaped configs. + + Bare ``attr_fields=["cat_key"]`` is rewritten to + ``["click_seq__cat_key"]``; strip drops the multi-positive outer + list (``list>`` -> ``list``). ``cat_map``, a + parquet column not in ``attr_fields``, is left untouched. """ f = tempfile.NamedTemporaryFile("w") self._temp_files.append(f) @@ -768,16 +762,12 @@ def test_launch_sampler_cluster_grouped_sequence_strip_and_rewrite(self): self.assertEqual(dataset._sampler_seq_prefix, "click_seq__") dataset.launch_sampler_cluster(2) - # data_config.sampler not mutated by the rewrite (deep-copied). + # Deep-copy guard: data_config not mutated by the rewrite. self.assertEqual( list(dataset._data_config.negative_sampler.attr_fields), ["cat_key"] ) - # Sampler sees the QUALIFIED column name; bare `cat_key` is gone. self.assertIn("click_seq__cat_key", dataset._sampler._attr_names) self.assertNotIn("cat_key", dataset._sampler._attr_names) - # Outer-list strip: - # - click_seq__cat_key (in attr_fields): list> -> list. - # - cat_map (parquet column NOT in attr_fields): stays list. cat_key_idx = dataset._sampler._attr_names.index("click_seq__cat_key") self.assertEqual( dataset._sampler._attr_types[cat_key_idx], pa.list_(pa.int64()) From b0213f8874e14dcc603808de61a1b2d52ac84bc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Thu, 21 May 2026 11:11:40 +0800 Subject: [PATCH 7/8] [polish] dataset_test: expand grouped-sequence test to cover multiple sub-feature types `test_launch_sampler_cluster_grouped_sequence_strip_and_rewrite` used a single LookupFeature sub with `sequence_fields=["cat_key"]` to construct one `list>` candidate input. The synthetic shape (multi-value layered under multi-positive) doesn't reflect HSTUMatch's actual usage where the candidate sequence carries multiple sub-features of varied types (id, raw, ...). Replace with three sub-features under a single `sequence_feature`: * `id_feature item_id` -> `click_seq__item_id: list` * `id_feature cat_id` -> `click_seq__cat_id: list` * `raw_feature watch_time` -> `click_seq__watch_time: list` `attr_fields=["item_id", "cat_id", "watch_time"]` exercises the prefix-prepend rewrite across all three; the strip filter scopes correctly per type (int64 / int64 / float32). Drops the `sequence_fields` complexity (LookupFeature-only knob) and the `cat_map` parquet column (the `assertNotIn("cat_map", ...)` assertion was checking sampler-side state unrelated to the strip filter's behavior). Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/datasets/dataset_test.py | 62 ++++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 26 deletions(-) diff --git a/tzrec/datasets/dataset_test.py b/tzrec/datasets/dataset_test.py index a03389d5..5c0276bf 100644 --- a/tzrec/datasets/dataset_test.py +++ b/tzrec/datasets/dataset_test.py @@ -693,23 +693,18 @@ def test_dataset_with_hard_negative_sampler_list_item_id(self, mode): self.assertEqual(set(hard_neg_indices[:, 0].tolist()), {0, 1, 2, 3, 4, 5, 6, 7}) def test_launch_sampler_cluster_grouped_sequence_strip_and_rewrite(self): - """Prefix-rewrite + outer-list strip for HSTUMatch-shaped configs. - - Bare ``attr_fields=["cat_key"]`` is rewritten to - ``["click_seq__cat_key"]``; strip drops the multi-positive outer - list (``list>`` -> ``list``). ``cat_map``, a - parquet column not in ``attr_fields``, is left untouched. - """ + """Prefix-rewrite + outer-list strip across multiple sub-feature types.""" f = tempfile.NamedTemporaryFile("w") self._temp_files.append(f) f.write("id:int64\tweight:float\tattrs:string\n") for i in range(100): - f.write(f"{i}\t1.0\t{i}\n") + f.write(f"{i}\t1.0\t{i}:{i + 1000}:0.5\n") f.flush() input_fields = [ - pa.field(name="cat_map", type=pa.list_(pa.string())), - pa.field(name="click_seq__cat_key", type=pa.list_(pa.list_(pa.int64()))), + pa.field(name="click_seq__item_id", type=pa.list_(pa.int64())), + pa.field(name="click_seq__cat_id", type=pa.list_(pa.int64())), + pa.field(name="click_seq__duration", type=pa.list_(pa.float32())), pa.field(name="label", type=pa.int32()), ] feature_cfgs = [ @@ -720,15 +715,27 @@ def test_launch_sampler_cluster_grouped_sequence_strip_and_rewrite(self): sequence_delim=";", features=[ feature_pb2.SeqFeatureConfig( - lookup_feature=feature_pb2.LookupFeature( - feature_name="lookup_c", - map="item:cat_map", - key="item:cat_key", - sequence_fields=["cat_key"], + id_feature=feature_pb2.IdFeature( + feature_name="item_id", + expression="item:item_id", + num_buckets=200, + embedding_dim=8, + ) + ), + feature_pb2.SeqFeatureConfig( + id_feature=feature_pb2.IdFeature( + feature_name="cat_id", + expression="item:cat_id", num_buckets=10, embedding_dim=8, ) ), + feature_pb2.SeqFeatureConfig( + raw_feature=feature_pb2.RawFeature( + feature_name="duration", + expression="item:duration", + ) + ), ], ) ), @@ -736,7 +743,7 @@ def test_launch_sampler_cluster_grouped_sequence_strip_and_rewrite(self): features = create_features( feature_cfgs, fg_mode=data_pb2.FgMode.FG_NORMAL, - neg_fields=["cat_key"], + neg_fields=["item_id", "cat_id", "duration"], force_base_data_group=True, ) dataset = _TestDataset( @@ -748,8 +755,9 @@ def test_launch_sampler_cluster_grouped_sequence_strip_and_rewrite(self): negative_sampler=sampler_pb2.NegativeSampler( input_path=f.name, num_sample=4, - attr_fields=["cat_key"], # bare; prefix-rewritten at launch - item_id_field="click_seq__cat_key", + # bare names; prefix-rewritten at launch + attr_fields=["item_id", "cat_id", "duration"], + item_id_field="click_seq__item_id", ), force_base_data_group=True, ), @@ -764,15 +772,17 @@ def test_launch_sampler_cluster_grouped_sequence_strip_and_rewrite(self): dataset.launch_sampler_cluster(2) # Deep-copy guard: data_config not mutated by the rewrite. self.assertEqual( - list(dataset._data_config.negative_sampler.attr_fields), ["cat_key"] - ) - self.assertIn("click_seq__cat_key", dataset._sampler._attr_names) - self.assertNotIn("cat_key", dataset._sampler._attr_names) - cat_key_idx = dataset._sampler._attr_names.index("click_seq__cat_key") - self.assertEqual( - dataset._sampler._attr_types[cat_key_idx], pa.list_(pa.int64()) + list(dataset._data_config.negative_sampler.attr_fields), + ["item_id", "cat_id", "duration"], ) - self.assertNotIn("cat_map", dataset._sampler._attr_names) + # Each bare entry prefix-rewritten and outer-list stripped. + for qualified, expected_type in [ + ("click_seq__item_id", pa.int64()), + ("click_seq__cat_id", pa.int64()), + ("click_seq__duration", pa.float32()), + ]: + idx = dataset._sampler._attr_names.index(qualified) + self.assertEqual(dataset._sampler._attr_types[idx], expected_type) def test_dataset_with_sample_mask(self): input_fields = [ From 8f6f4fb1e659dc7a7189b54bff5b7bd8aff412b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Thu, 21 May 2026 12:06:42 +0800 Subject: [PATCH 8/8] [refactor] feature: grouped_sequence_prefix property; dataset: assert -> raise ValueError Two PR-review fixes on top of the Round 3 prefix collapse: 1. BaseDataset's `assert feature.is_grouped_sequence, ...` is user- facing config validation, not an internal invariant. Under `python -O`/`-OO` the assert is stripped and the next line -- `feature.sequence_name + feature._underline` -- runs unconditionally; for the rejected top-level sequence_id_feature shape `sequence_name` is None (feature.py:427), producing `TypeError: unsupported operand type(s) for +: 'NoneType' and 'str'`. Convert to `raise ValueError` matching the explicit-raise pattern at dataset.py:260. 2. `feature._underline` (feature.py:426) is a private attribute; its prior call sites at feature.py:473, 749, 787 were all in-class. The Round 3 cross-module read from dataset.py is the first one crossing the underscore boundary. Add a public `grouped_sequence_prefix` property on BaseFeature next to `is_grouped_sequence`: @property def grouped_sequence_prefix(self) -> str: return f"{self.sequence_name}{self._underline}" if self._is_grouped_seq else "" Empty-string for non-grouped features so callers can use it unconditionally. Migrate the dataset.py read AND the three in-class call sites: * feature.py:469-475 -- `name` property: the if/else for grouped vs not collapses to one line via the property's "" fallback. * feature.py:749 -- _is_sequence_input: prefix construction. * feature.py:787 -- side_inputs: seq_prefix construction. The explicit raise from (1) stays as the source of truth for config validation (the property's "" fallback only protects against the None+str TypeError; without the explicit raise a non-grouped item_id would silently flow through with empty prefix + set delim, producing a half-configured sequence path that fails further downstream). Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/datasets/dataset.py | 15 ++++++++------- tzrec/features/feature.py | 14 ++++++++------ 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index a56fc1e7..e589cfee 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -173,14 +173,15 @@ def __init__( for feature in features: if self._sampler_item_id_field not in feature.sequence_input_names: continue - assert feature.is_grouped_sequence, ( - f"item_id_field '{self._sampler_item_id_field}' is a " - f"sequence input but its matching feature is not a " - f"grouped sequence; only grouped sequence sub-features " - f"are supported as item_id_field." - ) + if not feature.is_grouped_sequence: + raise ValueError( + f"item_id_field '{self._sampler_item_id_field}' is a " + "sequence input but its matching feature is not a " + "grouped sequence; only grouped sequence sub-features " + "are supported as item_id_field." + ) self._sampler_seq_delim = feature.sequence_delim - self._sampler_seq_prefix = feature.sequence_name + feature._underline + self._sampler_seq_prefix = feature.grouped_sequence_prefix break self._fg_mode = data_config.fg_mode diff --git a/tzrec/features/feature.py b/tzrec/features/feature.py index 55cc587f..bd5bb416 100644 --- a/tzrec/features/feature.py +++ b/tzrec/features/feature.py @@ -469,10 +469,7 @@ def __init__( @property def name(self) -> str: """Feature name.""" - if self._is_grouped_seq: - return f"{self.sequence_name}{self._underline}{self.config.feature_name}" - else: - return self.config.feature_name + return f"{self.grouped_sequence_prefix}{self.config.feature_name}" @property def is_neg(self) -> bool: @@ -568,6 +565,11 @@ def is_grouped_sequence(self) -> bool: """Feature is grouped sequence or not.""" return self._is_grouped_seq + @property + def grouped_sequence_prefix(self) -> str: + """Flatten prefix ``<_underline>``; empty if not grouped.""" + return f"{self.sequence_name}{self._underline}" if self._is_grouped_seq else "" + @property def is_weighted(self) -> bool: """Feature is weighted id feature or not.""" @@ -746,7 +748,7 @@ def _is_sequence_input(self, side: str, name: str) -> bool: if not self.is_sequence: return False if self._is_grouped_seq and self.sequence_name: - prefix = f"{self.sequence_name}{self._underline}" + prefix = self.grouped_sequence_prefix if name.startswith(prefix): name = name[len(prefix) :] if ( @@ -784,7 +786,7 @@ def side_inputs(self) -> List[Tuple[str, str]]: ) side, name = x[0], x[1] seq_prefix = ( - f"{self.sequence_name}{self._underline}" + self.grouped_sequence_prefix if self._need_seq_prefix(side, name) else "" )