diff --git a/DocIndexRetriever/README.md b/DocIndexRetriever/README.md index 9ea6c4e37f..12e5e6d109 100644 --- a/DocIndexRetriever/README.md +++ b/DocIndexRetriever/README.md @@ -80,9 +80,22 @@ Example usage: ```python url = "http://{host_ip}:{port}/v1/retrievaltool".format(host_ip=host_ip, port=port) payload = { - "messages": query, + "messages": query, # must be a string, this is a required field "k": 5, # retriever top k "top_n": 2, # reranker top n } response = requests.post(url, json=payload) ``` + +**Note**: `messages` is the required field. You can also pass in parameters for the retriever and reranker in the request. The parameters that can changed are listed below. + + 1. retriever + * search_type: str = "similarity" + * k: int = 4 + * distance_threshold: Optional[float] = None + * fetch_k: int = 20 + * lambda_mult: float = 0.5 + * score_threshold: float = 0.2 + + 2. reranker + * top_n: int = 1 diff --git a/DocIndexRetriever/docker_compose/intel/cpu/xeon/README.md b/DocIndexRetriever/docker_compose/intel/cpu/xeon/README.md index f3cff5263a..9f20546a3d 100644 --- a/DocIndexRetriever/docker_compose/intel/cpu/xeon/README.md +++ b/DocIndexRetriever/docker_compose/intel/cpu/xeon/README.md @@ -97,9 +97,6 @@ Retrieval from KnowledgeBase curl http://${host_ip}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{ "messages": "Explain the OPEA project?" }' - -# expected output -{"id":"354e62c703caac8c547b3061433ec5e8","reranked_docs":[{"id":"06d5a5cefc06cf9a9e0b5fa74a9f233c","text":"Close SearchsearchMenu WikiNewsCommunity Daysx-twitter linkedin github searchStreamlining implementation of enterprise-grade Generative AIEfficiently integrate secure, performant, and cost-effective Generative AI workflows into business value.TODAYOPEA..."}],"initial_query":"Explain the OPEA project?"} ``` **Note**: `messages` is the required field. You can also pass in parameters for the retriever and reranker in the request. The parameters that can changed are listed below. @@ -128,7 +125,7 @@ curl http://${host_ip}:8889/v1/retrievaltool -X POST -H "Content-Type: applicati # embedding microservice curl http://${host_ip}:6000/v1/embeddings \ -X POST \ - -d '{"text":"Explain the OPEA project"}' \ + -d '{"messages":"Explain the OPEA project"}' \ -H 'Content-Type: application/json' > query docker container logs embedding-server diff --git a/DocIndexRetriever/docker_compose/intel/cpu/xeon/compose.yaml b/DocIndexRetriever/docker_compose/intel/cpu/xeon/compose.yaml index 6a389de450..f05d4ab715 100644 --- a/DocIndexRetriever/docker_compose/intel/cpu/xeon/compose.yaml +++ b/DocIndexRetriever/docker_compose/intel/cpu/xeon/compose.yaml @@ -13,10 +13,11 @@ services: dataprep-redis-service: image: ${REGISTRY:-opea}/dataprep:${TAG:-latest} container_name: dataprep-redis-server - # volumes: - # - $WORKDIR/GenAIExamples/DocIndexRetriever/docker_image_build/GenAIComps/comps:/home/user/comps depends_on: - - redis-vector-db + redis-vector-db: + condition: service_started + tei-embedding-service: + condition: service_healthy ports: - "6007:5000" - "6008:6008" @@ -28,7 +29,7 @@ services: REDIS_URL: ${REDIS_URL} REDIS_HOST: ${REDIS_HOST} INDEX_NAME: ${INDEX_NAME} - TEI_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT} + TEI_EMBEDDING_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT} HUGGINGFACEHUB_API_TOKEN: ${HUGGINGFACEHUB_API_TOKEN} LOGFLAG: ${LOGFLAG} tei-embedding-service: @@ -54,8 +55,6 @@ services: embedding: image: ${REGISTRY:-opea}/embedding:${TAG:-latest} container_name: embedding-server - # volumes: - # - $WORKDIR/GenAIExamples/DocIndexRetriever/docker_image_build/GenAIComps/comps:/home/comps ports: - "6000:6000" ipc: host @@ -114,8 +113,6 @@ services: reranking: image: ${REGISTRY:-opea}/reranking:${TAG:-latest} container_name: reranking-tei-xeon-server - # volumes: - # - $WORKDIR/GenAIExamples/DocIndexRetriever/docker_image_build/GenAIComps/comps:/home/user/comps depends_on: tei-reranking-service: condition: service_healthy diff --git a/DocIndexRetriever/docker_compose/intel/hpu/gaudi/README.md b/DocIndexRetriever/docker_compose/intel/hpu/gaudi/README.md index 12380afcf0..01a4dceb38 100644 --- a/DocIndexRetriever/docker_compose/intel/hpu/gaudi/README.md +++ b/DocIndexRetriever/docker_compose/intel/hpu/gaudi/README.md @@ -87,9 +87,6 @@ Retrieval from KnowledgeBase curl http://${host_ip}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{ "messages": "Explain the OPEA project?" }' - -# expected output -{"id":"354e62c703caac8c547b3061433ec5e8","reranked_docs":[{"id":"06d5a5cefc06cf9a9e0b5fa74a9f233c","text":"Close SearchsearchMenu WikiNewsCommunity Daysx-twitter linkedin github searchStreamlining implementation of enterprise-grade Generative AIEfficiently integrate secure, performant, and cost-effective Generative AI workflows into business value.TODAYOPEA..."}],"initial_query":"Explain the OPEA project?"} ``` **Note**: `messages` is the required field. You can also pass in parameters for the retriever and reranker in the request. The parameters that can changed are listed below. @@ -118,7 +115,7 @@ curl http://${host_ip}:8889/v1/retrievaltool -X POST -H "Content-Type: applicati # embedding microservice curl http://${host_ip}:6000/v1/embeddings \ -X POST \ - -d '{"text":"Explain the OPEA project"}' \ + -d '{"messages":"Explain the OPEA project"}' \ -H 'Content-Type: application/json' > query docker container logs embedding-server diff --git a/DocIndexRetriever/docker_compose/intel/hpu/gaudi/compose.yaml b/DocIndexRetriever/docker_compose/intel/hpu/gaudi/compose.yaml index 777fbb45ca..5bb9177d5a 100644 --- a/DocIndexRetriever/docker_compose/intel/hpu/gaudi/compose.yaml +++ b/DocIndexRetriever/docker_compose/intel/hpu/gaudi/compose.yaml @@ -15,8 +15,10 @@ services: image: ${REGISTRY:-opea}/dataprep:${TAG:-latest} container_name: dataprep-redis-server depends_on: - - redis-vector-db - - tei-embedding-service + redis-vector-db: + condition: service_started + tei-embedding-service: + condition: service_healthy ports: - "6007:5000" environment: @@ -25,7 +27,7 @@ services: https_proxy: ${https_proxy} REDIS_URL: ${REDIS_URL} INDEX_NAME: ${INDEX_NAME} - TEI_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT} + TEI_EMBEDDING_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT} HUGGINGFACEHUB_API_TOKEN: ${HUGGINGFACEHUB_API_TOKEN} tei-embedding-service: image: ghcr.io/huggingface/tei-gaudi:1.5.0 @@ -87,6 +89,8 @@ services: INDEX_NAME: ${INDEX_NAME} LOGFLAG: ${LOGFLAG} RETRIEVER_COMPONENT_NAME: "OPEA_RETRIEVER_REDIS" + TEI_EMBEDDING_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT} + HUGGINGFACEHUB_API_TOKEN: ${HUGGINGFACEHUB_API_TOKEN} restart: unless-stopped tei-reranking-service: image: ghcr.io/huggingface/text-embeddings-inference:cpu-1.6 diff --git a/DocIndexRetriever/retrieval_tool.py b/DocIndexRetriever/retrieval_tool.py index 4639517c14..27e22f105d 100644 --- a/DocIndexRetriever/retrieval_tool.py +++ b/DocIndexRetriever/retrieval_tool.py @@ -8,9 +8,8 @@ from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType from comps.cores.proto.api_protocol import ChatCompletionRequest, EmbeddingRequest -from comps.cores.proto.docarray import LLMParamsDoc, RerankedDoc, RerankerParms, RetrieverParms, TextDoc +from comps.cores.proto.docarray import LLMParams, LLMParamsDoc, RerankedDoc, RerankerParms, RetrieverParms, TextDoc from fastapi import Request -from fastapi.responses import StreamingResponse MEGA_SERVICE_PORT = os.getenv("MEGA_SERVICE_PORT", 8889) EMBEDDING_SERVICE_HOST_IP = os.getenv("EMBEDDING_SERVICE_HOST_IP", "0.0.0.0") @@ -22,41 +21,75 @@ def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs): - print(f"Inputs to {cur_node}: {inputs}") + print(f"*** Inputs to {cur_node}:\n{inputs}") + print("--" * 50) for key, value in kwargs.items(): print(f"{key}: {value}") + if self.services[cur_node].service_type == ServiceType.EMBEDDING: + inputs["input"] = inputs["text"] + del inputs["text"] + elif self.services[cur_node].service_type == ServiceType.RETRIEVER: + # input is EmbedDoc + """Class EmbedDoc(BaseDoc): + + text: Union[str, List[str]] + embedding: Union[conlist(float, min_length=0), List[conlist(float, min_length=0)]] + search_type: str = "similarity" + k: int = 4 + distance_threshold: Optional[float] = None + fetch_k: int = 20 + lambda_mult: float = 0.5 + score_threshold: float = 0.2 + constraints: Optional[Union[Dict[str, Any], List[Dict[str, Any]], None]] = None + index_name: Optional[str] = None + """ + # prepare the retriever params + retriever_parameters = kwargs.get("retriever_parameters", None) + if retriever_parameters: + inputs.update(retriever_parameters.dict()) + elif self.services[cur_node].service_type == ServiceType.RERANK: + # input is SearchedDoc + """Class SearchedDoc(BaseDoc): + + retrieved_docs: DocList[TextDoc] + initial_query: str + top_n: int = 1 + """ + # prepare the reranker params + reranker_parameters = kwargs.get("reranker_parameters", None) + if reranker_parameters: + inputs.update(reranker_parameters.dict()) + print(f"*** Formatted Inputs to {cur_node}:\n{inputs}") + print("--" * 50) return inputs def align_outputs(self, data, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs): - next_data = {} - if self.services[cur_node].service_type == ServiceType.EMBEDDING: - # turn into chat completion request - # next_data = {"text": inputs["input"], "embedding": [item["embedding"] for item in data["data"]]} - print("Assembing output from Embedding for next node...") - print("Inputs to Embedding: ", inputs) - print("Keyword arguments: ") - for key, value in kwargs.items(): - print(f"{key}: {value}") - - next_data = { - "input": inputs["input"], - "messages": inputs["input"], - "embedding": [item["embedding"] for item in data["data"]], - "k": kwargs["k"] if "k" in kwargs else 4, - "search_type": kwargs["search_type"] if "search_type" in kwargs else "similarity", - "distance_threshold": kwargs["distance_threshold"] if "distance_threshold" in kwargs else None, - "fetch_k": kwargs["fetch_k"] if "fetch_k" in kwargs else 20, - "lambda_mult": kwargs["lambda_mult"] if "lambda_mult" in kwargs else 0.5, - "score_threshold": kwargs["score_threshold"] if "score_threshold" in kwargs else 0.2, - "top_n": kwargs["top_n"] if "top_n" in kwargs else 1, - } - - print("Output from Embedding for next node:\n", next_data) + print(f"*** Direct Outputs from {cur_node}:\n{data}") + print("--" * 50) + if self.services[cur_node].service_type == ServiceType.EMBEDDING: + # direct output from Embedding microservice is EmbeddingResponse + """ + class EmbeddingResponse(BaseModel): + object: str = "list" + model: Optional[str] = None + data: List[EmbeddingResponseData] + usage: Optional[UsageInfo] = None + + class EmbeddingResponseData(BaseModel): + index: int + object: str = "embedding" + embedding: Union[List[float], str] + """ + # turn it into EmbedDoc + assert isinstance(data["data"], list) + next_data = {"text": inputs["input"], "embedding": data["data"][0]["embedding"]} # EmbedDoc else: next_data = data + print(f"*** Formatted Output from {cur_node} for next node:\n", next_data) + print("--" * 50) return next_data @@ -100,54 +133,41 @@ def add_remote_service(self): self.megaservice.flow_to(retriever, rerank) async def handle_request(self, request: Request): - def parser_input(data, TypeClass, key): - chat_request = None - try: - chat_request = TypeClass.parse_obj(data) - query = getattr(chat_request, key) - except: - query = None - return query, chat_request - data = await request.json() - query = None - for key, TypeClass in zip(["text", "input", "messages"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]): - query, chat_request = parser_input(data, TypeClass, key) - if query is not None: - break - if query is None: - raise ValueError(f"Unknown request type: {data}") - if chat_request is None: - raise ValueError(f"Unknown request type: {data}") - - if isinstance(chat_request, ChatCompletionRequest): - initial_inputs = { - "messages": query, - "input": query, # has to be input due to embedding expects either input or text - "search_type": chat_request.search_type if chat_request.search_type else "similarity", - "k": chat_request.k if chat_request.k else 4, - "distance_threshold": chat_request.distance_threshold if chat_request.distance_threshold else None, - "fetch_k": chat_request.fetch_k if chat_request.fetch_k else 20, - "lambda_mult": chat_request.lambda_mult if chat_request.lambda_mult else 0.5, - "score_threshold": chat_request.score_threshold if chat_request.score_threshold else 0.2, - "top_n": chat_request.top_n if chat_request.top_n else 1, - } - - kwargs = { - "search_type": chat_request.search_type if chat_request.search_type else "similarity", - "k": chat_request.k if chat_request.k else 4, - "distance_threshold": chat_request.distance_threshold if chat_request.distance_threshold else None, - "fetch_k": chat_request.fetch_k if chat_request.fetch_k else 20, - "lambda_mult": chat_request.lambda_mult if chat_request.lambda_mult else 0.5, - "score_threshold": chat_request.score_threshold if chat_request.score_threshold else 0.2, - "top_n": chat_request.top_n if chat_request.top_n else 1, - } - result_dict, runtime_graph = await self.megaservice.schedule( - initial_inputs=initial_inputs, - **kwargs, - ) - else: - result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"input": query}) + chat_request = ChatCompletionRequest.parse_obj(data) + + prompt = chat_request.messages + + # dummy llm params + parameters = LLMParams( + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + chat_template=chat_request.chat_template if chat_request.chat_template else None, + model=chat_request.model if chat_request.model else None, + ) + + retriever_parameters = RetrieverParms( + search_type=chat_request.search_type if chat_request.search_type else "similarity", + k=chat_request.k if chat_request.k else 4, + distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None, + fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20, + lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5, + score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2, + ) + reranker_parameters = RerankerParms( + top_n=chat_request.top_n if chat_request.top_n else 1, + ) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"text": prompt}, + llm_parameters=parameters, + retriever_parameters=retriever_parameters, + reranker_parameters=reranker_parameters, + ) last_node = runtime_graph.all_leaves()[-1] response = result_dict[last_node] diff --git a/DocIndexRetriever/tests/test.py b/DocIndexRetriever/tests/test.py index ba74827fa6..daa7825d59 100644 --- a/DocIndexRetriever/tests/test.py +++ b/DocIndexRetriever/tests/test.py @@ -2,16 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 import os +from typing import Any import requests -def search_knowledge_base(query: str) -> str: +def search_knowledge_base(query: str, args: Any) -> str: """Search the knowledge base for a specific query.""" - url = os.environ.get("RETRIEVAL_TOOL_URL") + url = os.environ.get("RETRIEVAL_TOOL_URL", "http://localhost:8889/v1/retrievaltool") print(url) proxies = {"http": ""} - payload = {"messages": query, "k": 5, "top_n": 2} + payload = {"messages": query, "k": args.k, "top_n": args.top_n} response = requests.post(url, json=payload, proxies=proxies) print(response) if "documents" in response.json(): @@ -33,6 +34,16 @@ def search_knowledge_base(query: str) -> str: if __name__ == "__main__": - resp = search_knowledge_base("What is OPEA?") - # resp = search_knowledge_base("Thriller") + import argparse + + parser = argparse.ArgumentParser(description="Test the knowledge base search.") + parser.add_argument("--k", type=int, default=5, help="retriever top k") + parser.add_argument("--top_n", type=int, default=2, help="reranker top n") + args = parser.parse_args() + + resp = search_knowledge_base("What is OPEA?", args) + print(resp) + + if not resp.startswith("Error"): + print("Test successful!") diff --git a/DocIndexRetriever/tests/test_compose_milvus_on_gaudi.sh b/DocIndexRetriever/tests/test_compose_milvus_on_gaudi.sh index 2e993604e4..40633be8f4 100644 --- a/DocIndexRetriever/tests/test_compose_milvus_on_gaudi.sh +++ b/DocIndexRetriever/tests/test_compose_milvus_on_gaudi.sh @@ -88,10 +88,10 @@ function validate_megaservice() { fi # Curl the Mega Service - echo "================Testing retriever service: Text Request ================" + echo "================Testing retriever service ================" cd $WORKPATH/tests local CONTENT=$(http_proxy="" curl http://${ip_address}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{ - "text": "Explain the OPEA project?" + "messages": "Explain the OPEA project?" }') local EXIT_CODE=$(validate "$CONTENT" "OPEA" "doc-index-retriever-service-gaudi") diff --git a/DocIndexRetriever/tests/test_compose_milvus_on_xeon.sh b/DocIndexRetriever/tests/test_compose_milvus_on_xeon.sh index 70bdbc8ce2..59b1c40aa0 100755 --- a/DocIndexRetriever/tests/test_compose_milvus_on_xeon.sh +++ b/DocIndexRetriever/tests/test_compose_milvus_on_xeon.sh @@ -87,10 +87,10 @@ function validate_megaservice() { fi # Curl the Mega Service - echo "================Testing retriever service: Text Request ================" + echo "================Testing retriever service ================" cd $WORKPATH/tests local CONTENT=$(http_proxy="" curl http://${ip_address}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{ - "text": "Explain the OPEA project?" + "messages": "Explain the OPEA project?" }') local EXIT_CODE=$(validate "$CONTENT" "OPEA" "doc-index-retriever-service-xeon") diff --git a/DocIndexRetriever/tests/test_compose_on_gaudi.sh b/DocIndexRetriever/tests/test_compose_on_gaudi.sh index d6dd8a7138..4b9de5a2df 100644 --- a/DocIndexRetriever/tests/test_compose_on_gaudi.sh +++ b/DocIndexRetriever/tests/test_compose_on_gaudi.sh @@ -38,7 +38,6 @@ function start_services() { export RERANK_MODEL_ID="BAAI/bge-reranker-base" export TEI_EMBEDDING_ENDPOINT="http://${ip_address}:8090" export TEI_RERANKING_ENDPOINT="http://${ip_address}:8808" - export TGI_LLM_ENDPOINT="http://${ip_address}:8008" export REDIS_URL="redis://${ip_address}:6379" export INDEX_NAME="rag-redis" export HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN} @@ -46,14 +45,13 @@ function start_services() { export EMBEDDING_SERVICE_HOST_IP=${ip_address} export RETRIEVER_SERVICE_HOST_IP=${ip_address} export RERANK_SERVICE_HOST_IP=${ip_address} - export LLM_SERVICE_HOST_IP=${ip_address} export host_ip=${ip_address} export RERANK_TYPE="tei" export LOGFLAG=true # Start Docker Containers docker compose up -d - sleep 30 + sleep 1m echo "Docker services started!" } @@ -86,11 +84,13 @@ function validate_megaservice() { fi # Curl the Mega Service - echo "==============Testing retriever service: Text Request=================" - local CONTENT=$(curl http://${ip_address}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{ - "text": "Explain the OPEA project?" + echo "==============Testing retriever service=================" + local CONTENT=$(http_proxy="" curl http://${ip_address}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{ + "messages": "Explain the OPEA project?" }') + local EXIT_CODE=$(validate "$CONTENT" "OPEA" "doc-index-retriever-service-gaudi") + echo "$EXIT_CODE" local EXIT_CODE="${EXIT_CODE:0-1}" echo "return value is $EXIT_CODE" diff --git a/DocIndexRetriever/tests/test_compose_on_xeon.sh b/DocIndexRetriever/tests/test_compose_on_xeon.sh index 06d2ef3dcd..467411653c 100644 --- a/DocIndexRetriever/tests/test_compose_on_xeon.sh +++ b/DocIndexRetriever/tests/test_compose_on_xeon.sh @@ -53,7 +53,7 @@ function start_services() { # Start Docker Containers docker compose up -d - sleep 5m + sleep 1m echo "Docker services started!" } @@ -86,10 +86,11 @@ function validate_megaservice() { fi # Curl the Mega Service - echo "================Testing retriever service: Text Request ================" + echo "================Testing retriever service ================" cd $WORKPATH/tests + local CONTENT=$(http_proxy="" curl http://${ip_address}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{ - "text": "Explain the OPEA project?" + "messages": "Explain the OPEA project?" }') local EXIT_CODE=$(validate "$CONTENT" "OPEA" "doc-index-retriever-service-xeon") @@ -128,6 +129,7 @@ function main() { if [[ "$IMAGE_REPO" == "opea" ]]; then build_docker_images; fi echo "Dump current docker ps" docker ps + start_time=$(date +%s) start_services end_time=$(date +%s) diff --git a/DocIndexRetriever/tests/test_compose_without_rerank_on_xeon.sh b/DocIndexRetriever/tests/test_compose_without_rerank_on_xeon.sh index 94b298231f..c0e32c4e93 100644 --- a/DocIndexRetriever/tests/test_compose_without_rerank_on_xeon.sh +++ b/DocIndexRetriever/tests/test_compose_without_rerank_on_xeon.sh @@ -80,10 +80,10 @@ function validate_megaservice() { fi # Curl the Mega Service - echo "================Testing retriever service: Text Request ================" + echo "================Testing retriever service ================" cd $WORKPATH/tests local CONTENT=$(http_proxy="" curl http://${ip_address}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{ - "text": "Explain the OPEA project?" + "messages": "Explain the OPEA project?" }') local EXIT_CODE=$(validate "$CONTENT" "OPEA" "doc-index-retriever-service-xeon")