@@ -393,16 +393,51 @@ def __init__(self, llm_model, inference_type, vllm_endpoint, **kwargs):
393393 self .model_path = llm_instance .model_path
394394
395395 self .llm = llm_model
396+ if self .inference_type == InferenceType .LOCAL :
397+ self .lock = asyncio .Lock ()
396398 if self .inference_type == InferenceType .VLLM :
397399 self .vllm_name = llm_model ().model_id
398400 if vllm_endpoint == "" :
399401 vllm_endpoint = os .getenv ("vLLM_ENDPOINT" , "http://localhost:8086" )
400402 self .vllm_endpoint = vllm_endpoint
401403
402404 async def run (self , chat_request , retrieved_nodes , node_parser_type , ** kwargs ):
403- response = await self .run_vllm (chat_request , retrieved_nodes , node_parser_type , ** kwargs )
405+ if self .inference_type == InferenceType .LOCAL :
406+ response = await self .run_local (chat_request , retrieved_nodes , node_parser_type , ** kwargs )
407+ elif self .inference_type == InferenceType .VLLM :
408+ response = await self .run_vllm (chat_request , retrieved_nodes , node_parser_type , ** kwargs )
409+ else :
410+ raise ValueError ("LLM inference_type not supported" )
404411 return response
405412
413+ async def run_local (self , chat_request , retrieved_nodes , node_parser_type , ** kwargs ):
414+ if self .llm () is None :
415+ # This could happen when User delete all LLMs through RESTful API
416+ raise ValueError ("No LLM available, please load LLM" )
417+ generate_kwargs = dict (
418+ temperature = chat_request .temperature ,
419+ do_sample = chat_request .temperature > 0.0 ,
420+ top_p = chat_request .top_p ,
421+ top_k = chat_request .top_k ,
422+ typical_p = chat_request .typical_p ,
423+ repetition_penalty = chat_request .repetition_penalty ,
424+ )
425+ self .llm ().generate_kwargs = generate_kwargs
426+ self .llm ().max_new_tokens = chat_request .max_tokens
427+ prompt_str = chatcompletion_to_chatml (chat_request )
428+ if chat_request .stream :
429+
430+ # Asynchronous generator
431+ async def generator ():
432+ async for chunk in local_stream_generator (self .lock , self .llm (), prompt_str , "" ):
433+ yield chunk or ""
434+ await asyncio .sleep (0 )
435+
436+ return generator ()
437+ else :
438+ result = self .llm ().complete (prompt_str )
439+ return result
440+
406441 async def run_vllm (self , chat_request , retrieved_nodes , node_parser_type , ** kwargs ):
407442 llm = OpenAILike (
408443 api_key = "fake" ,
0 commit comments