Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions tzrec/models/dlrm_hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]:
with record_function("## preprocess ##"):
grouped_features = self.build_input(batch)

# Capture num_targets before the descending-timestamp flip below, so the
# output split key stays in the original (un-flipped) request order.
num_targets = grouped_features["candidate.sequence_length"]

if not self._model_config.sequence_timestamp_is_ascending:
# if timestamp of sequence is descending,
# we should reverse all features
Expand Down Expand Up @@ -235,9 +239,7 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]:
suffix=f"_{task_name}",
)
)
predictions[TARGET_REPEAT_INTERLEAVE_KEY] = grouped_features[
"candidate.sequence_length"
]
predictions[TARGET_REPEAT_INTERLEAVE_KEY] = num_targets

return predictions

Expand Down
29 changes: 29 additions & 0 deletions tzrec/models/dlrm_hstu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tzrec.features.feature import create_features
from tzrec.models.dlrm_hstu import DlrmHSTU
from tzrec.models.model import TrainWrapper
from tzrec.models.rank_model import TARGET_REPEAT_INTERLEAVE_KEY
from tzrec.ops import Kernel
from tzrec.protos import (
feature_pb2,
Expand Down Expand Up @@ -445,6 +446,34 @@ def test_dlrm_hstu(
self.assertEqual(predictions["logits_is_comment"].size(), (6,))
self.assertEqual(predictions["probs_is_comment"].size(), (6,))

@unittest.skipIf(*gpu_unavailable)
def test_dlrm_hstu_predict_num_targets_order(self) -> None:
"""num_targets split key must stay in input order.

``_write_predictions`` regroups predictions per request via
``cumsum(predictions[TARGET_REPEAT_INTERLEAVE_KEY])``. With
``sequence_timestamp_is_ascending=False``, ``predict()`` flips features
(reversing request order) and flips predictions back, so the key must
also be un-flipped; reading it from the still-flipped
``candidate.sequence_length`` returns [4, 2] for the [2, 4] test batch,
misassigning whole-request blocks. ``size()`` (== 6) can't catch it.
"""
device = torch.device("cuda")
# candidate (cand_seq) counts in _build_batch, in input order.
expected_num_targets = [2, 4]
for ascending in (True, False):
model = _build_model(
device=device, sequence_timestamp_is_ascending=ascending
)
batch = _build_batch(device=device)
with torch.no_grad():
predictions = model.predict(batch)
self.assertEqual(
predictions[TARGET_REPEAT_INTERLEAVE_KEY].cpu().tolist(),
expected_num_targets,
msg=f"num_targets order wrong for ascending={ascending}",
)

@unittest.skipIf(*gpu_unavailable)
def test_dlrm_hstu_task_weight(self) -> None:
device = torch.device("cuda")
Expand Down
2 changes: 1 addition & 1 deletion tzrec/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "1.2.16"
__version__ = "1.2.17"
Loading