Skip to content

Commit dd2aa10

Browse files
authored
fix cuda graph capture failure in CI test (#7094)
1 parent daa9524 commit dd2aa10

1 file changed

Lines changed: 9 additions & 2 deletions

File tree

custom_ops/gpu_ops/speculate_decoding/verify_draft_tokens.cu

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,16 @@ __global__ void verify_draft_tokens(
252252
break;
253253
}
254254

255-
// Accept-all override (debug/warmup)
255+
// Accept-all override (debug/warmup/CUDA graph capture)
256256
if (accept_all) {
257-
if (ctx.emit_token(i, ctx.step_input_ids_now[i + 1])) break;
257+
int64_t token = ctx.step_input_ids_now[i + 1];
258+
// During dummy run (accept_all), replace EOS tokens with a safe
259+
// non-EOS value to prevent stop_flags being set, which would cause
260+
// CUDA graph capture failure due to token count mismatch.
261+
if (is_in_end(token, end_tokens, end_length)) {
262+
token = 5;
263+
}
264+
if (ctx.emit_token(i, token)) break;
258265
continue;
259266
}
260267

0 commit comments

Comments
 (0)