Skip to content

Commit b5a3678

Browse files
committed
fix: guard D2H copy with rank==0, replace per-element .item() with tolist()
1 parent 57050fb commit b5a3678

2 files changed

Lines changed: 16 additions & 18 deletions

File tree

lmdeploy/pytorch/engine/logits_process.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,8 +479,9 @@ def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor):
479479
result = __random_sampling(scores, indices)
480480

481481
if self.guided_decoding_manager and self.guided_processors:
482+
result_cpu = result.tolist()
482483
for i, processor in self.guided_processors.items():
483-
self.guided_decoding_manager.accept_token(processor, result[i].item())
484+
self.guided_decoding_manager.accept_token(processor, result_cpu[i])
484485

485486
return result
486487

src/turbomind/generation/guided_decoding.cc

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ void GuidedDecoding::ApplyMask(int phase, TensorMap& env)
105105

106106
void GuidedDecoding::ScheduleUpdate(int phase, TensorMap& env)
107107
{
108-
if (auto& d = *data_.at(phase); d.active) {
108+
if (auto& d = *data_.at(phase); d.active && tp_group_->rank() == 0) {
109109
// Record event on main stream after sampling GPU work is submitted.
110110
// The secondary stream will wait for this before issuing the D2H copy,
111111
// ensuring it reads the output_ids written by sampling.
@@ -121,30 +121,27 @@ void GuidedDecoding::ScheduleUpdate(int phase, TensorMap& env)
121121

122122
void GuidedDecoding::FinishUpdate(int phase)
123123
{
124-
if (auto& d = *data_.at(phase); d.active) {
124+
if (auto& d = *data_.at(phase); d.active && tp_group_->rank() == 0) {
125125
// Wait only for the D2H copy to complete — the main stream's
126126
// AppendTokenIds + stop_criteria may still be executing on GPU.
127127
d2h_done_.Sync();
128128

129-
if (tp_group_->rank() == 0) {
130-
// Collect active matchers and their token IDs for batch AcceptToken
131-
std::vector<xgrammar::GrammarMatcher> active_matchers;
132-
std::vector<int32_t> active_token_ids;
133-
active_matchers.reserve(d.matchers.size());
134-
active_token_ids.reserve(d.matchers.size());
129+
// Collect active matchers and their token IDs for batch AcceptToken
130+
std::vector<xgrammar::GrammarMatcher> active_matchers;
131+
std::vector<int32_t> active_token_ids;
132+
active_matchers.reserve(d.matchers.size());
133+
active_token_ids.reserve(d.matchers.size());
135134

136-
for (size_t i = 0; i < d.matchers.size(); ++i) {
137-
if (const auto& m = d.matchers[i]; m && !m->IsTerminated()) {
138-
active_matchers.emplace_back(*m);
139-
active_token_ids.emplace_back(output_ids_buf_[i]);
140-
}
135+
for (size_t i = 0; i < d.matchers.size(); ++i) {
136+
if (const auto& m = d.matchers[i]; m && !m->IsTerminated()) {
137+
active_matchers.emplace_back(*m);
138+
active_token_ids.emplace_back(output_ids_buf_[i]);
141139
}
140+
}
142141

143-
if (!active_matchers.empty()) {
144-
xgrammar::BatchGrammarMatcher::BatchAcceptToken(&active_matchers, active_token_ids);
145-
}
142+
if (!active_matchers.empty()) {
143+
xgrammar::BatchGrammarMatcher::BatchAcceptToken(&active_matchers, active_token_ids);
146144
}
147-
// active_matchers destroyed here: refcount-- for each entry
148145
}
149146
}
150147

0 commit comments

Comments
 (0)