Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
1a8b7d1
[fix] sampler: resolve bare attr_fields against item_id_field's seque…
tiankongdeguiji May 20, 2026
aa76da2
[feat] features: project_grouped_sequence_feature_to_scalar helper fo…
tiankongdeguiji May 20, 2026
7a9318f
[feat] HSTUMatchItemTower: scalar export view via set_is_inference(True)
tiankongdeguiji May 20, 2026
0c1dcf1
[feat] HSTUMatch item-tower scalar export view + item_input_path routing
tiankongdeguiji May 20, 2026
2da7c29
[ci] add kuairand-1k-match-item-c1.parquet wget for HSTUMatch item-to…
tiankongdeguiji May 20, 2026
4578c75
[refactor] sampler: unify sequence state via feature configs, drop _s…
tiankongdeguiji May 20, 2026
8cd5d3e
[refactor] HSTUMatchItemTower view via lazy properties; materialize s…
tiankongdeguiji May 20, 2026
35b18cb
[refactor] expose features/feature_groups as properties end-to-end; d…
tiankongdeguiji May 20, 2026
11ef729
[doc] explain why the prefix-resolve if-else is load-bearing
tiankongdeguiji May 20, 2026
e7836ff
[refactor] gate prefix-resolve on candidate-side sequence_input_names…
tiankongdeguiji May 20, 2026
1e885b5
[refactor] move item_input_path out of proto; expose as generic --dat…
tiankongdeguiji May 20, 2026
59af910
[bugfix] MatchTower: expose features/feature_groups properties (CI fix)
tiankongdeguiji May 20, 2026
f4925dc
Merge remote-tracking branch 'origin/master' into feat/hstu-match-sca…
tiankongdeguiji May 21, 2026
ae17668
[refactor] cleanup PR comments; rename CLI flag back to --item_input_…
tiankongdeguiji May 21, 2026
9f11f05
[bugfix] address PR review: drop dead hasattr gate, lock scalar-view …
tiankongdeguiji May 21, 2026
fc50ddd
[refactor] drop dead model_config.feature_groups write; inline item_t…
tiankongdeguiji May 21, 2026
d25eb5b
[doc] hstu_match: add item-tower export instructions with --item_inpu…
tiankongdeguiji May 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/source/models/hstu_match.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,16 @@ model_config {

## 模型导出

HSTU Match 模型导出时需要设置环境变量 `ENABLE_AOT=1` 启用 AOT Inductor 导出。例如:
HSTU Match 模型导出时,若使用 Triton kernel,需要设置环境变量 `ENABLE_AOT=1` 启用 AOT Inductor 导出。

同时需要通过命令行参数 `--item_input_path` 指定 item 侧的输入数据路径(一行一个 item 的 parquet,schema 与候选序列子特征对齐,例如包含 `video_id` 列)。item tower 导出时会从该路径读取一个样本 batch 用于 trace;user tower 不受影响,仍使用 `train_input_path`。例如:

```
ENABLE_AOT=1 torchrun --master_addr=localhost --master_port=32555 \
--nnodes=1 --nproc-per-node=1 --node_rank=0 \
-m tzrec.export \
--pipeline_config_path experiments/hstu_match/pipeline.config \
--item_input_path experiments/hstu_match/item_data/*.parquet \
--export_dir experiments/hstu_match/export
```

Expand Down
1 change: 1 addition & 0 deletions scripts/ci/ci_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-mot-1k-eval-c4
wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-train-c4096-s100-f1892eabc70ae3407afe9ff5bca8cb5f.parquet -O data/test/kuairand-1k-match-train-c4096-s100.parquet
wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-eval-c4096-s100-e4ca5e15d157efa723041cd05c127228.parquet -O data/test/kuairand-1k-match-eval-c4096-s100.parquet
wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-item-gl-3d459148303acd9f838da108efcc40e5.txt -O data/test/kuairand-1k-match-item-gl.txt
wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-item-c1-8dcadabdc3e9049ed9c2250565b4b134.parquet -O data/test/kuairand-1k-match-item-c1.parquet
10 changes: 10 additions & 0 deletions tzrec/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@
help="JSON string of extra key/value pairs merged into model_acc.json, "
'e.g. \'{"cand_seq_pk": "cand_seq"}\' for DlrmHSTU.',
)
parser.add_argument(
"--item_input_path",
type=str,
default=None,
help="Optional input path for the item-tower's predict-mode "
"dataloader. When set, the item tower reads from this path "
"(a one-row-per-item table matching the scalar export view) "
"instead of `train_input_path`.",
)
args, extra_args = parser.parse_known_args()

