Skip to content

Commit 0752b1d

Browse files
committed
style: fix black formatting and drop unused var for pre-commit
1 parent eaa4ba7 commit 0752b1d

6 files changed

Lines changed: 19 additions & 35 deletions

File tree

lightllm/common/basemodel/attention/fa3/fp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _nomarl_prefill_att(
105105

106106
k_descale, v_descale = None, None # disable quantization
107107
Lq = q.shape[-1]
108-
sm_scale = 1.0 / (Lq**0.5)
108+
sm_scale = 1.0 / (Lq ** 0.5)
109109
o = flash_attn_with_kvcache(
110110
q=q,
111111
k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]),
@@ -237,7 +237,7 @@ def _normal_decode_att(
237237

238238
k_descale, v_descale = None, None # disable quantization
239239
Lq = q.shape[-1]
240-
sm_scale = 1.0 / (Lq**0.5)
240+
sm_scale = 1.0 / (Lq ** 0.5)
241241
o = flash_attn_with_kvcache(
242242
q=q,
243243
k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]),

lightllm/common/basemodel/attention/fa3/fp8.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,12 @@ def init_state(self):
4444
torch.arange(batch_size, device=device), self.infer_state.b_q_seq_len
4545
)
4646
# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
47-
self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
48-
self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
49-
47+
self.k_descale = (
48+
offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
49+
)
50+
self.v_descale = (
51+
offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
52+
)
5053

5154
def prefill_att(
5255
self,
@@ -115,16 +118,19 @@ def init_state(self):
115118
super().init_state()
116119
self.backend: Fp8Fa3AttBackend = self.backend
117120

118-
device = self.infer_state.input_ids.device
119121
batch_size = self.b_att_seq_len.shape[0]
120122
mem_manager = self.backend.model.mem_manager
121123

122124
offline_scales: torch.Tensor = mem_manager.scales
123125
head_num = mem_manager.head_num
124126

125127
# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
126-
self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
127-
self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
128+
self.k_descale = (
129+
offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
130+
)
131+
self.v_descale = (
132+
offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
133+
)
128134

129135
return
130136

lightllm/models/qwen3_5_mtp/layer_infer/pre_layer_infer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77

88
class Qwen3_5MTPPreLayerInfer(Qwen3VLMultimodalPreLayerInfer):
9-
109
def __init__(self, network_config):
1110
super().__init__(network_config)
1211
self.eps_ = network_config["rms_norm_eps"]

lightllm/models/qwen3_5_mtp/layer_weights/pre_and_post_layer_weight.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010

1111
class Qwen3_5MTPPreAndPostLayerWeight(PreAndPostLayerWeight):
12-
1312
def __init__(self, data_type, network_config, quant_cfg: Quantcfg):
1413
super().__init__(data_type, network_config)
1514
self.quant_cfg: Quantcfg = quant_cfg

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,7 @@ def prefill_normal(
115115
model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill)
116116
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
117117
model_output = self.model.forward(model_input)
118-
(
119-
_,
120-
next_token_ids_cpu,
121-
next_token_logprobs_cpu,
122-
) = self._sample_and_scatter_token(
118+
(_, next_token_ids_cpu, next_token_logprobs_cpu,) = self._sample_and_scatter_token(
123119
logits=model_output.logits,
124120
b_req_idx=model_input.b_req_idx,
125121
b_mtp_index=model_input.b_mtp_index,
@@ -162,11 +158,7 @@ def decode_normal(
162158
model_input, run_reqs = prepare_decode_inputs(decode_reqs)
163159
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
164160
model_output = self.model.forward(model_input)
165-
(
166-
_,
167-
next_token_ids_cpu,
168-
next_token_logprobs_cpu,
169-
) = self._sample_and_scatter_token(
161+
(_, next_token_ids_cpu, next_token_logprobs_cpu,) = self._sample_and_scatter_token(
170162
logits=model_output.logits,
171163
b_req_idx=model_input.b_req_idx,
172164
b_mtp_index=model_input.b_mtp_index,
@@ -204,11 +196,7 @@ def prefill_mtp(
204196
model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill)
205197
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
206198
model_output = self.model.forward(model_input)
207-
(
208-
next_token_ids,
209-
next_token_ids_cpu,
210-
next_token_logprobs_cpu,
211-
) = self._sample_and_scatter_token(
199+
(next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu,) = self._sample_and_scatter_token(
212200
logits=model_output.logits,
213201
b_req_idx=model_input.b_req_idx,
214202
b_mtp_index=model_input.b_mtp_index,
@@ -490,11 +478,7 @@ def _draft_decode_eagle(
490478
g_infer_state_lock.release()
491479
eagle_mem_indexes = eagle_mem_indexes_cpu.cuda(non_blocking=True)
492480

493-
(
494-
draft_model_input,
495-
draft_next_token_ids,
496-
accepted_req_idx,
497-
) = self._build_eagle_accepted_draft_input(
481+
(draft_model_input, draft_next_token_ids, accepted_req_idx,) = self._build_eagle_accepted_draft_input(
498482
main_model_input=main_model_input,
499483
main_model_output=main_model_output,
500484
next_token_ids=next_token_ids,

unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,7 @@ def test_build_eagle_accepted_draft_input_narrows_to_accepted_rows():
286286
b_req_mtp_start_loc = torch.tensor([0, 3], dtype=torch.int32)
287287
mtp_accept_len = torch.tensor([2, 3], dtype=torch.int32)
288288

289-
(
290-
draft_input,
291-
accepted_next_tokens,
292-
accepted_req_idx,
293-
) = backend._build_eagle_accepted_draft_input(
289+
(draft_input, accepted_next_tokens, accepted_req_idx,) = backend._build_eagle_accepted_draft_input(
294290
main_model_input=main_input,
295291
main_model_output=main_output,
296292
next_token_ids=next_token_ids,

0 commit comments

Comments
 (0)