Skip to content

Commit 2740083

Browse files
blueswhenniushengxiao
andauthored
add Flashinfer sampling backend (#1328)
Co-authored-by: niushengxiao <niushengxiao@sensetime.com>
1 parent 78e34a7 commit 2740083

5 files changed

Lines changed: 22 additions & 10 deletions

File tree

lightllm/server/api_cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,10 +647,10 @@ def make_argument_parser() -> argparse.ArgumentParser:
647647
parser.add_argument(
648648
"--sampling_backend",
649649
type=str,
650-
choices=["triton", "sglang_kernel"],
650+
choices=["triton", "flashinfer"],
651651
default="triton",
652652
help="""sampling used impl. 'triton' is use torch and triton kernel,
653-
sglang_kernel use sglang_kernel impl""",
653+
flashinfer use flashinfer sampling impl""",
654654
)
655655
parser.add_argument(
656656
"--penalty_counter_mode",

lightllm/server/core/objs/start_args_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class StartArgs:
149149
default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt", "fp8kv_dsa"]}
150150
)
151151
llm_kv_quant_group_size: int = field(default=8)
152-
sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]})
152+
sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "flashinfer"]})
153153
penalty_counter_mode: str = field(
154154
default="gpu_counter", metadata={"choices": ["cpu_counter", "pin_mem_counter", "gpu_counter"]}
155155
)

lightllm/server/router/model_infer/mode_backend/generic_post_process.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ def _top_p_top_k_sample(
114114
b_top_ks: torch.Tensor,
115115
exist_req_use_random_seed: bool,
116116
) -> Tuple[torch.Tensor, torch.Tensor]:
117-
if get_env_start_args().sampling_backend == "triton":
117+
sampling_backend = get_env_start_args().sampling_backend
118+
119+
if sampling_backend == "triton":
118120
probs_sort, probs_idx = _top_p_top_k(probs, b_top_ps, b_top_ks)
119121
if not exist_req_use_random_seed:
120122
sampled_index = torch.multinomial(probs_sort, num_samples=1, replacement=True)
@@ -124,8 +126,8 @@ def _top_p_top_k_sample(
124126
next_token_logprobs = torch.log(torch.gather(probs_sort, dim=1, index=sampled_index))
125127
return next_token_ids.view(-1), next_token_logprobs.view(-1)
126128

127-
elif get_env_start_args().sampling_backend == "sglang_kernel":
128-
from sgl_kernel import top_k_top_p_sampling_from_probs
129+
elif sampling_backend == "flashinfer":
130+
from flashinfer.sampling import top_k_top_p_sampling_from_probs
129131

130132
batch_next_token_ids = top_k_top_p_sampling_from_probs(
131133
probs,

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ frozendict==2.4.6
8080
atomics==1.0.3
8181
easydict==1.13
8282
hypercorn==0.18.0
83-
flashinfer-python==0.6.8.post1
84-
flashinfer-cubin==0.6.8.post1
83+
flashinfer-python==0.6.12
84+
flashinfer-cubin==0.6.12
8585
sglang-kernel==0.4.2.post1
8686
httpx==0.28.1
8787
librosa==0.11.0

test/benchmark/service/benchmark_qps.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ def get_custom_input_data(data_path, output_len, tokenizer, range_ratio):
108108

109109

110110
model_name = []
111+
sampling_config = {
112+
"temperature": 1.0,
113+
"top_p": 0.9,
114+
"top_k": -1,
115+
}
111116

112117

113118
# Minimal fix: one retry on transient network errors.
@@ -123,7 +128,9 @@ async def async_post_stream_openai(url, prompt, max_new_tokens, session):
123128
"max_tokens": max_new_tokens,
124129
"ignore_eos": True,
125130
"stream": True,
126-
"temperature": 0.0,
131+
"temperature": sampling_config["temperature"],
132+
"top_p": sampling_config["top_p"],
133+
"top_k": sampling_config["top_k"],
127134
"best_of": 1,
128135
}
129136
headers = {"Content-Type": "application/json"}
@@ -166,9 +173,12 @@ async def async_post_stream_lightllm(url, prompt, max_new_tokens, session):
166173
data = {
167174
"inputs": text_input,
168175
"parameters": {
169-
"do_sample": False,
176+
"do_sample": True,
170177
"ignore_eos": True,
171178
"max_new_tokens": max_new_tokens,
179+
"temperature": sampling_config["temperature"],
180+
"top_p": sampling_config["top_p"],
181+
"top_k": sampling_config["top_k"],
172182
"add_special_tokens": False,
173183
},
174184
}

0 commit comments

Comments
 (0)