additional_export_config = (
Expand All @@ -62,4 +71,5 @@
checkpoint_path=args.checkpoint_path,
asset_files=args.asset_files,
additional_export_config=additional_export_config,
item_input_path=args.item_input_path,
)
37 changes: 37 additions & 0 deletions tzrec/features/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,6 +1230,43 @@ def create_features(
return features


def project_grouped_sequence_feature_to_scalar(
feature: BaseFeature,
) -> feature_pb2.FeatureConfig:
"""Return a scalar export FeatureConfig for a grouped sequence sub-feature.

Rewraps the inner ``SeqFeatureConfig`` as a top-level ``FeatureConfig``
and materializes the source's effective ``default_value`` / ``value_dim``
so the exported scalar feature matches the training sub-feature
(otherwise scalar mode defaults differ from sequence mode).

Args:
feature: a grouped sequence sub-feature.

Returns:
a fresh FeatureConfig for ``create_features()`` to build as scalar.
"""
if not feature.is_grouped_sequence:
raise ValueError(
"project_grouped_sequence_feature_to_scalar only accepts grouped "
f"sequence sub-features; got {feature.name} "
"(is_grouped_sequence=False)"
)
src_cfg = feature.feature_config # SeqFeatureConfig
feat_type = src_cfg.WhichOneof("feature")
src_msg = getattr(src_cfg, feat_type)

scalar_cfg = feature_pb2.FeatureConfig()
dst_msg = getattr(scalar_cfg, feat_type)
dst_msg.CopyFrom(src_msg)

if hasattr(dst_msg, "default_value") and not dst_msg.default_value:
dst_msg.default_value = feature.default_value
if hasattr(dst_msg, "value_dim") and not dst_msg.HasField("value_dim"):
dst_msg.value_dim = feature.value_dim
Comment on lines +1263 to +1266
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two things:

  1. The hasattr guards are meaningful only for OverlapFeature (its default_value is commented out in the proto); for the other 11 SeqFeatureConfig.feature variants both fields are always present, so the guard reads as defensive code with no purpose. Worth naming OverlapFeature in a comment.
  2. Test gap: the "don't overwrite" branches (if not dst_msg.default_value / not dst_msg.HasField("value_dim")) are not exercised. The test in feature_test.py only covers the "both unset → materialize from source" path. Consider adding a case where the inner proto already sets e.g. default_value="-1" or value_dim=4 and asserting the materialization does not overwrite. Same for value_dim > 1 on raw_feature — currently only value_dim == 1 (default) is asserted.

return scalar_cfg


def _copy_assets(
feature: BaseFeature,
asset_dir: Optional[str] = None,
Expand Down
73 changes: 73 additions & 0 deletions tzrec/features/feature_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,5 +750,78 @@ def test_sequence_input_names(self, fg_mode):
)


class ProjectGroupedSequenceFeatureToScalarTest(unittest.TestCase):
def _build_grouped(self, seq_sub_cfg):
feature_cfgs = [
feature_pb2.FeatureConfig(
sequence_feature=feature_pb2.SequenceFeature(
sequence_name="cand_seq",
sequence_delim="|",
sequence_length=100,
features=[seq_sub_cfg],
)
),
]
return feature_lib.create_features(feature_cfgs)

def test_projection_materializes_defaults_and_passes_through_create_features(self):
# id_feature: default_value / value_dim materialization + create_features.
id_sub_cfg = feature_pb2.SeqFeatureConfig(
id_feature=feature_pb2.IdFeature(
feature_name="video_id",
expression="item:video_id",
embedding_dim=32,
num_buckets=10000000,
)
)
features = self._build_grouped(id_sub_cfg)
self.assertEqual(len(features), 1)
sub_feature = features[0]
self.assertTrue(sub_feature.is_grouped_sequence)
# Sequence-effective defaults on the source.
self.assertEqual(sub_feature.default_value, "0")
self.assertEqual(sub_feature.value_dim, 1)

scalar_cfg = feature_lib.project_grouped_sequence_feature_to_scalar(sub_feature)
self.assertEqual(scalar_cfg.WhichOneof("feature"), "id_feature")
# Materialized onto the scalar proto.
self.assertEqual(scalar_cfg.id_feature.default_value, "0")
self.assertTrue(scalar_cfg.id_feature.HasField("value_dim"))
self.assertEqual(scalar_cfg.id_feature.value_dim, 1)
# Source proto not mutated.
self.assertEqual(sub_feature.feature_config.id_feature.default_value, "")
self.assertFalse(sub_feature.feature_config.id_feature.HasField("value_dim"))

# create_features rebuilds it as a top-level scalar feature.
scalar_features = feature_lib.create_features([scalar_cfg])
self.assertEqual(len(scalar_features), 1)
scalar = scalar_features[0]
self.assertEqual(scalar.name, "video_id")
self.assertFalse(scalar.is_grouped_sequence)
self.assertEqual(scalar.value_dim, 1)

# raw_feature: confirms the helper isn't hard-coded to id_feature.
raw_sub_cfg = feature_pb2.SeqFeatureConfig(
raw_feature=feature_pb2.RawFeature(
feature_name="watch_time", expression="user:watch_time"
)
)
raw_features = self._build_grouped(raw_sub_cfg)
raw_scalar_cfg = feature_lib.project_grouped_sequence_feature_to_scalar(
raw_features[0]
)
self.assertEqual(raw_scalar_cfg.WhichOneof("feature"), "raw_feature")

def test_projection_rejects_non_grouped_feature(self):
feature_cfgs = [
feature_pb2.FeatureConfig(
id_feature=feature_pb2.IdFeature(feature_name="user_id")
),
]
features = feature_lib.create_features(feature_cfgs)
with self.assertRaisesRegex(ValueError, "is_grouped_sequence=False"):
feature_lib.project_grouped_sequence_feature_to_scalar(features[0])


if __name__ == "__main__":
unittest.main()
13 changes: 13 additions & 0 deletions tzrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,7 @@ def export(
checkpoint_path: Optional[str] = None,
asset_files: Optional[str] = None,
additional_export_config: Optional[Dict[str, str]] = None,
item_input_path: Optional[str] = None,
) -> None:
"""Export a EasyRec model.

Expand All @@ -906,6 +907,9 @@ def export(
asset_files (str, optional): more files will be copied to export_dir.
additional_export_config (dict, optional): extra key/value pairs merged
into model_acc.json (e.g. ``{"cand_seq_pk": "cand_seq"}`` for DlrmHSTU).
item_input_path (str, optional): override for the item tower's
predict-mode dataloader input path. When set, the item tower
reads from this path instead of ``train_input_path``.
"""
is_rank_zero = int(os.environ.get("RANK", 0)) == 0

Expand Down Expand Up @@ -936,6 +940,10 @@ def export(
sampler_type=None,
)
InferWrapper = ScriptWrapper
# Flip to inference *before* wrapping so view-dependent state
# (e.g. HSTUMatchItemTower's lazy properties, wrapper EmbeddingGroups)
# is snapshot from the scalar view.
model.set_is_inference(True)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth a short comment here naming the reason this must flip on all ranks pre-wrap: export_rtp_model runs init_process_group() / dist.barrier() / DistributedModelParallel collectives, which require every rank to share the same model topology. If set_is_inference ever alters sharded-module shape (as it now does, via HSTUMatchItemTower.features's scalar projection), a rank-0-only flip would diverge ranks. A one-liner would prevent a future refactor from moving it back inside the rank-0 branch.

Also, asymmetry to watch for: the outer ScriptWrapper is constructed after the flip, so its own _is_inference is False while the inner self.model._is_inference is True. Functionally fine (nothing reads the outer flag), but surprising — same comment can call this out.

model = InferWrapper(model)

if not checkpoint_path:
Expand All @@ -959,13 +967,18 @@ def export(
)
tower = InferWrapper(wrapper(module, name))
tower_export_dir = os.path.join(export_dir, name.replace("_tower", ""))
# item-tower-only; user tower falls back to `train_input_path`.
tower_data_input_path = (
item_input_path if name == "item_tower" else None
)
Comment on lines +970 to +973
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The user-tower fallback to train_input_path is asserted only implicitly by the integration test (which happens to predict the user tower over the eval parquet, not the train parquet). So if a refactor accidentally routed item_input_path to the user tower too, the integration test would still pass.

If you want to lock this in cheaply, a small unit test in main_test.py that monkey-patches export_model and asserts data_input_path per tower (item_tower → item_input_path, user_tower → None) would pin the policy.

export_model(
ori_pipeline_config,
tower,
checkpoint_path,
tower_export_dir,
assets=assets,
additional_export_config=additional_export_config,
data_input_path=tower_data_input_path,
)
elif isinstance(model.model, TDM):
for name, module in model.model.named_children():
Expand Down
69 changes: 66 additions & 3 deletions tzrec/models/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,17 @@
import torch.nn.functional as F
from torch import nn

from tzrec.datasets.utils import CAND_POS_LENGTHS, HARD_NEG_INDICES, Batch
from tzrec.features.feature import BaseFeature
from tzrec.datasets.utils import (
BASE_DATA_GROUP,
CAND_POS_LENGTHS,
HARD_NEG_INDICES,
Batch,
)
from tzrec.features.feature import (
BaseFeature,
create_features,
project_grouped_sequence_feature_to_scalar,
)
from tzrec.models.match_model import MatchModel, MatchTowerWoEG
from tzrec.modules.embedding import EmbeddingGroup
from tzrec.modules.gr.hstu_transducer import HSTUMatchEncoder
Expand Down Expand Up @@ -151,8 +160,18 @@ def __init__(
# tower_config.input names on the user-tower proto). Use the item-side
# tower_config.input here, which equals feature_groups[0].group_name.
self._group_name = tower_config.input
# MLP sized off the training candidate group; the scalar view has
# identical per-feature embedding dim.
candidate_dims = embedding_group.group_dims(f"{self._group_name}.sequence")
candidate_total_dim = sum(candidate_dims)

# Lazy caches for the scalar export view (populated on first
# property access after `set_is_inference(True)`).
self._features_scalar: Optional[List[BaseFeature]] = None
self._feature_groups_scalar: Optional[List[model_pb2.FeatureGroupConfig]] = None
Comment on lines +168 to +171
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth documenting the cache-validity invariant: set_is_inference(False) flips the flag back but does not clear _features_scalar / _feature_groups_scalar. Reuse on a subsequent flip is correct because _features is immutable post-__init__, but a future maintainer reading this will reasonably wonder whether the cache should be invalidated on revert. One line ("cache valid for tower lifetime; _features is immutable") would prevent a well-intentioned "fix".

# `MatchTowerWoEG` derives from `nn.Module`, not `BaseModule`,
# so init `_is_inference` here.
self._is_inference: bool = False
if tower_config.HasField("mlp"):
self.mlp: torch.nn.Module = MLP(
in_features=candidate_total_dim,
Expand All @@ -166,6 +185,48 @@ def __init__(
if self._output_dim > 0:
self.output = nn.Linear(mlp_out_dim, output_dim)

@property
def features(self) -> List[BaseFeature]:
"""Item features (training: grouped sub-features; export: scalar projection)."""
if self._is_inference:
if self._features_scalar is None:
self._build_scalar_features()
return self._features_scalar
return self._features

@property
def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]:
"""Item feature_groups in the current view (see ``features``)."""
if self._is_inference:
if self._feature_groups_scalar is None:
self._build_scalar_features()
return self._feature_groups_scalar
return self._feature_groups

def _build_scalar_features(self) -> None:
"""Project each grouped sequence sub-feature into a scalar export feature."""
scalar_configs = [
project_grouped_sequence_feature_to_scalar(f) for f in self._features
]
source = self._features[0]
scalar_features = create_features(
scalar_configs,
fg_mode=source.fg_mode,
neg_fields=None,
fg_encoded_multival_sep=source._fg_encoded_multival_sep,
force_base_data_group=any(
f.data_group == BASE_DATA_GROUP for f in self._features
),
)
self._features_scalar = scalar_features
self._feature_groups_scalar = [
model_pb2.FeatureGroupConfig(
group_name=self._group_name,
feature_names=[f.name for f in scalar_features],
group_type=model_pb2.JAGGED_SEQUENCE,
)
]

def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Forward the item tower.

Expand All @@ -175,7 +236,9 @@ def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor:
Returns:
item embeddings of shape (sum_candidates, D).
"""
cand_emb = grouped_features[f"{self._group_name}.sequence"]
# `.sequence` (jagged) at training, `.query` (scalar) at export.
suffix = ".query" if self._is_inference else ".sequence"
cand_emb = grouped_features[self._group_name + suffix]
item_emb = self.mlp(cand_emb)
if self._output_dim > 0:
item_emb = self.output(item_emb)
Expand Down
Loading
Loading