Skip to content

Commit ef7c18e

Browse files
authored
bugfix: fix mtp sampling mixed issue. (#1128)
1 parent 0f7d258 commit ef7c18e

4 files changed

Lines changed: 68 additions & 2 deletions

File tree

xllm/core/framework/sampling/rejection_sampler.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,11 @@ RejectionSampler::RejectionSampler(
105105
rate_controller_(rate_controller),
106106
enable_fused_kernel_(enable_fused_kernel) {
107107
CHECK(do_sample.defined());
108-
// [batch_size, 1]
109-
do_sample_ = do_sample.unsqueeze_(/*dim=*/-1);
108+
// Keep a private expanded view and do not mutate the caller-owned tensor.
109+
// The same SamplingParameters object is reused later by MTP draft extend.
110+
// An in-place unsqueeze here corrupts Sampler::forward() mixed-mode shape
111+
// assumptions and can broadcast sampled token ids into 2D.
112+
do_sample_ = do_sample.unsqueeze(/*dim=*/-1);
110113
}
111114

112115
// draft_token_ids: [batch_size, n_speculative_tokens]

xllm/core/framework/sampling/rejection_sampler_test.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,55 @@ TEST(RejectionSamplerTest, LogProbs) {
254254
EXPECT_TRUE(torch::equal(output.top_tokens, top_k_indices));
255255
}
256256

257+
TEST(RejectionSamplerTest, ConstructorDoesNotMutateDoSampleShape) {
258+
const auto device = get_test_device();
259+
auto do_sample = torch::tensor({false, true}, torch::device(device));
260+
261+
ASSERT_EQ(do_sample.dim(), 1);
262+
ASSERT_EQ(do_sample.sizes(), torch::IntArrayRef({2}));
263+
264+
RejectionSampler sampler(do_sample,
265+
do_sample.all().item<bool>(),
266+
!do_sample.any().item<bool>(),
267+
/*logprobs=*/false,
268+
/*max_top_logprobs=*/0);
269+
270+
EXPECT_EQ(do_sample.dim(), 1);
271+
EXPECT_EQ(do_sample.sizes(), torch::IntArrayRef({2}));
272+
}
273+
274+
TEST(RejectionSamplerTest,
275+
ReusingDoSampleAfterRejectionSamplerKeepsSamplerOutput1D) {
276+
const auto options = get_test_options(torch::kFloat32);
277+
const auto device = get_test_device();
278+
auto do_sample = torch::tensor({false, true}, torch::device(device));
279+
280+
RejectionSampler rejection_sampler(do_sample,
281+
do_sample.all().item<bool>(),
282+
!do_sample.any().item<bool>(),
283+
/*logprobs=*/false,
284+
/*max_top_logprobs=*/0);
285+
(void)rejection_sampler;
286+
287+
SamplingParameters params;
288+
params.selected_token_idxes =
289+
torch::tensor({0, 1}, torch::dtype(torch::kInt64).device(device));
290+
params.sample_idxes =
291+
torch::tensor({0, 1}, torch::dtype(torch::kInt64).device(device));
292+
params.do_sample = do_sample;
293+
params.all_random_sample = false;
294+
params.all_greedy_sample = false;
295+
296+
auto logits =
297+
torch::tensor({{3.0f, 1.0f, 0.5f}, {0.1f, 0.2f, 4.0f}}, options);
298+
auto output = Sampler().forward(logits, params);
299+
300+
EXPECT_EQ(output.probs.dim(), 2);
301+
EXPECT_EQ(output.probs.size(0), 2);
302+
EXPECT_EQ(output.next_tokens.dim(), 1);
303+
EXPECT_EQ(output.next_tokens.size(0), 2);
304+
}
305+
257306
TEST(RejectionSamplerTest, Random) {
258307
const auto options = get_test_options(torch::kFloat32);
259308

xllm/core/framework/sampling/sampler.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ SampleOutput Sampler::forward(torch::Tensor& logits,
5151
sample_logits = logits.index_select(/*dim=*/0, params.sample_idxes);
5252
}
5353

54+
CHECK(params.do_sample.defined()) << "params.do_sample must be defined";
55+
CHECK_EQ(params.do_sample.dim(), 1)
56+
<< "params.do_sample must be 1D [num_seqs], got "
57+
<< params.do_sample.sizes();
5458
// same batch size
5559
CHECK_EQ(sample_logits.size(0), params.do_sample.size(0));
5660

xllm/core/runtime/mtp_worker_impl.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,16 @@ std::optional<ForwardOutput> MTPWorkerImpl::run_validate(
507507

508508
void MTPWorkerImpl::process_draft_sample_output(SampleOutput& sample_output) {
509509
if (sample_output.probs.defined()) {
510+
CHECK(sample_output.next_tokens.defined())
511+
<< "draft sample_output.next_tokens must be defined when probs exist";
512+
CHECK_EQ(sample_output.next_tokens.dim(), 1)
513+
<< "MTP draft cache expects next_tokens [batch], got "
514+
<< sample_output.next_tokens.sizes();
515+
CHECK(sample_output.probs.dim() == 1 || sample_output.probs.dim() == 2)
516+
<< "MTP draft cache expects probs [batch] or [batch,vocab], got "
517+
<< sample_output.probs.sizes();
518+
CHECK_EQ(sample_output.probs.size(0), sample_output.next_tokens.size(0))
519+
<< "MTP draft cache probs/token batch mismatch";
510520
// Cache always stores selected-only draft probs [batch_size] to reduce HBM.
511521
sample_output.probs = specBuilder::draftProbs::compress_for_cache(
512522
sample_output.probs, sample_output.next_tokens);

0 commit comments

Comments
 (0)