diff --git a/backend/modules/observability/domain/task/service/task_service.go b/backend/modules/observability/domain/task/service/task_service.go index 3d64543f1..f7737921f 100644 --- a/backend/modules/observability/domain/task/service/task_service.go +++ b/backend/modules/observability/domain/task/service/task_service.go @@ -213,6 +213,16 @@ func (t *TaskServiceImpl) UpdateTask(ctx context.Context, req *UpdateTaskReq) (e } if event != nil { + if event.After == entity.TaskStatusRunning && event.Before != entity.TaskStatusRunning { + if err := t.TaskRepo.AddNonFinalTask(ctx, strconv.FormatInt(taskDO.WorkspaceID, 10), taskDO.ID); err != nil { + logs.CtxError(ctx, "add non final task failed, task_id=%d, err=%v", taskDO.ID, err) + } + } + if event.Before == entity.TaskStatusRunning && event.After == entity.TaskStatusPending { + if err := t.TaskRepo.RemoveNonFinalTask(ctx, strconv.FormatInt(taskDO.WorkspaceID, 10), taskDO.ID); err != nil { + logs.CtxError(ctx, "remove non final task failed, task_id=%d, err=%v", taskDO.ID, err) + } + } if event.After == entity.TaskStatusDisabled { // 禁用操作处理 proc := t.taskProcessor.GetTaskProcessor(taskDO.TaskType) diff --git a/backend/modules/observability/domain/task/service/task_service_test.go b/backend/modules/observability/domain/task/service/task_service_test.go index 85b8d235a..c8e9300ca 100755 --- a/backend/modules/observability/domain/task/service/task_service_test.go +++ b/backend/modules/observability/domain/task/service/task_service_test.go @@ -442,6 +442,156 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) { }) assert.EqualError(t, err, "finish fail") }) + + t.Run("running to pending removes cache and skips finish", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + repoMock := repomocks.NewMockITaskRepo(ctrl) + taskDO := &entity.ObservabilityTask{ + ID: 1, + WorkspaceID: 2, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusRunning, + Sampler: &entity.Sampler{}, + EffectiveTime: &entity.EffectiveTime{ + StartAt: time.Now().Add(-time.Minute).UnixMilli(), + EndAt: time.Now().Add(time.Minute).UnixMilli(), + }, + TaskRuns: []*entity.TaskRun{{RunStatus: entity.TaskRunStatusRunning}}, + } + + repoMock.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Any(), gomock.Nil()).Return(taskDO, nil) + repoMock.EXPECT().RemoveNonFinalTask(gomock.Any(), "2", int64(1)).Return(nil) + repoMock.EXPECT().UpdateTask(gomock.Any(), taskDO).Return(nil) + + procMock := &fakeProcessor{onTaskRunFinishedErr: errors.New("finish fail")} + tp := processor.NewTaskProcessor() + tp.Register(entity.TaskTypeAutoEval, procMock) + svc := &TaskServiceImpl{TaskRepo: repoMock, taskProcessor: *tp} + + err := svc.UpdateTask(session.WithCtxUser(context.Background(), &session.User{ID: "user"}), &UpdateTaskReq{ + TaskID: 1, + WorkspaceID: 2, + TaskStatus: gptr.Of(entity.TaskStatusPending), + UserID: "user", + }) + assert.NoError(t, err) + assert.Equal(t, entity.TaskStatusPending, taskDO.TaskStatus) + }) + + t.Run("running to pending remove cache error", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + repoMock := repomocks.NewMockITaskRepo(ctrl) + taskDO := &entity.ObservabilityTask{ + ID: 1, + WorkspaceID: 2, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusRunning, + Sampler: &entity.Sampler{}, + EffectiveTime: &entity.EffectiveTime{ + StartAt: time.Now().Add(-time.Minute).UnixMilli(), + EndAt: time.Now().Add(time.Minute).UnixMilli(), + }, + TaskRuns: []*entity.TaskRun{{RunStatus: entity.TaskRunStatusRunning}}, + } + + repoMock.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Any(), gomock.Nil()).Return(taskDO, nil) + repoMock.EXPECT().RemoveNonFinalTask(gomock.Any(), "2", int64(1)).Return(errors.New("remove cache fail")) + repoMock.EXPECT().UpdateTask(gomock.Any(), taskDO).Return(nil) + + procMock := &fakeProcessor{} + tp := processor.NewTaskProcessor() + tp.Register(entity.TaskTypeAutoEval, procMock) + svc := &TaskServiceImpl{TaskRepo: repoMock, taskProcessor: *tp} + + err := svc.UpdateTask(session.WithCtxUser(context.Background(), &session.User{ID: "user"}), &UpdateTaskReq{ + TaskID: 1, + WorkspaceID: 2, + TaskStatus: gptr.Of(entity.TaskStatusPending), + UserID: "user", + }) + assert.NoError(t, err) + assert.Equal(t, entity.TaskStatusPending, taskDO.TaskStatus) + }) + + t.Run("pending to running adds cache", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + repoMock := repomocks.NewMockITaskRepo(ctrl) + taskDO := &entity.ObservabilityTask{ + ID: 1, + WorkspaceID: 2, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusPending, + Sampler: &entity.Sampler{}, + EffectiveTime: &entity.EffectiveTime{ + StartAt: time.Now().Add(-time.Minute).UnixMilli(), + EndAt: time.Now().Add(time.Minute).UnixMilli(), + }, + } + + repoMock.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Any(), gomock.Nil()).Return(taskDO, nil) + repoMock.EXPECT().AddNonFinalTask(gomock.Any(), "2", int64(1)).Return(nil) + repoMock.EXPECT().UpdateTask(gomock.Any(), taskDO).Return(nil) + + procMock := &fakeProcessor{} + tp := processor.NewTaskProcessor() + tp.Register(entity.TaskTypeAutoEval, procMock) + svc := &TaskServiceImpl{TaskRepo: repoMock, taskProcessor: *tp} + + err := svc.UpdateTask(session.WithCtxUser(context.Background(), &session.User{ID: "user"}), &UpdateTaskReq{ + TaskID: 1, + WorkspaceID: 2, + TaskStatus: gptr.Of(entity.TaskStatusRunning), + UserID: "user", + }) + assert.NoError(t, err) + assert.Equal(t, entity.TaskStatusRunning, taskDO.TaskStatus) + }) + + t.Run("pending to running add cache error", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + repoMock := repomocks.NewMockITaskRepo(ctrl) + taskDO := &entity.ObservabilityTask{ + ID: 1, + WorkspaceID: 2, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusPending, + Sampler: &entity.Sampler{}, + EffectiveTime: &entity.EffectiveTime{ + StartAt: time.Now().Add(-time.Minute).UnixMilli(), + EndAt: time.Now().Add(time.Minute).UnixMilli(), + }, + } + + repoMock.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Any(), gomock.Nil()).Return(taskDO, nil) + repoMock.EXPECT().AddNonFinalTask(gomock.Any(), "2", int64(1)).Return(errors.New("add cache fail")) + repoMock.EXPECT().UpdateTask(gomock.Any(), taskDO).Return(nil) + + procMock := &fakeProcessor{} + tp := processor.NewTaskProcessor() + tp.Register(entity.TaskTypeAutoEval, procMock) + svc := &TaskServiceImpl{TaskRepo: repoMock, taskProcessor: *tp} + + err := svc.UpdateTask(session.WithCtxUser(context.Background(), &session.User{ID: "user"}), &UpdateTaskReq{ + TaskID: 1, + WorkspaceID: 2, + TaskStatus: gptr.Of(entity.TaskStatusRunning), + UserID: "user", + }) + assert.NoError(t, err) + assert.Equal(t, entity.TaskStatusRunning, taskDO.TaskStatus) + }) } func TestTaskServiceImpl_ListTasks(t *testing.T) { diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go index c87a5bd96..dd7d082a1 100644 --- a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go @@ -127,7 +127,7 @@ func (p *AutoEvaluateProcessor) Invoke(ctx context.Context, trigger *taskexe.Tri _ = p.taskRepo.DecrTaskRunCount(ctx, trigger.Task.ID, taskRun.ID, taskTTL) return nil } - _, err := p.evaluationSvc.InvokeExperiment(ctx, &rpc.InvokeExperimentReq{ + addedItems, err := p.evaluationSvc.InvokeExperiment(ctx, &rpc.InvokeExperimentReq{ WorkspaceID: workspaceID, EvaluationSetID: taskRun.GetTaskRunConfig().GetAutoEvaluateRunConfig().GetEvalID(), Items: []*eval_set.EvaluationSetItem{ @@ -171,6 +171,11 @@ func (p *AutoEvaluateProcessor) Invoke(ctx context.Context, trigger *taskexe.Tri } return err } + if addedItems <= 0 { + _ = p.taskRepo.DecrTaskCount(ctx, trigger.Task.ID, taskTTL) + _ = p.taskRepo.DecrTaskRunCount(ctx, trigger.Task.ID, taskRun.ID, taskTTL) + return nil + } return nil } diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go index c716fec39..b2260f19b 100755 --- a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go @@ -20,6 +20,7 @@ import ( "github.com/cloudwego/kitex/client/callopt" "github.com/coze-dev/coze-loop/backend/infra/middleware/session" + datadataset "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/data/domain/dataset" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/domain/common" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/expt" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/dataset" @@ -457,7 +458,12 @@ func TestAutoEvaluateProcessor_Invoke_WithEvaluationProvider_SuccessAddedItems(t repoMock.EXPECT().DecrTaskCount(gomock.Any(), gomock.Any(), gomock.Any()).Times(0) repoMock.EXPECT().DecrTaskRunCount(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) - client := &fakeExperimentClient{invokeResp: &expt.InvokeExperimentResponse{AddedItems: map[int64]int64{1: 1, 2: 1}}} + client := &fakeExperimentClient{invokeResp: &expt.InvokeExperimentResponse{ + ItemOutputs: []*datadataset.CreateDatasetItemOutput{ + {IsNewItem: gptr.Of(true)}, + {IsNewItem: gptr.Of(true)}, + }, + }} provider := evalrpc.NewEvaluationRPCProvider(client) proc := &AutoEvaluateProcessor{evaluationSvc: provider, taskRepo: repoAdapter} err := proc.Invoke(context.Background(), trigger) @@ -765,6 +771,34 @@ func TestAutoEvaluateProcessor_Invoke(t *testing.T) { err := proc.Invoke(context.Background(), trigger) assert.NoError(t, err) }) + + t.Run("success but addedItems is zero", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + taskObj := buildTestTask(t) + taskObj.Sampler.SampleSize = 5 + trigger := buildTrigger(taskObj, textSchema) + + repoMock := repomocks.NewMockITaskRepo(ctrl) + repoAdapter := &taskRepoMockAdapter{MockITaskRepo: repoMock} + repoMock.EXPECT().IncrTaskCount(gomock.Any(), taskObj.ID, gomock.Any()).Return(nil) + repoMock.EXPECT().IncrTaskRunCount(gomock.Any(), taskObj.ID, trigger.TaskRun.ID, gomock.Any()).Return(nil) + repoMock.EXPECT().GetTaskCount(gomock.Any(), taskObj.ID).Return(int64(1), nil) + repoMock.EXPECT().GetTaskRunCount(gomock.Any(), taskObj.ID, trigger.TaskRun.ID).Return(int64(1), nil) + repoMock.EXPECT().DecrTaskCount(gomock.Any(), taskObj.ID, gomock.Any()).Return(nil) + repoMock.EXPECT().DecrTaskRunCount(gomock.Any(), taskObj.ID, trigger.TaskRun.ID, gomock.Any()).Return(nil) + + evalMock := &fakeEvaluationAdapter{} + evalMock.invokeResp.addedItems = 0 + + proc := &AutoEvaluateProcessor{ + evaluationSvc: evalMock, + taskRepo: repoAdapter, + } + err := proc.Invoke(context.Background(), trigger) + assert.NoError(t, err) + }) } func TestAutoEvaluateProcessor_OnUpdateTaskChange(t *testing.T) { diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go index 819051c20..461dd5121 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go @@ -308,7 +308,10 @@ func (h *TraceHubServiceImpl) combineFilters(filters ...*loop_span.FilterFields) // fetchSpans paginates span data func (h *TraceHubServiceImpl) fetchSpans(ctx context.Context, listParam *repo.ListSpansParam, sub *spanSubscriber) ([]*loop_span.Span, string, error) { - result, err := h.traceRepo.ListSpans(ctx, listParam) + // 默认 30s to 60s 减少超时报错情况 + listCtx, cancel := context.WithTimeout(ctx, 60*time.Second) + defer cancel() + result, err := h.traceRepo.ListSpans(listCtx, listParam) if err != nil { logs.CtxError(ctx, "List spans failed, parma=%v, err=%v", listParam, err) return nil, "", err diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go index 021515a31..bf023de55 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go @@ -5,6 +5,7 @@ package tracehub import ( "context" + "errors" "fmt" "time" @@ -14,9 +15,55 @@ import ( "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" "github.com/coze-dev/coze-loop/backend/pkg/logs" "github.com/hashicorp/go-multierror" - "github.com/pkg/errors" + pkgerrors "github.com/pkg/errors" ) +const ( + taskRunCreateLockKeyTemplate = "observability:task_run:create:%d:%s:%d:%d" + taskRunCreateLockTTL = 30 * time.Second +) + +var errSkipSubscriber = errors.New("skip subscriber") + +func calcTaskRunEndAt(t *entity.ObservabilityTask, runStartAt int64) int64 { + if !t.Sampler.IsCycle { + return t.EffectiveTime.EndAt + } + switch t.Sampler.CycleTimeUnit { + case entity.TimeUnitDay: + return runStartAt + t.Sampler.CycleInterval*24*time.Hour.Milliseconds() + case entity.TimeUnitWeek: + return runStartAt + t.Sampler.CycleInterval*7*24*time.Hour.Milliseconds() + default: + return runStartAt + t.Sampler.CycleInterval*10*time.Minute.Milliseconds() + } +} + +func (h *TraceHubServiceImpl) withTaskRunCreateLock( + ctx context.Context, + taskID int64, + runType entity.TaskRunType, + runStartAt int64, + runEndAt int64, + fn func() error, +) error { + if h.locker == nil { + return fn() + } + key := fmt.Sprintf(taskRunCreateLockKeyTemplate, taskID, runType, runStartAt, runEndAt) + locked, err := h.locker.Lock(ctx, key, taskRunCreateLockTTL) + if err != nil { + return err + } + if !locked { + return nil + } + defer func() { + _, _ = h.locker.Unlock(key) + }() + return fn() +} + func (h *TraceHubServiceImpl) SpanTrigger(ctx context.Context, span *loop_span.Span) error { logSuffix := fmt.Sprintf("log_id=%s, trace_id=%s, span_id=%s", span.LogID, span.TraceID, span.SpanID) @@ -76,9 +123,15 @@ func (h *TraceHubServiceImpl) buildSubscriberOfSpan(ctx context.Context, span *l if !cfg.IsAllSpace && !gslice.Contains(cfg.SpaceList, taskDO.WorkspaceID) { continue } + if taskDO.EffectiveTime == nil || taskDO.EffectiveTime.StartAt == 0 { continue } + + if taskDO.TaskStatus == entity.TaskStatusPending { + continue + } + if span.StartTime < taskDO.EffectiveTime.StartAt { logs.CtxInfo(ctx, "span start time is before task cycle start time, trace_id=%s, span_id=%s", span.TraceID, span.SpanID) continue @@ -110,7 +163,7 @@ func (h *TraceHubServiceImpl) buildSubscriberOfSpan(ctx context.Context, span *l ok, err := s.Match(ctx, span) logs.CtxInfo(ctx, "Match span, task_id=%d, trace_id=%s, span_id=%s, ok=%v, err=%v", s.taskID, span.TraceID, span.SpanID, ok, err) if err != nil { - merr = multierror.Append(merr, errors.WithMessagef(err, "match span,task_id=%d, trace_id=%s, span_id=%s", s.taskID, span.TraceID, span.SpanID)) + merr = multierror.Append(merr, pkgerrors.WithMessagef(err, "match span,task_id=%d, trace_id=%s, span_id=%s", s.taskID, span.TraceID, span.SpanID)) continue } if ok { @@ -128,56 +181,65 @@ func (h *TraceHubServiceImpl) buildSubscriberOfSpan(ctx context.Context, span *l func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, subs []*spanSubscriber) error { merr := &multierror.Error{} for _, sub := range subs { - // First step: lock for task status change - // Task run status var runStartAt, runEndAt int64 if sub.t.TaskStatus == entity.TaskStatusUnstarted { - logs.CtxWarn(ctx, "task is unstarted, need sub.Creative") runStartAt = sub.t.EffectiveTime.StartAt - if !sub.t.Sampler.IsCycle { - runEndAt = sub.t.EffectiveTime.EndAt - } else { - switch sub.t.Sampler.CycleTimeUnit { - case entity.TimeUnitDay: - runEndAt = runStartAt + (sub.t.Sampler.CycleInterval)*24*time.Hour.Milliseconds() - case entity.TimeUnitWeek: - runEndAt = runStartAt + (sub.t.Sampler.CycleInterval)*7*24*time.Hour.Milliseconds() - default: - runEndAt = runStartAt + (sub.t.Sampler.CycleInterval)*10*time.Minute.Milliseconds() + runEndAt = calcTaskRunEndAt(sub.t, runStartAt) + if err := h.withTaskRunCreateLock(ctx, sub.taskID, sub.runType, runStartAt, runEndAt, func() error { + taskRunConfig, err := h.taskRepo.GetLatestNewDataTaskRun(ctx, &sub.t.WorkspaceID, sub.taskID) + if err != nil { + return err } - } - if err := sub.Creative(ctx, runStartAt, runEndAt); err != nil { - merr = multierror.Append(merr, errors.WithMessagef(err, "task is unstarted, need sub.Creative,creative processor, task_id=%d", sub.taskID)) - continue - } - if err := sub.processor.OnTaskUpdated(ctx, sub.t, entity.TaskStatusRunning); err != nil { - logs.CtxWarn(ctx, "OnTaskUpdated, task_id=%d, err=%v", sub.taskID, err) + if taskRunConfig != nil && + taskRunConfig.RunStartAt.UnixMilli() == runStartAt && + taskRunConfig.RunEndAt.UnixMilli() == runEndAt { + if sub.t.TaskStatus != entity.TaskStatusUnstarted { + return nil + } + sub.t.TaskStatus = entity.TaskStatusRunning + if err := sub.processor.OnTaskUpdated(ctx, sub.t, entity.TaskStatusRunning); err != nil { + logs.CtxWarn(ctx, "sub.processor.OnTaskUpdated err:%v", err) + return errSkipSubscriber + } + return nil + } + if err := sub.Creative(ctx, runStartAt, runEndAt); err != nil { + return err + } + if err := sub.processor.OnTaskUpdated(ctx, sub.t, entity.TaskStatusRunning); err != nil { + logs.CtxWarn(ctx, "sub.processor.OnTaskUpdated err:%v", err) + return errSkipSubscriber + } + sub.t.TaskStatus = entity.TaskStatusRunning + return nil + }); err != nil { + if errors.Is(err, errSkipSubscriber) { + continue + } + merr = multierror.Append(merr, pkgerrors.WithMessagef(err, "task is unstarted, need sub.Creative,creative processor, task_id=%d", sub.taskID)) continue } } - // Fetch the corresponding task config + taskRunConfig, err := h.taskRepo.GetLatestNewDataTaskRun(ctx, &sub.t.WorkspaceID, sub.taskID) if err != nil { logs.CtxWarn(ctx, "GetLatestNewDataTaskRun, task_id=%d, err=%v", sub.taskID, err) continue } if taskRunConfig == nil { - logs.CtxWarn(ctx, "task run config not found, task_id=%d", sub.taskID) runStartAt = sub.t.EffectiveTime.StartAt - if !sub.t.Sampler.IsCycle { - runEndAt = sub.t.EffectiveTime.EndAt - } else { - switch sub.t.Sampler.CycleTimeUnit { - case entity.TimeUnitDay: - runEndAt = runStartAt + sub.t.Sampler.CycleInterval*24*time.Hour.Milliseconds() - case entity.TimeUnitWeek: - runEndAt = runStartAt + sub.t.Sampler.CycleInterval*7*24*time.Hour.Milliseconds() - default: - runEndAt = runStartAt + sub.t.Sampler.CycleInterval*10*time.Minute.Milliseconds() + runEndAt = calcTaskRunEndAt(sub.t, runStartAt) + if err = h.withTaskRunCreateLock(ctx, sub.taskID, sub.runType, runStartAt, runEndAt, func() error { + existing, err := h.taskRepo.GetLatestNewDataTaskRun(ctx, &sub.t.WorkspaceID, sub.taskID) + if err != nil { + return err } - } - if err = sub.Creative(ctx, runStartAt, runEndAt); err != nil { - merr = multierror.Append(merr, errors.WithMessagef(err, "task run config not found,creative processor, task_id=%d", sub.taskID)) + if existing != nil { + return nil + } + return sub.Creative(ctx, runStartAt, runEndAt) + }); err != nil { + merr = multierror.Append(merr, pkgerrors.WithMessagef(err, "task run config not found,creative processor, task_id=%d", sub.taskID)) } continue } @@ -192,7 +254,7 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, subs []*spanSubsc IsFinish: true, }); err != nil { logs.CtxWarn(ctx, "time.Now().After(endTime) Finish processor, task_id=%d", sub.taskID) - merr = multierror.Append(merr, errors.WithMessagef(err, "time.Now().After(endTime) Finish processor, task_id=%d", sub.taskID)) + merr = multierror.Append(merr, pkgerrors.WithMessagef(err, "time.Now().After(endTime) Finish processor, task_id=%d", sub.taskID)) continue } } @@ -210,7 +272,7 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, subs []*spanSubsc TaskRun: taskRunConfig, IsFinish: true, }); err != nil { - merr = multierror.Append(merr, errors.WithMessagef(err, "time.Now().After(endTime) Finish processor, task_id=%d", sub.taskID)) + merr = multierror.Append(merr, pkgerrors.WithMessagef(err, "time.Now().After(endTime) Finish processor, task_id=%d", sub.taskID)) continue } } @@ -224,13 +286,24 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, subs []*spanSubsc TaskRun: taskRunConfig, IsFinish: false, }); err != nil { - merr = multierror.Append(merr, errors.WithMessagef(err, "time.Now().After(endTime) Finish processor, task_id=%d", sub.taskID)) + merr = multierror.Append(merr, pkgerrors.WithMessagef(err, "time.Now().After(endTime) Finish processor, task_id=%d", sub.taskID)) continue } runStartAt = taskRunConfig.RunEndAt.UnixMilli() runEndAt = taskRunConfig.RunEndAt.UnixMilli() + (taskRunConfig.RunEndAt.UnixMilli() - taskRunConfig.RunStartAt.UnixMilli()) - if err := sub.Creative(ctx, runStartAt, runEndAt); err != nil { - merr = multierror.Append(merr, errors.WithMessagef(err, "time.Now().After(cycleEndTime) creative processor, task_id=%d", sub.taskID)) + if err := h.withTaskRunCreateLock(ctx, sub.taskID, sub.runType, runStartAt, runEndAt, func() error { + existing, err := h.taskRepo.GetLatestNewDataTaskRun(ctx, &sub.t.WorkspaceID, sub.taskID) + if err != nil { + return err + } + if existing != nil && + existing.RunStartAt.UnixMilli() == runStartAt && + existing.RunEndAt.UnixMilli() == runEndAt { + return nil + } + return sub.Creative(ctx, runStartAt, runEndAt) + }); err != nil { + merr = multierror.Append(merr, pkgerrors.WithMessagef(err, "time.Now().After(cycleEndTime) creative processor, task_id=%d", sub.taskID)) continue } } @@ -242,7 +315,7 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, subs []*spanSubsc TaskRun: taskRunConfig, IsFinish: false, }); err != nil { - merr = multierror.Append(merr, errors.WithMessagef(err, "time.Now().After(endTime) Finish processor, task_id=%d", sub.taskID)) + merr = multierror.Append(merr, pkgerrors.WithMessagef(err, "time.Now().After(endTime) Finish processor, task_id=%d", sub.taskID)) continue } } @@ -258,7 +331,7 @@ func (h *TraceHubServiceImpl) dispatch(ctx context.Context, span *loop_span.Span continue } if err := sub.AddSpan(ctx, span); err != nil { - merr = multierror.Append(merr, errors.WithMessagef(err, "add span to subscriber, log_id=%s, trace_id=%s, span_id=%s, task_id=%d", + merr = multierror.Append(merr, pkgerrors.WithMessagef(err, "add span to subscriber, log_id=%s, trace_id=%s, span_id=%s, task_id=%d", span.LogID, span.TraceID, span.SpanID, sub.taskID)) } else { logs.CtxInfo(ctx, "add span to subscriber, task_id=%d, log_id=%s, trace_id=%s, span_id=%s", sub.taskID, diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go index 424d095c3..d7c801a63 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go @@ -6,11 +6,14 @@ package tracehub import ( "context" "errors" + "sync" + "sync/atomic" "testing" "time" "go.uber.org/mock/gomock" + lock_mocks "github.com/coze-dev/coze-loop/backend/infra/lock/mocks" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/common" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" taskconvertor "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" @@ -19,6 +22,7 @@ import ( tenant_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/tenant/mocks" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" repo_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo/mocks" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" trace_service_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/mocks" @@ -220,7 +224,7 @@ func TestTraceHubServiceImpl_preDispatchHandlesUnstartedAndLimits(t *testing.T) RunEndAt: now.Add(-30 * time.Minute), } - mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(taskRunConfig, nil) + mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(taskRunConfig, nil).AnyTimes() mockRepo.EXPECT().GetTaskCount(gomock.Any(), taskID).Return(int64(1), nil) mockRepo.EXPECT().GetTaskRunCount(gomock.Any(), taskID, taskRunConfig.ID).Return(int64(1), nil) @@ -282,7 +286,7 @@ func TestTraceHubServiceImpl_preDispatchHandlesMissingTaskRunConfig(t *testing.T BaseInfo: &common.BaseInfo{}, }) - mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(nil, nil) + mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(nil, nil).AnyTimes() impl := &TraceHubServiceImpl{taskRepo: mockRepo} @@ -296,6 +300,350 @@ func TestTraceHubServiceImpl_preDispatchHandlesMissingTaskRunConfig(t *testing.T require.Equal(t, 0, procMock.finishChangeInvoked) } +func TestTraceHubServiceImpl_preDispatchDedupTaskRunCreateWithLock(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockRepo := repo_mocks.NewMockITaskRepo(ctrl) + mockLocker := lock_mocks.NewMockILocker(ctrl) + procMock := &stubProcessor{} + + now := time.Now() + startAt := now.Add(-10 * time.Minute).UnixMilli() + endAt := now.Add(time.Hour).UnixMilli() + workspaceID := int64(3103) + taskID := int64(3204) + + sampl := &task.Sampler{ + SampleRate: floatPtr(1), + SampleSize: int64Ptr(5), + IsCycle: boolPtr(false), + } + rule := &task.Rule{ + EffectiveTime: &task.EffectiveTime{ + StartAt: ptr.Of(startAt), + EndAt: ptr.Of(endAt), + }, + Sampler: sampl, + } + + sub := &spanSubscriber{ + taskID: taskID, + processor: procMock, + taskRepo: mockRepo, + runType: entity.TaskRunTypeNewData, + } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusUnstarted), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) + + taskRunConfig := &entity.TaskRun{ + ID: 3305, + TaskID: taskID, + WorkspaceID: workspaceID, + TaskType: entity.TaskRunTypeNewData, + RunStatus: task.TaskStatusRunning, + RunStartAt: time.UnixMilli(startAt), + RunEndAt: time.UnixMilli(endAt), + } + + getLatestCall := 0 + mockRepo.EXPECT(). + GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID). + DoAndReturn(func(context.Context, *int64, int64) (*entity.TaskRun, error) { + getLatestCall++ + if getLatestCall == 1 { + return nil, nil + } + return taskRunConfig, nil + }). + AnyTimes() + mockRepo.EXPECT().GetTaskCount(gomock.Any(), taskID).Return(int64(0), nil).AnyTimes() + mockRepo.EXPECT().GetTaskRunCount(gomock.Any(), taskID, taskRunConfig.ID).Return(int64(0), nil).AnyTimes() + + mockLocker.EXPECT().Lock(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil).AnyTimes() + mockLocker.EXPECT().Unlock(gomock.Any()).Return(true, nil).AnyTimes() + + impl := &TraceHubServiceImpl{taskRepo: mockRepo, locker: mockLocker} + + require.NoError(t, impl.preDispatch(context.Background(), []*spanSubscriber{sub})) + require.NoError(t, impl.preDispatch(context.Background(), []*spanSubscriber{sub})) + require.Equal(t, 1, len(procMock.createTaskRunReqs)) + require.Equal(t, 1, procMock.updateCallCount) +} + +func TestTraceHubServiceImpl_preDispatchConcurrent(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockRepo := repo_mocks.NewMockITaskRepo(ctrl) + mockLocker := lock_mocks.NewMockILocker(ctrl) + + now := time.Now() + startAt := now.Add(-10 * time.Minute).UnixMilli() + endAt := now.Add(time.Hour).UnixMilli() + workspaceID := int64(3301) + taskID := int64(3402) + + sampl := &task.Sampler{ + SampleRate: floatPtr(1), + SampleSize: int64Ptr(5), + IsCycle: boolPtr(false), + } + rule := &task.Rule{ + EffectiveTime: &task.EffectiveTime{ + StartAt: ptr.Of(startAt), + EndAt: ptr.Of(endAt), + }, + Sampler: sampl, + } + + // 构造基础 task DO,后续在 goroutine 中深拷贝使用 + baseTask := toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusUnstarted), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) + + taskRunConfig := &entity.TaskRun{ + ID: 3503, + TaskID: taskID, + WorkspaceID: workspaceID, + TaskType: entity.TaskRunTypeNewData, + RunStatus: task.TaskStatusRunning, + RunStartAt: time.UnixMilli(startAt), + RunEndAt: time.UnixMilli(endAt), + } + + // 并发控制状态 + var createCount int32 + var taskRunCreated atomic.Bool + + // 1. 模拟 GetLatestNewDataTaskRun: + // - 如果已经创建过 (taskRunCreated=true),返回存在的 config + // - 否则返回 nil,触发创建逻辑 + mockRepo.EXPECT(). + GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID). + DoAndReturn(func(context.Context, *int64, int64) (*entity.TaskRun, error) { + if taskRunCreated.Load() { + return taskRunConfig, nil + } + return nil, nil + }). + AnyTimes() + + // 模拟 Lock/Unlock:总是成功 + mockLocker.EXPECT().Lock(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil).AnyTimes() + mockLocker.EXPECT().Unlock(gomock.Any()).Return(true, nil).AnyTimes() + + safeProc := &concurrentStubProcessor{ + createAction: func() error { + atomic.AddInt32(&createCount, 1) + taskRunCreated.Store(true) + time.Sleep(10 * time.Millisecond) + return nil + }, + } + + mockRepo.EXPECT().GetTaskCount(gomock.Any(), taskID).Return(int64(0), nil).AnyTimes() + mockRepo.EXPECT().GetTaskRunCount(gomock.Any(), taskID, taskRunConfig.ID).Return(int64(0), nil).AnyTimes() + + impl := &TraceHubServiceImpl{taskRepo: mockRepo, locker: mockLocker} + + concurrency := 10 + var wg sync.WaitGroup + wg.Add(concurrency) + + for i := 0; i < concurrency; i++ { + go func() { + defer wg.Done() + myTask := *baseTask + + mySub := &spanSubscriber{ + taskID: taskID, + processor: safeProc, + taskRepo: mockRepo, + runType: entity.TaskRunTypeNewData, + t: &myTask, + } + _ = impl.preDispatch(context.Background(), []*spanSubscriber{mySub}) + }() + } + + wg.Wait() + require.Equal(t, int32(1), atomic.LoadInt32(&createCount)) +} + +func TestTraceHubServiceImpl_preDispatchHandlesUnstartedTaskWithExistingRunConfig(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockRepo := repo_mocks.NewMockITaskRepo(ctrl) + mockLocker := lock_mocks.NewMockILocker(ctrl) + procMock := &stubProcessor{} + + now := time.Now() + startAt := now.Add(-10 * time.Minute).UnixMilli() + endAt := now.Add(time.Hour).UnixMilli() + workspaceID := int64(3601) + taskID := int64(3702) + + sampl := &task.Sampler{ + SampleRate: floatPtr(1), + SampleSize: int64Ptr(5), + IsCycle: boolPtr(false), + } + rule := &task.Rule{ + EffectiveTime: &task.EffectiveTime{ + StartAt: ptr.Of(startAt), + EndAt: ptr.Of(endAt), + }, + Sampler: sampl, + } + + sub := &spanSubscriber{ + taskID: taskID, + processor: procMock, + taskRepo: mockRepo, + runType: entity.TaskRunTypeNewData, + } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusUnstarted), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) + + taskRunConfig := &entity.TaskRun{ + ID: 3803, + TaskID: taskID, + WorkspaceID: workspaceID, + TaskType: entity.TaskRunTypeNewData, + RunStatus: task.TaskStatusRunning, + RunStartAt: time.UnixMilli(startAt), + RunEndAt: time.UnixMilli(endAt), + } + + // 模拟 GetLatestNewDataTaskRun 返回已存在的配置 + mockRepo.EXPECT(). + GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID). + Return(taskRunConfig, nil). + AnyTimes() + + // 模拟 Lock/Unlock + mockLocker.EXPECT().Lock(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil).AnyTimes() + mockLocker.EXPECT().Unlock(gomock.Any()).Return(true, nil).AnyTimes() + + mockRepo.EXPECT().GetTaskCount(gomock.Any(), taskID).Return(int64(0), nil).AnyTimes() + mockRepo.EXPECT().GetTaskRunCount(gomock.Any(), taskID, taskRunConfig.ID).Return(int64(0), nil).AnyTimes() + + impl := &TraceHubServiceImpl{taskRepo: mockRepo, locker: mockLocker} + + err := impl.preDispatch(context.Background(), []*spanSubscriber{sub}) + require.NoError(t, err) + + require.Empty(t, procMock.createTaskRunReqs) + // 应调用 OnTaskUpdated 将状态更新为 Running + require.Equal(t, 1, procMock.updateCallCount) +} + +func TestTraceHubServiceImpl_preDispatchHandlesUnstartedTaskWithExistingRunConfig_UpdateError(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockRepo := repo_mocks.NewMockITaskRepo(ctrl) + mockLocker := lock_mocks.NewMockILocker(ctrl) + procMock := &stubProcessor{updateErr: errors.New("update fail")} + + now := time.Now() + startAt := now.Add(-10 * time.Minute).UnixMilli() + endAt := now.Add(time.Hour).UnixMilli() + workspaceID := int64(3901) + taskID := int64(4002) + + sampl := &task.Sampler{ + SampleRate: floatPtr(1), + SampleSize: int64Ptr(5), + IsCycle: boolPtr(false), + } + rule := &task.Rule{ + EffectiveTime: &task.EffectiveTime{ + StartAt: ptr.Of(startAt), + EndAt: ptr.Of(endAt), + }, + Sampler: sampl, + } + + sub := &spanSubscriber{ + taskID: taskID, + processor: procMock, + taskRepo: mockRepo, + runType: entity.TaskRunTypeNewData, + } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusUnstarted), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) + + taskRunConfig := &entity.TaskRun{ + ID: 4103, + TaskID: taskID, + WorkspaceID: workspaceID, + TaskType: entity.TaskRunTypeNewData, + RunStatus: task.TaskStatusRunning, + RunStartAt: time.UnixMilli(startAt), + RunEndAt: time.UnixMilli(endAt), + } + + mockRepo.EXPECT(). + GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID). + Return(taskRunConfig, nil). + AnyTimes() + + mockLocker.EXPECT().Lock(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil).AnyTimes() + mockLocker.EXPECT().Unlock(gomock.Any()).Return(true, nil).AnyTimes() + + impl := &TraceHubServiceImpl{taskRepo: mockRepo, locker: mockLocker} + + err := impl.preDispatch(context.Background(), []*spanSubscriber{sub}) + // 因为返回的是 errSkipSubscriber,在 loop 中会被 swallow 掉,所以外层 err 应该是 nil + require.NoError(t, err) + + require.Empty(t, procMock.createTaskRunReqs) + require.Equal(t, 1, procMock.updateCallCount) +} + +// 线程安全的桩 Processor +type concurrentStubProcessor struct { + stubProcessor // 继承其他方法的默认实现 + createAction func() error +} + +func (p *concurrentStubProcessor) OnTaskRunCreated(ctx context.Context, req taskexe.OnTaskRunCreatedReq) error { + if p.createAction != nil { + return p.createAction() + } + return nil +} + +func (p *concurrentStubProcessor) OnTaskUpdated(ctx context.Context, task *entity.ObservabilityTask, status entity.TaskStatus) error { + return nil +} + func TestTraceHubServiceImpl_preDispatchHandlesNonCycle(t *testing.T) { ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) @@ -347,7 +695,7 @@ func TestTraceHubServiceImpl_preDispatchHandlesNonCycle(t *testing.T) { RunEndAt: now.Add(30 * time.Minute), } - mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(taskRunConfig, nil) + mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(taskRunConfig, nil).AnyTimes() mockRepo.EXPECT().GetTaskCount(gomock.Any(), taskID).Return(int64(0), nil) mockRepo.EXPECT().GetTaskRunCount(gomock.Any(), taskID, taskRunConfig.ID).Return(int64(0), nil) @@ -401,7 +749,7 @@ func TestTraceHubServiceImpl_preDispatchHandlesCycleDefaultUnit(t *testing.T) { BaseInfo: &common.BaseInfo{}, }) - mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(nil, nil) + mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(nil, nil).AnyTimes() impl := &TraceHubServiceImpl{taskRepo: mockRepo} @@ -468,7 +816,7 @@ func TestTraceHubServiceImpl_preDispatchTimeLimitFinishError(t *testing.T) { RunEndAt: now.Add(-2 * time.Hour), } - mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(taskRunConfig, nil) + mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(taskRunConfig, nil).AnyTimes() mockRepo.EXPECT().GetTaskCount(gomock.Any(), taskID).Return(int64(0), nil).AnyTimes() mockRepo.EXPECT().GetTaskRunCount(gomock.Any(), taskID, taskRunConfig.ID).Return(int64(0), nil).AnyTimes() @@ -718,6 +1066,8 @@ func TestTraceHubServiceImpl_preDispatchCreativeError(t *testing.T) { BaseInfo: &common.BaseInfo{}, }) + mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(nil, nil).AnyTimes() + impl := &TraceHubServiceImpl{taskRepo: mockRepo} err := impl.preDispatch(context.Background(), []*spanSubscriber{sub}) @@ -802,7 +1152,8 @@ func TestTraceHubServiceImpl_preDispatchAggregatesErrors(t *testing.T) { BaseInfo: &common.BaseInfo{}, }) - mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), secondTaskID).Return(secondRun, nil) + mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), firstSub.taskID).Return(nil, nil).AnyTimes() + mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), secondTaskID).Return(secondRun, nil).AnyTimes() mockRepo.EXPECT().GetTaskCount(gomock.Any(), secondTaskID).Return(int64(0), nil).AnyTimes() mockRepo.EXPECT().GetTaskRunCount(gomock.Any(), secondTaskID, secondRun.ID).Return(int64(0), nil).AnyTimes() @@ -857,6 +1208,8 @@ func TestTraceHubServiceImpl_preDispatchUpdateError(t *testing.T) { BaseInfo: &common.BaseInfo{}, }) + mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(nil, nil).AnyTimes() + impl := &TraceHubServiceImpl{taskRepo: mockRepo} err := impl.preDispatch(context.Background(), []*spanSubscriber{sub}) @@ -951,7 +1304,7 @@ func TestTraceHubServiceImpl_preDispatchTaskRunConfigDay(t *testing.T) { BaseInfo: &common.BaseInfo{}, }) - mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(nil, nil) + mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(nil, nil).AnyTimes() impl := &TraceHubServiceImpl{taskRepo: mockRepo} @@ -1017,7 +1370,7 @@ func TestTraceHubServiceImpl_preDispatchCycleCreativeError(t *testing.T) { RunEndAt: now.Add(-time.Minute), } - mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(taskRunConfig, nil) + mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(taskRunConfig, nil).AnyTimes() mockRepo.EXPECT().GetTaskCount(gomock.Any(), taskID).Return(int64(0), nil) mockRepo.EXPECT().GetTaskRunCount(gomock.Any(), taskID, taskRunConfig.ID).Return(int64(0), nil) @@ -1028,3 +1381,207 @@ func TestTraceHubServiceImpl_preDispatchCycleCreativeError(t *testing.T) { require.ErrorContains(t, err, "cycle create fail") require.Equal(t, 1, len(procMock.createTaskRunReqs)) } + +func TestTraceHubServiceImpl_buildSubscriberOfSpan_Filtering(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRepo := repo_mocks.NewMockITaskRepo(ctrl) + configLoader := config_mocks.NewMockITraceConfig(ctrl) + + impl := &TraceHubServiceImpl{ + taskRepo: mockRepo, + config: configLoader, + localCache: NewLocalCache(), + taskProcessor: processor.NewTaskProcessor(), + } + + // Setup local cache to pass the initial span filter + impl.localCache.taskCache.Store("ObjListWithTask", TaskCacheInfo{ + WorkspaceIDs: []string{"space-1"}, + }) + + baseSpan := &loop_span.Span{ + TraceID: "trace-1", + SpanID: "span-1", + WorkspaceID: "space-1", + StartTime: 2000, + } + + tests := []struct { + name string + configSetup func() + taskSetup func() *entity.ObservabilityTask + spanSetup func() *loop_span.Span + }{ + { + name: "Filter by SpaceList", + configSetup: func() { + configLoader.EXPECT().GetConsumerListening(gomock.Any()).Return(&componentconfig.ConsumerListening{ + IsAllSpace: false, + SpaceList: []int64{999}, + }, nil) + }, + taskSetup: func() *entity.ObservabilityTask { + return &entity.ObservabilityTask{ + ID: 1, + WorkspaceID: 1, + TaskType: entity.TaskTypeAutoEval, + EffectiveTime: &entity.EffectiveTime{StartAt: 1000}, + } + }, + }, + { + name: "Filter by Nil EffectiveTime", + configSetup: func() { + configLoader.EXPECT().GetConsumerListening(gomock.Any()).Return(&componentconfig.ConsumerListening{ + IsAllSpace: true, + }, nil) + }, + taskSetup: func() *entity.ObservabilityTask { + return &entity.ObservabilityTask{ + ID: 2, + WorkspaceID: 1, + TaskType: entity.TaskTypeAutoEval, + EffectiveTime: nil, + } + }, + }, + { + name: "Filter by Zero StartAt", + configSetup: func() { + configLoader.EXPECT().GetConsumerListening(gomock.Any()).Return(&componentconfig.ConsumerListening{ + IsAllSpace: true, + }, nil) + }, + taskSetup: func() *entity.ObservabilityTask { + return &entity.ObservabilityTask{ + ID: 3, + WorkspaceID: 1, + TaskType: entity.TaskTypeAutoEval, + EffectiveTime: &entity.EffectiveTime{StartAt: 0}, + } + }, + }, + { + name: "Filter by Pending Status", + configSetup: func() { + configLoader.EXPECT().GetConsumerListening(gomock.Any()).Return(&componentconfig.ConsumerListening{ + IsAllSpace: true, + }, nil) + }, + taskSetup: func() *entity.ObservabilityTask { + return &entity.ObservabilityTask{ + ID: 4, + WorkspaceID: 1, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusPending, + EffectiveTime: &entity.EffectiveTime{StartAt: 1000}, + } + }, + }, + { + name: "Filter by Span StartTime before Task StartAt", + configSetup: func() { + configLoader.EXPECT().GetConsumerListening(gomock.Any()).Return(&componentconfig.ConsumerListening{ + IsAllSpace: true, + }, nil) + }, + taskSetup: func() *entity.ObservabilityTask { + return &entity.ObservabilityTask{ + ID: 5, + WorkspaceID: 1, + TaskType: entity.TaskTypeAutoEval, + EffectiveTime: &entity.EffectiveTime{StartAt: 3000}, + } + }, + spanSetup: func() *loop_span.Span { + s := *baseSpan + s.StartTime = 2000 + return &s + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.configSetup() + taskDO := tt.taskSetup() + + mockRepo.EXPECT().ListNonFinalTaskBySpaceID(gomock.Any(), "space-1").Return([]int64{taskDO.ID}, nil) + mockRepo.EXPECT().GetTaskByCache(gomock.Any(), taskDO.ID).Return(taskDO, nil) + + span := baseSpan + if tt.spanSetup != nil { + span = tt.spanSetup() + } + + err := impl.SpanTrigger(context.Background(), span) + require.NoError(t, err) + }) + } +} + +func TestTraceHubServiceImpl_withTaskRunCreateLock(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLocker := lock_mocks.NewMockILocker(ctrl) + impl := &TraceHubServiceImpl{ + locker: mockLocker, + } + + ctx := context.Background() + taskID := int64(1001) + runType := entity.TaskRunTypeNewData + runStartAt := int64(2000) + runEndAt := int64(3000) + + t.Run("Lock error", func(t *testing.T) { + mockLocker.EXPECT().Lock(gomock.Any(), gomock.Any(), gomock.Any()).Return(false, errors.New("redis error")) + + err := impl.withTaskRunCreateLock(ctx, taskID, runType, runStartAt, runEndAt, func() error { + return nil + }) + require.ErrorContains(t, err, "redis error") + }) + + t.Run("Not locked", func(t *testing.T) { + mockLocker.EXPECT().Lock(gomock.Any(), gomock.Any(), gomock.Any()).Return(false, nil) + + called := false + err := impl.withTaskRunCreateLock(ctx, taskID, runType, runStartAt, runEndAt, func() error { + called = true + return nil + }) + require.NoError(t, err) + require.False(t, called) + }) + + t.Run("Lock success", func(t *testing.T) { + mockLocker.EXPECT().Lock(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil) + mockLocker.EXPECT().Unlock(gomock.Any()).Return(true, nil) + + called := false + err := impl.withTaskRunCreateLock(ctx, taskID, runType, runStartAt, runEndAt, func() error { + called = true + return nil + }) + require.NoError(t, err) + require.True(t, called) + }) + + t.Run("Locker nil", func(t *testing.T) { + nilLockerImpl := &TraceHubServiceImpl{locker: nil} + called := false + err := nilLockerImpl.withTaskRunCreateLock(ctx, taskID, runType, runStartAt, runEndAt, func() error { + called = true + return nil + }) + require.NoError(t, err) + require.True(t, called) + }) +} diff --git a/backend/modules/observability/infra/rpc/evaluation/evaluation.go b/backend/modules/observability/infra/rpc/evaluation/evaluation.go index 6caaa7dd7..10c59d3c4 100644 --- a/backend/modules/observability/infra/rpc/evaluation/evaluation.go +++ b/backend/modules/observability/infra/rpc/evaluation/evaluation.go @@ -6,6 +6,9 @@ package evaluation import ( "context" + "github.com/bytedance/gg/gslice" + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/data/domain/dataset" + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/experimentservice" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/expt" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/rpc" @@ -83,7 +86,10 @@ func (e *EvaluationProvider) InvokeExperiment(ctx context.Context, param *rpc.In // 其他非 BizStatus 错误保留原始错误作为 cause,并包装为通用 RPC 错误 return 0, errorx.WrapByCode(err, obErrorx.CommonRPCErrorCode) } - return int64(len(resp.GetAddedItems())), nil + realAddedItems := gslice.Filter(resp.ItemOutputs, func(output *dataset.CreateDatasetItemOutput) bool { + return output.GetIsNewItem() + }) + return int64(len(realAddedItems)), nil } func (e *EvaluationProvider) FinishExperiment(ctx context.Context, param *rpc.FinishExperimentReq) (err error) {