Skip to content

[feat] HSTUMatch: request-time-anchored time bias (fix ts_gap semantics)#526

Merged
tiankongdeguiji merged 4 commits into
masterfrom
fix/hstu-match-query-time
May 26, 2026
Merged

[feat] HSTUMatch: request-time-anchored time bias (fix ts_gap semantics)#526
tiankongdeguiji merged 4 commits into
masterfrom
fix/hstu-match-query-time

Conversation

@tiankongdeguiji

@tiankongdeguiji tiankongdeguiji commented May 25, 2026

Copy link
Copy Markdown
Collaborator

Problem

HSTU's time encoding turns each event into a learned bias from ts_gap = query_time - event_time, where query_time is gathered as the last in-sequence timestamp (pt_position.py / triton_position.py).

  • DLRM-HSTU (ranking) concatenates candidate_timestamp (the cand_seq___query_time request time) as the final position, so the anchor is the request time — correct.
  • HSTUMatch (two-tower) is UIH-only: UIHPreprocessor builds the sequence from uih_timestamp alone and the user embedding is the last UIH position. So query_time = last UIH event time, which means:
    • the most-recent event always lands in time-bucket 0 — the recency signal time encoding exists to provide is destroyed;
    • at serving the ANN query fires at now (later than the last event), but the exported user tower has no input to receive the real request time, so history staleness relative to the request is invisible;
    • the same positional_encoder config means request-anchored gaps in ranking but last-event-anchored gaps in match.

#518 made the item tower exportable for ANN retrieval, which turned the user tower into the live query encoder — surfacing that it cannot be told the request time.

Fix

Decouple query_time from the last event: let HSTUMatch consume an optional per-row request-time scalar and anchor ts_gap on it.

  • add_timestamp_positional_embeddings (pytorch + triton + dispatcher) takes an optional query_time. Passed → anchors every position on it; None → the last-timestamp gather is unchanged. DLRM-HSTU and existing configs are byte-identical. The triton kernel gates the new pointer behind a HAS_QUERY_TIME constexpr, mirroring NumTargets/HAS_MULTIPLE_TARGETS; backward replays stored bucket indices and is untouched.
  • _HSTUPipelineBase reads the anchor from grouped_features[query_time_key] (not the shared preprocessor tuple, so the ranking path is undisturbed) and threads it to the positional encoder.
  • HSTUUserTower detects an optional query_time DEEP group (one scalar request-time raw feature) by name and plumbs the key through. No proto change — routed by name like contextual / uih_action / uih_timestamp.

At training the group carries the sample's request time; at serving, now.

Fixtures

