Skip to content

Commit 477af37

Browse files
committed
feat: implement LLMFactory for unified model creation and health checks; add is_deleted field to model config
1 parent 152f5cd commit 477af37

17 files changed

Lines changed: 180 additions & 141 deletions

File tree

backend/shared/domain-common/src/main/java/com/datamate/common/setting/domain/entity/ModelConfig.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
*/
1313
@Getter
1414
@Setter
15-
@TableName("t_model_config")
15+
@TableName("t_models")
1616
@Builder
1717
@ToString
1818
@NoArgsConstructor

runtime/datamate-python/app/db/models/model_config.py renamed to runtime/datamate-python/app/db/models/models.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
1-
from sqlalchemy import Column, String, Integer, TIMESTAMP, select
1+
from sqlalchemy import Boolean, Column, String, TIMESTAMP
22

33
from app.db.models.base_entity import BaseEntity
44

55

6-
async def get_model_by_id(db_session, model_id: str):
7-
"""根据 ID 获取单个模型配置。"""
8-
result =await db_session.execute(select(ModelConfig).where(ModelConfig.id == model_id))
9-
model_config = result.scalar_one_or_none()
10-
return model_config
6+
class Models(BaseEntity):
7+
"""模型配置表,对应表 t_models
118
12-
class ModelConfig(BaseEntity):
13-
"""模型配置表,对应表 t_model_config
14-
15-
CREATE TABLE IF NOT EXISTS t_model_config (
9+
CREATE TABLE IF NOT EXISTS t_models (
1610
id VARCHAR(36) PRIMARY KEY COMMENT '主键ID',
1711
model_name VARCHAR(100) NOT NULL COMMENT '模型名称(如 qwen2)',
1812
provider VARCHAR(50) NOT NULL COMMENT '模型提供商(如 Ollama、OpenAI、DeepSeek)',
@@ -29,7 +23,7 @@ class ModelConfig(BaseEntity):
2923
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COMMENT ='模型配置表';
3024
"""
3125

32-
__tablename__ = "t_model_config"
26+
__tablename__ = "t_models"
3327

3428
id = Column(String(36), primary_key=True, index=True, comment="主键ID")
3529
model_name = Column(String(100), nullable=False, comment="模型名称(如 qwen2)")
@@ -38,10 +32,9 @@ class ModelConfig(BaseEntity):
3832
api_key = Column(String(512), nullable=False, default="", comment="API 密钥(无密钥则为空)")
3933
type = Column(String(50), nullable=False, comment="模型类型(如 chat、embedding)")
4034

41-
# 使用 Integer 存储 TINYINT,后续可在业务层将 0/1 转为 bool
42-
is_enabled = Column(Integer, nullable=False, default=1, comment="是否启用:1-启用,0-禁用")
43-
is_default = Column(Integer, nullable=False, default=0, comment="是否默认:1-默认,0-非默认")
44-
is_deleted = Column(Integer, nullable=False, default=0, comment="是否删除:1-已删除,0-未删除")
35+
is_enabled = Column(Boolean, nullable=False, default=True, comment="是否启用")
36+
is_default = Column(Boolean, nullable=False, default=False, comment="是否默认")
37+
is_deleted = Column(Boolean, nullable=False, default=False, comment="是否删除")
4538

4639
__table_args__ = (
4740
# 与 DDL 中的 uk_model_provider 保持一致

runtime/datamate-python/app/module/evaluation/interface/evaluation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ async def create_evaluation_task(
8080
if existing_task.scalar_one_or_none():
8181
raise HTTPException(status_code=400, detail=f"Evaluation task with name '{request.name}' already exists")
8282

83-
model_config = await get_model_by_id(db, request.eval_config.model_id)
84-
if not model_config:
83+
models = await get_model_by_id(db, request.eval_config.model_id)
84+
if not models:
8585
raise HTTPException(status_code=400, detail=f"Model with id '{request.eval_config.model_id}' not found")
8686

8787
# 创建评估任务
@@ -96,7 +96,7 @@ async def create_evaluation_task(
9696
eval_prompt=request.eval_prompt,
9797
eval_config=json.dumps({
9898
"modelId": request.eval_config.model_id,
99-
"modelName": model_config.model_name,
99+
"modelName": models.model_name,
100100
"dimensions": request.eval_config.dimensions,
101101
}),
102102
status=TaskStatus.PENDING.value,

runtime/datamate-python/app/module/evaluation/service/evaluation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def get_eval_prompt(self, item: EvaluationItem) -> str:
4343

4444
async def execute(self):
4545
eval_config = json.loads(self.task.eval_config)
46-
model_config = await get_model_by_id(self.db, eval_config.get("modelId"))
46+
models = await get_model_by_id(self.db, eval_config.get("modelId"))
4747
semaphore = asyncio.Semaphore(10)
4848
files = (await self.db.execute(
4949
select(EvaluationFile).where(EvaluationFile.task_id == self.task.id)
@@ -55,7 +55,7 @@ async def execute(self):
5555
for file in files:
5656
items = (await self.db.execute(query.where(EvaluationItem.file_id == file.file_id))).scalars().all()
5757
tasks = [
58-
self.evaluate_item(model_config, item, semaphore)
58+
self.evaluate_item(models, item, semaphore)
5959
for item in items
6060
]
6161
await asyncio.gather(*tasks, return_exceptions=True)
@@ -64,13 +64,13 @@ async def execute(self):
6464
self.task.eval_process = evaluated_count / total
6565
await self.db.commit()
6666

67-
async def evaluate_item(self, model_config, item: EvaluationItem, semaphore: asyncio.Semaphore):
67+
async def evaluate_item(self, models, item: EvaluationItem, semaphore: asyncio.Semaphore):
6868
async with semaphore:
6969
max_try = 3
7070
while max_try > 0:
7171
prompt_text = self.get_eval_prompt(item)
7272
resp_text = await asyncio.to_thread(
73-
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
73+
call_openai_style_model, models.base_url, models.api_key, models.model_name,
7474
prompt_text,
7575
)
7676
resp_text = extract_json_substring(resp_text)

runtime/datamate-python/app/module/generation/service/generation_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from app.module.shared.common.document_loaders import load_documents
2525
from app.module.shared.common.text_split import DocumentSplitter
2626
from app.module.shared.util.model_chat import extract_json_substring
27-
from app.core.llm import LLMFactory
27+
from app.module.shared.llm import LLMFactory
2828
from app.module.system.service.common_service import get_model_by_id
2929

3030

runtime/datamate-python/app/module/rag/service/rag_service.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import asyncio
33
from typing import Optional, Sequence
44

5-
from fastapi import BackgroundTasks, Depends
5+
from fastapi import Depends
66
from sqlalchemy import select
77
from sqlalchemy.ext.asyncio import AsyncSession
88

@@ -17,7 +17,7 @@
1717
build_llm_model_func,
1818
initialize_rag,
1919
)
20-
from app.core.llm import LLMFactory
20+
from app.module.shared.llm import LLMFactory
2121
from ...system.service.common_service import get_model_by_id
2222

2323
logger = get_logger(__name__)
@@ -44,8 +44,8 @@ async def get_unprocessed_files(self, knowledge_base_id: str) -> Sequence[RagFil
4444

4545
async def init_graph_rag(self, knowledge_base_id: str):
4646
kb = await self._get_knowledge_base(knowledge_base_id)
47-
embedding_model = await self._get_model_config(kb.embedding_model)
48-
chat_model = await self._get_model_config(kb.chat_model)
47+
embedding_model = await self._get_models(kb.embedding_model)
48+
chat_model = await self._get_models(kb.chat_model)
4949

5050
llm_callable = await build_llm_model_func(
5151
chat_model.model_name, chat_model.base_url, chat_model.api_key
@@ -126,13 +126,13 @@ async def _get_knowledge_base(self, knowledge_base_id: str):
126126
raise ValueError(f"Knowledge base with ID {knowledge_base_id} not found.")
127127
return knowledge_base
128128

129-
async def _get_model_config(self, model_id: Optional[str]):
129+
async def _get_models(self, model_id: Optional[str]):
130130
if not model_id:
131131
raise ValueError("Model ID is required for initializing RAG.")
132-
model = await get_model_by_id(self.db, model_id)
133-
if not model:
134-
raise ValueError(f"Model config with ID {model_id} not found.")
135-
return model
132+
models = await get_model_by_id(self.db, model_id)
133+
if not models:
134+
raise ValueError(f"Models with ID {model_id} not found.")
135+
return models
136136

137137
async def query_rag(self, query: str, knowledge_base_id: str) -> str:
138138
if not self.rag:
File renamed without changes.

runtime/datamate-python/app/core/llm/factory.py renamed to runtime/datamate-python/app/module/shared/llm/factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def check_health(
4545
model_name: str,
4646
base_url: str,
4747
api_key: str | None,
48-
model_type: Literal["CHAT", "EMBEDDING"],
48+
model_type: Literal["CHAT", "EMBEDDING"] | str,
4949
) -> None:
5050
"""对配置做一次最小化调用进行健康检查,失败则抛出。"""
5151
if model_type == "CHAT":
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from fastapi import APIRouter
22

33
from .about import router as about_router
4-
from .model_config import router as model_config_router
4+
from app.module.system.interface.models import router as models_router
55

66
router = APIRouter()
77

88
router.include_router(about_router)
9-
router.include_router(model_config_router)
9+
router.include_router(models_router)

runtime/datamate-python/app/module/system/interface/model_config.py renamed to runtime/datamate-python/app/module/system/interface/models.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,34 @@
1-
"""
2-
模型配置 REST 接口:与 Java ModelConfigController 路径、语义一致,响应使用 StandardResponse。
3-
db 通过 ModelConfigService 的 Depends(get_db) 注入,不在本层传递。
4-
"""
51
from fastapi import APIRouter, Depends, Query
62

73
from app.module.shared.schema import StandardResponse, PaginatedData
8-
from app.module.system.schema import (
4+
from app.module.system.schema.models import (
95
CreateModelRequest,
106
QueryModelRequest,
11-
ModelConfigResponse,
7+
ModelsResponse,
128
ProviderItem,
139
ModelType,
1410
)
15-
from app.module.system.service.model_config_service import ModelConfigService
11+
from app.module.system.service.models_service import ModelsService
1612

1713
router = APIRouter(prefix="/models", tags=["models"])
1814

1915

2016
@router.get("/providers", response_model=StandardResponse[list[ProviderItem]])
21-
async def get_providers(svc: ModelConfigService = Depends()):
17+
async def get_providers(svc: ModelsService = Depends()):
2218
"""获取厂商列表,与 Java GET /models/providers 一致。"""
2319
data = await svc.get_providers()
2420
return StandardResponse(code=200, message="success", data=data)
2521

2622

27-
@router.get("/list", response_model=StandardResponse[PaginatedData[ModelConfigResponse]])
23+
@router.get("/list", response_model=StandardResponse[PaginatedData[ModelsResponse]])
2824
async def get_models(
2925
page: int = Query(0, ge=0, description="页码,从 0 开始"),
3026
size: int = Query(20, gt=0, le=500, description="每页大小"),
3127
provider: str | None = Query(None, description="模型提供商"),
3228
type: ModelType | None = Query(None, description="模型类型"),
3329
isEnabled: bool | None = Query(None, description="是否启用"),
3430
isDefault: bool | None = Query(None, description="是否默认"),
35-
svc: ModelConfigService = Depends(),
31+
svc: ModelsService = Depends(),
3632
):
3733
"""分页查询模型列表,与 Java GET /models/list 一致。"""
3834
q = QueryModelRequest(
@@ -47,33 +43,33 @@ async def get_models(
4743
return StandardResponse(code=200, message="success", data=data)
4844

4945

50-
@router.post("/create", response_model=StandardResponse[ModelConfigResponse])
51-
async def create_model(req: CreateModelRequest, svc: ModelConfigService = Depends()):
46+
@router.post("/create", response_model=StandardResponse[ModelsResponse])
47+
async def create_model(req: CreateModelRequest, svc: ModelsService = Depends()):
5248
"""创建模型配置,与 Java POST /models/create 一致。"""
5349
data = await svc.create_model(req)
5450
return StandardResponse(code=200, message="success", data=data)
5551

5652

57-
@router.get("/{model_id}", response_model=StandardResponse[ModelConfigResponse])
58-
async def get_model_detail(model_id: str, svc: ModelConfigService = Depends()):
53+
@router.get("/{model_id}", response_model=StandardResponse[ModelsResponse])
54+
async def get_model_detail(model_id: str, svc: ModelsService = Depends()):
5955
"""获取模型详情,与 Java GET /models/{modelId} 一致。"""
6056
data = await svc.get_model_detail(model_id)
6157
return StandardResponse(code=200, message="success", data=data)
6258

6359

64-
@router.put("/{model_id}", response_model=StandardResponse[ModelConfigResponse])
60+
@router.put("/{model_id}", response_model=StandardResponse[ModelsResponse])
6561
async def update_model(
6662
model_id: str,
6763
req: CreateModelRequest,
68-
svc: ModelConfigService = Depends(),
64+
svc: ModelsService = Depends(),
6965
):
7066
"""更新模型配置,与 Java PUT /models/{modelId} 一致。"""
7167
data = await svc.update_model(model_id, req)
7268
return StandardResponse(code=200, message="success", data=data)
7369

7470

7571
@router.delete("/{model_id}", response_model=StandardResponse[None])
76-
async def delete_model(model_id: str, svc: ModelConfigService = Depends()):
72+
async def delete_model(model_id: str, svc: ModelsService = Depends()):
7773
"""删除模型配置,与 Java DELETE /models/{modelId} 一致。"""
7874
await svc.delete_model(model_id)
7975
return StandardResponse(code=200, message="success", data=None)

0 commit comments

Comments
 (0)