Skip to content

Commit 427efad

Browse files
authored
[Feature] Support set PREEMPTED_TOKEN_ID in GET_SAVE_OUTPUT_V1 (#7159)
* [Feature] Support set PREEMPTED_TOKEN_ID in GET_SAVE_OUTPUT_V1 * [Feature] Support set PREEMPTED_TOKEN_ID in GET_SAVE_OUTPUT_V1 * fix
1 parent 9b970de commit 427efad

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

fastdeploy/worker/gpu_model_runner.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from paddle import nn
2828
from paddleformers.utils.log import logger
2929

30-
from fastdeploy.config import FDConfig
30+
from fastdeploy.config import PREEMPTED_TOKEN_ID, FDConfig
3131
from fastdeploy.engine.pooling_params import PoolingParams
3232
from fastdeploy.engine.request import ImagePosition, Request, RequestType
3333
from fastdeploy.model_executor.graph_optimization.utils import (
@@ -2409,6 +2409,16 @@ def _postprocess(
24092409

24102410
# 5.1. Async cpy
24112411
post_process_event = paddle.device.cuda.create_event()
2412+
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
2413+
# If one query is preempted, there is no sampled token for it, we use token_id PREEMPTED_TOKEN_ID to signal server, abort is finished.
2414+
paddle.assign(
2415+
paddle.where(
2416+
self.share_inputs["last_preempted_idx"][: sampler_output.sampled_token_ids.shape[0]] == 1,
2417+
PREEMPTED_TOKEN_ID,
2418+
sampler_output.sampled_token_ids,
2419+
),
2420+
sampler_output.sampled_token_ids,
2421+
)
24122422
# if not self.speculative_decoding:
24132423
self.share_inputs["sampled_token_ids"].copy_(sampler_output.sampled_token_ids, False)
24142424
if self.speculative_decoding:

0 commit comments

Comments
 (0)