Skip to content

Commit 6958472

Browse files
committed
fix: limit guided decoding loops to generation_size from logits shape
FillMask, ScheduleUpdate, and FinishUpdate previously iterated over d.matchers.size() entries, but only the first generation_size (= logits.shape(0)) slots are actively generating. Entries beyond that index contain stale output_ids and unused bitmasks. - FillMask: limit matcher iteration and reserve to gs = logits.shape(0) - ScheduleUpdate: copy only gs output_ids entries for D2H transfer - FinishUpdate: add TensorMap& env param, iterate only over gs slots Fixes review comments on PR #4605 (3280137130, 3280137198).
1 parent 7cec779 commit 6958472

3 files changed

Lines changed: 22 additions & 11 deletions

File tree

src/turbomind/generation/generation.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ struct Generation::Impl {
298298

299299
stop_criteria_->Forward(phase, env);
300300

301-
guided_decoding_->FinishUpdate(phase);
301+
guided_decoding_->FinishUpdate(phase, env);
302302
}
303303
}
304304
};

src/turbomind/generation/guided_decoding.cc

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ void GuidedDecoding::Setup(int phase, TensorMap& env)
5252
void GuidedDecoding::FillMask(int phase, TensorMap& env)
5353
{
5454
if (auto& d = *data_.at(phase); d.active) {
55+
// Only the first `generation_size` (= logits.shape(0)) slots are actively
56+
// generating; matchers beyond this index belong to idle/prefill requests
57+
// whose output_ids are stale and whose bitmasks are never applied.
58+
const int gs = env.at("logits").shape(0);
59+
5560
static_assert(sizeof(ssize_t) == sizeof(int64_t));
5661
DLTensor dlbitmask{bitmask_buf_.data(),
5762
DLDevice{kDLCPU, 0},
@@ -63,11 +68,11 @@ void GuidedDecoding::FillMask(int phase, TensorMap& env)
6368

6469
std::vector<xgrammar::GrammarMatcher> active_matchers;
6570
std::vector<int32_t> active_indices;
66-
active_matchers.reserve(d.matchers.size());
67-
active_indices.reserve(d.matchers.size());
71+
active_matchers.reserve(gs);
72+
active_indices.reserve(gs);
6873

6974
if (tp_group_->rank() == 0) {
70-
for (size_t i = 0; i < d.matchers.size(); ++i) {
75+
for (int i = 0; i < gs; ++i) {
7176
if (const auto& m = d.matchers[i]; m && !m->IsTerminated()) {
7277
active_matchers.emplace_back(*m);
7378
active_indices.emplace_back(static_cast<int32_t>(i));
@@ -113,26 +118,32 @@ void GuidedDecoding::ScheduleUpdate(int phase, TensorMap& env)
113118

114119
// D2H copy on secondary stream — overlaps with subsequent GPU kernels
115120
// on the main stream (AppendTokenIds, stop_criteria).
121+
// Only copy the first `generation_size` entries: sampling writes exactly
122+
// that many output_ids, and entries beyond it contain stale values.
123+
const int gs = env.at("logits").shape(0);
116124
d2h_stream_.Wait(sampling_done_);
117-
Copy(env.at("output_ids").buffer(), d.matchers.size(), output_ids_buf_, d2h_stream_);
125+
Copy(env.at("output_ids").buffer(), gs, output_ids_buf_, d2h_stream_);
118126
d2h_done_.Record(d2h_stream_);
119127
}
120128
}
121129

122-
void GuidedDecoding::FinishUpdate(int phase)
130+
void GuidedDecoding::FinishUpdate(int phase, TensorMap& env)
123131
{
124132
if (auto& d = *data_.at(phase); d.active && tp_group_->rank() == 0) {
125133
// Wait only for the D2H copy to complete — the main stream's
126134
// AppendTokenIds + stop_criteria may still be executing on GPU.
127135
d2h_done_.Sync();
128136

129-
// Collect active matchers and their token IDs for batch AcceptToken
137+
// Collect active matchers and their token IDs for batch AcceptToken.
138+
// Only iterate over the first `generation_size` (= logits.shape(0)) slots —
139+
// beyond that index the output_ids buffer contains stale data from prior steps.
140+
const int gs = env.at("logits").shape(0);
130141
std::vector<xgrammar::GrammarMatcher> active_matchers;
131142
std::vector<int32_t> active_token_ids;
132-
active_matchers.reserve(d.matchers.size());
133-
active_token_ids.reserve(d.matchers.size());
143+
active_matchers.reserve(gs);
144+
active_token_ids.reserve(gs);
134145

135-
for (size_t i = 0; i < d.matchers.size(); ++i) {
146+
for (int i = 0; i < gs; ++i) {
136147
if (const auto& m = d.matchers[i]; m && !m->IsTerminated()) {
137148
active_matchers.emplace_back(*m);
138149
active_token_ids.emplace_back(output_ids_buf_[i]);

src/turbomind/generation/guided_decoding.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class GuidedDecoding: public BaseGenerationParam {
2222
void ApplyMask(int phase, TensorMap& env);
2323

2424
void ScheduleUpdate(int phase, TensorMap& env);
25-
void FinishUpdate(int phase);
25+
void FinishUpdate(int phase, TensorMap& env);
2626

2727
private:
2828
comm::HostComm tp_group_;

0 commit comments

Comments
 (0)