Skip to content

Commit 39e50ce

Browse files
committed
feat: implement RAG management module with knowledge base and file models, and add processing endpoint
1 parent e5132fc commit 39e50ce

9 files changed

Lines changed: 205 additions & 0 deletions

File tree

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
Tables of RAG Management Module
3+
"""
4+
import uuid
5+
from sqlalchemy import Column, String, TIMESTAMP, Text, Integer, JSON
6+
from sqlalchemy.sql import func
7+
from app.db.session import Base
8+
9+
10+
class RagKnowledgeBase(Base):
11+
"""知识库模型"""
12+
__tablename__ = "t_rag_knowledge_base"
13+
14+
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
15+
name = Column(String(255), nullable=False, comment="知识库名称")
16+
type = Column(String(50), nullable=False, comment="知识库类型")
17+
description = Column(String(512), nullable=True, comment="知识库描述")
18+
embedding_model = Column(String(255), nullable=False, comment="嵌入模型")
19+
chat_model = Column(String(255), nullable=True, comment="聊天模型")
20+
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
21+
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(),
22+
comment="更新时间")
23+
created_by = Column(String(255), nullable=True, comment="创建者")
24+
updated_by = Column(String(255), nullable=True, comment="更新者")
25+
26+
def __repr__(self):
27+
return f"<RagKnowledgeBase(id={self.id}, name={self.name}, type={self.type})>"
28+
29+
30+
class RagFile(Base):
31+
"""知识库文件模型"""
32+
__tablename__ = "t_rag_file"
33+
34+
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
35+
knowledge_base_id = Column(String(36), nullable=False, comment="知识库ID")
36+
file_name = Column(String(255), nullable=False, comment="文件名")
37+
file_id = Column(String(255), nullable=False, comment="文件ID")
38+
chunk_count = Column(Integer, nullable=True, comment="切片数")
39+
file_metadata = Column("metadata", JSON, nullable=True, comment="元数据")
40+
status = Column(String(50), nullable=True, comment="文件状态")
41+
err_msg = Column(Text, nullable=True, comment="错误信息")
42+
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
43+
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(),
44+
comment="更新时间")
45+
created_by = Column(String(255), nullable=True, comment="创建者")
46+
updated_by = Column(String(255), nullable=True, comment="更新者")
47+

runtime/datamate-python/app/module/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .generation.interface import router as generation_router
77
from .evaluation.interface import router as evaluation_router
88
from .collection.interface import router as collection_route
9+
from .rag.interface.rag_interface import router as rag_router
910

1011
router = APIRouter(
1112
prefix="/api"
@@ -17,5 +18,6 @@
1718
router.include_router(generation_router)
1819
router.include_router(evaluation_router)
1920
router.include_router(collection_route)
21+
router.include_router(rag_router)
2022

2123
__all__ = ["router"]

runtime/datamate-python/app/module/rag/__init__.py

Whitespace-only changes.

runtime/datamate-python/app/module/rag/interface/__init__.py

Whitespace-only changes.
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from fastapi import APIRouter, Depends, HTTPException
2+
from sqlalchemy.ext.asyncio import AsyncSession
3+
4+
from app.db.session import get_db
5+
from app.module.rag.service.rag_service import RAGService
6+
from app.module.shared.schema import StandardResponse
7+
8+
router = APIRouter(prefix="/rag", tags=["rag"])
9+
10+
@router.post("/process/{knowledge_base_id}")
11+
async def process_knowledge_base(knowledge_base_id: str, db: AsyncSession = Depends(get_db)):
12+
"""
13+
Process all unprocessed files in a knowledge base.
14+
"""
15+
try:
16+
await RAGService(db).init_graph_rag(knowledge_base_id)
17+
return StandardResponse(
18+
code=200,
19+
message="Processing started for knowledge base.",
20+
data=None
21+
)
22+
except Exception as e:
23+
raise HTTPException(status_code=500, detail=str(e))
24+
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from pydantic import BaseModel
2+
3+
class ProcessRequest(BaseModel):
4+
knowledge_base_id: str
5+

runtime/datamate-python/app/module/rag/service/__init__.py

Whitespace-only changes.
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import asyncio
2+
import os
3+
from typing import Awaitable, Callable, Optional
4+
5+
import numpy as np
6+
from lightrag import LightRAG, QueryParam
7+
from lightrag.llm.openai import openai_embed, openai_complete_if_cache
8+
from lightrag.utils import setup_logger, EmbeddingFunc
9+
10+
setup_logger("lightrag", level="DEBUG")
11+
DEFAULT_WORKING_DIR = "/rag_storage"
12+
13+
14+
async def build_llm_model_func(model_name: str, base_url: str, api_key: str) -> Callable[..., Awaitable[str]]:
15+
async def _llm_model(
16+
prompt, system_prompt=None, history_messages=None, **kwargs
17+
) -> str:
18+
history_messages = history_messages or []
19+
return await openai_complete_if_cache(
20+
model_name,
21+
prompt,
22+
system_prompt=system_prompt,
23+
history_messages=history_messages,
24+
api_key=api_key,
25+
base_url=base_url,
26+
**kwargs,
27+
)
28+
29+
return _llm_model
30+
31+
32+
async def build_embedding_func(
33+
model_name: str, base_url: str, api_key: str, embedding_dim: int
34+
) -> EmbeddingFunc:
35+
async def _embedding_func(texts: list[str]) -> np.ndarray:
36+
return await openai_embed.func(
37+
texts,
38+
model=model_name,
39+
api_key=api_key,
40+
base_url=base_url,
41+
embedding_dim=embedding_dim,
42+
)
43+
44+
return EmbeddingFunc(embedding_dim=embedding_dim, func=_embedding_func)
45+
46+
47+
async def initialize_rag(
48+
llm_callable: Callable[..., Awaitable[str]],
49+
embedding_callable: EmbeddingFunc,
50+
working_dir: Optional[str] = None,
51+
):
52+
target_dir = working_dir or DEFAULT_WORKING_DIR
53+
os.makedirs(target_dir, exist_ok=True)
54+
rag = LightRAG(
55+
working_dir=target_dir,
56+
llm_model_func=llm_callable,
57+
embedding_func=embedding_callable,
58+
)
59+
await rag.initialize_storages()
60+
return rag
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import os
2+
from typing import Optional
3+
4+
from fastapi import Depends
5+
from sqlalchemy import select
6+
from sqlalchemy.ext.asyncio import AsyncSession
7+
8+
from app.db.models.knowledge_gen import RagKnowledgeBase
9+
from app.db.models.model_config import ModelConfig
10+
from app.db.session import AsyncSessionLocal
11+
from .graph_rag import (
12+
DEFAULT_WORKING_DIR,
13+
build_embedding_func,
14+
build_llm_model_func,
15+
initialize_rag,
16+
)
17+
18+
19+
class RAGService:
20+
def __init__(
21+
self,
22+
db: AsyncSession = Depends(AsyncSessionLocal),
23+
):
24+
self.db = db
25+
self.rag = None
26+
27+
28+
async def get_unprocessed_files(self, knowledge_base_id: str) -> list[str]:
29+
pass
30+
31+
async def init_graph_rag(self, knowledge_base_id: str):
32+
kb = await self._get_knowledge_base(knowledge_base_id)
33+
embedding_model = await self._get_model_config(kb.embedding_model)
34+
chat_model = await self._get_model_config(kb.chat_model)
35+
36+
llm_callable = await build_llm_model_func(
37+
chat_model.model_name, chat_model.base_url, chat_model.api_key
38+
)
39+
embedding_callable = await build_embedding_func(
40+
embedding_model.model_name,
41+
embedding_model.base_url,
42+
embedding_model.api_key,
43+
embedding_dim=embedding_model.embedding_dim if hasattr(embedding_model, "embedding_dim") else 1024,
44+
)
45+
46+
kb_working_dir = os.path.join(DEFAULT_WORKING_DIR, kb.name)
47+
self.rag = await initialize_rag(llm_callable, embedding_callable, kb_working_dir)
48+
return {"status": "initialized", "knowledge_base_id": knowledge_base_id}
49+
50+
async def _get_knowledge_base(self, knowledge_base_id: str):
51+
result = await self.db.execute(
52+
select(RagKnowledgeBase).where(RagKnowledgeBase.id == knowledge_base_id)
53+
)
54+
knowledge_base = result.scalars().first()
55+
if not knowledge_base:
56+
raise ValueError(f"Knowledge base with ID {knowledge_base_id} not found.")
57+
return knowledge_base
58+
59+
async def _get_model_config(self, model_id: Optional[str]):
60+
if not model_id:
61+
raise ValueError("Model ID is required for initializing RAG.")
62+
result = await self.db.execute(select(ModelConfig).where(ModelConfig.id == model_id))
63+
model = result.scalars().first()
64+
if not model:
65+
raise ValueError(f"Model config with ID {model_id} not found.")
66+
return model
67+

0 commit comments

Comments
 (0)