The kuairand-1k-match train/eval parquets are regenerated to add a per-row request_time = the retrieval split boundary (the first post-history event's timestamp, >= max(uih_seq__action_timestamp)). Regeneration is deterministic and otherwise byte-identical to the prior fixtures (same users, same uih/cand; item-gl md5 unchanged). It is surfaced via a request_time raw feature + a query_time DEEP group in hstu_kuairand_1k.config and the doc example; ci_data.sh points at the new train/eval objects (item-gl / item-c1 unchanged).

Test plan

  • position_test.py::…explicit_query_time — passing the last timestamp as query_time reproduces the None (gather-last) path exactly on both kernels; pytorch and triton explicit paths agree; a later request time shifts the gaps and changes the output.
  • hstu_test.py::test_hstu_match — now always builds with the query_time group and asserts the threaded _query_time_key, across NORMAL/FX_TRACE/JIT_SCRIPT × PYTORCH/TRITON × cpu/cuda (the standalone query_time test was folded in).
  • match_integration_test.py::test_hstu_with_fg_train_eval — train + eval + AOT export of the request_time-bearing user tower + item-tower scalar export/predict + user-tower predict over the request_time eval parquet. Green on local A10.
  • Backward-compat: existing …_triton timestamp op (None path), hstu_transducer_test (ranking), pre-commit — all green.

🤖 Generated with Claude Code

tiankongdeguiji and others added 4 commits May 25, 2026 14:41
add_timestamp_positional_embeddings derives the time-bias gap as
`ts_gap = query_time - event_time`, where `query_time` was always the
last in-sequence timestamp. That is correct for DLRM-HSTU (the candidate
request time is concatenated last) but wrong for any UIH-only sequence,
where the most-recent event always lands at gap 0.

Add an optional per-row `query_time` tensor to the pytorch + triton
kernels and the dispatcher. When passed, it anchors the gap for every
position (including the last); when None, the last-timestamp gather is
unchanged, so DLRM-HSTU and existing configs are byte-identical.

The triton kernel gates the new `QueryTime` pointer behind a
`HAS_QUERY_TIME` constexpr, mirroring the existing `NumTargets` /
`HAS_MULTIPLE_TARGETS` pattern; backward is unaffected (it replays the
stored time-bucket indices).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… encoder

HSTUPositionalEncoder forwards an optional `query_time` to the time-bias
op. `_HSTUPipelineBase` reads it from `grouped_features[query_time_key]`
(a `[B, 1]` raw request-time scalar) and passes it to the encoder —
sourced from grouped_features rather than the shared preprocessor tuple
so the ranking (ContextualInterleavePreprocessor) path is untouched.

`query_time_key` defaults to "" (last-timestamp anchor, unchanged);
HSTUMatchEncoder plumbs it to the base so the two-tower user side can
opt into request-time anchoring.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
HSTUUserTower detects an optional `query_time` DEEP group (a single
per-row request-time raw feature) among its feature groups and threads
the key to HSTUMatchEncoder. With it, the HSTU time bias anchors on the
request time instead of the last UIH event, so the user embedding
reflects how stale the history is relative to the (serving) request —
the signal time encoding is meant to provide, and which the exported
two-tower user encoder otherwise had no way to receive.

Absent the group, behavior is unchanged. No proto change: the group is
routed by name, like contextual / uih_action / uih_timestamp.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Turn on request-time anchoring in the integration path. The kuairand-1k-match
train/eval parquets now carry a per-row `request_time` (the retrieval split
boundary: first post-history event, >= max uih timestamp), surfaced via a
`request_time` raw feature + `query_time` DEEP group in the config and the doc
example; ci_data.sh points at the regenerated fixtures.

Fold the standalone query_time unit test into test_hstu_match -- the model is
always built with the query_time group and asserts the threaded key, across
NORMAL/FX_TRACE/JIT_SCRIPT x PYTORCH/TRITON x cpu/cuda. Trim the inline/docstring
comments to one line.

Verified on A10: match_integration_test.test_hstu_with_fg_train_eval green
(train + eval + AOT export of the request_time-bearing user tower + item/user
predict); hstu_test + position_test green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji added the claude-review Let Claude Review label May 26, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label May 26, 2026
Comment thread tzrec/models/hstu.py
Comment on lines +101 to +108
query_time_key = next(
(
feature_group.group_name
for feature_group in feature_groups
if feature_group.group_name == "query_time"
),
"",
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The query_time group is detected purely by name, with no check that it resolves to a single scalar feature (dim 1). If it's misconfigured with multiple features (or a sparse/embedded feature), the grouped tensor is [B, K] and the op does query_time.view(-1) → length B*K: the triton kernel then reads QueryTime + off_b past the buffer (silent OOB, see other comment) and the pytorch path either errors deep inside the op or, at K/length 1, silently broadcasts one row to all users.

Consider a construction-time guard mirroring the adjacent contextual check (lines 90-97), e.g. assert embedding_group.group_total_dim("query_time") == 1 with a message naming the group.

Separately: when this group is configured but positional_encoder.use_time_encoding=false (or no positional_encoder), the anchor is fetched in _preprocess and silently dropped (positional_encoder.py:115). Since driving the time bias is the group's only purpose, a warning or a hard check that query_time implies use_time_encoding would save a hard-to-debug silent no-op.

Comment on lines +500 to +503
if query_time is not None:
query_time = switch_to_contiguous_if_needed(
query_time.view(-1).to(timestamps.dtype)
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel loads the anchor with tl.load(QueryTime + off_b) (line ~354) where off_b ∈ [0, B), but this wrapper never asserts query_time has B rows — a shorter buffer causes an unmasked out-of-bounds GPU read (silent garbage or a CUDA fault), with no clean diagnostic. B is already computed at line 487. Suggest mirroring the existing torch._assert(... == B) convention in jagged_tensors.py:151:

Suggested change
if query_time is not None:
query_time = switch_to_contiguous_if_needed(
query_time.view(-1).to(timestamps.dtype)
)
timestamps = switch_to_contiguous_if_needed(timestamps)
if query_time is not None:
query_time = switch_to_contiguous_if_needed(
query_time.view(-1).to(timestamps.dtype)
)
torch._assert(query_time.shape[0] == B, "query_time must have B rows")

(The pytorch path at pt_position.py:131 has the same latent issue — view(-1, 1) silently broadcasts when query_time has 1 element rather than B; an analogous assert there would keep the two kernels from diverging on bad input.)

)

@unittest.skipIf(*gpu_unavailable)
def test_add_timestamp_positional_embeddings_explicit_query_time(self) -> None:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two coverage gaps worth closing on this new test:

  • It's @unittest.skipIf(*gpu_unavailable), so on CPU-only CI the only thing exercising the explicit-query_time path is hstu_test.test_hstu_match, which asserts key-wiring + output shape but never that query_time numerically affects the output. The pytorch branch (pt_position.py:128-131) runs fine on CPU — consider running the None-vs-last_ts equality and the "future shifts output" check on Kernel.PYTORCH without the GPU gate, so CPU CI validates correctness, not just wiring.
  • The test is forward-only, so the modified autograd backward (the new trailing None for query_time) is never exercised. A .backward() assertion on the explicit path would lock that signature in.

Minor: only a future query_time is tested — a request time before the events (negative gap → clamp(min=1e-6)) is a realistic clock-skew case worth one assertion that both kernels clamp identically.

@github-actions

Copy link
Copy Markdown

Code review summary

Clean, well-scoped fix. The diagnosis in the description is convincing, the optional-query_time threading is consistent across all three layers (pytorch op / triton kernel / dispatcher), and the None default keeps the DLRM-HSTU ranking path byte-identical. A few things I verified hold up well:

  • None path is genuinely zero-overhead — the HAS_QUERY_TIME constexpr gates the new pointer deref at compile time, so the existing kernel compiles unchanged; the pytorch None branch is the unmodified gather. Backward correctly appends one None for the non-differentiable input.
  • Docs are accurate — the markdown config example, the Chinese query_time description (units/optional/fallback), and the docstrings all match the code.
  • The op test asserting query_time=last_ts reproduces the gather-last path on both kernels is a strong regression guard.

The noteworthy items are in inline comments, in priority order:

  1. Validate the query_time group is a single scalar + add a host-side torch._assert(query_time.shape[0] == B) — together these close a silent OOB GPU read (triton) / silent broadcast (pytorch) under a misconfigured group. (hstu.py, triton_position.py)
  2. Silent no-op when the group is configured but use_time_encoding=false. (hstu.py)
  3. Test gaps: the new op test is GPU-gated + forward-only, so CPU CI only smoke-tests query_time and the modified autograd backward is never exercised; negative-gap clamp untested. (position_test.py)

Nothing blocking; (1) is the one I'd most want addressed before merge.

@tiankongdeguiji tiankongdeguiji merged commit 58fd5cc into master May 26, 2026
11 checks passed
@tiankongdeguiji tiankongdeguiji deleted the fix/hstu-match-query-time branch June 23, 2026 10:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants