@@ -115,11 +115,7 @@ def prefill_normal(
115115 model_input , run_reqs = prepare_prefill_inputs (prefill_reqs , is_chuncked_mode = not self .disable_chunked_prefill )
116116 with torch .cuda .stream (g_infer_context .get_overlap_stream ()):
117117 model_output = self .model .forward (model_input )
118- (
119- _ ,
120- next_token_ids_cpu ,
121- next_token_logprobs_cpu ,
122- ) = self ._sample_and_scatter_token (
118+ (_ , next_token_ids_cpu , next_token_logprobs_cpu ,) = self ._sample_and_scatter_token (
123119 logits = model_output .logits ,
124120 b_req_idx = model_input .b_req_idx ,
125121 b_mtp_index = model_input .b_mtp_index ,
@@ -162,11 +158,7 @@ def decode_normal(
162158 model_input , run_reqs = prepare_decode_inputs (decode_reqs )
163159 with torch .cuda .stream (g_infer_context .get_overlap_stream ()):
164160 model_output = self .model .forward (model_input )
165- (
166- _ ,
167- next_token_ids_cpu ,
168- next_token_logprobs_cpu ,
169- ) = self ._sample_and_scatter_token (
161+ (_ , next_token_ids_cpu , next_token_logprobs_cpu ,) = self ._sample_and_scatter_token (
170162 logits = model_output .logits ,
171163 b_req_idx = model_input .b_req_idx ,
172164 b_mtp_index = model_input .b_mtp_index ,
@@ -204,11 +196,7 @@ def prefill_mtp(
204196 model_input , run_reqs = prepare_prefill_inputs (prefill_reqs , is_chuncked_mode = not self .disable_chunked_prefill )
205197 with torch .cuda .stream (g_infer_context .get_overlap_stream ()):
206198 model_output = self .model .forward (model_input )
207- (
208- next_token_ids ,
209- next_token_ids_cpu ,
210- next_token_logprobs_cpu ,
211- ) = self ._sample_and_scatter_token (
199+ (next_token_ids , next_token_ids_cpu , next_token_logprobs_cpu ,) = self ._sample_and_scatter_token (
212200 logits = model_output .logits ,
213201 b_req_idx = model_input .b_req_idx ,
214202 b_mtp_index = model_input .b_mtp_index ,
@@ -490,11 +478,7 @@ def _draft_decode_eagle(
490478 g_infer_state_lock .release ()
491479 eagle_mem_indexes = eagle_mem_indexes_cpu .cuda (non_blocking = True )
492480
493- (
494- draft_model_input ,
495- draft_next_token_ids ,
496- accepted_req_idx ,
497- ) = self ._build_eagle_accepted_draft_input (
481+ (draft_model_input , draft_next_token_ids , accepted_req_idx ,) = self ._build_eagle_accepted_draft_input (
498482 main_model_input = main_model_input ,
499483 main_model_output = main_model_output ,
500484 next_token_ids = next_token_ids ,
0 commit comments