diff --git a/backend/api/handler/coze/loop/apis/wire_gen.go b/backend/api/handler/coze/loop/apis/wire_gen.go index 6f0460673..b0ce0120b 100644 --- a/backend/api/handler/coze/loop/apis/wire_gen.go +++ b/backend/api/handler/coze/loop/apis/wire_gen.go @@ -8,7 +8,6 @@ package apis import ( "context" - "github.com/cloudwego/kitex/pkg/endpoint" "github.com/coze-dev/coze-loop/backend/infra/ck" "github.com/coze-dev/coze-loop/backend/infra/db" @@ -175,7 +174,7 @@ func InitObservabilityHandler(ctx context.Context, db2 db.Provider, ckDb ck.Prov if err != nil { return nil, err } - iTaskApplication, err := application6.InitTaskApplication(db2, idgen2, configFactory, benefit2, ckDb, redis2, mqFactory, userClient, authCli, evalClient, evalSetClient, experimentClient, datasetClient, fileClient, taskProcessor, aid, persistentCmdable) + iTaskApplication, err := application6.InitTaskApplication(db2, idgen2, configFactory, benefit2, ckDb, meter, redis2, mqFactory, userClient, authCli, evalClient, evalSetClient, experimentClient, datasetClient, fileClient, taskProcessor, aid, persistentCmdable) if err != nil { return nil, err } diff --git a/backend/modules/observability/application/wire.go b/backend/modules/observability/application/wire.go index 356e4a48b..f66351cc0 100644 --- a/backend/modules/observability/application/wire.go +++ b/backend/modules/observability/application/wire.go @@ -462,6 +462,7 @@ func InitTaskApplication( configFactory conf.IConfigLoaderFactory, benefit benefit.IBenefitService, ckDb ck.Provider, + meter metrics.Meter, redis redis.Cmdable, mqFactory mq.IFactory, userClient userservice.Client, diff --git a/backend/modules/observability/application/wire_gen.go b/backend/modules/observability/application/wire_gen.go index 91927da43..22dc1b3ae 100644 --- a/backend/modules/observability/application/wire_gen.go +++ b/backend/modules/observability/application/wire_gen.go @@ -253,7 +253,7 @@ func InitTraceIngestionApplication(configFactory conf.IConfigLoaderFactory, stor return iTraceIngestionApplication, nil } -func InitTaskApplication(db2 db.Provider, idgen2 idgen.IIDGenerator, configFactory conf.IConfigLoaderFactory, benefit2 benefit.IBenefitService, ckDb ck.Provider, redis3 redis.Cmdable, mqFactory mq.IFactory, userClient userservice.Client, authClient authservice.Client, evalService evaluatorservice.Client, evalSetService evaluationsetservice.Client, exptService experimentservice.Client, datasetService datasetservice.Client, fileClient fileservice.Client, taskProcessor processor.TaskProcessor, aid int32, persistentCmdable redis.PersistentCmdable) (ITaskApplication, error) { +func InitTaskApplication(db2 db.Provider, idgen2 idgen.IIDGenerator, configFactory conf.IConfigLoaderFactory, benefit2 benefit.IBenefitService, ckDb ck.Provider, meter metrics.Meter, redis3 redis.Cmdable, mqFactory mq.IFactory, userClient userservice.Client, authClient authservice.Client, evalService evaluatorservice.Client, evalSetService evaluationsetservice.Client, exptService experimentservice.Client, datasetService datasetservice.Client, fileClient fileservice.Client, taskProcessor processor.TaskProcessor, aid int32, persistentCmdable redis.PersistentCmdable) (ITaskApplication, error) { iTaskDao := mysql.NewTaskDaoImpl(db2) iTaskDAO := redis2.NewTaskDAO(redis3) iTaskRunDao := mysql.NewTaskRunDaoImpl(db2) @@ -296,7 +296,20 @@ func InitTaskApplication(db2 db.Provider, idgen2 idgen.IIDGenerator, configFacto return nil, err } iLocker := NewTaskLocker(redis3) - iTraceHubService, err := tracehub.NewTraceHubImpl(iTaskRepo, iTraceRepo, iTenantProvider, traceFilterProcessorBuilder, processorTaskProcessor, aid, iBackfillProducer, iLocker, iTraceConfig) + iTraceProducer, err := producer.NewTraceProducerImpl(iTraceConfig, mqFactory) + if err != nil { + return nil, err + } + iAnnotationProducer, err := producer.NewAnnotationProducerImpl(iTraceConfig, mqFactory) + if err != nil { + return nil, err + } + iTraceMetrics := metrics2.NewTraceMetricsImpl(meter) + iTraceService, err := service.NewTraceServiceImpl(iTraceRepo, iTraceConfig, iTraceProducer, iAnnotationProducer, iTraceMetrics, traceFilterProcessorBuilder, iTenantProvider, iEvaluatorRPCAdapter, iTaskRepo, persistentCmdable) + if err != nil { + return nil, err + } + iTraceHubService, err := tracehub.NewTraceHubImpl(iTaskRepo, iTraceRepo, iTenantProvider, traceFilterProcessorBuilder, processorTaskProcessor, aid, iBackfillProducer, iLocker, iTraceConfig, iTraceService) if err != nil { return nil, err } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go index 819051c20..7bdc90040 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go @@ -409,6 +409,10 @@ func (h *TraceHubServiceImpl) processSpansForBackfill(ctx context.Context, spans } batch := spans[i:end] + err = h.traceService.MergeHistoryMessagesByRespIDBatch(ctx, spans, sub.t.GetPlatformType()) + if err != nil { + return err, false + } err, shouldFinish = h.processBatchSpans(ctx, batch, sub) if err != nil { logs.CtxError(ctx, "process batch spans failed, task_id=%d, batch_start=%d, err=%v", diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go index c67539f7b..18adfff56 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go @@ -173,9 +173,14 @@ func TestTraceHubServiceImpl_ProcessBatchSpans_DispatchError(t *testing.T) { t.Cleanup(ctrl.Finish) mockRepo := repo_mocks.NewMockITaskRepo(ctrl) + mockTraceService := builder_mocks.NewMockITraceService(ctrl) proc := &stubProcessor{invokeErr: errors.New("invoke fail")} - impl := &TraceHubServiceImpl{taskRepo: mockRepo} + impl := &TraceHubServiceImpl{taskRepo: mockRepo, traceService: mockTraceService} + mockTraceService.EXPECT(). + MergeHistoryMessagesByRespIDBatch(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil). + AnyTimes() now := time.Now() sampler := &entity.Sampler{ @@ -202,12 +207,13 @@ func TestTraceHubServiceImpl_ProcessBatchSpans_DispatchError(t *testing.T) { RunEndAt: now.Add(time.Minute), } sub := &spanSubscriber{ - taskID: 1, - t: taskDO, - tr: taskRun, - processor: proc, - runType: entity.TaskRunTypeNewData, - taskRepo: mockRepo, + taskID: 1, + t: taskDO, + tr: taskRun, + processor: proc, + traceService: mockTraceService, + runType: entity.TaskRunTypeNewData, + taskRepo: mockRepo, } spanRun := &entity.TaskRun{ @@ -310,12 +316,14 @@ func TestTraceHubServiceImpl_ListAndSendSpans_WithoutLastSpanPageToken(t *testin mockTenant := tenant_mocks.NewMockITenantProvider(ctrl) mockBuilder := builder_mocks.NewMockTraceFilterProcessorBuilder(ctrl) filterMock := spanfilter_mocks.NewMockFilter(ctrl) + mockTraceService := builder_mocks.NewMockITraceService(ctrl) impl := &TraceHubServiceImpl{ taskRepo: mockTaskRepo, traceRepo: mockTraceRepo, tenantProvider: mockTenant, buildHelper: mockBuilder, + traceService: mockTraceService, } now := time.Now() @@ -329,6 +337,10 @@ func TestTraceHubServiceImpl_ListAndSendSpans_WithoutLastSpanPageToken(t *testin filterMock.EXPECT().BuildRootSpanFilter(gomock.Any(), gomock.Any()).Return([]*loop_span.FilterField{}, nil) mockBuilder.EXPECT().BuildGetTraceProcessors(gomock.Any(), gomock.Any()).Return([]span_processor.Processor(nil), nil).Times(2) mockTenant.EXPECT().GetTenantsByPlatformType(gomock.Any(), loop_span.PlatformType(common.PlatformTypeCozeBot)).Return([]string{"tenant"}, nil) + mockTraceService.EXPECT(). + MergeHistoryMessagesByRespIDBatch(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil). + Times(2) mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, param *repo.ListSpansParam) (*repo.ListSpansResult, error) { switch param.PageToken { @@ -370,12 +382,14 @@ func TestTraceHubServiceImpl_ListAndSendSpans_Success(t *testing.T) { mockTenant := tenant_mocks.NewMockITenantProvider(ctrl) mockBuilder := builder_mocks.NewMockTraceFilterProcessorBuilder(ctrl) filterMock := spanfilter_mocks.NewMockFilter(ctrl) + mockTraceService := builder_mocks.NewMockITraceService(ctrl) impl := &TraceHubServiceImpl{ taskRepo: mockTaskRepo, traceRepo: mockTraceRepo, tenantProvider: mockTenant, buildHelper: mockBuilder, + traceService: mockTraceService, } now := time.Now() @@ -390,6 +404,10 @@ func TestTraceHubServiceImpl_ListAndSendSpans_Success(t *testing.T) { filterMock.EXPECT().BuildRootSpanFilter(gomock.Any(), gomock.Any()).Return([]*loop_span.FilterField{}, nil) mockBuilder.EXPECT().BuildGetTraceProcessors(gomock.Any(), gomock.Any()).Return([]span_processor.Processor(nil), nil) mockTenant.EXPECT().GetTenantsByPlatformType(gomock.Any(), loop_span.PlatformType(common.PlatformTypeCozeBot)).Return([]string{"tenant"}, nil) + mockTraceService.EXPECT(). + MergeHistoryMessagesByRespIDBatch(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil). + Times(1) mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, param *repo.ListSpansParam) (*repo.ListSpansResult, error) { require.Equal(t, "tenant", param.Tenants[0]) @@ -470,7 +488,8 @@ func TestTraceHubServiceImpl_DoFlush_NoMoreFinishError(t *testing.T) { t.Cleanup(ctrl.Finish) mockTaskRepo := repo_mocks.NewMockITaskRepo(ctrl) - impl := &TraceHubServiceImpl{taskRepo: mockTaskRepo} + mockTraceService := builder_mocks.NewMockITraceService(ctrl) + impl := &TraceHubServiceImpl{taskRepo: mockTaskRepo, traceService: mockTraceService} now := time.Now() sub, proc := newBackfillSubscriber(mockTaskRepo, now) @@ -480,6 +499,10 @@ func TestTraceHubServiceImpl_DoFlush_NoMoreFinishError(t *testing.T) { mockTaskRepo.EXPECT().GetTaskCount(gomock.Any(), int64(1)).Return(int64(0), nil) mockTaskRepo.EXPECT().GetBackfillTaskRun(gomock.Any(), gomock.Nil(), int64(1)).Return(domainRun, nil) + mockTraceService.EXPECT(). + MergeHistoryMessagesByRespIDBatch(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil). + Times(1) // 调用flushSpans,然后手动调用OnTaskFinished来触发finish错误 err, _ := impl.flushSpans(context.Background(), []*loop_span.Span{span}, sub) diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go index 021515a31..d41dbef4e 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go @@ -19,7 +19,6 @@ import ( func (h *TraceHubServiceImpl) SpanTrigger(ctx context.Context, span *loop_span.Span) error { logSuffix := fmt.Sprintf("log_id=%s, trace_id=%s, span_id=%s", span.LogID, span.TraceID, span.SpanID) - // 1. perform initial filtering based on space_id // 1.1 Filter out spans that do not belong to any space or bot cacheInfo := h.localCache.LoadTaskCache(ctx) @@ -91,13 +90,14 @@ func (h *TraceHubServiceImpl) buildSubscriberOfSpan(ctx context.Context, span *l return nil, err } subscribers = append(subscribers, &spanSubscriber{ - taskID: taskDO.ID, - t: taskDO, - processor: proc, - taskRepo: h.taskRepo, - runType: entity.TaskRunTypeNewData, - buildHelper: h.buildHelper, - tenants: tenants, + taskID: taskDO.ID, + t: taskDO, + processor: proc, + taskRepo: h.taskRepo, + runType: entity.TaskRunTypeNewData, + buildHelper: h.buildHelper, + tenants: tenants, + traceService: h.traceService, }) } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go index 424d095c3..31a031985 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go @@ -62,6 +62,7 @@ func TestTraceHubServiceImpl_SpanTriggerDispatchError(t *testing.T) { mockFilter := span_filter_mocks.NewMockFilter(ctrl) configLoader := config_mocks.NewMockITraceConfig(ctrl) tenantProvider := tenant_mocks.NewMockITenantProvider(ctrl) + mockTraceService := trace_service_mocks.NewMockITraceService(ctrl) now := time.Now() workspaceID := int64(1) @@ -115,6 +116,10 @@ func TestTraceHubServiceImpl_SpanTriggerDispatchError(t *testing.T) { mockFilter.EXPECT().BuildALLSpanFilter(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockBuilder.EXPECT().BuildPlatformRelatedFilter(gomock.Any(), gomock.Any()).Return(mockFilter, nil).AnyTimes() tenantProvider.EXPECT().GetTenantsByPlatformType(gomock.Any(), loop_span.PlatformDefault, gomock.Any()).Return([]string{"tenant"}, nil).AnyTimes() + mockTraceService.EXPECT(). + MergeHistoryMessagesByRespIDBatch(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil). + AnyTimes() spanRun := &entity.TaskRun{ ID: 201, @@ -141,6 +146,7 @@ func TestTraceHubServiceImpl_SpanTriggerDispatchError(t *testing.T) { localCache: NewLocalCache(), config: configLoader, tenantProvider: tenantProvider, + traceService: mockTraceService, } impl.localCache.taskCache.Store("ObjListWithTask", TaskCacheInfo{WorkspaceIDs: []string{"space-1"}, Tasks: []*entity.ObservabilityTask{taskDO}}) diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go index 6a0a6652f..a311e4760 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go @@ -21,14 +21,15 @@ import ( ) type spanSubscriber struct { - taskID int64 - t *entity.ObservabilityTask - tr *entity.TaskRun - processor taskexe.Processor - tenants []string - taskRepo repo.ITaskRepo - runType entity.TaskRunType - buildHelper service.TraceFilterProcessorBuilder + taskID int64 + t *entity.ObservabilityTask + tr *entity.TaskRun + processor taskexe.Processor + tenants []string + taskRepo repo.ITaskRepo + runType entity.TaskRunType + buildHelper service.TraceFilterProcessorBuilder + traceService service.ITraceService } // Sampled determines whether a span is sampled based on the sampling rate; the sample size will be validated during flush. @@ -209,6 +210,15 @@ func (s *spanSubscriber) AddSpan(ctx context.Context, span *loop_span.Span) erro } trigger := &taskexe.Trigger{Task: s.t, Span: span, TaskRun: taskRunConfig} logs.CtxDebug(ctx, "invoke processor, trigger: %v", trigger) + // New Data 在这里处理 + // Back fill 在前置批量处理 + if s.runType == entity.TaskRunTypeNewData { + err := s.traceService.MergeHistoryMessagesByRespIDBatch(ctx, []*loop_span.Span{span}, s.t.GetPlatformType()) + if err != nil { + logs.CtxError(ctx, "merge history messages failed, task_id=%d, span_id=%s err: %v", s.t.ID, span.SpanID, err) + return err + } + } err = s.processor.Invoke(ctx, trigger) if err != nil { logs.CtxWarn(ctx, "invoke processor failed, trace_id=%s, span_id=%s, err: %v", span.TraceID, span.SpanID, err) diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub.go index 4ad2e5017..12d50a488 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub.go @@ -36,6 +36,7 @@ func NewTraceHubImpl( backfillProducer mq.IBackfillProducer, locker lock.ILocker, config config.ITraceConfig, + traceService service.ITraceService, ) (ITraceHubService, error) { impl := &TraceHubServiceImpl{ taskRepo: tRepo, @@ -48,6 +49,7 @@ func NewTraceHubImpl( locker: locker, config: config, localCache: NewLocalCache(), + traceService: traceService, } return impl, nil } @@ -61,7 +63,7 @@ type TraceHubServiceImpl struct { backfillProducer mq.IBackfillProducer locker lock.ILocker config config.ITraceConfig - + traceService service.ITraceService // Local cache - caching non-terminal task information localCache *LocalCache diff --git a/backend/modules/observability/domain/trace/entity/loop_span/span.go b/backend/modules/observability/domain/trace/entity/loop_span/span.go index 258378f63..ed3a9bfc0 100644 --- a/backend/modules/observability/domain/trace/entity/loop_span/span.go +++ b/backend/modules/observability/domain/trace/entity/loop_span/span.go @@ -55,6 +55,8 @@ const ( SpanFieldUserID = "user_id" SpanFieldPromptKey = "prompt_key" SpanFieldTenant = "tenant" + SpanFieldKeyPreviousResponseID = "previous_response_id" + SpanFieldKeyResponseID = "response_id" SpanTypePrompt = "prompt" SpanTypeModel = "model" @@ -195,6 +197,106 @@ func (s *Span) GetCustomTags() map[string]string { return ret } +type StringWrapper struct { + Role string `json:"role"` + Content string `json:"content"` + Type string `json:"type"` +} + +func (s *Span) IsResponseAPISpan() bool { + if s.SpanType != SpanTypeModel { + return false + } + if s.SystemTagsString == nil { + return false + } + v, ok := s.SystemTagsString[SpanFieldKeyPreviousResponseID] + return ok && v != "" +} + +func (s *Span) MergeHistoryContext(ctx context.Context, historySpans []*Span) { + // Normalize func for Response API String|List Message structure + normalizeMessages := func(v interface{}, role string, t string) ([]interface{}, bool) { + switch vv := v.(type) { + case []interface{}: + return vv, true + case string: + if vv == "" { + return nil, false + } + return []interface{}{StringWrapper{Role: role, Content: vv, Type: t}}, true + default: + return nil, false + } + } + + var currentInputMap map[string]interface{} + if err := sonic.UnmarshalString(s.Input, ¤tInputMap); err != nil { + logs.CtxWarn(ctx, "fail to trans input %s into map", s.Input) + return + } + + if s.SystemTagsString == nil { + s.SystemTagsString = make(map[string]string) + } + // 同一个 span 命中多个 subscriber 幂等 + if s.SystemTagsString["_history_merged"] == "true" { + logs.CtxInfo(ctx, "history context already merged, skip") + return + } + + logs.CtxInfo(ctx, "start to merge history context") + + var historyMessages []interface{} + for _, preSpan := range historySpans { + if preSpan.Input != "" { + var inputMap map[string]interface{} + if err := sonic.UnmarshalString(preSpan.Input, &inputMap); err == nil { + if msgs, ok := inputMap["messages"].([]interface{}); ok { + historyMessages = append(historyMessages, msgs...) + } else if msgs, ok := normalizeMessages(inputMap["input"], "user", "message"); ok { + historyMessages = append(historyMessages, msgs...) + } + } + } + if preSpan.Output != "" { + var outputMap map[string]interface{} + if err := sonic.UnmarshalString(preSpan.Output, &outputMap); err == nil { + if msgs, ok := outputMap["choices"].([]interface{}); ok { + historyMessages = append(historyMessages, msgs...) + } else if msgs, ok := normalizeMessages(outputMap["output"], "assistant", "message"); ok { + historyMessages = append(historyMessages, msgs...) + } + } + } + } + + if len(historyMessages) == 0 { + return + } + + // fill into current span input map + if msgs, ok := currentInputMap["messages"].([]interface{}); ok { + currentInputMap["messages"] = append(historyMessages, msgs...) + } else if msgs, ok := normalizeMessages(currentInputMap["input"], "user", "message"); ok { + currentInputMap["input"] = append(historyMessages, msgs...) + } else { + currentInputMap["input"] = historyMessages + } + + newInput, err := sonic.Marshal(currentInputMap) + if err != nil { + logs.CtxWarn(ctx, "fail to marshal new input, err:%v", err) + return + } + s.Input = string(newInput) + s.SystemTagsString["_history_merged"] = "true" +} + +func (s *Span) IsModelSpan() bool { + return s.SpanType == SpanTypeModel +} + func (s *Span) getTags() []*Tag { tags := make([]*Tag, 0) for k, v := range s.TagsString { diff --git a/backend/modules/observability/domain/trace/entity/loop_span/span_test.go b/backend/modules/observability/domain/trace/entity/loop_span/span_test.go index a731c380a..e5d2bffd0 100644 --- a/backend/modules/observability/domain/trace/entity/loop_span/span_test.go +++ b/backend/modules/observability/domain/trace/entity/loop_span/span_test.go @@ -5,11 +5,14 @@ package loop_span import ( "context" + "reflect" "strconv" "strings" "testing" "time" + "github.com/coze-dev/coze-loop/backend/pkg/json" + "github.com/stretchr/testify/assert" ) @@ -455,6 +458,296 @@ func TestGetFieldValue_SystemTags(t *testing.T) { } } +func TestSpan_MergeHistoryContext(t *testing.T) { + t.Parallel() + ctx := context.Background() + + t.Run("merge input and output history messages", func(t *testing.T) { + span := &Span{ + Input: `{"messages":[{"role":"user","content":"cur1"},{"role":"assistant","content":"cur2"}]}`, + } + history := []*Span{ + {Input: `{"messages":[{"role":"system","content":"hist_in1"}]}`}, + {Output: `{"choices":[{"role":"assistant","content":"hist_out1"}]}`}, + } + span.MergeHistoryContext(ctx, history) + var m map[string]interface{} + _ = json.Unmarshal([]byte(span.Input), &m) + msgs, _ := m["messages"].([]interface{}) + assert.Equal(t, 4, len(msgs)) + first, _ := msgs[0].(map[string]interface{}) + second, _ := msgs[1].(map[string]interface{}) + third, _ := msgs[2].(map[string]interface{}) + fourth, _ := msgs[3].(map[string]interface{}) + assert.Equal(t, "system", first["role"]) + assert.Equal(t, "assistant", second["role"]) + assert.Equal(t, "user", third["role"]) + assert.Equal(t, "assistant", fourth["role"]) + }) + + t.Run("merge response api input/output string into wrappers", func(t *testing.T) { + span := &Span{ + Input: `{"input":"cur_in"}`, + } + history := []*Span{ + {Input: `{"input":"hist_in"}`}, + {Output: `{"output":"hist_out"}`}, + } + span.MergeHistoryContext(ctx, history) + var m map[string]interface{} + _ = json.Unmarshal([]byte(span.Input), &m) + msgs, _ := m["input"].([]interface{}) + assert.Equal(t, 3, len(msgs)) + first, _ := msgs[0].(map[string]interface{}) + second, _ := msgs[1].(map[string]interface{}) + third, _ := msgs[2].(map[string]interface{}) + assert.Equal(t, "user", first["role"]) + assert.Equal(t, "hist_in", first["content"]) + assert.Equal(t, "assistant", second["role"]) + assert.Equal(t, "hist_out", second["content"]) + assert.Equal(t, "user", third["role"]) + assert.Equal(t, "cur_in", third["content"]) + }) + + t.Run("merge response api input/output list into messages", func(t *testing.T) { + span := &Span{ + Input: `{"messages":[{"role":"user","content":"cur"}]}`, + } + history := []*Span{ + {Input: `{"input":[{"role":"user","content":"hist_in"}]}`}, + {Output: `{"output":[{"role":"assistant","content":"hist_out"}]}`}, + } + span.MergeHistoryContext(ctx, history) + var m map[string]interface{} + _ = json.Unmarshal([]byte(span.Input), &m) + msgs, _ := m["messages"].([]interface{}) + assert.Equal(t, 3, len(msgs)) + first, _ := msgs[0].(map[string]interface{}) + second, _ := msgs[1].(map[string]interface{}) + third, _ := msgs[2].(map[string]interface{}) + assert.Equal(t, "user", first["role"]) + assert.Equal(t, "assistant", second["role"]) + assert.Equal(t, "user", third["role"]) + }) + + t.Run("current messages fallback to input when messages is not array", func(t *testing.T) { + span := &Span{ + Input: `{"messages":"bad","input":"cur_in"}`, + } + history := []*Span{ + {Input: `{"messages":"bad","input":"hist_in"}`}, + {Output: `{"choices":"bad","output":"hist_out"}`}, + } + span.MergeHistoryContext(ctx, history) + var m map[string]interface{} + _ = json.Unmarshal([]byte(span.Input), &m) + msgs, _ := m["input"].([]interface{}) + assert.Equal(t, 3, len(msgs)) + first, _ := msgs[0].(map[string]interface{}) + second, _ := msgs[1].(map[string]interface{}) + third, _ := msgs[2].(map[string]interface{}) + assert.Equal(t, "user", first["role"]) + assert.Equal(t, "hist_in", first["content"]) + assert.Equal(t, "assistant", second["role"]) + assert.Equal(t, "hist_out", second["content"]) + assert.Equal(t, "user", third["role"]) + assert.Equal(t, "cur_in", third["content"]) + }) + + t.Run("empty current input string merges history to input", func(t *testing.T) { + span := &Span{Input: `{"input":""}`} + history := []*Span{ + {Input: `{"input":"hist_in"}`}, + } + span.MergeHistoryContext(ctx, history) + var m map[string]interface{} + _ = json.Unmarshal([]byte(span.Input), &m) + msgs, _ := m["input"].([]interface{}) + assert.Equal(t, 1, len(msgs)) + first, _ := msgs[0].(map[string]interface{}) + assert.Equal(t, "user", first["role"]) + assert.Equal(t, "hist_in", first["content"]) + }) + + t.Run("no messages and no input merges history to input", func(t *testing.T) { + span := &Span{Input: `{"foo":"bar"}`} + history := []*Span{ + {Input: `{"messages":[{"role":"system","content":"h"}]}`}, + } + span.MergeHistoryContext(ctx, history) + var m map[string]interface{} + _ = json.Unmarshal([]byte(span.Input), &m) + msgs, _ := m["input"].([]interface{}) + assert.Equal(t, 1, len(msgs)) + first, _ := msgs[0].(map[string]interface{}) + assert.Equal(t, "system", first["role"]) + assert.Equal(t, "h", first["content"]) + assert.Equal(t, "bar", m["foo"]) + }) + + t.Run("no history messages keeps input unchanged", func(t *testing.T) { + orig := `{"messages":[{"role":"user","content":"cur"}]}` + span := &Span{Input: orig} + history := []*Span{ + {Input: `{"no_messages":[]}`}, + {Output: `{"info":"x"}`}, + } + span.MergeHistoryContext(ctx, history) + assert.Equal(t, orig, span.Input) + }) + + t.Run("invalid current input leaves unchanged", func(t *testing.T) { + span := &Span{Input: `not-json`} + history := []*Span{ + {Input: `{"messages":[{"role":"system","content":"h"}]}`}, + } + span.MergeHistoryContext(ctx, history) + assert.Equal(t, `not-json`, span.Input) + }) + + t.Run("invalid history json is skipped", func(t *testing.T) { + span := &Span{Input: `{"messages":[{"role":"user","content":"cur"}]}`} + history := []*Span{ + {Input: `{"messages":[{"role":"system","content":"h1"}]}`}, + {Output: `{"messages": "not-array"}`}, + {Input: `{"messages":[{"role":"assistant","content":"h2"}]}`}, + {Output: `bad-json`}, + } + span.MergeHistoryContext(ctx, history) + var m map[string]interface{} + _ = json.Unmarshal([]byte(span.Input), &m) + msgs, _ := m["messages"].([]interface{}) + assert.Equal(t, 3, len(msgs)) + }) + + t.Run("empty history does nothing", func(t *testing.T) { + orig := `{"messages":[{"role":"user","content":"cur"}]}` + span := &Span{Input: orig} + span.MergeHistoryContext(ctx, nil) + assert.Equal(t, orig, span.Input) + }) + + t.Run("helper methods in span.go", func(t *testing.T) { + span := &Span{ + SpanID: "0000000000000001", + TraceID: "00000000000000000000000000000001", + WorkspaceID: "1", + StartTime: time.Now().UnixMicro(), + SpanType: SpanTypeModel, + SystemTagsString: map[string]string{ + SpanFieldKeyPreviousResponseID: "prev", + SpanFieldTenant: "tenant1", + }, + } + assert.True(t, span.IsResponseAPISpan()) + assert.Equal(t, "tenant1", span.GetTenant()) + + span2 := &Span{SpanType: SpanTypePrompt} + assert.False(t, span2.IsResponseAPISpan()) + span3 := &Span{SpanType: SpanTypeModel} + assert.False(t, span3.IsResponseAPISpan()) + span4 := &Span{SpanType: SpanTypeModel, SystemTagsString: map[string]string{SpanFieldKeyPreviousResponseID: ""}} + assert.False(t, span4.IsResponseAPISpan()) + }) + + t.Run("AddAutoEvalAnnotation and SpanList helpers", func(t *testing.T) { + span := &Span{ + SpanID: "0000000000000001", + TraceID: "00000000000000000000000000000001", + WorkspaceID: "1", + StartTime: time.Now().UnixMicro(), + } + anno, err := span.AddAutoEvalAnnotation(1, 2, 3, 0.5, "reason", "user1") + assert.NoError(t, err) + assert.NotNil(t, anno) + assert.Equal(t, AnnotationTypeAutoEvaluate, anno.AnnotationType) + assert.Equal(t, 1, len(span.Annotations)) + + spans := SpanList{ + {SpanType: SpanTypePrompt, StartTime: 2}, + {SpanType: SpanTypeModel, StartTime: 3, TagsLong: map[string]int64{SpanFieldInputTokens: 10, SpanFieldOutputTokens: 20}}, + {SpanType: SpanTypeLLMCall, StartTime: 1, TagsLong: map[string]int64{SpanFieldInputTokens: 1, SpanFieldOutputTokens: 2}}, + } + in, out, err := spans.Stat(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(11), in) + assert.Equal(t, int64(22), out) + + filtered := spans.FilterSpans(GetModelSpansFilter()) + assert.Equal(t, 2, len(filtered)) + + spans.SortByStartTime(false) + assert.Equal(t, int64(1), spans[0].StartTime) + spans.SortByStartTime(true) + assert.Equal(t, int64(3), spans[0].StartTime) + + uniq := SpanList{ + {SpanID: "a", TraceID: "t"}, + {SpanID: "a", TraceID: "t"}, + {SpanID: "b", TraceID: "t"}, + }.Uniq() + assert.Equal(t, 2, len(uniq)) + }) + + t.Run("field and tag helpers", func(t *testing.T) { + type sample struct { + Str string `json:"str"` + Bool bool `json:"bool"` + I64 int64 `json:"i64"` + F64 float64 `json:"f64"` + Ptr *string `json:"ptr"` + Bad int `json:"bad"` + NoTag string + } + + s := &sample{} + fields := NewStruct(s).Fields() + assert.GreaterOrEqual(t, len(fields), 1) + + var ptrField *Field + var badField *Field + var noTagField *Field + for _, f := range fields { + if f.Name() == "Ptr" { + ptrField = f + } + if f.Name() == "Bad" { + badField = f + } + if f.Name() == "NoTag" { + noTagField = f + } + alias, err := f.TagJson() + if f.Name() == "Bad" { + assert.NoError(t, err) + assert.NotEmpty(t, alias) + } + } + assert.NotNil(t, noTagField) + _, err := noTagField.TagJson() + assert.Error(t, err) + + assert.NotNil(t, ptrField) + assert.Equal(t, reflect.Ptr, ptrField.Kind()) + assert.NoError(t, ptrField.Set("x")) + assert.NotNil(t, s.Ptr) + assert.Equal(t, "x", *s.Ptr) + vt, err := ptrField.ValueType() + assert.NoError(t, err) + assert.Equal(t, TagValueTypeString, vt) + + assert.NotNil(t, badField) + _, err = badField.ValueType() + assert.Error(t, err) + + assert.Equal(t, "Bool", TagValueTypeBool.String()) + assert.Equal(t, "I64", TagValueTypeInt64.String()) + assert.Equal(t, "F64", TagValueTypeFloat64.String()) + assert.Equal(t, "String", TagValueTypeString.String()) + assert.Equal(t, "", TagValueTypeUnknown.String()) + }) +} + // TestSizeofSpans tests the SizeofSpans function func TestSizeofSpans(t *testing.T) { t.Parallel() diff --git a/backend/modules/observability/domain/trace/service/mocks/trace_service.go b/backend/modules/observability/domain/trace/service/mocks/trace_service.go index e96169a9c..b2af2cee9 100644 --- a/backend/modules/observability/domain/trace/service/mocks/trace_service.go +++ b/backend/modules/observability/domain/trace/service/mocks/trace_service.go @@ -3,7 +3,7 @@ // // Generated by this command: // -// mockgen -destination=mocks/trace_service.go -package=mocks . ITraceService +// mockgen -destination=modules/observability/domain/trace/service/mocks/trace_service.go -package=mocks github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service ITraceService // // Package mocks is a generated GoMock package. @@ -23,7 +23,6 @@ import ( type MockITraceService struct { ctrl *gomock.Controller recorder *MockITraceServiceMockRecorder - isgomock struct{} } // MockITraceServiceMockRecorder is the mock recorder for MockITraceService. @@ -176,18 +175,18 @@ func (mr *MockITraceServiceMockRecorder) GetTracesMetaInfo(ctx, req any) *gomock } // GetTrajectories mocks base method. -func (m *MockITraceService) GetTrajectories(ctx context.Context, workspaceID int64, traceIDs []string, startTime, endTime int64, platformType loop_span.PlatformType) (map[string]*loop_span.Trajectory, error) { +func (m *MockITraceService) GetTrajectories(ctx context.Context, req int64, arg2 []string, arg3, arg4 int64, arg5 loop_span.PlatformType) (map[string]*loop_span.Trajectory, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTrajectories", ctx, workspaceID, traceIDs, startTime, endTime, platformType) + ret := m.ctrl.Call(m, "GetTrajectories", ctx, req, arg2, arg3, arg4, arg5) ret0, _ := ret[0].(map[string]*loop_span.Trajectory) ret1, _ := ret[1].(error) return ret0, ret1 } // GetTrajectories indicates an expected call of GetTrajectories. -func (mr *MockITraceServiceMockRecorder) GetTrajectories(ctx, workspaceID, traceIDs, startTime, endTime, platformType any) *gomock.Call { +func (mr *MockITraceServiceMockRecorder) GetTrajectories(ctx, req, arg2, arg3, arg4, arg5 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTrajectories", reflect.TypeOf((*MockITraceService)(nil).GetTrajectories), ctx, workspaceID, traceIDs, startTime, endTime, platformType) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTrajectories", reflect.TypeOf((*MockITraceService)(nil).GetTrajectories), ctx, req, arg2, arg3, arg4, arg5) } // GetTrajectoryConfig mocks base method. @@ -264,6 +263,21 @@ func (mr *MockITraceServiceMockRecorder) ListPreSpan(ctx, req any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListPreSpan", reflect.TypeOf((*MockITraceService)(nil).ListPreSpan), ctx, req) } +// ListPreSpanBatch mocks base method. +func (m *MockITraceService) ListPreSpanBatch(ctx context.Context, req *service.ListPreSpanBatchReq) (*service.ListPreSpanBatchResp, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListPreSpanBatch", ctx, req) + ret0, _ := ret[0].(*service.ListPreSpanBatchResp) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListPreSpanBatch indicates an expected call of ListPreSpanBatch. +func (mr *MockITraceServiceMockRecorder) ListPreSpanBatch(ctx, req any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListPreSpanBatch", reflect.TypeOf((*MockITraceService)(nil).ListPreSpanBatch), ctx, req) +} + // ListPreSpanOApi mocks base method. func (m *MockITraceService) ListPreSpanOApi(ctx context.Context, req *service.ListPreSpanOApiReq) (*service.ListPreSpanOApiResp, error) { m.ctrl.T.Helper() @@ -324,6 +338,20 @@ func (mr *MockITraceServiceMockRecorder) ListTrajectory(ctx, req any) *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListTrajectory", reflect.TypeOf((*MockITraceService)(nil).ListTrajectory), ctx, req) } +// MergeHistoryMessagesByRespIDBatch mocks base method. +func (m *MockITraceService) MergeHistoryMessagesByRespIDBatch(ctx context.Context, spans []*loop_span.Span, platformType loop_span.PlatformType) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MergeHistoryMessagesByRespIDBatch", ctx, spans, platformType) + ret0, _ := ret[0].(error) + return ret0 +} + +// MergeHistoryMessagesByRespIDBatch indicates an expected call of MergeHistoryMessagesByRespIDBatch. +func (mr *MockITraceServiceMockRecorder) MergeHistoryMessagesByRespIDBatch(ctx, spans, platformType any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MergeHistoryMessagesByRespIDBatch", reflect.TypeOf((*MockITraceService)(nil).MergeHistoryMessagesByRespIDBatch), ctx, spans, platformType) +} + // SearchTraceOApi mocks base method. func (m *MockITraceService) SearchTraceOApi(ctx context.Context, req *service.SearchTraceOApiReq) (*service.SearchTraceOApiResp, error) { m.ctrl.T.Helper() @@ -340,17 +368,17 @@ func (mr *MockITraceServiceMockRecorder) SearchTraceOApi(ctx, req any) *gomock.C } // Send mocks base method. -func (m *MockITraceService) Send(ctx context.Context, msg *entity.AnnotationEvent) error { +func (m *MockITraceService) Send(ctx context.Context, req *entity.AnnotationEvent) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Send", ctx, msg) + ret := m.ctrl.Call(m, "Send", ctx, req) ret0, _ := ret[0].(error) return ret0 } // Send indicates an expected call of Send. -func (mr *MockITraceServiceMockRecorder) Send(ctx, msg any) *gomock.Call { +func (mr *MockITraceServiceMockRecorder) Send(ctx, req any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockITraceService)(nil).Send), ctx, msg) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockITraceService)(nil).Send), ctx, req) } // UpdateManualAnnotation mocks base method. diff --git a/backend/modules/observability/domain/trace/service/trace_export_service.go b/backend/modules/observability/domain/trace/service/trace_export_service.go index e10d04c80..072a5b2c8 100644 --- a/backend/modules/observability/domain/trace/service/trace_export_service.go +++ b/backend/modules/observability/domain/trace/service/trace_export_service.go @@ -166,7 +166,10 @@ func (r *TraceExportServiceImpl) ExportTracesToDataset(ctx context.Context, req if err := r.clearDataset(ctx, datasetID, req); err != nil { return resp, err } - + err = r.traceService.MergeHistoryMessagesByRespIDBatch(ctx, spans, req.PlatformType) + if err != nil { + return resp, err + } successItems, errorGroups, err := r.addToDataset(ctx, spans, req.FieldMappings, req.WorkspaceID, dataset, trajectoryMap) if err != nil { return resp, err @@ -204,12 +207,18 @@ func (r *TraceExportServiceImpl) PreviewExportTracesToDataset(ctx context.Contex return resp, err } + err = r.traceService.MergeHistoryMessagesByRespIDBatch(ctx, spans, req.PlatformType) + if err != nil { + return resp, err + } + successItems, failedItems, allItems := r.buildDatasetItems(ctx, spans, req.FieldMappings, req.WorkspaceID, dataset, nil) var ignoreCurrentCount *bool if !req.Config.IsNewDataset && req.ExportType == ExportType_Overwrite { ignoreCurrentCount = lo.ToPtr(true) } + addSuccess, errorGroups, err := r.getDatasetProvider(dataset.DatasetCategory).ValidateDatasetItems(ctx, dataset, successItems, ignoreCurrentCount) if err != nil { return resp, err diff --git a/backend/modules/observability/domain/trace/service/trace_export_service_test.go b/backend/modules/observability/domain/trace/service/trace_export_service_test.go index 2d24b83f0..64e5da3ef 100644 --- a/backend/modules/observability/domain/trace/service/trace_export_service_test.go +++ b/backend/modules/observability/domain/trace/service/trace_export_service_test.go @@ -32,7 +32,8 @@ import ( type stubTraceService struct { ITraceService - getTrajectoriesFunc func(ctx context.Context, workspaceID int64, traceIDs []string, startTime, endTime int64, platformType loop_span.PlatformType) (map[string]*loop_span.Trajectory, error) + getTrajectoriesFunc func(ctx context.Context, workspaceID int64, traceIDs []string, startTime, endTime int64, platformType loop_span.PlatformType) (map[string]*loop_span.Trajectory, error) + mergeHistoryMessagesByRespIDBatchFunc func(ctx context.Context, spans []*loop_span.Span, platformType loop_span.PlatformType) error } func (m *stubTraceService) GetTrajectories(ctx context.Context, workspaceID int64, traceIDs []string, startTime, endTime int64, platformType loop_span.PlatformType) (map[string]*loop_span.Trajectory, error) { @@ -42,6 +43,13 @@ func (m *stubTraceService) GetTrajectories(ctx context.Context, workspaceID int6 return nil, nil } +func (m *stubTraceService) MergeHistoryMessagesByRespIDBatch(ctx context.Context, spans []*loop_span.Span, platformType loop_span.PlatformType) error { + if m.mergeHistoryMessagesByRespIDBatchFunc != nil { + return m.mergeHistoryMessagesByRespIDBatchFunc(ctx, spans, platformType) + } + return nil +} + func TestTraceExportServiceImpl_ExportTracesToDataset(t *testing.T) { type fields struct { traceRepo repo.ITraceRepo @@ -119,7 +127,7 @@ func TestTraceExportServiceImpl_ExportTracesToDataset(t *testing.T) { tenantProvider: tenantMock, DatasetServiceAdaptor: adaptor, buildHelper: buildHelper, - traceService: nil, + traceService: &stubTraceService{}, } }, args: args{ @@ -201,6 +209,9 @@ func TestTraceExportServiceImpl_ExportTracesToDataset(t *testing.T) { defer ctrl.Finish() fields := tt.fieldsGetter(ctrl) + if fields.traceService == nil { + fields.traceService = &stubTraceService{} + } r := &TraceExportServiceImpl{ traceRepo: fields.traceRepo, traceConfig: fields.traceConfig, @@ -1001,6 +1012,9 @@ func TestTraceExportServiceImpl_PreviewExportTracesToDataset(t *testing.T) { defer ctrl.Finish() fields := tt.fieldsGetter(ctrl) + if fields.traceService == nil { + fields.traceService = &stubTraceService{} + } r := &TraceExportServiceImpl{ traceRepo: fields.traceRepo, traceConfig: fields.traceConfig, @@ -1779,6 +1793,9 @@ func TestTraceExportServiceImpl_ExportTracesToDataset_Additional(t *testing.T) { defer ctrl.Finish() fields := tt.fieldsGetter(ctrl) + if fields.traceService == nil { + fields.traceService = &stubTraceService{} + } r := &TraceExportServiceImpl{ traceRepo: fields.traceRepo, traceConfig: fields.traceConfig, @@ -1927,6 +1944,9 @@ func TestTraceExportServiceImpl_PreviewExportTracesToDataset_Additional(t *testi defer ctrl.Finish() fields := tt.fieldsGetter(ctrl) + if fields.traceService == nil { + fields.traceService = &stubTraceService{} + } r := &TraceExportServiceImpl{ traceRepo: fields.traceRepo, traceConfig: fields.traceConfig, diff --git a/backend/modules/observability/domain/trace/service/trace_service.go b/backend/modules/observability/domain/trace/service/trace_service.go index 860220fa4..618cfa1ed 100644 --- a/backend/modules/observability/domain/trace/service/trace_service.go +++ b/backend/modules/observability/domain/trace/service/trace_service.go @@ -10,6 +10,8 @@ import ( "sync" "time" + "github.com/bytedance/gg/gslice" + "github.com/coze-dev/coze-loop/backend/infra/redis" tconv "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" taskrepo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" @@ -75,6 +77,34 @@ type ListPreSpanResp struct { Spans loop_span.SpanList } +type ListPreSpanBatchReq struct { + WorkspaceID int64 + ThirdPartyWorkspaceID string + StartTime int64 // ms + EndTime int64 + Items []*ListPreSpanItem + PlatformType loop_span.PlatformType +} + +type ListPreSpanItem struct { + TraceID string + SpanID string + PreviousResponseID string + CurrentSpan *loop_span.Span +} + +type ListPreSpanBatchResp struct { + Results []*ListPreSpanResult +} + +type ListPreSpanResult struct { + TraceID string + SpanID string + PreviousResponseID string + Spans loop_span.SpanList + Error error +} + type GetTraceReq struct { WorkspaceID int64 LogID string @@ -358,6 +388,7 @@ type IAnnotationEvent interface { type ITraceService interface { ListSpans(ctx context.Context, req *ListSpansReq) (*ListSpansResp, error) ListPreSpan(ctx context.Context, req *ListPreSpanReq) (r *ListPreSpanResp, err error) + ListPreSpanBatch(ctx context.Context, req *ListPreSpanBatchReq) (*ListPreSpanBatchResp, error) GetTrace(ctx context.Context, req *GetTraceReq) (*GetTraceResp, error) SearchTraceOApi(ctx context.Context, req *SearchTraceOApiReq) (*SearchTraceOApiResp, error) ListSpansOApi(ctx context.Context, req *ListSpansOApiReq) (*ListSpansOApiResp, error) @@ -380,6 +411,7 @@ type ITraceService interface { ListTrajectory(ctx context.Context, req *ListTrajectoryRequest) (*ListTrajectoryResponse, error) GetTrajectories(ctx context.Context, workspaceID int64, traceIDs []string, startTime, endTime int64, platformType loop_span.PlatformType) (map[string]*loop_span.Trajectory, error) + MergeHistoryMessagesByRespIDBatch(ctx context.Context, spans []*loop_span.Span, platformType loop_span.PlatformType) error } func NewTraceServiceImpl( @@ -442,7 +474,7 @@ func (r *TraceServiceImpl) ListPreSpan(ctx context.Context, req *ListPreSpanReq) preAndCurrentSpanIDs = append(preAndCurrentSpanIDs, req.SpanID) // for select current span together // batch select from ck - preAndCurrentSpans, err := r.batchGetPreSpan(ctx, preAndCurrentSpanIDs, tenants, req.StartTime) + preAndCurrentSpans, err := r.batchGetPreSpan(ctx, preAndCurrentSpanIDs, tenants, req.StartTime-timeutil.Day2MillSec(30), req.StartTime+1) if err != nil { return nil, errorx.WrapByCode(err, obErrorx.CommercialCommonInternalErrorCodeCode) } @@ -471,12 +503,12 @@ func (r *TraceServiceImpl) ListPreSpan(ctx context.Context, req *ListPreSpanReq) } // order SpanList: remove duplicate span_id, and remove current span - orderSpans := r.orderPreSpans(preAndCurrentSpans, respIDByOrder) + orderSpans := r.orderPreSpans(ctx, preAndCurrentSpans, respIDByOrder) return &ListPreSpanResp{Spans: orderSpans}, nil } -func (r *TraceServiceImpl) batchGetPreSpan(ctx context.Context, spanIDs []string, tenants []string, startTime int64) ([]*loop_span.Span, error) { +func (r *TraceServiceImpl) batchGetPreSpan(ctx context.Context, spanIDs []string, tenants []string, startTime int64, endTime int64) ([]*loop_span.Span, error) { batchNum := 100 batchPreSpan := make([][]string, 0) oneBatchPreSpan := make([]string, 0) @@ -504,8 +536,8 @@ func (r *TraceServiceImpl) batchGetPreSpan(ctx context.Context, spanIDs []string }, }, }, - StartAt: startTime - timeutil.Day2MillSec(30), // past 30 days - EndAt: startTime + 1, + StartAt: startTime, + EndAt: endTime, Limit: 200, }) if err != nil { @@ -595,7 +627,7 @@ func (r *TraceServiceImpl) checkGetPreSpanAuth(ctx context.Context, req *ListPre return nil } -func (r *TraceServiceImpl) orderPreSpans(preAndCurrentSpans []*loop_span.Span, respIDByOrder []string) loop_span.SpanList { +func (r *TraceServiceImpl) orderPreSpans(ctx context.Context, preAndCurrentSpans []*loop_span.Span, respIDByOrder []string) loop_span.SpanList { respIDSpanMap := make(map[string]*loop_span.Span) for _, span := range preAndCurrentSpans { if respID, ok := span.SystemTagsString[keyResponseID]; ok { @@ -612,6 +644,289 @@ func (r *TraceServiceImpl) orderPreSpans(preAndCurrentSpans []*loop_span.Span, r return orderSpans } +// ListPreSpanBatch batch version of ListPreSpan, processes multiple previous_response_id in one call. +func (r *TraceServiceImpl) ListPreSpanBatch(ctx context.Context, req *ListPreSpanBatchReq) (*ListPreSpanBatchResp, error) { + if len(req.Items) == 0 { + return &ListPreSpanBatchResp{Results: []*ListPreSpanResult{}}, nil + } + + // Step 1: Get tenants (shared across all items) + tenants, err := r.getTenants(ctx, req.PlatformType) + if err != nil { + return nil, err + } + + // Step 2: Batch get all pre span IDs from redis + spanIDsInfo, err := r.batchGetPreSpanIDsFromRedis(ctx, req.Items) + if err != nil { + return nil, err + } + logs.CtxInfo(ctx, "Got span from redis info: %v", tconv.ToJSONString(ctx, spanIDsInfo)) + // Step 3: Collect all unique span IDs to query + allSpanIDs := r.collectAllSpanIDs(spanIDsInfo, req.Items) + // Step 4: Batch query all spans from ClickHouse + allSpans, err := r.batchGetPreSpan(ctx, allSpanIDs, tenants, req.StartTime-timeutil.Day2MillSec(30), req.EndTime+1) + if err != nil { + return nil, errorx.WrapByCode(err, obErrorx.CommercialCommonInternalErrorCodeCode) + } + + // Step 5: Apply span processors once for all spans + processedSpans, err := r.applyProcessors(ctx, allSpans, req) + if err != nil { + return nil, err + } + // Step 6: Build span map for quick lookup + allSpanMap := r.buildSpanMap(processedSpans) + + // Step 6.1: Add current spans from request items (for New Data scenario where span is not yet in CK) + for _, item := range req.Items { + if item.CurrentSpan != nil { + allSpanMap[item.CurrentSpan.SpanID] = item.CurrentSpan + } + } + + // Step 7: Process each item individually (auth check, ordering) + results := r.processEachItem(ctx, req, tenants, spanIDsInfo, allSpanMap) + return &ListPreSpanBatchResp{Results: results}, nil +} + +// batchGetPreSpanIDsFromRedis fetches pre span IDs from Redis for all items +// Returns a map keyed by SpanID (not PreviousResponseID) to handle multiple spans sharing the same PreviousResponseID +func (r *TraceServiceImpl) batchGetPreSpanIDsFromRedis( + ctx context.Context, + items []*ListPreSpanItem, +) (map[string]*preSpanIDsInfo, error) { + result := make(map[string]*preSpanIDsInfo, len(items)) + preRespIDCache := make(map[string]*preSpanIDsInfo) + + for _, item := range items { + if item.PreviousResponseID == "" { + continue + } + + if cached, ok := preRespIDCache[item.PreviousResponseID]; ok { + result[item.SpanID] = &preSpanIDsInfo{ + PreSpanIDs: cached.PreSpanIDs, + RespIDByOrder: cached.RespIDByOrder, + } + continue + } + + preSpanIDs, respIDByOrder, err := r.traceRepo.GetPreSpanIDs(ctx, &repo.GetPreSpanIDsParam{ + PreRespID: item.PreviousResponseID, + }) + if err != nil { + return nil, err + } + + info := &preSpanIDsInfo{ + PreSpanIDs: preSpanIDs, + RespIDByOrder: respIDByOrder, + } + preRespIDCache[item.PreviousResponseID] = info + result[item.SpanID] = info + } + + return result, nil +} + +// collectAllSpanIDs collects all unique span IDs that need to be queried +func (r *TraceServiceImpl) collectAllSpanIDs( + spanIDsInfo map[string]*preSpanIDsInfo, + items []*ListPreSpanItem, +) []string { + spanIDSet := make(map[string]struct{}) + + // Add current span IDs from items + for _, item := range items { + spanIDSet[item.SpanID] = struct{}{} + } + + // Add pre span IDs from Redis results + for _, info := range spanIDsInfo { + for _, spanID := range info.PreSpanIDs { + spanIDSet[spanID] = struct{}{} + } + } + + allSpanIDs := make([]string, 0, len(spanIDSet)) + for spanID := range spanIDSet { + allSpanIDs = append(allSpanIDs, spanID) + } + + return allSpanIDs +} + +// applyProcessors applies span processors to all spans at once +func (r *TraceServiceImpl) applyProcessors( + ctx context.Context, + spans []*loop_span.Span, + req *ListPreSpanBatchReq, +) ([]*loop_span.Span, error) { + processors, err := r.buildHelper.BuildGetTraceProcessors(ctx, span_processor.Settings{ + WorkspaceId: req.WorkspaceID, + PlatformType: req.PlatformType, + QueryStartTime: req.StartTime - timeutil.Day2MillSec(30), // past 30 days + QueryEndTime: req.EndTime, + }) + if err != nil { + return nil, errorx.WrapByCode(err, obErrorx.CommercialCommonInternalErrorCodeCode) + } + + processedSpans := spans + for _, p := range processors { + processedSpans, err = p.Transform(ctx, processedSpans) + if err != nil { + return nil, err + } + } + + return processedSpans, nil +} + +// buildSpanMap creates a map for quick span lookup by span_id +func (r *TraceServiceImpl) buildSpanMap(spans []*loop_span.Span) map[string]*loop_span.Span { + spanMap := make(map[string]*loop_span.Span, len(spans)) + for _, span := range spans { + if span != nil { + spanMap[span.SpanID] = span + } + } + return spanMap +} + +// processEachItem processes each request item individually +func (r *TraceServiceImpl) processEachItem( + ctx context.Context, + req *ListPreSpanBatchReq, + tenants []string, + spanIDsInfo map[string]*preSpanIDsInfo, + spanMap map[string]*loop_span.Span, +) []*ListPreSpanResult { + results := make([]*ListPreSpanResult, 0, len(req.Items)) + + for _, item := range req.Items { + result := &ListPreSpanResult{ + TraceID: item.TraceID, + SpanID: item.SpanID, + PreviousResponseID: item.PreviousResponseID, + } + + // Get span IDs info for this item (now keyed by SpanID) + info, exists := spanIDsInfo[item.SpanID] + if !exists { + result.Error = errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, + errorx.WithExtraMsg("span_id not found in redis lookup")) + logs.CtxWarn(ctx, "Span id not found in redis lookup: %v", item.SpanID) + results = append(results, result) + continue + } + + // Collect pre spans + current span for this item + // Note: current span is needed for checkGetPreSpanAuth, but will be filtered out by orderPreSpans + preAndCurrentSpans := make([]*loop_span.Span, 0, len(info.PreSpanIDs)+1) + for _, spanID := range info.PreSpanIDs { + if span, ok := spanMap[spanID]; ok { + preAndCurrentSpans = append(preAndCurrentSpans, span) + } + } + if currentSpan, ok := spanMap[item.SpanID]; ok { + preAndCurrentSpans = append(preAndCurrentSpans, currentSpan) + } + + // Auth check + itemReq := &ListPreSpanReq{ + WorkspaceID: req.WorkspaceID, + ThirdPartyWorkspaceID: req.ThirdPartyWorkspaceID, + StartTime: req.StartTime, + TraceID: item.TraceID, + SpanID: item.SpanID, + PreviousResponseID: item.PreviousResponseID, + PlatformType: req.PlatformType, + } + if err := r.checkGetPreSpanAuth(ctx, itemReq, tenants, preAndCurrentSpans); err != nil { + result.Error = err + logs.CtxWarn(ctx, "CheckGetPreSpanAuth failed: %v", err) + results = append(results, result) + continue + } + + // Order spans + orderSpans := r.orderPreSpans(ctx, preAndCurrentSpans, info.RespIDByOrder) + result.Spans = orderSpans + + results = append(results, result) + } + + return results +} + +// preSpanIDsInfo holds the pre span IDs and their order for a single previous_response_id +type preSpanIDsInfo struct { + PreSpanIDs []string + RespIDByOrder []string +} + +func (r *TraceServiceImpl) MergeHistoryMessagesByRespIDBatch(ctx context.Context, spans []*loop_span.Span, platformType loop_span.PlatformType) error { + spansWithRespID := gslice.Filter(spans, func(span *loop_span.Span) bool { + if !span.IsModelSpan() { + return false + } + if span.SystemTagsString == nil { + return false + } + v, ok := span.SystemTagsString[keyPreviousResponseID] + return ok && v != "" + }) + if len(spansWithRespID) > 0 { + spanResp, err := r.ListPreSpanBatch(ctx, spanList2ListPreSpanBatchReq(spansWithRespID, platformType)) + if err != nil { + logs.CtxError(ctx, "MergeHistoryMessagesByRespIDBatch ListPreSpanBatch fail, err:%v", err) + return err + } + spanIdMap := gslice.ToMap(spanResp.Results, func(t *ListPreSpanResult) (string, *ListPreSpanResult) { + return t.SpanID, t + }) + for _, span := range spansWithRespID { + preResult, ok := spanIdMap[span.SpanID] + if !ok || preResult.Error != nil { + continue + } + + span.MergeHistoryContext(ctx, preResult.Spans) + } + } + return nil +} + +func spanList2ListPreSpanBatchReq(spanList []*loop_span.Span, platformType loop_span.PlatformType) *ListPreSpanBatchReq { + if len(spanList) == 0 { + return nil + } + workspaceId, _ := strconv.Atoi(spanList[0].WorkspaceID) + startTime := gslice.Min(gslice.Map(spanList, func(span *loop_span.Span) int64 { + return span.StartTime + })) + endTime := gslice.Max(gslice.Map(spanList, func(span *loop_span.Span) int64 { + return span.StartTime + })) + return &ListPreSpanBatchReq{ + WorkspaceID: int64(workspaceId), + ThirdPartyWorkspaceID: "", + StartTime: startTime.Value() / 1000, // us to ms + EndTime: endTime.Value() / 1000, + Items: gslice.Map(spanList, func(span *loop_span.Span) *ListPreSpanItem { + return &ListPreSpanItem{ + TraceID: span.TraceID, + SpanID: span.SpanID, + PreviousResponseID: span.SystemTagsString[keyPreviousResponseID], + CurrentSpan: span, + } + }), + PlatformType: platformType, + } +} + func (r *TraceServiceImpl) ListTrajectory(ctx context.Context, req *ListTrajectoryRequest) (*ListTrajectoryResponse, error) { if req.StartTime == nil { return nil, errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("start_time is required")) @@ -898,7 +1213,7 @@ func (r *TraceServiceImpl) ListPreSpanOApi(ctx context.Context, req *ListPreSpan preAndCurrentSpanIDs = append(preAndCurrentSpanIDs, req.SpanID) // for select current span together // batch select from ck - preAndCurrentSpans, err := r.batchGetPreSpan(ctx, preAndCurrentSpanIDs, req.Tenants, req.StartTime) + preAndCurrentSpans, err := r.batchGetPreSpan(ctx, preAndCurrentSpanIDs, req.Tenants, req.StartTime-timeutil.Day2MillSec(30), req.StartTime+1) if err != nil { return nil, errorx.WrapByCode(err, obErrorx.CommercialCommonInternalErrorCodeCode) } @@ -936,7 +1251,7 @@ func (r *TraceServiceImpl) ListPreSpanOApi(ctx context.Context, req *ListPreSpan } // order SpanList: remove duplicate span_id, and remove current span - orderSpans := r.orderPreSpans(preAndCurrentSpans, respIDByOrder) + orderSpans := r.orderPreSpans(ctx, preAndCurrentSpans, respIDByOrder) return &ListPreSpanOApiResp{ Spans: orderSpans, diff --git a/backend/modules/observability/domain/trace/service/trace_service_pre_span_oapi_test.go b/backend/modules/observability/domain/trace/service/trace_service_pre_span_oapi_test.go index b77f8f5d0..52522375b 100644 --- a/backend/modules/observability/domain/trace/service/trace_service_pre_span_oapi_test.go +++ b/backend/modules/observability/domain/trace/service/trace_service_pre_span_oapi_test.go @@ -87,7 +87,7 @@ func TestTraceServiceImpl_ListPreSpanOApi_Success(t *testing.T) { resp, err := r.ListPreSpanOApi(context.Background(), req) assert.NoError(t, err) if assert.NotNil(t, resp) { - // 顺序应按 respIDByOrder:resp-2 在前、resp-1 在后 + // 顺序应按 RespIDByOrder:resp-2 在前、resp-1 在后 got := make([]string, 0, len(resp.Spans)) for _, s := range resp.Spans { got = append(got, s.SpanID) diff --git a/backend/modules/observability/domain/trace/service/trace_service_test.go b/backend/modules/observability/domain/trace/service/trace_service_test.go index d8c1b324f..d5a362dbd 100644 --- a/backend/modules/observability/domain/trace/service/trace_service_test.go +++ b/backend/modules/observability/domain/trace/service/trace_service_test.go @@ -10,6 +10,8 @@ import ( "testing" "time" + timeutil "github.com/coze-dev/coze-loop/backend/pkg/time" + "github.com/coze-dev/coze-loop/backend/infra/middleware/session" "github.com/coze-dev/coze-loop/backend/infra/redis" "github.com/coze-dev/coze-loop/backend/infra/redis/mocks" @@ -2948,6 +2950,341 @@ func TestTraceServiceImpl_ListSpansOApi(t *testing.T) { }, wantErr: true, }, + { + name: "list spans successfully with valid request", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockRepo := repomocks.NewMockITraceRepo(ctrl) + filterFactoryMock := filtermocks.NewMockPlatformFilterFactory(ctrl) + mockFilter := filtermocks.NewMockFilter(ctrl) + + filterFactoryMock.EXPECT(). + GetFilter(gomock.Any(), loop_span.PlatformCozeLoop). + Return(mockFilter, nil) + + mockFilter.EXPECT(). + BuildBasicSpanFilter(gomock.Any(), gomock.Any()). + Return([]*loop_span.FilterField{}, true, nil) + + mockFilter.EXPECT(). + BuildALLSpanFilter(gomock.Any(), gomock.Any()). + Return([]*loop_span.FilterField{}, nil) + + mockRepo.EXPECT(). + ListSpans(gomock.Any(), gomock.Any()). + Return(&repo.ListSpansResult{ + Spans: []*loop_span.Span{ + { + SpanID: "span-1", + TraceID: "trace-1", + WorkspaceID: "123", + StartTime: 1640995200000, + }, + { + SpanID: "span-2", + TraceID: "trace-1", + WorkspaceID: "123", + StartTime: 1640995300000, + }, + }, + PageToken: "next-token", + HasMore: true, + }, nil) + + buildHelper := NewTraceFilterProcessorBuilder(filterFactoryMock, nil, nil, nil, nil, nil, nil) + + return fields{ + traceRepo: mockRepo, + buildHelper: buildHelper, + } + }, + args: args{ + ctx: context.Background(), + req: &ListSpansOApiReq{ + WorkspaceID: 123, + Tenants: []string{"tenant1"}, + StartTime: 1640995200000, + EndTime: 1640995800000, + Filters: &loop_span.FilterFields{ + FilterFields: []*loop_span.FilterField{ + { + FieldName: "span_type", + FieldType: loop_span.FieldTypeString, + Values: []string{"model"}, + QueryType: ptr.Of(loop_span.QueryTypeEnumIn), + }, + }, + }, + Limit: 100, + PlatformType: loop_span.PlatformCozeLoop, + SpanListType: loop_span.SpanListTypeAllSpan, + }, + }, + want: &ListSpansOApiResp{ + Spans: loop_span.SpanList{ + { + SpanID: "span-1", + TraceID: "trace-1", + WorkspaceID: "123", + StartTime: 1640995200000, + }, + { + SpanID: "span-2", + TraceID: "trace-1", + WorkspaceID: "123", + StartTime: 1640995300000, + }, + }, + NextPageToken: "next-token", + HasMore: true, + }, + wantErr: false, + }, + { + name: "list spans returns empty when builtin filter returns nil", + fieldsGetter: func(ctrl *gomock.Controller) fields { + filterFactoryMock := filtermocks.NewMockPlatformFilterFactory(ctrl) + mockFilter := filtermocks.NewMockFilter(ctrl) + + filterFactoryMock.EXPECT(). + GetFilter(gomock.Any(), loop_span.PlatformCozeLoop). + Return(mockFilter, nil) + + mockFilter.EXPECT(). + BuildBasicSpanFilter(gomock.Any(), gomock.Any()). + Return([]*loop_span.FilterField{}, false, nil) + + buildHelper := NewTraceFilterProcessorBuilder(filterFactoryMock, nil, nil, nil, nil, nil, nil) + + return fields{ + buildHelper: buildHelper, + } + }, + args: args{ + ctx: context.Background(), + req: &ListSpansOApiReq{ + WorkspaceID: 123, + Tenants: []string{"tenant1"}, + StartTime: 1640995200000, + EndTime: 1640995800000, + Limit: 100, + PlatformType: loop_span.PlatformCozeLoop, + SpanListType: loop_span.SpanListTypeAllSpan, + }, + }, + want: &ListSpansOApiResp{ + Spans: loop_span.SpanList{}, + }, + wantErr: false, + }, + { + name: "list spans failed due to platform filter error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + filterFactoryMock := filtermocks.NewMockPlatformFilterFactory(ctrl) + + filterFactoryMock.EXPECT(). + GetFilter(gomock.Any(), loop_span.PlatformCozeLoop). + Return(nil, errorx.NewByCode(obErrorx.CommercialCommonInternalErrorCodeCode)) + + buildHelper := NewTraceFilterProcessorBuilder(filterFactoryMock, nil, nil, nil, nil, nil, nil) + + return fields{ + buildHelper: buildHelper, + } + }, + args: args{ + ctx: context.Background(), + req: &ListSpansOApiReq{ + WorkspaceID: 123, + Tenants: []string{"tenant1"}, + StartTime: 1640995200000, + EndTime: 1640995800000, + Limit: 100, + PlatformType: loop_span.PlatformCozeLoop, + SpanListType: loop_span.SpanListTypeAllSpan, + }, + }, + wantErr: true, + }, + { + name: "list spans failed due to repo error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockRepo := repomocks.NewMockITraceRepo(ctrl) + filterFactoryMock := filtermocks.NewMockPlatformFilterFactory(ctrl) + mockFilter := filtermocks.NewMockFilter(ctrl) + + filterFactoryMock.EXPECT(). + GetFilter(gomock.Any(), loop_span.PlatformCozeLoop). + Return(mockFilter, nil) + + mockFilter.EXPECT(). + BuildBasicSpanFilter(gomock.Any(), gomock.Any()). + Return([]*loop_span.FilterField{}, true, nil) + + mockFilter.EXPECT(). + BuildALLSpanFilter(gomock.Any(), gomock.Any()). + Return([]*loop_span.FilterField{}, nil) + + mockRepo.EXPECT(). + ListSpans(gomock.Any(), gomock.Any()). + Return(nil, errorx.NewByCode(obErrorx.CommercialCommonInternalErrorCodeCode)) + + buildHelper := NewTraceFilterProcessorBuilder(filterFactoryMock, nil, nil, nil, nil, nil, nil) + + return fields{ + traceRepo: mockRepo, + buildHelper: buildHelper, + } + }, + args: args{ + ctx: context.Background(), + req: &ListSpansOApiReq{ + WorkspaceID: 123, + Tenants: []string{"tenant1"}, + StartTime: 1640995200000, + EndTime: 1640995800000, + Limit: 100, + PlatformType: loop_span.PlatformCozeLoop, + SpanListType: loop_span.SpanListTypeAllSpan, + }, + }, + wantErr: true, + }, + { + name: "list spans with pagination", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockRepo := repomocks.NewMockITraceRepo(ctrl) + filterFactoryMock := filtermocks.NewMockPlatformFilterFactory(ctrl) + mockFilter := filtermocks.NewMockFilter(ctrl) + + filterFactoryMock.EXPECT(). + GetFilter(gomock.Any(), loop_span.PlatformCozeLoop). + Return(mockFilter, nil) + + mockFilter.EXPECT(). + BuildBasicSpanFilter(gomock.Any(), gomock.Any()). + Return([]*loop_span.FilterField{}, true, nil) + + mockFilter.EXPECT(). + BuildALLSpanFilter(gomock.Any(), gomock.Any()). + Return([]*loop_span.FilterField{}, nil) + + mockRepo.EXPECT(). + ListSpans(gomock.Any(), gomock.Any()). + Return(&repo.ListSpansResult{ + Spans: []*loop_span.Span{ + { + SpanID: "span-3", + TraceID: "trace-1", + WorkspaceID: "123", + StartTime: 1640995400000, + }, + }, + PageToken: "page-token-2", + HasMore: false, + }, nil) + + buildHelper := NewTraceFilterProcessorBuilder(filterFactoryMock, nil, nil, nil, nil, nil, nil) + + return fields{ + traceRepo: mockRepo, + buildHelper: buildHelper, + } + }, + args: args{ + ctx: context.Background(), + req: &ListSpansOApiReq{ + WorkspaceID: 123, + Tenants: []string{"tenant1"}, + StartTime: 1640995200000, + EndTime: 1640995800000, + Limit: 10, + DescByStartTime: true, + PageToken: "page-token-1", + PlatformType: loop_span.PlatformCozeLoop, + SpanListType: loop_span.SpanListTypeAllSpan, + }, + }, + want: &ListSpansOApiResp{ + Spans: loop_span.SpanList{ + { + SpanID: "span-3", + TraceID: "trace-1", + WorkspaceID: "123", + StartTime: 1640995400000, + }, + }, + NextPageToken: "page-token-2", + HasMore: false, + }, + wantErr: false, + }, + { + name: "list spans with third party workspace id", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockRepo := repomocks.NewMockITraceRepo(ctrl) + filterFactoryMock := filtermocks.NewMockPlatformFilterFactory(ctrl) + mockFilter := filtermocks.NewMockFilter(ctrl) + + filterFactoryMock.EXPECT(). + GetFilter(gomock.Any(), loop_span.PlatformCozeLoop). + Return(mockFilter, nil) + + mockFilter.EXPECT(). + BuildBasicSpanFilter(gomock.Any(), gomock.Any()). + Return([]*loop_span.FilterField{}, true, nil) + + mockFilter.EXPECT(). + BuildALLSpanFilter(gomock.Any(), gomock.Any()). + Return([]*loop_span.FilterField{}, nil) + + mockRepo.EXPECT(). + ListSpans(gomock.Any(), gomock.Any()). + Return(&repo.ListSpansResult{ + Spans: []*loop_span.Span{ + { + SpanID: "span-1", + TraceID: "trace-1", + WorkspaceID: "123", + StartTime: 1640995200000, + }, + }, + HasMore: false, + }, nil) + + buildHelper := NewTraceFilterProcessorBuilder(filterFactoryMock, nil, nil, nil, nil, nil, nil) + + return fields{ + traceRepo: mockRepo, + buildHelper: buildHelper, + } + }, + args: args{ + ctx: context.Background(), + req: &ListSpansOApiReq{ + WorkspaceID: 123, + ThirdPartyWorkspaceID: "third-party-ws-1", + Tenants: []string{"tenant1"}, + StartTime: 1640995200000, + EndTime: 1640995800000, + Limit: 100, + PlatformType: loop_span.PlatformCozeLoop, + SpanListType: loop_span.SpanListTypeAllSpan, + }, + }, + want: &ListSpansOApiResp{ + Spans: loop_span.SpanList{ + { + SpanID: "span-1", + TraceID: "trace-1", + WorkspaceID: "123", + StartTime: 1640995200000, + }, + }, + NextPageToken: "", + HasMore: false, + }, + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -4160,7 +4497,7 @@ func TestTraceServiceImpl_batchGetPreSpanFromCk(t *testing.T) { r := &TraceServiceImpl{ traceRepo: fields.traceRepo, } - got, err := r.batchGetPreSpan(tt.args.ctx, tt.args.spanIDs, tt.args.tenants, tt.args.startTime) + got, err := r.batchGetPreSpan(tt.args.ctx, tt.args.spanIDs, tt.args.tenants, tt.args.startTime-timeutil.Day2MillSec(30), tt.args.startTime+1) assert.Equal(t, tt.wantErr, err != nil) if !tt.wantErr { assert.Equal(t, len(tt.want), len(got)) @@ -4714,3 +5051,792 @@ func TestTraceServiceImpl_checkGetPreSpanAuth_Comprehensive(t *testing.T) { }) } } + +func TestTraceServiceImpl_ListPreSpanBatch(t *testing.T) { + type fields struct { + traceRepo repo.ITraceRepo + buildHelper TraceFilterProcessorBuilder + tenantProvider tenant.ITenantProvider + } + type args struct { + ctx context.Context + req *ListPreSpanBatchReq + } + tests := []struct { + name string + fieldsGetter func(ctrl *gomock.Controller) fields + args args + want *ListPreSpanBatchResp + wantErr bool + }{ + { + name: "empty items - should return empty results", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), + req: &ListPreSpanBatchReq{ + WorkspaceID: 1, + StartTime: time.Now().UnixMilli(), + Items: []*ListPreSpanItem{}, + PlatformType: loop_span.PlatformCozeLoop, + }, + }, + want: &ListPreSpanBatchResp{ + Results: []*ListPreSpanResult{}, + }, + wantErr: false, + }, + { + name: "single item - successful query", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockRepo := repomocks.NewMockITraceRepo(ctrl) + mockTenantProvider := tenantmocks.NewMockITenantProvider(ctrl) + mockFilterFactory := filtermocks.NewMockPlatformFilterFactory(ctrl) + mockBuilder := NewTraceFilterProcessorBuilder(mockFilterFactory, nil, nil, nil, nil, nil, nil) + + // Mock GetTenantsByPlatformType + mockTenantProvider.EXPECT(). + GetTenantsByPlatformType(gomock.Any(), loop_span.PlatformCozeLoop). + Return([]string{"tenant1"}, nil) + + // Mock GetPreSpanIDs + mockRepo.EXPECT(). + GetPreSpanIDs(gomock.Any(), &repo.GetPreSpanIDsParam{ + PreRespID: "prev-resp-1", + }). + Return([]string{"span-0"}, []string{"resp-0"}, nil) + + // Mock ListSpans - will be called in batchGetPreSpan + mockRepo.EXPECT(). + ListSpans(gomock.Any(), gomock.Any()). + Return(&repo.ListSpansResult{ + Spans: []*loop_span.Span{ + { + SpanID: "span-0", + TraceID: "trace-1", + WorkspaceID: "1", + SystemTagsString: map[string]string{ + keyResponseID: "resp-0", + keyPreviousResponseID: "", + }, + }, + { + SpanID: "span-1", + TraceID: "trace-1", + WorkspaceID: "1", + SystemTagsString: map[string]string{ + keyResponseID: "resp-1", + keyPreviousResponseID: "prev-resp-1", + }, + }, + }, + }, nil). + AnyTimes() + + return fields{ + traceRepo: mockRepo, + buildHelper: mockBuilder, + tenantProvider: mockTenantProvider, + } + }, + args: args{ + ctx: context.Background(), + req: &ListPreSpanBatchReq{ + WorkspaceID: 1, + StartTime: time.Now().UnixMilli(), + Items: []*ListPreSpanItem{ + { + TraceID: "trace-1", + SpanID: "span-1", + PreviousResponseID: "prev-resp-1", + }, + }, + PlatformType: loop_span.PlatformCozeLoop, + }, + }, + want: &ListPreSpanBatchResp{ + Results: []*ListPreSpanResult{ + { + TraceID: "trace-1", + SpanID: "span-1", + PreviousResponseID: "prev-resp-1", + Spans: loop_span.SpanList{ + { + SpanID: "span-0", + TraceID: "trace-1", + WorkspaceID: "1", + SystemTagsString: map[string]string{ + keyResponseID: "resp-0", + keyPreviousResponseID: "", + }, + }, + }, + Error: nil, + }, + }, + }, + wantErr: false, + }, + { + name: "multiple items - all successful", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockRepo := repomocks.NewMockITraceRepo(ctrl) + mockTenantProvider := tenantmocks.NewMockITenantProvider(ctrl) + mockFilterFactory := filtermocks.NewMockPlatformFilterFactory(ctrl) + mockBuilder := NewTraceFilterProcessorBuilder(mockFilterFactory, nil, nil, nil, nil, nil, nil) + + mockTenantProvider.EXPECT(). + GetTenantsByPlatformType(gomock.Any(), loop_span.PlatformCozeLoop). + Return([]string{"tenant1"}, nil) + + // Mock GetPreSpanIDs for item 1 + mockRepo.EXPECT(). + GetPreSpanIDs(gomock.Any(), &repo.GetPreSpanIDsParam{ + PreRespID: "prev-resp-1", + }). + Return([]string{"span-0"}, []string{"resp-0"}, nil) + + // Mock GetPreSpanIDs for item 2 + mockRepo.EXPECT(). + GetPreSpanIDs(gomock.Any(), &repo.GetPreSpanIDsParam{ + PreRespID: "prev-resp-2", + }). + Return([]string{"span-10"}, []string{"resp-10"}, nil) + + mockRepo.EXPECT(). + ListSpans(gomock.Any(), gomock.Any()). + Return(&repo.ListSpansResult{ + Spans: []*loop_span.Span{ + { + SpanID: "span-0", + TraceID: "trace-1", + WorkspaceID: "1", + SystemTagsString: map[string]string{ + keyResponseID: "resp-0", + }, + }, + { + SpanID: "span-1", + TraceID: "trace-1", + WorkspaceID: "1", + SystemTagsString: map[string]string{ + keyResponseID: "resp-1", + keyPreviousResponseID: "prev-resp-1", + }, + }, + { + SpanID: "span-10", + TraceID: "trace-2", + WorkspaceID: "1", + SystemTagsString: map[string]string{ + keyResponseID: "resp-10", + }, + }, + { + SpanID: "span-2", + TraceID: "trace-2", + WorkspaceID: "1", + SystemTagsString: map[string]string{ + keyResponseID: "resp-2", + keyPreviousResponseID: "prev-resp-2", + }, + }, + }, + }, nil). + AnyTimes() + + return fields{ + traceRepo: mockRepo, + buildHelper: mockBuilder, + tenantProvider: mockTenantProvider, + } + }, + args: args{ + ctx: context.Background(), + req: &ListPreSpanBatchReq{ + WorkspaceID: 1, + StartTime: time.Now().UnixMilli(), + Items: []*ListPreSpanItem{ + { + TraceID: "trace-1", + SpanID: "span-1", + PreviousResponseID: "prev-resp-1", + }, + { + TraceID: "trace-2", + SpanID: "span-2", + PreviousResponseID: "prev-resp-2", + }, + }, + PlatformType: loop_span.PlatformCozeLoop, + }, + }, + want: &ListPreSpanBatchResp{ + Results: []*ListPreSpanResult{ + { + TraceID: "trace-1", + SpanID: "span-1", + PreviousResponseID: "prev-resp-1", + Spans: loop_span.SpanList{ + { + SpanID: "span-0", + TraceID: "trace-1", + WorkspaceID: "1", + SystemTagsString: map[string]string{ + keyResponseID: "resp-0", + }, + }, + }, + Error: nil, + }, + { + TraceID: "trace-2", + SpanID: "span-2", + PreviousResponseID: "prev-resp-2", + Spans: loop_span.SpanList{ + { + SpanID: "span-10", + TraceID: "trace-2", + WorkspaceID: "1", + SystemTagsString: map[string]string{ + keyResponseID: "resp-10", + }, + }, + }, + Error: nil, + }, + }, + }, + wantErr: false, + }, + { + name: "tenant provider error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockTenantProvider := tenantmocks.NewMockITenantProvider(ctrl) + mockTenantProvider.EXPECT(). + GetTenantsByPlatformType(gomock.Any(), loop_span.PlatformCozeLoop). + Return(nil, errorx.NewByCode(obErrorx.CommercialCommonInternalErrorCodeCode)) + + return fields{ + tenantProvider: mockTenantProvider, + } + }, + args: args{ + ctx: context.Background(), + req: &ListPreSpanBatchReq{ + WorkspaceID: 1, + StartTime: time.Now().UnixMilli(), + Items: []*ListPreSpanItem{ + { + TraceID: "trace-1", + SpanID: "span-1", + PreviousResponseID: "prev-resp-1", + }, + }, + PlatformType: loop_span.PlatformCozeLoop, + }, + }, + wantErr: true, + }, + { + name: "span_id not found in redis lookup - empty previous_response_id should return per item error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockRepo := repomocks.NewMockITraceRepo(ctrl) + mockTenantProvider := tenantmocks.NewMockITenantProvider(ctrl) + mockFilterFactory := filtermocks.NewMockPlatformFilterFactory(ctrl) + mockBuilder := NewTraceFilterProcessorBuilder(mockFilterFactory, nil, nil, nil, nil, nil, nil) + + mockTenantProvider.EXPECT(). + GetTenantsByPlatformType(gomock.Any(), loop_span.PlatformCozeLoop). + Return([]string{"tenant1"}, nil) + + mockRepo.EXPECT(). + ListSpans(gomock.Any(), gomock.Any()). + Return(&repo.ListSpansResult{ + Spans: []*loop_span.Span{ + { + SpanID: "span-1", + TraceID: "trace-1", + WorkspaceID: "1", + }, + }, + }, nil). + AnyTimes() + + return fields{ + traceRepo: mockRepo, + buildHelper: mockBuilder, + tenantProvider: mockTenantProvider, + } + }, + args: args{ + ctx: context.Background(), + req: &ListPreSpanBatchReq{ + WorkspaceID: 1, + StartTime: time.Now().UnixMilli(), + Items: []*ListPreSpanItem{ + { + TraceID: "trace-1", + SpanID: "span-1", + PreviousResponseID: "", + }, + }, + PlatformType: loop_span.PlatformCozeLoop, + }, + }, + want: &ListPreSpanBatchResp{ + Results: []*ListPreSpanResult{ + { + TraceID: "trace-1", + SpanID: "span-1", + PreviousResponseID: "", + Error: errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode), + }, + }, + }, + wantErr: false, + }, + { + name: "auth check failed - previous_response_id mismatch should return per item error", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockRepo := repomocks.NewMockITraceRepo(ctrl) + mockTenantProvider := tenantmocks.NewMockITenantProvider(ctrl) + mockFilterFactory := filtermocks.NewMockPlatformFilterFactory(ctrl) + mockBuilder := NewTraceFilterProcessorBuilder(mockFilterFactory, nil, nil, nil, nil, nil, nil) + + mockTenantProvider.EXPECT(). + GetTenantsByPlatformType(gomock.Any(), loop_span.PlatformCozeLoop). + Return([]string{"tenant1"}, nil) + + mockRepo.EXPECT(). + GetPreSpanIDs(gomock.Any(), &repo.GetPreSpanIDsParam{ + PreRespID: "prev-resp-1", + }). + Return([]string{"span-0"}, []string{"resp-0"}, nil) + + mockRepo.EXPECT(). + ListSpans(gomock.Any(), gomock.Any()). + Return(&repo.ListSpansResult{ + Spans: []*loop_span.Span{ + { + SpanID: "span-0", + TraceID: "trace-1", + WorkspaceID: "1", + SystemTagsString: map[string]string{ + keyResponseID: "resp-0", + }, + }, + { + SpanID: "span-1", + TraceID: "trace-1", + WorkspaceID: "1", + SystemTagsString: map[string]string{ + keyResponseID: "resp-1", + keyPreviousResponseID: "wrong-prev-resp", + }, + }, + }, + }, nil). + AnyTimes() + + return fields{ + traceRepo: mockRepo, + buildHelper: mockBuilder, + tenantProvider: mockTenantProvider, + } + }, + args: args{ + ctx: context.Background(), + req: &ListPreSpanBatchReq{ + WorkspaceID: 1, + StartTime: time.Now().UnixMilli(), + Items: []*ListPreSpanItem{ + { + TraceID: "trace-1", + SpanID: "span-1", + PreviousResponseID: "prev-resp-1", + }, + }, + PlatformType: loop_span.PlatformCozeLoop, + }, + }, + want: &ListPreSpanBatchResp{ + Results: []*ListPreSpanResult{ + { + TraceID: "trace-1", + SpanID: "span-1", + PreviousResponseID: "prev-resp-1", + Error: errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode), + }, + }, + }, + wantErr: false, + }, + { + name: "multiple items with same previous_response_id - should hit local cache", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockRepo := repomocks.NewMockITraceRepo(ctrl) + mockTenantProvider := tenantmocks.NewMockITenantProvider(ctrl) + mockFilterFactory := filtermocks.NewMockPlatformFilterFactory(ctrl) + mockBuilder := NewTraceFilterProcessorBuilder(mockFilterFactory, nil, nil, nil, nil, nil, nil) + + mockTenantProvider.EXPECT(). + GetTenantsByPlatformType(gomock.Any(), loop_span.PlatformCozeLoop). + Return([]string{"tenant1"}, nil) + + mockRepo.EXPECT(). + GetPreSpanIDs(gomock.Any(), &repo.GetPreSpanIDsParam{ + PreRespID: "shared-prev-resp", + }). + Return([]string{"span-0"}, []string{"resp-0"}, nil). + Times(1) + + mockRepo.EXPECT(). + ListSpans(gomock.Any(), gomock.Any()). + Return(&repo.ListSpansResult{ + Spans: []*loop_span.Span{ + { + SpanID: "span-0", + TraceID: "trace-1", + WorkspaceID: "1", + SystemTagsString: map[string]string{ + keyResponseID: "resp-0", + }, + }, + { + SpanID: "span-1", + TraceID: "trace-1", + WorkspaceID: "1", + SystemTagsString: map[string]string{ + keyResponseID: "resp-1", + keyPreviousResponseID: "shared-prev-resp", + }, + }, + { + SpanID: "span-2", + TraceID: "trace-1", + WorkspaceID: "1", + SystemTagsString: map[string]string{ + keyResponseID: "resp-2", + keyPreviousResponseID: "shared-prev-resp", + }, + }, + }, + }, nil). + AnyTimes() + + return fields{ + traceRepo: mockRepo, + buildHelper: mockBuilder, + tenantProvider: mockTenantProvider, + } + }, + args: args{ + ctx: context.Background(), + req: &ListPreSpanBatchReq{ + WorkspaceID: 1, + StartTime: time.Now().UnixMilli(), + Items: []*ListPreSpanItem{ + { + TraceID: "trace-1", + SpanID: "span-1", + PreviousResponseID: "shared-prev-resp", + }, + { + TraceID: "trace-1", + SpanID: "span-2", + PreviousResponseID: "shared-prev-resp", + }, + }, + PlatformType: loop_span.PlatformCozeLoop, + }, + }, + want: &ListPreSpanBatchResp{ + Results: []*ListPreSpanResult{ + { + TraceID: "trace-1", + SpanID: "span-1", + PreviousResponseID: "shared-prev-resp", + Spans: loop_span.SpanList{ + { + SpanID: "span-0", + TraceID: "trace-1", + WorkspaceID: "1", + SystemTagsString: map[string]string{ + keyResponseID: "resp-0", + }, + }, + }, + Error: nil, + }, + { + TraceID: "trace-1", + SpanID: "span-2", + PreviousResponseID: "shared-prev-resp", + Spans: loop_span.SpanList{ + { + SpanID: "span-0", + TraceID: "trace-1", + WorkspaceID: "1", + SystemTagsString: map[string]string{ + keyResponseID: "resp-0", + }, + }, + }, + Error: nil, + }, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + fields := tt.fieldsGetter(ctrl) + r := &TraceServiceImpl{ + traceRepo: fields.traceRepo, + buildHelper: fields.buildHelper, + tenantProvider: fields.tenantProvider, + } + got, err := r.ListPreSpanBatch(tt.args.ctx, tt.args.req) + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.NotNil(t, got) + assert.Equal(t, len(tt.want.Results), len(got.Results)) + for i, wantResult := range tt.want.Results { + gotResult := got.Results[i] + assert.Equal(t, wantResult.TraceID, gotResult.TraceID) + assert.Equal(t, wantResult.SpanID, gotResult.SpanID) + assert.Equal(t, wantResult.PreviousResponseID, gotResult.PreviousResponseID) + if wantResult.Error != nil { + assert.Error(t, gotResult.Error) + } else { + assert.NoError(t, gotResult.Error) + assert.Equal(t, len(wantResult.Spans), len(gotResult.Spans)) + } + } + }) + } +} + +func TestTraceServiceImpl_MergeHistoryMessagesByRespIDBatch(t *testing.T) { + ctx := context.Background() + + t.Run("empty spans - noop", func(t *testing.T) { + r := &TraceServiceImpl{} + err := r.MergeHistoryMessagesByRespIDBatch(ctx, nil, loop_span.PlatformCozeLoop) + assert.NoError(t, err) + }) + + t.Run("non model spans - noop", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockTenantProvider := tenantmocks.NewMockITenantProvider(ctrl) + mockRepo := repomocks.NewMockITraceRepo(ctrl) + mockFilterFactory := filtermocks.NewMockPlatformFilterFactory(ctrl) + mockBuilder := NewTraceFilterProcessorBuilder(mockFilterFactory, nil, nil, nil, nil, nil, nil) + + r := &TraceServiceImpl{ + traceRepo: mockRepo, + buildHelper: mockBuilder, + tenantProvider: mockTenantProvider, + } + + span := &loop_span.Span{ + SpanID: "span-1", + TraceID: "trace-1", + SpanType: loop_span.SpanTypePrompt, + Input: `{"messages":[{"role":"user","content":"cur"}]}`, + SystemTagsString: map[string]string{ + keyPreviousResponseID: "prev-resp-1", + }, + } + err := r.MergeHistoryMessagesByRespIDBatch(ctx, []*loop_span.Span{span}, loop_span.PlatformCozeLoop) + assert.NoError(t, err) + assert.Equal(t, `{"messages":[{"role":"user","content":"cur"}]}`, span.Input) + }) + + t.Run("merge history messages successfully", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRepo := repomocks.NewMockITraceRepo(ctrl) + mockTenantProvider := tenantmocks.NewMockITenantProvider(ctrl) + mockFilterFactory := filtermocks.NewMockPlatformFilterFactory(ctrl) + mockBuilder := NewTraceFilterProcessorBuilder(mockFilterFactory, nil, nil, nil, nil, nil, nil) + + mockTenantProvider.EXPECT(). + GetTenantsByPlatformType(gomock.Any(), loop_span.PlatformCozeLoop). + Return([]string{"tenant1"}, nil) + + mockRepo.EXPECT(). + GetPreSpanIDs(gomock.Any(), &repo.GetPreSpanIDsParam{PreRespID: "prev-resp-1"}). + Return([]string{"span-0"}, []string{"resp-0"}, nil) + + mockRepo.EXPECT(). + ListSpans(gomock.Any(), gomock.Any()). + Return(&repo.ListSpansResult{ + Spans: []*loop_span.Span{ + { + SpanID: "span-0", + TraceID: "trace-1", + WorkspaceID: "1", + Input: `{"messages":[{"role":"system","content":"hist_in"}]}`, + Output: `{"choices":[{"role":"assistant","content":"hist_out"}]}`, + SystemTagsString: map[string]string{ + keyResponseID: "resp-0", + }, + }, + { + SpanID: "span-1", + TraceID: "trace-1", + WorkspaceID: "1", + SystemTagsString: map[string]string{ + keyResponseID: "resp-1", + keyPreviousResponseID: "prev-resp-1", + }, + }, + }, + }, nil). + AnyTimes() + + r := &TraceServiceImpl{ + traceRepo: mockRepo, + buildHelper: mockBuilder, + tenantProvider: mockTenantProvider, + } + + span := &loop_span.Span{ + StartTime: time.Now().UnixMilli(), + WorkspaceID: "1", + SpanID: "span-1", + TraceID: "trace-1", + SpanType: loop_span.SpanTypeModel, + Input: `{"messages":[{"role":"user","content":"cur"}]}`, + SystemTagsString: map[string]string{ + keyPreviousResponseID: "prev-resp-1", + }, + } + + err := r.MergeHistoryMessagesByRespIDBatch(ctx, []*loop_span.Span{span}, loop_span.PlatformCozeLoop) + assert.NoError(t, err) + + var m map[string]interface{} + assert.NoError(t, json.Unmarshal([]byte(span.Input), &m)) + msgs, ok := m["messages"].([]interface{}) + assert.True(t, ok) + assert.Equal(t, 3, len(msgs)) + }) + + t.Run("ListPreSpanBatch returns error - should return error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRepo := repomocks.NewMockITraceRepo(ctrl) + mockTenantProvider := tenantmocks.NewMockITenantProvider(ctrl) + mockFilterFactory := filtermocks.NewMockPlatformFilterFactory(ctrl) + mockBuilder := NewTraceFilterProcessorBuilder(mockFilterFactory, nil, nil, nil, nil, nil, nil) + + mockTenantProvider.EXPECT(). + GetTenantsByPlatformType(gomock.Any(), loop_span.PlatformCozeLoop). + Return(nil, fmt.Errorf("tenant error")) + + r := &TraceServiceImpl{ + traceRepo: mockRepo, + buildHelper: mockBuilder, + tenantProvider: mockTenantProvider, + } + + span := &loop_span.Span{ + StartTime: time.Now().UnixMilli(), + WorkspaceID: "1", + SpanID: "span-1", + TraceID: "trace-1", + SpanType: loop_span.SpanTypeModel, + Input: `{"messages":[{"role":"user","content":"cur"}]}`, + SystemTagsString: map[string]string{ + keyPreviousResponseID: "prev-resp-1", + }, + } + + err := r.MergeHistoryMessagesByRespIDBatch(ctx, []*loop_span.Span{span}, loop_span.PlatformCozeLoop) + assert.Error(t, err) + }) + + t.Run("current span from request overrides CK data - should merge successfully", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRepo := repomocks.NewMockITraceRepo(ctrl) + mockTenantProvider := tenantmocks.NewMockITenantProvider(ctrl) + mockFilterFactory := filtermocks.NewMockPlatformFilterFactory(ctrl) + mockBuilder := NewTraceFilterProcessorBuilder(mockFilterFactory, nil, nil, nil, nil, nil, nil) + + mockTenantProvider.EXPECT(). + GetTenantsByPlatformType(gomock.Any(), loop_span.PlatformCozeLoop). + Return([]string{"tenant1"}, nil) + + mockRepo.EXPECT(). + GetPreSpanIDs(gomock.Any(), &repo.GetPreSpanIDsParam{PreRespID: "prev-resp-1"}). + Return([]string{"span-0"}, []string{"resp-0"}, nil) + + mockRepo.EXPECT(). + ListSpans(gomock.Any(), gomock.Any()). + Return(&repo.ListSpansResult{ + Spans: []*loop_span.Span{ + { + SpanID: "span-0", + TraceID: "trace-1", + WorkspaceID: "1", + Input: `{"messages":[{"role":"system","content":"hist_in"}]}`, + Output: `{"choices":[{"role":"assistant","content":"hist_out"}]}`, + SystemTagsString: map[string]string{ + keyResponseID: "resp-0", + }, + }, + { + SpanID: "span-1", + TraceID: "trace-1", + WorkspaceID: "1", + SystemTagsString: map[string]string{ + keyResponseID: "resp-1", + keyPreviousResponseID: "mismatch", + }, + }, + }, + }, nil). + AnyTimes() + + r := &TraceServiceImpl{ + traceRepo: mockRepo, + buildHelper: mockBuilder, + tenantProvider: mockTenantProvider, + } + + orig := `{"messages":[{"role":"user","content":"cur"}]}` + span := &loop_span.Span{ + StartTime: time.Now().UnixMilli(), + WorkspaceID: "1", + SpanID: "span-1", + TraceID: "trace-1", + SpanType: loop_span.SpanTypeModel, + Input: orig, + SystemTagsString: map[string]string{ + keyPreviousResponseID: "prev-resp-1", + }, + } + + err := r.MergeHistoryMessagesByRespIDBatch(ctx, []*loop_span.Span{span}, loop_span.PlatformCozeLoop) + assert.NoError(t, err) + assert.Equal(t, `{"messages":[{"role":"system","content":"hist_in"},{"role":"assistant","content":"hist_out"},{"role":"user","content":"cur"}]}`, span.Input) + }) +} diff --git a/backend/modules/observability/infra/mq/consumer/task_consumer.go b/backend/modules/observability/infra/mq/consumer/task_consumer.go index 0bfd13b59..a1973d0fe 100644 --- a/backend/modules/observability/infra/mq/consumer/task_consumer.go +++ b/backend/modules/observability/infra/mq/consumer/task_consumer.go @@ -59,6 +59,6 @@ func (e *TaskConsumer) HandleMessage(ctx context.Context, ext *mq.MessageExt) er logs.CtxWarn(ctx, "Task msg json unmarshal fail, raw: %v, err: %s", conv.UnsafeBytesToString(ext.Body), err) return nil } - logs.CtxDebug(ctx, "Span msg,log_id=%s, trace_id=%s, span_id=%s,msgID=%s", event.LogID, event.TraceID, event.SpanID, ext.MsgID) + logs.CtxInfo(ctx, "Span msg,log_id=%s, trace_id=%s, span_id=%s,msgID=%s", event.LogID, event.TraceID, event.SpanID, ext.MsgID) return e.handler.SpanTrigger(ctx, event, nil) } diff --git a/backend/modules/observability/infra/repo/redis/spans.go b/backend/modules/observability/infra/repo/redis/spans.go index 814dc42be..45f57d49f 100644 --- a/backend/modules/observability/infra/repo/redis/spans.go +++ b/backend/modules/observability/infra/repo/redis/spans.go @@ -60,6 +60,7 @@ func (s SpansRedisDaoImpl) GetPreSpans(ctx context.Context, respID string) (span if spanID != "" { preSpanIDs = append(preSpanIDs, spanID) // do not need order, only for select from db } + // 时间升序 respIDByOrder = append([]string{preRespID}, respIDByOrder...) // need order, for order SpanList preRespID = redisValue.PreviousResponseID @@ -68,6 +69,5 @@ func (s SpansRedisDaoImpl) GetPreSpans(ctx context.Context, respID string) (span break } } - return preSpanIDs, respIDByOrder, nil }