Skip to content

[bugfix] DlrmHSTU predict: keep candidate split key in original order for descending timestamps#540

Merged
tiankongdeguiji merged 2 commits into
alibaba:masterfrom
tiankongdeguiji:fix/hstu-predict-descending-misalign
Jun 8, 2026
Merged

[bugfix] DlrmHSTU predict: keep candidate split key in original order for descending timestamps#540
tiankongdeguiji merged 2 commits into
alibaba:masterfrom
tiankongdeguiji:fix/hstu-predict-descending-misalign

Conversation

@tiankongdeguiji

Copy link
Copy Markdown
Collaborator

Problem

DlrmHSTU.predict (inherited by UltraHSTU) can attach a request's predictions
to the wrong output row when running inference with batch_size > 1 and
sequence_timestamp_is_ascending = false. The per-candidate values themselves
are correct, but whole-request prediction blocks get assigned to the wrong
request_id/reserved columns. batch_size = 1 is always correct, which makes
the discrepancy look batch-size dependent.

Root cause

In the descending-timestamp path, predict():

  1. flips every tensor in grouped_features along dim 0 (_fx_flip_tensor_dict),
    which reverses both the in-sequence token order and the request order in
    the batch;
  2. runs the model and flips mt_preds back to the original request order;
  3. but sets the per-request output-split key from the still-flipped
    grouped_features["candidate.sequence_length"]:
predictions[TARGET_REPEAT_INTERLEAVE_KEY] = grouped_features["candidate.sequence_length"]

_write_predictions regroups the flat per-candidate predictions into per-request
rows via cumsum(TARGET_REPEAT_INTERLEAVE_KEY). Because the predictions are in
original order but the split key is reversed, the cumulative boundaries diverge
whenever the per-request candidate count (num_targets) is not uniform across
the batch. Where all requests share the same candidate count,
cumsum(reversed) == cumsum(original) so nothing breaks — which is why only the
requests near a differing-length request are affected, and why batch_size = 1
(a single request flips to itself) is always correct. The misaligned span is
exactly the rows between a variable-length request's position and its flip-mirror
within the batch.

Fix

Capture candidate.sequence_length before the flip and use that original-order
copy as the split key (the predictions are already flipped back to original
order, so the split key must match):

grouped_features = self.build_input(batch)
num_targets = grouped_features["candidate.sequence_length"]  # original order
if not self._model_config.sequence_timestamp_is_ascending:
    grouped_features = _fx_flip_tensor_dict(grouped_features)
...
predictions[TARGET_REPEAT_INTERLEAVE_KEY] = num_targets

Verification

Re-ran checkpoint prediction on a descending-timestamp DlrmHSTU model with
variable per-request candidate counts. Before the fix a subset of rows did not
match their own per-row reference scores and the affected count changed with
num_workers (which alters batch composition); after the fix every output row
matches its reference and the output is identical across batch_size and
num_workers. The bug also affects the exported/scripted serving graph (same
traced predict()), so a re-export is needed to propagate the fix to a deployed
artifact.

🤖 Generated with Claude Code

…ing timestamps

DlrmHSTU.predict flips grouped_features along dim 0 when
sequence_timestamp_is_ascending is false (reversing request order), runs the
model, then flips mt_preds back to original order. TARGET_REPEAT_INTERLEAVE_KEY
was read from the still-flipped grouped_features['candidate.sequence_length'],
so the per-request output-split boundaries were in reversed order while the
predictions were in original order. When candidate counts differ across
requests in a batch, whole-request prediction blocks were assigned to the wrong
rows. Capture candidate.sequence_length before the flip so the split key matches
the prediction order. Single-row batches were unaffected (a request flips to
itself), which is why batch_size=1 was always correct.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji force-pushed the fix/hstu-predict-descending-misalign branch from 89c9aff to e66e90c Compare June 6, 2026 08:00
@tiankongdeguiji tiankongdeguiji added the claude-review Let Claude Review label Jun 6, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label Jun 6, 2026
@github-actions

