Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions backend/modules/observability/domain/task/service/task_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,16 @@ func (t *TaskServiceImpl) UpdateTask(ctx context.Context, req *UpdateTaskReq) (e
}

if event != nil {
if event.After == entity.TaskStatusRunning && event.Before != entity.TaskStatusRunning {
if err := t.TaskRepo.AddNonFinalTask(ctx, strconv.FormatInt(taskDO.WorkspaceID, 10), taskDO.ID); err != nil {
logs.CtxError(ctx, "add non final task failed, task_id=%d, err=%v", taskDO.ID, err)
}
}
if event.Before == entity.TaskStatusRunning && event.After == entity.TaskStatusPending {
if err := t.TaskRepo.RemoveNonFinalTask(ctx, strconv.FormatInt(taskDO.WorkspaceID, 10), taskDO.ID); err != nil {
logs.CtxError(ctx, "remove non final task failed, task_id=%d, err=%v", taskDO.ID, err)
}
}
if event.After == entity.TaskStatusDisabled {
// 禁用操作处理
proc := t.taskProcessor.GetTaskProcessor(taskDO.TaskType)
Expand Down
150 changes: 150 additions & 0 deletions backend/modules/observability/domain/task/service/task_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,156 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) {
})
assert.EqualError(t, err, "finish fail")
})

t.Run("running to pending removes cache and skips finish", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()

repoMock := repomocks.NewMockITaskRepo(ctrl)
taskDO := &entity.ObservabilityTask{
ID: 1,
WorkspaceID: 2,
TaskType: entity.TaskTypeAutoEval,
TaskStatus: entity.TaskStatusRunning,
Sampler: &entity.Sampler{},
EffectiveTime: &entity.EffectiveTime{
StartAt: time.Now().Add(-time.Minute).UnixMilli(),
EndAt: time.Now().Add(time.Minute).UnixMilli(),
},
TaskRuns: []*entity.TaskRun{{RunStatus: entity.TaskRunStatusRunning}},
}

repoMock.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Any(), gomock.Nil()).Return(taskDO, nil)
repoMock.EXPECT().RemoveNonFinalTask(gomock.Any(), "2", int64(1)).Return(nil)
repoMock.EXPECT().UpdateTask(gomock.Any(), taskDO).Return(nil)

procMock := &fakeProcessor{onTaskRunFinishedErr: errors.New("finish fail")}
tp := processor.NewTaskProcessor()
tp.Register(entity.TaskTypeAutoEval, procMock)
svc := &TaskServiceImpl{TaskRepo: repoMock, taskProcessor: *tp}

err := svc.UpdateTask(session.WithCtxUser(context.Background(), &session.User{ID: "user"}), &UpdateTaskReq{
TaskID: 1,
WorkspaceID: 2,
TaskStatus: gptr.Of(entity.TaskStatusPending),
UserID: "user",
})
assert.NoError(t, err)
assert.Equal(t, entity.TaskStatusPending, taskDO.TaskStatus)
})

t.Run("running to pending remove cache error", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()

repoMock := repomocks.NewMockITaskRepo(ctrl)
taskDO := &entity.ObservabilityTask{
ID: 1,
WorkspaceID: 2,
TaskType: entity.TaskTypeAutoEval,
TaskStatus: entity.TaskStatusRunning,
Sampler: &entity.Sampler{},
EffectiveTime: &entity.EffectiveTime{
StartAt: time.Now().Add(-time.Minute).UnixMilli(),
EndAt: time.Now().Add(time.Minute).UnixMilli(),
},
TaskRuns: []*entity.TaskRun{{RunStatus: entity.TaskRunStatusRunning}},
}

repoMock.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Any(), gomock.Nil()).Return(taskDO, nil)
repoMock.EXPECT().RemoveNonFinalTask(gomock.Any(), "2", int64(1)).Return(errors.New("remove cache fail"))
repoMock.EXPECT().UpdateTask(gomock.Any(), taskDO).Return(nil)

procMock := &fakeProcessor{}
tp := processor.NewTaskProcessor()
tp.Register(entity.TaskTypeAutoEval, procMock)
svc := &TaskServiceImpl{TaskRepo: repoMock, taskProcessor: *tp}

