|
24 | 24 | from tzrec.features.feature import create_features |
25 | 25 | from tzrec.models.dlrm_hstu import DlrmHSTU |
26 | 26 | from tzrec.models.model import TrainWrapper |
| 27 | +from tzrec.models.rank_model import TARGET_REPEAT_INTERLEAVE_KEY |
27 | 28 | from tzrec.ops import Kernel |
28 | 29 | from tzrec.protos import ( |
29 | 30 | feature_pb2, |
@@ -445,6 +446,34 @@ def test_dlrm_hstu( |
445 | 446 | self.assertEqual(predictions["logits_is_comment"].size(), (6,)) |
446 | 447 | self.assertEqual(predictions["probs_is_comment"].size(), (6,)) |
447 | 448 |
|
| 449 | + @unittest.skipIf(*gpu_unavailable) |
| 450 | + def test_dlrm_hstu_predict_num_targets_order(self) -> None: |
| 451 | + """num_targets split key must stay in input order. |
| 452 | +
|
| 453 | + ``_write_predictions`` regroups predictions per request via |
| 454 | + ``cumsum(predictions[TARGET_REPEAT_INTERLEAVE_KEY])``. With |
| 455 | + ``sequence_timestamp_is_ascending=False``, ``predict()`` flips features |
| 456 | + (reversing request order) and flips predictions back, so the key must |
| 457 | + also be un-flipped; reading it from the still-flipped |
| 458 | + ``candidate.sequence_length`` returns [4, 2] for the [2, 4] test batch, |
| 459 | + misassigning whole-request blocks. ``size()`` (== 6) can't catch it. |
| 460 | + """ |
| 461 | + device = torch.device("cuda") |
| 462 | + # candidate (cand_seq) counts in _build_batch, in input order. |
| 463 | + expected_num_targets = [2, 4] |
| 464 | + for ascending in (True, False): |
| 465 | + model = _build_model( |
| 466 | + device=device, sequence_timestamp_is_ascending=ascending |
| 467 | + ) |
| 468 | + batch = _build_batch(device=device) |
| 469 | + with torch.no_grad(): |
| 470 | + predictions = model.predict(batch) |
| 471 | + self.assertEqual( |
| 472 | + predictions[TARGET_REPEAT_INTERLEAVE_KEY].cpu().tolist(), |
| 473 | + expected_num_targets, |
| 474 | + msg=f"num_targets order wrong for ascending={ascending}", |
| 475 | + ) |
| 476 | + |
448 | 477 | @unittest.skipIf(*gpu_unavailable) |
449 | 478 | def test_dlrm_hstu_task_weight(self) -> None: |
450 | 479 | device = torch.device("cuda") |
|
0 commit comments