Skip to content

Commit 70b4972

Browse files
authored
[Speculate Decoding] Reset reasoning_status when request finishes (#7660)
1 parent 78e3e9d commit 70b4972

3 files changed

Lines changed: 8 additions & 0 deletions

File tree

fastdeploy/worker/gpu_model_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
851851
self.share_inputs["enable_thinking"][idx : idx + 1, :] = enable_thinking
852852
if enable_thinking:
853853
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
854+
self.share_inputs["reasoning_status"][idx : idx + 1, :] = 0
854855
if request.get("reasoning_max_tokens") is not None:
855856
# Enable thinking
856857
self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get(
@@ -870,6 +871,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
870871
self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1
871872
self.share_inputs["max_reply_lens"][idx : idx + 1, :] = -1
872873
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
874+
self.share_inputs["reasoning_status"][idx : idx + 1, :] = 0
873875

874876
if isinstance(request.prompt_token_ids, np.ndarray):
875877
prompt_token_ids = request.prompt_token_ids.tolist()

fastdeploy/worker/metax_model_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
721721
self.share_inputs["enable_thinking"][idx : idx + 1, :] = enable_thinking
722722
if enable_thinking:
723723
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
724+
self.share_inputs["reasoning_status"][idx : idx + 1, :] = 0
724725
if request.get("reasoning_max_tokens") is not None:
725726
# Enable thinking
726727
self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get(
@@ -740,6 +741,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
740741
self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1
741742
self.share_inputs["max_reply_lens"][idx : idx + 1, :] = -1
742743
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
744+
self.share_inputs["reasoning_status"][idx : idx + 1, :] = 0
743745

744746
if isinstance(request.prompt_token_ids, np.ndarray):
745747
prompt_token_ids = request.prompt_token_ids.tolist()

fastdeploy/worker/xpu_model_runner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,10 +599,12 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
599599
# Enable thinking
600600
self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get("reasoning_max_tokens")
601601
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
602+
self.share_inputs["reasoning_status"][idx : idx + 1, :] = 0
602603
else:
603604
# Disable thinking
604605
self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1
605606
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
607+
self.share_inputs["reasoning_status"][idx : idx + 1, :] = 0
606608

607609
if (
608610
hasattr(request, "sampling_params")
@@ -798,10 +800,12 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
798800
# Enable thinking
799801
self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get("reasoning_max_tokens")
800802
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
803+
self.share_inputs["reasoning_status"][idx : idx + 1, :] = 0
801804
else:
802805
# Disable thinking
803806
self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1
804807
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
808+
self.share_inputs["reasoning_status"][idx : idx + 1, :] = 0
805809

806810
def get_attr_from_request(request, attr, default_value=None):
807811
res = request.get(attr, default_value)

0 commit comments

Comments
 (0)