github-actions Bot commented Jun 6, 2026

Copy link
Copy Markdown

Review summary

Reviewed across code-quality, performance, tests, docs, and security. The fix is correct, minimal, and well-analyzed — capturing num_targets before _fx_flip_tensor_dict is exactly right, since _fx_flip_tensor_dict allocates new tensors (so the captured reference stays in original request order) and mt_preds is flipped back to original order before the split key is attached. The inline comment accurately documents the subtle ordering invariant, and the pattern is FX-trace/jit-script safe. No performance or security concerns.

One noteworthy gap: no regression test

This is a silent, serving-affecting correctness bug, and nothing in the suite guards against its return:

  • test_dlrm_hstu (dlrm_hstu_test.py:441-446) parametrizes sequence_timestamp_is_ascending over True/False and _build_batch already uses variable per-request candidate counts ([2, 4]) at batch_size 2 — i.e. the buggy path is triggered — but the test only asserts .size(), which is invariant under the bug. The corrupted ordering is never observed.
  • The export/predict consumer _write_predictions (tzrec/main.py:1000-1027, asynchronous_complete_cumsum over TARGET_REPEAT_INTERLEAVE_KEY) is only exercised by the integration test, whose config leaves sequence_timestamp_is_ascending at its proto default of true — so the descending branch is never run end-to-end either.

Low-cost guard: in the descending branch (batch_size ≥ 2, variable num_targets), assert predictions[TARGET_REPEAT_INTERLEAVE_KEY] equals the original-order candidate.sequence_length (e.g. [2, 4]). That single tensor-equality assertion would have failed pre-fix and pins the exact invariant this PR restores.

Not blocking — the fix is good to merge as-is, but a regression test is well worth adding given how quietly this one slipped through.

…mestamp reversal

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji force-pushed the fix/hstu-predict-descending-misalign branch from cae1cdc to e24ee27 Compare June 6, 2026 09:02
@tiankongdeguiji tiankongdeguiji merged commit ad083a7 into alibaba:master Jun 8, 2026
7 checks passed
WhiteSwan1 added a commit to WhiteSwan1/TorchEasyRec that referenced this pull request Jun 9, 2026
Merge upstream/master (1.2.17, incl. alibaba#540 DlrmHSTU fix) and bump.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
WhiteSwan1 added a commit to WhiteSwan1/TorchEasyRec that referenced this pull request Jun 11, 2026
…_abstract

Brings the reviewed alibaba#539 foundation onto feat/sid_abstract (which already
carries alibaba#538 + an older RQ-VAE/RQ-Kmeans port), and syncs to upstream/master
(alibaba#540, alibaba#541, which alibaba#539 already contains).

Conflict resolutions:
- sid_rqkmeans.py(+test), residual_kmeans_quantizer.py, sid_model.py:
  take alibaba#539's canonical versions (BaseSidModel now hosts both SID models,
  with mse/rel_loss/unique_sid_ratio and the unified x_hat recon key).
- types.py: union — keep alibaba#539's QuantizeOutput, retain feat's
  QuantizeForwardMode enum + ResidualQuantizerOutput (RQ-VAE needs them).
- protos/models/sid_model.proto: union — alibaba#539's typed FaissKmeansConfig +
  clean SidRqkmeans, re-add feat's SinkhornConfig/ClipConfig/SidRqvae;
  drop the now-unused struct.proto import.
- protos/model.proto: enable `SidRqvae sid_rqvae = 600;` (the field alibaba#539
  reserved for this follow-up).
- main.py / model.py on_train_end: take alibaba#539's wording; drop feat's forced
  tail-checkpoint (SID models rely on the final=True tail save).

Transitional state: old modules/sid/kmeans.py still coexists with alibaba#539's
kmeans_quantize.py, and the RQ-VAE stack is still on the old abstraction —
both retired in the follow-up refactor commit. All SID modules import.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants