Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions backend/modules/evaluation/application/experiment_app.go
Original file line number Diff line number Diff line change
Expand Up @@ -565,8 +565,71 @@ func (e *experimentApplication) SubmitExperiment(ctx context.Context, req *expt.
// 2) 从有序 EvaluatorIDVersionList 中批量解析并按输入顺序回填版本ID
// 3) 从 EvaluatorIDVersionList 中提取 runconfig 和权重配置,构建 evaluator_version_id 到 runconfig/权重的映射
// 注意:runconfig 用于评估器运行时配置,score weight 用于加权分数计算
// validateEvaluatorVersionsBelongToWorkspace 校验直接传入的 evaluator_version_id 是否属于当前工作空间
// 预置评估器(Builtin=true)允许跨空间复用,不做 SpaceID 校验
func (e *experimentApplication) validateEvaluatorVersionsBelongToWorkspace(ctx context.Context, evaluatorVersionIDs []int64, workspaceID int64) error {
if len(evaluatorVersionIDs) == 0 || workspaceID <= 0 {
return nil
}
// 去重,避免重复查询
seen := make(map[int64]struct{}, len(evaluatorVersionIDs))
uniq := make([]int64, 0, len(evaluatorVersionIDs))
for _, id := range evaluatorVersionIDs {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
uniq = append(uniq, id)
}
if len(uniq) == 0 {
return nil
}
evs, err := e.evaluatorService.BatchGetEvaluatorVersion(ctx, nil, uniq, false)
if err != nil {
return err
}
found := make(map[int64]*entity.Evaluator, len(evs))
for _, ev := range evs {
if ev == nil {
continue
}
found[ev.GetEvaluatorVersionID()] = ev
}
for _, id := range uniq {
ev, ok := found[id]
if !ok || ev == nil {
return errorx.NewByCode(
errno.EvaluatorVersionNotFoundCode,
errorx.WithExtraMsg(fmt.Sprintf("evaluator version %d not found", id)),
)
}
// 预置评估器允许跨空间复用
if ev.Builtin {
continue
}
if ev.GetSpaceID() != workspaceID {
return errorx.NewByCode(
errno.EvaluatorVersionNotFoundCode,
errorx.WithExtraMsg(fmt.Sprintf("evaluator %d version %s does not belong to workspace %d", ev.ID, ev.GetVersion(), workspaceID)),
)
}
}
return nil
}

func (e *experimentApplication) resolveEvaluatorVersionIDsFromCreateReq(ctx context.Context, req *expt.CreateExperimentRequest) ([]int64, map[int64]*evaluatordto.EvaluatorRunConfig, map[int64]float64, error) {
workspaceID := req.GetWorkspaceID()

evalVersionIDs := make([]int64, 0, len(req.EvaluatorVersionIds))
// 对于直接传入的 evaluator_version_id,需要校验是否属于当前空间(预置评估器除外)
if len(req.EvaluatorVersionIds) > 0 && workspaceID > 0 {
if err := e.validateEvaluatorVersionsBelongToWorkspace(ctx, req.EvaluatorVersionIds, workspaceID); err != nil {
return nil, nil, nil, err
}
}
evalVersionIDs = append(evalVersionIDs, req.EvaluatorVersionIds...)

// 权重映射:key 为 evaluator_version_id,value 为权重(用于加权分数计算)
Expand Down Expand Up @@ -609,6 +672,7 @@ func (e *experimentApplication) resolveEvaluatorVersionIDsFromCreateReq(ctx cont
}
for _, ev := range evs {
if ev != nil {
// 预置评估器允许跨空间复用,这里不做 SpaceID 校验
id2Builtin[ev.ID] = ev
}
}
Expand All @@ -624,6 +688,16 @@ func (e *experimentApplication) resolveEvaluatorVersionIDsFromCreateReq(ctx cont
if ev == nil {
continue
}
// 非预置评估器必须与实验 WorkspaceID 一致,防止绑定其他空间的评估器
// 同时校验根字段 SpaceID(来自 evaluator 元信息)和内层版本 SpaceID(来自 evaluator_version)
if workspaceID > 0 && !ev.Builtin {
if ev.GetSpaceID() != workspaceID {
return nil, nil, nil, errorx.NewByCode(
errno.EvaluatorVersionNotFoundCode,
errorx.WithExtraMsg(fmt.Sprintf("evaluator %d version %s does not belong to workspace %d", ev.ID, ev.GetVersion(), workspaceID)),
)
}
}
key := fmt.Sprintf("%d#%s", ev.ID, ev.GetVersion())
pair2Eval[key] = ev
}
Expand Down
85 changes: 85 additions & 0 deletions backend/modules/evaluation/application/experiment_app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,19 @@ func TestExperimentApplication_CreateExperiment(t *testing.T) {
},
},
mockSetup: func() {
mockEvaluatorService.EXPECT().BatchGetEvaluatorVersion(gomock.Any(), gomock.Any(), []int64{10001}, false).Return([]*entity.Evaluator{
{
ID: 3,
SpaceID: validWorkspaceID,
EvaluatorType: entity.EvaluatorTypePrompt,
PromptEvaluatorVersion: &entity.PromptEvaluatorVersion{
SpaceID: validWorkspaceID,
ID: 10001,
EvaluatorID: 3,
Version: "v1",
},
},
}, nil)
mockEvaluatorService.EXPECT().BatchGetBuiltinEvaluator(gomock.Any(), []int64{1, 1}).Return([]*entity.Evaluator{
{
ID: 1,
Expand All @@ -178,8 +191,10 @@ func TestExperimentApplication_CreateExperiment(t *testing.T) {
mockEvaluatorService.EXPECT().BatchGetEvaluatorByIDAndVersion(gomock.Any(), gomock.Any()).Return([]*entity.Evaluator{
{
ID: 2,
SpaceID: validWorkspaceID,
EvaluatorType: entity.EvaluatorTypePrompt,
PromptEvaluatorVersion: &entity.PromptEvaluatorVersion{
SpaceID: validWorkspaceID,
ID: 20200,
EvaluatorID: 2,
Version: "1.0.0",
Expand Down Expand Up @@ -234,6 +249,63 @@ func TestExperimentApplication_CreateExperiment(t *testing.T) {
},
wantErr: true,
},
{
name: "cross_workspace_evaluator_id_version_list_rejected",
req: &exptpb.CreateExperimentRequest{
WorkspaceID: validWorkspaceID,
EvaluatorIDVersionList: []*evaluator.EvaluatorIDVersionItem{
{EvaluatorID: gptr.Of(int64(2)), Version: gptr.Of("1.0.0")},
},
},
mockSetup: func() {
mockEvaluatorService.EXPECT().BatchGetEvaluatorByIDAndVersion(gomock.Any(), gomock.Any()).Return([]*entity.Evaluator{
{
ID: 2,
SpaceID: validWorkspaceID + 1, // 其他空间
EvaluatorType: entity.EvaluatorTypePrompt,
Builtin: false,
PromptEvaluatorVersion: &entity.PromptEvaluatorVersion{
SpaceID: validWorkspaceID + 1,
ID: 20200,
EvaluatorID: 2,
Version: "1.0.0",
},
},
}, nil)
},
wantErr: true,
wantCode: errno.EvaluatorVersionNotFoundCode,
},
{
name: "cross_workspace_builtin_evaluator_id_version_list_allowed",
req: &exptpb.CreateExperimentRequest{
WorkspaceID: validWorkspaceID,
EvaluatorIDVersionList: []*evaluator.EvaluatorIDVersionItem{
{EvaluatorID: gptr.Of(int64(9)), Version: gptr.Of("1.0.0")},
},
CreateEvalTargetParam: &eval_target.CreateEvalTargetParam{
EvalTargetType: gptr.Of(domain_eval_target.EvalTargetType_CozeBot),
},
},
mockSetup: func() {
mockEvaluatorService.EXPECT().BatchGetEvaluatorByIDAndVersion(gomock.Any(), gomock.Any()).Return([]*entity.Evaluator{
{
ID: 9,
SpaceID: validWorkspaceID + 2, // 预置评估器通常属于平台空间
EvaluatorType: entity.EvaluatorTypePrompt,
Builtin: true, // 预置评估器允许跨空间复用
PromptEvaluatorVersion: &entity.PromptEvaluatorVersion{
SpaceID: validWorkspaceID + 2,
ID: 90900,
EvaluatorID: 9,
Version: "1.0.0",
},
},
}, nil)
mockManager.EXPECT().CreateExpt(gomock.Any(), gomock.Any(), gomock.Any()).Return(validExpt, nil)
},
wantErr: false,
},
{
name: "skip_missing_evaluators",
req: &exptpb.CreateExperimentRequest{
Expand All @@ -248,6 +320,19 @@ func TestExperimentApplication_CreateExperiment(t *testing.T) {
},
},
mockSetup: func() {
mockEvaluatorService.EXPECT().BatchGetEvaluatorVersion(gomock.Any(), gomock.Any(), []int64{10001}, false).Return([]*entity.Evaluator{
{
ID: 3,
SpaceID: validWorkspaceID,
EvaluatorType: entity.EvaluatorTypePrompt,
PromptEvaluatorVersion: &entity.PromptEvaluatorVersion{
SpaceID: validWorkspaceID,
ID: 10001,
EvaluatorID: 3,
Version: "v1",
},
},
}, nil)
mockEvaluatorService.EXPECT().BatchGetBuiltinEvaluator(gomock.Any(), []int64{1}).Return([]*entity.Evaluator{
{
ID: 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ func (dao *EvaluatorTagDAOImpl) querySourceIDsForCondition(ctx context.Context,
return nil, nil
}
}
if len(restrictTo) <= sourceIDInChunkSize {
if restrictTo == nil || len(restrictTo) <= sourceIDInChunkSize {
return dao.querySourceIDsForConditionOnce(ctx, tagType, langType, condition, restrictTo, opts...)
}
set := make(map[int64]struct{})
Expand Down Expand Up @@ -412,7 +412,7 @@ func (dao *EvaluatorTagDAOImpl) sourceIDsForNameLike(ctx context.Context, tagTyp
return nil, nil
}
}
if len(restrictTo) <= sourceIDInChunkSize {
if restrictTo == nil || len(restrictTo) <= sourceIDInChunkSize {
return dao.sourceIDsForNameLikeOnce(ctx, tagType, langType, kw, restrictTo, opts...)
}
set := make(map[int64]struct{})
Expand Down
Loading