err := svc.UpdateTask(session.WithCtxUser(context.Background(), &session.User{ID: "user"}), &UpdateTaskReq{
TaskID: 1,
WorkspaceID: 2,
TaskStatus: gptr.Of(entity.TaskStatusPending),
UserID: "user",
})
assert.NoError(t, err)
assert.Equal(t, entity.TaskStatusPending, taskDO.TaskStatus)
})

t.Run("pending to running adds cache", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()

repoMock := repomocks.NewMockITaskRepo(ctrl)
taskDO := &entity.ObservabilityTask{
ID: 1,
WorkspaceID: 2,
TaskType: entity.TaskTypeAutoEval,
TaskStatus: entity.TaskStatusPending,
Sampler: &entity.Sampler{},
EffectiveTime: &entity.EffectiveTime{
StartAt: time.Now().Add(-time.Minute).UnixMilli(),
EndAt: time.Now().Add(time.Minute).UnixMilli(),
},
}

repoMock.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Any(), gomock.Nil()).Return(taskDO, nil)
repoMock.EXPECT().AddNonFinalTask(gomock.Any(), "2", int64(1)).Return(nil)
repoMock.EXPECT().UpdateTask(gomock.Any(), taskDO).Return(nil)

procMock := &fakeProcessor{}
tp := processor.NewTaskProcessor()
tp.Register(entity.TaskTypeAutoEval, procMock)
svc := &TaskServiceImpl{TaskRepo: repoMock, taskProcessor: *tp}

err := svc.UpdateTask(session.WithCtxUser(context.Background(), &session.User{ID: "user"}), &UpdateTaskReq{
TaskID: 1,
WorkspaceID: 2,
TaskStatus: gptr.Of(entity.TaskStatusRunning),
UserID: "user",
})
assert.NoError(t, err)
assert.Equal(t, entity.TaskStatusRunning, taskDO.TaskStatus)
})

t.Run("pending to running add cache error", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()

repoMock := repomocks.NewMockITaskRepo(ctrl)
taskDO := &entity.ObservabilityTask{
ID: 1,
WorkspaceID: 2,
TaskType: entity.TaskTypeAutoEval,
TaskStatus: entity.TaskStatusPending,
Sampler: &entity.Sampler{},
EffectiveTime: &entity.EffectiveTime{
StartAt: time.Now().Add(-time.Minute).UnixMilli(),
EndAt: time.Now().Add(time.Minute).UnixMilli(),
},
}

repoMock.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Any(), gomock.Nil()).Return(taskDO, nil)
repoMock.EXPECT().AddNonFinalTask(gomock.Any(), "2", int64(1)).Return(errors.New("add cache fail"))
repoMock.EXPECT().UpdateTask(gomock.Any(), taskDO).Return(nil)

procMock := &fakeProcessor{}
tp := processor.NewTaskProcessor()
tp.Register(entity.TaskTypeAutoEval, procMock)
svc := &TaskServiceImpl{TaskRepo: repoMock, taskProcessor: *tp}

err := svc.UpdateTask(session.WithCtxUser(context.Background(), &session.User{ID: "user"}), &UpdateTaskReq{
TaskID: 1,
WorkspaceID: 2,
TaskStatus: gptr.Of(entity.TaskStatusRunning),
UserID: "user",
})
assert.NoError(t, err)
assert.Equal(t, entity.TaskStatusRunning, taskDO.TaskStatus)
})
}

