@@ -109,6 +109,7 @@ def __init__(
109109 port : int = 8000 ,
110110 enable_graph : bool = False ,
111111 attn_backend : str = "default" ,
112+ ignore_eos : bool = False ,
112113 ):
113114 """Initialize inference server.
114115
@@ -150,6 +151,7 @@ def __init__(
150151 self .port = port
151152 self .enable_graph = enable_graph
152153 self .attn_backend = attn_backend
154+ self .ignore_eos = ignore_eos
153155
154156 self .engine : AsyncLLMEngine = None
155157
@@ -331,6 +333,7 @@ def pick(key: str, default):
331333 top_k = int (pick ("top_k" , self .top_k )),
332334 max_tokens = int (max_tokens ) if max_tokens is not None else None ,
333335 stop = stop ,
336+ ignore_eos = self .ignore_eos ,
334337 )
335338
336339 async def _stream_chat (self , request_id : str , data : dict , http_request : Request ):
@@ -382,7 +385,11 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request)
382385 # Skip EOS token text for OpenAI API compatibility
383386 # Check if this token is an EOS token by comparing token_id with eos_token_ids
384387 eos_token_ids = self .engine .engine .eos_token_ids
385- is_eos_token = eos_token_ids and token_output .token_id in eos_token_ids
388+ is_eos_token = (
389+ not sampling_params .ignore_eos
390+ and eos_token_ids
391+ and token_output .token_id in eos_token_ids
392+ )
386393
387394 if not is_eos_token and token_output .token_text :
388395 # Send token
@@ -631,6 +638,13 @@ def parse_args():
631638 choices = ["DEBUG" , "INFO" , "WARNING" , "ERROR" , "CRITICAL" ],
632639 help = "Logging level" ,
633640 )
641+ parser .add_argument (
642+ "--ignore-eos" ,
643+ action = "store_true" ,
644+ dest = "ignore_eos" ,
645+ default = False ,
646+ help = "Ignore EOS token and continue generation" ,
647+ )
634648
635649 return parser .parse_args ()
636650
@@ -688,6 +702,7 @@ def main():
688702 port = args .port ,
689703 enable_graph = args .enable_graph ,
690704 attn_backend = args .attn ,
705+ ignore_eos = args .ignore_eos ,
691706 )
692707 server .start ()
693708
0 commit comments