|
29 | 29 | REDIS_RETRIEVER_PORT = int(os.getenv("REDIS_RETRIEVER_PORT", 7000)) |
30 | 30 | TEI_EMBEDDING_HOST_IP = os.getenv("TEI_EMBEDDING_HOST_IP", "0.0.0.0") |
31 | 31 | EMBEDDER_PORT = int(os.getenv("EMBEDDER_PORT", 6000)) |
| 32 | +LLM_MODEL_ID = os.getenv("LLM_MODEL_ID", "Qwen/Qwen2.5-Coder-7B-Instruct") |
32 | 33 | OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", None) |
33 | 34 |
|
34 | 35 | grader_prompt = """You are a grader assessing relevance of a retrieved document to a user question. \n |
@@ -67,11 +68,22 @@ def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **k |
67 | 68 | inputs["input"] = inputs["query"] |
68 | 69 |
|
69 | 70 | # Check if the current service type is RETRIEVER |
70 | | - if self.services[cur_node].service_type == ServiceType.RETRIEVER: |
| 71 | + elif self.services[cur_node].service_type == ServiceType.RETRIEVER: |
71 | 72 | # Extract the embedding from the inputs |
72 | 73 | embedding = inputs["data"][0]["embedding"] |
73 | 74 | # Align the inputs for the retriever service |
74 | 75 | inputs = {"index_name": llm_parameters_dict["index_name"], "text": self.input_query, "embedding": embedding} |
| 76 | + elif self.services[cur_node].service_type == ServiceType.LLM: |
| 77 | + # convert TGI/vLLM to unified OpenAI /v1/chat/completions format |
| 78 | + next_inputs = {} |
| 79 | + next_inputs["model"] = LLM_MODEL_ID |
| 80 | + next_inputs["messages"] = [{"role": "user", "content": inputs["query"]}] |
| 81 | + next_inputs["max_tokens"] = llm_parameters_dict["max_tokens"] |
| 82 | + next_inputs["top_p"] = llm_parameters_dict["top_p"] |
| 83 | + next_inputs["stream"] = inputs["stream"] |
| 84 | + next_inputs["frequency_penalty"] = inputs["frequency_penalty"] |
| 85 | + next_inputs["temperature"] = inputs["temperature"] |
| 86 | + inputs = next_inputs |
75 | 87 |
|
76 | 88 | return inputs |
77 | 89 |
|
|
0 commit comments