Skip to content

Commit d28f516

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 7ac16dd commit d28f516

56 files changed

Lines changed: 731 additions & 851 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

EdgeCraftRAG/edgecraftrag/api/v1/agent.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import json
55
import os
66
import time
7+
78
from edgecraftrag.api_schema import AgentCreateIn
89
from edgecraftrag.base import AgentType
10+
from edgecraftrag.config_repository import MilvusConfigRepository, save_agent_configurations
911
from edgecraftrag.context import ctx
1012
from edgecraftrag.env import AGENT_FILE
11-
from edgecraftrag.config_repository import MilvusConfigRepository, save_agent_configurations
1213
from fastapi import FastAPI, HTTPException, status
1314

1415
agent_app = FastAPI()
@@ -21,14 +22,16 @@ async def get_all_agents():
2122
agents = ctx.get_agent_mgr().get_agents()
2223
active_id = ctx.get_agent_mgr().get_active_agent_id()
2324
for k, agent in agents.items():
24-
out.append(AgentCreateIn(
25-
idx=agent.idx,
26-
name=agent.name,
27-
type=agent.comp_subtype,
28-
pipeline_idx=agent.pipeline_idx,
29-
configs=agent.configs,
30-
active=True if agent.idx == active_id else False
31-
))
25+
out.append(
26+
AgentCreateIn(
27+
idx=agent.idx,
28+
name=agent.name,
29+
type=agent.comp_subtype,
30+
pipeline_idx=agent.pipeline_idx,
31+
configs=agent.configs,
32+
active=True if agent.idx == active_id else False,
33+
)
34+
)
3235
return out
3336

3437

@@ -44,7 +47,7 @@ async def get_agent(name):
4447
type=agent.comp_subtype,
4548
pipeline_idx=agent.pipeline_idx,
4649
configs=agent.configs,
47-
active=isactive
50+
active=isactive,
4851
)
4952
else:
5053
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -95,7 +98,7 @@ async def delete_agent(name):
9598
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
9699

97100

98-
# GET Agent Type defualt configs
101+
# GET Agent Type default configs
99102
@agent_app.get(path="/v1/settings/agents/configs/{agent_type}")
100103
async def get_agent_default_configs(agent_type):
101104
try:

EdgeCraftRAG/edgecraftrag/api/v1/chatqna.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
# Copyright (C) 2024 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

4-
import requests
5-
import json
64
import asyncio
7-
from typing import List
5+
import json
86
from concurrent.futures import ThreadPoolExecutor
7+
from typing import List
8+
9+
import requests
910
from comps.cores.proto.api_protocol import ChatCompletionRequest
1011
from edgecraftrag.api_schema import RagOut
1112
from edgecraftrag.context import ctx
12-
from edgecraftrag.utils import serialize_contexts, stream_generator, chain_async_generators
13+
from edgecraftrag.utils import chain_async_generators, serialize_contexts, stream_generator
1314
from fastapi import Body, FastAPI, HTTPException, status
1415
from fastapi.responses import StreamingResponse
1516

@@ -25,7 +26,10 @@ async def retrieval(request: ChatCompletionRequest):
2526
if active_kb:
2627
request.user = active_kb
2728
else:
28-
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Retrieval needs to have an active knowledgebase")
29+
raise HTTPException(
30+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
31+
detail="Retrieval needs to have an active knowledgebase",
32+
)
2933
contexts = await ctx.get_pipeline_mgr().run_retrieve_postprocess(chat_request=request)
3034
serialized_contexts = serialize_contexts(contexts)
3135

@@ -63,16 +67,10 @@ async def chatqna(request: ChatCompletionRequest):
6367
request.model = generator.model_id
6468

6569
if request.stream:
66-
run_pipeline_gen, contexts = await ctx.get_pipeline_mgr().run_pipeline(
67-
chat_request=request
68-
)
69-
return StreamingResponse(
70-
save_session(sessionid, run_pipeline_gen), media_type="text/plain"
71-
)
70+
run_pipeline_gen, contexts = await ctx.get_pipeline_mgr().run_pipeline(chat_request=request)
71+
return StreamingResponse(save_session(sessionid, run_pipeline_gen), media_type="text/plain")
7272
else:
73-
ret, contexts = await ctx.get_pipeline_mgr().run_pipeline(
74-
chat_request=request
75-
)
73+
ret, contexts = await ctx.get_pipeline_mgr().run_pipeline(chat_request=request)
7674
ctx.get_session_mgr().save_current_message(sessionid, "assistant", str(ret))
7775
return str(ret)
7876

