Skip to content

Commit 77408f7

Browse files
[feat] HSTUMatch: scalar item-tower export view (#518)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a21d26f commit 77408f7

11 files changed

Lines changed: 322 additions & 45 deletions

File tree

docs/source/models/hstu_match.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,13 +256,16 @@ model_config {
256256

257257
## 模型导出
258258

259-
HSTU Match 模型导出时需要设置环境变量 `ENABLE_AOT=1` 启用 AOT Inductor 导出。例如:
259+
HSTU Match 模型导出时,若使用 Triton kernel,需要设置环境变量 `ENABLE_AOT=1` 启用 AOT Inductor 导出。
260+
261+
同时需要通过命令行参数 `--item_input_path` 指定 item 侧的输入数据路径(一行一个 item 的 parquet,schema 与候选序列子特征对齐,例如包含 `video_id` 列)。item tower 导出时会从该路径读取一个样本 batch 用于 trace;user tower 不受影响,仍使用 `train_input_path`。例如:
260262

261263
```
262264
ENABLE_AOT=1 torchrun --master_addr=localhost --master_port=32555 \
263265
--nnodes=1 --nproc-per-node=1 --node_rank=0 \
264266
-m tzrec.export \
265267
--pipeline_config_path experiments/hstu_match/pipeline.config \
268+
--item_input_path experiments/hstu_match/item_data/*.parquet \
266269
--export_dir experiments/hstu_match/export
267270
```
268271

scripts/ci/ci_data.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-mot-1k-eval-c4
1313
wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-train-c4096-s100-f1892eabc70ae3407afe9ff5bca8cb5f.parquet -O data/test/kuairand-1k-match-train-c4096-s100.parquet
1414
wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-eval-c4096-s100-e4ca5e15d157efa723041cd05c127228.parquet -O data/test/kuairand-1k-match-eval-c4096-s100.parquet
1515
wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-item-gl-3d459148303acd9f838da108efcc40e5.txt -O data/test/kuairand-1k-match-item-gl.txt
16+
wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-item-c1-8dcadabdc3e9049ed9c2250565b4b134.parquet -O data/test/kuairand-1k-match-item-c1.parquet

tzrec/export.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@
4848
help="JSON string of extra key/value pairs merged into model_acc.json, "
4949
'e.g. \'{"cand_seq_pk": "cand_seq"}\' for DlrmHSTU.',
5050
)
51+
parser.add_argument(
52+
"--item_input_path",
53+
type=str,
54+
default=None,
55+
help="Optional input path for the item-tower's predict-mode "
56+
"dataloader. When set, the item tower reads from this path "
57+
"(a one-row-per-item table matching the scalar export view) "
58+
"instead of `train_input_path`.",
59+
)
5160
args, extra_args = parser.parse_known_args()
5261

5362
additional_export_config = (
@@ -62,4 +71,5 @@
6271
checkpoint_path=args.checkpoint_path,
6372
asset_files=args.asset_files,
6473
additional_export_config=additional_export_config,
74+
item_input_path=args.item_input_path,
6575
)

tzrec/features/feature.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,6 +1230,43 @@ def create_features(
12301230
return features
12311231

12321232

1233+
def project_grouped_sequence_feature_to_scalar(
1234+
feature: BaseFeature,
1235+
) -> feature_pb2.FeatureConfig:
1236+
"""Return a scalar export FeatureConfig for a grouped sequence sub-feature.
1237+
1238+
Rewraps the inner ``SeqFeatureConfig`` as a top-level ``FeatureConfig``
1239+
and materializes the source's effective ``default_value`` / ``value_dim``
1240+
so the exported scalar feature matches the training sub-feature
1241+
(otherwise scalar mode defaults differ from sequence mode).
1242+
1243+
Args:
1244+
feature: a grouped sequence sub-feature.
1245+
1246+
Returns:
1247+
a fresh FeatureConfig for ``create_features()`` to build as scalar.
1248+
"""
1249+
if not feature.is_grouped_sequence:
1250+
raise ValueError(
1251+
"project_grouped_sequence_feature_to_scalar only accepts grouped "
1252+
f"sequence sub-features; got {feature.name} "
1253+
"(is_grouped_sequence=False)"
1254+
)
1255+
src_cfg = feature.feature_config # SeqFeatureConfig
1256+
feat_type = src_cfg.WhichOneof("feature")
1257+
src_msg = getattr(src_cfg, feat_type)
1258+
1259+
scalar_cfg = feature_pb2.FeatureConfig()
1260+
dst_msg = getattr(scalar_cfg, feat_type)
1261+
dst_msg.CopyFrom(src_msg)
1262+
1263+
if hasattr(dst_msg, "default_value") and not dst_msg.default_value:
1264+
dst_msg.default_value = feature.default_value
1265+
if hasattr(dst_msg, "value_dim") and not dst_msg.HasField("value_dim"):
1266+
dst_msg.value_dim = feature.value_dim
1267+
return scalar_cfg
1268+
1269+
12331270
def _copy_assets(
12341271
feature: BaseFeature,
12351272
asset_dir: Optional[str] = None,

tzrec/features/feature_test.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,5 +750,78 @@ def test_sequence_input_names(self, fg_mode):
750750
)
751751

752752

753+
class ProjectGroupedSequenceFeatureToScalarTest(unittest.TestCase):
754+
def _build_grouped(self, seq_sub_cfg):
755+
feature_cfgs = [
756+
feature_pb2.FeatureConfig(
757+
sequence_feature=feature_pb2.SequenceFeature(
758+
sequence_name="cand_seq",
759+
sequence_delim="|",
760+
sequence_length=100,
761+
features=[seq_sub_cfg],
762+
)
763+
),
764+
]
765+
return feature_lib.create_features(feature_cfgs)
766+
767+
def test_projection_materializes_defaults_and_passes_through_create_features(self):
768+
# id_feature: default_value / value_dim materialization + create_features.
769+
id_sub_cfg = feature_pb2.SeqFeatureConfig(
770+
id_feature=feature_pb2.IdFeature(
771+
feature_name="video_id",
772+
expression="item:video_id",
773+
embedding_dim=32,
774+
num_buckets=10000000,
775+
)
776+
)
777+
features = self._build_grouped(id_sub_cfg)
778+
self.assertEqual(len(features), 1)
779+
sub_feature = features[0]
780+
self.assertTrue(sub_feature.is_grouped_sequence)
781+
# Sequence-effective defaults on the source.
782+
self.assertEqual(sub_feature.default_value, "0")
783+
self.assertEqual(sub_feature.value_dim, 1)
784+
785+
scalar_cfg = feature_lib.project_grouped_sequence_feature_to_scalar(sub_feature)
786+
self.assertEqual(scalar_cfg.WhichOneof("feature"), "id_feature")
787+
# Materialized onto the scalar proto.
788+
self.assertEqual(scalar_cfg.id_feature.default_value, "0")
789+
self.assertTrue(scalar_cfg.id_feature.HasField("value_dim"))
790+
self.assertEqual(scalar_cfg.id_feature.value_dim, 1)
791+
# Source proto not mutated.
792+
self.assertEqual(sub_feature.feature_config.id_feature.default_value, "")
793+
self.assertFalse(sub_feature.feature_config.id_feature.HasField("value_dim"))
794+
795+
# create_features rebuilds it as a top-level scalar feature.
796+
scalar_features = feature_lib.create_features([scalar_cfg])
797+
self.assertEqual(len(scalar_features), 1)
798+
scalar = scalar_features[0]
799+
self.assertEqual(scalar.name, "video_id")
800+
self.assertFalse(scalar.is_grouped_sequence)
801+
self.assertEqual(scalar.value_dim, 1)
802+
803+
# raw_feature: confirms the helper isn't hard-coded to id_feature.
804+
raw_sub_cfg = feature_pb2.SeqFeatureConfig(
805+
raw_feature=feature_pb2.RawFeature(
806+
feature_name="watch_time", expression="user:watch_time"
807+
)
808+
)
809+
raw_features = self._build_grouped(raw_sub_cfg)
810+
raw_scalar_cfg = feature_lib.project_grouped_sequence_feature_to_scalar(
811+
raw_features[0]
812+
)
813+
self.assertEqual(raw_scalar_cfg.WhichOneof("feature"), "raw_feature")
814+
815+
def test_projection_rejects_non_grouped_feature(self):
816+
feature_cfgs = [
817+
feature_pb2.FeatureConfig(
818+
id_feature=feature_pb2.IdFeature(feature_name="user_id")
819+
),
820+
]
821+
features = feature_lib.create_features(feature_cfgs)
822+
with self.assertRaisesRegex(ValueError, "is_grouped_sequence=False"):
823+
feature_lib.project_grouped_sequence_feature_to_scalar(features[0])
824+
825+
753826
if __name__ == "__main__":
754827
unittest.main()

tzrec/main.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,7 @@ def export(
895895
checkpoint_path: Optional[str] = None,
896896
asset_files: Optional[str] = None,
897897
additional_export_config: Optional[Dict[str, str]] = None,
898+
item_input_path: Optional[str] = None,
898899
) -> None:
899900
"""Export a EasyRec model.
900901
@@ -906,6 +907,9 @@ def export(
906907
asset_files (str, optional): more files will be copied to export_dir.
907908
additional_export_config (dict, optional): extra key/value pairs merged
908909
into model_acc.json (e.g. ``{"cand_seq_pk": "cand_seq"}`` for DlrmHSTU).
910+
item_input_path (str, optional): override for the item tower's
911+
predict-mode dataloader input path. When set, the item tower
912+
reads from this path instead of ``train_input_path``.
909913
"""
910914
is_rank_zero = int(os.environ.get("RANK", 0)) == 0
911915

@@ -936,6 +940,10 @@ def export(
936940
sampler_type=None,
937941
)
938942
InferWrapper = ScriptWrapper
943+
# Flip to inference *before* wrapping so view-dependent state
944+
# (e.g. HSTUMatchItemTower's lazy properties, wrapper EmbeddingGroups)
945+
# is snapshot from the scalar view.
946+
model.set_is_inference(True)
939947
model = InferWrapper(model)
940948

941949
if not checkpoint_path:
@@ -959,13 +967,18 @@ def export(
959967
)
960968
tower = InferWrapper(wrapper(module, name))
961969
tower_export_dir = os.path.join(export_dir, name.replace("_tower", ""))
970+
# item-tower-only; user tower falls back to `train_input_path`.
971+
tower_data_input_path = (
972+
item_input_path if name == "item_tower" else None
973+
)
962974
export_model(
963975
ori_pipeline_config,
964976
tower,
965977
checkpoint_path,
966978
tower_export_dir,
967979
assets=assets,
968980
additional_export_config=additional_export_config,
981+
data_input_path=tower_data_input_path,
969982
)
970983
elif isinstance(model.model, TDM):
971984
for name, module in model.model.named_children():

tzrec/models/hstu.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,17 @@
1515
import torch.nn.functional as F
1616
from torch import nn
1717

18-
from tzrec.datasets.utils import CAND_POS_LENGTHS, HARD_NEG_INDICES, Batch
19-
from tzrec.features.feature import BaseFeature
18+
from tzrec.datasets.utils import (
19+
BASE_DATA_GROUP,
20+
CAND_POS_LENGTHS,
21+
HARD_NEG_INDICES,
22+
Batch,
23+
)
24+
from tzrec.features.feature import (
25+
BaseFeature,
26+
create_features,
27+
project_grouped_sequence_feature_to_scalar,
28+
)
2029
from tzrec.models.match_model import MatchModel, MatchTowerWoEG
2130
from tzrec.modules.embedding import EmbeddingGroup
2231
from tzrec.modules.gr.hstu_transducer import HSTUMatchEncoder
@@ -151,8 +160,18 @@ def __init__(
151160
# tower_config.input names on the user-tower proto). Use the item-side
152161
# tower_config.input here, which equals feature_groups[0].group_name.
153162
self._group_name = tower_config.input
163+
# MLP sized off the training candidate group; the scalar view has
164+
# identical per-feature embedding dim.
154165
candidate_dims = embedding_group.group_dims(f"{self._group_name}.sequence")
155166
candidate_total_dim = sum(candidate_dims)
167+
168+
# Lazy caches for the scalar export view (populated on first
169+
# property access after `set_is_inference(True)`).
170+
self._features_scalar: Optional[List[BaseFeature]] = None
171+
self._feature_groups_scalar: Optional[List[model_pb2.FeatureGroupConfig]] = None
172+
# `MatchTowerWoEG` derives from `nn.Module`, not `BaseModule`,
173+
# so init `_is_inference` here.
174+
self._is_inference: bool = False
156175
if tower_config.HasField("mlp"):
157176
self.mlp: torch.nn.Module = MLP(
158177
in_features=candidate_total_dim,
@@ -166,6 +185,48 @@ def __init__(
166185
if self._output_dim > 0:
167186
self.output = nn.Linear(mlp_out_dim, output_dim)
168187

188+
@property
189+
def features(self) -> List[BaseFeature]:
190+
"""Item features (training: grouped sub-features; export: scalar projection)."""
191+
if self._is_inference:
192+
if self._features_scalar is None:
193+
self._build_scalar_features()
194+
return self._features_scalar
195+
return self._features
196+
197+
@property
198+
def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]:
199+
"""Item feature_groups in the current view (see ``features``)."""
200+
if self._is_inference:
201+
if self._feature_groups_scalar is None:
202+
self._build_scalar_features()
203+
return self._feature_groups_scalar
204+
return self._feature_groups
205+
206+
def _build_scalar_features(self) -> None:
207+
"""Project each grouped sequence sub-feature into a scalar export feature."""
208+
scalar_configs = [
209+
project_grouped_sequence_feature_to_scalar(f) for f in self._features
210+
]
211+
source = self._features[0]
212+
scalar_features = create_features(
213+
scalar_configs,
214+
fg_mode=source.fg_mode,
215+
neg_fields=None,
216+
fg_encoded_multival_sep=source._fg_encoded_multival_sep,
217+
force_base_data_group=any(
218+
f.data_group == BASE_DATA_GROUP for f in self._features
219+
),
220+
)
221+
self._features_scalar = scalar_features
222+
self._feature_groups_scalar = [
223+
model_pb2.FeatureGroupConfig(
224+
group_name=self._group_name,
225+
feature_names=[f.name for f in scalar_features],
226+
group_type=model_pb2.JAGGED_SEQUENCE,
227+
)
228+
]
229+
169230
def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor:
170231
"""Forward the item tower.
171232
@@ -175,7 +236,9 @@ def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor:
175236
Returns:
176237
item embeddings of shape (sum_candidates, D).
177238
"""
178-
cand_emb = grouped_features[f"{self._group_name}.sequence"]
239+
# `.sequence` (jagged) at training, `.query` (scalar) at export.
240+
suffix = ".query" if self._is_inference else ".sequence"
241+
cand_emb = grouped_features[self._group_name + suffix]
179242
item_emb = self.mlp(cand_emb)
180243
if self._output_dim > 0:
181244
item_emb = self.output(item_emb)

0 commit comments

Comments
 (0)