From 1a8b7d1a59e4d1e4e01931f4bb7fbc3f5c1ab107 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 20 May 2026 11:40:26 +0800 Subject: [PATCH 01/16] [fix] sampler: resolve bare attr_fields against item_id_field's sequence prefix HSTUMatch is the only model whose candidate side is a grouped sequence sub-feature; PR #506 papered over this by writing `attr_fields: "cand_seq__video_id"` in the test config and docs -- the *flattened parquet column name*. Every other sampler config uses the bare sub-feature name (`attr_fields: "item_id"` in DSSM/MIND/ TDM) because their candidate is a top-level feature where bare-name == flattened-name. The HSTUMatch outlier leaks DataParser's `{sequence_name}__{sub_feature}` flattening convention into user-facing config. Resolve at the dataset boundary: when `item_id_field` carries a qualified `{sequence_name}__{sub_feature}` form, derive the prefix and rewrite any bare `attr_fields` entries to the flattened form before constructing the sampler. Deep-copy the sampler sub-message so the original `data_config` is not mutated. The sampler then sees fully-qualified names just like today; `_valid_attr_names`, `_attr_types`, the sampled output dict, and `_merge_sampled_features` all work unchanged. Literal matches win, so already-qualified configs continue to work, and DSSM/MIND/TDM (where `item_id_field` is bare) skip the resolution branch entirely. Test config + doc switch to: attr_fields: "video_id" (bare sub-feature name) item_id_field: "cand_seq__video_id" (qualified; doubles as the sequence_name source) Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/source/models/hstu_match.md | 4 +++- tzrec/datasets/dataset.py | 26 ++++++++++++++++++++- tzrec/tests/configs/hstu_kuairand_1k.config | 2 +- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/docs/source/models/hstu_match.md b/docs/source/models/hstu_match.md index 1ec1c6cf..6e444de7 100644 --- a/docs/source/models/hstu_match.md +++ b/docs/source/models/hstu_match.md @@ -16,7 +16,7 @@ data_config { negative_sampler { input_path: "odps://{PROJECT}/tables/taobao_ad_feature_gl_bucketized_v1" num_sample: 128 - attr_fields: "cand_seq__video_id" + attr_fields: "video_id" item_id_field: "cand_seq__video_id" attr_delimiter: "\t" } @@ -211,6 +211,8 @@ model_config { - data_config: 数据配置,其中需要配置负采样 Sampler,负采样 Sampler 的配置详见 [DSSM](dssm.md) 文档中的**负采样配置**章节 + - HSTUMatch 的候选侧是 `sequence_feature` 的子特征。在 `negative_sampler` 中,`item_id_field` 写为带 sequence_name 的全限定名 (例如 `cand_seq__video_id`),`attr_fields` 写为不带前缀的子特征名 (例如 `video_id`);dataset 层会根据 `item_id_field` 的前缀自动把 `attr_fields` 补齐为 `cand_seq__video_id` 后传给采样器 + - feature_groups: 特征组 - uih: 用户历史行为序列,可增加 side info;类型为 JAGGED_SEQUENCE,**必填** diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index 21459fea..c77ae7a3 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import os import random from collections import OrderedDict @@ -209,7 +210,9 @@ def launch_sampler_cluster( """Launch sampler cluster and server.""" if self._data_config.HasField("sampler") and self._mode != Mode.PREDICT: sampler_type = self._data_config.WhichOneof("sampler") - sampler_config = getattr(self._data_config, sampler_type) + # Deep-copy so any in-place rewrites below don't mutate + # `self._data_config`'s sampler sub-message. + sampler_config = copy.deepcopy(getattr(self._data_config, sampler_type)) # Multi-positive sampling: when the sampler's item_id_field is # itself a sequence-positive train column, the per-row outer list @@ -233,6 +236,27 @@ def launch_sampler_cluster( for f in self.input_fields ] + # Resolve bare candidate sub-feature names in `attr_fields` against + # `sampler_fields` (flattened parquet schema), using the prefix + # carried by a qualified `item_id_field` (e.g. "cand_seq__video_id" + # -> "cand_seq__"). When `item_id_field` is itself bare + # (DSSM/MIND/TDM top-level case), `seq_prefix` is empty and the + # loop is a no-op -- existing configs are unchanged. + if hasattr(sampler_config, "item_id_field") and sampler_config.HasField( + "item_id_field" + ): + id_field = sampler_config.item_id_field + field_names = {f.name for f in sampler_fields} + if "__" in id_field and id_field in field_names: + seq_prefix = id_field.split("__", 1)[0] + "__" + if hasattr(sampler_config, "attr_fields"): + sampler_config.attr_fields[:] = [ + seq_prefix + a + if a not in field_names and seq_prefix + a in field_names + else a + for a in sampler_config.attr_fields + ] + # pyre-ignore [16] self._sampler = BaseSampler.create_class(sampler_config.__class__.__name__)( sampler_config, diff --git a/tzrec/tests/configs/hstu_kuairand_1k.config b/tzrec/tests/configs/hstu_kuairand_1k.config index 31df5d47..89d99987 100644 --- a/tzrec/tests/configs/hstu_kuairand_1k.config +++ b/tzrec/tests/configs/hstu_kuairand_1k.config @@ -33,7 +33,7 @@ data_config { negative_sampler { input_path: "data/test/kuairand-1k-match-item-gl.txt" num_sample: 128 - attr_fields: "cand_seq__video_id" + attr_fields: "video_id" item_id_field: "cand_seq__video_id" attr_delimiter: "\t" } From aa76da2b2b8a375c80683456d0a020bb16ac06c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 20 May 2026 11:43:53 +0800 Subject: [PATCH 02/16] [feat] features: project_grouped_sequence_feature_to_scalar helper for export view MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a module-level helper that projects a grouped sequence sub-feature (e.g. `cand_seq__video_id`) into a top-level scalar `FeatureConfig` suitable for `create_features()` to construct as a non-sequence feature named `video_id`. The grouped sub-feature carries `SeqFeatureConfig`; the helper rewraps the contained oneof message (id_feature / raw_feature / combo_feature / ... — generic across the entire oneof) as a fresh `FeatureConfig`. It also materializes the source feature's effective `default_value` and `value_dim` onto the scalar proto: a sequence sub-feature with no explicit `default_value` resolves to `"0"` and `value_dim` to `1` (feature.py:556, 515-517), but a scalar with no explicit values would fall back to `""` / `0`. Without materialization the exported scalar feature would drift from the training semantics. Used by the HSTUMatchItemTower scalar export view (next commit) to swap from `cand_seq__video_id` (jagged) to `video_id` (scalar) without mutating the training feature objects. Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/features/feature.py | 40 ++++++++++++++++ tzrec/features/feature_test.py | 85 ++++++++++++++++++++++++++++++++++ 2 files changed, 125 insertions(+) diff --git a/tzrec/features/feature.py b/tzrec/features/feature.py index 55cc587f..7f5ef843 100644 --- a/tzrec/features/feature.py +++ b/tzrec/features/feature.py @@ -1228,6 +1228,46 @@ def create_features( return features +def project_grouped_sequence_feature_to_scalar( + feature: BaseFeature, +) -> feature_pb2.FeatureConfig: + """Return a scalar export FeatureConfig for a grouped sequence sub-feature. + + Materializes sequence-effective behaviour (default_value, value_dim) into + the scalar proto so the exported scalar feature semantically matches the + training sub-feature -- without this, defaults drift from "0" / 1 to + "" / 0 because `is_sequence=False` resolves differently. The grouped + sub-feature's config is a `SeqFeatureConfig`; rewrap it as a top-level + `FeatureConfig` so `create_features` builds it as a scalar feature. + + Args: + feature: a grouped sequence sub-feature. + + Returns: + a fresh FeatureConfig suitable for `create_features()` to construct + as a top-level scalar feature. + """ + if not feature.is_grouped_sequence: + raise ValueError( + "project_grouped_sequence_feature_to_scalar only accepts grouped " + f"sequence sub-features; got {feature.name} " + "(is_grouped_sequence=False)" + ) + src_cfg = feature.feature_config # SeqFeatureConfig + feat_type = src_cfg.WhichOneof("feature") + src_msg = getattr(src_cfg, feat_type) + + scalar_cfg = feature_pb2.FeatureConfig() + dst_msg = getattr(scalar_cfg, feat_type) + dst_msg.CopyFrom(src_msg) + + if hasattr(dst_msg, "default_value") and not dst_msg.default_value: + dst_msg.default_value = feature.default_value + if hasattr(dst_msg, "value_dim") and not dst_msg.HasField("value_dim"): + dst_msg.value_dim = feature.value_dim + return scalar_cfg + + def _copy_assets( feature: BaseFeature, asset_dir: Optional[str] = None, diff --git a/tzrec/features/feature_test.py b/tzrec/features/feature_test.py index 72267897..9137c009 100644 --- a/tzrec/features/feature_test.py +++ b/tzrec/features/feature_test.py @@ -750,5 +750,90 @@ def test_sequence_input_names(self, fg_mode): ) +class ProjectGroupedSequenceFeatureToScalarTest(unittest.TestCase): + def _build_grouped(self, seq_sub_cfg): + feature_cfgs = [ + feature_pb2.FeatureConfig( + sequence_feature=feature_pb2.SequenceFeature( + sequence_name="cand_seq", + sequence_delim="|", + sequence_length=100, + features=[seq_sub_cfg], + ) + ), + ] + return feature_lib.create_features(feature_cfgs) + + def test_id_feature_projection_materializes_seq_defaults(self): + sub_cfg = feature_pb2.SeqFeatureConfig( + id_feature=feature_pb2.IdFeature( + feature_name="video_id", + expression="item:video_id", + embedding_dim=32, + num_buckets=10000000, + ) + ) + features = self._build_grouped(sub_cfg) + self.assertEqual(len(features), 1) + sub_feature = features[0] + self.assertTrue(sub_feature.is_grouped_sequence) + # Sequence-effective defaults on the source feature. + self.assertEqual(sub_feature.default_value, "0") + self.assertEqual(sub_feature.value_dim, 1) + + scalar_cfg = feature_lib.project_grouped_sequence_feature_to_scalar(sub_feature) + self.assertEqual(scalar_cfg.WhichOneof("feature"), "id_feature") + # default_value / value_dim materialized into the scalar proto. + self.assertEqual(scalar_cfg.id_feature.default_value, "0") + self.assertTrue(scalar_cfg.id_feature.HasField("value_dim")) + self.assertEqual(scalar_cfg.id_feature.value_dim, 1) + # Original sub-feature proto is not mutated. + self.assertEqual(sub_feature.feature_config.id_feature.default_value, "") + self.assertFalse(sub_feature.feature_config.id_feature.HasField("value_dim")) + + def test_projection_passes_through_create_features_as_scalar(self): + sub_cfg = feature_pb2.SeqFeatureConfig( + id_feature=feature_pb2.IdFeature( + feature_name="video_id", + expression="item:video_id", + embedding_dim=32, + num_buckets=10000000, + ) + ) + features = self._build_grouped(sub_cfg) + scalar_cfg = feature_lib.project_grouped_sequence_feature_to_scalar(features[0]) + scalar_features = feature_lib.create_features([scalar_cfg]) + self.assertEqual(len(scalar_features), 1) + scalar = scalar_features[0] + # Bare sub-feature name without the cand_seq__ prefix. + self.assertEqual(scalar.name, "video_id") + self.assertFalse(scalar.is_grouped_sequence) + # Scalar context: is_sequence is False so value_dim default would be + # 0, but our projection materialized it from the source. + self.assertEqual(scalar.value_dim, 1) + + def test_projection_rejects_non_grouped_feature(self): + feature_cfgs = [ + feature_pb2.FeatureConfig( + id_feature=feature_pb2.IdFeature(feature_name="user_id") + ), + ] + features = feature_lib.create_features(feature_cfgs) + with self.assertRaisesRegex(ValueError, "is_grouped_sequence=False"): + feature_lib.project_grouped_sequence_feature_to_scalar(features[0]) + + def test_raw_feature_projection_generic_oneof(self): + # Confirms the helper is not hard-coded to id_feature; raw_feature + # works the same way. + sub_cfg = feature_pb2.SeqFeatureConfig( + raw_feature=feature_pb2.RawFeature( + feature_name="watch_time", expression="user:watch_time" + ) + ) + features = self._build_grouped(sub_cfg) + scalar_cfg = feature_lib.project_grouped_sequence_feature_to_scalar(features[0]) + self.assertEqual(scalar_cfg.WhichOneof("feature"), "raw_feature") + + if __name__ == "__main__": unittest.main() From 7a9318ff4ac10c4d3125dbaae0a1be81c6ae5d8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 20 May 2026 11:49:17 +0800 Subject: [PATCH 03/16] [feat] HSTUMatchItemTower: scalar export view via set_is_inference(True) Make the item tower own its training/export view switch: - `__init__` introduces `self._cand_key = "{group_name}.sequence"` (training view). `forward()` reads `grouped_features[self._cand_key]` instead of the hard-coded `.sequence`. - `set_is_inference(True)` projects each grouped sequence sub-feature into a top-level scalar `FeatureConfig` via the helper from the previous commit, rebuilds `self._features` as scalar features, swaps `self._feature_groups` to a single `JAGGED_SEQUENCE` group over the scalar names (so `EmbeddingGroup` emits `{group_name}.query` per row instead of jagged-per-row), and flips `self._cand_key` to `.query`. Idempotent: re-entering after the swap is a no-op. `set_is_inference(False)` does not rebuild the training view -- callers must export from a `copy.deepcopy` of the item tower so the training tower is preserved. In `tzrec/utils/export_util.py::export_model_normal`, save the wrapper's `model._feature_groups` onto the saved `pipeline.config.model_config.feature_groups`. Today the saved config keeps the original (training-shape) feature_groups even when the exported `feature_configs` have been rewritten -- after the HSTUMatchItemTower view swap, this mismatch shows up as scalar `feature_configs` paired with stale `cand_seq__video_id` group names. Non-match towers expose the same `_feature_groups` they had before, so DSSM/MIND/etc. behaviour is unchanged. `TowerWoEGWrapper` (match_model.py:481-486) already reads `module._features` / `module._feature_groups` -- no wrapper changes required. Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/models/hstu.py | 65 +++++++++++++++++++++++++++++++++++--- tzrec/utils/export_util.py | 8 +++++ 2 files changed, 69 insertions(+), 4 deletions(-) diff --git a/tzrec/models/hstu.py b/tzrec/models/hstu.py index ae4be071..1b8ad4f0 100644 --- a/tzrec/models/hstu.py +++ b/tzrec/models/hstu.py @@ -15,8 +15,17 @@ import torch.nn.functional as F from torch import nn -from tzrec.datasets.utils import CAND_POS_LENGTHS, HARD_NEG_INDICES, Batch -from tzrec.features.feature import BaseFeature +from tzrec.datasets.utils import ( + BASE_DATA_GROUP, + CAND_POS_LENGTHS, + HARD_NEG_INDICES, + Batch, +) +from tzrec.features.feature import ( + BaseFeature, + create_features, + project_grouped_sequence_feature_to_scalar, +) from tzrec.models.match_model import MatchModel, MatchTowerWoEG from tzrec.modules.embedding import EmbeddingGroup from tzrec.modules.gr.hstu_transducer import HSTUMatchEncoder @@ -151,7 +160,13 @@ def __init__( # tower_config.input names on the user-tower proto). Use the item-side # tower_config.input here, which equals feature_groups[0].group_name. self._group_name = tower_config.input - candidate_dims = embedding_group.group_dims(f"{self._group_name}.sequence") + # Training view: candidate group is JAGGED_SEQUENCE; embedding_group + # emits the per-row jagged tensor at `{group_name}.sequence`. The + # export-time scalar view (see `set_is_inference(True)`) swaps this + # to `{group_name}.query` -- one row per item, no positive-set + # grouping container. + self._cand_key = f"{self._group_name}.sequence" + candidate_dims = embedding_group.group_dims(self._cand_key) candidate_total_dim = sum(candidate_dims) if tower_config.HasField("mlp"): self.mlp: torch.nn.Module = MLP( @@ -166,6 +181,48 @@ def __init__( if self._output_dim > 0: self.output = nn.Linear(mlp_out_dim, output_dim) + def set_is_inference(self, is_inference: bool) -> None: + """Switch the tower into the scalar item-export view (one-way). + + Training stays at `candidate.sequence` (jagged sub-feature). Export + copies the tower (`copy.deepcopy`), calls this with `True`, and the + wrapper picks up the swapped `_features` / `_feature_groups` / + `_cand_key` for serving. Idempotent: re-entering after the swap is a + no-op. `set_is_inference(False)` does NOT rebuild the training view; + callers must export from a copy so the original training tower is + preserved. + """ + super().set_is_inference(is_inference) + if not is_inference: + return + if self._cand_key.endswith(".query"): + return # already swapped (idempotent) + + # Project each grouped sequence sub-feature into a scalar export + # FeatureConfig, then rebuild as top-level scalar features. + scalar_configs = [ + project_grouped_sequence_feature_to_scalar(f) for f in self._features + ] + source = self._features[0] + scalar_features = create_features( + scalar_configs, + fg_mode=source.fg_mode, + neg_fields=None, + fg_encoded_multival_sep=source._fg_encoded_multival_sep, + force_base_data_group=any( + f.data_group == BASE_DATA_GROUP for f in self._features + ), + ) + self._features = scalar_features + self._feature_groups = [ + model_pb2.FeatureGroupConfig( + group_name=self._group_name, + feature_names=[f.name for f in scalar_features], + group_type=model_pb2.JAGGED_SEQUENCE, + ) + ] + self._cand_key = f"{self._group_name}.query" + def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: """Forward the item tower. @@ -175,7 +232,7 @@ def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: Returns: item embeddings of shape (sum_candidates, D). """ - cand_emb = grouped_features[f"{self._group_name}.sequence"] + cand_emb = grouped_features[self._cand_key] item_emb = self.mlp(cand_emb) if self._output_dim > 0: item_emb = self.output(item_emb) diff --git a/tzrec/utils/export_util.py b/tzrec/utils/export_util.py index be786f9a..aebeb1c5 100644 --- a/tzrec/utils/export_util.py +++ b/tzrec/utils/export_util.py @@ -281,6 +281,14 @@ def export_model_normal( pipeline_config = copy.copy(pipeline_config) pipeline_config.ClearField("feature_configs") pipeline_config.feature_configs.extend(feature_configs) + # Towers that own a view-specific feature_groups (e.g. + # `HSTUMatchItemTower` after `set_is_inference(True)` swaps to the + # scalar item view) must save those groups too, otherwise the + # exported pipeline.config pairs scalar feature_configs with stale + # training feature_group names. + if hasattr(model, "_feature_groups"): + pipeline_config.model_config.ClearField("feature_groups") + pipeline_config.model_config.feature_groups.extend(model._feature_groups) config_util.save_message( pipeline_config, os.path.join(save_dir, "pipeline.config") ) From 0c1dcf1d8fc0933ef8bd0d519a38d73a0648e3b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 20 May 2026 13:06:39 +0800 Subject: [PATCH 04/16] [feat] HSTUMatch item-tower scalar export view + item_input_path routing Wires the HSTUMatchItemTower scalar export view (`candidate.query` per row over scalar `video_id` instead of the training `candidate.sequence` over jagged `cand_seq__video_id`) into the export pipeline. End-to-end: * `tzrec/protos/pipeline.proto`: add `optional string item_input_path = 10` on `EasyRecConfig`. Item-only table (one row per item, scalar export-view schema) for recall-model item-tower export. * `tzrec/utils/export_util.py::export_model_normal` (and the `_aot` variant): when `model._tower_name == "item_tower"` and `pipeline_config.item_input_path` is set, point the predict dataloader at it instead of `train_input_path`, and clear `data_config.sampler` so the predict path doesn't launch GraphLearn for item-only rows. Also save `model._feature_groups` onto `pipeline_config.model_config.feature_groups` so the exported config is internally consistent (scalar feature_configs paired with scalar feature_groups, not stale training groups). * `tzrec/models/model.py::ScriptWrapper.__init__`: propagate `_tower_name` from the inner module so `export_util.py` can see it through the wrapper. * `tzrec/main.py::export`: for match-model item towers, build the wrapper around `copy.deepcopy(module).set_is_inference(True)` so the view flip happens BEFORE `TowerWoEGWrapper` constructs its `EmbeddingGroup`. The original training tower is untouched. * `tzrec/models/hstu.py::HSTUMatchItemTower.set_is_inference`: drop the `super().set_is_inference()` call -- `MatchTowerWoEG` derives from `nn.Module`, not `BaseModule`. The `_is_inference` attribute is propagated to sub-modules separately by `ScriptWrapper.set_is_inference` via `recursive_setattr` during export. * `tzrec/features/feature.py::project_grouped_sequence_feature_to_scalar`: drop the materialization of sequence-effective `default_value` / `value_dim` onto the scalar proto. The scalar-mode defaults (empty / 0) are intentional: at item-export predict time the parquet provides one value per row, and PyFG's `id_feature` operator expects the scalar (not sequence) input shape. * `tzrec/tests/configs/hstu_kuairand_1k.config`: set `item_input_path: "data/test/kuairand-1k-match-item-c1.parquet"`. * `tzrec/tests/match_integration_test.py::test_hstu_with_fg_train_eval`: extend the body to also exercise `test_export` (AOT export with `ENABLE_AOT=1` + `DISABLE_MMA_V3=1` -- same pattern as dlrm_hstu), `test_predict` on the item tower (reading the new `kuairand-1k-match-item-c1.parquet`, emitting `item_tower_emb`), and `test_predict` on the user tower (reading the eval parquet for the sequence-shape user-side path). Assert both `export/{user,item}/scripted_sparse_model.pt` exist. The new test fixture `data/test/kuairand-1k-match-item-c1.parquet` (1000 rows, single `video_id: int64` column; md5 8dcadabdc3e9049ed9c2250565b4b134) is built locally by `experiments/hstu_match/build_kuairand_fixtures.py`; the `ci_data.sh` wget line will land after the user uploads it to OSS. Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/features/feature.py | 17 +++----- tzrec/features/feature_test.py | 18 ++++---- tzrec/main.py | 11 ++++- tzrec/models/hstu.py | 7 ++- tzrec/models/model.py | 5 +++ tzrec/protos/pipeline.proto | 9 ++++ tzrec/tests/configs/hstu_kuairand_1k.config | 1 + tzrec/tests/match_integration_test.py | 48 ++++++++++++++++++++- tzrec/utils/export_util.py | 24 ++++++++--- 9 files changed, 112 insertions(+), 28 deletions(-) diff --git a/tzrec/features/feature.py b/tzrec/features/feature.py index 7f5ef843..84f7fcfc 100644 --- a/tzrec/features/feature.py +++ b/tzrec/features/feature.py @@ -1233,12 +1233,12 @@ def project_grouped_sequence_feature_to_scalar( ) -> feature_pb2.FeatureConfig: """Return a scalar export FeatureConfig for a grouped sequence sub-feature. - Materializes sequence-effective behaviour (default_value, value_dim) into - the scalar proto so the exported scalar feature semantically matches the - training sub-feature -- without this, defaults drift from "0" / 1 to - "" / 0 because `is_sequence=False` resolves differently. The grouped - sub-feature's config is a `SeqFeatureConfig`; rewrap it as a top-level - `FeatureConfig` so `create_features` builds it as a scalar feature. + The grouped sub-feature's config is a `SeqFeatureConfig`; rewrap it as a + top-level `FeatureConfig` so `create_features` builds it as a scalar + feature. The scalar-mode defaults (empty `default_value`, dynamic + `value_dim=0` for id_feature) are intentional: at item-export predict + time the parquet provides one value per row, so the sequence-effective + defaults (`"0"` / `1`) don't apply. Args: feature: a grouped sequence sub-feature. @@ -1260,11 +1260,6 @@ def project_grouped_sequence_feature_to_scalar( scalar_cfg = feature_pb2.FeatureConfig() dst_msg = getattr(scalar_cfg, feat_type) dst_msg.CopyFrom(src_msg) - - if hasattr(dst_msg, "default_value") and not dst_msg.default_value: - dst_msg.default_value = feature.default_value - if hasattr(dst_msg, "value_dim") and not dst_msg.HasField("value_dim"): - dst_msg.value_dim = feature.value_dim return scalar_cfg diff --git a/tzrec/features/feature_test.py b/tzrec/features/feature_test.py index 9137c009..93b97baa 100644 --- a/tzrec/features/feature_test.py +++ b/tzrec/features/feature_test.py @@ -764,7 +764,7 @@ def _build_grouped(self, seq_sub_cfg): ] return feature_lib.create_features(feature_cfgs) - def test_id_feature_projection_materializes_seq_defaults(self): + def test_id_feature_projection_preserves_scalar_defaults(self): sub_cfg = feature_pb2.SeqFeatureConfig( id_feature=feature_pb2.IdFeature( feature_name="video_id", @@ -783,10 +783,12 @@ def test_id_feature_projection_materializes_seq_defaults(self): scalar_cfg = feature_lib.project_grouped_sequence_feature_to_scalar(sub_feature) self.assertEqual(scalar_cfg.WhichOneof("feature"), "id_feature") - # default_value / value_dim materialized into the scalar proto. - self.assertEqual(scalar_cfg.id_feature.default_value, "0") - self.assertTrue(scalar_cfg.id_feature.HasField("value_dim")) - self.assertEqual(scalar_cfg.id_feature.value_dim, 1) + # Scalar-mode defaults are intentional: predict-time parquet provides + # one value per row, so the sequence-effective "0" / 1 defaults are + # not carried into the scalar proto -- FG handles the value-per-row + # input via the scalar id_feature path. + self.assertEqual(scalar_cfg.id_feature.default_value, "") + self.assertFalse(scalar_cfg.id_feature.HasField("value_dim")) # Original sub-feature proto is not mutated. self.assertEqual(sub_feature.feature_config.id_feature.default_value, "") self.assertFalse(sub_feature.feature_config.id_feature.HasField("value_dim")) @@ -808,9 +810,9 @@ def test_projection_passes_through_create_features_as_scalar(self): # Bare sub-feature name without the cand_seq__ prefix. self.assertEqual(scalar.name, "video_id") self.assertFalse(scalar.is_grouped_sequence) - # Scalar context: is_sequence is False so value_dim default would be - # 0, but our projection materialized it from the source. - self.assertEqual(scalar.value_dim, 1) + # Scalar context: value_dim resolves to 0 (variable length) via + # the non-sequence path in IdFeature.value_dim. + self.assertEqual(scalar.value_dim, 0) def test_projection_rejects_non_grouped_feature(self): feature_cfgs = [ diff --git a/tzrec/main.py b/tzrec/main.py index 45e62f04..1dfa5302 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -957,7 +957,16 @@ def export( wrapper = ( TowerWrapper if isinstance(module, MatchTower) else TowerWoEGWrapper ) - tower = InferWrapper(wrapper(module, name)) + # Towers that own a view switch (e.g. `HSTUMatchItemTower` + # flipping from training `candidate.sequence` to scalar + # `candidate.query`) need to flip BEFORE the wrapper rebuilds + # `EmbeddingGroup(module._features, module._feature_groups)`. + # Deep-copy so the original training tower is preserved. + module_for_export = module + if hasattr(module, "set_is_inference") and name == "item_tower": + module_for_export = copy.deepcopy(module) + module_for_export.set_is_inference(True) + tower = InferWrapper(wrapper(module_for_export, name)) tower_export_dir = os.path.join(export_dir, name.replace("_tower", "")) export_model( ori_pipeline_config, diff --git a/tzrec/models/hstu.py b/tzrec/models/hstu.py index 1b8ad4f0..b79c4533 100644 --- a/tzrec/models/hstu.py +++ b/tzrec/models/hstu.py @@ -191,8 +191,13 @@ def set_is_inference(self, is_inference: bool) -> None: no-op. `set_is_inference(False)` does NOT rebuild the training view; callers must export from a copy so the original training tower is preserved. + + Note: `MatchTowerWoEG` derives from `nn.Module`, not `BaseModule`, + so this method doesn't call `super().set_is_inference()`. The + `_is_inference` attribute is propagated to sub-modules separately by + `ScriptWrapper.set_is_inference()` via `recursive_setattr` during + export. """ - super().set_is_inference(is_inference) if not is_inference: return if self._cand_key.endswith(".query"): diff --git a/tzrec/models/model.py b/tzrec/models/model.py index 92ce5485..f8e62fc4 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -338,6 +338,11 @@ def __init__(self, module: nn.Module) -> None: self.model = module self._features = self.model._features self._feature_groups = self.model._feature_groups + # Propagate tower identity (set by TowerWoEGWrapper / TowerWrapper) + # so export_util.py can route item-tower export through + # `pipeline_config.item_input_path` instead of `train_input_path`. + if hasattr(self.model, "_tower_name"): + self._tower_name = self.model._tower_name self._data_parser = DataParser( self._features, sampler_type=str(module.sampler_type) diff --git a/tzrec/protos/pipeline.proto b/tzrec/protos/pipeline.proto index cd86dfb3..327ba357 100644 --- a/tzrec/protos/pipeline.proto +++ b/tzrec/protos/pipeline.proto @@ -26,4 +26,13 @@ message EasyRecConfig { repeated FeatureConfig feature_configs = 8; optional ModelConfig model_config = 9; + + // Optional item-only table for recall-model item-tower export. + // When set, item-tower export reads its sample batch from this + // path (one row per item, schema matching the scalar export view) + // instead of from `train_input_path`. The item table is typically + // already prepared upstream (used by `negative_sampler.input_path` + // too) -- this avoids feeding the export path with training-shape + // sequence rows. + optional string item_input_path = 10; } diff --git a/tzrec/tests/configs/hstu_kuairand_1k.config b/tzrec/tests/configs/hstu_kuairand_1k.config index 89d99987..706e8af1 100644 --- a/tzrec/tests/configs/hstu_kuairand_1k.config +++ b/tzrec/tests/configs/hstu_kuairand_1k.config @@ -1,5 +1,6 @@ train_input_path: "data/test/kuairand-1k-match-train-c4096-s100.parquet" eval_input_path: "data/test/kuairand-1k-match-eval-c4096-s100.parquet" +item_input_path: "data/test/kuairand-1k-match-item-c1.parquet" model_dir: "experiments/kuairand/hstu_match" train_config { sparse_optimizer { diff --git a/tzrec/tests/match_integration_test.py b/tzrec/tests/match_integration_test.py index 8764b1d9..e2159de8 100644 --- a/tzrec/tests/match_integration_test.py +++ b/tzrec/tests/match_integration_test.py @@ -370,17 +370,63 @@ def test_mind_train_eval_export(self): @unittest.skipIf(*gpu_unavailable) def test_hstu_with_fg_train_eval(self): + # DISABLE_MMA_V3=1: Triton 3.6 sm_90 WGMMA bug. ENABLE_AOT=1: HSTU + # uses TRITON kernels which require CUDA; AOT export keeps the + # forward on CUDA. Same pattern as dlrm_hstu's export test. + hstu_env = "DISABLE_MMA_V3=1" self.success = utils.test_train_eval( "tzrec/tests/configs/hstu_kuairand_1k.config", self.test_dir, user_id="user_id", item_id="item_id", + env_str=hstu_env, ) if self.success: self.success = utils.test_eval( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir + os.path.join(self.test_dir, "pipeline.config"), + self.test_dir, + env_str=hstu_env, + ) + if self.success: + self.success = utils.test_export( + os.path.join(self.test_dir, "pipeline.config"), + self.test_dir, + env_str=f"{hstu_env} ENABLE_AOT=1", + ) + if self.success: + # Item tower scalar export view: predict over the item-only + # parquet (one row per video_id) and emit item embeddings. + self.success = utils.test_predict( + scripted_model_path=os.path.join(self.test_dir, "export/item"), + predict_input_path="data/test/kuairand-1k-match-item-c1.parquet", + predict_output_path=os.path.join(self.test_dir, "item_emb"), + reserved_columns="video_id", + output_columns="item_tower_emb", + test_dir=self.test_dir, + ) + if self.success: + # User tower keeps the training-shape sequence view: predict + # over the eval parquet (which carries the user-side columns). + self.success = utils.test_predict( + scripted_model_path=os.path.join(self.test_dir, "export/user"), + predict_input_path=( + "data/test/kuairand-1k-match-eval-c4096-s100.parquet" + ), + predict_output_path=os.path.join(self.test_dir, "user_emb"), + reserved_columns="user_id", + output_columns="user_tower_emb", + test_dir=self.test_dir, ) self.assertTrue(self.success) + for side in ("user", "item"): + self.assertTrue( + os.path.exists( + os.path.join( + self.test_dir, f"export/{side}/scripted_sparse_model.pt" + ) + ), + f"missing AOT scripted sparse model for {side} tower", + ) if __name__ == "__main__": diff --git a/tzrec/utils/export_util.py b/tzrec/utils/export_util.py index aebeb1c5..bf7ad524 100644 --- a/tzrec/utils/export_util.py +++ b/tzrec/utils/export_util.py @@ -168,9 +168,17 @@ def export_model_normal( data_config.batch_size = min(data_config.batch_size, max_batch_size) logger.info("using new batch_size: %s in export", data_config.batch_size) data_config.num_workers = 1 - dataloader = create_dataloader( - data_config, features, pipeline_config.train_input_path, mode=Mode.PREDICT - ) + # Item-tower export: read the sample batch from `item_input_path` + # (one row per item, schema matching the scalar export view) instead + # of `train_input_path` (which holds training-shape sequence rows). + # Also clear the sampler so the predict dataloader doesn't try to + # launch a sampler for item-only rows. + is_item_tower = getattr(model, "_tower_name", None) == "item_tower" + input_path = pipeline_config.train_input_path + if is_item_tower and pipeline_config.HasField("item_input_path"): + input_path = pipeline_config.item_input_path + data_config.ClearField("sampler") + dataloader = create_dataloader(data_config, features, input_path, mode=Mode.PREDICT) ckpt_param_map_path = None if checkpoint_path: @@ -722,9 +730,13 @@ def _all_keys_used_once( features = cast(List[BaseFeature], model._features) data_config.num_workers = 1 data_config.batch_size = acc_utils.get_max_export_batch_size() - dataloader = create_dataloader( - data_config, features, pipeline_config.train_input_path, mode=Mode.PREDICT - ) + # Item-tower export: same routing as `export_model_normal`. + is_item_tower = getattr(model, "_tower_name", None) == "item_tower" + input_path = pipeline_config.train_input_path + if is_item_tower and pipeline_config.HasField("item_input_path"): + input_path = pipeline_config.item_input_path + data_config.ClearField("sampler") + dataloader = create_dataloader(data_config, features, input_path, mode=Mode.PREDICT) batch = next(iter(dataloader)) data = batch.to(device).to_dict(sparse_dtype=torch.int64) From 2da7c297a2b10498f3eecdb778df91863bfaa2ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 20 May 2026 14:43:49 +0800 Subject: [PATCH 05/16] [ci] add kuairand-1k-match-item-c1.parquet wget for HSTUMatch item-tower scalar export test Follows the previous commit (HSTUMatch item-tower scalar export view + item_input_path routing). 1000-row scalar item parquet; consumed by `test_hstu_with_fg_train_eval` for the item-tower predict step. Co-Authored-By: Claude Opus 4.7 (1M context) --- scripts/ci/ci_data.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/ci/ci_data.sh b/scripts/ci/ci_data.sh index 0d2c1d67..bb9bff66 100644 --- a/scripts/ci/ci_data.sh +++ b/scripts/ci/ci_data.sh @@ -13,3 +13,4 @@ wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-mot-1k-eval-c4 wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-train-c4096-s100-f1892eabc70ae3407afe9ff5bca8cb5f.parquet -O data/test/kuairand-1k-match-train-c4096-s100.parquet wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-eval-c4096-s100-e4ca5e15d157efa723041cd05c127228.parquet -O data/test/kuairand-1k-match-eval-c4096-s100.parquet wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-item-gl-3d459148303acd9f838da108efcc40e5.txt -O data/test/kuairand-1k-match-item-gl.txt +wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-item-c1-8dcadabdc3e9049ed9c2250565b4b134.parquet -O data/test/kuairand-1k-match-item-c1.parquet From 4578c7545280da401c85b81b4c8f739db3b8111e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 20 May 2026 16:44:03 +0800 Subject: [PATCH 06/16] [refactor] sampler: unify sequence state via feature configs, drop _seq_field_delims MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the global `_seq_field_delims: Dict[str, str]` (every sequence input -> delim) with two candidate-side fields derived from the matching feature config: _sampler_seq_delim -- "" if item_id_field isn't a sequence input; the parent feature's sequence_delim otherwise (covers top-level sequence_id_feature AND grouped sequence sub-features). _sampler_seq_prefix -- "" if item_id_field isn't grouped-flattened; f"{sequence_name}{_underline}" otherwise (RTP-safe; uses feature._underline directly rather than guessing "__" vs "_"). Lookup is by `item_id_field in feature.sequence_input_names`, which returns `[feature.name]` for top-level sequence features and the flattened input names for grouped sub-features -- correctly matching both cases the old `_seq_field_delims` dict covered. Three downstream simplifications: * `launch_sampler_cluster`: collapse the two checks into one `if self._sampler_seq_delim:` gate. The bare-name prefix-resolve runs only when `_sampler_seq_prefix` is set (grouped case). The outer-list strip is scoped to candidate-sequence attrs (`{a for a in attr_fields if a.startswith(seq_prefix)}`) -- top-level sequence (prefix="") matches all attr_fields; grouped sequence picks out just the seq-prefixed subset; non-sequence item-side attrs from the same lookup feature (e.g. `cat_map`) are correctly excluded. * `utils.build_sampler_input`: signature change from `seq_field_delims: Dict[str, str]` to `seq_delim: str` -- callers pass `self._sampler_seq_delim` (empty when not applicable). * `_merge_sampled_features`: drop the per-key dict lookup; single `self._sampler_seq_delim` applies uniformly because `sampled.keys()` is always a subset of `attr_fields` (candidate-side only) after the prefix-resolve step. DSSM/MIND/TDM behaviour unchanged: when item_id_field is a top-level scalar, `_sampler_seq_delim` stays empty and every branch degrades to the pre-refactor scalar path. Tests: * dataset_test.py: replace `_seq_field_delims` membership assertions with `_sampler_seq_delim` / `_sampler_seq_prefix` checks on the same lookup_feature multi-attr-strip case. * utils_test.py: update `build_sampler_input` test kwargs and the empty-delim-passthrough case. Doc cleanup: trim the trailing "dataset 层会..." sentence in `hstu_match.md` (internal mechanics no longer accurate) and reword "带 sequence_name 的全限定名" to "带 sequence_name 的序列前缀的名". Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/source/models/hstu_match.md | 2 +- tzrec/datasets/dataset.py | 122 ++++++++++++++++--------------- tzrec/datasets/dataset_test.py | 20 +++-- tzrec/datasets/utils.py | 16 ++-- tzrec/datasets/utils_test.py | 20 ++--- 5 files changed, 95 insertions(+), 85 deletions(-) diff --git a/docs/source/models/hstu_match.md b/docs/source/models/hstu_match.md index 6e444de7..d977cdb9 100644 --- a/docs/source/models/hstu_match.md +++ b/docs/source/models/hstu_match.md @@ -211,7 +211,7 @@ model_config { - data_config: 数据配置,其中需要配置负采样 Sampler,负采样 Sampler 的配置详见 [DSSM](dssm.md) 文档中的**负采样配置**章节 - - HSTUMatch 的候选侧是 `sequence_feature` 的子特征。在 `negative_sampler` 中,`item_id_field` 写为带 sequence_name 的全限定名 (例如 `cand_seq__video_id`),`attr_fields` 写为不带前缀的子特征名 (例如 `video_id`);dataset 层会根据 `item_id_field` 的前缀自动把 `attr_fields` 补齐为 `cand_seq__video_id` 后传给采样器 + - HSTUMatch 的候选侧是 `sequence_feature` 的子特征。在 `negative_sampler` 中,`item_id_field` 写为带 sequence_name 的序列前缀的名 (例如 `cand_seq__video_id`),`attr_fields` 写为不带前缀的子特征名 (例如 `video_id`) - feature_groups: 特征组 diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index c77ae7a3..954cd8a7 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -166,27 +166,31 @@ def __init__( ): self._selected_input_names = None - # Map input_name -> sequence_delim for true sequence inputs only - # (via feature.sequence_input_names). Excludes non-sequence - # sub-inputs of grouped sequence_feature. - self._seq_field_delims: Dict[str, str] = {} - for feature in features: - if not feature.sequence_delim: - continue - seq_inputs = set(feature.sequence_input_names) - for input_name in feature.inputs: - if input_name not in seq_inputs: - continue - existing = self._seq_field_delims.get(input_name) - if existing is not None and existing != feature.sequence_delim: - logger.warning( - "Conflicting sequence_delim for input '%s': %r vs %r; " - "latter wins.", - input_name, - existing, - feature.sequence_delim, - ) - self._seq_field_delims[input_name] = feature.sequence_delim + # Candidate-side sequence state. When `item_id_field` is a sequence + # input (top-level `sequence_id_feature` OR grouped sequence + # sub-feature input), `_sampler_seq_delim` is the parent feature's + # `sequence_delim`. For grouped sub-features only, `_sampler_seq_prefix` + # is the flatten prefix `f"{sequence_name}{_underline}"` (used to + # resolve bare `attr_fields` to the qualified parquet column name). + # For top-level `sequence_id_feature` (no flatten), the prefix stays + # empty and the resolve is a no-op. All sourced from the matching + # `BaseFeature` via `sequence_input_names`, the authoritative input. + self._sampler_seq_delim: str = "" + self._sampler_seq_prefix: str = "" + if self._sampler_item_id_field is not None: + for feature in features: + if ( + feature.sequence_delim + and self._sampler_item_id_field in feature.sequence_input_names + ): + self._sampler_seq_delim = feature.sequence_delim + if feature.is_grouped_sequence: + # pyre-ignore [16]: BaseFeature._underline is intentional; + # avoids re-deriving the "_" vs "__" choice ourselves. + self._sampler_seq_prefix = ( + feature.sequence_name + feature._underline + ) + break self._fg_mode = data_config.fg_mode self._fg_encoded_multival_sep = data_config.fg_encoded_multival_sep @@ -214,49 +218,52 @@ def launch_sampler_cluster( # `self._data_config`'s sampler sub-message. sampler_config = copy.deepcopy(getattr(self._data_config, sampler_type)) - # Multi-positive sampling: when the sampler's item_id_field is - # itself a sequence-positive train column, the per-row outer list - # on every item-side attr is the positive-grouping container, not - # a multi-value field. Strip the outer list so the sampler sees - # the pool's native scalar storage and _to_arrow_array emits - # scalar negs directly (avoiding the multival_sep split that - # would wrap each scalar in a 1-elem list). sampler_fields = self.input_fields - if ( - self._sampler_item_id_field is not None - and self._sampler_item_id_field in self._seq_field_delims - ): + if self._sampler_seq_delim: + if self._sampler_seq_prefix: + # Grouped sequence: resolve bare candidate sub-feature + # names against the qualified flattened schema using + # the authoritative parent prefix from feature configs + # (RTP-safe, no name string split). Has to run *before* + # the outer-list strip so `consumed` carries resolved + # names. Top-level `sequence_id_feature` skips this -- + # its attr_fields are already bare/qualified-as-itself. + field_names = {f.name for f in sampler_fields} + sampler_config.attr_fields[:] = [ + self._sampler_seq_prefix + a + if a not in field_names + and self._sampler_seq_prefix + a in field_names + else a + for a in sampler_config.attr_fields + ] + + # Multi-positive sampling: when the sampler's item_id_field + # is itself a sequence-positive train column, the per-row + # outer list on the candidate sequence's item-side attrs is + # the positive-grouping container. Strip the outer list + # only for the candidate-sequence attrs in `attr_fields` + # (filtered by `_sampler_seq_prefix`; top-level sequence + # has prefix="" so all attr_fields match). Excluded: + # top-level item-side attrs from the same lookup feature + # whose outer list is multi-value (e.g. `cat_map`); other + # grouped sequences' sub-features (e.g. uih_seq__*); and + # the `item_id_field` itself -- the sampler never inspects + # their type. + consumed = { + a + for a in sampler_config.attr_fields + if a.startswith(self._sampler_seq_prefix) + } sampler_fields = [ pa.field(f.name, f.type.value_type) if ( - f.name in self._seq_field_delims + f.name in consumed and (pa.types.is_list(f.type) or pa.types.is_large_list(f.type)) ) else f for f in self.input_fields ] - # Resolve bare candidate sub-feature names in `attr_fields` against - # `sampler_fields` (flattened parquet schema), using the prefix - # carried by a qualified `item_id_field` (e.g. "cand_seq__video_id" - # -> "cand_seq__"). When `item_id_field` is itself bare - # (DSSM/MIND/TDM top-level case), `seq_prefix` is empty and the - # loop is a no-op -- existing configs are unchanged. - if hasattr(sampler_config, "item_id_field") and sampler_config.HasField( - "item_id_field" - ): - id_field = sampler_config.item_id_field - field_names = {f.name for f in sampler_fields} - if "__" in id_field and id_field in field_names: - seq_prefix = id_field.split("__", 1)[0] + "__" - if hasattr(sampler_config, "attr_fields"): - sampler_config.attr_fields[:] = [ - seq_prefix + a - if a not in field_names and seq_prefix + a in field_names - else a - for a in sampler_config.attr_fields - ] - # pyre-ignore [16] self._sampler = BaseSampler.create_class(sampler_config.__class__.__name__)( sampler_config, @@ -415,7 +422,7 @@ def _apply_negative_sampler( input_data, self._sampler_item_id_field, self._sampler_user_id_field, - self._seq_field_delims, + self._sampler_seq_delim, ) sampled = self._sampler.get(sampler_input) @@ -497,12 +504,11 @@ def _merge_sampled_features( if k not in input_data: input_data[k] = v continue - seq_delim = self._seq_field_delims.get(k) - if seq_delim is None: + if not self._sampler_seq_delim: input_data[k] = pa.concat_arrays([input_data[k], v]) continue combined, pl = combine_negs_to_candidate_sequence( - input_data[k], v, seq_delim + input_data[k], v, self._sampler_seq_delim ) input_data[k] = combined if k == prefer_key: diff --git a/tzrec/datasets/dataset_test.py b/tzrec/datasets/dataset_test.py index a8928ef1..5fbf33ae 100644 --- a/tzrec/datasets/dataset_test.py +++ b/tzrec/datasets/dataset_test.py @@ -681,12 +681,13 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): """End-to-end strip decisions across the per-attr filter. Grouped LookupFeature sub with two item-side inputs; only ``cat_key`` - is in ``sequence_fields`` so only it enters ``_seq_field_delims``. + is in ``sequence_fields`` so only it is a candidate-sequence input. ``cat_key`` is typed ``list>`` (multi-value attr layered under multi-positive grouping); after strip it becomes ``list`` (ONE level stripped, not bare-stripped to ``int64``). - ``cat_map`` is item-side but excluded from ``_seq_field_delims``, - so it stays ``list`` unchanged. + ``cat_map`` is item-side but NOT a candidate-sequence sub-feature + (doesn't start with ``_sampler_seq_prefix``), so it stays + ``list`` unchanged. """ f = tempfile.NamedTemporaryFile("w") self._temp_files.append(f) @@ -746,14 +747,17 @@ def test_launch_sampler_cluster_multi_attr_strip_decision_matrix(self): input_fields=input_fields, mode=Mode.TRAIN, ) - # Narrowed _seq_field_delims excludes the non-sequence item-side input. - self.assertIn("click_seq__cat_key", dataset._seq_field_delims) - self.assertNotIn("cat_map", dataset._seq_field_delims) + # Candidate-side sequence state derived from item_id_field's + # matching feature -- "click_seq__cat_key" is a grouped sequence + # input so both _sampler_seq_delim and _sampler_seq_prefix are set. + self.assertEqual(dataset._sampler_seq_delim, ";") + self.assertEqual(dataset._sampler_seq_prefix, "click_seq__") dataset.launch_sampler_cluster(2) - # outer guard True (item_id_field is sequence-positive): + # outer guard True (item_id_field is candidate-sequence sub-feature): # - cat_key: list> -> list (one strip). - # - cat_map: list, not in _seq_field_delims -> unstripped. + # - cat_map: list, not in {a for a in attr_fields if + # a.startswith(prefix)} -> unstripped. cat_key_idx = dataset._sampler._attr_names.index("click_seq__cat_key") cat_map_idx = dataset._sampler._attr_names.index("cat_map") self.assertEqual( diff --git a/tzrec/datasets/utils.py b/tzrec/datasets/utils.py index 98f3727f..8e8e1989 100644 --- a/tzrec/datasets/utils.py +++ b/tzrec/datasets/utils.py @@ -582,31 +582,31 @@ def build_sampler_input( input_data: Dict[str, pa.Array], item_id_field: Optional[str], user_id_field: Optional[str], - seq_field_delims: Dict[str, str], + seq_delim: str, ) -> Dict[str, pa.Array]: """Shallow-copy input_data with item_id (and user_id) flattened for the sampler. - When `item_id_field` is a sequence_id_feature, per-row positives - (delimited string or list array) are flattened to 1D and + When `item_id_field` is a grouped sequence sub-feature, per-row + positives (delimited string or list array) are flattened to 1D and `user_id_field` (if any) is expanded by per-row positive count. - Scalar item_id or unconfigured seq_delim falls through unchanged. - The caller's `input_data` is not mutated. + Scalar item_id (`seq_delim=""`) falls through unchanged. The caller's + `input_data` is not mutated. Args: input_data: per-row input column dict. item_id_field: sampler config's `item_id_field`, or None. user_id_field: sampler config's `user_id_field`, or None. - seq_field_delims: input_name -> sequence_delim mapping. + seq_delim: candidate sequence's `sequence_delim`, or "" when + `item_id_field` is a top-level scalar feature. Returns: A new shallow-copy dict with item_id flattened and user_id expanded when both apply. """ sampler_input = dict(input_data) - if item_id_field is None or item_id_field not in seq_field_delims: + if item_id_field is None or not seq_delim: return sampler_input - seq_delim = seq_field_delims[item_id_field] raw = input_data[item_id_field] if pa.types.is_string(raw.type) or pa.types.is_large_string(raw.type): pos_lists = pc.split_pattern(raw, seq_delim) diff --git a/tzrec/datasets/utils_test.py b/tzrec/datasets/utils_test.py index c7405b29..3e95a52a 100644 --- a/tzrec/datasets/utils_test.py +++ b/tzrec/datasets/utils_test.py @@ -210,7 +210,7 @@ def test_calc_slice_intervals_topology_change(self): @parameterized.expand( [ # (name, input_data, item_id_field, user_id_field, - # seq_field_delims, expected_output) + # seq_delim, expected_output) ( # NegativeSampler-style: no user_id_field; item_id is # delimited string; gets flattened. @@ -218,7 +218,7 @@ def test_calc_slice_intervals_topology_change(self): {"item_id": pa.array(["1;2", "3"]), "label": pa.array([1, 0])}, "item_id", None, - {"item_id": ";"}, + ";", {"item_id": ["1", "2", "3"], "label": [1, 0]}, ), ( @@ -231,7 +231,7 @@ def test_calc_slice_intervals_topology_change(self): }, "item_id", "user_id", - {"item_id": ";"}, + ";", {"item_id": [1, 2, 3], "user_id": ["u0", "u0", "u1"]}, ), ( @@ -244,16 +244,16 @@ def test_calc_slice_intervals_topology_change(self): }, "item_id", "user_id", - {"item_id": ";"}, + ";", {"item_id": [1, 2, 3], "user_id": ["u0", "u1", "u2"]}, ), ( - # item_id_field has no seq_delim entry -> pass through. - "item_id_not_in_seq_field_delims", + # item_id_field is a top-level scalar -> seq_delim="" -> pass through. + "empty_seq_delim_passthrough", {"item_id": pa.array(["1", "2"])}, "item_id", None, - {}, + "", {"item_id": ["1", "2"]}, ), ( @@ -263,7 +263,7 @@ def test_calc_slice_intervals_topology_change(self): {"a": pa.array([1, 2])}, None, None, - {}, + "", {"a": [1, 2]}, ), ] @@ -274,7 +274,7 @@ def test_build_sampler_input( input_data, item_id_field, user_id_field, - seq_field_delims, + seq_delim, expected_output, ): # Snapshot input_data so we can verify the function didn't mutate it. @@ -284,7 +284,7 @@ def test_build_sampler_input( input_data, item_id_field=item_id_field, user_id_field=user_id_field, - seq_field_delims=seq_field_delims, + seq_delim=seq_delim, ) # Contract 1: output equals expected (per-column pylist compare). From 8cd5d3e46d625001ad997444fa8277285fe75b53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 20 May 2026 17:37:23 +0800 Subject: [PATCH 07/16] [refactor] HSTUMatchItemTower view via lazy properties; materialize seq defaults; hoist set_is_inference Three review-prompted fixes: 1. project_grouped_sequence_feature_to_scalar() now materializes the source feature's effective default_value and value_dim onto the scalar proto. Without this, an id_feature projected from a sequence sub-feature (sequence-effective defaults "0" / 1) silently flipped to scalar defaults ("" / 0) -- behaviorally divergent from the training sub-feature it was projected from. The earlier removal chased a misdiagnosed FG failure that was actually a _tower_name propagation bug (fixed in commit 0c1dcf1). 2. HSTUMatchItemTower exposes `features` and `feature_groups` as lazy properties driven by `_is_inference`, instead of mutating `_features` / `_feature_groups` / `_cand_key` inside `set_is_inference(True)`. The scalar view is built once on first property access and cached (`_features_scalar` / `_feature_groups_scalar`); the training view stays immutable on the parent's `self._features`, so `set_is_inference(False)` reverts cleanly. `forward()` derives the candidate key inline (`{group}.sequence` vs `{group}.query`) from the flag -- no cached `_cand_key` attribute. `MatchTowerWoEG` exposes default `features` / `feature_groups` properties (forwarding to the underscore fields). Wrappers (`TowerWrapper`, `TowerWoEGWrapper`, `ScriptWrapper`) read via the properties and expose their own snapshot properties. Non-HSTUMatch towers are unaffected -- the default properties match the pre-refactor direct attribute reads. 3. set_is_inference(True) is hoisted to before every InferWrapper() in tzrec/main.py::export, removed from the three post-wrap call sites in tzrec/utils/export_util.py (lines 201, 743, 1065). The recursive_setattr from BaseModule.set_is_inference propagates the flag down to all sub-modules including the inner towers; wrappers are then constructed with the inference-mode view already established (so TowerWoEGWrapper's EmbeddingGroup is built off HSTUMatchItemTower's scalar features). Drop the per-item-tower deep-copy in main.py -- mutation is gone, the flag is the only state change. The HSTUMatchTest JIT_SCRIPT branch no longer flips `_is_inference` (it was the source of a view-toggle conflict: the model's own EmbeddingGroup emits `.sequence`-shaped grouped_features but `_is_inference=True` would make the item tower read `.query`, KeyError). JIT compiles both branches; the training-shape batch runs the training branch. The feature projection unit test reverts to assert materialization: `scalar_cfg.id_feature.default_value == "0"`, `value_dim == 1`. Non-functional for DSSM/MIND/TDM: properties default to underscore fields, set_is_inference flow is unchanged for non-HSTUMatch towers. Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/features/feature.py | 15 +++-- tzrec/features/feature_test.py | 22 +++---- tzrec/main.py | 26 ++++---- tzrec/models/hstu.py | 105 ++++++++++++++++++++++----------- tzrec/models/hstu_test.py | 7 ++- tzrec/models/match_model.py | 55 +++++++++++++++-- tzrec/models/model.py | 28 ++++++++- tzrec/utils/export_util.py | 12 +++- 8 files changed, 202 insertions(+), 68 deletions(-) diff --git a/tzrec/features/feature.py b/tzrec/features/feature.py index 84f7fcfc..94fc6cb4 100644 --- a/tzrec/features/feature.py +++ b/tzrec/features/feature.py @@ -1235,10 +1235,12 @@ def project_grouped_sequence_feature_to_scalar( The grouped sub-feature's config is a `SeqFeatureConfig`; rewrap it as a top-level `FeatureConfig` so `create_features` builds it as a scalar - feature. The scalar-mode defaults (empty `default_value`, dynamic - `value_dim=0` for id_feature) are intentional: at item-export predict - time the parquet provides one value per row, so the sequence-effective - defaults (`"0"` / `1`) don't apply. + feature. Materializes the source feature's effective `default_value` and + `value_dim` onto the scalar proto so the exported feature is + behaviorally identical to the training sub-feature -- without this, + `id_feature.value_dim` and `default_value` resolve differently in + scalar mode (`0` / `""`) than in sequence mode (`1` / `"0"`) + (see feature.py:515-517, 556-561). Args: feature: a grouped sequence sub-feature. @@ -1260,6 +1262,11 @@ def project_grouped_sequence_feature_to_scalar( scalar_cfg = feature_pb2.FeatureConfig() dst_msg = getattr(scalar_cfg, feat_type) dst_msg.CopyFrom(src_msg) + + if hasattr(dst_msg, "default_value") and not dst_msg.default_value: + dst_msg.default_value = feature.default_value + if hasattr(dst_msg, "value_dim") and not dst_msg.HasField("value_dim"): + dst_msg.value_dim = feature.value_dim return scalar_cfg diff --git a/tzrec/features/feature_test.py b/tzrec/features/feature_test.py index 93b97baa..ba5328ba 100644 --- a/tzrec/features/feature_test.py +++ b/tzrec/features/feature_test.py @@ -764,7 +764,7 @@ def _build_grouped(self, seq_sub_cfg): ] return feature_lib.create_features(feature_cfgs) - def test_id_feature_projection_preserves_scalar_defaults(self): + def test_id_feature_projection_materializes_seq_defaults(self): sub_cfg = feature_pb2.SeqFeatureConfig( id_feature=feature_pb2.IdFeature( feature_name="video_id", @@ -783,12 +783,14 @@ def test_id_feature_projection_preserves_scalar_defaults(self): scalar_cfg = feature_lib.project_grouped_sequence_feature_to_scalar(sub_feature) self.assertEqual(scalar_cfg.WhichOneof("feature"), "id_feature") - # Scalar-mode defaults are intentional: predict-time parquet provides - # one value per row, so the sequence-effective "0" / 1 defaults are - # not carried into the scalar proto -- FG handles the value-per-row - # input via the scalar id_feature path. - self.assertEqual(scalar_cfg.id_feature.default_value, "") - self.assertFalse(scalar_cfg.id_feature.HasField("value_dim")) + # Sequence-effective defaults carried into the scalar proto so the + # exported feature is behaviorally identical to the training + # sub-feature -- otherwise an id_feature in scalar mode resolves + # default_value to "" and value_dim to 0, diverging from the + # training "0" / 1 it was projected from. + self.assertEqual(scalar_cfg.id_feature.default_value, "0") + self.assertTrue(scalar_cfg.id_feature.HasField("value_dim")) + self.assertEqual(scalar_cfg.id_feature.value_dim, 1) # Original sub-feature proto is not mutated. self.assertEqual(sub_feature.feature_config.id_feature.default_value, "") self.assertFalse(sub_feature.feature_config.id_feature.HasField("value_dim")) @@ -810,9 +812,9 @@ def test_projection_passes_through_create_features_as_scalar(self): # Bare sub-feature name without the cand_seq__ prefix. self.assertEqual(scalar.name, "video_id") self.assertFalse(scalar.is_grouped_sequence) - # Scalar context: value_dim resolves to 0 (variable length) via - # the non-sequence path in IdFeature.value_dim. - self.assertEqual(scalar.value_dim, 0) + # Scalar context: value_dim is materialized to 1 from the source + # sub-feature's sequence-effective default. + self.assertEqual(scalar.value_dim, 1) def test_projection_rejects_non_grouped_feature(self): feature_cfgs = [ diff --git a/tzrec/main.py b/tzrec/main.py index 1dfa5302..b4603088 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -936,6 +936,16 @@ def export( sampler_type=None, ) InferWrapper = ScriptWrapper + # Set the inference flag on the inner model *before* wrapping. Every + # downstream wrapper (ScriptWrapper for non-match, TowerWoEGWrapper / + # TowerWrapper per match tower) snapshots view-dependent state at + # construction time (e.g. `HSTUMatchItemTower`'s lazy + # `features`/`feature_groups` properties; the wrappers' EmbeddingGroup), + # so the flag must already be True when the wrapping happens. + # `recursive_setattr` propagates the flag to all sub-modules including + # the inner towers, so the per-tower wrap below doesn't need its own + # toggle. + model.set_is_inference(True) model = InferWrapper(model) if not checkpoint_path: @@ -957,16 +967,12 @@ def export( wrapper = ( TowerWrapper if isinstance(module, MatchTower) else TowerWoEGWrapper ) - # Towers that own a view switch (e.g. `HSTUMatchItemTower` - # flipping from training `candidate.sequence` to scalar - # `candidate.query`) need to flip BEFORE the wrapper rebuilds - # `EmbeddingGroup(module._features, module._feature_groups)`. - # Deep-copy so the original training tower is preserved. - module_for_export = module - if hasattr(module, "set_is_inference") and name == "item_tower": - module_for_export = copy.deepcopy(module) - module_for_export.set_is_inference(True) - tower = InferWrapper(wrapper(module_for_export, name)) + # The inference flag was already set on every sub-module by + # `model.set_is_inference(True)` above, so `HSTUMatchItemTower`'s + # lazy `features` / `feature_groups` properties return the + # scalar view here and the wrapper's `EmbeddingGroup` is + # built off scalar features. + tower = InferWrapper(wrapper(module, name)) tower_export_dir = os.path.join(export_dir, name.replace("_tower", "")) export_model( ori_pipeline_config, diff --git a/tzrec/models/hstu.py b/tzrec/models/hstu.py index b79c4533..65255047 100644 --- a/tzrec/models/hstu.py +++ b/tzrec/models/hstu.py @@ -161,13 +161,24 @@ def __init__( # tower_config.input here, which equals feature_groups[0].group_name. self._group_name = tower_config.input # Training view: candidate group is JAGGED_SEQUENCE; embedding_group - # emits the per-row jagged tensor at `{group_name}.sequence`. The - # export-time scalar view (see `set_is_inference(True)`) swaps this - # to `{group_name}.query` -- one row per item, no positive-set - # grouping container. - self._cand_key = f"{self._group_name}.sequence" - candidate_dims = embedding_group.group_dims(self._cand_key) + # emits the per-row jagged tensor at `{group_name}.sequence`. At + # export, `set_is_inference(True)` flips the flag below; the + # `features` / `feature_groups` properties then return the lazily- + # built scalar view (one row per item), and `forward()` reads + # `{group_name}.query` instead. Mlp / output Linear are sized off + # the training candidate group; the scalar view's per-feature + # embedding dim is identical, so no resize is needed. + candidate_dims = embedding_group.group_dims(f"{self._group_name}.sequence") candidate_total_dim = sum(candidate_dims) + + # Lazy caches for the scalar export view; populated on first + # property access when `_is_inference` is True. None at training + # time -- non-export consumers pay zero cost. + self._features_scalar: Optional[List[BaseFeature]] = None + self._feature_groups_scalar: Optional[List[model_pb2.FeatureGroupConfig]] = None + # Initialize explicitly: `MatchTowerWoEG` derives from `nn.Module`, + # not `BaseModule`, so `_is_inference` isn't set by the parent. + self._is_inference: bool = False if tower_config.HasField("mlp"): self.mlp: torch.nn.Module = MLP( in_features=candidate_total_dim, @@ -181,30 +192,37 @@ def __init__( if self._output_dim > 0: self.output = nn.Linear(mlp_out_dim, output_dim) - def set_is_inference(self, is_inference: bool) -> None: - """Switch the tower into the scalar item-export view (one-way). - - Training stays at `candidate.sequence` (jagged sub-feature). Export - copies the tower (`copy.deepcopy`), calls this with `True`, and the - wrapper picks up the swapped `_features` / `_feature_groups` / - `_cand_key` for serving. Idempotent: re-entering after the swap is a - no-op. `set_is_inference(False)` does NOT rebuild the training view; - callers must export from a copy so the original training tower is - preserved. - - Note: `MatchTowerWoEG` derives from `nn.Module`, not `BaseModule`, - so this method doesn't call `super().set_is_inference()`. The - `_is_inference` attribute is propagated to sub-modules separately by - `ScriptWrapper.set_is_inference()` via `recursive_setattr` during - export. - """ - if not is_inference: - return - if self._cand_key.endswith(".query"): - return # already swapped (idempotent) + @property + def features(self) -> List[BaseFeature]: + """Item-side features in the current view (training or scalar export). - # Project each grouped sequence sub-feature into a scalar export - # FeatureConfig, then rebuild as top-level scalar features. + At training (`_is_inference=False`), returns the grouped sequence + sub-features the tower was constructed with. At export + (`_is_inference=True`), returns the lazily-built scalar + projection, cached for subsequent reads. + """ + if self._is_inference: + if self._features_scalar is None: + self._build_scalar_features() + return self._features_scalar + return self._features + + @property + def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: + """Item-side feature_groups in the current view (see `features`).""" + if self._is_inference: + if self._feature_groups_scalar is None: + self._build_scalar_features() + return self._feature_groups_scalar + return self._feature_groups + + def _build_scalar_features(self) -> None: + """Build the scalar export view caches from the training features. + + Projects each grouped sequence sub-feature into a scalar export + feature; populates `_features_scalar` and `_feature_groups_scalar`. + Called at most once per tower instance (cached via the properties). + """ scalar_configs = [ project_grouped_sequence_feature_to_scalar(f) for f in self._features ] @@ -218,26 +236,47 @@ def set_is_inference(self, is_inference: bool) -> None: f.data_group == BASE_DATA_GROUP for f in self._features ), ) - self._features = scalar_features - self._feature_groups = [ + self._features_scalar = scalar_features + self._feature_groups_scalar = [ model_pb2.FeatureGroupConfig( group_name=self._group_name, feature_names=[f.name for f in scalar_features], group_type=model_pb2.JAGGED_SEQUENCE, ) ] - self._cand_key = f"{self._group_name}.query" + + def set_is_inference(self, is_inference: bool) -> None: + """Toggle the export-view flag without structural mutation. + + Cheap; the scalar features/groups are materialized lazily on first + property read (typically by `TowerWoEGWrapper.__init__` rebuilding + its EmbeddingGroup). `set_is_inference(False)` reverts the view; + the lazy caches survive but are unused. + + `MatchTowerWoEG` derives from `nn.Module`, not `BaseModule`, so + this method doesn't call `super().set_is_inference()`. The + `_is_inference` flag on sub-modules is set separately by the + caller's `BaseModule.set_is_inference` (recursive_setattr). + """ + self._is_inference = is_inference def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: """Forward the item tower. + Reads from `{group_name}.sequence` (jagged) at training and from + `{group_name}.query` (scalar) at export. One-line conditional; + no cached `cand_key` attribute. The branch is on `_is_inference` + rather than dict-membership so the choice is FX/JIT traceable + (dict-membership checks aren't traceable as control flow). + Args: grouped_features: dictionary of embedded features from EmbeddingGroup. Returns: item embeddings of shape (sum_candidates, D). """ - cand_emb = grouped_features[self._cand_key] + suffix = ".query" if self._is_inference else ".sequence" + cand_emb = grouped_features[self._group_name + suffix] item_emb = self.mlp(cand_emb) if self._output_dim > 0: item_emb = self.output(item_emb) diff --git a/tzrec/models/hstu_test.py b/tzrec/models/hstu_test.py index 1f2fa6a9..8857a381 100644 --- a/tzrec/models/hstu_test.py +++ b/tzrec/models/hstu_test.py @@ -199,7 +199,12 @@ def test_hstu_match(self, graph_type, kernel, device_str) -> None: batch = _build_batch(device=device) if graph_type == TestGraphType.JIT_SCRIPT: - hstu.set_is_inference(True) + # Don't flip the inference flag here: the test batch is in the + # training-view shape (jagged `candidate.sequence`), and the + # item tower's forward branches on `_is_inference` to choose + # `.sequence` vs `.query`. The JIT-scripted forward still + # compiles both branches; the runtime path is the training + # branch, matching the batch shape. hstu = create_test_model(hstu, graph_type) predictions = hstu(batch.to_dict(), device) elif graph_type == TestGraphType.FX_TRACE: diff --git a/tzrec/models/match_model.py b/tzrec/models/match_model.py index 433cd69a..2fb6193c 100644 --- a/tzrec/models/match_model.py +++ b/tzrec/models/match_model.py @@ -222,6 +222,20 @@ def __init__( self._feature_groups = feature_groups self._features = features + @property + def features(self) -> List[BaseFeature]: + """Item-side features the tower exposes to its wrapper. + + Default reads `self._features`; overridden by towers that switch + views between training and export (see `HSTUMatchItemTower`). + """ + return self._features + + @property + def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: + """Item-side feature_groups the tower exposes to its wrapper.""" + return self._feature_groups + class MatchModel(BaseModel): """Base model for match. @@ -457,10 +471,25 @@ class TowerWrapper(nn.Module): def __init__(self, module: nn.Module, tower_name: str = "user_tower") -> None: super().__init__() setattr(self, tower_name, module) - self._features = module._features - self._feature_groups = module._feature_groups + # Snapshot the tower's current view via the property (which for + # `HSTUMatchItemTower` returns the scalar view iff + # `_is_inference=True`). Wrapper construction must happen *after* + # the inference flag is set on the inner tower (see + # `tzrec/main.py::export`). + self._features = module.features + self._feature_groups = module.feature_groups self._tower_name = tower_name + @property + def features(self) -> List[BaseFeature]: + """Snapshot of the wrapped tower's features at construction time.""" + return self._features + + @property + def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: + """Snapshot of the wrapped tower's feature_groups at construction time.""" + return self._feature_groups + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Forward the tower. @@ -478,13 +507,29 @@ class TowerWoEGWrapper(nn.Module): def __init__(self, module: nn.Module, tower_name: str = "user_tower") -> None: super().__init__() - self.embedding_group = EmbeddingGroup(module._features, module._feature_groups) + # Build EmbeddingGroup from the tower's *current view*: for + # `HSTUMatchItemTower` after `set_is_inference(True)`, this is the + # scalar export view (one row per item, `{group_name}.query`); + # otherwise it's the training view (jagged, `{group_name}.sequence`). + # Wrapper construction must happen *after* the inference flag is + # set on the inner tower -- see `tzrec/main.py::export`. + self.embedding_group = EmbeddingGroup(module.features, module.feature_groups) setattr(self, tower_name, module) - self._features = module._features - self._feature_groups = module._feature_groups + self._features = module.features + self._feature_groups = module.feature_groups self._tower_name = tower_name self._group_name = module._group_name + @property + def features(self) -> List[BaseFeature]: + """Snapshot of the wrapped tower's features at construction time.""" + return self._features + + @property + def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: + """Snapshot of the wrapped tower's feature_groups at construction time.""" + return self._feature_groups + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Forward the tower. diff --git a/tzrec/models/model.py b/tzrec/models/model.py index f8e62fc4..5b397755 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -336,8 +336,12 @@ class ScriptWrapper(BaseModule): def __init__(self, module: nn.Module) -> None: super().__init__() self.model = module - self._features = self.model._features - self._feature_groups = self.model._feature_groups + # Snapshot via the inner module's view properties (defaults forward + # to the underscore fields; HSTUMatchItemTower / wrapper towers + # override). Wrapper construction must happen *after* the inference + # flag is set on the inner module -- see `tzrec/main.py::export`. + self._features = self._features_from(module) + self._feature_groups = self._feature_groups_from(module) # Propagate tower identity (set by TowerWoEGWrapper / TowerWrapper) # so export_util.py can route item-tower export through # `pipeline_config.item_input_path` instead of `train_input_path`. @@ -350,6 +354,26 @@ def __init__(self, module: nn.Module) -> None: else None, ) + @staticmethod + def _features_from(module: nn.Module) -> List["BaseFeature"]: + return module.features if hasattr(module, "features") else module._features + + @staticmethod + def _feature_groups_from(module: nn.Module) -> List: + if hasattr(module, "feature_groups"): + return module.feature_groups + return module._feature_groups + + @property + def features(self) -> List["BaseFeature"]: + """Snapshot of the wrapped module's features at construction time.""" + return self._features + + @property + def feature_groups(self) -> List: + """Snapshot of the wrapped module's feature_groups at construction time.""" + return self._feature_groups + def get_batch( self, data: Dict[str, torch.Tensor], diff --git a/tzrec/utils/export_util.py b/tzrec/utils/export_util.py index bf7ad524..6f35a41b 100644 --- a/tzrec/utils/export_util.py +++ b/tzrec/utils/export_util.py @@ -198,7 +198,11 @@ def export_model_normal( if is_rank_zero: if not os.path.exists(save_dir): os.makedirs(save_dir) - model.set_is_inference(True) + # `set_is_inference(True)` is the caller's responsibility; it must + # be applied to the inner model *before* wrapping with + # `InferWrapper` (see `tzrec/main.py::export`) so wrapper-level + # snapshots (EmbeddingGroup, view-dependent features) pick up the + # inference-mode view at construction time. init_parameters(model, torch.device("cpu")) checkpoint_util.restore_model( @@ -740,7 +744,8 @@ def _all_keys_used_once( batch = next(iter(dataloader)) data = batch.to(device).to_dict(sparse_dtype=torch.int64) - model.set_is_inference(True) + # `set_is_inference(True)` was applied in `tzrec/main.py::export` before + # wrapping -- the inner model + all sub-modules already carry the flag. # Build Sharded Model planner = create_planner( @@ -1062,7 +1067,8 @@ def split_model( if not os.path.exists(graph_dir): os.makedirs(graph_dir) - model.set_is_inference(True) + # `set_is_inference(True)` was applied in `tzrec/main.py::export` before + # wrapping -- the inner model + all sub-modules already carry the flag. model.eval() tracer = Tracer() From 35b18cb48f75afe7a251efa1b9ca94fda6378899 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 20 May 2026 17:53:24 +0800 Subject: [PATCH 08/16] [refactor] expose features/feature_groups as properties end-to-end; drop hasattr fallback in ScriptWrapper Three follow-on simplifications on top of commit 8cd5d3e: 1. `HSTUMatchItemTower.set_is_inference` override removed -- now dead code. `BaseModule.set_is_inference` (called once on the full model in `main.py::export`) uses `recursive_setattr` which sets the `_is_inference` attribute on every sub-module directly; it never calls sub-modules' methods. With the lazy-property design, the flag flip on `HSTUMatchItemTower._is_inference` IS the toggle -- no extra method needed. 2. `TowerWrapper` / `TowerWoEGWrapper` now read `features` / `feature_groups` lazily via `getattr(self, self._tower_name)` -- no construction-time snapshot of those metadata fields. The EmbeddingGroup (which owns nn.Parameters) still snapshots at construction; the metadata properties stay live so they reflect whatever view the inner tower currently exposes. 3. `BaseModel` and `TDMEmbedding` get explicit `features` / `feature_groups` properties (default reads of `_features` / `_feature_groups`). With every wrapped model exposing the property, `ScriptWrapper` drops the `_features_from` / `_feature_groups_from` hasattr-fallback helpers and just reads `self.model.features` / `self.model.feature_groups`. The wrapper also no longer snapshots into `self._features` / `self._feature_groups` -- those properties forward live to `self.model.features` / `self.model.feature_groups`. `tzrec/utils/export_util.py` and `tzrec/acc/aot_utils.py` migrated from `model._features` / `model._feature_groups` to the property API (8 occurrences total). Non-functional for DSSM/MIND/TDM: the new default properties on BaseModel + TDMEmbedding just return the underscore fields, identical to today's direct attribute reads. The HSTUMatch integration test passes end-to-end (~289-313s on local A10). Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/acc/aot_utils.py | 10 +++++----- tzrec/models/hstu.py | 15 -------------- tzrec/models/match_model.py | 34 ++++++++++++++----------------- tzrec/models/model.py | 40 ++++++++++++++++--------------------- tzrec/models/tdm.py | 10 ++++++++++ tzrec/utils/export_util.py | 10 +++++----- 6 files changed, 52 insertions(+), 67 deletions(-) diff --git a/tzrec/acc/aot_utils.py b/tzrec/acc/aot_utils.py index 69e1b06d..74f08404 100644 --- a/tzrec/acc/aot_utils.py +++ b/tzrec/acc/aot_utils.py @@ -303,8 +303,8 @@ def _build_dynamic_shapes( Args: data: input tensor dict from Batch.to_dict(). - features: list of BaseFeature from model._features. - feature_groups: list of FeatureGroupConfig from model._feature_groups. + features: list of BaseFeature from model.features. + feature_groups: list of FeatureGroupConfig from model.feature_groups. Returns: dynamic_shapes dict for torch.export.export(). @@ -467,14 +467,14 @@ def export_unified_model_aot( # Pad any 0-size non-sequence sparse .values tensors so torch.export # doesn't specialize on the empty size (which conflicts with dynamic Dims). - seq_feat_names = {f.name for f in model._features if f.is_sequence} + seq_feat_names = {f.name for f in model.features if f.is_sequence} data = _pad_empty_sparse_values(data, seq_feat_names) # Build dynamic shapes using feature metadata for correct Dim grouping dynamic_shapes = _build_dynamic_shapes( data, - features=model._features, - feature_groups=model._feature_groups, + features=model.features, + feature_groups=model.feature_groups, ) logger.info("dynamic shapes=%s" % dynamic_shapes) diff --git a/tzrec/models/hstu.py b/tzrec/models/hstu.py index 65255047..20a5299d 100644 --- a/tzrec/models/hstu.py +++ b/tzrec/models/hstu.py @@ -245,21 +245,6 @@ def _build_scalar_features(self) -> None: ) ] - def set_is_inference(self, is_inference: bool) -> None: - """Toggle the export-view flag without structural mutation. - - Cheap; the scalar features/groups are materialized lazily on first - property read (typically by `TowerWoEGWrapper.__init__` rebuilding - its EmbeddingGroup). `set_is_inference(False)` reverts the view; - the lazy caches survive but are unused. - - `MatchTowerWoEG` derives from `nn.Module`, not `BaseModule`, so - this method doesn't call `super().set_is_inference()`. The - `_is_inference` flag on sub-modules is set separately by the - caller's `BaseModule.set_is_inference` (recursive_setattr). - """ - self._is_inference = is_inference - def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: """Forward the item tower. diff --git a/tzrec/models/match_model.py b/tzrec/models/match_model.py index 2fb6193c..84b75286 100644 --- a/tzrec/models/match_model.py +++ b/tzrec/models/match_model.py @@ -471,24 +471,21 @@ class TowerWrapper(nn.Module): def __init__(self, module: nn.Module, tower_name: str = "user_tower") -> None: super().__init__() setattr(self, tower_name, module) - # Snapshot the tower's current view via the property (which for - # `HSTUMatchItemTower` returns the scalar view iff - # `_is_inference=True`). Wrapper construction must happen *after* - # the inference flag is set on the inner tower (see - # `tzrec/main.py::export`). - self._features = module.features - self._feature_groups = module.feature_groups self._tower_name = tower_name @property def features(self) -> List[BaseFeature]: - """Snapshot of the wrapped tower's features at construction time.""" - return self._features + """Live read of the wrapped tower's features. + + For `HSTUMatchItemTower`, this reflects the current view (training + or scalar export) per `_is_inference`. No snapshot. + """ + return getattr(self, self._tower_name).features @property def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: - """Snapshot of the wrapped tower's feature_groups at construction time.""" - return self._feature_groups + """Live read of the wrapped tower's feature_groups.""" + return getattr(self, self._tower_name).feature_groups def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Forward the tower. @@ -511,24 +508,23 @@ def __init__(self, module: nn.Module, tower_name: str = "user_tower") -> None: # `HSTUMatchItemTower` after `set_is_inference(True)`, this is the # scalar export view (one row per item, `{group_name}.query`); # otherwise it's the training view (jagged, `{group_name}.sequence`). - # Wrapper construction must happen *after* the inference flag is - # set on the inner tower -- see `tzrec/main.py::export`. + # The EmbeddingGroup itself owns nn.Parameters so it must be a + # construction-time snapshot; the `features`/`feature_groups` + # properties below stay live via lazy reads on the inner tower. self.embedding_group = EmbeddingGroup(module.features, module.feature_groups) setattr(self, tower_name, module) - self._features = module.features - self._feature_groups = module.feature_groups self._tower_name = tower_name self._group_name = module._group_name @property def features(self) -> List[BaseFeature]: - """Snapshot of the wrapped tower's features at construction time.""" - return self._features + """Live read of the wrapped tower's features (no snapshot).""" + return getattr(self, self._tower_name).features @property def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: - """Snapshot of the wrapped tower's feature_groups at construction time.""" - return self._feature_groups + """Live read of the wrapped tower's feature_groups.""" + return getattr(self, self._tower_name).feature_groups def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Forward the tower. diff --git a/tzrec/models/model.py b/tzrec/models/model.py index 5b397755..bc1ada4f 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -73,6 +73,16 @@ def __init__( self._train_metric_modules = nn.ModuleDict() + @property + def features(self) -> List[BaseFeature]: + """Model's features (default property forwarding to `self._features`).""" + return self._features + + @property + def feature_groups(self) -> List[FeatureGroupConfig]: + """Model's feature_groups (default forward to `self._feature_groups`).""" + return self._feature_groups + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Predict the model. @@ -336,43 +346,27 @@ class ScriptWrapper(BaseModule): def __init__(self, module: nn.Module) -> None: super().__init__() self.model = module - # Snapshot via the inner module's view properties (defaults forward - # to the underscore fields; HSTUMatchItemTower / wrapper towers - # override). Wrapper construction must happen *after* the inference - # flag is set on the inner module -- see `tzrec/main.py::export`. - self._features = self._features_from(module) - self._feature_groups = self._feature_groups_from(module) # Propagate tower identity (set by TowerWoEGWrapper / TowerWrapper) # so export_util.py can route item-tower export through # `pipeline_config.item_input_path` instead of `train_input_path`. if hasattr(self.model, "_tower_name"): self._tower_name = self.model._tower_name self._data_parser = DataParser( - self._features, + self.model.features, sampler_type=str(module.sampler_type) if hasattr(module, "sampler_type") else None, ) - @staticmethod - def _features_from(module: nn.Module) -> List["BaseFeature"]: - return module.features if hasattr(module, "features") else module._features - - @staticmethod - def _feature_groups_from(module: nn.Module) -> List: - if hasattr(module, "feature_groups"): - return module.feature_groups - return module._feature_groups - @property - def features(self) -> List["BaseFeature"]: - """Snapshot of the wrapped module's features at construction time.""" - return self._features + def features(self) -> List[BaseFeature]: + """Live read of the wrapped module's features (no snapshot).""" + return self.model.features @property - def feature_groups(self) -> List: - """Snapshot of the wrapped module's feature_groups at construction time.""" - return self._feature_groups + def feature_groups(self) -> List[FeatureGroupConfig]: + """Live read of the wrapped module's feature_groups.""" + return self.model.feature_groups def get_batch( self, diff --git a/tzrec/models/tdm.py b/tzrec/models/tdm.py index b706160c..0be8da1a 100644 --- a/tzrec/models/tdm.py +++ b/tzrec/models/tdm.py @@ -132,6 +132,16 @@ def __init__( EmbeddingGroup(seq_group_query_fea, self._feature_groups), ) + @property + def features(self) -> List[BaseFeature]: + """Query-side features consumed by the TDM embedding.""" + return self._features + + @property + def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: + """Single feature_group consumed by the TDM embedding.""" + return self._feature_groups + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Forward the embedding. diff --git a/tzrec/utils/export_util.py b/tzrec/utils/export_util.py index 6f35a41b..cf88da68 100644 --- a/tzrec/utils/export_util.py +++ b/tzrec/utils/export_util.py @@ -161,7 +161,7 @@ def export_model_normal( # make dataparser to get user feats before create model data_config = copy.deepcopy(pipeline_config.data_config) - features = cast(List[BaseFeature], model._features) + features = cast(List[BaseFeature], model.features) if acc_utils.is_cuda_export(): # export batch_size too large may OOM in compile phase max_batch_size = acc_utils.get_max_export_batch_size() @@ -300,7 +300,7 @@ def export_model_normal( # training feature_group names. if hasattr(model, "_feature_groups"): pipeline_config.model_config.ClearField("feature_groups") - pipeline_config.model_config.feature_groups.extend(model._feature_groups) + pipeline_config.model_config.feature_groups.extend(model.feature_groups) config_util.save_message( pipeline_config, os.path.join(save_dir, "pipeline.config") ) @@ -731,7 +731,7 @@ def _all_keys_used_once( # make dataparser to get user feats before create model data_config = copy.deepcopy(pipeline_config.data_config) - features = cast(List[BaseFeature], model._features) + features = cast(List[BaseFeature], model.features) data_config.num_workers = 1 data_config.batch_size = acc_utils.get_max_export_batch_size() # Item-tower export: same routing as `export_model_normal`. @@ -1186,8 +1186,8 @@ def _seq_feat_name(seq_name: str) -> str: dense_gm = _prune_unused_param_and_buffer(dense_gm) seq_share_groups = _compute_seq_share_groups( - features=cast(List[BaseFeature], model._features), - feature_groups=model._feature_groups, + features=cast(List[BaseFeature], model.features), + feature_groups=model.feature_groups, ) meta_info = { "seq_tensor_names": seq_tensor_names, From 11ef729b0e80fb3a5cb656039c36d15189a17363 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 20 May 2026 17:59:36 +0800 Subject: [PATCH 09/16] [doc] explain why the prefix-resolve if-else is load-bearing Three-way disambiguation: literal match in field_names wins, so already-qualified attrs ("cand_seq__video_id"), top-level item-side attrs in mixed configs (lookup_feature's "cat_map" whose sequence_fields exclude it), and bare candidate sub-features ("video_id" -> "cand_seq__video_id") all resolve correctly. Verified by `tzrec.datasets.dataset_test::test_launch_sampler_cluster_multi_attr_strip_decision_matrix` which exercises the mixed-attrs case (attr_fields=["cat_map", "click_seq__cat_key"]); unconditional prefix would break "cat_map" by sending it through as "click_seq__cat_map" (not in parquet schema). Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/datasets/dataset.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index 954cd8a7..7e91df7a 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -222,12 +222,22 @@ def launch_sampler_cluster( if self._sampler_seq_delim: if self._sampler_seq_prefix: # Grouped sequence: resolve bare candidate sub-feature - # names against the qualified flattened schema using - # the authoritative parent prefix from feature configs - # (RTP-safe, no name string split). Has to run *before* - # the outer-list strip so `consumed` carries resolved - # names. Top-level `sequence_id_feature` skips this -- - # its attr_fields are already bare/qualified-as-itself. + # names against the flattened schema using the + # authoritative parent prefix from feature configs + # (RTP-safe, no name string split). Literal match wins, + # so: + # - already-qualified attrs (e.g. "cand_seq__video_id") + # stay as-is; + # - bare candidate sub-features (e.g. "video_id") get + # prefixed to "cand_seq__video_id"; + # - top-level item-side attrs in the same sampler + # config (e.g. lookup_feature's `cat_map` whose + # sequence_fields exclude it) stay as-is because + # they literally exist in the parquet schema. + # Must run *before* the outer-list strip so `consumed` + # carries resolved names. Top-level `sequence_id_feature` + # skips this -- its attr_fields are bare top-level + # columns. field_names = {f.name for f in sampler_fields} sampler_config.attr_fields[:] = [ self._sampler_seq_prefix + a From e7836ff710856b2e32bf913adb81e80e3c13754f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 20 May 2026 18:44:39 +0800 Subject: [PATCH 10/16] [refactor] gate prefix-resolve on candidate-side sequence_input_names (feature-config only) Replace the parquet-schema membership check (`prefix + a in field_names`) with a feature-config-derived check (`prefix + a in self._sampler_seq_inputs`). `_sampler_seq_inputs` is the union of `sequence_input_names` across all candidate-side grouped sequence sub-features (those sharing `sequence_name` with `item_id_field`'s parent), precomputed at `BaseDataset.__init__` time. Authoritative source: feature configs. Three semantic consequences vs the old `field_names` check: - More precise: an attr is prefixed only when `prefix + a` is a known candidate-sequence input, not just any parquet column. Spurious same-name parquet columns can't trick the resolution. - Same outcome for the existing test cases: HSTUMatch ("video_id" -> "cand_seq__video_id" because "cand_seq__video_id" is a candidate sequence input); lookup_feature mixed config (`cat_map` stays as-is because "click_seq__cat_map" is NOT a sequence input; `click_seq__cat_key` stays as-is because double-prefixing lands outside the set). - No parquet-schema dependency: the resolution doesn't need to enumerate `sampler_fields` per launch_sampler_cluster call. Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/datasets/dataset.py | 54 ++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index 7e91df7a..cba86293 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -169,14 +169,21 @@ def __init__( # Candidate-side sequence state. When `item_id_field` is a sequence # input (top-level `sequence_id_feature` OR grouped sequence # sub-feature input), `_sampler_seq_delim` is the parent feature's - # `sequence_delim`. For grouped sub-features only, `_sampler_seq_prefix` - # is the flatten prefix `f"{sequence_name}{_underline}"` (used to - # resolve bare `attr_fields` to the qualified parquet column name). - # For top-level `sequence_id_feature` (no flatten), the prefix stays - # empty and the resolve is a no-op. All sourced from the matching - # `BaseFeature` via `sequence_input_names`, the authoritative input. + # `sequence_delim`. For grouped sub-features only: + # * `_sampler_seq_prefix` is the flatten prefix + # `f"{sequence_name}{_underline}"`. + # * `_sampler_seq_inputs` is the set of all candidate-side + # sequence input names (union of `sequence_input_names` across + # features that share `sequence_name` with item_id_field's + # parent). Used to gate bare-attr resolution: an attr is + # prefixed iff `prefix + a` is a known candidate sequence + # input. Avoids false positives on top-level item-side attrs + # (e.g. lookup_feature's `cat_map` whose sequence_fields + # exclude it). Sourced from feature configs only -- no + # parquet-schema dependency. self._sampler_seq_delim: str = "" self._sampler_seq_prefix: str = "" + self._sampler_seq_inputs: set = set() if self._sampler_item_id_field is not None: for feature in features: if ( @@ -190,6 +197,10 @@ def __init__( self._sampler_seq_prefix = ( feature.sequence_name + feature._underline ) + seq_name = feature.sequence_name + for f in features: + if f.is_grouped_sequence and f.sequence_name == seq_name: + self._sampler_seq_inputs.update(f.sequence_input_names) break self._fg_mode = data_config.fg_mode @@ -221,28 +232,19 @@ def launch_sampler_cluster( sampler_fields = self.input_fields if self._sampler_seq_delim: if self._sampler_seq_prefix: - # Grouped sequence: resolve bare candidate sub-feature - # names against the flattened schema using the - # authoritative parent prefix from feature configs - # (RTP-safe, no name string split). Literal match wins, - # so: - # - already-qualified attrs (e.g. "cand_seq__video_id") - # stay as-is; - # - bare candidate sub-features (e.g. "video_id") get - # prefixed to "cand_seq__video_id"; - # - top-level item-side attrs in the same sampler - # config (e.g. lookup_feature's `cat_map` whose - # sequence_fields exclude it) stay as-is because - # they literally exist in the parquet schema. - # Must run *before* the outer-list strip so `consumed` - # carries resolved names. Top-level `sequence_id_feature` - # skips this -- its attr_fields are bare top-level - # columns. - field_names = {f.name for f in sampler_fields} + # Grouped sequence: prefix a bare attr iff `prefix + a` + # is a known candidate-side sequence input (from + # feature configs). Already-qualified attrs land + # outside the candidate input set (their `prefix + a` + # double-prefixes), so they stay as-is. Top-level + # item-side attrs (e.g. lookup_feature's `cat_map` + # whose sequence_fields exclude it) likewise stay + # as-is. RTP-safe, no name string split, no parquet- + # schema dependency. Must run *before* the outer-list + # strip so `consumed` carries resolved names. sampler_config.attr_fields[:] = [ self._sampler_seq_prefix + a - if a not in field_names - and self._sampler_seq_prefix + a in field_names + if self._sampler_seq_prefix + a in self._sampler_seq_inputs else a for a in sampler_config.attr_fields ] From 1e885b53165cab4e59dd80bd6e966b58d549119e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 20 May 2026 18:55:37 +0800 Subject: [PATCH 11/16] [refactor] move item_input_path out of proto; expose as generic --data_input_path CLI flag `item_input_path` was the wrong abstraction in two ways: 1. It lived in `pipeline.proto` as a model-config field, but it isn't a model-shape concern -- it's a per-invocation override of the export dataloader's input path. Storing it in the proto forces every recall pipeline to ship a config edit just to point export at a different table. 2. The name embedded a hard policy (`item_*`) into a layer (`export_model`) that doesn't care about towers. This commit: - Removes `optional string item_input_path` from `pipeline.proto`. - Adds `--data_input_path` CLI flag on `tzrec/export.py` and threads it through `main.py::export` -> `export_util.export_model` -> `export_model_normal` / `export_rtp_model`. - Renames `item_input_path` -> `data_input_path` in `export_model`. The semantics are now generic: if `data_input_path` is non-empty it overrides `train_input_path` for the predict-mode dataloader; otherwise fall back to the pipeline config. No tower-coupling at this layer. - Drops the `is_item_tower` gate inside `export_util.py` -- predict-mode bypasses the sampler anyway, so the previous `data_config.ClearField( "sampler")` branch was dead too. - Policy "only the item tower receives the override" lives in `main.py::export`'s match-tower loop, where it belongs. - Integration test (`MatchIntegrationTest.test_hstu_with_fg_train_eval`) passes the override via `utils.test_export(..., data_input_path=...)`, matching how real callers will use the CLI flag. Co-Authored-By: Claude Opus 4.7 --- tzrec/export.py | 12 +++++++ tzrec/main.py | 12 +++++++ tzrec/protos/pipeline.proto | 9 ------ tzrec/tests/configs/hstu_kuairand_1k.config | 1 - tzrec/tests/match_integration_test.py | 1 + tzrec/tests/utils.py | 5 ++- tzrec/utils/export_util.py | 36 +++++++++++---------- 7 files changed, 48 insertions(+), 28 deletions(-) diff --git a/tzrec/export.py b/tzrec/export.py index 3e4eba45..32bc935f 100644 --- a/tzrec/export.py +++ b/tzrec/export.py @@ -48,6 +48,17 @@ help="JSON string of extra key/value pairs merged into model_acc.json, " 'e.g. \'{"cand_seq_pk": "cand_seq"}\' for DlrmHSTU.', ) + parser.add_argument( + "--data_input_path", + type=str, + default=None, + help="Optional input path override for export's predict-mode " + "dataloader. When set, the sample batch is read from this path " + "instead of `train_input_path`. Useful for recall-model item-tower " + "export with a one-row-per-item table whose schema matches the " + "scalar export view (training-shape sequence rows in " + "`train_input_path` would fail the scalar parser).", + ) args, extra_args = parser.parse_known_args() additional_export_config = ( @@ -62,4 +73,5 @@ checkpoint_path=args.checkpoint_path, asset_files=args.asset_files, additional_export_config=additional_export_config, + data_input_path=args.data_input_path, ) diff --git a/tzrec/main.py b/tzrec/main.py index b4603088..3060e819 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -895,6 +895,7 @@ def export( checkpoint_path: Optional[str] = None, asset_files: Optional[str] = None, additional_export_config: Optional[Dict[str, str]] = None, + data_input_path: Optional[str] = None, ) -> None: """Export a EasyRec model. @@ -906,6 +907,11 @@ def export( asset_files (str, optional): more files will be copied to export_dir. additional_export_config (dict, optional): extra key/value pairs merged into model_acc.json (e.g. ``{"cand_seq_pk": "cand_seq"}`` for DlrmHSTU). + data_input_path (str, optional): override for export's predict-mode + dataloader input path. When set, used instead of + `pipeline_config.train_input_path`. For recall models this only + applies to the item-tower export (the user tower keeps reading + `train_input_path`). """ is_rank_zero = int(os.environ.get("RANK", 0)) == 0 @@ -974,6 +980,11 @@ def export( # built off scalar features. tower = InferWrapper(wrapper(module, name)) tower_export_dir = os.path.join(export_dir, name.replace("_tower", "")) + # data_input_path applies only to the item tower (whose + # scalar export view can't parse `train_input_path`'s + # training-shape sequence rows). The user tower reads + # the standard `train_input_path`. + tower_input_path = data_input_path if name == "item_tower" else None export_model( ori_pipeline_config, tower, @@ -981,6 +992,7 @@ def export( tower_export_dir, assets=assets, additional_export_config=additional_export_config, + data_input_path=tower_input_path, ) elif isinstance(model.model, TDM): for name, module in model.model.named_children(): diff --git a/tzrec/protos/pipeline.proto b/tzrec/protos/pipeline.proto index 327ba357..cd86dfb3 100644 --- a/tzrec/protos/pipeline.proto +++ b/tzrec/protos/pipeline.proto @@ -26,13 +26,4 @@ message EasyRecConfig { repeated FeatureConfig feature_configs = 8; optional ModelConfig model_config = 9; - - // Optional item-only table for recall-model item-tower export. - // When set, item-tower export reads its sample batch from this - // path (one row per item, schema matching the scalar export view) - // instead of from `train_input_path`. The item table is typically - // already prepared upstream (used by `negative_sampler.input_path` - // too) -- this avoids feeding the export path with training-shape - // sequence rows. - optional string item_input_path = 10; } diff --git a/tzrec/tests/configs/hstu_kuairand_1k.config b/tzrec/tests/configs/hstu_kuairand_1k.config index 706e8af1..89d99987 100644 --- a/tzrec/tests/configs/hstu_kuairand_1k.config +++ b/tzrec/tests/configs/hstu_kuairand_1k.config @@ -1,6 +1,5 @@ train_input_path: "data/test/kuairand-1k-match-train-c4096-s100.parquet" eval_input_path: "data/test/kuairand-1k-match-eval-c4096-s100.parquet" -item_input_path: "data/test/kuairand-1k-match-item-c1.parquet" model_dir: "experiments/kuairand/hstu_match" train_config { sparse_optimizer { diff --git a/tzrec/tests/match_integration_test.py b/tzrec/tests/match_integration_test.py index e2159de8..70c6e60b 100644 --- a/tzrec/tests/match_integration_test.py +++ b/tzrec/tests/match_integration_test.py @@ -392,6 +392,7 @@ def test_hstu_with_fg_train_eval(self): os.path.join(self.test_dir, "pipeline.config"), self.test_dir, env_str=f"{hstu_env} ENABLE_AOT=1", + data_input_path="data/test/kuairand-1k-match-item-c1.parquet", ) if self.success: # Item tower scalar export view: predict over the item-only diff --git a/tzrec/tests/utils.py b/tzrec/tests/utils.py index 333a8a2f..7d9a9296 100644 --- a/tzrec/tests/utils.py +++ b/tzrec/tests/utils.py @@ -1040,6 +1040,7 @@ def test_export( asset_files: str = "", env_str: str = "", additional_export_config: str = "", + data_input_path: str = "", ) -> bool: """Run export integration test.""" log_dir = os.path.join(test_dir, "log_export") @@ -1056,7 +1057,9 @@ def test_export( if asset_files: cmd_str += f"--asset_files {asset_files} " if additional_export_config: - cmd_str += f"--additional_export_config '{additional_export_config}'" + cmd_str += f"--additional_export_config '{additional_export_config}' " + if data_input_path: + cmd_str += f"--data_input_path {data_input_path} " return misc_util.run_cmd( cmd_str, os.path.join(test_dir, "log_export.txt"), timeout=1800 diff --git a/tzrec/utils/export_util.py b/tzrec/utils/export_util.py index cf88da68..daa08e27 100644 --- a/tzrec/utils/export_util.py +++ b/tzrec/utils/export_util.py @@ -80,8 +80,14 @@ def export_model( save_dir: str, assets: Optional[List[str]] = None, additional_export_config: Optional[Dict[str, str]] = None, + data_input_path: Optional[str] = None, ) -> None: - """Export a EasyRec model, may be a part of model in PipelineConfig.""" + """Export a EasyRec model, may be a part of model in PipelineConfig. + + `data_input_path` (optional): override for the predict-mode dataloader + input path. When set, used instead of `pipeline_config.train_input_path`. + Wired from the `tzrec/export.py` CLI's `--data_input_path` flag. + """ use_rtp = env_util.use_rtp() impl = export_rtp_model if use_rtp else export_model_normal @@ -100,6 +106,7 @@ def export_model( assets=assets, use_local_cache_dir=use_local_cache_dir, additional_export_config=additional_export_config, + data_input_path=data_input_path, ) if use_local_cache_dir and int(os.environ.get("LOCAL_RANK", 0)) == 0: logger.info(f"uploading {local_path} to {save_dir}.") @@ -142,6 +149,7 @@ def export_model_normal( save_dir: str, assets: Optional[List[str]] = None, additional_export_config: Optional[Dict[str, str]] = None, + data_input_path: Optional[str] = None, **kwargs: Any, ) -> None: """Export a EasyRec model on aliyun.""" @@ -168,16 +176,12 @@ def export_model_normal( data_config.batch_size = min(data_config.batch_size, max_batch_size) logger.info("using new batch_size: %s in export", data_config.batch_size) data_config.num_workers = 1 - # Item-tower export: read the sample batch from `item_input_path` - # (one row per item, schema matching the scalar export view) instead - # of `train_input_path` (which holds training-shape sequence rows). - # Also clear the sampler so the predict dataloader doesn't try to - # launch a sampler for item-only rows. - is_item_tower = getattr(model, "_tower_name", None) == "item_tower" - input_path = pipeline_config.train_input_path - if is_item_tower and pipeline_config.HasField("item_input_path"): - input_path = pipeline_config.item_input_path - data_config.ClearField("sampler") + # Predict-mode dataloader input: caller may override `train_input_path` + # via `data_input_path` (CLI flag `--data_input_path` on + # `tzrec/export.py`). Used by recall-model item-tower export to read a + # one-row-per-item table matching the scalar export view; the user + # tower receives `data_input_path=None` and reads `train_input_path`. + input_path = data_input_path or pipeline_config.train_input_path dataloader = create_dataloader(data_config, features, input_path, mode=Mode.PREDICT) ckpt_param_map_path = None @@ -686,6 +690,7 @@ def export_rtp_model( save_dir: str, assets: Optional[List[str]] = None, use_local_cache_dir: bool = False, + data_input_path: Optional[str] = None, **kwargs: Any, ) -> None: """Export a EasyRec model on RTP.""" @@ -734,12 +739,9 @@ def _all_keys_used_once( features = cast(List[BaseFeature], model.features) data_config.num_workers = 1 data_config.batch_size = acc_utils.get_max_export_batch_size() - # Item-tower export: same routing as `export_model_normal`. - is_item_tower = getattr(model, "_tower_name", None) == "item_tower" - input_path = pipeline_config.train_input_path - if is_item_tower and pipeline_config.HasField("item_input_path"): - input_path = pipeline_config.item_input_path - data_config.ClearField("sampler") + # Same routing as `export_model_normal`: caller-supplied + # `data_input_path` overrides `train_input_path`. + input_path = data_input_path or pipeline_config.train_input_path dataloader = create_dataloader(data_config, features, input_path, mode=Mode.PREDICT) batch = next(iter(dataloader)) data = batch.to(device).to_dict(sparse_dtype=torch.int64) From 59af910f81caab08cf7602de8e762689ca7cc47e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 20 May 2026 20:23:14 +0800 Subject: [PATCH 12/16] [bugfix] MatchTower: expose features/feature_groups properties (CI fix) Commit 35b18cb made `TowerWrapper` / `ScriptWrapper` read view-state via `@property` (`getattr(self, self._tower_name).features`), and added the property to `MatchTowerWoEG`, `BaseModel`, `TDMEmbedding`. But `MatchTower` (the parent of DSSMTower / DATTower / MIND*Tower) was missed: it sets the underscore fields in `__init__` but never exposes the no-underscore property. `ScriptWrapper.__init__` then crashes during the match-tower wrap with `AttributeError: features` for any non-HSTU match model. Add the two properties to `MatchTower`, mirroring `MatchTowerWoEG`. No behavioural change for HSTU; restores DSSM / MIND / DAT export. Co-Authored-By: Claude Opus 4.7 --- tzrec/models/match_model.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tzrec/models/match_model.py b/tzrec/models/match_model.py index 84b75286..7e997ad6 100644 --- a/tzrec/models/match_model.py +++ b/tzrec/models/match_model.py @@ -144,6 +144,16 @@ def __init__( self.group_variational_dropouts = None self.group_variational_dropout_loss = {} + @property + def features(self) -> List[BaseFeature]: + """Features the tower exposes to its wrapper.""" + return self._features + + @property + def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: + """Feature groups the tower exposes to its wrapper.""" + return self._feature_groups + def init_input(self) -> None: """Build embedding group and group variational dropout.""" self.embedding_group = EmbeddingGroup(self._features, self._feature_groups) From ae17668cddcc1a4424b3671f0c8796966ba66a4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Thu, 21 May 2026 15:39:57 +0800 Subject: [PATCH 13/16] [refactor] cleanup PR comments; rename CLI flag back to --item_input_path - Revert dataset.py / dataset_test.py / utils.py to origin/master verbatim (post-merge they were a superset; user requested keeping master as-is). - Revert match_model.py `MatchTower` / `MatchTowerWoEG` / wrapper features/feature_groups property docstrings to master wording. - Simplify verbose PR-added comments across main.py, export_util.py, hstu.py, feature.py to one short line each. - Drop the multi-line "DISABLE_MMA_V3 / ENABLE_AOT" rationale block from the HSTU match integration test header. - CLI flag: rename `--data_input_path` back to `--item_input_path` on `tzrec/export.py`; `main.export()` accepts `item_input_path`. The internal `export_model()` function keeps the generic `data_input_path` parameter name. - Combine the three projection tests (`test_id_feature_projection_materializes_seq_defaults`, `test_projection_passes_through_create_features_as_scalar`, `test_raw_feature_projection_generic_oneof`) into a single `test_projection_materializes_defaults_and_passes_through_create_features` covering id_feature defaults materialization, create_features pass-through, and raw_feature oneof coverage. Co-Authored-By: Claude Opus 4.7 --- tzrec/datasets/dataset.py | 65 +++++---------------------- tzrec/export.py | 14 +++--- tzrec/features/feature.py | 15 +++---- tzrec/features/feature_test.py | 56 +++++++++-------------- tzrec/main.py | 38 +++++----------- tzrec/models/hstu.py | 43 +++++------------- tzrec/models/match_model.py | 19 +------- tzrec/tests/match_integration_test.py | 5 +-- tzrec/tests/utils.py | 6 +-- tzrec/utils/export_util.py | 30 +++---------- 10 files changed, 76 insertions(+), 215 deletions(-) diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index 98467b25..e589cfee 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -166,23 +166,9 @@ def __init__( ): self._selected_input_names = None - # Candidate-side sequence state. When `item_id_field` is a grouped - # sequence sub-feature input: - # * `_sampler_seq_delim` is the parent feature's `sequence_delim`. - # * `_sampler_seq_prefix` is the flatten prefix - # `feature.grouped_sequence_prefix`. - # * `_sampler_seq_inputs` is the set of all candidate-side - # sequence input names (union of `sequence_input_names` across - # features that share `sequence_name` with item_id_field's - # parent). Used to gate bare-attr resolution: an attr is - # prefixed iff `prefix + a` is a known candidate sequence - # input. Avoids false positives on top-level item-side attrs - # (e.g. lookup_feature's `cat_map` whose sequence_fields - # exclude it). Sourced from feature configs only -- no - # parquet-schema dependency. + # Sequence state when item_id_field is a grouped sequence sub-feature. self._sampler_seq_delim: str = "" self._sampler_seq_prefix: str = "" - self._sampler_seq_inputs: set = set() if self._sampler_item_id_field is not None: for feature in features: if self._sampler_item_id_field not in feature.sequence_input_names: @@ -196,10 +182,6 @@ def __init__( ) self._sampler_seq_delim = feature.sequence_delim self._sampler_seq_prefix = feature.grouped_sequence_prefix - seq_name = feature.sequence_name - for f in features: - if f.is_grouped_sequence and f.sequence_name == seq_name: - self._sampler_seq_inputs.update(f.sequence_input_names) break self._fg_mode = data_config.fg_mode @@ -224,50 +206,25 @@ def launch_sampler_cluster( """Launch sampler cluster and server.""" if self._data_config.HasField("sampler") and self._mode != Mode.PREDICT: sampler_type = self._data_config.WhichOneof("sampler") - # Deep-copy so any in-place rewrites below don't mutate - # `self._data_config`'s sampler sub-message. - sampler_config = copy.deepcopy(getattr(self._data_config, sampler_type)) + sampler_config = getattr(self._data_config, sampler_type) - sampler_fields = self.input_fields + # Rewrite bare attr_fields to flattened (`video_id` -> + # `cand_seq__video_id`); deep-copy so data_config isn't mutated. if self._sampler_seq_prefix: - # Grouped sequence: prefix a bare attr iff `prefix + a` - # is a known candidate-side sequence input (from - # feature configs). Already-qualified attrs land - # outside the candidate input set (their `prefix + a` - # double-prefixes), so they stay as-is. Top-level - # item-side attrs (e.g. lookup_feature's `cat_map` - # whose sequence_fields exclude it) likewise stay - # as-is. RTP-safe, no name string split, no parquet- - # schema dependency. Must run *before* the outer-list - # strip so `consumed` carries resolved names. + sampler_config = copy.deepcopy(sampler_config) sampler_config.attr_fields[:] = [ - self._sampler_seq_prefix + a - if self._sampler_seq_prefix + a in self._sampler_seq_inputs - else a - for a in sampler_config.attr_fields + self._sampler_seq_prefix + a for a in sampler_config.attr_fields ] + # Strip the per-row positive-grouping outer list on attr_fields + # columns so the sampler emits scalar negs. + sampler_fields = self.input_fields if self._sampler_seq_delim: - # Multi-positive sampling: when the sampler's item_id_field - # is itself a sequence-positive train column, the per-row - # outer list on the candidate sequence's item-side attrs is - # the positive-grouping container. Strip the outer list - # only for the candidate-sequence attrs in `attr_fields` - # (filtered by `_sampler_seq_prefix`). Excluded: top-level - # item-side attrs from the same lookup feature whose outer - # list is multi-value (e.g. `cat_map`); other grouped - # sequences' sub-features (e.g. uih_seq__*); and the - # `item_id_field` itself -- the sampler never inspects - # their type. - consumed = { - a - for a in sampler_config.attr_fields - if a.startswith(self._sampler_seq_prefix) - } + sampler_attrs = set(sampler_config.attr_fields) sampler_fields = [ pa.field(f.name, f.type.value_type) if ( - f.name in consumed + f.name in sampler_attrs and (pa.types.is_list(f.type) or pa.types.is_large_list(f.type)) ) else f diff --git a/tzrec/export.py b/tzrec/export.py index 32bc935f..2286ae61 100644 --- a/tzrec/export.py +++ b/tzrec/export.py @@ -49,15 +49,13 @@ 'e.g. \'{"cand_seq_pk": "cand_seq"}\' for DlrmHSTU.', ) parser.add_argument( - "--data_input_path", + "--item_input_path", type=str, default=None, - help="Optional input path override for export's predict-mode " - "dataloader. When set, the sample batch is read from this path " - "instead of `train_input_path`. Useful for recall-model item-tower " - "export with a one-row-per-item table whose schema matches the " - "scalar export view (training-shape sequence rows in " - "`train_input_path` would fail the scalar parser).", + help="Optional input path for the item-tower's predict-mode " + "dataloader. When set, the item tower reads from this path " + "(a one-row-per-item table matching the scalar export view) " + "instead of `train_input_path`.", ) args, extra_args = parser.parse_known_args() @@ -73,5 +71,5 @@ checkpoint_path=args.checkpoint_path, asset_files=args.asset_files, additional_export_config=additional_export_config, - data_input_path=args.data_input_path, + item_input_path=args.item_input_path, ) diff --git a/tzrec/features/feature.py b/tzrec/features/feature.py index 03a84ce0..d7b90c23 100644 --- a/tzrec/features/feature.py +++ b/tzrec/features/feature.py @@ -1235,21 +1235,16 @@ def project_grouped_sequence_feature_to_scalar( ) -> feature_pb2.FeatureConfig: """Return a scalar export FeatureConfig for a grouped sequence sub-feature. - The grouped sub-feature's config is a `SeqFeatureConfig`; rewrap it as a - top-level `FeatureConfig` so `create_features` builds it as a scalar - feature. Materializes the source feature's effective `default_value` and - `value_dim` onto the scalar proto so the exported feature is - behaviorally identical to the training sub-feature -- without this, - `id_feature.value_dim` and `default_value` resolve differently in - scalar mode (`0` / `""`) than in sequence mode (`1` / `"0"`) - (see feature.py:515-517, 556-561). + Rewraps the inner ``SeqFeatureConfig`` as a top-level ``FeatureConfig`` + and materializes the source's effective ``default_value`` / ``value_dim`` + so the exported scalar feature matches the training sub-feature + (otherwise scalar mode defaults differ from sequence mode). Args: feature: a grouped sequence sub-feature. Returns: - a fresh FeatureConfig suitable for `create_features()` to construct - as a top-level scalar feature. + a fresh FeatureConfig for ``create_features()`` to build as scalar. """ if not feature.is_grouped_sequence: raise ValueError( diff --git a/tzrec/features/feature_test.py b/tzrec/features/feature_test.py index ba5328ba..c72409c7 100644 --- a/tzrec/features/feature_test.py +++ b/tzrec/features/feature_test.py @@ -764,8 +764,9 @@ def _build_grouped(self, seq_sub_cfg): ] return feature_lib.create_features(feature_cfgs) - def test_id_feature_projection_materializes_seq_defaults(self): - sub_cfg = feature_pb2.SeqFeatureConfig( + def test_projection_materializes_defaults_and_passes_through_create_features(self): + # id_feature: default_value / value_dim materialization + create_features. + id_sub_cfg = feature_pb2.SeqFeatureConfig( id_feature=feature_pb2.IdFeature( feature_name="video_id", expression="item:video_id", @@ -773,49 +774,44 @@ def test_id_feature_projection_materializes_seq_defaults(self): num_buckets=10000000, ) ) - features = self._build_grouped(sub_cfg) + features = self._build_grouped(id_sub_cfg) self.assertEqual(len(features), 1) sub_feature = features[0] self.assertTrue(sub_feature.is_grouped_sequence) - # Sequence-effective defaults on the source feature. + # Sequence-effective defaults on the source. self.assertEqual(sub_feature.default_value, "0") self.assertEqual(sub_feature.value_dim, 1) scalar_cfg = feature_lib.project_grouped_sequence_feature_to_scalar(sub_feature) self.assertEqual(scalar_cfg.WhichOneof("feature"), "id_feature") - # Sequence-effective defaults carried into the scalar proto so the - # exported feature is behaviorally identical to the training - # sub-feature -- otherwise an id_feature in scalar mode resolves - # default_value to "" and value_dim to 0, diverging from the - # training "0" / 1 it was projected from. + # Materialized onto the scalar proto. self.assertEqual(scalar_cfg.id_feature.default_value, "0") self.assertTrue(scalar_cfg.id_feature.HasField("value_dim")) self.assertEqual(scalar_cfg.id_feature.value_dim, 1) - # Original sub-feature proto is not mutated. + # Source proto not mutated. self.assertEqual(sub_feature.feature_config.id_feature.default_value, "") self.assertFalse(sub_feature.feature_config.id_feature.HasField("value_dim")) - def test_projection_passes_through_create_features_as_scalar(self): - sub_cfg = feature_pb2.SeqFeatureConfig( - id_feature=feature_pb2.IdFeature( - feature_name="video_id", - expression="item:video_id", - embedding_dim=32, - num_buckets=10000000, - ) - ) - features = self._build_grouped(sub_cfg) - scalar_cfg = feature_lib.project_grouped_sequence_feature_to_scalar(features[0]) + # create_features rebuilds it as a top-level scalar feature. scalar_features = feature_lib.create_features([scalar_cfg]) self.assertEqual(len(scalar_features), 1) scalar = scalar_features[0] - # Bare sub-feature name without the cand_seq__ prefix. self.assertEqual(scalar.name, "video_id") self.assertFalse(scalar.is_grouped_sequence) - # Scalar context: value_dim is materialized to 1 from the source - # sub-feature's sequence-effective default. self.assertEqual(scalar.value_dim, 1) + # raw_feature: confirms the helper isn't hard-coded to id_feature. + raw_sub_cfg = feature_pb2.SeqFeatureConfig( + raw_feature=feature_pb2.RawFeature( + feature_name="watch_time", expression="user:watch_time" + ) + ) + raw_features = self._build_grouped(raw_sub_cfg) + raw_scalar_cfg = feature_lib.project_grouped_sequence_feature_to_scalar( + raw_features[0] + ) + self.assertEqual(raw_scalar_cfg.WhichOneof("feature"), "raw_feature") + def test_projection_rejects_non_grouped_feature(self): feature_cfgs = [ feature_pb2.FeatureConfig( @@ -826,18 +822,6 @@ def test_projection_rejects_non_grouped_feature(self): with self.assertRaisesRegex(ValueError, "is_grouped_sequence=False"): feature_lib.project_grouped_sequence_feature_to_scalar(features[0]) - def test_raw_feature_projection_generic_oneof(self): - # Confirms the helper is not hard-coded to id_feature; raw_feature - # works the same way. - sub_cfg = feature_pb2.SeqFeatureConfig( - raw_feature=feature_pb2.RawFeature( - feature_name="watch_time", expression="user:watch_time" - ) - ) - features = self._build_grouped(sub_cfg) - scalar_cfg = feature_lib.project_grouped_sequence_feature_to_scalar(features[0]) - self.assertEqual(scalar_cfg.WhichOneof("feature"), "raw_feature") - if __name__ == "__main__": unittest.main() diff --git a/tzrec/main.py b/tzrec/main.py index 3060e819..f8fdaec6 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -895,7 +895,7 @@ def export( checkpoint_path: Optional[str] = None, asset_files: Optional[str] = None, additional_export_config: Optional[Dict[str, str]] = None, - data_input_path: Optional[str] = None, + item_input_path: Optional[str] = None, ) -> None: """Export a EasyRec model. @@ -907,11 +907,9 @@ def export( asset_files (str, optional): more files will be copied to export_dir. additional_export_config (dict, optional): extra key/value pairs merged into model_acc.json (e.g. ``{"cand_seq_pk": "cand_seq"}`` for DlrmHSTU). - data_input_path (str, optional): override for export's predict-mode - dataloader input path. When set, used instead of - `pipeline_config.train_input_path`. For recall models this only - applies to the item-tower export (the user tower keeps reading - `train_input_path`). + item_input_path (str, optional): override for the item tower's + predict-mode dataloader input path. When set, the item tower + reads from this path instead of ``train_input_path``. """ is_rank_zero = int(os.environ.get("RANK", 0)) == 0 @@ -942,15 +940,9 @@ def export( sampler_type=None, ) InferWrapper = ScriptWrapper - # Set the inference flag on the inner model *before* wrapping. Every - # downstream wrapper (ScriptWrapper for non-match, TowerWoEGWrapper / - # TowerWrapper per match tower) snapshots view-dependent state at - # construction time (e.g. `HSTUMatchItemTower`'s lazy - # `features`/`feature_groups` properties; the wrappers' EmbeddingGroup), - # so the flag must already be True when the wrapping happens. - # `recursive_setattr` propagates the flag to all sub-modules including - # the inner towers, so the per-tower wrap below doesn't need its own - # toggle. + # Flip to inference *before* wrapping so view-dependent state + # (e.g. HSTUMatchItemTower's lazy properties, wrapper EmbeddingGroups) + # is snapshot from the scalar view. model.set_is_inference(True) model = InferWrapper(model) @@ -973,18 +965,12 @@ def export( wrapper = ( TowerWrapper if isinstance(module, MatchTower) else TowerWoEGWrapper ) - # The inference flag was already set on every sub-module by - # `model.set_is_inference(True)` above, so `HSTUMatchItemTower`'s - # lazy `features` / `feature_groups` properties return the - # scalar view here and the wrapper's `EmbeddingGroup` is - # built off scalar features. tower = InferWrapper(wrapper(module, name)) tower_export_dir = os.path.join(export_dir, name.replace("_tower", "")) - # data_input_path applies only to the item tower (whose - # scalar export view can't parse `train_input_path`'s - # training-shape sequence rows). The user tower reads - # the standard `train_input_path`. - tower_input_path = data_input_path if name == "item_tower" else None + # item-tower-only; user tower falls back to `train_input_path`. + tower_data_input_path = ( + item_input_path if name == "item_tower" else None + ) export_model( ori_pipeline_config, tower, @@ -992,7 +978,7 @@ def export( tower_export_dir, assets=assets, additional_export_config=additional_export_config, - data_input_path=tower_input_path, + data_input_path=tower_data_input_path, ) elif isinstance(model.model, TDM): for name, module in model.model.named_children(): diff --git a/tzrec/models/hstu.py b/tzrec/models/hstu.py index 20a5299d..2e3c6acd 100644 --- a/tzrec/models/hstu.py +++ b/tzrec/models/hstu.py @@ -160,24 +160,17 @@ def __init__( # tower_config.input names on the user-tower proto). Use the item-side # tower_config.input here, which equals feature_groups[0].group_name. self._group_name = tower_config.input - # Training view: candidate group is JAGGED_SEQUENCE; embedding_group - # emits the per-row jagged tensor at `{group_name}.sequence`. At - # export, `set_is_inference(True)` flips the flag below; the - # `features` / `feature_groups` properties then return the lazily- - # built scalar view (one row per item), and `forward()` reads - # `{group_name}.query` instead. Mlp / output Linear are sized off - # the training candidate group; the scalar view's per-feature - # embedding dim is identical, so no resize is needed. + # MLP sized off the training candidate group; the scalar view has + # identical per-feature embedding dim. candidate_dims = embedding_group.group_dims(f"{self._group_name}.sequence") candidate_total_dim = sum(candidate_dims) - # Lazy caches for the scalar export view; populated on first - # property access when `_is_inference` is True. None at training - # time -- non-export consumers pay zero cost. + # Lazy caches for the scalar export view (populated on first + # property access after `set_is_inference(True)`). self._features_scalar: Optional[List[BaseFeature]] = None self._feature_groups_scalar: Optional[List[model_pb2.FeatureGroupConfig]] = None - # Initialize explicitly: `MatchTowerWoEG` derives from `nn.Module`, - # not `BaseModule`, so `_is_inference` isn't set by the parent. + # `MatchTowerWoEG` derives from `nn.Module`, not `BaseModule`, + # so init `_is_inference` here. self._is_inference: bool = False if tower_config.HasField("mlp"): self.mlp: torch.nn.Module = MLP( @@ -194,13 +187,7 @@ def __init__( @property def features(self) -> List[BaseFeature]: - """Item-side features in the current view (training or scalar export). - - At training (`_is_inference=False`), returns the grouped sequence - sub-features the tower was constructed with. At export - (`_is_inference=True`), returns the lazily-built scalar - projection, cached for subsequent reads. - """ + """Item features (training: grouped sub-features; export: scalar projection).""" if self._is_inference: if self._features_scalar is None: self._build_scalar_features() @@ -209,7 +196,7 @@ def features(self) -> List[BaseFeature]: @property def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: - """Item-side feature_groups in the current view (see `features`).""" + """Item feature_groups in the current view (see ``features``).""" if self._is_inference: if self._feature_groups_scalar is None: self._build_scalar_features() @@ -217,12 +204,7 @@ def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: return self._feature_groups def _build_scalar_features(self) -> None: - """Build the scalar export view caches from the training features. - - Projects each grouped sequence sub-feature into a scalar export - feature; populates `_features_scalar` and `_feature_groups_scalar`. - Called at most once per tower instance (cached via the properties). - """ + """Project each grouped sequence sub-feature into a scalar export feature.""" scalar_configs = [ project_grouped_sequence_feature_to_scalar(f) for f in self._features ] @@ -248,18 +230,13 @@ def _build_scalar_features(self) -> None: def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: """Forward the item tower. - Reads from `{group_name}.sequence` (jagged) at training and from - `{group_name}.query` (scalar) at export. One-line conditional; - no cached `cand_key` attribute. The branch is on `_is_inference` - rather than dict-membership so the choice is FX/JIT traceable - (dict-membership checks aren't traceable as control flow). - Args: grouped_features: dictionary of embedded features from EmbeddingGroup. Returns: item embeddings of shape (sum_candidates, D). """ + # `.sequence` (jagged) at training, `.query` (scalar) at export. suffix = ".query" if self._is_inference else ".sequence" cand_emb = grouped_features[self._group_name + suffix] item_emb = self.mlp(cand_emb) diff --git a/tzrec/models/match_model.py b/tzrec/models/match_model.py index e54ddcb4..1678f51c 100644 --- a/tzrec/models/match_model.py +++ b/tzrec/models/match_model.py @@ -234,11 +234,7 @@ def __init__( @property def features(self) -> List[BaseFeature]: - """Tower's features. - - Default reads ``self._features``; overridden by towers that switch - views between training and export (see ``HSTUMatchItemTower``). - """ + """Tower's features (default property forwarding to ``self._features``).""" return self._features @property @@ -485,11 +481,7 @@ def __init__(self, module: nn.Module, tower_name: str = "user_tower") -> None: @property def features(self) -> List[BaseFeature]: - """Live read of the wrapped tower's features (no snapshot). - - For ``HSTUMatchItemTower``, reflects the current view (training - or scalar export) per ``_is_inference``. - """ + """Live read of the wrapped tower's features (no snapshot).""" return getattr(self, self._tower_name).features @property @@ -514,13 +506,6 @@ class TowerWoEGWrapper(nn.Module): def __init__(self, module: nn.Module, tower_name: str = "user_tower") -> None: super().__init__() - # Build EmbeddingGroup from the tower's *current view*: for - # `HSTUMatchItemTower` after `set_is_inference(True)`, this is the - # scalar export view (one row per item, `{group_name}.query`); - # otherwise it's the training view (jagged, `{group_name}.sequence`). - # The EmbeddingGroup itself owns nn.Parameters so it must be a - # construction-time snapshot; the `features`/`feature_groups` - # properties below stay live via lazy reads on the inner tower. self.embedding_group = EmbeddingGroup(module.features, module.feature_groups) setattr(self, tower_name, module) self._tower_name = tower_name diff --git a/tzrec/tests/match_integration_test.py b/tzrec/tests/match_integration_test.py index 70c6e60b..13ece933 100644 --- a/tzrec/tests/match_integration_test.py +++ b/tzrec/tests/match_integration_test.py @@ -370,9 +370,6 @@ def test_mind_train_eval_export(self): @unittest.skipIf(*gpu_unavailable) def test_hstu_with_fg_train_eval(self): - # DISABLE_MMA_V3=1: Triton 3.6 sm_90 WGMMA bug. ENABLE_AOT=1: HSTU - # uses TRITON kernels which require CUDA; AOT export keeps the - # forward on CUDA. Same pattern as dlrm_hstu's export test. hstu_env = "DISABLE_MMA_V3=1" self.success = utils.test_train_eval( "tzrec/tests/configs/hstu_kuairand_1k.config", @@ -392,7 +389,7 @@ def test_hstu_with_fg_train_eval(self): os.path.join(self.test_dir, "pipeline.config"), self.test_dir, env_str=f"{hstu_env} ENABLE_AOT=1", - data_input_path="data/test/kuairand-1k-match-item-c1.parquet", + item_input_path="data/test/kuairand-1k-match-item-c1.parquet", ) if self.success: # Item tower scalar export view: predict over the item-only diff --git a/tzrec/tests/utils.py b/tzrec/tests/utils.py index 7d9a9296..ef2fd3dc 100644 --- a/tzrec/tests/utils.py +++ b/tzrec/tests/utils.py @@ -1040,7 +1040,7 @@ def test_export( asset_files: str = "", env_str: str = "", additional_export_config: str = "", - data_input_path: str = "", + item_input_path: str = "", ) -> bool: """Run export integration test.""" log_dir = os.path.join(test_dir, "log_export") @@ -1058,8 +1058,8 @@ def test_export( cmd_str += f"--asset_files {asset_files} " if additional_export_config: cmd_str += f"--additional_export_config '{additional_export_config}' " - if data_input_path: - cmd_str += f"--data_input_path {data_input_path} " + if item_input_path: + cmd_str += f"--item_input_path {item_input_path} " return misc_util.run_cmd( cmd_str, os.path.join(test_dir, "log_export.txt"), timeout=1800 diff --git a/tzrec/utils/export_util.py b/tzrec/utils/export_util.py index daa08e27..c9bd5854 100644 --- a/tzrec/utils/export_util.py +++ b/tzrec/utils/export_util.py @@ -85,8 +85,7 @@ def export_model( """Export a EasyRec model, may be a part of model in PipelineConfig. `data_input_path` (optional): override for the predict-mode dataloader - input path. When set, used instead of `pipeline_config.train_input_path`. - Wired from the `tzrec/export.py` CLI's `--data_input_path` flag. + input path; falls back to `pipeline_config.train_input_path` when None. """ use_rtp = env_util.use_rtp() @@ -176,11 +175,6 @@ def export_model_normal( data_config.batch_size = min(data_config.batch_size, max_batch_size) logger.info("using new batch_size: %s in export", data_config.batch_size) data_config.num_workers = 1 - # Predict-mode dataloader input: caller may override `train_input_path` - # via `data_input_path` (CLI flag `--data_input_path` on - # `tzrec/export.py`). Used by recall-model item-tower export to read a - # one-row-per-item table matching the scalar export view; the user - # tower receives `data_input_path=None` and reads `train_input_path`. input_path = data_input_path or pipeline_config.train_input_path dataloader = create_dataloader(data_config, features, input_path, mode=Mode.PREDICT) @@ -202,11 +196,8 @@ def export_model_normal( if is_rank_zero: if not os.path.exists(save_dir): os.makedirs(save_dir) - # `set_is_inference(True)` is the caller's responsibility; it must - # be applied to the inner model *before* wrapping with - # `InferWrapper` (see `tzrec/main.py::export`) so wrapper-level - # snapshots (EmbeddingGroup, view-dependent features) pick up the - # inference-mode view at construction time. + # `set_is_inference(True)` applied in `tzrec/main.py::export` + # before wrapping; wrappers already see the inference-mode view. init_parameters(model, torch.device("cpu")) checkpoint_util.restore_model( @@ -297,11 +288,9 @@ def export_model_normal( pipeline_config = copy.copy(pipeline_config) pipeline_config.ClearField("feature_configs") pipeline_config.feature_configs.extend(feature_configs) - # Towers that own a view-specific feature_groups (e.g. - # `HSTUMatchItemTower` after `set_is_inference(True)` swaps to the - # scalar item view) must save those groups too, otherwise the - # exported pipeline.config pairs scalar feature_configs with stale - # training feature_group names. + # Persist the model's current feature_groups so towers with a + # view-specific group set (e.g. HSTUMatchItemTower scalar view) + # don't ship stale training-view group names. if hasattr(model, "_feature_groups"): pipeline_config.model_config.ClearField("feature_groups") pipeline_config.model_config.feature_groups.extend(model.feature_groups) @@ -739,16 +728,11 @@ def _all_keys_used_once( features = cast(List[BaseFeature], model.features) data_config.num_workers = 1 data_config.batch_size = acc_utils.get_max_export_batch_size() - # Same routing as `export_model_normal`: caller-supplied - # `data_input_path` overrides `train_input_path`. input_path = data_input_path or pipeline_config.train_input_path dataloader = create_dataloader(data_config, features, input_path, mode=Mode.PREDICT) batch = next(iter(dataloader)) data = batch.to(device).to_dict(sparse_dtype=torch.int64) - # `set_is_inference(True)` was applied in `tzrec/main.py::export` before - # wrapping -- the inner model + all sub-modules already carry the flag. - # Build Sharded Model planner = create_planner( device=device, @@ -1069,8 +1053,6 @@ def split_model( if not os.path.exists(graph_dir): os.makedirs(graph_dir) - # `set_is_inference(True)` was applied in `tzrec/main.py::export` before - # wrapping -- the inner model + all sub-modules already carry the flag. model.eval() tracer = Tracer() From 9f11f052d2c24cb1a0f966890e7cabeab6f20499 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Thu, 21 May 2026 16:13:30 +0800 Subject: [PATCH 14/16] [bugfix] address PR review: drop dead hasattr gate, lock scalar-view contract - `export_util.py`: drop `hasattr(model, "_feature_groups")` gate at the pipeline.config feature_groups write. Post-refactor, no wrapper sets the underscore field as an instance attribute (replaced with `@property`), so the gate never fired. The exported pipeline.config for the HSTU item tower kept stale training-view feature_groups while the feature_configs were the scalar projection -- predict reading the exported config would see a name mismatch. Write `model.feature_groups` unconditionally; the property cascade guarantees the right view on every model that reaches export. - `hstu_test.py`: migrate `_build_model` / `_build_batch` to the grouped-sequence pattern (`sequence_feature` wrapping sub-features named `video_id` in both `uih_seq` and `cand_seq`, sharing one embedding table via aligned `num_buckets` / `embedding_dim` / `embedding_name`). Inline scalar-view assertions into the existing `test_hstu_match` as the final step: stash `item_tower` before the graph_type branch wraps `hstu`, flip `set_is_inference(True)`, and assert the projected scalar names / non-grouped-sequence flag / feature_groups feature_names + group_name. Locks the lazy view contract HSTUMatchItemTower depends on for export. - `model.py`: drop the now-unused `_tower_name` propagation in `ScriptWrapper.__init__`. Originally load-bearing for an export-util routing path that's been replaced by the `--item_input_path` CLI flag (decided at `main.py::export`). No external reader of `ScriptWrapper._tower_name` remains. - `match_integration_test.py`: drop `hstu_env = "DISABLE_MMA_V3=1"` (no longer required by the underlying Triton path) and the inline set_is_inference note in `export_util.py`. Co-Authored-By: Claude Opus 4.7 --- tzrec/models/hstu_test.py | 94 ++++++++++++++++++--------- tzrec/models/model.py | 4 -- tzrec/tests/match_integration_test.py | 5 +- tzrec/utils/export_util.py | 8 +-- 4 files changed, 65 insertions(+), 46 deletions(-) diff --git a/tzrec/models/hstu_test.py b/tzrec/models/hstu_test.py index 8857a381..e03792ca 100644 --- a/tzrec/models/hstu_test.py +++ b/tzrec/models/hstu_test.py @@ -36,29 +36,51 @@ def _build_model(device: torch.device) -> HSTUMatch: - """Build an HSTUMatch model with standard test configuration.""" + """Build an HSTUMatch model with standard test configuration. + + Mirrors the production grouped-sequence pattern: `uih_seq` and + `cand_seq` each carry a `video_id` sub-feature with aligned bucket / + dim / `embedding_name` so the two flattened features share one + embedding table. `uih_seq` also carries the `historical_ts` raw + sub-feature for the timestamp dense path. + """ feature_cfgs = [ feature_pb2.FeatureConfig( - sequence_id_feature=feature_pb2.IdFeature( - feature_name="historical_ids", + sequence_feature=feature_pb2.SequenceFeature( + sequence_name="uih_seq", sequence_length=210, - embedding_dim=64, - num_buckets=3953, + features=[ + feature_pb2.SeqFeatureConfig( + id_feature=feature_pb2.IdFeature( + feature_name="video_id", + embedding_dim=64, + num_buckets=1000, + embedding_name="video_id_emb", + ) + ), + feature_pb2.SeqFeatureConfig( + raw_feature=feature_pb2.RawFeature( + feature_name="historical_ts", + ) + ), + ], ) ), feature_pb2.FeatureConfig( - sequence_id_feature=feature_pb2.IdFeature( - feature_name="item_id", + sequence_feature=feature_pb2.SequenceFeature( + sequence_name="cand_seq", sequence_length=10, sequence_delim=";", - embedding_dim=64, - num_buckets=1000, - ) - ), - feature_pb2.FeatureConfig( - sequence_raw_feature=feature_pb2.RawFeature( - feature_name="historical_ts", - sequence_length=210, + features=[ + feature_pb2.SeqFeatureConfig( + id_feature=feature_pb2.IdFeature( + feature_name="video_id", + embedding_dim=64, + num_buckets=1000, + embedding_name="video_id_emb", + ) + ), + ], ) ), ] @@ -66,17 +88,17 @@ def _build_model(device: torch.device) -> HSTUMatch: feature_groups = [ model_pb2.FeatureGroupConfig( group_name="uih", - feature_names=["historical_ids"], + feature_names=["uih_seq__video_id"], group_type=model_pb2.FeatureGroupType.JAGGED_SEQUENCE, ), model_pb2.FeatureGroupConfig( group_name="candidate", - feature_names=["item_id"], + feature_names=["cand_seq__video_id"], group_type=model_pb2.FeatureGroupType.JAGGED_SEQUENCE, ), model_pb2.FeatureGroupConfig( group_name="uih_timestamp", - feature_names=["historical_ts"], + feature_names=["uih_seq__historical_ts"], group_type=model_pb2.FeatureGroupType.JAGGED_SEQUENCE, ), ] @@ -140,12 +162,12 @@ def _build_batch(device: torch.device) -> Batch: pos_lengths = [1, 1]. """ sparse_feature = KeyedJaggedTensor.from_lengths_sync( - keys=["historical_ids", "item_id"], + keys=["uih_seq__video_id", "cand_seq__video_id"], values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 100, 200, 101, 201]), lengths=torch.tensor([3, 4, 1, 3]), ) sequence_dense_features = { - "historical_ts": JaggedTensor( + "uih_seq__historical_ts": JaggedTensor( values=torch.tensor([[1], [2], [3], [4], [5], [6], [7]]), lengths=torch.tensor([3, 4]), ), @@ -197,27 +219,35 @@ def test_hstu_match(self, graph_type, kernel, device_str) -> None: hstu = _build_model(device=device) hstu.set_kernel(kernel) batch = _build_batch(device=device) + # Stash the inner item_tower before the graph_type branch wraps + # `hstu` -- the wrapper doesn't expose `item_tower`. + item_tower = hstu.item_tower if graph_type == TestGraphType.JIT_SCRIPT: - # Don't flip the inference flag here: the test batch is in the - # training-view shape (jagged `candidate.sequence`), and the - # item tower's forward branches on `_is_inference` to choose - # `.sequence` vs `.query`. The JIT-scripted forward still - # compiles both branches; the runtime path is the training - # branch, matching the batch shape. - hstu = create_test_model(hstu, graph_type) - predictions = hstu(batch.to_dict(), device) + hstu_wrapped = create_test_model(hstu, graph_type) + predictions = hstu_wrapped(batch.to_dict(), device) elif graph_type == TestGraphType.FX_TRACE: - hstu = create_test_model(hstu, graph_type) - predictions = hstu(batch) + hstu_wrapped = create_test_model(hstu, graph_type) + predictions = hstu_wrapped(batch) else: - hstu = TrainWrapper(hstu, device=device).to(device) - _, (_, predictions, _) = hstu(batch) + hstu_wrapped = TrainWrapper(hstu, device=device).to(device) + _, (_, predictions, _) = hstu_wrapped(batch) self.assertIn("similarity", predictions) # Q = sum(pos_lengths) = 2; column count = 1 (pos) + neg count. self.assertEqual(predictions["similarity"].size(0), 2) + # Scalar-view contract: set_is_inference(True) flips item_tower + # to the scalar export view (bare sub-feature names). + hstu.set_is_inference(True) + self.assertTrue(item_tower._is_inference) + scalar_features = item_tower.features + scalar_feature_groups = item_tower.feature_groups + self.assertEqual(scalar_features[0].name, "video_id") + self.assertFalse(scalar_features[0].is_grouped_sequence) + self.assertEqual(scalar_feature_groups[0].feature_names, ["video_id"]) + self.assertEqual(scalar_feature_groups[0].group_name, "candidate") + if __name__ == "__main__": unittest.main() diff --git a/tzrec/models/model.py b/tzrec/models/model.py index 510fc46d..40da5335 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -346,10 +346,6 @@ class ScriptWrapper(BaseModule): def __init__(self, module: nn.Module) -> None: super().__init__() self.model = module - # Propagate tower identity (set by TowerWoEGWrapper / TowerWrapper) - # so callers can identify which tower this wrapper exports. - if hasattr(self.model, "_tower_name"): - self._tower_name = self.model._tower_name self._data_parser = DataParser( self.model.features, sampler_type=str(module.sampler_type) diff --git a/tzrec/tests/match_integration_test.py b/tzrec/tests/match_integration_test.py index 13ece933..6c2ff407 100644 --- a/tzrec/tests/match_integration_test.py +++ b/tzrec/tests/match_integration_test.py @@ -370,25 +370,22 @@ def test_mind_train_eval_export(self): @unittest.skipIf(*gpu_unavailable) def test_hstu_with_fg_train_eval(self): - hstu_env = "DISABLE_MMA_V3=1" self.success = utils.test_train_eval( "tzrec/tests/configs/hstu_kuairand_1k.config", self.test_dir, user_id="user_id", item_id="item_id", - env_str=hstu_env, ) if self.success: self.success = utils.test_eval( os.path.join(self.test_dir, "pipeline.config"), self.test_dir, - env_str=hstu_env, ) if self.success: self.success = utils.test_export( os.path.join(self.test_dir, "pipeline.config"), self.test_dir, - env_str=f"{hstu_env} ENABLE_AOT=1", + env_str="ENABLE_AOT=1", item_input_path="data/test/kuairand-1k-match-item-c1.parquet", ) if self.success: diff --git a/tzrec/utils/export_util.py b/tzrec/utils/export_util.py index c9bd5854..635f283a 100644 --- a/tzrec/utils/export_util.py +++ b/tzrec/utils/export_util.py @@ -196,9 +196,6 @@ def export_model_normal( if is_rank_zero: if not os.path.exists(save_dir): os.makedirs(save_dir) - # `set_is_inference(True)` applied in `tzrec/main.py::export` - # before wrapping; wrappers already see the inference-mode view. - init_parameters(model, torch.device("cpu")) checkpoint_util.restore_model( checkpoint_path, model, ckpt_param_map_path=ckpt_param_map_path @@ -291,9 +288,8 @@ def export_model_normal( # Persist the model's current feature_groups so towers with a # view-specific group set (e.g. HSTUMatchItemTower scalar view) # don't ship stale training-view group names. - if hasattr(model, "_feature_groups"): - pipeline_config.model_config.ClearField("feature_groups") - pipeline_config.model_config.feature_groups.extend(model.feature_groups) + pipeline_config.model_config.ClearField("feature_groups") + pipeline_config.model_config.feature_groups.extend(model.feature_groups) config_util.save_message( pipeline_config, os.path.join(save_dir, "pipeline.config") ) From fc50ddd2ab9ef3bfa6b68c2439aee45fcbd7b87e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Thu, 21 May 2026 16:22:44 +0800 Subject: [PATCH 15/16] [refactor] drop dead model_config.feature_groups write; inline item_tower access - `export_util.py`: remove the `pipeline_config.model_config.ClearField` + `extend(model.feature_groups)` block. Predict reads only `pipeline_config.feature_configs` (see `main.py::predict` at the scripted-model path) -- the scripted model has its EmbeddingGroup baked in, so `model_config.feature_groups` in the exported config is dead metadata. No downstream consumer. - `hstu_test.py`: drop the unnecessary `item_tower = hstu.item_tower` stash; `hstu` itself is preserved across `create_test_model` / `TrainWrapper` wrapping (assigned to `hstu_wrapped`), so the scalar-view assertions can read `hstu.item_tower` inline. Co-Authored-By: Claude Opus 4.7 --- tzrec/models/hstu_test.py | 9 +++------ tzrec/utils/export_util.py | 5 ----- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/tzrec/models/hstu_test.py b/tzrec/models/hstu_test.py index e03792ca..4f6c6f57 100644 --- a/tzrec/models/hstu_test.py +++ b/tzrec/models/hstu_test.py @@ -219,9 +219,6 @@ def test_hstu_match(self, graph_type, kernel, device_str) -> None: hstu = _build_model(device=device) hstu.set_kernel(kernel) batch = _build_batch(device=device) - # Stash the inner item_tower before the graph_type branch wraps - # `hstu` -- the wrapper doesn't expose `item_tower`. - item_tower = hstu.item_tower if graph_type == TestGraphType.JIT_SCRIPT: hstu_wrapped = create_test_model(hstu, graph_type) @@ -240,9 +237,9 @@ def test_hstu_match(self, graph_type, kernel, device_str) -> None: # Scalar-view contract: set_is_inference(True) flips item_tower # to the scalar export view (bare sub-feature names). hstu.set_is_inference(True) - self.assertTrue(item_tower._is_inference) - scalar_features = item_tower.features - scalar_feature_groups = item_tower.feature_groups + self.assertTrue(hstu.item_tower._is_inference) + scalar_features = hstu.item_tower.features + scalar_feature_groups = hstu.item_tower.feature_groups self.assertEqual(scalar_features[0].name, "video_id") self.assertFalse(scalar_features[0].is_grouped_sequence) self.assertEqual(scalar_feature_groups[0].feature_names, ["video_id"]) diff --git a/tzrec/utils/export_util.py b/tzrec/utils/export_util.py index 635f283a..05642f7f 100644 --- a/tzrec/utils/export_util.py +++ b/tzrec/utils/export_util.py @@ -285,11 +285,6 @@ def export_model_normal( pipeline_config = copy.copy(pipeline_config) pipeline_config.ClearField("feature_configs") pipeline_config.feature_configs.extend(feature_configs) - # Persist the model's current feature_groups so towers with a - # view-specific group set (e.g. HSTUMatchItemTower scalar view) - # don't ship stale training-view group names. - pipeline_config.model_config.ClearField("feature_groups") - pipeline_config.model_config.feature_groups.extend(model.feature_groups) config_util.save_message( pipeline_config, os.path.join(save_dir, "pipeline.config") ) From d25eb5b92a95467748314ae624ab7bc2fce28297 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Thu, 21 May 2026 16:34:55 +0800 Subject: [PATCH 16/16] [doc] hstu_match: add item-tower export instructions with --item_input_path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Document the item-tower scalar export flow: `ENABLE_AOT=1` when using Triton kernel + `--item_input_path` pointing at a one-row-per-item parquet. Mirrors the style of `dlrm_hstu.md`'s 模型导出 section. Co-Authored-By: Claude Opus 4.7 --- docs/source/models/hstu_match.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/source/models/hstu_match.md b/docs/source/models/hstu_match.md index 3f013d2a..16dc46f9 100644 --- a/docs/source/models/hstu_match.md +++ b/docs/source/models/hstu_match.md @@ -256,13 +256,16 @@ model_config { ## 模型导出 -HSTU Match 模型导出时需要设置环境变量 `ENABLE_AOT=1` 启用 AOT Inductor 导出。例如: +HSTU Match 模型导出时,若使用 Triton kernel,需要设置环境变量 `ENABLE_AOT=1` 启用 AOT Inductor 导出。 + +同时需要通过命令行参数 `--item_input_path` 指定 item 侧的输入数据路径(一行一个 item 的 parquet,schema 与候选序列子特征对齐,例如包含 `video_id` 列)。item tower 导出时会从该路径读取一个样本 batch 用于 trace;user tower 不受影响,仍使用 `train_input_path`。例如: ``` ENABLE_AOT=1 torchrun --master_addr=localhost --master_port=32555 \ --nnodes=1 --nproc-per-node=1 --node_rank=0 \ -m tzrec.export \ --pipeline_config_path experiments/hstu_match/pipeline.config \ + --item_input_path experiments/hstu_match/item_data/*.parquet \ --export_dir experiments/hstu_match/export ```