@@ -105,11 +103,11 @@ async def res_gen_json():
105103
yield token.replace("\n", "\\n")
106104

107105
# Reconstruct RagOut in stream response
108-
query_gen = stream_generator("{\"query\":\"" + request.messages + "\",")
106+
query_gen = stream_generator('{"query":"' + request.messages + '",')
109107

110108
s_contexts = json.dumps(serialize_contexts(contexts))
111-
context_gen = stream_generator("\"contexts\":" + s_contexts + ",\"response\":\"")
112-
final_gen = stream_generator("\"}")
109+
context_gen = stream_generator('"contexts":' + s_contexts + ',"response":"')
110+
final_gen = stream_generator('"}')
113111
output_gen = chain_async_generators([query_gen, context_gen, res_gen_json(), final_gen])
114112

115113
return StreamingResponse(output_gen, media_type="text/plain")

EdgeCraftRAG/edgecraftrag/api/v1/data.py

Lines changed: 31 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
# Copyright (C) 2024 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import json
45
import os
56
from typing import List
67

78
from edgecraftrag.api_schema import DataIn, FilesIn
8-
from edgecraftrag.context import ctx
9-
from fastapi import FastAPI, File, HTTPException, UploadFile, status
10-
import json
119
from edgecraftrag.config_repository import MilvusConfigRepository
10+
from edgecraftrag.context import ctx
1211
from edgecraftrag.env import UI_DIRECTORY
12+
from fastapi import FastAPI, File, HTTPException, UploadFile, status
1313

1414
data_app = FastAPI()
1515

@@ -52,22 +52,22 @@ async def redindex_data():
5252
# Gets the current nodelist
5353
@data_app.get(path="/v1/data/nodes")
5454
async def get_nodes_with_kb(kb_name=None):
55-
node_lists = {}
55+
node_lists = {}
5656
active_pl = ctx.get_pipeline_mgr().get_active_pipeline()
5757
if kb_name:
5858
kb = ctx.get_knowledge_mgr().get_knowledge_base_by_name_or_id(kb_name)
5959
else:
6060
kb = ctx.get_knowledge_mgr().get_active_knowledge_base()
6161
if active_pl.indexer.comp_subtype == "faiss_vector":
62-
return active_pl.indexer.docstore.docs
62+
return active_pl.indexer.docstore.docs
6363
elif active_pl.indexer.comp_subtype == "milvus_vector":
6464
collection_name = kb.name + active_pl.name
65-
Milvus_node_list = MilvusConfigRepository.create_connection(collection_name,1, active_pl.indexer.vector_url)
65+
Milvus_node_list = MilvusConfigRepository.create_connection(collection_name, 1, active_pl.indexer.vector_url)
6666
results = Milvus_node_list.get_configs(output_fields=["text", "_node_content", "doc_id"])
6767
for node_list in results:
6868
text = node_list.get("text")
69-
node_content = json.loads(node_list.get("_node_content"))
70-
node_content["doc_id"]=node_list.get("doc_id")
69+
node_content = json.loads(node_list.get("_node_content"))
70+
node_content["doc_id"] = node_list.get("doc_id")
7171
node_content["text"] = text
7272
node_lists[node_content.get("id_")] = node_content
7373
return node_lists
@@ -81,12 +81,10 @@ async def get_nodes_by_document_name(document_name: str):
8181
all_nodes = await get_nodes_with_kb()
8282
matching_nodes = []
8383
for node in all_nodes.values() if isinstance(all_nodes, dict) else all_nodes:
84-
metadata = node.get('metadata', {}) if isinstance(node, dict) else getattr(node, 'metadata', {})
85-
node_file_name = metadata.get('file_name', '')
86-
node_file_path = metadata.get('file_path', '')
87-
if (node_file_name == document_name or
88-
document_name in node_file_name or
89-
document_name in node_file_path):
84+
metadata = node.get("metadata", {}) if isinstance(node, dict) else getattr(node, "metadata", {})
85+
node_file_name = metadata.get("file_name", "")
86+
node_file_path = metadata.get("file_path", "")
87+
if node_file_name == document_name or document_name in node_file_name or document_name in node_file_path:
9088
matching_nodes.append(node)
9189
return matching_nodes
9290

@@ -100,23 +98,20 @@ async def get_document_names():
10098

