@@ -109,6 +109,7 @@ def __init__(
109109 port : int = 8000 ,
110110 enable_graph : bool = False ,
111111 attn_backend : str = "default" ,
112+ ignore_eos : bool = False ,
112113 ):
113114 """Initialize inference server.
114115
@@ -150,6 +151,7 @@ def __init__(
150151 self .port = port
151152 self .enable_graph = enable_graph
152153 self .attn_backend = attn_backend
154+ self .ignore_eos = ignore_eos
153155
154156 self .engine : AsyncLLMEngine = None
155157
@@ -331,6 +333,7 @@ def pick(key: str, default):
331333 top_k = int (pick ("top_k" , self .top_k )),
332334 max_tokens = int (max_tokens ) if max_tokens is not None else None ,
333335 stop = stop ,
336+ ignore_eos = self .ignore_eos ,
334337 )
335338
336339 async def _stream_chat (self , request_id : str , data : dict , http_request : Request ):
@@ -382,7 +385,11 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request)
382385 # Skip EOS token text for OpenAI API compatibility
383386 # Check if this token is an EOS token by comparing token_id with eos_token_ids
384387 eos_token_ids = self .engine .engine .eos_token_ids
385- is_eos_token = eos_token_ids and token_output .token_id in eos_token_ids
388+ is_eos_token = (
389+ not sampling_params .ignore_eos
390+ and eos_token_ids
391+ and token_output .token_id in eos_token_ids
392+ )
386393
387394 if not is_eos_token and token_output .token_text :
388395 # Send token
@@ -631,6 +638,13 @@ def parse_args():
631638 choices = ["DEBUG" , "INFO" , "WARNING" , "ERROR" , "CRITICAL" ],
632639 help = "Logging level" ,
633640 )
641+ parser .add_argument (
642+ "--ignore-eos" ,
643+ action = "store_true" ,
644+ dest = "ignore_eos" ,
645+ default = False ,
646+ help = "Ignore EOS token and continue generation" ,
647+ )
634648
635649 return parser .parse_args ()
636650
@@ -644,21 +658,22 @@ def main():
644658 server = InferenceServer (
645659 model_path = cfg .model ,
646660 device = device ,
647- dtype = cfg .dtype ,
648- tensor_parallel_size = cfg .tp ,
649- cache_type = cfg .cache_type ,
650- max_tokens = cfg .max_tokens ,
651- max_batch_size = cfg .max_batch_size ,
652- num_blocks = cfg .num_blocks ,
653- block_size = cfg .block_size ,
654- max_cache_len = cfg .max_cache_len ,
655- temperature = cfg .temperature ,
656- top_p = cfg .top_p ,
657- top_k = cfg .top_k ,
658- host = cfg .host ,
659- port = cfg .port ,
660- enable_graph = cfg .enable_graph ,
661- attn_backend = cfg .attn ,
661+ dtype = args .dtype ,
662+ tensor_parallel_size = args .tp ,
663+ cache_type = args .cache_type ,
664+ max_tokens = args .max_tokens ,
665+ max_batch_size = args .max_batch_size ,
666+ num_blocks = args .num_blocks ,
667+ block_size = args .block_size ,
668+ max_cache_len = args .max_cache_len ,
669+ temperature = args .temperature ,
670+ top_p = args .top_p ,
671+ top_k = args .top_k ,
672+ host = args .host ,
673+ port = args .port ,
674+ enable_graph = args .enable_graph ,
675+ attn_backend = args .attn ,
676+ ignore_eos = args .ignore_eos ,
662677 )
663678 server .start ()
664679
0 commit comments