Skip to content

Commit 4012ced

Browse files
committed
weight score draft
1 parent 00b2486 commit 4012ced

8 files changed

Lines changed: 176 additions & 49 deletions

File tree

backend/modules/evaluation/application/convertor/experiment/expt.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,19 @@ type EvalConfConvert struct{}
2929

3030
func (e *EvalConfConvert) ConvertToEntity(cer *expt.CreateExperimentRequest) (*entity.EvaluationConfiguration, error) {
3131
ec := &entity.EvaluationConfiguration{
32-
ItemConcurNum: ptr.ConvIntPtr[int32, int](cer.ItemConcurNum),
33-
EnableWeightedScore: gptr.Indirect(cer.EnableWeightedScore),
34-
EvaluatorScoreWeights: cer.GetEvaluatorScoreWeights(),
32+
ItemConcurNum: ptr.ConvIntPtr[int32, int](cer.ItemConcurNum),
3533
}
34+
3635
ec.ConnectorConf.TargetConf = &entity.TargetConf{
3736
TargetVersionID: cer.GetTargetVersionID(),
3837
IngressConf: toTargetFieldMappingDO(cer.GetTargetFieldMapping(), cer.GetTargetRuntimeParam()),
3938
}
4039
if cer.GetEvaluatorFieldMapping() != nil {
4140
ec.ConnectorConf.EvaluatorsConf = &entity.EvaluatorsConf{
42-
EvaluatorConcurNum: ptr.ConvIntPtr[int32, int](cer.EvaluatorsConcurNum),
43-
EvaluatorConf: toEvaluatorFieldMappingDo(cer.GetEvaluatorFieldMapping()),
41+
EvaluatorConcurNum: ptr.ConvIntPtr[int32, int](cer.EvaluatorsConcurNum),
42+
EvaluatorConf: toEvaluatorFieldMappingDo(cer.GetEvaluatorFieldMapping()),
43+
EnableWeightedScore: gptr.Indirect(cer.EnableWeightedScore),
44+
EvaluatorScoreWeights: cer.GetEvaluatorScoreWeights(),
4445
}
4546
}
4647
return ec, nil

backend/modules/evaluation/application/convertor/experiment/expt_template.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,8 @@ func ConvertCreateExptTemplateReq(req *expt.CreateExperimentTemplateRequest) (*e
3434

3535
// 转换模板配置
3636
templateConf := &entity.ExptTemplateConfiguration{
37-
EnableWeightedScore: gptr.Indirect(req.EnableWeightedScore),
38-
EvaluatorScoreWeights: req.GetEvaluatorScoreWeights(),
39-
ItemConcurNum: ptr.ConvIntPtr[int32, int](req.DefaultItemConcurNum),
40-
EvaluatorsConcurNum: ptr.ConvIntPtr[int32, int](req.DefaultEvaluatorsConcurNum),
37+
ItemConcurNum: ptr.ConvIntPtr[int32, int](req.DefaultItemConcurNum),
38+
EvaluatorsConcurNum: ptr.ConvIntPtr[int32, int](req.DefaultEvaluatorsConcurNum),
4139
}
4240

4341
// 构建 ConnectorConf
@@ -51,7 +49,9 @@ func ConvertCreateExptTemplateReq(req *expt.CreateExperimentTemplateRequest) (*e
5149

5250
if len(evaluatorFieldMapping) > 0 {
5351
templateConf.ConnectorConf.EvaluatorsConf = &entity.EvaluatorsConf{
54-
EvaluatorConf: evaluatorFieldMapping,
52+
EvaluatorConf: evaluatorFieldMapping,
53+
EnableWeightedScore: gptr.Indirect(req.EnableWeightedScore),
54+
EvaluatorScoreWeights: req.GetEvaluatorScoreWeights(),
5555
}
5656
}
5757
}

backend/modules/evaluation/domain/entity/expt.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,6 @@ func (e *ExptEvaluatorVersionRef) String() string {
160160
type EvaluationConfiguration struct {
161161
ConnectorConf Connector
162162
ItemConcurNum *int
163-
// 评估器得分加权配置
164-
EnableWeightedScore bool
165-
EvaluatorScoreWeights map[int64]float64
166163
}
167164

168165
type Connector struct {
@@ -196,6 +193,9 @@ type TargetIngressConf struct {
196193
type EvaluatorsConf struct {
197194
EvaluatorConcurNum *int
198195
EvaluatorConf []*EvaluatorConf
196+
// 评估器得分加权配置(移动自 EvaluationConfiguration)
197+
EnableWeightedScore bool
198+
EvaluatorScoreWeights map[int64]float64
199199
}
200200

201201
func (e *EvaluatorsConf) Valid(ctx context.Context) error {

backend/modules/evaluation/domain/entity/expt_result.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ const (
3434
// 标注项, FieldKey为TagKeyID
3535
FieldType_Annotation FieldType = 23
3636

37-
// 加权得分, FieldKey为expt_id, value为weightedScore
37+
// 加权得分, FieldKey为expt_id
3838
FieldType_WeightedScore FieldType = 24
3939
)
4040

backend/modules/evaluation/domain/entity/expt_template.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,6 @@ type ExptTemplateConfiguration struct {
5454
ConnectorConf Connector
5555
ItemConcurNum *int
5656

57-
// 评估器得分加权配置
58-
EnableWeightedScore bool
59-
EvaluatorScoreWeights map[int64]float64
60-
6157
// 默认评估器并发数
6258
EvaluatorsConcurNum *int
6359
}

backend/modules/evaluation/domain/service/expt_result_aggr_impl.go

Lines changed: 122 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ func (e *ExptAggrResultServiceImpl) CreateExptAggrResult(ctx context.Context, sp
119119

120120
func (e *ExptAggrResultServiceImpl) createExptAggrResult(ctx context.Context, spaceID, experimentID int64, evaluatorVersionID2AggregatorGroup map[int64]*AggregatorGroup) error {
121121
exptAggrResults := make([]*entity.ExptAggrResult, 0)
122+
122123
for evaluatorVersionID, aggregatorGroup := range evaluatorVersionID2AggregatorGroup {
123124
aggrResult := aggregatorGroup.Result()
124125
var averageScore float64
@@ -143,7 +144,18 @@ func (e *ExptAggrResultServiceImpl) createExptAggrResult(ctx context.Context, sp
143144
})
144145
}
145146

146-
err := e.exptAggrResultRepo.BatchCreateExptAggrResult(ctx, exptAggrResults)
147+
// 追加“加权得分”聚合指标(FieldType_WeightedScore):
148+
// 基于行级 WeightedScore 做聚合(加权评分的聚合),而不是对各评估器聚合结果再加权。
149+
experiment, err := e.experimentRepo.GetByID(ctx, experimentID, spaceID)
150+
if err == nil && experiment != nil && experiment.EvalConf != nil && experiment.EvalConf.EnableWeightedScore {
151+
if weightedAggr, err := e.createWeightedScoreAggrResult(ctx, spaceID, experimentID); err != nil {
152+
return err
153+
} else if weightedAggr != nil {
154+
exptAggrResults = append(exptAggrResults, weightedAggr)
155+
}
156+
}
157+
158+
err = e.exptAggrResultRepo.BatchCreateExptAggrResult(ctx, exptAggrResults)
147159
if err != nil {
148160
return err
149161
}
@@ -153,6 +165,77 @@ func (e *ExptAggrResultServiceImpl) createExptAggrResult(ctx context.Context, sp
153165
return nil
154166
}
155167

168+
// createWeightedScoreAggrResult 基于行级 WeightedScore 计算聚合指标
169+
// 只统计成功的轮次(TurnRunState_Success)
170+
func (e *ExptAggrResultServiceImpl) createWeightedScoreAggrResult(ctx context.Context, spaceID, experimentID int64) (*entity.ExptAggrResult, error) {
171+
const (
172+
limit = int64(500)
173+
maxTry = 10000
174+
)
175+
176+
aggGroup := NewAggregatorGroup(WithScoreDistributionAggregator())
177+
var (
178+
cursor int64
179+
hasData bool
180+
)
181+
182+
for i := 0; i < maxTry; i++ {
183+
turnResults, nextCursor, err := e.exptTurnResultRepo.ScanTurnResults(
184+
ctx,
185+
experimentID,
186+
[]int32{int32(entity.TurnRunState_Success)},
187+
cursor,
188+
limit,
189+
spaceID,
190+
)
191+
if err != nil {
192+
return nil, err
193+
}
194+
if len(turnResults) == 0 {
195+
break
196+
}
197+
198+
for _, tr := range turnResults {
199+
aggGroup.Append(tr.WeightedScore)
200+
hasData = true
201+
}
202+
203+
if nextCursor == 0 || nextCursor == cursor {
204+
break
205+
}
206+
cursor = nextCursor
207+
}
208+
209+
if !hasData {
210+
return nil, nil
211+
}
212+
213+
aggrResult := aggGroup.Result()
214+
var averageScore float64
215+
for _, r := range aggrResult.AggregatorResults {
216+
if r.AggregatorType == entity.Average {
217+
averageScore = r.GetScore()
218+
break
219+
}
220+
}
221+
222+
aggrBytes, err := json.Marshal(aggrResult)
223+
if err != nil {
224+
return nil, err
225+
}
226+
227+
return &entity.ExptAggrResult{
228+
SpaceID: spaceID,
229+
ExperimentID: experimentID,
230+
FieldType: int32(entity.FieldType_WeightedScore),
231+
// 约定 FieldKey 为 experimentID
232+
FieldKey: strconv.FormatInt(experimentID, 10),
233+
Score: averageScore,
234+
AggrResult: aggrBytes,
235+
Version: 0,
236+
}, nil
237+
}
238+
156239
func (e *ExptAggrResultServiceImpl) UpdateExptAggrResult(ctx context.Context, param *entity.UpdateExptAggrResultParam) (err error) {
157240
now := time.Now().Unix()
158241
defer func() {
@@ -307,8 +390,10 @@ func (e *ExptAggrResultServiceImpl) BatchGetExptAggrResultByExperimentIDs(ctx co
307390
for exptID, exptResult := range expt2AggrResults {
308391
evaluatorResults := make(map[int64]*entity.EvaluatorAggregateResult)
309392
annotationResults := make(map[int64]*entity.AnnotationAggregateResult)
393+
var weightedResults []*entity.AggregatorResult
310394

311395
for _, fieldResult := range exptResult {
396+
// 标注类聚合
312397
if fieldResult.FieldType == int32(entity.FieldType_Annotation) {
313398
tagKeyID, err := strconv.ParseInt(fieldResult.FieldKey, 10, 64)
314399
if err != nil {
@@ -332,46 +417,55 @@ func (e *ExptAggrResultServiceImpl) BatchGetExptAggrResultByExperimentIDs(ctx co
332417
annotationResults[tagKeyID] = annotationResult
333418
}
334419

335-
if fieldResult.FieldType != int32(entity.FieldType_EvaluatorScore) {
336-
continue
337-
}
420+
// 评估器聚合得分
421+
if fieldResult.FieldType == int32(entity.FieldType_EvaluatorScore) {
422+
evaluatorVersionID, err := strconv.ParseInt(fieldResult.FieldKey, 10, 64)
423+
if err != nil {
424+
return nil, fmt.Errorf("failed to parse evaluator version id from field key %s, err: %v", fieldResult.FieldKey, err)
425+
}
338426

339-
evaluatorVersionID, err := strconv.ParseInt(fieldResult.FieldKey, 10, 64)
340-
if err != nil {
341-
return nil, fmt.Errorf("failed to parse evaluator version id from field key %s, err: %v", fieldResult.FieldKey, err)
342-
}
427+
aggregateResultDO := entity.AggregateResult{}
428+
err = json.Unmarshal(fieldResult.AggrResult, &aggregateResultDO)
429+
if err != nil {
430+
return nil, fmt.Errorf("json.Unmarshal(%s) failed, err: %v", fieldResult.AggrResult, err)
431+
}
343432

344-
aggregateResultDO := entity.AggregateResult{}
345-
err = json.Unmarshal(fieldResult.AggrResult, &aggregateResultDO)
346-
if err != nil {
347-
return nil, fmt.Errorf("json.Unmarshal(%s) failed, err: %v", fieldResult.AggrResult, err)
348-
}
433+
evaluator, ok := versionID2Evaluator[evaluatorVersionID]
434+
if !ok {
435+
return nil, fmt.Errorf("failed to get evaluator by version_id %d", evaluatorVersionID)
436+
}
349437

350-
evaluator, ok := versionID2Evaluator[evaluatorVersionID]
351-
if !ok {
352-
return nil, fmt.Errorf("failed to get evaluator by version_id %d", evaluatorVersionID)
438+
evaluatorAggrResult := entity.EvaluatorAggregateResult{
439+
EvaluatorID: evaluator.ID,
440+
EvaluatorVersionID: evaluatorVersionID,
441+
AggregatorResults: aggregateResultDO.AggregatorResults,
442+
Name: gptr.Of(evaluator.Name),
443+
Version: gptr.Of(evaluator.GetVersion()),
444+
}
445+
evaluatorResults[evaluatorVersionID] = &evaluatorAggrResult
446+
continue
353447
}
354448

355-
evaluatorAggrResult := entity.EvaluatorAggregateResult{
356-
EvaluatorID: evaluator.ID,
357-
EvaluatorVersionID: evaluatorVersionID,
358-
AggregatorResults: aggregateResultDO.AggregatorResults,
359-
Name: gptr.Of(evaluator.Name),
360-
Version: gptr.Of(evaluator.GetVersion()),
449+
// 加权得分聚合(FieldType_WeightedScore):直接使用预先计算好的加权评分聚合结果
450+
if fieldResult.FieldType == int32(entity.FieldType_WeightedScore) {
451+
aggregateResultDO := entity.AggregateResult{}
452+
if err := json.Unmarshal(fieldResult.AggrResult, &aggregateResultDO); err != nil {
453+
return nil, fmt.Errorf("json.Unmarshal(%s) failed, err: %v", fieldResult.AggrResult, err)
454+
}
455+
weightedResults = aggregateResultDO.AggregatorResults
456+
continue
361457
}
362-
evaluatorResults[evaluatorVersionID] = &evaluatorAggrResult
363-
364458
}
459+
365460
exptAgg := &entity.ExptAggregateResult{
366461
ExperimentID: exptID,
367462
EvaluatorResults: evaluatorResults,
368463
AnnotationResults: annotationResults,
369464
}
370465

371-
// 计算所有聚合指标的加权结果(如 avg、p99 等)
372-
experiment, err := e.experimentRepo.GetByID(ctx, exptID, spaceID)
373-
if err == nil && experiment != nil && experiment.EvalConf != nil && experiment.EvalConf.EnableWeightedScore {
374-
exptAgg.WeightedResults = e.calculateWeightedAggregateResults(evaluatorResults, experiment.EvalConf.EvaluatorScoreWeights)
466+
// 将加权聚合指标挂到结果中
467+
if len(weightedResults) > 0 {
468+
exptAgg.WeightedResults = weightedResults
375469
}
376470

377471
results = append(results, exptAgg)

backend/modules/evaluation/domain/service/expt_result_impl.go

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,9 @@ func (e ExptResultServiceImpl) RecordItemRunLogs(ctx context.Context, exptID, ex
171171
scoreWeights map[int64]float64
172172
)
173173
expt, err := e.ExperimentRepo.GetByID(ctx, exptID, spaceID)
174-
if err == nil && expt != nil && expt.EvalConf != nil && expt.EvalConf.EnableWeightedScore {
174+
if err == nil && expt != nil && expt.EvalConf != nil && expt.EvalConf.ConnectorConf.EvaluatorsConf != nil && expt.EvalConf.ConnectorConf.EvaluatorsConf.EnableWeightedScore {
175175
enableWeightedScore = true
176-
scoreWeights = expt.EvalConf.EvaluatorScoreWeights
176+
scoreWeights = expt.EvalConf.ConnectorConf.EvaluatorsConf.EvaluatorScoreWeights
177177
}
178178

179179
var (
@@ -1364,10 +1364,43 @@ func calculateWeightedScore(
13641364
evaluatorRecords map[int64]*entity.EvaluatorRecord,
13651365
weights map[int64]float64,
13661366
) *float64 {
1367-
if len(evaluatorRecords) == 0 || len(weights) == 0 {
1367+
if len(evaluatorRecords) == 0 {
13681368
return nil
13691369
}
13701370

1371+
// 如果未配置权重(weights 为空),则按所有评估器权重相同计算加权分(即简单平均)
1372+
if len(weights) == 0 {
1373+
var (
1374+
sumScore float64
1375+
cnt int
1376+
)
1377+
for _, record := range evaluatorRecords {
1378+
if record == nil {
1379+
continue
1380+
}
1381+
// 获取评估器分数(优先使用修正分数)
1382+
var score *float64
1383+
if record.EvaluatorOutputData != nil && record.EvaluatorOutputData.EvaluatorResult != nil {
1384+
if record.EvaluatorOutputData.EvaluatorResult.Correction != nil &&
1385+
record.EvaluatorOutputData.EvaluatorResult.Correction.Score != nil {
1386+
score = record.EvaluatorOutputData.EvaluatorResult.Correction.Score
1387+
} else if record.EvaluatorOutputData.EvaluatorResult.Score != nil {
1388+
score = record.EvaluatorOutputData.EvaluatorResult.Score
1389+
}
1390+
}
1391+
if score == nil {
1392+
continue
1393+
}
1394+
sumScore += *score
1395+
cnt++
1396+
}
1397+
if cnt == 0 {
1398+
return nil
1399+
}
1400+
avg := sumScore / float64(cnt)
1401+
return &avg
1402+
}
1403+
13711404
var totalWeightedScore float64
13721405
var totalWeight float64
13731406
hasValidScore := false

idl/thrift/coze/loop/evaluation/domain/expt.thrift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ struct Experiment {
6363
43: optional string source_id
6464

6565
50: optional ExptTemplate expt_template
66+
// 评估器得分加权配置
67+
51: optional bool enable_weighted_score
68+
52: optional map<i64, double> evaluator_score_weights
6669
}
6770

6871
// 离线实验模板,用于预先配置评测对象、评测集与评估器,并在创建实验时复用

0 commit comments

Comments
 (0)