10199
documents = {}
102100
for node in all_nodes.values() if isinstance(all_nodes, dict) else all_nodes:
103-
metadata = node.get('metadata', {}) if isinstance(node, dict) else getattr(node, 'metadata', {})
104-
file_name = metadata.get('file_name')
105-
file_path = metadata.get('file_path')
101+
metadata = node.get("metadata", {}) if isinstance(node, dict) else getattr(node, "metadata", {})
102+
file_name = metadata.get("file_name")
103+
file_path = metadata.get("file_path")
106104
if file_name and file_name not in documents:
107105
documents[file_name] = {
108106
"file_name": file_name,
109107
"file_path": file_path,
110-
"file_type": metadata.get('file_type', 'unknown'),
111-
"chunk_count": 0
108+
"file_type": metadata.get("file_type", "unknown"),
109+
"chunk_count": 0,
112110
}
113111
if file_name:
114112
documents[file_name]["chunk_count"] += 1
115113

116-
return {
117-
"total_documents": len(documents),
118-
"documents": list(documents.values())
119-
}
114+
return {"total_documents": len(documents), "documents": list(documents.values())}
120115

121116

122117
# Upload files by a list of file_path
@@ -145,8 +140,7 @@ async def get_files():
145140
async def get_nodes_by_document_name(document_name: str) -> List[dict]:
146141
pl = ctx.get_pipeline_mgr().get_active_pipeline()
147142
if pl is None:
148-
raise HTTPException(
149-
status_code=status.HTTP_404_NOT_FOUND, detail="No active pipeline")
143+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No active pipeline")
150144

151145
nodelist = ctx.get_node_mgr().get_nodes(pl.node_parser.idx)
152146

@@ -155,13 +149,11 @@ async def get_nodes_by_document_name(document_name: str) -> List[dict]:
155149

156150
matching_nodes = []
157151
for node in nodelist:
158-
if hasattr(node, 'metadata') and node.metadata:
159-
node_file_name = node.metadata.get('file_name', '')
160-
node_file_path = node.metadata.get('file_path', '')
152+
if hasattr(node, "metadata") and node.metadata:
153+
node_file_name = node.metadata.get("file_name", "")
154+
node_file_path = node.metadata.get("file_path", "")
161155

162-
if (node_file_name == document_name or
163-
document_name in node_file_name or
164-
document_name in node_file_path):
156+
if node_file_name == document_name or document_name in node_file_name or document_name in node_file_path:
165157
node_dict = node.model_dump()
166158
matching_nodes.append(node_dict)
167159

@@ -173,8 +165,7 @@ async def get_nodes_by_document_name(document_name: str) -> List[dict]:
173165
async def get_node_by_id(node_id: str) -> dict:
174166
pl = ctx.get_pipeline_mgr().get_active_pipeline()
175167
if pl is None:
176-
raise HTTPException(
177-
status_code=status.HTTP_404_NOT_FOUND, detail="No active pipeline")
168+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No active pipeline")
178169

179170
nodelist = ctx.get_node_mgr().get_nodes(pl.node_parser.idx)
180171

@@ -193,8 +184,7 @@ async def get_node_by_id(node_id: str) -> dict:
193184
async def get_document_names():
194185
pl = ctx.get_pipeline_mgr().get_active_pipeline()
195186
if pl is None:
196-
raise HTTPException(
197-
status_code=status.HTTP_404_NOT_FOUND, detail="No active pipeline")
187+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No active pipeline")
198188

199189
nodelist = ctx.get_node_mgr().get_nodes(pl.node_parser.idx)
200190

@@ -203,25 +193,22 @@ async def get_document_names():
203193

204194
documents = {}
205195
for node in nodelist:
206-
if hasattr(node, 'metadata') and node.metadata:
207-
file_name = node.metadata.get('file_name')
208-
file_path = node.metadata.get('file_path')
196+
if hasattr(node, "metadata") and node.metadata:
197+
file_name = node.metadata.get("file_name")
198+
file_path = node.metadata.get("file_path")
209199

210200
if file_name and file_name not in documents:
211201
documents[file_name] = {
212202
"file_name": file_name,
213203
"file_path": file_path,
214-
"file_type": node.metadata.get('file_type', 'unknown'),
215-
"chunk_count": 0
204+
"file_type": node.metadata.get("file_type", "unknown"),
205+
"chunk_count": 0,
216206
}
217207

218208
if file_name:
219209
documents[file_name]["chunk_count"] += 1
220210

221-
return {
222-
"total_documents": len(documents),
223-
"documents": list(documents.values())
224-
}
211+
return {"total_documents": len(documents), "documents": list(documents.values())}
225212

226213

227214
# GET a file

0 commit comments

Comments
 (0)