Skip to content

Commit fca3d76

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 688c93b commit fca3d76

56 files changed

Lines changed: 722 additions & 831 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

EdgeCraftRAG/edgecraftrag/api/v1/agent.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import json
55
import os
66
import time
7+
78
from edgecraftrag.api_schema import AgentCreateIn
89
from edgecraftrag.base import AgentType
10+
from edgecraftrag.config_repository import MilvusConfigRepository, save_agent_configurations
911
from edgecraftrag.context import ctx
1012
from edgecraftrag.env import AGENT_FILE
11-
from edgecraftrag.config_repository import MilvusConfigRepository, save_agent_configurations
1213
from fastapi import FastAPI, HTTPException, status
1314

1415
agent_app = FastAPI()
@@ -21,14 +22,16 @@ async def get_all_agents():
2122
agents = ctx.get_agent_mgr().get_agents()
2223
active_id = ctx.get_agent_mgr().get_active_agent_id()
2324
for k, agent in agents.items():
24-
out.append(AgentCreateIn(
25-
idx=agent.idx,
26-
name=agent.name,
27-
type=agent.comp_subtype,
28-
pipeline_idx=agent.pipeline_idx,
29-
configs=agent.configs,
30-
active=True if agent.idx == active_id else False
31-
))
25+
out.append(
26+
AgentCreateIn(
27+
idx=agent.idx,
28+
name=agent.name,
29+
type=agent.comp_subtype,
30+
pipeline_idx=agent.pipeline_idx,
31+
configs=agent.configs,
32+
active=True if agent.idx == active_id else False,
33+
)
34+
)
3235
return out
3336

3437

@@ -44,7 +47,7 @@ async def get_agent(name):
4447
type=agent.comp_subtype,
4548
pipeline_idx=agent.pipeline_idx,
4649
configs=agent.configs,
47-
active=isactive
50+
active=isactive,
4851
)
4952
else:
5053
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -95,7 +98,7 @@ async def delete_agent(name):
9598
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
9699

97100

98-
# GET Agent Type defualt configs
101+
# GET Agent Type default configs
99102
@agent_app.get(path="/v1/settings/agents/configs/{agent_type}")
100103
async def get_agent_default_configs(agent_type):
101104
try:

EdgeCraftRAG/edgecraftrag/api/v1/chatqna.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
# Copyright (C) 2024 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

4-
import requests
5-
import json
64
import asyncio
7-
from typing import List
5+
import json
86
from concurrent.futures import ThreadPoolExecutor
7+
from typing import List
8+
9+
import requests
910
from comps.cores.proto.api_protocol import ChatCompletionRequest
1011
from edgecraftrag.api_schema import RagOut
1112
from edgecraftrag.context import ctx
12-
from edgecraftrag.utils import serialize_contexts, stream_generator, chain_async_generators
13+
from edgecraftrag.utils import chain_async_generators, serialize_contexts, stream_generator
1314
from fastapi import Body, FastAPI, HTTPException, status
1415
from fastapi.responses import StreamingResponse
1516

@@ -25,7 +26,10 @@ async def retrieval(request: ChatCompletionRequest):
2526
if active_kb:
2627
request.user = active_kb
2728
else:
28-
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Retrieval needs to have an active knowledgebase")
29+
raise HTTPException(
30+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
31+
detail="Retrieval needs to have an active knowledgebase",
32+
)
2933
contexts = await ctx.get_pipeline_mgr().run_retrieve_postprocess(chat_request=request)
3034
serialized_contexts = serialize_contexts(contexts)
3135

@@ -63,16 +67,10 @@ async def chatqna(request: ChatCompletionRequest):
6367
request.model = generator.model_id
6468

6569
if request.stream:
66-
run_pipeline_gen, contexts = await ctx.get_pipeline_mgr().run_pipeline(
67-
chat_request=request
68-
)
69-
return StreamingResponse(
70-
save_session(sessionid, run_pipeline_gen), media_type="text/plain"
71-
)
70+
run_pipeline_gen, contexts = await ctx.get_pipeline_mgr().run_pipeline(chat_request=request)
71+
return StreamingResponse(save_session(sessionid, run_pipeline_gen), media_type="text/plain")
7272
else:
73-
ret, contexts = await ctx.get_pipeline_mgr().run_pipeline(
74-
chat_request=request
75-
)
73+
ret, contexts = await ctx.get_pipeline_mgr().run_pipeline(chat_request=request)
7674
ctx.get_session_mgr().save_current_message(sessionid, "assistant", str(ret))
7775
return str(ret)
7876

@@ -105,11 +103,11 @@ async def res_gen_json():
105103
yield token.replace("\n", "\\n")
106104

107105
# Reconstruct RagOut in stream response
108-
query_gen = stream_generator("{\"query\":\"" + request.messages + "\",")
106+
query_gen = stream_generator('{"query":"' + request.messages + '",')
109107

