Skip to content

Commit 4eab14d

Browse files
authored
Merge pull request #300 from InfiniTensor/issue/299
issue/299 - allow ignoring eos in server
2 parents f73e18b + cb62ce2 commit 4eab14d

File tree

3 files changed

+36
-15
lines changed

3 files changed

+36
-15
lines changed

python/infinilm/llm/llm.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -334,21 +334,23 @@ def _check_request_finished(self, req: InferenceRequest, token_id: int) -> bool:
334334
req.finish_reason = FinishReason.LENGTH
335335
return True
336336

337-
# Check EOS token
338-
eos_ids = req.eos_token_ids or self.eos_token_ids
339-
if eos_ids and token_id in eos_ids:
340-
req.finish_reason = FinishReason.EOS_TOKEN
341-
return True
342-
343-
# Check stop strings
344-
# Remove stop string from generated_text if STOP_STRING finish reason
345-
stop_strings = req.sampling_params.stop or []
346-
for stop_str in stop_strings:
347-
if req.generated_text.endswith(stop_str):
348-
req.generated_text = req.generated_text[: -len(stop_str)]
349-
req.finish_reason = FinishReason.STOP_STRING
337+
if not req.sampling_params.ignore_eos:
338+
# Check EOS token - only stop if ignore_eos is False
339+
eos_ids = req.eos_token_ids or self.eos_token_ids
340+
if eos_ids and token_id in eos_ids:
341+
req.finish_reason = FinishReason.EOS_TOKEN
350342
return True
351343

344+
# While ignoring EOS, stop strings are also ignored to avoid requiring additional arguments for benchmarking.
345+
# Check stop strings
346+
# Remove stop string from generated_text if STOP_STRING is the finishing reason
347+
stop_strings = req.sampling_params.stop or []
348+
for stop_str in stop_strings:
349+
if req.generated_text.endswith(stop_str):
350+
req.generated_text = req.generated_text[: -len(stop_str)]
351+
req.finish_reason = FinishReason.STOP_STRING
352+
return True
353+
352354
return False
353355

354356
def tokenize(self, text: str) -> List[int]:

python/infinilm/llm/sampling_params.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ class SamplingParams:
1515
top_k: int = 1
1616
max_tokens: Optional[int] = None
1717
stop: Optional[List[str]] = None
18-
stop_token_ids: Optional[List[int]] = None # Placeholder for future usage, not currently handled
18+
stop_token_ids: Optional[List[int]] = (
19+
None # Placeholder for future usage, not currently handled
20+
)
21+
ignore_eos: bool = False
1922

2023
def __post_init__(self):
2124
if self.stop is None:
@@ -32,4 +35,5 @@ def clone(self) -> "SamplingParams":
3235
max_tokens=self.max_tokens,
3336
stop=self.stop.copy() if self.stop else None,
3437
stop_token_ids=self.stop_token_ids.copy() if self.stop_token_ids else None,
38+
ignore_eos=self.ignore_eos,
3539
)

python/infinilm/server/inference_server.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def __init__(
109109
port: int = 8000,
110110
enable_graph: bool = False,
111111
attn_backend: str = "default",
112+
ignore_eos: bool = False,
112113
):
113114
"""Initialize inference server.
114115
@@ -150,6 +151,7 @@ def __init__(
150151
self.port = port
151152
self.enable_graph = enable_graph
152153
self.attn_backend = attn_backend
154+
self.ignore_eos = ignore_eos
153155

154156
self.engine: AsyncLLMEngine = None
155157

@@ -331,6 +333,7 @@ def pick(key: str, default):
331333
top_k=int(pick("top_k", self.top_k)),
332334
max_tokens=int(max_tokens) if max_tokens is not None else None,
333335
stop=stop,
336+
ignore_eos=self.ignore_eos,
334337
)
335338

336339
async def _stream_chat(self, request_id: str, data: dict, http_request: Request):
@@ -382,7 +385,11 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request)
382385
# Skip EOS token text for OpenAI API compatibility
383386
# Check if this token is an EOS token by comparing token_id with eos_token_ids
384387
eos_token_ids = self.engine.engine.eos_token_ids
385-
is_eos_token = eos_token_ids and token_output.token_id in eos_token_ids
388+
is_eos_token = (
389+
not sampling_params.ignore_eos
390+
and eos_token_ids
391+
and token_output.token_id in eos_token_ids
392+
)
386393

387394
if not is_eos_token and token_output.token_text:
388395
# Send token
@@ -631,6 +638,13 @@ def parse_args():
631638
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
632639
help="Logging level",
633640
)
641+
parser.add_argument(
642+
"--ignore-eos",
643+
action="store_true",
644+
dest="ignore_eos",
645+
default=False,
646+
help="Ignore EOS token and continue generation",
647+
)
634648

635649
return parser.parse_args()
636650

@@ -688,6 +702,7 @@ def main():
688702
port=args.port,
689703
enable_graph=args.enable_graph,
690704
attn_backend=args.attn,
705+
ignore_eos=args.ignore_eos,
691706
)
692707
server.start()
693708

0 commit comments

Comments
 (0)