Skip to content

Commit 58fd5cc

Browse files
[feat] HSTUMatch: request-time-anchored time bias (fix ts_gap semantics) (#526)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 6c8e2bd commit 58fd5cc

11 files changed

Lines changed: 204 additions & 9 deletions

File tree

docs/source/models/hstu_match.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ feature_configs {
104104
}
105105
}
106106
}
107+
feature_configs {
108+
raw_feature {
109+
feature_name: "request_time"
110+
expression: "user:request_time"
111+
}
112+
}
107113
model_config {
108114
feature_groups {
109115
group_name: "contextual"
@@ -140,6 +146,11 @@ model_config {
140146
feature_names: "uih_seq__action_timestamp"
141147
group_type: JAGGED_SEQUENCE
142148
}
149+
feature_groups {
150+
group_name: "query_time"
151+
feature_names: "request_time"
152+
group_type: DEEP
153+
}
143154
hstu_match {
144155
user_tower {
145156
input: "uih"
@@ -221,6 +232,7 @@ model_config {
221232
- uih_action: 用户历史交互的行为事件序列,注: 该行为事件按位存储,如 expr, click, add, buy 三个行为,则一般 expr=0, click=1, add=2, buy=4;类型为 JAGGED_SEQUENCE,当 `uih_preprocessor.action_encoder` 配置时必填
222233
- uih_watchtime: 用户历史交互的行为时长序列;类型为 JAGGED_SEQUENCE,当 action encoder 需要 watchtime 时必填
223234
- uih_timestamp: 用户历史交互的行为时间戳序列;类型为 JAGGED_SEQUENCE,当 `positional_encoder.use_time_encoding=true` 时必填
235+
- query_time: 每行一个标量的请求时间 raw 特征 (需与 uih_timestamp 同单位);类型为 DEEP,可选。配置后时间编码以请求时间为基准 (`ts_gap = query_time - 行为时间戳`),否则回退到最后一个 UIH 行为时间
224236

225237
**group_name 不能变**,user_tower/item_tower 通过 group_name 索引对应的 feature_group
226238

scripts/ci/ci_data.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-rtp-eval-c4
1010
wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-mot-1k-train-c4096-s100-e28061f3c88f543b9e18f40be6ddb94d.parquet -O data/test/kuairand-mot-1k-train-c4096-s100.parquet
1111
wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-mot-1k-eval-c4096-s100-f185f38e3b4a49cb791d2e4302087a1f.parquet -O data/test/kuairand-mot-1k-eval-c4096-s100.parquet
1212
# kuairand-1k-match (HSTUMatch integration test fixtures)
13-
wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-train-c4096-s100-f1892eabc70ae3407afe9ff5bca8cb5f.parquet -O data/test/kuairand-1k-match-train-c4096-s100.parquet
14-
wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-eval-c4096-s100-e4ca5e15d157efa723041cd05c127228.parquet -O data/test/kuairand-1k-match-eval-c4096-s100.parquet
13+
wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-train-c4096-s100-aa77964ed7f50ca30645f8dd08dbf10d.parquet -O data/test/kuairand-1k-match-train-c4096-s100.parquet
14+
wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-eval-c4096-s100-8678a3ac699fb08f0602f4c06cef2edf.parquet -O data/test/kuairand-1k-match-eval-c4096-s100.parquet
1515
wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-item-gl-3d459148303acd9f838da108efcc40e5.txt -O data/test/kuairand-1k-match-item-gl.txt
1616
wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-item-c1-8dcadabdc3e9049ed9c2250565b4b134.parquet -O data/test/kuairand-1k-match-item-c1.parquet

tzrec/models/hstu.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,17 @@ def __init__(
9696
contextual_feature_dim = contextual_dims[0]
9797
max_contextual_seq_len = len(contextual_dims)
9898

99+
# Optional `query_time` DEEP group: per-row request-time anchor for the
100+
# HSTU time bias (absent -> anchor on the last UIH timestamp).
101+
query_time_key = next(
102+
(
103+
feature_group.group_name
104+
for feature_group in feature_groups
105+
if feature_group.group_name == "query_time"
106+
),
107+
"",
108+
)
109+
99110
self._hstu_encoder: HSTUMatchEncoder = HSTUMatchEncoder(
100111
uih_embedding_dim=embedding_group.group_total_dim(
101112
f"{tower_config.input}.sequence"
@@ -105,6 +116,7 @@ def __init__(
105116
contextual_group_name=contextual_group_name,
106117
scaling_seqlen=tower_config.max_seq_len,
107118
is_inference=False,
119+
query_time_key=query_time_key,
108120
**config_to_kwargs(tower_config.hstu),
109121
)
110122
if self._output_dim > 0:
@@ -266,6 +278,9 @@ class HSTUMatch(MatchModel):
266278
UIHPreprocessor's action_encoder and the HSTU positional
267279
encoder's time bias. Required when `uih_preprocessor.action_encoder`
268280
is configured.
281+
- "query_time" (optional, DEEP): a single per-row scalar request-time
282+
raw feature used as the HSTU time-bias anchor; absent, the anchor
283+
falls back to the last UIH timestamp.
269284
270285
User tower returns the last-position UIH embedding per user; it is compared
271286
against candidate embeddings via the configured similarity at both train and

tzrec/models/hstu_test.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch
1515
from hypothesis import Verbosity, assume, given, settings
1616
from hypothesis import strategies as st
17-
from torchrec import JaggedTensor, KeyedJaggedTensor
17+
from torchrec import JaggedTensor, KeyedJaggedTensor, KeyedTensor
1818

1919
from tzrec.datasets.utils import BASE_DATA_GROUP, CAND_POS_LENGTHS, Batch
2020
from tzrec.features.feature import create_features
@@ -43,6 +43,10 @@ def _build_model(device: torch.device) -> HSTUMatch:
4343
dim / `embedding_name` so the two flattened features share one
4444
embedding table. `uih_seq` also carries the `historical_ts` raw
4545
sub-feature for the timestamp dense path.
46+
47+
Time encoding is on, with a scalar ``request_time`` raw feature exposed
48+
through a ``query_time`` DEEP group — the per-row time-bias anchor
49+
(mirrors the production config).
4650
"""
4751
feature_cfgs = [
4852
feature_pb2.FeatureConfig(
@@ -84,6 +88,11 @@ def _build_model(device: torch.device) -> HSTUMatch:
8488
)
8589
),
8690
]
91+
feature_cfgs.append(
92+
feature_pb2.FeatureConfig(
93+
raw_feature=feature_pb2.RawFeature(feature_name="request_time")
94+
)
95+
)
8796
features = create_features(feature_cfgs)
8897
feature_groups = [
8998
model_pb2.FeatureGroupConfig(
@@ -102,6 +111,13 @@ def _build_model(device: torch.device) -> HSTUMatch:
102111
group_type=model_pb2.FeatureGroupType.JAGGED_SEQUENCE,
103112
),
104113
]
114+
feature_groups.append(
115+
model_pb2.FeatureGroupConfig(
116+
group_name="query_time",
117+
feature_names=["request_time"],
118+
group_type=model_pb2.FeatureGroupType.DEEP,
119+
)
120+
)
105121
model_config = model_pb2.ModelConfig(
106122
feature_groups=feature_groups,
107123
hstu_match=match_model_pb2.HSTUMatch(
@@ -120,6 +136,8 @@ def _build_model(device: torch.device) -> HSTUMatch:
120136
attn_num_layers=2,
121137
positional_encoder=module_pb2.GRPositionalEncoder(
122138
num_position_buckets=512,
139+
num_time_buckets=512,
140+
use_time_encoding=True,
123141
),
124142
input_preprocessor=module_pb2.GRInputPreprocessor(
125143
uih_preprocessor=module_pb2.GRUIHPreprocessor(),
@@ -160,6 +178,9 @@ def _build_batch(device: torch.device) -> Batch:
160178
Candidates: row 0 = [pos_0]; row 1 (last) = [pos_1, simple_neg_0,
161179
simple_neg_1] -- the shared simple-neg pool sits in the last row's suffix.
162180
pos_lengths = [1, 1].
181+
182+
A per-row ``request_time`` dense scalar (strictly after each user's last
183+
UIH event at ts 3 / 7) is included as the time-bias anchor.
163184
"""
164185
sparse_feature = KeyedJaggedTensor.from_lengths_sync(
165186
keys=["uih_seq__video_id", "cand_seq__video_id"],
@@ -172,7 +193,14 @@ def _build_batch(device: torch.device) -> Batch:
172193
lengths=torch.tensor([3, 4]),
173194
),
174195
}
196+
dense_features = {
197+
BASE_DATA_GROUP: KeyedTensor.from_tensor_list(
198+
keys=["request_time"],
199+
tensors=[torch.tensor([[100.0], [100.0]])],
200+
)
201+
}
175202
return Batch(
203+
dense_features=dense_features,
176204
sparse_features={BASE_DATA_GROUP: sparse_feature},
177205
sequence_dense_features=sequence_dense_features,
178206
jagged_labels={
@@ -217,6 +245,9 @@ def test_hstu_match(self, graph_type, kernel, device_str) -> None:
217245

218246
device = torch.device(device_str)
219247
hstu = _build_model(device=device)
248+
# The query_time DEEP group is detected and threaded as the per-row
249+
# time-bias anchor (request-time anchoring, not the last UIH event).
250+
self.assertEqual(hstu.user_tower._hstu_encoder._query_time_key, "query_time")
220251
hstu.set_kernel(kernel)
221252
batch = _build_batch(device=device)
222253

tzrec/modules/gr/hstu_transducer.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,13 @@ def __init__(
7171
attn_truncation_split_layer: int = 0,
7272
attn_truncation_tail_len: int = 0,
7373
name: str = "",
74+
query_time_key: str = "",
7475
) -> None:
7576
super().__init__(is_inference=is_inference)
77+
# Grouped-feature key of the per-row request time used as the time-bias
78+
# anchor. Empty -> anchor on the last in-sequence timestamp (canonical
79+
# HSTU / DLRM-HSTU, which concatenates the candidate request time).
80+
self._query_time_key: str = query_time_key
7681
self._input_preprocessor: InputPreprocessor = create_input_preprocessor(
7782
input_preprocessor,
7883
uih_embedding_dim=uih_embedding_dim,
@@ -129,6 +134,13 @@ def _preprocess(
129134
output_num_targets,
130135
) = self._input_preprocessor(grouped_features)
131136

137+
# Per-row request time anchor (HSTUMatch). Read from grouped_features
138+
# rather than the preprocessor tuple so the shared ranking path is
139+
# untouched. `[B, 1]` raw values -> the op reshapes to `[B]`.
140+
query_time: Optional[torch.Tensor] = None
141+
if self._query_time_key != "":
142+
query_time = grouped_features[self._query_time_key]
143+
132144
with record_function("hstu_positional_encoder"):
133145
if self._positional_encoder is not None:
134146
output_seq_embeddings = self._positional_encoder(
@@ -138,6 +150,7 @@ def _preprocess(
138150
seq_timestamps=output_seq_timestamps,
139151
seq_embeddings=output_seq_embeddings,
140152
num_targets=output_num_targets,
153+
query_time=query_time,
141154
)
142155

143156
output_seq_embeddings = torch.nn.functional.dropout(
@@ -468,6 +481,10 @@ class HSTUMatchEncoder(_HSTUPipelineBase):
468481
is_inference (bool): whether to run in inference mode.
469482
attn_truncation_split_layer (int): see `HSTUTransducer`.
470483
attn_truncation_tail_len (int): see `HSTUTransducer`.
484+
query_time_key (str): grouped-feature key of the per-row request time
485+
used as the time-bias anchor. Empty (default) anchors on the last
486+
UIH timestamp; pass a scalar request-time group to anchor on the
487+
actual request time (decoupled from UIH staleness).
471488
"""
472489

473490
def __init__(
@@ -487,6 +504,7 @@ def __init__(
487504
attn_truncation_split_layer: int = 0,
488505
attn_truncation_tail_len: int = 0,
489506
name: str = "",
507+
query_time_key: str = "",
490508
) -> None:
491509
super().__init__(
492510
uih_embedding_dim=uih_embedding_dim,
@@ -504,6 +522,7 @@ def __init__(
504522
attn_truncation_split_layer=attn_truncation_split_layer,
505523
attn_truncation_tail_len=attn_truncation_tail_len,
506524
name=name,
525+
query_time_key=query_time_key,
507526
)
508527
self._output_postprocessor: OutputPostprocessor = create_output_postprocessor(
509528
output_postprocessor, embedding_dim=stu["embedding_dim"]

tzrec/modules/gr/positional_encoder.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def forward(
7676
seq_timestamps: torch.Tensor,
7777
seq_embeddings: torch.Tensor,
7878
num_targets: Optional[torch.Tensor],
79+
query_time: Optional[torch.Tensor] = None,
7980
) -> torch.Tensor:
8081
"""Forward the module.
8182
@@ -86,6 +87,9 @@ def forward(
8687
seq_timestamps (torch.Tensor): input sequence timestamps.
8788
seq_embeddings (torch.Tensor): input sequence embeddings.
8889
num_targets (int): number of targets.
90+
query_time (torch.Tensor, optional): per-row request time used as
91+
the time-bias anchor (``ts_gap = query_time - timestamp``).
92+
When ``None``, the last in-sequence timestamp is used.
8993
9094
Returns:
9195
torch.Tensor: output sequence embedding with position embedding.
@@ -106,6 +110,7 @@ def forward(
106110
time_bucket_fn=self._time_bucket_fn,
107111
time_bucket_increments=self._time_bucket_increments,
108112
kernel=self.kernel(),
113+
query_time=query_time,
109114
)
110115
else:
111116
seq_embeddings = add_positional_embeddings(

tzrec/ops/_pytorch/pt_position.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def pytorch_add_timestamp_positional_embeddings(
9393
interleave_targets: bool,
9494
time_bucket_fn: str,
9595
time_bucket_increments: float,
96+
query_time: Optional[torch.Tensor] = None,
9697
) -> torch.Tensor:
9798
max_pos_ind = pos_embeddings.size(0)
9899
# position encoding
@@ -115,11 +116,19 @@ def pytorch_add_timestamp_positional_embeddings(
115116
max_lengths=[max_seq_len],
116117
padding_value=0.0,
117118
).squeeze(-1)
118-
query_time = torch.gather(
119-
timestamps,
120-
dim=1,
121-
index=(seq_lengths - 1).unsqueeze(1).clamp(min=0).to(torch.int64),
122-
)
119+
if query_time is None:
120+
# No explicit anchor: use the last in-sequence timestamp. For
121+
# DLRM-HSTU the candidate is concatenated last, so this is the
122+
# request time; for any UIH-only sequence it is the most-recent event.
123+
query_time = torch.gather(
124+
timestamps,
125+
dim=1,
126+
index=(seq_lengths - 1).unsqueeze(1).clamp(min=0).to(torch.int64),
127+
)
128+
else:
129+
# Explicit per-row request time (HSTUMatch two-tower: no candidate is
130+
# concatenated, so the anchor cannot be derived from the sequence).
131+
query_time = query_time.view(-1, 1).to(timestamps.dtype)
123132
ts = query_time - timestamps
124133
ts = ts + time_delta
125134
ts = ts.clamp(min=1e-6) / time_bucket_increments

0 commit comments

Comments
 (0)