[bugfix] DlrmHSTU predict: keep candidate split key in original order for descending timestamps#540
Conversation
…ing timestamps DlrmHSTU.predict flips grouped_features along dim 0 when sequence_timestamp_is_ascending is false (reversing request order), runs the model, then flips mt_preds back to original order. TARGET_REPEAT_INTERLEAVE_KEY was read from the still-flipped grouped_features['candidate.sequence_length'], so the per-request output-split boundaries were in reversed order while the predictions were in original order. When candidate counts differ across requests in a batch, whole-request prediction blocks were assigned to the wrong rows. Capture candidate.sequence_length before the flip so the split key matches the prediction order. Single-row batches were unaffected (a request flips to itself), which is why batch_size=1 was always correct. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
89c9aff to
e66e90c
Compare
Review summaryReviewed across code-quality, performance, tests, docs, and security. The fix is correct, minimal, and well-analyzed — capturing One noteworthy gap: no regression testThis is a silent, serving-affecting correctness bug, and nothing in the suite guards against its return:
Low-cost guard: in the descending branch (batch_size ≥ 2, variable Not blocking — the fix is good to merge as-is, but a regression test is well worth adding given how quietly this one slipped through. |
…mestamp reversal Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
cae1cdc to
e24ee27
Compare
Merge upstream/master (1.2.17, incl. alibaba#540 DlrmHSTU fix) and bump. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…_abstract Brings the reviewed alibaba#539 foundation onto feat/sid_abstract (which already carries alibaba#538 + an older RQ-VAE/RQ-Kmeans port), and syncs to upstream/master (alibaba#540, alibaba#541, which alibaba#539 already contains). Conflict resolutions: - sid_rqkmeans.py(+test), residual_kmeans_quantizer.py, sid_model.py: take alibaba#539's canonical versions (BaseSidModel now hosts both SID models, with mse/rel_loss/unique_sid_ratio and the unified x_hat recon key). - types.py: union — keep alibaba#539's QuantizeOutput, retain feat's QuantizeForwardMode enum + ResidualQuantizerOutput (RQ-VAE needs them). - protos/models/sid_model.proto: union — alibaba#539's typed FaissKmeansConfig + clean SidRqkmeans, re-add feat's SinkhornConfig/ClipConfig/SidRqvae; drop the now-unused struct.proto import. - protos/model.proto: enable `SidRqvae sid_rqvae = 600;` (the field alibaba#539 reserved for this follow-up). - main.py / model.py on_train_end: take alibaba#539's wording; drop feat's forced tail-checkpoint (SID models rely on the final=True tail save). Transitional state: old modules/sid/kmeans.py still coexists with alibaba#539's kmeans_quantize.py, and the RQ-VAE stack is still on the old abstraction — both retired in the follow-up refactor commit. All SID modules import.
Problem
DlrmHSTU.predict(inherited byUltraHSTU) can attach a request's predictionsto the wrong output row when running inference with
batch_size > 1andsequence_timestamp_is_ascending = false. The per-candidate values themselvesare correct, but whole-request prediction blocks get assigned to the wrong
request_id/reserved columns.batch_size = 1is always correct, which makesthe discrepancy look batch-size dependent.
Root cause
In the descending-timestamp path,
predict():grouped_featuresalong dim 0 (_fx_flip_tensor_dict),which reverses both the in-sequence token order and the request order in
the batch;
mt_predsback to the original request order;grouped_features["candidate.sequence_length"]:_write_predictionsregroups the flat per-candidate predictions into per-requestrows via
cumsum(TARGET_REPEAT_INTERLEAVE_KEY). Because the predictions are inoriginal order but the split key is reversed, the cumulative boundaries diverge
whenever the per-request candidate count (
num_targets) is not uniform acrossthe batch. Where all requests share the same candidate count,
cumsum(reversed) == cumsum(original)so nothing breaks — which is why only therequests near a differing-length request are affected, and why
batch_size = 1(a single request flips to itself) is always correct. The misaligned span is
exactly the rows between a variable-length request's position and its flip-mirror
within the batch.
Fix
Capture
candidate.sequence_lengthbefore the flip and use that original-ordercopy as the split key (the predictions are already flipped back to original
order, so the split key must match):
Verification
Re-ran checkpoint prediction on a descending-timestamp DlrmHSTU model with
variable per-request candidate counts. Before the fix a subset of rows did not
match their own per-row reference scores and the affected count changed with
num_workers(which alters batch composition); after the fix every output rowmatches its reference and the output is identical across
batch_sizeandnum_workers. The bug also affects the exported/scripted serving graph (sametraced
predict()), so a re-export is needed to propagate the fix to a deployedartifact.
🤖 Generated with Claude Code