|
1 | | -from unittest.mock import Mock |
| 1 | +from unittest.mock import Mock, patch |
2 | 2 |
|
3 | 3 | import paddle |
4 | 4 | import pytest |
|
12 | 12 | SpeculativeConfig, |
13 | 13 | StructuredOutputsConfig, |
14 | 14 | ) |
15 | | -from fastdeploy.worker.input_batch import InputBatch, reorder_split_prefill_and_decode |
| 15 | +from fastdeploy.worker.input_batch import ( |
| 16 | + InputBatch, |
| 17 | + ProposerInputBatch, |
| 18 | + reorder_split_prefill_and_decode, |
| 19 | +) |
16 | 20 |
|
17 | 21 |
|
18 | 22 | def create_mock_config(): |
@@ -61,6 +65,7 @@ def create_mock_config(): |
61 | 65 | scheduler_config = Mock(spec=SchedulerConfig) |
62 | 66 | scheduler_config.max_num_seqs = 10 |
63 | 67 | scheduler_config.max_num_batched_tokens = 2048 |
| 68 | + scheduler_config.max_extra_num_batched_tokens = 0 |
64 | 69 |
|
65 | 70 | speculative_config = Mock(spec=SpeculativeConfig) |
66 | 71 | speculative_config.method = None |
@@ -315,5 +320,119 @@ def test_reorder_all_prefill(self): |
315 | 320 | assert paddle.equal_all(input_batch.input_ids[i], original_data[i]) |
316 | 321 |
|
317 | 322 |
|
| 323 | +class TestProposerInputBatchReset: |
| 324 | + """Cover ProposerInputBatch.reset_model_inputs CUDA + token_ids_all branch |
| 325 | + (fastdeploy/worker/input_batch.py:972-985).""" |
| 326 | + |
| 327 | + def _make_config(self): |
| 328 | + # Enable spec_decoding path so InputBatch.init_share_inputs allocates |
| 329 | + # cu_seqlens_q_output / draft_tokens / accept_num etc. |
| 330 | + fd_config = create_mock_config() |
| 331 | + fd_config.speculative_config.method = "mtp" |
| 332 | + fd_config.speculative_config.num_speculative_tokens = 1 |
| 333 | + fd_config.speculative_config.num_model_steps = 1 |
| 334 | + return fd_config |
| 335 | + |
| 336 | + def _make_target(self, fd_config): |
| 337 | + target = InputBatch(fd_config) |
| 338 | + target.init_share_inputs() |
| 339 | + return target |
| 340 | + |
| 341 | + def _make_proposer(self, fd_config, target): |
| 342 | + """Construct a ProposerInputBatch and manually populate only the |
| 343 | + attributes that `reset_model_inputs` writes via `fill_paddle_tensor`. |
| 344 | + Skipping full `init_share_inputs` avoids depending on rope_emb, |
| 345 | + attention backends, and other heavy setup unrelated to the branch |
| 346 | + under test.""" |
| 347 | + proposer = ProposerInputBatch(fd_config, target) |
| 348 | + |
| 349 | + max_num_seqs = fd_config.scheduler_config.max_num_seqs |
| 350 | + hidden_size = fd_config.model_config.hidden_size |
| 351 | + max_draft_token_num = fd_config.speculative_config.num_speculative_tokens |
| 352 | + |
| 353 | + proposer.target_hidden_states = paddle.full([max_num_seqs, hidden_size], 0, dtype="bfloat16") |
| 354 | + proposer.draft_tokens = paddle.full([max_num_seqs, max_draft_token_num + 1], -1, dtype="int64") |
| 355 | + proposer.is_block_step = paddle.full([max_num_seqs, 1], False, dtype="bool") |
| 356 | + proposer.batch_drop = paddle.full([max_num_seqs, 1], False, dtype="bool") |
| 357 | + proposer.used_list_len = paddle.full([max_num_seqs], 0, dtype="int32") |
| 358 | + proposer.first_token_hidden_states = paddle.full([max_num_seqs, hidden_size], -1) |
| 359 | + proposer.batch_token_num = paddle.full([max_num_seqs], 0, dtype="int32") |
| 360 | + proposer.next_token_num = paddle.full([max_num_seqs], 0, dtype="int32") |
| 361 | + proposer.cu_batch_token_offset = paddle.full([max_num_seqs + 1], 0, dtype="int32") |
| 362 | + proposer.cu_next_token_offset = paddle.full([max_num_seqs + 1], 0, dtype="int32") |
| 363 | + proposer.mask_rollback = paddle.full([max_num_seqs, 1], 0, dtype="int32") |
| 364 | + proposer.recompute_token_num = paddle.full([max_num_seqs, 1], 0, dtype="int32") |
| 365 | + return proposer |
| 366 | + |
| 367 | + @patch("fastdeploy.worker.input_batch.current_platform") |
| 368 | + def test_reset_rebinds_token_ids_all_on_cuda(self, mock_platform): |
| 369 | + """When current_platform.is_cuda() and target has token_ids_all, |
| 370 | + reset_model_inputs must re-pull token_ids_all from target (line 973) |
| 371 | + and rebuild pre_ids from target.token_ids_all[bs_idx, prompt_len:].""" |
| 372 | + mock_platform.is_cuda.return_value = True |
| 373 | + mock_platform.is_xpu.return_value = False |
| 374 | + |
| 375 | + fd_config = self._make_config() |
| 376 | + target = self._make_target(fd_config) |
| 377 | + proposer = self._make_proposer(fd_config, target) |
| 378 | + |
| 379 | + max_num_seqs = fd_config.scheduler_config.max_num_seqs |
| 380 | + max_model_len = fd_config.model_config.max_model_len |
| 381 | + |
| 382 | + # Rebind target.token_ids_all to a NEW tensor with known content so |
| 383 | + # we can distinguish "reset re-pulled it" from "init_share_inputs |
| 384 | + # already bound to the old reference". |
| 385 | + new_token_ids_all = paddle.arange(max_num_seqs * max_model_len, dtype="int64").reshape( |
| 386 | + [max_num_seqs, max_model_len] |
| 387 | + ) |
| 388 | + target.token_ids_all = new_token_ids_all |
| 389 | + |
| 390 | + # Set a non-zero prompt_len so the slice [:, prompt_len:] is non-trivial. |
| 391 | + prompt_len_value = 3 |
| 392 | + target.prompt_lens = paddle.full([max_num_seqs, 1], prompt_len_value, dtype="int64") |
| 393 | + |
| 394 | + proposer.reset_model_inputs() |
| 395 | + |
| 396 | + # Line 973 effect: token_ids_all rebound to target's new tensor. |
| 397 | + assert proposer.token_ids_all is new_token_ids_all |
| 398 | + |
| 399 | + # Line 975-985 effect: pre_ids has correct shape and the prefix is |
| 400 | + # token_ids_all[bs_idx, prompt_len:], suffix remains -1. |
| 401 | + assert proposer.pre_ids.shape == [max_num_seqs, max_model_len] |
| 402 | + valid_len = max_model_len - prompt_len_value |
| 403 | + expected_prefix = new_token_ids_all[:, prompt_len_value:] |
| 404 | + assert paddle.equal_all(proposer.pre_ids[:, :valid_len], expected_prefix) |
| 405 | + assert paddle.equal_all( |
| 406 | + proposer.pre_ids[:, valid_len:], |
| 407 | + paddle.full([max_num_seqs, prompt_len_value], -1, dtype="int64"), |
| 408 | + ) |
| 409 | + |
| 410 | + @patch("fastdeploy.worker.input_batch.current_platform") |
| 411 | + def test_reset_falls_back_to_pre_ids_clone_when_no_token_ids_all(self, mock_platform): |
| 412 | + """When current_platform.is_cuda() but target lacks token_ids_all, |
| 413 | + reset_model_inputs takes the else branch (line 986-988): clone pre_ids, |
| 414 | + set token_ids_all to None.""" |
| 415 | + mock_platform.is_cuda.return_value = True |
| 416 | + mock_platform.is_xpu.return_value = False |
| 417 | + |
| 418 | + fd_config = self._make_config() |
| 419 | + target = self._make_target(fd_config) |
| 420 | + proposer = self._make_proposer(fd_config, target) |
| 421 | + |
| 422 | + # Remove token_ids_all from target so the else branch fires. |
| 423 | + del target.token_ids_all |
| 424 | + # Provide a recognizable pre_ids on target. |
| 425 | + max_num_seqs = fd_config.scheduler_config.max_num_seqs |
| 426 | + max_model_len = fd_config.model_config.max_model_len |
| 427 | + target.pre_ids = paddle.full([max_num_seqs, max_model_len], 42, dtype="int64") |
| 428 | + |
| 429 | + proposer.reset_model_inputs() |
| 430 | + |
| 431 | + assert proposer.token_ids_all is None |
| 432 | + assert paddle.equal_all(proposer.pre_ids, paddle.full([max_num_seqs, max_model_len], 42, dtype="int64")) |
| 433 | + # Clone, not reference share. |
| 434 | + assert proposer.pre_ids is not target.pre_ids |
| 435 | + |
| 436 | + |
318 | 437 | if __name__ == "__main__": |
319 | 438 | pytest.main([__file__, "-v"]) |
0 commit comments