Skip to content

Commit 105fcc6

Browse files
authored
[fix][backend] auto task optimze (#405)
* feat(backend): add lock for dispatch * feat(backend): UT * feat(backend): UT * feat(backend): UT * feat(backend): Longer time out * feat(backend): minor fix * feat(backend): minor fix
1 parent 893fefb commit 105fcc6

8 files changed

Lines changed: 894 additions & 56 deletions

File tree

backend/modules/observability/domain/task/service/task_service.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,16 @@ func (t *TaskServiceImpl) UpdateTask(ctx context.Context, req *UpdateTaskReq) (e
213213
}
214214

215215
if event != nil {
216+
if event.After == entity.TaskStatusRunning && event.Before != entity.TaskStatusRunning {
217+
if err := t.TaskRepo.AddNonFinalTask(ctx, strconv.FormatInt(taskDO.WorkspaceID, 10), taskDO.ID); err != nil {
218+
logs.CtxError(ctx, "add non final task failed, task_id=%d, err=%v", taskDO.ID, err)
219+
}
220+
}
221+
if event.Before == entity.TaskStatusRunning && event.After == entity.TaskStatusPending {
222+
if err := t.TaskRepo.RemoveNonFinalTask(ctx, strconv.FormatInt(taskDO.WorkspaceID, 10), taskDO.ID); err != nil {
223+
logs.CtxError(ctx, "remove non final task failed, task_id=%d, err=%v", taskDO.ID, err)
224+
}
225+
}
216226
if event.After == entity.TaskStatusDisabled {
217227
// 禁用操作处理
218228
proc := t.taskProcessor.GetTaskProcessor(taskDO.TaskType)

backend/modules/observability/domain/task/service/task_service_test.go

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,156 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) {
442442
})
443443
assert.EqualError(t, err, "finish fail")
444444
})
445+
446+
t.Run("running to pending removes cache and skips finish", func(t *testing.T) {
447+
t.Parallel()
448+
ctrl := gomock.NewController(t)
449+
defer ctrl.Finish()
450+
451+
repoMock := repomocks.NewMockITaskRepo(ctrl)
452+
taskDO := &entity.ObservabilityTask{
453+
ID: 1,
454+
WorkspaceID: 2,
455+
TaskType: entity.TaskTypeAutoEval,
456+
TaskStatus: entity.TaskStatusRunning,
457+
Sampler: &entity.Sampler{},
458+
EffectiveTime: &entity.EffectiveTime{
459+
StartAt: time.Now().Add(-time.Minute).UnixMilli(),
460+
EndAt: time.Now().Add(time.Minute).UnixMilli(),
461+
},
462+
TaskRuns: []*entity.TaskRun{{RunStatus: entity.TaskRunStatusRunning}},
463+
}
464+
465+
repoMock.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Any(), gomock.Nil()).Return(taskDO, nil)
466+
repoMock.EXPECT().RemoveNonFinalTask(gomock.Any(), "2", int64(1)).Return(nil)
467+
repoMock.EXPECT().UpdateTask(gomock.Any(), taskDO).Return(nil)
468+
469+
procMock := &fakeProcessor{onTaskRunFinishedErr: errors.New("finish fail")}
470+
tp := processor.NewTaskProcessor()
471+
tp.Register(entity.TaskTypeAutoEval, procMock)
472+
svc := &TaskServiceImpl{TaskRepo: repoMock, taskProcessor: *tp}
473+
474+
err := svc.UpdateTask(session.WithCtxUser(context.Background(), &session.User{ID: "user"}), &UpdateTaskReq{
475+
TaskID: 1,
476+
WorkspaceID: 2,
477+
TaskStatus: gptr.Of(entity.TaskStatusPending),
478+
UserID: "user",
479+
})
480+
assert.NoError(t, err)
481+
assert.Equal(t, entity.TaskStatusPending, taskDO.TaskStatus)
482+
})
483+
484+
t.Run("running to pending remove cache error", func(t *testing.T) {
485+
t.Parallel()
486+
ctrl := gomock.NewController(t)
487+
defer ctrl.Finish()
488+
489+
repoMock := repomocks.NewMockITaskRepo(ctrl)
490+
taskDO := &entity.ObservabilityTask{
491+
ID: 1,
492+
WorkspaceID: 2,
493+
TaskType: entity.TaskTypeAutoEval,
494+
TaskStatus: entity.TaskStatusRunning,
495+
Sampler: &entity.Sampler{},
496+
EffectiveTime: &entity.EffectiveTime{
497+
StartAt: time.Now().Add(-time.Minute).UnixMilli(),
498+
EndAt: time.Now().Add(time.Minute).UnixMilli(),
499+
},
500+
TaskRuns: []*entity.TaskRun{{RunStatus: entity.TaskRunStatusRunning}},
501+
}
502+
503+
repoMock.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Any(), gomock.Nil()).Return(taskDO, nil)
504+
repoMock.EXPECT().RemoveNonFinalTask(gomock.Any(), "2", int64(1)).Return(errors.New("remove cache fail"))
505+
repoMock.EXPECT().UpdateTask(gomock.Any(), taskDO).Return(nil)
506+
507+
procMock := &fakeProcessor{}
508+
tp := processor.NewTaskProcessor()
509+
tp.Register(entity.TaskTypeAutoEval, procMock)
510+
svc := &TaskServiceImpl{TaskRepo: repoMock, taskProcessor: *tp}
511+
512+
err := svc.UpdateTask(session.WithCtxUser(context.Background(), &session.User{ID: "user"}), &UpdateTaskReq{
513+
TaskID: 1,
514+
WorkspaceID: 2,
515+
TaskStatus: gptr.Of(entity.TaskStatusPending),
516+
UserID: "user",
517+
})
518+
assert.NoError(t, err)
519+
assert.Equal(t, entity.TaskStatusPending, taskDO.TaskStatus)
520+
})
521+
522+
t.Run("pending to running adds cache", func(t *testing.T) {
523+
t.Parallel()
524+
ctrl := gomock.NewController(t)
525+
defer ctrl.Finish()
526+
527+
repoMock := repomocks.NewMockITaskRepo(ctrl)
528+
taskDO := &entity.ObservabilityTask{
529+
ID: 1,
530+
WorkspaceID: 2,
531+
TaskType: entity.TaskTypeAutoEval,
532+
TaskStatus: entity.TaskStatusPending,
533+
Sampler: &entity.Sampler{},
534+
EffectiveTime: &entity.EffectiveTime{
535+
StartAt: time.Now().Add(-time.Minute).UnixMilli(),
536+
EndAt: time.Now().Add(time.Minute).UnixMilli(),
537+
},
538+
}
539+
540+
repoMock.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Any(), gomock.Nil()).Return(taskDO, nil)
541+
repoMock.EXPECT().AddNonFinalTask(gomock.Any(), "2", int64(1)).Return(nil)
542+
repoMock.EXPECT().UpdateTask(gomock.Any(), taskDO).Return(nil)
543+
544+
procMock := &fakeProcessor{}
545+
tp := processor.NewTaskProcessor()
546+
tp.Register(entity.TaskTypeAutoEval, procMock)
547+
svc := &TaskServiceImpl{TaskRepo: repoMock, taskProcessor: *tp}
548+
549+
err := svc.UpdateTask(session.WithCtxUser(context.Background(), &session.User{ID: "user"}), &UpdateTaskReq{
550+
TaskID: 1,
551+
WorkspaceID: 2,
552+
TaskStatus: gptr.Of(entity.TaskStatusRunning),
553+
UserID: "user",
554+
})
555+
assert.NoError(t, err)
556+
assert.Equal(t, entity.TaskStatusRunning, taskDO.TaskStatus)
557+
})
558+
559+
t.Run("pending to running add cache error", func(t *testing.T) {
560+
t.Parallel()
561+
ctrl := gomock.NewController(t)
562+
defer ctrl.Finish()
563+
564+
repoMock := repomocks.NewMockITaskRepo(ctrl)
565+
taskDO := &entity.ObservabilityTask{
566+
ID: 1,
567+
WorkspaceID: 2,
568+
TaskType: entity.TaskTypeAutoEval,
569+
TaskStatus: entity.TaskStatusPending,
570+
Sampler: &entity.Sampler{},
571+
EffectiveTime: &entity.EffectiveTime{
572+
StartAt: time.Now().Add(-time.Minute).UnixMilli(),
573+
EndAt: time.Now().Add(time.Minute).UnixMilli(),
574+
},
575+
}
576+
577+
repoMock.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Any(), gomock.Nil()).Return(taskDO, nil)
578+
repoMock.EXPECT().AddNonFinalTask(gomock.Any(), "2", int64(1)).Return(errors.New("add cache fail"))
579+
repoMock.EXPECT().UpdateTask(gomock.Any(), taskDO).Return(nil)
580+
581+
procMock := &fakeProcessor{}
582+
tp := processor.NewTaskProcessor()
583+
tp.Register(entity.TaskTypeAutoEval, procMock)
584+
svc := &TaskServiceImpl{TaskRepo: repoMock, taskProcessor: *tp}
585+
586+
err := svc.UpdateTask(session.WithCtxUser(context.Background(), &session.User{ID: "user"}), &UpdateTaskReq{
587+
TaskID: 1,
588+
WorkspaceID: 2,
589+
TaskStatus: gptr.Of(entity.TaskStatusRunning),
590+
UserID: "user",
591+
})
592+
assert.NoError(t, err)
593+
assert.Equal(t, entity.TaskStatusRunning, taskDO.TaskStatus)
594+
})
445595
}
446596