110108
s_contexts = json.dumps(serialize_contexts(contexts))
111-
context_gen = stream_generator("\"contexts\":" + s_contexts + ",\"response\":\"")
112-
final_gen = stream_generator("\"}")
109+
context_gen = stream_generator('"contexts":' + s_contexts + ',"response":"')
110+
final_gen = stream_generator('"}')
113111
output_gen = chain_async_generators([query_gen, context_gen, res_gen_json(), final_gen])
114112

115113
return StreamingResponse(output_gen, media_type="text/plain")

EdgeCraftRAG/edgecraftrag/api/v1/data.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
# Copyright (C) 2024 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import json
45
import os
56
from typing import List
67

78
from edgecraftrag.api_schema import DataIn, FilesIn
8-
from edgecraftrag.context import ctx
9-
from fastapi import FastAPI, File, HTTPException, UploadFile, status
10-
import json
119
from edgecraftrag.config_repository import MilvusConfigRepository
10+
from edgecraftrag.context import ctx
1211
from edgecraftrag.env import UI_DIRECTORY
12+
from fastapi import FastAPI, File, HTTPException, UploadFile, status
1313

1414
data_app = FastAPI()
1515

@@ -52,22 +52,22 @@ async def redindex_data():
5252
# Gets the current nodelist
5353
@data_app.get(path="/v1/data/nodes")
5454
async def get_nodes_with_kb(kb_name=None):
55-
node_lists = {}
55+
node_lists = {}
5656
active_pl = ctx.get_pipeline_mgr().get_active_pipeline()
5757
if kb_name:
5858
kb = ctx.get_knowledge_mgr().get_knowledge_base_by_name_or_id(kb_name)
5959
else:
6060
kb = ctx.get_knowledge_mgr().get_active_knowledge_base()
6161
if active_pl.indexer.comp_subtype == "faiss_vector":
62-
return active_pl.indexer.docstore.docs
62+
return active_pl.indexer.docstore.docs
6363
elif active_pl.indexer.comp_subtype == "milvus_vector":
6464
collection_name = kb.name + active_pl.name
65-
Milvus_node_list = MilvusConfigRepository.create_connection(collection_name,1, active_pl.indexer.vector_url)
65+
Milvus_node_list = MilvusConfigRepository.create_connection(collection_name, 1, active_pl.indexer.vector_url)
6666
results = Milvus_node_list.get_configs(output_fields=["text", "_node_content", "doc_id"])
6767
for node_list in results:
6868
text = node_list.get("text")
69-
node_content = json.loads(node_list.get("_node_content"))
70-
node_content["doc_id"]=node_list.get("doc_id")
69+
node_content = json.loads(node_list.get("_node_content"))
70+
node_content["doc_id"] = node_list.get("doc_id")
7171
node_content["text"] = text
7272
node_lists[node_content.get("id_")] = node_content
7373
return node_lists
@@ -81,12 +81,10 @@ async def get_nodes_by_document_name(document_name: str):
8181
all_nodes = await get_nodes_with_kb()
8282
matching_nodes = []
8383
for node in all_nodes.values() if isinstance(all_nodes, dict) else all_nodes:
84-
metadata = node.get('metadata', {}) if isinstance(node, dict) else getattr(node, 'metadata', {})
85-
node_file_name = metadata.get('file_name', '')
86-
node_file_path = metadata.get('file_path', '')
87-
if (node_file_name == document_name or
88-
document_name in node_file_name or
89-
document_name in node_file_path):
84+
metadata = node.get("metadata", {}) if isinstance(node, dict) else getattr(node, "metadata", {})
85+
node_file_name = metadata.get("file_name", "")
86+
node_file_path = metadata.get("file_path", "")
87+
if node_file_name == document_name or document_name in node_file_name or document_name in node_file_path:
9088
matching_nodes.append(node)
9189
return matching_nodes
9290

@@ -100,23 +98,20 @@ async def get_document_names():
10098

10199
documents = {}
102100
for node in all_nodes.values() if isinstance(all_nodes, dict) else all_nodes:
103-
metadata = node.get('metadata', {}) if isinstance(node, dict) else getattr(node, 'metadata', {})
104-
file_name = metadata.get('file_name')
105-
file_path = metadata.get('file_path')
101+
metadata = node.get("metadata", {}) if isinstance(node, dict) else getattr(node, "metadata", {})
102+
file_name = metadata.get("file_name")
103+
file_path = metadata.get("file_path")
106104
if file_name and file_name not in documents:
107105
documents[file_name] = {
108106
"file_name": file_name,
109107
"file_path": file_path,
110-
"file_type": metadata.get('file_type', 'unknown'),
111-
"chunk_count": 0
108+
"file_type": metadata.get("file_type", "unknown"),
109+
"chunk_count": 0,
112110
}
113111
if file_name:
114112
documents[file_name]["chunk_count"] += 1
115113

116-
return {
117-
"total_documents": len(documents),
118-
"documents": list(documents.values())
119-
}
114+
return {"total_documents": len(documents), "documents": list(documents.values())}
120115

121116

122117
# Upload files by a list of file_path

0 commit comments

Comments
 (0)