Skip to content

Commit 5b543d5

Browse files
wooway777MoringLotus
authored andcommitted
issue/297 - compile all paged batch sizes up to 64
1 parent 9acffbc commit 5b543d5

4 files changed

Lines changed: 52 additions & 34 deletions

File tree

csrc/engine/compiler/paged_compiler.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@ inline void set_minus_one(infinicore::Tensor &tensor) {
1818
namespace infinilm::engine {
1919
PagedCompiler::PagedCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier)
2020
: GraphCompiler(model, barrier) {
21-
for (size_t b = 1; b < 32; b++) {
22-
decode_batch_sizes_.push_back(b);
23-
}
24-
for (size_t b = 32; b < 64; b += 8) {
21+
for (size_t b = 1; b < 64; ++b) {
2522
decode_batch_sizes_.push_back(b);
2623
}
2724
for (size_t b = 64; b < 128; b += 16) {

python/infinilm/llm/llm.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -334,21 +334,23 @@ def _check_request_finished(self, req: InferenceRequest, token_id: int) -> bool:
334334
req.finish_reason = FinishReason.LENGTH
335335
return True
336336

337-
# Check EOS token
338-
eos_ids = req.eos_token_ids or self.eos_token_ids
339-
if eos_ids and token_id in eos_ids:
340-
req.finish_reason = FinishReason.EOS_TOKEN
341-
return True
342-
343-
# Check stop strings
344-
# Remove stop string from generated_text if STOP_STRING finish reason
345-
stop_strings = req.sampling_params.stop or []
346-
for stop_str in stop_strings:
347-
if req.generated_text.endswith(stop_str):
348-
req.generated_text = req.generated_text[: -len(stop_str)]
349-
req.finish_reason = FinishReason.STOP_STRING
337+
if not req.sampling_params.ignore_eos:
338+
# Check EOS token - only stop if ignore_eos is False
339+
eos_ids = req.eos_token_ids or self.eos_token_ids
340+
if eos_ids and token_id in eos_ids:
341+
req.finish_reason = FinishReason.EOS_TOKEN
350342
return True
351343

344+
# While ignoring EOS, stop strings are also ignored to avoid requiring additional arguments for benchmarking.
345+
# Check stop strings
346+
# Remove stop string from generated_text if STOP_STRING is the finishing reason
347+
stop_strings = req.sampling_params.stop or []
348+
for stop_str in stop_strings:
349+
if req.generated_text.endswith(stop_str):
350+
req.generated_text = req.generated_text[: -len(stop_str)]
351+
req.finish_reason = FinishReason.STOP_STRING
352+
return True
353+
352354
return False
353355

354356
def tokenize(self, text: str) -> List[int]:

python/infinilm/llm/sampling_params.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ class SamplingParams:
1515
top_k: int = 1
1616
max_tokens: Optional[int] = None
1717
stop: Optional[List[str]] = None
18-
stop_token_ids: Optional[List[int]] = None # Placeholder for future usage, not currently handled
18+
stop_token_ids: Optional[List[int]] = (
19+
None # Placeholder for future usage, not currently handled
20+
)
21+
ignore_eos: bool = False
1922

2023
def __post_init__(self):
2124
if self.stop is None:
@@ -32,4 +35,5 @@ def clone(self) -> "SamplingParams":
3235
max_tokens=self.max_tokens,
3336
stop=self.stop.copy() if self.stop else None,
3437
stop_token_ids=self.stop_token_ids.copy() if self.stop_token_ids else None,
38+
ignore_eos=self.ignore_eos,
3539
)

python/infinilm/server/inference_server.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def __init__(
109109
port: int = 8000,
110110
enable_graph: bool = False,
111111
attn_backend: str = "default",
112+
ignore_eos: bool = False,
112113
):
113114
"""Initialize inference server.
114115
@@ -150,6 +151,7 @@ def __init__(
150151
self.port = port
151152
self.enable_graph = enable_graph
152153
self.attn_backend = attn_backend
154+
self.ignore_eos = ignore_eos
153155

154156
self.engine: AsyncLLMEngine = None
155157

@@ -331,6 +333,7 @@ def pick(key: str, default):
331333
top_k=int(pick("top_k", self.top_k)),
332334
max_tokens=int(max_tokens) if max_tokens is not None else None,
333335
stop=stop,
336+
ignore_eos=self.ignore_eos,
334337
)
335338

336339
async def _stream_chat(self, request_id: str, data: dict, http_request: Request):
@@ -382,7 +385,11 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request)
382385
# Skip EOS token text for OpenAI API compatibility
383386
# Check if this token is an EOS token by comparing token_id with eos_token_ids
384387
eos_token_ids = self.engine.engine.eos_token_ids
385-
is_eos_token = eos_token_ids and token_output.token_id in eos_token_ids
388+
is_eos_token = (
389+
not sampling_params.ignore_eos
390+
and eos_token_ids
391+
and token_output.token_id in eos_token_ids
392+
)
386393

387394
if not is_eos_token and token_output.token_text:
388395
# Send token
@@ -631,6 +638,13 @@ def parse_args():
631638
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
632639
help="Logging level",
633640
)
641+
parser.add_argument(
642+
"--ignore-eos",
643+
action="store_true",
644+
dest="ignore_eos",
645+
default=False,
646+
help="Ignore EOS token and continue generation",
647+
)
634648

635649
return parser.parse_args()
636650

@@ -644,21 +658,22 @@ def main():
644658
server = InferenceServer(
645659
model_path=cfg.model,
646660
device=device,
647-
dtype=cfg.dtype,
648-
tensor_parallel_size=cfg.tp,
649-
cache_type=cfg.cache_type,
650-
max_tokens=cfg.max_tokens,
651-
max_batch_size=cfg.max_batch_size,
652-
num_blocks=cfg.num_blocks,
653-
block_size=cfg.block_size,
654-
max_cache_len=cfg.max_cache_len,
655-
temperature=cfg.temperature,
656-
top_p=cfg.top_p,
657-
top_k=cfg.top_k,
658-
host=cfg.host,
659-
port=cfg.port,
660-
enable_graph=cfg.enable_graph,
661-
attn_backend=cfg.attn,
661+
dtype=args.dtype,
662+
tensor_parallel_size=args.tp,
663+
cache_type=args.cache_type,
664+
max_tokens=args.max_tokens,
665+
max_batch_size=args.max_batch_size,
666+
num_blocks=args.num_blocks,
667+
block_size=args.block_size,
668+
max_cache_len=args.max_cache_len,
669+
temperature=args.temperature,
670+
top_p=args.top_p,
671+
top_k=args.top_k,
672+
host=args.host,
673+
port=args.port,
674+
enable_graph=args.enable_graph,
675+
attn_backend=args.attn,
676+
ignore_eos=args.ignore_eos,
662677
)
663678
server.start()
664679

0 commit comments

Comments
 (0)