Skip to content

Commit 0591b82

Browse files
xiguiwcogniware-devops
authored andcommitted
CodeGen update input prompt template (opea-project#1997)
Signed-off-by: Wang, Xigui <xigui.wang@intel.com> Signed-off-by: cogniware-devops <ambarish.desai@cogniware.ai>
1 parent 25cb9f9 commit 0591b82

1 file changed

Lines changed: 13 additions & 1 deletion

File tree

CodeGen/codegen.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
REDIS_RETRIEVER_PORT = int(os.getenv("REDIS_RETRIEVER_PORT", 7000))
3030
TEI_EMBEDDING_HOST_IP = os.getenv("TEI_EMBEDDING_HOST_IP", "0.0.0.0")
3131
EMBEDDER_PORT = int(os.getenv("EMBEDDER_PORT", 6000))
32+
LLM_MODEL_ID = os.getenv("LLM_MODEL_ID", "Qwen/Qwen2.5-Coder-7B-Instruct")
3233
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", None)
3334

3435
grader_prompt = """You are a grader assessing relevance of a retrieved document to a user question. \n
@@ -67,11 +68,22 @@ def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **k
6768
inputs["input"] = inputs["query"]
6869

6970
# Check if the current service type is RETRIEVER
70-
if self.services[cur_node].service_type == ServiceType.RETRIEVER:
71+
elif self.services[cur_node].service_type == ServiceType.RETRIEVER:
7172
# Extract the embedding from the inputs
7273
embedding = inputs["data"][0]["embedding"]
7374
# Align the inputs for the retriever service
7475
inputs = {"index_name": llm_parameters_dict["index_name"], "text": self.input_query, "embedding": embedding}
76+
elif self.services[cur_node].service_type == ServiceType.LLM:
77+
# convert TGI/vLLM to unified OpenAI /v1/chat/completions format
78+
next_inputs = {}
79+
next_inputs["model"] = LLM_MODEL_ID
80+
next_inputs["messages"] = [{"role": "user", "content": inputs["query"]}]
81+
next_inputs["max_tokens"] = llm_parameters_dict["max_tokens"]
82+
next_inputs["top_p"] = llm_parameters_dict["top_p"]
83+
next_inputs["stream"] = inputs["stream"]
84+
next_inputs["frequency_penalty"] = inputs["frequency_penalty"]
85+
next_inputs["temperature"] = inputs["temperature"]
86+
inputs = next_inputs
7587

7688
return inputs
7789

0 commit comments

Comments
 (0)