[feat] HSTUMatch: request-time-anchored time bias (fix ts_gap semantics)#526
Conversation
add_timestamp_positional_embeddings derives the time-bias gap as `ts_gap = query_time - event_time`, where `query_time` was always the last in-sequence timestamp. That is correct for DLRM-HSTU (the candidate request time is concatenated last) but wrong for any UIH-only sequence, where the most-recent event always lands at gap 0. Add an optional per-row `query_time` tensor to the pytorch + triton kernels and the dispatcher. When passed, it anchors the gap for every position (including the last); when None, the last-timestamp gather is unchanged, so DLRM-HSTU and existing configs are byte-identical. The triton kernel gates the new `QueryTime` pointer behind a `HAS_QUERY_TIME` constexpr, mirroring the existing `NumTargets` / `HAS_MULTIPLE_TARGETS` pattern; backward is unaffected (it replays the stored time-bucket indices). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… encoder HSTUPositionalEncoder forwards an optional `query_time` to the time-bias op. `_HSTUPipelineBase` reads it from `grouped_features[query_time_key]` (a `[B, 1]` raw request-time scalar) and passes it to the encoder — sourced from grouped_features rather than the shared preprocessor tuple so the ranking (ContextualInterleavePreprocessor) path is untouched. `query_time_key` defaults to "" (last-timestamp anchor, unchanged); HSTUMatchEncoder plumbs it to the base so the two-tower user side can opt into request-time anchoring. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
HSTUUserTower detects an optional `query_time` DEEP group (a single per-row request-time raw feature) among its feature groups and threads the key to HSTUMatchEncoder. With it, the HSTU time bias anchors on the request time instead of the last UIH event, so the user embedding reflects how stale the history is relative to the (serving) request — the signal time encoding is meant to provide, and which the exported two-tower user encoder otherwise had no way to receive. Absent the group, behavior is unchanged. No proto change: the group is routed by name, like contextual / uih_action / uih_timestamp. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Turn on request-time anchoring in the integration path. The kuairand-1k-match train/eval parquets now carry a per-row `request_time` (the retrieval split boundary: first post-history event, >= max uih timestamp), surfaced via a `request_time` raw feature + `query_time` DEEP group in the config and the doc example; ci_data.sh points at the regenerated fixtures. Fold the standalone query_time unit test into test_hstu_match -- the model is always built with the query_time group and asserts the threaded key, across NORMAL/FX_TRACE/JIT_SCRIPT x PYTORCH/TRITON x cpu/cuda. Trim the inline/docstring comments to one line. Verified on A10: match_integration_test.test_hstu_with_fg_train_eval green (train + eval + AOT export of the request_time-bearing user tower + item/user predict); hstu_test + position_test green. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
| query_time_key = next( | ||
| ( | ||
| feature_group.group_name | ||
| for feature_group in feature_groups | ||
| if feature_group.group_name == "query_time" | ||
| ), | ||
| "", | ||
| ) |
There was a problem hiding this comment.
The query_time group is detected purely by name, with no check that it resolves to a single scalar feature (dim 1). If it's misconfigured with multiple features (or a sparse/embedded feature), the grouped tensor is [B, K] and the op does query_time.view(-1) → length B*K: the triton kernel then reads QueryTime + off_b past the buffer (silent OOB, see other comment) and the pytorch path either errors deep inside the op or, at K/length 1, silently broadcasts one row to all users.
Consider a construction-time guard mirroring the adjacent contextual check (lines 90-97), e.g. assert embedding_group.group_total_dim("query_time") == 1 with a message naming the group.
Separately: when this group is configured but positional_encoder.use_time_encoding=false (or no positional_encoder), the anchor is fetched in _preprocess and silently dropped (positional_encoder.py:115). Since driving the time bias is the group's only purpose, a warning or a hard check that query_time implies use_time_encoding would save a hard-to-debug silent no-op.
| if query_time is not None: | ||
| query_time = switch_to_contiguous_if_needed( | ||
| query_time.view(-1).to(timestamps.dtype) | ||
| ) |
There was a problem hiding this comment.
The kernel loads the anchor with tl.load(QueryTime + off_b) (line ~354) where off_b ∈ [0, B), but this wrapper never asserts query_time has B rows — a shorter buffer causes an unmasked out-of-bounds GPU read (silent garbage or a CUDA fault), with no clean diagnostic. B is already computed at line 487. Suggest mirroring the existing torch._assert(... == B) convention in jagged_tensors.py:151:
| if query_time is not None: | |
| query_time = switch_to_contiguous_if_needed( | |
| query_time.view(-1).to(timestamps.dtype) | |
| ) | |
| timestamps = switch_to_contiguous_if_needed(timestamps) | |
| if query_time is not None: | |
| query_time = switch_to_contiguous_if_needed( | |
| query_time.view(-1).to(timestamps.dtype) | |
| ) | |
| torch._assert(query_time.shape[0] == B, "query_time must have B rows") |
(The pytorch path at pt_position.py:131 has the same latent issue — view(-1, 1) silently broadcasts when query_time has 1 element rather than B; an analogous assert there would keep the two kernels from diverging on bad input.)
| ) | ||
|
|
||
| @unittest.skipIf(*gpu_unavailable) | ||
| def test_add_timestamp_positional_embeddings_explicit_query_time(self) -> None: |
There was a problem hiding this comment.
Two coverage gaps worth closing on this new test:
- It's
@unittest.skipIf(*gpu_unavailable), so on CPU-only CI the only thing exercising the explicit-query_timepath ishstu_test.test_hstu_match, which asserts key-wiring + output shape but never thatquery_timenumerically affects the output. The pytorch branch (pt_position.py:128-131) runs fine on CPU — consider running theNone-vs-last_tsequality and the "future shifts output" check onKernel.PYTORCHwithout the GPU gate, so CPU CI validates correctness, not just wiring. - The test is forward-only, so the modified autograd
backward(the new trailingNoneforquery_time) is never exercised. A.backward()assertion on the explicit path would lock that signature in.
Minor: only a future query_time is tested — a request time before the events (negative gap → clamp(min=1e-6)) is a realistic clock-skew case worth one assertion that both kernels clamp identically.
Code review summaryClean, well-scoped fix. The diagnosis in the description is convincing, the optional-
The noteworthy items are in inline comments, in priority order:
Nothing blocking; (1) is the one I'd most want addressed before merge. |
Problem
HSTU's time encoding turns each event into a learned bias from
ts_gap = query_time - event_time, wherequery_timeis gathered as the last in-sequence timestamp (pt_position.py/triton_position.py).candidate_timestamp(thecand_seq___query_timerequest time) as the final position, so the anchor is the request time — correct.UIHPreprocessorbuilds the sequence fromuih_timestampalone and the user embedding is the last UIH position. Soquery_time = last UIH event time, which means:positional_encoderconfig means request-anchored gaps in ranking but last-event-anchored gaps in match.#518 made the item tower exportable for ANN retrieval, which turned the user tower into the live query encoder — surfacing that it cannot be told the request time.
Fix
Decouple
query_timefrom the last event: let HSTUMatch consume an optional per-row request-time scalar and anchorts_gapon it.add_timestamp_positional_embeddings(pytorch + triton + dispatcher) takes an optionalquery_time. Passed → anchors every position on it;None→ the last-timestamp gather is unchanged. DLRM-HSTU and existing configs are byte-identical. The triton kernel gates the new pointer behind aHAS_QUERY_TIMEconstexpr, mirroringNumTargets/HAS_MULTIPLE_TARGETS; backward replays stored bucket indices and is untouched._HSTUPipelineBasereads the anchor fromgrouped_features[query_time_key](not the shared preprocessor tuple, so the ranking path is undisturbed) and threads it to the positional encoder.HSTUUserTowerdetects an optionalquery_timeDEEP group (one scalar request-time raw feature) by name and plumbs the key through. No proto change — routed by name likecontextual/uih_action/uih_timestamp.At training the group carries the sample's request time; at serving, now.
Fixtures
The
kuairand-1k-matchtrain/eval parquets are regenerated to add a per-rowrequest_time= the retrieval split boundary (the first post-history event's timestamp,>= max(uih_seq__action_timestamp)). Regeneration is deterministic and otherwise byte-identical to the prior fixtures (same users, same uih/cand; item-gl md5 unchanged). It is surfaced via arequest_timeraw feature + aquery_timeDEEP group inhstu_kuairand_1k.configand the doc example;ci_data.shpoints at the new train/eval objects (item-gl / item-c1 unchanged).Test plan
position_test.py::…explicit_query_time— passing the last timestamp asquery_timereproduces theNone(gather-last) path exactly on both kernels; pytorch and triton explicit paths agree; a later request time shifts the gaps and changes the output.hstu_test.py::test_hstu_match— now always builds with thequery_timegroup and asserts the threaded_query_time_key, across NORMAL/FX_TRACE/JIT_SCRIPT × PYTORCH/TRITON × cpu/cuda (the standalone query_time test was folded in).match_integration_test.py::test_hstu_with_fg_train_eval— train + eval + AOT export of the request_time-bearing user tower + item-tower scalar export/predict + user-tower predict over therequest_timeeval parquet. Green on local A10.…_tritontimestamp op (None path),hstu_transducer_test(ranking),pre-commit— all green.🤖 Generated with Claude Code