diff --git a/backend/modules/evaluation/application/experiment_app.go b/backend/modules/evaluation/application/experiment_app.go index 246ae8084..ae29918d7 100644 --- a/backend/modules/evaluation/application/experiment_app.go +++ b/backend/modules/evaluation/application/experiment_app.go @@ -565,8 +565,71 @@ func (e *experimentApplication) SubmitExperiment(ctx context.Context, req *expt. // 2) 从有序 EvaluatorIDVersionList 中批量解析并按输入顺序回填版本ID // 3) 从 EvaluatorIDVersionList 中提取 runconfig 和权重配置,构建 evaluator_version_id 到 runconfig/权重的映射 // 注意:runconfig 用于评估器运行时配置,score weight 用于加权分数计算 +// validateEvaluatorVersionsBelongToWorkspace 校验直接传入的 evaluator_version_id 是否属于当前工作空间 +// 预置评估器(Builtin=true)允许跨空间复用,不做 SpaceID 校验 +func (e *experimentApplication) validateEvaluatorVersionsBelongToWorkspace(ctx context.Context, evaluatorVersionIDs []int64, workspaceID int64) error { + if len(evaluatorVersionIDs) == 0 || workspaceID <= 0 { + return nil + } + // 去重,避免重复查询 + seen := make(map[int64]struct{}, len(evaluatorVersionIDs)) + uniq := make([]int64, 0, len(evaluatorVersionIDs)) + for _, id := range evaluatorVersionIDs { + if id <= 0 { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + uniq = append(uniq, id) + } + if len(uniq) == 0 { + return nil + } + evs, err := e.evaluatorService.BatchGetEvaluatorVersion(ctx, nil, uniq, false) + if err != nil { + return err + } + found := make(map[int64]*entity.Evaluator, len(evs)) + for _, ev := range evs { + if ev == nil { + continue + } + found[ev.GetEvaluatorVersionID()] = ev + } + for _, id := range uniq { + ev, ok := found[id] + if !ok || ev == nil { + return errorx.NewByCode( + errno.EvaluatorVersionNotFoundCode, + errorx.WithExtraMsg(fmt.Sprintf("evaluator version %d not found", id)), + ) + } + // 预置评估器允许跨空间复用 + if ev.Builtin { + continue + } + if ev.GetSpaceID() != workspaceID { + return errorx.NewByCode( + errno.EvaluatorVersionNotFoundCode, + errorx.WithExtraMsg(fmt.Sprintf("evaluator %d version %s does not belong to workspace %d", ev.ID, ev.GetVersion(), workspaceID)), + ) + } + } + return nil +} + func (e *experimentApplication) resolveEvaluatorVersionIDsFromCreateReq(ctx context.Context, req *expt.CreateExperimentRequest) ([]int64, map[int64]*evaluatordto.EvaluatorRunConfig, map[int64]float64, error) { + workspaceID := req.GetWorkspaceID() + evalVersionIDs := make([]int64, 0, len(req.EvaluatorVersionIds)) + // 对于直接传入的 evaluator_version_id,需要校验是否属于当前空间(预置评估器除外) + if len(req.EvaluatorVersionIds) > 0 && workspaceID > 0 { + if err := e.validateEvaluatorVersionsBelongToWorkspace(ctx, req.EvaluatorVersionIds, workspaceID); err != nil { + return nil, nil, nil, err + } + } evalVersionIDs = append(evalVersionIDs, req.EvaluatorVersionIds...) // 权重映射:key 为 evaluator_version_id,value 为权重(用于加权分数计算) @@ -609,6 +672,7 @@ func (e *experimentApplication) resolveEvaluatorVersionIDsFromCreateReq(ctx cont } for _, ev := range evs { if ev != nil { + // 预置评估器允许跨空间复用,这里不做 SpaceID 校验 id2Builtin[ev.ID] = ev } } @@ -624,6 +688,16 @@ func (e *experimentApplication) resolveEvaluatorVersionIDsFromCreateReq(ctx cont if ev == nil { continue } + // 非预置评估器必须与实验 WorkspaceID 一致,防止绑定其他空间的评估器 + // 同时校验根字段 SpaceID(来自 evaluator 元信息)和内层版本 SpaceID(来自 evaluator_version) + if workspaceID > 0 && !ev.Builtin { + if ev.GetSpaceID() != workspaceID { + return nil, nil, nil, errorx.NewByCode( + errno.EvaluatorVersionNotFoundCode, + errorx.WithExtraMsg(fmt.Sprintf("evaluator %d version %s does not belong to workspace %d", ev.ID, ev.GetVersion(), workspaceID)), + ) + } + } key := fmt.Sprintf("%d#%s", ev.ID, ev.GetVersion()) pair2Eval[key] = ev } diff --git a/backend/modules/evaluation/application/experiment_app_test.go b/backend/modules/evaluation/application/experiment_app_test.go index a8542f707..e5c6bb1e3 100644 --- a/backend/modules/evaluation/application/experiment_app_test.go +++ b/backend/modules/evaluation/application/experiment_app_test.go @@ -163,6 +163,19 @@ func TestExperimentApplication_CreateExperiment(t *testing.T) { }, }, mockSetup: func() { + mockEvaluatorService.EXPECT().BatchGetEvaluatorVersion(gomock.Any(), gomock.Any(), []int64{10001}, false).Return([]*entity.Evaluator{ + { + ID: 3, + SpaceID: validWorkspaceID, + EvaluatorType: entity.EvaluatorTypePrompt, + PromptEvaluatorVersion: &entity.PromptEvaluatorVersion{ + SpaceID: validWorkspaceID, + ID: 10001, + EvaluatorID: 3, + Version: "v1", + }, + }, + }, nil) mockEvaluatorService.EXPECT().BatchGetBuiltinEvaluator(gomock.Any(), []int64{1, 1}).Return([]*entity.Evaluator{ { ID: 1, @@ -178,8 +191,10 @@ func TestExperimentApplication_CreateExperiment(t *testing.T) { mockEvaluatorService.EXPECT().BatchGetEvaluatorByIDAndVersion(gomock.Any(), gomock.Any()).Return([]*entity.Evaluator{ { ID: 2, + SpaceID: validWorkspaceID, EvaluatorType: entity.EvaluatorTypePrompt, PromptEvaluatorVersion: &entity.PromptEvaluatorVersion{ + SpaceID: validWorkspaceID, ID: 20200, EvaluatorID: 2, Version: "1.0.0", @@ -234,6 +249,63 @@ func TestExperimentApplication_CreateExperiment(t *testing.T) { }, wantErr: true, }, + { + name: "cross_workspace_evaluator_id_version_list_rejected", + req: &exptpb.CreateExperimentRequest{ + WorkspaceID: validWorkspaceID, + EvaluatorIDVersionList: []*evaluator.EvaluatorIDVersionItem{ + {EvaluatorID: gptr.Of(int64(2)), Version: gptr.Of("1.0.0")}, + }, + }, + mockSetup: func() { + mockEvaluatorService.EXPECT().BatchGetEvaluatorByIDAndVersion(gomock.Any(), gomock.Any()).Return([]*entity.Evaluator{ + { + ID: 2, + SpaceID: validWorkspaceID + 1, // 其他空间 + EvaluatorType: entity.EvaluatorTypePrompt, + Builtin: false, + PromptEvaluatorVersion: &entity.PromptEvaluatorVersion{ + SpaceID: validWorkspaceID + 1, + ID: 20200, + EvaluatorID: 2, + Version: "1.0.0", + }, + }, + }, nil) + }, + wantErr: true, + wantCode: errno.EvaluatorVersionNotFoundCode, + }, + { + name: "cross_workspace_builtin_evaluator_id_version_list_allowed", + req: &exptpb.CreateExperimentRequest{ + WorkspaceID: validWorkspaceID, + EvaluatorIDVersionList: []*evaluator.EvaluatorIDVersionItem{ + {EvaluatorID: gptr.Of(int64(9)), Version: gptr.Of("1.0.0")}, + }, + CreateEvalTargetParam: &eval_target.CreateEvalTargetParam{ + EvalTargetType: gptr.Of(domain_eval_target.EvalTargetType_CozeBot), + }, + }, + mockSetup: func() { + mockEvaluatorService.EXPECT().BatchGetEvaluatorByIDAndVersion(gomock.Any(), gomock.Any()).Return([]*entity.Evaluator{ + { + ID: 9, + SpaceID: validWorkspaceID + 2, // 预置评估器通常属于平台空间 + EvaluatorType: entity.EvaluatorTypePrompt, + Builtin: true, // 预置评估器允许跨空间复用 + PromptEvaluatorVersion: &entity.PromptEvaluatorVersion{ + SpaceID: validWorkspaceID + 2, + ID: 90900, + EvaluatorID: 9, + Version: "1.0.0", + }, + }, + }, nil) + mockManager.EXPECT().CreateExpt(gomock.Any(), gomock.Any(), gomock.Any()).Return(validExpt, nil) + }, + wantErr: false, + }, { name: "skip_missing_evaluators", req: &exptpb.CreateExperimentRequest{ @@ -248,6 +320,19 @@ func TestExperimentApplication_CreateExperiment(t *testing.T) { }, }, mockSetup: func() { + mockEvaluatorService.EXPECT().BatchGetEvaluatorVersion(gomock.Any(), gomock.Any(), []int64{10001}, false).Return([]*entity.Evaluator{ + { + ID: 3, + SpaceID: validWorkspaceID, + EvaluatorType: entity.EvaluatorTypePrompt, + PromptEvaluatorVersion: &entity.PromptEvaluatorVersion{ + SpaceID: validWorkspaceID, + ID: 10001, + EvaluatorID: 3, + Version: "v1", + }, + }, + }, nil) mockEvaluatorService.EXPECT().BatchGetBuiltinEvaluator(gomock.Any(), []int64{1}).Return([]*entity.Evaluator{ { ID: 1, diff --git a/backend/modules/evaluation/infra/repo/evaluator/mysql/evaluator_tag.go b/backend/modules/evaluation/infra/repo/evaluator/mysql/evaluator_tag.go index a5c2693d9..ca018bd26 100644 --- a/backend/modules/evaluation/infra/repo/evaluator/mysql/evaluator_tag.go +++ b/backend/modules/evaluation/infra/repo/evaluator/mysql/evaluator_tag.go @@ -363,7 +363,7 @@ func (dao *EvaluatorTagDAOImpl) querySourceIDsForCondition(ctx context.Context, return nil, nil } } - if len(restrictTo) <= sourceIDInChunkSize { + if restrictTo == nil || len(restrictTo) <= sourceIDInChunkSize { return dao.querySourceIDsForConditionOnce(ctx, tagType, langType, condition, restrictTo, opts...) } set := make(map[int64]struct{}) @@ -412,7 +412,7 @@ func (dao *EvaluatorTagDAOImpl) sourceIDsForNameLike(ctx context.Context, tagTyp return nil, nil } } - if len(restrictTo) <= sourceIDInChunkSize { + if restrictTo == nil || len(restrictTo) <= sourceIDInChunkSize { return dao.sourceIDsForNameLikeOnce(ctx, tagType, langType, kw, restrictTo, opts...) } set := make(map[int64]struct{})