447597
func TestTaskServiceImpl_ListTasks(t *testing.T) {

backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ func (p *AutoEvaluateProcessor) Invoke(ctx context.Context, trigger *taskexe.Tri
127127
_ = p.taskRepo.DecrTaskRunCount(ctx, trigger.Task.ID, taskRun.ID, taskTTL)
128128
return nil
129129
}
130-
_, err := p.evaluationSvc.InvokeExperiment(ctx, &rpc.InvokeExperimentReq{
130+
addedItems, err := p.evaluationSvc.InvokeExperiment(ctx, &rpc.InvokeExperimentReq{
131131
WorkspaceID: workspaceID,
132132
EvaluationSetID: taskRun.GetTaskRunConfig().GetAutoEvaluateRunConfig().GetEvalID(),
133133
Items: []*eval_set.EvaluationSetItem{
@@ -171,6 +171,11 @@ func (p *AutoEvaluateProcessor) Invoke(ctx context.Context, trigger *taskexe.Tri
171171
}
172172
return err
173173
}
174+
if addedItems <= 0 {
175+
_ = p.taskRepo.DecrTaskCount(ctx, trigger.Task.ID, taskTTL)
176+
_ = p.taskRepo.DecrTaskRunCount(ctx, trigger.Task.ID, taskRun.ID, taskTTL)
177+
return nil
178+
}
174179
return nil
175180
}
176181

backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"github.com/cloudwego/kitex/client/callopt"
2121

2222
"github.com/coze-dev/coze-loop/backend/infra/middleware/session"
23+
datadataset "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/data/domain/dataset"
2324
"github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/domain/common"
2425
"github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/expt"
2526
"github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/dataset"
@@ -457,7 +458,12 @@ func TestAutoEvaluateProcessor_Invoke_WithEvaluationProvider_SuccessAddedItems(t
457458
repoMock.EXPECT().DecrTaskCount(gomock.Any(), gomock.Any(), gomock.Any()).Times(0)
458459
repoMock.EXPECT().DecrTaskRunCount(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0)
459460

460-
client := &fakeExperimentClient{invokeResp: &expt.InvokeExperimentResponse{AddedItems: map[int64]int64{1: 1, 2: 1}}}
461+
client := &fakeExperimentClient{invokeResp: &expt.InvokeExperimentResponse{
462+
ItemOutputs: []*datadataset.CreateDatasetItemOutput{
463+
{IsNewItem: gptr.Of(true)},
464+
{IsNewItem: gptr.Of(true)},
465+
},
466+
}}
461467
provider := evalrpc.NewEvaluationRPCProvider(client)
462468
proc := &AutoEvaluateProcessor{evaluationSvc: provider, taskRepo: repoAdapter}
463469
err := proc.Invoke(context.Background(), trigger)
@@ -765,6 +771,34 @@ func TestAutoEvaluateProcessor_Invoke(t *testing.T) {
765771
err := proc.Invoke(context.Background(), trigger)
766772
assert.NoError(t, err)
767773
})
774+
775+
t.Run("success but addedItems is zero", func(t *testing.T) {
776+
ctrl := gomock.NewController(t)
777+
defer ctrl.Finish()
778+
779+
taskObj := buildTestTask(t)
780+
taskObj.Sampler.SampleSize = 5
781+
trigger := buildTrigger(taskObj, textSchema)
782+
783+
repoMock := repomocks.NewMockITaskRepo(ctrl)
784+
repoAdapter := &taskRepoMockAdapter{MockITaskRepo: repoMock}
785+
repoMock.EXPECT().IncrTaskCount(gomock.Any(), taskObj.ID, gomock.Any()).Return(nil)
786+
repoMock.EXPECT().IncrTaskRunCount(gomock.Any(), taskObj.ID, trigger.TaskRun.ID, gomock.Any()).Return(nil)
787+
repoMock.EXPECT().GetTaskCount(gomock.Any(), taskObj.ID).Return(int64(1), nil)
788+
repoMock.EXPECT().GetTaskRunCount(gomock.Any(), taskObj.ID, trigger.TaskRun.ID).Return(int64(1), nil)
789+
repoMock.EXPECT().DecrTaskCount(gomock.Any(), taskObj.ID, gomock.Any()).Return(nil)
790+
repoMock.EXPECT().DecrTaskRunCount(gomock.Any(), taskObj.ID, trigger.TaskRun.ID, gomock.Any()).Return(nil)
791+
792+
evalMock := &fakeEvaluationAdapter{}
793+
evalMock.invokeResp.addedItems = 0
794+
795+
proc := &AutoEvaluateProcessor{
796+
evaluationSvc: evalMock,
797+
taskRepo: repoAdapter,
798+
}
799+
err := proc.Invoke(context.Background(), trigger)
800+
assert.NoError(t, err)
801+
})
768802
}
769803

770804
func TestAutoEvaluateProcessor_OnUpdateTaskChange(t *testing.T) {

backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,10 @@ func (h *TraceHubServiceImpl) combineFilters(filters ...*loop_span.FilterFields)
308308

309309
// fetchSpans paginates span data
310310
func (h *TraceHubServiceImpl) fetchSpans(ctx context.Context, listParam *repo.ListSpansParam, sub *spanSubscriber) ([]*loop_span.Span, string, error) {
311-
result, err := h.traceRepo.ListSpans(ctx, listParam)
311+
// 默认 30s to 60s 减少超时报错情况
312+
listCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
313+
defer cancel()
314+
result, err := h.traceRepo.ListSpans(listCtx, listParam)
312315
if err != nil {
313316
logs.CtxError(ctx, "List spans failed, parma=%v, err=%v", listParam, err)
314317
return nil, "", err

0 commit comments

Comments
 (0)