Skip to content

Commit e24ee27

Browse files
[test] guard DlrmHSTU predict num_targets order against descending-timestamp reversal
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent e66e90c commit e24ee27

1 file changed

Lines changed: 29 additions & 0 deletions

File tree

tzrec/models/dlrm_hstu_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tzrec.features.feature import create_features
2525
from tzrec.models.dlrm_hstu import DlrmHSTU
2626
from tzrec.models.model import TrainWrapper
27+
from tzrec.models.rank_model import TARGET_REPEAT_INTERLEAVE_KEY
2728
from tzrec.ops import Kernel
2829
from tzrec.protos import (
2930
feature_pb2,
@@ -445,6 +446,34 @@ def test_dlrm_hstu(
445446
self.assertEqual(predictions["logits_is_comment"].size(), (6,))
446447
self.assertEqual(predictions["probs_is_comment"].size(), (6,))
447448

449+
@unittest.skipIf(*gpu_unavailable)
450+
def test_dlrm_hstu_predict_num_targets_order(self) -> None:
451+
"""num_targets split key must stay in input order.
452+
453+
``_write_predictions`` regroups predictions per request via
454+
``cumsum(predictions[TARGET_REPEAT_INTERLEAVE_KEY])``. With
455+
``sequence_timestamp_is_ascending=False``, ``predict()`` flips features
456+
(reversing request order) and flips predictions back, so the key must
457+
also be un-flipped; reading it from the still-flipped
458+
``candidate.sequence_length`` returns [4, 2] for the [2, 4] test batch,
459+
misassigning whole-request blocks. ``size()`` (== 6) can't catch it.
460+
"""
461+
device = torch.device("cuda")
462+
# candidate (cand_seq) counts in _build_batch, in input order.
463+
expected_num_targets = [2, 4]
464+
for ascending in (True, False):
465+
model = _build_model(
466+
device=device, sequence_timestamp_is_ascending=ascending
467+
)
468+
batch = _build_batch(device=device)
469+
with torch.no_grad():
470+
predictions = model.predict(batch)
471+
self.assertEqual(
472+
predictions[TARGET_REPEAT_INTERLEAVE_KEY].cpu().tolist(),
473+
expected_num_targets,
474+
msg=f"num_targets order wrong for ascending={ascending}",
475+
)
476+
448477
@unittest.skipIf(*gpu_unavailable)
449478
def test_dlrm_hstu_task_weight(self) -> None:
450479
device = torch.device("cuda")

0 commit comments

Comments
 (0)