@@ -119,6 +119,7 @@ func (e *ExptAggrResultServiceImpl) CreateExptAggrResult(ctx context.Context, sp
119119
120120func (e * ExptAggrResultServiceImpl ) createExptAggrResult (ctx context.Context , spaceID , experimentID int64 , evaluatorVersionID2AggregatorGroup map [int64 ]* AggregatorGroup ) error {
121121 exptAggrResults := make ([]* entity.ExptAggrResult , 0 )
122+
122123 for evaluatorVersionID , aggregatorGroup := range evaluatorVersionID2AggregatorGroup {
123124 aggrResult := aggregatorGroup .Result ()
124125 var averageScore float64
@@ -143,7 +144,18 @@ func (e *ExptAggrResultServiceImpl) createExptAggrResult(ctx context.Context, sp
143144 })
144145 }
145146
146- err := e .exptAggrResultRepo .BatchCreateExptAggrResult (ctx , exptAggrResults )
147+ // 追加“加权得分”聚合指标(FieldType_WeightedScore):
148+ // 基于行级 WeightedScore 做聚合(加权评分的聚合),而不是对各评估器聚合结果再加权。
149+ experiment , err := e .experimentRepo .GetByID (ctx , experimentID , spaceID )
150+ if err == nil && experiment != nil && experiment .EvalConf != nil && experiment .EvalConf .EnableWeightedScore {
151+ if weightedAggr , err := e .createWeightedScoreAggrResult (ctx , spaceID , experimentID ); err != nil {
152+ return err
153+ } else if weightedAggr != nil {
154+ exptAggrResults = append (exptAggrResults , weightedAggr )
155+ }
156+ }
157+
158+ err = e .exptAggrResultRepo .BatchCreateExptAggrResult (ctx , exptAggrResults )
147159 if err != nil {
148160 return err
149161 }
@@ -153,6 +165,77 @@ func (e *ExptAggrResultServiceImpl) createExptAggrResult(ctx context.Context, sp
153165 return nil
154166}
155167
168+ // createWeightedScoreAggrResult 基于行级 WeightedScore 计算聚合指标
169+ // 只统计成功的轮次(TurnRunState_Success)
170+ func (e * ExptAggrResultServiceImpl ) createWeightedScoreAggrResult (ctx context.Context , spaceID , experimentID int64 ) (* entity.ExptAggrResult , error ) {
171+ const (
172+ limit = int64 (500 )
173+ maxTry = 10000
174+ )
175+
176+ aggGroup := NewAggregatorGroup (WithScoreDistributionAggregator ())
177+ var (
178+ cursor int64
179+ hasData bool
180+ )
181+
182+ for i := 0 ; i < maxTry ; i ++ {
183+ turnResults , nextCursor , err := e .exptTurnResultRepo .ScanTurnResults (
184+ ctx ,
185+ experimentID ,
186+ []int32 {int32 (entity .TurnRunState_Success )},
187+ cursor ,
188+ limit ,
189+ spaceID ,
190+ )
191+ if err != nil {
192+ return nil , err
193+ }
194+ if len (turnResults ) == 0 {
195+ break
196+ }
197+
198+ for _ , tr := range turnResults {
199+ aggGroup .Append (tr .WeightedScore )
200+ hasData = true
201+ }
202+
203+ if nextCursor == 0 || nextCursor == cursor {
204+ break
205+ }
206+ cursor = nextCursor
207+ }
208+
209+ if ! hasData {
210+ return nil , nil
211+ }
212+
213+ aggrResult := aggGroup .Result ()
214+ var averageScore float64
215+ for _ , r := range aggrResult .AggregatorResults {
216+ if r .AggregatorType == entity .Average {
217+ averageScore = r .GetScore ()
218+ break
219+ }
220+ }
221+
222+ aggrBytes , err := json .Marshal (aggrResult )
223+ if err != nil {
224+ return nil , err
225+ }
226+
227+ return & entity.ExptAggrResult {
228+ SpaceID : spaceID ,
229+ ExperimentID : experimentID ,
230+ FieldType : int32 (entity .FieldType_WeightedScore ),
231+ // 约定 FieldKey 为 experimentID
232+ FieldKey : strconv .FormatInt (experimentID , 10 ),
233+ Score : averageScore ,
234+ AggrResult : aggrBytes ,
235+ Version : 0 ,
236+ }, nil
237+ }
238+
156239func (e * ExptAggrResultServiceImpl ) UpdateExptAggrResult (ctx context.Context , param * entity.UpdateExptAggrResultParam ) (err error ) {
157240 now := time .Now ().Unix ()
158241 defer func () {
@@ -307,8 +390,10 @@ func (e *ExptAggrResultServiceImpl) BatchGetExptAggrResultByExperimentIDs(ctx co
307390 for exptID , exptResult := range expt2AggrResults {
308391 evaluatorResults := make (map [int64 ]* entity.EvaluatorAggregateResult )
309392 annotationResults := make (map [int64 ]* entity.AnnotationAggregateResult )
393+ var weightedResults []* entity.AggregatorResult
310394
311395 for _ , fieldResult := range exptResult {
396+ // 标注类聚合
312397 if fieldResult .FieldType == int32 (entity .FieldType_Annotation ) {
313398 tagKeyID , err := strconv .ParseInt (fieldResult .FieldKey , 10 , 64 )
314399 if err != nil {
@@ -332,46 +417,55 @@ func (e *ExptAggrResultServiceImpl) BatchGetExptAggrResultByExperimentIDs(ctx co
332417 annotationResults [tagKeyID ] = annotationResult
333418 }
334419
335- if fieldResult .FieldType != int32 (entity .FieldType_EvaluatorScore ) {
336- continue
337- }
420+ // 评估器聚合得分
421+ if fieldResult .FieldType == int32 (entity .FieldType_EvaluatorScore ) {
422+ evaluatorVersionID , err := strconv .ParseInt (fieldResult .FieldKey , 10 , 64 )
423+ if err != nil {
424+ return nil , fmt .Errorf ("failed to parse evaluator version id from field key %s, err: %v" , fieldResult .FieldKey , err )
425+ }
338426
339- evaluatorVersionID , err := strconv .ParseInt (fieldResult .FieldKey , 10 , 64 )
340- if err != nil {
341- return nil , fmt .Errorf ("failed to parse evaluator version id from field key %s, err: %v" , fieldResult .FieldKey , err )
342- }
427+ aggregateResultDO := entity.AggregateResult {}
428+ err = json .Unmarshal (fieldResult .AggrResult , & aggregateResultDO )
429+ if err != nil {
430+ return nil , fmt .Errorf ("json.Unmarshal(%s) failed, err: %v" , fieldResult .AggrResult , err )
431+ }
343432
344- aggregateResultDO := entity.AggregateResult {}
345- err = json .Unmarshal (fieldResult .AggrResult , & aggregateResultDO )
346- if err != nil {
347- return nil , fmt .Errorf ("json.Unmarshal(%s) failed, err: %v" , fieldResult .AggrResult , err )
348- }
433+ evaluator , ok := versionID2Evaluator [evaluatorVersionID ]
434+ if ! ok {
435+ return nil , fmt .Errorf ("failed to get evaluator by version_id %d" , evaluatorVersionID )
436+ }
349437
350- evaluator , ok := versionID2Evaluator [evaluatorVersionID ]
351- if ! ok {
352- return nil , fmt .Errorf ("failed to get evaluator by version_id %d" , evaluatorVersionID )
438+ evaluatorAggrResult := entity.EvaluatorAggregateResult {
439+ EvaluatorID : evaluator .ID ,
440+ EvaluatorVersionID : evaluatorVersionID ,
441+ AggregatorResults : aggregateResultDO .AggregatorResults ,
442+ Name : gptr .Of (evaluator .Name ),
443+ Version : gptr .Of (evaluator .GetVersion ()),
444+ }
445+ evaluatorResults [evaluatorVersionID ] = & evaluatorAggrResult
446+ continue
353447 }
354448
355- evaluatorAggrResult := entity.EvaluatorAggregateResult {
356- EvaluatorID : evaluator .ID ,
357- EvaluatorVersionID : evaluatorVersionID ,
358- AggregatorResults : aggregateResultDO .AggregatorResults ,
359- Name : gptr .Of (evaluator .Name ),
360- Version : gptr .Of (evaluator .GetVersion ()),
449+ // 加权得分聚合(FieldType_WeightedScore):直接使用预先计算好的加权评分聚合结果
450+ if fieldResult .FieldType == int32 (entity .FieldType_WeightedScore ) {
451+ aggregateResultDO := entity.AggregateResult {}
452+ if err := json .Unmarshal (fieldResult .AggrResult , & aggregateResultDO ); err != nil {
453+ return nil , fmt .Errorf ("json.Unmarshal(%s) failed, err: %v" , fieldResult .AggrResult , err )
454+ }
455+ weightedResults = aggregateResultDO .AggregatorResults
456+ continue
361457 }
362- evaluatorResults [evaluatorVersionID ] = & evaluatorAggrResult
363-
364458 }
459+
365460 exptAgg := & entity.ExptAggregateResult {
366461 ExperimentID : exptID ,
367462 EvaluatorResults : evaluatorResults ,
368463 AnnotationResults : annotationResults ,
369464 }
370465
371- // 计算所有聚合指标的加权结果(如 avg、p99 等)
372- experiment , err := e .experimentRepo .GetByID (ctx , exptID , spaceID )
373- if err == nil && experiment != nil && experiment .EvalConf != nil && experiment .EvalConf .EnableWeightedScore {
374- exptAgg .WeightedResults = e .calculateWeightedAggregateResults (evaluatorResults , experiment .EvalConf .EvaluatorScoreWeights )
466+ // 将加权聚合指标挂到结果中
467+ if len (weightedResults ) > 0 {
468+ exptAgg .WeightedResults = weightedResults
375469 }
376470
377471 results = append (results , exptAgg )
0 commit comments