func TestTaskServiceImpl_ListTasks(t *testing.T) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (p *AutoEvaluateProcessor) Invoke(ctx context.Context, trigger *taskexe.Tri
_ = p.taskRepo.DecrTaskRunCount(ctx, trigger.Task.ID, taskRun.ID, taskTTL)
return nil
}
_, err := p.evaluationSvc.InvokeExperiment(ctx, &rpc.InvokeExperimentReq{
addedItems, err := p.evaluationSvc.InvokeExperiment(ctx, &rpc.InvokeExperimentReq{
WorkspaceID: workspaceID,
EvaluationSetID: taskRun.GetTaskRunConfig().GetAutoEvaluateRunConfig().GetEvalID(),
Items: []*eval_set.EvaluationSetItem{
Expand Down Expand Up @@ -171,6 +171,11 @@ func (p *AutoEvaluateProcessor) Invoke(ctx context.Context, trigger *taskexe.Tri
}
return err
}
if addedItems <= 0 {
_ = p.taskRepo.DecrTaskCount(ctx, trigger.Task.ID, taskTTL)
_ = p.taskRepo.DecrTaskRunCount(ctx, trigger.Task.ID, taskRun.ID, taskTTL)
return nil
}
return nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/cloudwego/kitex/client/callopt"

"github.com/coze-dev/coze-loop/backend/infra/middleware/session"
datadataset "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/data/domain/dataset"
"github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/domain/common"
"github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/expt"
"github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/dataset"
Expand Down Expand Up @@ -457,7 +458,12 @@ func TestAutoEvaluateProcessor_Invoke_WithEvaluationProvider_SuccessAddedItems(t
repoMock.EXPECT().DecrTaskCount(gomock.Any(), gomock.Any(), gomock.Any()).Times(0)
repoMock.EXPECT().DecrTaskRunCount(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0)

client := &fakeExperimentClient{invokeResp: &expt.InvokeExperimentResponse{AddedItems: map[int64]int64{1: 1, 2: 1}}}
client := &fakeExperimentClient{invokeResp: &expt.InvokeExperimentResponse{
ItemOutputs: []*datadataset.CreateDatasetItemOutput{
{IsNewItem: gptr.Of(true)},
{IsNewItem: gptr.Of(true)},
},
}}
provider := evalrpc.NewEvaluationRPCProvider(client)
proc := &AutoEvaluateProcessor{evaluationSvc: provider, taskRepo: repoAdapter}
err := proc.Invoke(context.Background(), trigger)
Expand Down Expand Up @@ -765,6 +771,34 @@ func TestAutoEvaluateProcessor_Invoke(t *testing.T) {
err := proc.Invoke(context.Background(), trigger)
assert.NoError(t, err)
})

t.Run("success but addedItems is zero", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

taskObj := buildTestTask(t)
taskObj.Sampler.SampleSize = 5
trigger := buildTrigger(taskObj, textSchema)

repoMock := repomocks.NewMockITaskRepo(ctrl)
repoAdapter := &taskRepoMockAdapter{MockITaskRepo: repoMock}
repoMock.EXPECT().IncrTaskCount(gomock.Any(), taskObj.ID, gomock.Any()).Return(nil)
repoMock.EXPECT().IncrTaskRunCount(gomock.Any(), taskObj.ID, trigger.TaskRun.ID, gomock.Any()).Return(nil)
repoMock.EXPECT().GetTaskCount(gomock.Any(), taskObj.ID).Return(int64(1), nil)
repoMock.EXPECT().GetTaskRunCount(gomock.Any(), taskObj.ID, trigger.TaskRun.ID).Return(int64(1), nil)
repoMock.EXPECT().DecrTaskCount(gomock.Any(), taskObj.ID, gomock.Any()).Return(nil)
repoMock.EXPECT().DecrTaskRunCount(gomock.Any(), taskObj.ID, trigger.TaskRun.ID, gomock.Any()).Return(nil)

evalMock := &fakeEvaluationAdapter{}
evalMock.invokeResp.addedItems = 0

proc := &AutoEvaluateProcessor{
evaluationSvc: evalMock,
taskRepo: repoAdapter,
}
err := proc.Invoke(context.Background(), trigger)
assert.NoError(t, err)
})
}

func TestAutoEvaluateProcessor_OnUpdateTaskChange(t *testing.T) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,10 @@ func (h *TraceHubServiceImpl) combineFilters(filters ...*loop_span.FilterFields)

// fetchSpans paginates span data
func (h *TraceHubServiceImpl) fetchSpans(ctx context.Context, listParam *repo.ListSpansParam, sub *spanSubscriber) ([]*loop_span.Span, string, error) {
result, err := h.traceRepo.ListSpans(ctx, listParam)
// 默认 30s to 60s 减少超时报错情况
listCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
defer cancel()
result, err := h.traceRepo.ListSpans(listCtx, listParam)
if err != nil {
logs.CtxError(ctx, "List spans failed, parma=%v, err=%v", listParam, err)
return nil, "", err
Expand Down
Loading
Loading