diff --git a/backend/cmd/main.go b/backend/cmd/main.go index 8dedc1925..df9c83601 100644 --- a/backend/cmd/main.go +++ b/backend/cmd/main.go @@ -9,6 +9,8 @@ import ( "fmt" "net/url" "os" + "os/signal" + "syscall" "time" "github.com/bytedance/gg/gptr" @@ -60,12 +62,21 @@ func main() { if err := initTracer(handler); err != nil { panic(err) } - consumerWorkers := MustInitConsumerWorkers(c.cfgFactory, handler, handler, handler, handler) - if err := registry.NewConsumerRegistry(c.mqFactory).Register(consumerWorkers).StartAll(ctx); err != nil { + + signalCtx, signalCancel := signal.NotifyContext(ctx, syscall.SIGTERM, syscall.SIGINT) + defer signalCancel() + + r := registry.NewConsumerRegistryWithShutdown(signalCtx, c.mqFactory).Register(MustInitConsumerWorkers(c.cfgFactory, handler, handler, handler, handler)) + if err := r.StartAll(ctx); err != nil { panic(err) } - api.Start(handler) + go api.Start(handler) + <-signalCtx.Done() + + stopCtx, stopCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer stopCancel() + _ = r.StopAll(stopCtx) } type ComponentConfig struct { diff --git a/backend/infra/mq/factory.go b/backend/infra/mq/factory.go index 33980af1c..173c77d2d 100644 --- a/backend/infra/mq/factory.go +++ b/backend/infra/mq/factory.go @@ -31,6 +31,9 @@ type ProducerConfig struct { FlushFrequency time.Duration // How long to wait for the cluster to settle between retries RetryBackoff time.Duration + + AccessKey *string + AccessSecret *string } type ConsumerConfig struct { @@ -50,6 +53,9 @@ type ConsumerConfig struct { ConsumeTimeout time.Duration EnablePPE *bool IsEnabled *bool + + AccessKey *string + AccessSecret *string } type CompressionCodec int diff --git a/backend/infra/mq/mocks/registry.go b/backend/infra/mq/mocks/registry.go index 67c9bf67a..5bb62ae61 100644 --- a/backend/infra/mq/mocks/registry.go +++ b/backend/infra/mq/mocks/registry.go @@ -69,6 +69,20 @@ func (mr *MockConsumerRegistryMockRecorder) StartAll(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartAll", reflect.TypeOf((*MockConsumerRegistry)(nil).StartAll), ctx) } +// StopAll mocks base method. +func (m *MockConsumerRegistry) StopAll(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StopAll", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// StopAll indicates an expected call of StopAll. +func (mr *MockConsumerRegistryMockRecorder) StopAll(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopAll", reflect.TypeOf((*MockConsumerRegistry)(nil).StopAll), ctx) +} + // MockIConsumerWorker is a mock of IConsumerWorker interface. type MockIConsumerWorker struct { ctrl *gomock.Controller diff --git a/backend/infra/mq/registry.go b/backend/infra/mq/registry.go index ee3f9c74c..122d85b29 100644 --- a/backend/infra/mq/registry.go +++ b/backend/infra/mq/registry.go @@ -12,6 +12,7 @@ import ( type ConsumerRegistry interface { Register(worker []IConsumerWorker) ConsumerRegistry StartAll(ctx context.Context) error + StopAll(ctx context.Context) error } type IConsumerWorker interface { diff --git a/backend/infra/mq/registry/registry.go b/backend/infra/mq/registry/registry.go index 9db8c3e78..f319b6300 100644 --- a/backend/infra/mq/registry/registry.go +++ b/backend/infra/mq/registry/registry.go @@ -5,6 +5,7 @@ package registry import ( "context" + "errors" "github.com/coze-dev/coze-loop/backend/infra/mq" "github.com/coze-dev/coze-loop/backend/pkg/errorx" @@ -14,20 +15,27 @@ import ( ) type defaultConsumerRegistry struct { - factory mq.IFactory - workers []mq.IConsumerWorker + factory mq.IFactory + workers []mq.IConsumerWorker + consumers []mq.IConsumer + shutdownCtx context.Context } func NewConsumerRegistry(factory mq.IFactory) mq.ConsumerRegistry { return &defaultConsumerRegistry{factory: factory} } +func NewConsumerRegistryWithShutdown(shutdownCtx context.Context, factory mq.IFactory) mq.ConsumerRegistry { + return &defaultConsumerRegistry{factory: factory, shutdownCtx: shutdownCtx} +} + func (d *defaultConsumerRegistry) Register(worker []mq.IConsumerWorker) mq.ConsumerRegistry { d.workers = append(d.workers, worker...) return d } func (d *defaultConsumerRegistry) StartAll(ctx context.Context) error { + d.consumers = nil for _, worker := range d.workers { cfg, err := worker.ConsumerCfg(ctx) if err != nil { @@ -39,14 +47,47 @@ func (d *defaultConsumerRegistry) StartAll(ctx context.Context) error { return errorx.Wrapf(err, "NewConsumer fail, cfg: %v", json.Jsonify(cfg)) } - consumer.RegisterHandler(newSafeConsumerWrapper(worker)) + handler := newSafeConsumerWrapper(worker) + if d.shutdownCtx != nil { + handler = newShutdownContextWrapper(handler, d.shutdownCtx) + } + consumer.RegisterHandler(handler) if err := consumer.Start(); err != nil { return errorx.Wrapf(err, "StartConsumer fail, cfg: %v", json.Jsonify(cfg)) } + d.consumers = append(d.consumers, consumer) } return nil } +func (d *defaultConsumerRegistry) StopAll(ctx context.Context) error { + if len(d.consumers) == 0 { + return nil + } + var errs []error + for i := len(d.consumers) - 1; i >= 0; i-- { + select { + case <-ctx.Done(): + errs = append(errs, ctx.Err()) + return errors.Join(errs...) + default: + consumer := d.consumers[i] + done := make(chan error, 1) + go func(c mq.IConsumer) { done <- c.Close() }(consumer) + select { + case err := <-done: + if err != nil { + errs = append(errs, err) + } + case <-ctx.Done(): + errs = append(errs, ctx.Err()) + return errors.Join(errs...) + } + } + } + return errors.Join(errs...) +} + type safeConsumerHandlerDecorator struct { handler mq.IConsumerHandler } @@ -59,3 +100,25 @@ func (s *safeConsumerHandlerDecorator) HandleMessage(ctx context.Context, msg *m func newSafeConsumerWrapper(h mq.IConsumerHandler) mq.IConsumerHandler { return &safeConsumerHandlerDecorator{handler: h} } + +type shutdownContextDecorator struct { + handler mq.IConsumerHandler + shutdownCtx context.Context +} + +func (s *shutdownContextDecorator) HandleMessage(ctx context.Context, msg *mq.MessageExt) error { + nctx, cancel := context.WithCancel(ctx) + go func() { + defer goroutine.Recovery(ctx) + select { + case <-ctx.Done(): + case <-s.shutdownCtx.Done(): + } + cancel() + }() + return s.handler.HandleMessage(nctx, msg) +} + +func newShutdownContextWrapper(h mq.IConsumerHandler, shutdownCtx context.Context) mq.IConsumerHandler { + return &shutdownContextDecorator{handler: h, shutdownCtx: shutdownCtx} +} diff --git a/backend/infra/mq/registry/registry_test.go b/backend/infra/mq/registry/registry_test.go index 9db8d4a34..485f55c18 100644 --- a/backend/infra/mq/registry/registry_test.go +++ b/backend/infra/mq/registry/registry_test.go @@ -20,6 +20,7 @@ func TestDefaultConsumerRegistry_StartAll(t *testing.T) { name string workers []mq.IConsumerWorker setupMocks func(*mocks.MockIFactory, []*mocks.MockIConsumer, []*mocks.MockIConsumerWorker) + shutdownCtx context.Context expectedError error }{ { @@ -39,6 +40,21 @@ func TestDefaultConsumerRegistry_StartAll(t *testing.T) { }, expectedError: nil, }, + { + name: "successfully start all workers with shutdown ctx", + workers: []mq.IConsumerWorker{ + mocks.NewMockIConsumerWorker(gomock.NewController(t)), + }, + setupMocks: func(factory *mocks.MockIFactory, consumers []*mocks.MockIConsumer, workers []*mocks.MockIConsumerWorker) { + cfg := &mq.ConsumerConfig{} + workers[0].EXPECT().ConsumerCfg(gomock.Any()).Return(cfg, nil) + consumers[0].EXPECT().RegisterHandler(gomock.Any()).Return() + consumers[0].EXPECT().Start().Return(nil) + factory.EXPECT().NewConsumer(gomock.Any()).Return(consumers[0], nil) + }, + shutdownCtx: context.Background(), + expectedError: nil, + }, { name: "fail to get consumer config", workers: []mq.IConsumerWorker{ @@ -92,9 +108,12 @@ func TestDefaultConsumerRegistry_StartAll(t *testing.T) { } tt.setupMocks(factory, consumers, workers) - - registry := NewConsumerRegistry(factory).Register(tt.workers) - + var registry mq.ConsumerRegistry + if tt.shutdownCtx != nil { + registry = NewConsumerRegistryWithShutdown(tt.shutdownCtx, factory).Register(tt.workers) + } else { + registry = NewConsumerRegistry(factory).Register(tt.workers) + } err := registry.StartAll(context.Background()) if tt.expectedError != nil { assert.Error(t, err) @@ -106,6 +125,69 @@ func TestDefaultConsumerRegistry_StartAll(t *testing.T) { } } +func TestDefaultConsumerRegistry_StopAll(t *testing.T) { + t.Run("no consumers", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + factory := mocks.NewMockIFactory(ctrl) + registry := NewConsumerRegistry(factory) + err := registry.StopAll(context.Background()) + assert.NoError(t, err) + }) + + t.Run("successfully stop all consumers", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + factory := mocks.NewMockIFactory(ctrl) + workers := []mq.IConsumerWorker{ + mocks.NewMockIConsumerWorker(ctrl), + mocks.NewMockIConsumerWorker(ctrl), + } + consumers := []*mocks.MockIConsumer{ + mocks.NewMockIConsumer(ctrl), + mocks.NewMockIConsumer(ctrl), + } + cfg := &mq.ConsumerConfig{} + for i := range workers { + workers[i].(*mocks.MockIConsumerWorker).EXPECT().ConsumerCfg(gomock.Any()).Return(cfg, nil) + factory.EXPECT().NewConsumer(gomock.Any()).Return(consumers[i], nil) + consumers[i].EXPECT().RegisterHandler(gomock.Any()) + consumers[i].EXPECT().Start().Return(nil) + } + registry := NewConsumerRegistry(factory).Register(workers) + err := registry.StartAll(context.Background()) + assert.NoError(t, err) + + // StopAll 按逆序关闭,先关 consumers[1] 再关 consumers[0] + consumers[1].EXPECT().Close().Return(nil) + consumers[0].EXPECT().Close().Return(nil) + err = registry.StopAll(context.Background()) + assert.NoError(t, err) + }) + + t.Run("context cancelled during stop", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + factory := mocks.NewMockIFactory(ctrl) + worker := mocks.NewMockIConsumerWorker(ctrl) + consumer := mocks.NewMockIConsumer(ctrl) + cfg := &mq.ConsumerConfig{} + worker.EXPECT().ConsumerCfg(gomock.Any()).Return(cfg, nil) + factory.EXPECT().NewConsumer(gomock.Any()).Return(consumer, nil) + consumer.EXPECT().RegisterHandler(gomock.Any()) + consumer.EXPECT().Start().Return(nil) + registry := NewConsumerRegistry(factory).Register([]mq.IConsumerWorker{worker}) + err := registry.StartAll(context.Background()) + assert.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err = registry.StopAll(ctx) + assert.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + }) +} + func TestSafeConsumerHandlerDecorator_HandleMessage(t *testing.T) { tests := []struct { name string @@ -150,3 +232,72 @@ func TestSafeConsumerHandlerDecorator_HandleMessage(t *testing.T) { }) } } + +func TestNewConsumerRegistryWithShutdown(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + factory := mocks.NewMockIFactory(ctrl) + shutdownCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + registry := NewConsumerRegistryWithShutdown(shutdownCtx, factory).(*defaultConsumerRegistry) + assert.Equal(t, factory, registry.factory) + assert.Equal(t, shutdownCtx, registry.shutdownCtx) +} + +func TestShutdownContextDecorator_HandleMessage(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockHandler := mocks.NewMockIConsumerWorker(ctrl) + shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) + + decorator := &shutdownContextDecorator{ + handler: mockHandler, + shutdownCtx: shutdownCtx, + } + + tests := []struct { + name string + setupMock func() + triggerCancel func() + ctx context.Context + }{ + { + name: "normal execution", + setupMock: func() { + mockHandler.EXPECT().HandleMessage(gomock.Any(), gomock.Any()).Return(nil) + }, + triggerCancel: func() {}, + ctx: context.Background(), + }, + { + name: "shutdown context cancelled", + setupMock: func() { + mockHandler.EXPECT().HandleMessage(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, msg *mq.MessageExt) error { + <-ctx.Done() + return ctx.Err() + }) + }, + triggerCancel: func() { + shutdownCancel() + }, + ctx: context.Background(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupMock() + go tt.triggerCancel() + err := decorator.HandleMessage(tt.ctx, &mq.MessageExt{}) + if tt.name == "shutdown context cancelled" { + assert.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/backend/kitex_gen/coze/loop/evaluation/domain/common/common.go b/backend/kitex_gen/coze/loop/evaluation/domain/common/common.go index 17a2f7fea..e28bfc994 100644 --- a/backend/kitex_gen/coze/loop/evaluation/domain/common/common.go +++ b/backend/kitex_gen/coze/loop/evaluation/domain/common/common.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/apache/thrift/lib/go/thrift" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/data/domain/dataset" + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/manage" "strings" ) @@ -4988,11 +4989,14 @@ type ModelConfig struct { // 模型id ModelID *int64 `thrift:"model_id,1,optional" frugal:"1,optional,i64" json:"model_id" form:"model_id" query:"model_id"` // 模型名称 - ModelName *string `thrift:"model_name,2,optional" frugal:"2,optional,string" form:"model_name" json:"model_name,omitempty" query:"model_name"` - Temperature *float64 `thrift:"temperature,3,optional" frugal:"3,optional,double" form:"temperature" json:"temperature,omitempty" query:"temperature"` - MaxTokens *int32 `thrift:"max_tokens,4,optional" frugal:"4,optional,i32" form:"max_tokens" json:"max_tokens,omitempty" query:"max_tokens"` - TopP *float64 `thrift:"top_p,5,optional" frugal:"5,optional,double" form:"top_p" json:"top_p,omitempty" query:"top_p"` - JSONExt *string `thrift:"json_ext,50,optional" frugal:"50,optional,string" form:"json_ext" json:"json_ext,omitempty" query:"json_ext"` + ModelName *string `thrift:"model_name,2,optional" frugal:"2,optional,string" form:"model_name" json:"model_name,omitempty" query:"model_name"` + Temperature *float64 `thrift:"temperature,3,optional" frugal:"3,optional,double" form:"temperature" json:"temperature,omitempty" query:"temperature"` + MaxTokens *int32 `thrift:"max_tokens,4,optional" frugal:"4,optional,i32" form:"max_tokens" json:"max_tokens,omitempty" query:"max_tokens"` + TopP *float64 `thrift:"top_p,5,optional" frugal:"5,optional,double" form:"top_p" json:"top_p,omitempty" query:"top_p"` + Protocol *manage.Protocol `thrift:"protocol,6,optional" frugal:"6,optional,string" form:"protocol" json:"protocol,omitempty" query:"protocol"` + Identification *string `thrift:"identification,7,optional" frugal:"7,optional,string" form:"identification" json:"identification,omitempty" query:"identification"` + PresetModel *bool `thrift:"preset_model,8,optional" frugal:"8,optional,bool" form:"preset_model" json:"preset_model,omitempty" query:"preset_model"` + JSONExt *string `thrift:"json_ext,50,optional" frugal:"50,optional,string" form:"json_ext" json:"json_ext,omitempty" query:"json_ext"` } func NewModelConfig() *ModelConfig { @@ -5062,6 +5066,42 @@ func (p *ModelConfig) GetTopP() (v float64) { return *p.TopP } +var ModelConfig_Protocol_DEFAULT manage.Protocol + +func (p *ModelConfig) GetProtocol() (v manage.Protocol) { + if p == nil { + return + } + if !p.IsSetProtocol() { + return ModelConfig_Protocol_DEFAULT + } + return *p.Protocol +} + +var ModelConfig_Identification_DEFAULT string + +func (p *ModelConfig) GetIdentification() (v string) { + if p == nil { + return + } + if !p.IsSetIdentification() { + return ModelConfig_Identification_DEFAULT + } + return *p.Identification +} + +var ModelConfig_PresetModel_DEFAULT bool + +func (p *ModelConfig) GetPresetModel() (v bool) { + if p == nil { + return + } + if !p.IsSetPresetModel() { + return ModelConfig_PresetModel_DEFAULT + } + return *p.PresetModel +} + var ModelConfig_JSONExt_DEFAULT string func (p *ModelConfig) GetJSONExt() (v string) { @@ -5088,6 +5128,15 @@ func (p *ModelConfig) SetMaxTokens(val *int32) { func (p *ModelConfig) SetTopP(val *float64) { p.TopP = val } +func (p *ModelConfig) SetProtocol(val *manage.Protocol) { + p.Protocol = val +} +func (p *ModelConfig) SetIdentification(val *string) { + p.Identification = val +} +func (p *ModelConfig) SetPresetModel(val *bool) { + p.PresetModel = val +} func (p *ModelConfig) SetJSONExt(val *string) { p.JSONExt = val } @@ -5098,6 +5147,9 @@ var fieldIDToName_ModelConfig = map[int16]string{ 3: "temperature", 4: "max_tokens", 5: "top_p", + 6: "protocol", + 7: "identification", + 8: "preset_model", 50: "json_ext", } @@ -5121,6 +5173,18 @@ func (p *ModelConfig) IsSetTopP() bool { return p.TopP != nil } +func (p *ModelConfig) IsSetProtocol() bool { + return p.Protocol != nil +} + +func (p *ModelConfig) IsSetIdentification() bool { + return p.Identification != nil +} + +func (p *ModelConfig) IsSetPresetModel() bool { + return p.PresetModel != nil +} + func (p *ModelConfig) IsSetJSONExt() bool { return p.JSONExt != nil } @@ -5183,6 +5247,30 @@ func (p *ModelConfig) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 6: + if fieldTypeId == thrift.STRING { + if err = p.ReadField6(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 7: + if fieldTypeId == thrift.STRING { + if err = p.ReadField7(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 8: + if fieldTypeId == thrift.BOOL { + if err = p.ReadField8(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } case 50: if fieldTypeId == thrift.STRING { if err = p.ReadField50(iprot); err != nil { @@ -5275,6 +5363,39 @@ func (p *ModelConfig) ReadField5(iprot thrift.TProtocol) error { p.TopP = _field return nil } +func (p *ModelConfig) ReadField6(iprot thrift.TProtocol) error { + + var _field *manage.Protocol + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Protocol = _field + return nil +} +func (p *ModelConfig) ReadField7(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Identification = _field + return nil +} +func (p *ModelConfig) ReadField8(iprot thrift.TProtocol) error { + + var _field *bool + if v, err := iprot.ReadBool(); err != nil { + return err + } else { + _field = &v + } + p.PresetModel = _field + return nil +} func (p *ModelConfig) ReadField50(iprot thrift.TProtocol) error { var _field *string @@ -5313,6 +5434,18 @@ func (p *ModelConfig) Write(oprot thrift.TProtocol) (err error) { fieldId = 5 goto WriteFieldError } + if err = p.writeField6(oprot); err != nil { + fieldId = 6 + goto WriteFieldError + } + if err = p.writeField7(oprot); err != nil { + fieldId = 7 + goto WriteFieldError + } + if err = p.writeField8(oprot); err != nil { + fieldId = 8 + goto WriteFieldError + } if err = p.writeField50(oprot); err != nil { fieldId = 50 goto WriteFieldError @@ -5425,6 +5558,60 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 5 end error: ", p), err) } +func (p *ModelConfig) writeField6(oprot thrift.TProtocol) (err error) { + if p.IsSetProtocol() { + if err = oprot.WriteFieldBegin("protocol", thrift.STRING, 6); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Protocol); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 6 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 6 end error: ", p), err) +} +func (p *ModelConfig) writeField7(oprot thrift.TProtocol) (err error) { + if p.IsSetIdentification() { + if err = oprot.WriteFieldBegin("identification", thrift.STRING, 7); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Identification); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 7 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 7 end error: ", p), err) +} +func (p *ModelConfig) writeField8(oprot thrift.TProtocol) (err error) { + if p.IsSetPresetModel() { + if err = oprot.WriteFieldBegin("preset_model", thrift.BOOL, 8); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteBool(*p.PresetModel); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 8 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 8 end error: ", p), err) +} func (p *ModelConfig) writeField50(oprot thrift.TProtocol) (err error) { if p.IsSetJSONExt() { if err = oprot.WriteFieldBegin("json_ext", thrift.STRING, 50); err != nil { @@ -5473,6 +5660,15 @@ func (p *ModelConfig) DeepEqual(ano *ModelConfig) bool { if !p.Field5DeepEqual(ano.TopP) { return false } + if !p.Field6DeepEqual(ano.Protocol) { + return false + } + if !p.Field7DeepEqual(ano.Identification) { + return false + } + if !p.Field8DeepEqual(ano.PresetModel) { + return false + } if !p.Field50DeepEqual(ano.JSONExt) { return false } @@ -5539,6 +5735,42 @@ func (p *ModelConfig) Field5DeepEqual(src *float64) bool { } return true } +func (p *ModelConfig) Field6DeepEqual(src *manage.Protocol) bool { + + if p.Protocol == src { + return true + } else if p.Protocol == nil || src == nil { + return false + } + if strings.Compare(*p.Protocol, *src) != 0 { + return false + } + return true +} +func (p *ModelConfig) Field7DeepEqual(src *string) bool { + + if p.Identification == src { + return true + } else if p.Identification == nil || src == nil { + return false + } + if strings.Compare(*p.Identification, *src) != 0 { + return false + } + return true +} +func (p *ModelConfig) Field8DeepEqual(src *bool) bool { + + if p.PresetModel == src { + return true + } else if p.PresetModel == nil || src == nil { + return false + } + if *p.PresetModel != *src { + return false + } + return true +} func (p *ModelConfig) Field50DeepEqual(src *string) bool { if p.JSONExt == src { diff --git a/backend/kitex_gen/coze/loop/evaluation/domain/common/k-common.go b/backend/kitex_gen/coze/loop/evaluation/domain/common/k-common.go index 6db2f7501..5c101a144 100644 --- a/backend/kitex_gen/coze/loop/evaluation/domain/common/k-common.go +++ b/backend/kitex_gen/coze/loop/evaluation/domain/common/k-common.go @@ -12,10 +12,12 @@ import ( kutils "github.com/cloudwego/kitex/pkg/utils" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/data/domain/dataset" + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/manage" ) var ( _ = dataset.KitexUnusedProtection + _ = manage.KitexUnusedProtection ) // unused protection @@ -3584,6 +3586,48 @@ func (p *ModelConfig) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 6: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField6(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 7: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField7(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 8: + if fieldTypeId == thrift.BOOL { + l, err = p.FastReadField8(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } case 50: if fieldTypeId == thrift.STRING { l, err = p.FastReadField50(buf[offset:]) @@ -3686,6 +3730,48 @@ func (p *ModelConfig) FastReadField5(buf []byte) (int, error) { return offset, nil } +func (p *ModelConfig) FastReadField6(buf []byte) (int, error) { + offset := 0 + + var _field *manage.Protocol + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Protocol = _field + return offset, nil +} + +func (p *ModelConfig) FastReadField7(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Identification = _field + return offset, nil +} + +func (p *ModelConfig) FastReadField8(buf []byte) (int, error) { + offset := 0 + + var _field *bool + if v, l, err := thrift.Binary.ReadBool(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.PresetModel = _field + return offset, nil +} + func (p *ModelConfig) FastReadField50(buf []byte) (int, error) { offset := 0 @@ -3711,7 +3797,10 @@ func (p *ModelConfig) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset += p.fastWriteField3(buf[offset:], w) offset += p.fastWriteField4(buf[offset:], w) offset += p.fastWriteField5(buf[offset:], w) + offset += p.fastWriteField8(buf[offset:], w) offset += p.fastWriteField2(buf[offset:], w) + offset += p.fastWriteField6(buf[offset:], w) + offset += p.fastWriteField7(buf[offset:], w) offset += p.fastWriteField50(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) @@ -3726,6 +3815,9 @@ func (p *ModelConfig) BLength() int { l += p.field3Length() l += p.field4Length() l += p.field5Length() + l += p.field6Length() + l += p.field7Length() + l += p.field8Length() l += p.field50Length() } l += thrift.Binary.FieldStopLength() @@ -3777,6 +3869,33 @@ func (p *ModelConfig) fastWriteField5(buf []byte, w thrift.NocopyWriter) int { return offset } +func (p *ModelConfig) fastWriteField6(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetProtocol() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 6) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Protocol) + } + return offset +} + +func (p *ModelConfig) fastWriteField7(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetIdentification() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 7) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Identification) + } + return offset +} + +func (p *ModelConfig) fastWriteField8(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetPresetModel() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.BOOL, 8) + offset += thrift.Binary.WriteBool(buf[offset:], *p.PresetModel) + } + return offset +} + func (p *ModelConfig) fastWriteField50(buf []byte, w thrift.NocopyWriter) int { offset := 0 if p.IsSetJSONExt() { @@ -3831,6 +3950,33 @@ func (p *ModelConfig) field5Length() int { return l } +func (p *ModelConfig) field6Length() int { + l := 0 + if p.IsSetProtocol() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Protocol) + } + return l +} + +func (p *ModelConfig) field7Length() int { + l := 0 + if p.IsSetIdentification() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Identification) + } + return l +} + +func (p *ModelConfig) field8Length() int { + l := 0 + if p.IsSetPresetModel() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.BoolLength() + } + return l +} + func (p *ModelConfig) field50Length() int { l := 0 if p.IsSetJSONExt() { @@ -3874,6 +4020,24 @@ func (p *ModelConfig) DeepCopy(s interface{}) error { p.TopP = &tmp } + if src.Protocol != nil { + tmp := *src.Protocol + p.Protocol = &tmp + } + + if src.Identification != nil { + var tmp string + if *src.Identification != "" { + tmp = kutils.StringDeepCopy(*src.Identification) + } + p.Identification = &tmp + } + + if src.PresetModel != nil { + tmp := *src.PresetModel + p.PresetModel = &tmp + } + if src.JSONExt != nil { var tmp string if *src.JSONExt != "" { diff --git a/backend/kitex_gen/coze/loop/evaluation/domain/eval_target/eval_target.go b/backend/kitex_gen/coze/loop/evaluation/domain/eval_target/eval_target.go index 9b7103a0c..94c298c33 100644 --- a/backend/kitex_gen/coze/loop/evaluation/domain/eval_target/eval_target.go +++ b/backend/kitex_gen/coze/loop/evaluation/domain/eval_target/eval_target.go @@ -51,6 +51,8 @@ const ( EvalTargetType_VolcengineAgent EvalTargetType = 5 // 自定义RPC服务 for内场 EvalTargetType_CustomRPCServer EvalTargetType = 6 + // 火山智能体Agentkit + EvalTargetType_VolcengineAgentAgentkit EvalTargetType = 7 ) func (p EvalTargetType) String() string { @@ -67,6 +69,8 @@ func (p EvalTargetType) String() string { return "VolcengineAgent" case EvalTargetType_CustomRPCServer: return "CustomRPCServer" + case EvalTargetType_VolcengineAgentAgentkit: + return "VolcengineAgentAgentkit" } return "" } @@ -85,6 +89,8 @@ func EvalTargetTypeFromString(s string) (EvalTargetType, error) { return EvalTargetType_VolcengineAgent, nil case "CustomRPCServer": return EvalTargetType_CustomRPCServer, nil + case "VolcengineAgentAgentkit": + return EvalTargetType_VolcengineAgentAgentkit, nil } return EvalTargetType(0), fmt.Errorf("not a valid EvalTargetType string") } @@ -4386,8 +4392,9 @@ type VolcengineAgent struct { // DTO使用,不存数据库 VolcengineAgentEndpoints []*VolcengineAgentEndpoint `thrift:"volcengine_agent_endpoints,12,optional" frugal:"12,optional,list" form:"volcengine_agent_endpoints" json:"volcengine_agent_endpoints,omitempty" query:"volcengine_agent_endpoints"` // 注册协议 - Protocol *VolcengineAgentProtocol `thrift:"protocol,13,optional" frugal:"13,optional,string" form:"protocol" json:"protocol,omitempty" query:"protocol"` - BaseInfo *common.BaseInfo `thrift:"base_info,100,optional" frugal:"100,optional,common.BaseInfo" form:"base_info" json:"base_info,omitempty" query:"base_info"` + Protocol *VolcengineAgentProtocol `thrift:"protocol,13,optional" frugal:"13,optional,string" form:"protocol" json:"protocol,omitempty" query:"protocol"` + RuntimeID *string `thrift:"runtime_id,14,optional" frugal:"14,optional,string" form:"runtime_id" json:"runtime_id,omitempty" query:"runtime_id"` + BaseInfo *common.BaseInfo `thrift:"base_info,100,optional" frugal:"100,optional,common.BaseInfo" form:"base_info" json:"base_info,omitempty" query:"base_info"` } func NewVolcengineAgent() *VolcengineAgent { @@ -4457,6 +4464,18 @@ func (p *VolcengineAgent) GetProtocol() (v VolcengineAgentProtocol) { return *p.Protocol } +var VolcengineAgent_RuntimeID_DEFAULT string + +func (p *VolcengineAgent) GetRuntimeID() (v string) { + if p == nil { + return + } + if !p.IsSetRuntimeID() { + return VolcengineAgent_RuntimeID_DEFAULT + } + return *p.RuntimeID +} + var VolcengineAgent_BaseInfo_DEFAULT *common.BaseInfo func (p *VolcengineAgent) GetBaseInfo() (v *common.BaseInfo) { @@ -4483,6 +4502,9 @@ func (p *VolcengineAgent) SetVolcengineAgentEndpoints(val []*VolcengineAgentEndp func (p *VolcengineAgent) SetProtocol(val *VolcengineAgentProtocol) { p.Protocol = val } +func (p *VolcengineAgent) SetRuntimeID(val *string) { + p.RuntimeID = val +} func (p *VolcengineAgent) SetBaseInfo(val *common.BaseInfo) { p.BaseInfo = val } @@ -4493,6 +4515,7 @@ var fieldIDToName_VolcengineAgent = map[int16]string{ 11: "description", 12: "volcengine_agent_endpoints", 13: "protocol", + 14: "runtime_id", 100: "base_info", } @@ -4516,6 +4539,10 @@ func (p *VolcengineAgent) IsSetProtocol() bool { return p.Protocol != nil } +func (p *VolcengineAgent) IsSetRuntimeID() bool { + return p.RuntimeID != nil +} + func (p *VolcengineAgent) IsSetBaseInfo() bool { return p.BaseInfo != nil } @@ -4578,6 +4605,14 @@ func (p *VolcengineAgent) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 14: + if fieldTypeId == thrift.STRING { + if err = p.ReadField14(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } case 100: if fieldTypeId == thrift.STRUCT { if err = p.ReadField100(iprot); err != nil { @@ -4682,6 +4717,17 @@ func (p *VolcengineAgent) ReadField13(iprot thrift.TProtocol) error { p.Protocol = _field return nil } +func (p *VolcengineAgent) ReadField14(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.RuntimeID = _field + return nil +} func (p *VolcengineAgent) ReadField100(iprot thrift.TProtocol) error { _field := common.NewBaseInfo() if err := _field.Read(iprot); err != nil { @@ -4717,6 +4763,10 @@ func (p *VolcengineAgent) Write(oprot thrift.TProtocol) (err error) { fieldId = 13 goto WriteFieldError } + if err = p.writeField14(oprot); err != nil { + fieldId = 14 + goto WriteFieldError + } if err = p.writeField100(oprot); err != nil { fieldId = 100 goto WriteFieldError @@ -4837,6 +4887,24 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 13 end error: ", p), err) } +func (p *VolcengineAgent) writeField14(oprot thrift.TProtocol) (err error) { + if p.IsSetRuntimeID() { + if err = oprot.WriteFieldBegin("runtime_id", thrift.STRING, 14); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.RuntimeID); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 14 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 14 end error: ", p), err) +} func (p *VolcengineAgent) writeField100(oprot thrift.TProtocol) (err error) { if p.IsSetBaseInfo() { if err = oprot.WriteFieldBegin("base_info", thrift.STRUCT, 100); err != nil { @@ -4885,6 +4953,9 @@ func (p *VolcengineAgent) DeepEqual(ano *VolcengineAgent) bool { if !p.Field13DeepEqual(ano.Protocol) { return false } + if !p.Field14DeepEqual(ano.RuntimeID) { + return false + } if !p.Field100DeepEqual(ano.BaseInfo) { return false } @@ -4952,6 +5023,18 @@ func (p *VolcengineAgent) Field13DeepEqual(src *VolcengineAgentProtocol) bool { } return true } +func (p *VolcengineAgent) Field14DeepEqual(src *string) bool { + + if p.RuntimeID == src { + return true + } else if p.RuntimeID == nil || src == nil { + return false + } + if strings.Compare(*p.RuntimeID, *src) != 0 { + return false + } + return true +} func (p *VolcengineAgent) Field100DeepEqual(src *common.BaseInfo) bool { if !p.BaseInfo.DeepEqual(src) { diff --git a/backend/kitex_gen/coze/loop/evaluation/domain/eval_target/k-eval_target.go b/backend/kitex_gen/coze/loop/evaluation/domain/eval_target/k-eval_target.go index a0a9dce96..4b891b1e3 100644 --- a/backend/kitex_gen/coze/loop/evaluation/domain/eval_target/k-eval_target.go +++ b/backend/kitex_gen/coze/loop/evaluation/domain/eval_target/k-eval_target.go @@ -3071,6 +3071,20 @@ func (p *VolcengineAgent) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 14: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField14(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } case 100: if fieldTypeId == thrift.STRUCT { l, err = p.FastReadField100(buf[offset:]) @@ -3184,6 +3198,20 @@ func (p *VolcengineAgent) FastReadField13(buf []byte) (int, error) { return offset, nil } +func (p *VolcengineAgent) FastReadField14(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.RuntimeID = _field + return offset, nil +} + func (p *VolcengineAgent) FastReadField100(buf []byte) (int, error) { offset := 0 _field := common.NewBaseInfo() @@ -3208,6 +3236,7 @@ func (p *VolcengineAgent) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int offset += p.fastWriteField11(buf[offset:], w) offset += p.fastWriteField12(buf[offset:], w) offset += p.fastWriteField13(buf[offset:], w) + offset += p.fastWriteField14(buf[offset:], w) offset += p.fastWriteField100(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) @@ -3222,6 +3251,7 @@ func (p *VolcengineAgent) BLength() int { l += p.field11Length() l += p.field12Length() l += p.field13Length() + l += p.field14Length() l += p.field100Length() } l += thrift.Binary.FieldStopLength() @@ -3280,6 +3310,15 @@ func (p *VolcengineAgent) fastWriteField13(buf []byte, w thrift.NocopyWriter) in return offset } +func (p *VolcengineAgent) fastWriteField14(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetRuntimeID() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 14) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.RuntimeID) + } + return offset +} + func (p *VolcengineAgent) fastWriteField100(buf []byte, w thrift.NocopyWriter) int { offset := 0 if p.IsSetBaseInfo() { @@ -3338,6 +3377,15 @@ func (p *VolcengineAgent) field13Length() int { return l } +func (p *VolcengineAgent) field14Length() int { + l := 0 + if p.IsSetRuntimeID() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.RuntimeID) + } + return l +} + func (p *VolcengineAgent) field100Length() int { l := 0 if p.IsSetBaseInfo() { @@ -3394,6 +3442,14 @@ func (p *VolcengineAgent) DeepCopy(s interface{}) error { p.Protocol = &tmp } + if src.RuntimeID != nil { + var tmp string + if *src.RuntimeID != "" { + tmp = kutils.StringDeepCopy(*src.RuntimeID) + } + p.RuntimeID = &tmp + } + var _baseInfo *common.BaseInfo if src.BaseInfo != nil { _baseInfo = &common.BaseInfo{} diff --git a/backend/kitex_gen/coze/loop/evaluation/domain/evaluator/evaluator.go b/backend/kitex_gen/coze/loop/evaluation/domain/evaluator/evaluator.go index f2b67210f..4c89671c6 100644 --- a/backend/kitex_gen/coze/loop/evaluation/domain/evaluator/evaluator.go +++ b/backend/kitex_gen/coze/loop/evaluation/domain/evaluator/evaluator.go @@ -3785,7 +3785,7 @@ func (p *EvaluatorContent) Field103DeepEqual(src *CustomRPCEvaluator) bool { // 明确有顺序的 evaluator 与版本映射元素 type EvaluatorIDVersionItem struct { EvaluatorID *int64 `thrift:"evaluator_id,1,optional" frugal:"1,optional,i64" json:"evaluator_id" form:"evaluator_id" query:"evaluator_id"` - Version *string `thrift:"version,2,optional" frugal:"2,optional,string" json:"version" form:"version" query:"version"` + Version *string `thrift:"version,2,optional" frugal:"2,optional,string" form:"version" json:"version,omitempty" query:"version"` RunConfig *EvaluatorRunConfig `thrift:"run_config,3,optional" frugal:"3,optional,EvaluatorRunConfig" json:"run_config" form:"run_config" query:"run_config"` EvaluatorVersionID *int64 `thrift:"evaluator_version_id,4,optional" frugal:"4,optional,i64" json:"evaluator_version_id" form:"evaluator_version_id" query:"evaluator_version_id"` ScoreWeight *float64 `thrift:"score_weight,5,optional" frugal:"5,optional,double" json:"score_weight" form:"score_weight" query:"score_weight"` diff --git a/backend/kitex_gen/coze/loop/evaluation/domain/expt/expt.go b/backend/kitex_gen/coze/loop/evaluation/domain/expt/expt.go index b774ffbd3..3bb93f7b7 100644 --- a/backend/kitex_gen/coze/loop/evaluation/domain/expt/expt.go +++ b/backend/kitex_gen/coze/loop/evaluation/domain/expt/expt.go @@ -22457,9 +22457,11 @@ type ExptResultExportRecord struct { BaseInfo *common.BaseInfo `thrift:"base_info,5,optional" frugal:"5,optional,common.BaseInfo" form:"base_info" json:"base_info,omitempty" query:"base_info"` StartTime *int64 `thrift:"start_time,6,optional" frugal:"6,optional,i64" json:"start_time" form:"start_time" query:"start_time"` EndTime *int64 `thrift:"end_time,7,optional" frugal:"7,optional,i64" json:"end_time" form:"end_time" query:"end_time"` - URL *string `thrift:"URL,8,optional" frugal:"8,optional,string" form:"URL" json:"URL,omitempty" query:"URL"` - Expired *bool `thrift:"expired,9,optional" frugal:"9,optional,bool" form:"expired" json:"expired,omitempty" query:"expired"` - Error *RunError `thrift:"error,10,optional" frugal:"10,optional,RunError" form:"error" json:"error,omitempty" query:"error"` + // deprecated, cause not match snake name + URL *string `thrift:"URL,8,optional" frugal:"8,optional,string" form:"URL" json:"URL,omitempty" query:"URL"` + Expired *bool `thrift:"expired,9,optional" frugal:"9,optional,bool" form:"expired" json:"expired,omitempty" query:"expired"` + Error *RunError `thrift:"error,10,optional" frugal:"10,optional,RunError" form:"error" json:"error,omitempty" query:"error"` + URL_ *string `thrift:"url,11,optional" frugal:"11,optional,string" form:"url" json:"url,omitempty" query:"url"` } func NewExptResultExportRecord() *ExptResultExportRecord { @@ -22568,6 +22570,18 @@ func (p *ExptResultExportRecord) GetError() (v *RunError) { } return p.Error } + +var ExptResultExportRecord_URL__DEFAULT string + +func (p *ExptResultExportRecord) GetURL_() (v string) { + if p == nil { + return + } + if !p.IsSetURL_() { + return ExptResultExportRecord_URL__DEFAULT + } + return *p.URL_ +} func (p *ExptResultExportRecord) SetExportID(val int64) { p.ExportID = val } @@ -22598,6 +22612,9 @@ func (p *ExptResultExportRecord) SetExpired(val *bool) { func (p *ExptResultExportRecord) SetError(val *RunError) { p.Error = val } +func (p *ExptResultExportRecord) SetURL_(val *string) { + p.URL_ = val +} var fieldIDToName_ExptResultExportRecord = map[int16]string{ 1: "export_id", @@ -22610,6 +22627,7 @@ var fieldIDToName_ExptResultExportRecord = map[int16]string{ 8: "URL", 9: "expired", 10: "error", + 11: "url", } func (p *ExptResultExportRecord) IsSetBaseInfo() bool { @@ -22636,6 +22654,10 @@ func (p *ExptResultExportRecord) IsSetError() bool { return p.Error != nil } +func (p *ExptResultExportRecord) IsSetURL_() bool { + return p.URL_ != nil +} + func (p *ExptResultExportRecord) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -22742,6 +22764,14 @@ func (p *ExptResultExportRecord) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 11: + if fieldTypeId == thrift.STRING { + if err = p.ReadField11(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError @@ -22896,6 +22926,17 @@ func (p *ExptResultExportRecord) ReadField10(iprot thrift.TProtocol) error { p.Error = _field return nil } +func (p *ExptResultExportRecord) ReadField11(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.URL_ = _field + return nil +} func (p *ExptResultExportRecord) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 @@ -22943,6 +22984,10 @@ func (p *ExptResultExportRecord) Write(oprot thrift.TProtocol) (err error) { fieldId = 10 goto WriteFieldError } + if err = p.writeField11(oprot); err != nil { + fieldId = 11 + goto WriteFieldError + } } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -23133,6 +23178,24 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 10 end error: ", p), err) } +func (p *ExptResultExportRecord) writeField11(oprot thrift.TProtocol) (err error) { + if p.IsSetURL_() { + if err = oprot.WriteFieldBegin("url", thrift.STRING, 11); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.URL_); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 11 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 11 end error: ", p), err) +} func (p *ExptResultExportRecord) String() string { if p == nil { @@ -23178,6 +23241,9 @@ func (p *ExptResultExportRecord) DeepEqual(ano *ExptResultExportRecord) bool { if !p.Field10DeepEqual(ano.Error) { return false } + if !p.Field11DeepEqual(ano.URL_) { + return false + } return true } @@ -23271,6 +23337,18 @@ func (p *ExptResultExportRecord) Field10DeepEqual(src *RunError) bool { } return true } +func (p *ExptResultExportRecord) Field11DeepEqual(src *string) bool { + + if p.URL_ == src { + return true + } else if p.URL_ == nil || src == nil { + return false + } + if strings.Compare(*p.URL_, *src) != 0 { + return false + } + return true +} // 洞察分析记录 type ExptInsightAnalysisRecord struct { diff --git a/backend/kitex_gen/coze/loop/evaluation/domain/expt/k-expt.go b/backend/kitex_gen/coze/loop/evaluation/domain/expt/k-expt.go index 203907e7e..0803ff63e 100644 --- a/backend/kitex_gen/coze/loop/evaluation/domain/expt/k-expt.go +++ b/backend/kitex_gen/coze/loop/evaluation/domain/expt/k-expt.go @@ -15839,6 +15839,20 @@ func (p *ExptResultExportRecord) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 11: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField11(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } default: l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l @@ -16014,6 +16028,20 @@ func (p *ExptResultExportRecord) FastReadField10(buf []byte) (int, error) { return offset, nil } +func (p *ExptResultExportRecord) FastReadField11(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.URL_ = _field + return offset, nil +} + func (p *ExptResultExportRecord) FastWrite(buf []byte) int { return p.FastWriteNocopy(buf, nil) } @@ -16031,6 +16059,7 @@ func (p *ExptResultExportRecord) FastWriteNocopy(buf []byte, w thrift.NocopyWrit offset += p.fastWriteField5(buf[offset:], w) offset += p.fastWriteField8(buf[offset:], w) offset += p.fastWriteField10(buf[offset:], w) + offset += p.fastWriteField11(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset @@ -16049,6 +16078,7 @@ func (p *ExptResultExportRecord) BLength() int { l += p.field8Length() l += p.field9Length() l += p.field10Length() + l += p.field11Length() } l += thrift.Binary.FieldStopLength() return l @@ -16136,6 +16166,15 @@ func (p *ExptResultExportRecord) fastWriteField10(buf []byte, w thrift.NocopyWri return offset } +func (p *ExptResultExportRecord) fastWriteField11(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetURL_() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 11) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.URL_) + } + return offset +} + func (p *ExptResultExportRecord) field1Length() int { l := 0 l += thrift.Binary.FieldBeginLength() @@ -16218,6 +16257,15 @@ func (p *ExptResultExportRecord) field10Length() int { return l } +func (p *ExptResultExportRecord) field11Length() int { + l := 0 + if p.IsSetURL_() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.URL_) + } + return l +} + func (p *ExptResultExportRecord) DeepCopy(s interface{}) error { src, ok := s.(*ExptResultExportRecord) if !ok { @@ -16273,6 +16321,14 @@ func (p *ExptResultExportRecord) DeepCopy(s interface{}) error { } p.Error = _error + if src.URL_ != nil { + var tmp string + if *src.URL_ != "" { + tmp = kutils.StringDeepCopy(*src.URL_) + } + p.URL_ = &tmp + } + return nil } diff --git a/backend/kitex_gen/coze/loop/evaluation/evaluator/coze.loop.evaluation.evaluator.go b/backend/kitex_gen/coze/loop/evaluation/evaluator/coze.loop.evaluation.evaluator.go index 5f35921ac..f6da27dda 100644 --- a/backend/kitex_gen/coze/loop/evaluation/evaluator/coze.loop.evaluation.evaluator.go +++ b/backend/kitex_gen/coze/loop/evaluation/evaluator/coze.loop.evaluation.evaluator.go @@ -2675,9 +2675,10 @@ func (p *GetEvaluatorResponse) Field255DeepEqual(src *base.BaseResp) bool { } type CreateEvaluatorRequest struct { - Evaluator *evaluator.Evaluator `thrift:"evaluator,1,required" frugal:"1,required,evaluator.Evaluator" form:"evaluator,required" json:"evaluator,required"` - Cid *string `thrift:"cid,100,optional" frugal:"100,optional,string" form:"cid" json:"cid,omitempty"` - Base *base.Base `thrift:"Base,255,optional" frugal:"255,optional,base.Base" form:"Base" json:"Base,omitempty" query:"Base"` + Evaluator *evaluator.Evaluator `thrift:"evaluator,1,required" frugal:"1,required,evaluator.Evaluator" form:"evaluator,required" json:"evaluator,required"` + WorkspaceID *int64 `thrift:"workspace_id,2,optional" frugal:"2,optional,i64" json:"workspace_id" form:"workspace_id" ` + Cid *string `thrift:"cid,100,optional" frugal:"100,optional,string" form:"cid" json:"cid,omitempty"` + Base *base.Base `thrift:"Base,255,optional" frugal:"255,optional,base.Base" form:"Base" json:"Base,omitempty" query:"Base"` } func NewCreateEvaluatorRequest() *CreateEvaluatorRequest { @@ -2699,6 +2700,18 @@ func (p *CreateEvaluatorRequest) GetEvaluator() (v *evaluator.Evaluator) { return p.Evaluator } +var CreateEvaluatorRequest_WorkspaceID_DEFAULT int64 + +func (p *CreateEvaluatorRequest) GetWorkspaceID() (v int64) { + if p == nil { + return + } + if !p.IsSetWorkspaceID() { + return CreateEvaluatorRequest_WorkspaceID_DEFAULT + } + return *p.WorkspaceID +} + var CreateEvaluatorRequest_Cid_DEFAULT string func (p *CreateEvaluatorRequest) GetCid() (v string) { @@ -2725,6 +2738,9 @@ func (p *CreateEvaluatorRequest) GetBase() (v *base.Base) { func (p *CreateEvaluatorRequest) SetEvaluator(val *evaluator.Evaluator) { p.Evaluator = val } +func (p *CreateEvaluatorRequest) SetWorkspaceID(val *int64) { + p.WorkspaceID = val +} func (p *CreateEvaluatorRequest) SetCid(val *string) { p.Cid = val } @@ -2734,6 +2750,7 @@ func (p *CreateEvaluatorRequest) SetBase(val *base.Base) { var fieldIDToName_CreateEvaluatorRequest = map[int16]string{ 1: "evaluator", + 2: "workspace_id", 100: "cid", 255: "Base", } @@ -2742,6 +2759,10 @@ func (p *CreateEvaluatorRequest) IsSetEvaluator() bool { return p.Evaluator != nil } +func (p *CreateEvaluatorRequest) IsSetWorkspaceID() bool { + return p.WorkspaceID != nil +} + func (p *CreateEvaluatorRequest) IsSetCid() bool { return p.Cid != nil } @@ -2778,6 +2799,14 @@ func (p *CreateEvaluatorRequest) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 2: + if fieldTypeId == thrift.I64 { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } case 100: if fieldTypeId == thrift.STRING { if err = p.ReadField100(iprot); err != nil { @@ -2837,6 +2866,17 @@ func (p *CreateEvaluatorRequest) ReadField1(iprot thrift.TProtocol) error { p.Evaluator = _field return nil } +func (p *CreateEvaluatorRequest) ReadField2(iprot thrift.TProtocol) error { + + var _field *int64 + if v, err := iprot.ReadI64(); err != nil { + return err + } else { + _field = &v + } + p.WorkspaceID = _field + return nil +} func (p *CreateEvaluatorRequest) ReadField100(iprot thrift.TProtocol) error { var _field *string @@ -2867,6 +2907,10 @@ func (p *CreateEvaluatorRequest) Write(oprot thrift.TProtocol) (err error) { fieldId = 1 goto WriteFieldError } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } if err = p.writeField100(oprot); err != nil { fieldId = 100 goto WriteFieldError @@ -2909,6 +2953,24 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) } +func (p *CreateEvaluatorRequest) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetWorkspaceID() { + if err = oprot.WriteFieldBegin("workspace_id", thrift.I64, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI64(*p.WorkspaceID); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} func (p *CreateEvaluatorRequest) writeField100(oprot thrift.TProtocol) (err error) { if p.IsSetCid() { if err = oprot.WriteFieldBegin("cid", thrift.STRING, 100); err != nil { @@ -2963,6 +3025,9 @@ func (p *CreateEvaluatorRequest) DeepEqual(ano *CreateEvaluatorRequest) bool { if !p.Field1DeepEqual(ano.Evaluator) { return false } + if !p.Field2DeepEqual(ano.WorkspaceID) { + return false + } if !p.Field100DeepEqual(ano.Cid) { return false } @@ -2979,6 +3044,18 @@ func (p *CreateEvaluatorRequest) Field1DeepEqual(src *evaluator.Evaluator) bool } return true } +func (p *CreateEvaluatorRequest) Field2DeepEqual(src *int64) bool { + + if p.WorkspaceID == src { + return true + } else if p.WorkspaceID == nil || src == nil { + return false + } + if *p.WorkspaceID != *src { + return false + } + return true +} func (p *CreateEvaluatorRequest) Field100DeepEqual(src *string) bool { if p.Cid == src { diff --git a/backend/kitex_gen/coze/loop/evaluation/evaluator/k-coze.loop.evaluation.evaluator.go b/backend/kitex_gen/coze/loop/evaluation/evaluator/k-coze.loop.evaluation.evaluator.go index f588c033c..613979335 100644 --- a/backend/kitex_gen/coze/loop/evaluation/evaluator/k-coze.loop.evaluation.evaluator.go +++ b/backend/kitex_gen/coze/loop/evaluation/evaluator/k-coze.loop.evaluation.evaluator.go @@ -2004,6 +2004,20 @@ func (p *CreateEvaluatorRequest) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 2: + if fieldTypeId == thrift.I64 { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } case 100: if fieldTypeId == thrift.STRING { l, err = p.FastReadField100(buf[offset:]) @@ -2068,6 +2082,20 @@ func (p *CreateEvaluatorRequest) FastReadField1(buf []byte) (int, error) { return offset, nil } +func (p *CreateEvaluatorRequest) FastReadField2(buf []byte) (int, error) { + offset := 0 + + var _field *int64 + if v, l, err := thrift.Binary.ReadI64(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.WorkspaceID = _field + return offset, nil +} + func (p *CreateEvaluatorRequest) FastReadField100(buf []byte) (int, error) { offset := 0 @@ -2101,6 +2129,7 @@ func (p *CreateEvaluatorRequest) FastWrite(buf []byte) int { func (p *CreateEvaluatorRequest) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset := 0 if p != nil { + offset += p.fastWriteField2(buf[offset:], w) offset += p.fastWriteField1(buf[offset:], w) offset += p.fastWriteField100(buf[offset:], w) offset += p.fastWriteField255(buf[offset:], w) @@ -2113,6 +2142,7 @@ func (p *CreateEvaluatorRequest) BLength() int { l := 0 if p != nil { l += p.field1Length() + l += p.field2Length() l += p.field100Length() l += p.field255Length() } @@ -2127,6 +2157,15 @@ func (p *CreateEvaluatorRequest) fastWriteField1(buf []byte, w thrift.NocopyWrit return offset } +func (p *CreateEvaluatorRequest) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetWorkspaceID() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I64, 2) + offset += thrift.Binary.WriteI64(buf[offset:], *p.WorkspaceID) + } + return offset +} + func (p *CreateEvaluatorRequest) fastWriteField100(buf []byte, w thrift.NocopyWriter) int { offset := 0 if p.IsSetCid() { @@ -2152,6 +2191,15 @@ func (p *CreateEvaluatorRequest) field1Length() int { return l } +func (p *CreateEvaluatorRequest) field2Length() int { + l := 0 + if p.IsSetWorkspaceID() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I64Length() + } + return l +} + func (p *CreateEvaluatorRequest) field100Length() int { l := 0 if p.IsSetCid() { @@ -2185,6 +2233,11 @@ func (p *CreateEvaluatorRequest) DeepCopy(s interface{}) error { } p.Evaluator = _evaluator + if src.WorkspaceID != nil { + tmp := *src.WorkspaceID + p.WorkspaceID = &tmp + } + if src.Cid != nil { var tmp string if *src.Cid != "" { diff --git a/backend/kitex_gen/coze/loop/evaluation/expt/coze.loop.evaluation.expt.go b/backend/kitex_gen/coze/loop/evaluation/expt/coze.loop.evaluation.expt.go index c9db81e5e..124c98d96 100644 --- a/backend/kitex_gen/coze/loop/evaluation/expt/coze.loop.evaluation.expt.go +++ b/backend/kitex_gen/coze/loop/evaluation/expt/coze.loop.evaluation.expt.go @@ -12646,7 +12646,7 @@ func (p *BatchGetExperimentAggrResultRequest) Field255DeepEqual(src *base.Base) } type BatchGetExperimentAggrResultResponse struct { - ExptAggregateResults []*expt.ExptAggregateResult_ `thrift:"expt_aggregate_results,1,optional" frugal:"1,optional,list" form:"expt_aggregate_result" json:"expt_aggregate_result,omitempty"` + ExptAggregateResult_ []*expt.ExptAggregateResult_ `thrift:"expt_aggregate_result,1,optional" frugal:"1,optional,list" form:"expt_aggregate_result" json:"expt_aggregate_result,omitempty"` BaseResp *base.BaseResp `thrift:"BaseResp,255" frugal:"255,default,base.BaseResp" form:"BaseResp" json:"BaseResp" query:"BaseResp"` } @@ -12657,16 +12657,16 @@ func NewBatchGetExperimentAggrResultResponse() *BatchGetExperimentAggrResultResp func (p *BatchGetExperimentAggrResultResponse) InitDefault() { } -var BatchGetExperimentAggrResultResponse_ExptAggregateResults_DEFAULT []*expt.ExptAggregateResult_ +var BatchGetExperimentAggrResultResponse_ExptAggregateResult__DEFAULT []*expt.ExptAggregateResult_ -func (p *BatchGetExperimentAggrResultResponse) GetExptAggregateResults() (v []*expt.ExptAggregateResult_) { +func (p *BatchGetExperimentAggrResultResponse) GetExptAggregateResult_() (v []*expt.ExptAggregateResult_) { if p == nil { return } - if !p.IsSetExptAggregateResults() { - return BatchGetExperimentAggrResultResponse_ExptAggregateResults_DEFAULT + if !p.IsSetExptAggregateResult_() { + return BatchGetExperimentAggrResultResponse_ExptAggregateResult__DEFAULT } - return p.ExptAggregateResults + return p.ExptAggregateResult_ } var BatchGetExperimentAggrResultResponse_BaseResp_DEFAULT *base.BaseResp @@ -12680,20 +12680,20 @@ func (p *BatchGetExperimentAggrResultResponse) GetBaseResp() (v *base.BaseResp) } return p.BaseResp } -func (p *BatchGetExperimentAggrResultResponse) SetExptAggregateResults(val []*expt.ExptAggregateResult_) { - p.ExptAggregateResults = val +func (p *BatchGetExperimentAggrResultResponse) SetExptAggregateResult_(val []*expt.ExptAggregateResult_) { + p.ExptAggregateResult_ = val } func (p *BatchGetExperimentAggrResultResponse) SetBaseResp(val *base.BaseResp) { p.BaseResp = val } var fieldIDToName_BatchGetExperimentAggrResultResponse = map[int16]string{ - 1: "expt_aggregate_results", + 1: "expt_aggregate_result", 255: "BaseResp", } -func (p *BatchGetExperimentAggrResultResponse) IsSetExptAggregateResults() bool { - return p.ExptAggregateResults != nil +func (p *BatchGetExperimentAggrResultResponse) IsSetExptAggregateResult_() bool { + return p.ExptAggregateResult_ != nil } func (p *BatchGetExperimentAggrResultResponse) IsSetBaseResp() bool { @@ -12783,7 +12783,7 @@ func (p *BatchGetExperimentAggrResultResponse) ReadField1(iprot thrift.TProtocol if err := iprot.ReadListEnd(); err != nil { return err } - p.ExptAggregateResults = _field + p.ExptAggregateResult_ = _field return nil } func (p *BatchGetExperimentAggrResultResponse) ReadField255(iprot thrift.TProtocol) error { @@ -12828,14 +12828,14 @@ WriteStructEndError: } func (p *BatchGetExperimentAggrResultResponse) writeField1(oprot thrift.TProtocol) (err error) { - if p.IsSetExptAggregateResults() { - if err = oprot.WriteFieldBegin("expt_aggregate_results", thrift.LIST, 1); err != nil { + if p.IsSetExptAggregateResult_() { + if err = oprot.WriteFieldBegin("expt_aggregate_result", thrift.LIST, 1); err != nil { goto WriteFieldBeginError } - if err := oprot.WriteListBegin(thrift.STRUCT, len(p.ExptAggregateResults)); err != nil { + if err := oprot.WriteListBegin(thrift.STRUCT, len(p.ExptAggregateResult_)); err != nil { return err } - for _, v := range p.ExptAggregateResults { + for _, v := range p.ExptAggregateResult_ { if err := v.Write(oprot); err != nil { return err } @@ -12884,7 +12884,7 @@ func (p *BatchGetExperimentAggrResultResponse) DeepEqual(ano *BatchGetExperiment } else if p == nil || ano == nil { return false } - if !p.Field1DeepEqual(ano.ExptAggregateResults) { + if !p.Field1DeepEqual(ano.ExptAggregateResult_) { return false } if !p.Field255DeepEqual(ano.BaseResp) { @@ -12895,10 +12895,10 @@ func (p *BatchGetExperimentAggrResultResponse) DeepEqual(ano *BatchGetExperiment func (p *BatchGetExperimentAggrResultResponse) Field1DeepEqual(src []*expt.ExptAggregateResult_) bool { - if len(p.ExptAggregateResults) != len(src) { + if len(p.ExptAggregateResult_) != len(src) { return false } - for i, v := range p.ExptAggregateResults { + for i, v := range p.ExptAggregateResult_ { _src := src[i] if !v.DeepEqual(_src) { return false @@ -21598,7 +21598,7 @@ func (p *ListExperimentTemplatesResponse) Field255DeepEqual(src *base.BaseResp) type CheckExperimentTemplateNameRequest struct { WorkspaceID int64 `thrift:"workspace_id,1,required" frugal:"1,required,i64" json:"workspace_id" form:"workspace_id,required" ` - Name string `thrift:"name,2,required" frugal:"2,required,string" json:"name" form:"name,required" ` + Name string `thrift:"name,2,required" frugal:"2,required,string" form:"name,required" json:"name,required"` TemplateID *int64 `thrift:"template_id,3,optional" frugal:"3,optional,i64" json:"template_id" form:"template_id" ` Base *base.Base `thrift:"Base,255,optional" frugal:"255,optional,base.Base" form:"Base" json:"Base,omitempty" query:"Base"` } @@ -27843,8 +27843,8 @@ func (p *GetExptResultExportRecordRequest) Field255DeepEqual(src *base.Base) boo } type GetExptResultExportRecordResponse struct { - ExptResultExportRecord *expt.ExptResultExportRecord `thrift:"expt_result_export_record,1,optional" frugal:"1,optional,expt.ExptResultExportRecord" form:"expt_result_export_records" json:"expt_result_export_records,omitempty"` - BaseResp *base.BaseResp `thrift:"BaseResp,255" frugal:"255,default,base.BaseResp" form:"BaseResp" json:"BaseResp" query:"BaseResp"` + ExptResultExportRecords *expt.ExptResultExportRecord `thrift:"expt_result_export_records,1,optional" frugal:"1,optional,expt.ExptResultExportRecord" form:"expt_result_export_records" json:"expt_result_export_records,omitempty"` + BaseResp *base.BaseResp `thrift:"BaseResp,255" frugal:"255,default,base.BaseResp" form:"BaseResp" json:"BaseResp" query:"BaseResp"` } func NewGetExptResultExportRecordResponse() *GetExptResultExportRecordResponse { @@ -27854,16 +27854,16 @@ func NewGetExptResultExportRecordResponse() *GetExptResultExportRecordResponse { func (p *GetExptResultExportRecordResponse) InitDefault() { } -var GetExptResultExportRecordResponse_ExptResultExportRecord_DEFAULT *expt.ExptResultExportRecord +var GetExptResultExportRecordResponse_ExptResultExportRecords_DEFAULT *expt.ExptResultExportRecord -func (p *GetExptResultExportRecordResponse) GetExptResultExportRecord() (v *expt.ExptResultExportRecord) { +func (p *GetExptResultExportRecordResponse) GetExptResultExportRecords() (v *expt.ExptResultExportRecord) { if p == nil { return } - if !p.IsSetExptResultExportRecord() { - return GetExptResultExportRecordResponse_ExptResultExportRecord_DEFAULT + if !p.IsSetExptResultExportRecords() { + return GetExptResultExportRecordResponse_ExptResultExportRecords_DEFAULT } - return p.ExptResultExportRecord + return p.ExptResultExportRecords } var GetExptResultExportRecordResponse_BaseResp_DEFAULT *base.BaseResp @@ -27877,20 +27877,20 @@ func (p *GetExptResultExportRecordResponse) GetBaseResp() (v *base.BaseResp) { } return p.BaseResp } -func (p *GetExptResultExportRecordResponse) SetExptResultExportRecord(val *expt.ExptResultExportRecord) { - p.ExptResultExportRecord = val +func (p *GetExptResultExportRecordResponse) SetExptResultExportRecords(val *expt.ExptResultExportRecord) { + p.ExptResultExportRecords = val } func (p *GetExptResultExportRecordResponse) SetBaseResp(val *base.BaseResp) { p.BaseResp = val } var fieldIDToName_GetExptResultExportRecordResponse = map[int16]string{ - 1: "expt_result_export_record", + 1: "expt_result_export_records", 255: "BaseResp", } -func (p *GetExptResultExportRecordResponse) IsSetExptResultExportRecord() bool { - return p.ExptResultExportRecord != nil +func (p *GetExptResultExportRecordResponse) IsSetExptResultExportRecords() bool { + return p.ExptResultExportRecords != nil } func (p *GetExptResultExportRecordResponse) IsSetBaseResp() bool { @@ -27965,7 +27965,7 @@ func (p *GetExptResultExportRecordResponse) ReadField1(iprot thrift.TProtocol) e if err := _field.Read(iprot); err != nil { return err } - p.ExptResultExportRecord = _field + p.ExptResultExportRecords = _field return nil } func (p *GetExptResultExportRecordResponse) ReadField255(iprot thrift.TProtocol) error { @@ -28010,11 +28010,11 @@ WriteStructEndError: } func (p *GetExptResultExportRecordResponse) writeField1(oprot thrift.TProtocol) (err error) { - if p.IsSetExptResultExportRecord() { - if err = oprot.WriteFieldBegin("expt_result_export_record", thrift.STRUCT, 1); err != nil { + if p.IsSetExptResultExportRecords() { + if err = oprot.WriteFieldBegin("expt_result_export_records", thrift.STRUCT, 1); err != nil { goto WriteFieldBeginError } - if err := p.ExptResultExportRecord.Write(oprot); err != nil { + if err := p.ExptResultExportRecords.Write(oprot); err != nil { return err } if err = oprot.WriteFieldEnd(); err != nil { @@ -28058,7 +28058,7 @@ func (p *GetExptResultExportRecordResponse) DeepEqual(ano *GetExptResultExportRe } else if p == nil || ano == nil { return false } - if !p.Field1DeepEqual(ano.ExptResultExportRecord) { + if !p.Field1DeepEqual(ano.ExptResultExportRecords) { return false } if !p.Field255DeepEqual(ano.BaseResp) { @@ -28069,7 +28069,7 @@ func (p *GetExptResultExportRecordResponse) DeepEqual(ano *GetExptResultExportRe func (p *GetExptResultExportRecordResponse) Field1DeepEqual(src *expt.ExptResultExportRecord) bool { - if !p.ExptResultExportRecord.DeepEqual(src) { + if !p.ExptResultExportRecords.DeepEqual(src) { return false } return true diff --git a/backend/kitex_gen/coze/loop/evaluation/expt/coze.loop.evaluation.expt_validator.go b/backend/kitex_gen/coze/loop/evaluation/expt/coze.loop.evaluation.expt_validator.go index dc8db02af..dbd1fb866 100644 --- a/backend/kitex_gen/coze/loop/evaluation/expt/coze.loop.evaluation.expt_validator.go +++ b/backend/kitex_gen/coze/loop/evaluation/expt/coze.loop.evaluation.expt_validator.go @@ -748,9 +748,9 @@ func (p *GetExptResultExportRecordRequest) IsValid() error { return nil } func (p *GetExptResultExportRecordResponse) IsValid() error { - if p.ExptResultExportRecord != nil { - if err := p.ExptResultExportRecord.IsValid(); err != nil { - return fmt.Errorf("field ExptResultExportRecord not valid, %w", err) + if p.ExptResultExportRecords != nil { + if err := p.ExptResultExportRecords.IsValid(); err != nil { + return fmt.Errorf("field ExptResultExportRecords not valid, %w", err) } } if p.BaseResp != nil { diff --git a/backend/kitex_gen/coze/loop/evaluation/expt/k-coze.loop.evaluation.expt.go b/backend/kitex_gen/coze/loop/evaluation/expt/k-coze.loop.evaluation.expt.go index 5bf1546f8..b7d619af5 100644 --- a/backend/kitex_gen/coze/loop/evaluation/expt/k-coze.loop.evaluation.expt.go +++ b/backend/kitex_gen/coze/loop/evaluation/expt/k-coze.loop.evaluation.expt.go @@ -9412,7 +9412,7 @@ func (p *BatchGetExperimentAggrResultResponse) FastReadField1(buf []byte) (int, _field = append(_field, _elem) } - p.ExptAggregateResults = _field + p.ExptAggregateResult_ = _field return offset, nil } @@ -9454,12 +9454,12 @@ func (p *BatchGetExperimentAggrResultResponse) BLength() int { func (p *BatchGetExperimentAggrResultResponse) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { offset := 0 - if p.IsSetExptAggregateResults() { + if p.IsSetExptAggregateResult_() { offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.LIST, 1) listBeginOffset := offset offset += thrift.Binary.ListBeginLength() var length int - for _, v := range p.ExptAggregateResults { + for _, v := range p.ExptAggregateResult_ { length++ offset += v.FastWriteNocopy(buf[offset:], w) } @@ -9477,10 +9477,10 @@ func (p *BatchGetExperimentAggrResultResponse) fastWriteField255(buf []byte, w t func (p *BatchGetExperimentAggrResultResponse) field1Length() int { l := 0 - if p.IsSetExptAggregateResults() { + if p.IsSetExptAggregateResult_() { l += thrift.Binary.FieldBeginLength() l += thrift.Binary.ListBeginLength() - for _, v := range p.ExptAggregateResults { + for _, v := range p.ExptAggregateResult_ { _ = v l += v.BLength() } @@ -9501,9 +9501,9 @@ func (p *BatchGetExperimentAggrResultResponse) DeepCopy(s interface{}) error { return fmt.Errorf("%T's type not matched %T", s, p) } - if src.ExptAggregateResults != nil { - p.ExptAggregateResults = make([]*expt.ExptAggregateResult_, 0, len(src.ExptAggregateResults)) - for _, elem := range src.ExptAggregateResults { + if src.ExptAggregateResult_ != nil { + p.ExptAggregateResult_ = make([]*expt.ExptAggregateResult_, 0, len(src.ExptAggregateResult_)) + for _, elem := range src.ExptAggregateResult_ { var _elem *expt.ExptAggregateResult_ if elem != nil { _elem = &expt.ExptAggregateResult_{} @@ -9512,7 +9512,7 @@ func (p *BatchGetExperimentAggrResultResponse) DeepCopy(s interface{}) error { } } - p.ExptAggregateResults = append(p.ExptAggregateResults, _elem) + p.ExptAggregateResult_ = append(p.ExptAggregateResult_, _elem) } } @@ -20478,7 +20478,7 @@ func (p *GetExptResultExportRecordResponse) FastReadField1(buf []byte) (int, err } else { offset += l } - p.ExptResultExportRecord = _field + p.ExptResultExportRecords = _field return offset, nil } @@ -20520,9 +20520,9 @@ func (p *GetExptResultExportRecordResponse) BLength() int { func (p *GetExptResultExportRecordResponse) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { offset := 0 - if p.IsSetExptResultExportRecord() { + if p.IsSetExptResultExportRecords() { offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 1) - offset += p.ExptResultExportRecord.FastWriteNocopy(buf[offset:], w) + offset += p.ExptResultExportRecords.FastWriteNocopy(buf[offset:], w) } return offset } @@ -20536,9 +20536,9 @@ func (p *GetExptResultExportRecordResponse) fastWriteField255(buf []byte, w thri func (p *GetExptResultExportRecordResponse) field1Length() int { l := 0 - if p.IsSetExptResultExportRecord() { + if p.IsSetExptResultExportRecords() { l += thrift.Binary.FieldBeginLength() - l += p.ExptResultExportRecord.BLength() + l += p.ExptResultExportRecords.BLength() } return l } @@ -20556,14 +20556,14 @@ func (p *GetExptResultExportRecordResponse) DeepCopy(s interface{}) error { return fmt.Errorf("%T's type not matched %T", s, p) } - var _exptResultExportRecord *expt.ExptResultExportRecord - if src.ExptResultExportRecord != nil { - _exptResultExportRecord = &expt.ExptResultExportRecord{} - if err := _exptResultExportRecord.DeepCopy(src.ExptResultExportRecord); err != nil { + var _exptResultExportRecords *expt.ExptResultExportRecord + if src.ExptResultExportRecords != nil { + _exptResultExportRecords = &expt.ExptResultExportRecord{} + if err := _exptResultExportRecords.DeepCopy(src.ExptResultExportRecords); err != nil { return err } } - p.ExptResultExportRecord = _exptResultExportRecord + p.ExptResultExportRecords = _exptResultExportRecords var _baseResp *base.BaseResp if src.BaseResp != nil { diff --git a/backend/kitex_gen/coze/loop/llm/domain/manage/k-manage.go b/backend/kitex_gen/coze/loop/llm/domain/manage/k-manage.go index d501fc201..e836d8c83 100644 --- a/backend/kitex_gen/coze/loop/llm/domain/manage/k-manage.go +++ b/backend/kitex_gen/coze/loop/llm/domain/manage/k-manage.go @@ -170,6 +170,174 @@ func (p *Model) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 10: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField10(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 11: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField11(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 12: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField12(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 13: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField13(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 14: + if fieldTypeId == thrift.LIST { + l, err = p.FastReadField14(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 15: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField15(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 16: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField16(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 17: + if fieldTypeId == thrift.BOOL { + l, err = p.FastReadField17(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 100: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField100(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 101: + if fieldTypeId == thrift.I64 { + l, err = p.FastReadField101(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 102: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField102(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 103: + if fieldTypeId == thrift.I64 { + l, err = p.FastReadField103(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } default: l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l @@ -327,97 +495,295 @@ func (p *Model) FastReadField9(buf []byte) (int, error) { return offset, nil } -func (p *Model) FastWrite(buf []byte) int { - return p.FastWriteNocopy(buf, nil) -} - -func (p *Model) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { +func (p *Model) FastReadField10(buf []byte) (int, error) { offset := 0 - if p != nil { - offset += p.fastWriteField1(buf[offset:], w) - offset += p.fastWriteField2(buf[offset:], w) - offset += p.fastWriteField3(buf[offset:], w) - offset += p.fastWriteField4(buf[offset:], w) - offset += p.fastWriteField5(buf[offset:], w) - offset += p.fastWriteField6(buf[offset:], w) - offset += p.fastWriteField7(buf[offset:], w) - offset += p.fastWriteField8(buf[offset:], w) - offset += p.fastWriteField9(buf[offset:], w) - } - offset += thrift.Binary.WriteFieldStop(buf[offset:]) - return offset -} -func (p *Model) BLength() int { - l := 0 - if p != nil { - l += p.field1Length() - l += p.field2Length() - l += p.field3Length() - l += p.field4Length() - l += p.field5Length() - l += p.field6Length() - l += p.field7Length() - l += p.field8Length() - l += p.field9Length() + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v } - l += thrift.Binary.FieldStopLength() - return l + p.Identification = _field + return offset, nil } -func (p *Model) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { +func (p *Model) FastReadField11(buf []byte) (int, error) { offset := 0 - if p.IsSetModelID() { - offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I64, 1) - offset += thrift.Binary.WriteI64(buf[offset:], *p.ModelID) + _field := NewSeries() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l } - return offset + p.Series = _field + return offset, nil } -func (p *Model) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { +func (p *Model) FastReadField12(buf []byte) (int, error) { offset := 0 - if p.IsSetWorkspaceID() { - offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I64, 2) - offset += thrift.Binary.WriteI64(buf[offset:], *p.WorkspaceID) + _field := NewVisibility() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l } - return offset + p.Visibility = _field + return offset, nil } -func (p *Model) fastWriteField3(buf []byte, w thrift.NocopyWriter) int { +func (p *Model) FastReadField13(buf []byte) (int, error) { offset := 0 - if p.IsSetName() { - offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 3) - offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Name) + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v } - return offset + p.Icon = _field + return offset, nil } -func (p *Model) fastWriteField4(buf []byte, w thrift.NocopyWriter) int { +func (p *Model) FastReadField14(buf []byte) (int, error) { offset := 0 - if p.IsSetDesc() { - offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 4) - offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Desc) + + _, size, l, err := thrift.Binary.ReadListBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err } - return offset -} + _field := make([]string, 0, size) + for i := 0; i < size; i++ { + var _elem string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _elem = v + } -func (p *Model) fastWriteField5(buf []byte, w thrift.NocopyWriter) int { - offset := 0 - if p.IsSetAbility() { - offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 5) - offset += p.Ability.FastWriteNocopy(buf[offset:], w) + _field = append(_field, _elem) } - return offset + p.Tags = _field + return offset, nil } -func (p *Model) fastWriteField6(buf []byte, w thrift.NocopyWriter) int { +func (p *Model) FastReadField15(buf []byte) (int, error) { offset := 0 - if p.IsSetProtocol() { - offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 6) - offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Protocol) - } - return offset -} + + var _field *ModelStatus + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Status = _field + return offset, nil +} + +func (p *Model) FastReadField16(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.OriginalModelURL = _field + return offset, nil +} + +func (p *Model) FastReadField17(buf []byte) (int, error) { + offset := 0 + + var _field *bool + if v, l, err := thrift.Binary.ReadBool(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.PresetModel = _field + return offset, nil +} + +func (p *Model) FastReadField100(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.CreatedBy = _field + return offset, nil +} + +func (p *Model) FastReadField101(buf []byte) (int, error) { + offset := 0 + + var _field *int64 + if v, l, err := thrift.Binary.ReadI64(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.CreatedAt = _field + return offset, nil +} + +func (p *Model) FastReadField102(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.UpdatedBy = _field + return offset, nil +} + +func (p *Model) FastReadField103(buf []byte) (int, error) { + offset := 0 + + var _field *int64 + if v, l, err := thrift.Binary.ReadI64(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.UpdatedAt = _field + return offset, nil +} + +func (p *Model) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *Model) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + offset += p.fastWriteField17(buf[offset:], w) + offset += p.fastWriteField101(buf[offset:], w) + offset += p.fastWriteField103(buf[offset:], w) + offset += p.fastWriteField3(buf[offset:], w) + offset += p.fastWriteField4(buf[offset:], w) + offset += p.fastWriteField5(buf[offset:], w) + offset += p.fastWriteField6(buf[offset:], w) + offset += p.fastWriteField7(buf[offset:], w) + offset += p.fastWriteField8(buf[offset:], w) + offset += p.fastWriteField9(buf[offset:], w) + offset += p.fastWriteField10(buf[offset:], w) + offset += p.fastWriteField11(buf[offset:], w) + offset += p.fastWriteField12(buf[offset:], w) + offset += p.fastWriteField13(buf[offset:], w) + offset += p.fastWriteField14(buf[offset:], w) + offset += p.fastWriteField15(buf[offset:], w) + offset += p.fastWriteField16(buf[offset:], w) + offset += p.fastWriteField100(buf[offset:], w) + offset += p.fastWriteField102(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *Model) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + l += p.field3Length() + l += p.field4Length() + l += p.field5Length() + l += p.field6Length() + l += p.field7Length() + l += p.field8Length() + l += p.field9Length() + l += p.field10Length() + l += p.field11Length() + l += p.field12Length() + l += p.field13Length() + l += p.field14Length() + l += p.field15Length() + l += p.field16Length() + l += p.field17Length() + l += p.field100Length() + l += p.field101Length() + l += p.field102Length() + l += p.field103Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *Model) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetModelID() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I64, 1) + offset += thrift.Binary.WriteI64(buf[offset:], *p.ModelID) + } + return offset +} + +func (p *Model) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetWorkspaceID() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I64, 2) + offset += thrift.Binary.WriteI64(buf[offset:], *p.WorkspaceID) + } + return offset +} + +func (p *Model) fastWriteField3(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetName() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 3) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Name) + } + return offset +} + +func (p *Model) fastWriteField4(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetDesc() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 4) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Desc) + } + return offset +} + +func (p *Model) fastWriteField5(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetAbility() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 5) + offset += p.Ability.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *Model) fastWriteField6(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetProtocol() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 6) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Protocol) + } + return offset +} func (p *Model) fastWriteField7(buf []byte, w thrift.NocopyWriter) int { offset := 0 @@ -454,174 +820,1319 @@ func (p *Model) fastWriteField9(buf []byte, w thrift.NocopyWriter) int { return offset } -func (p *Model) field1Length() int { +func (p *Model) fastWriteField10(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetIdentification() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 10) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Identification) + } + return offset +} + +func (p *Model) fastWriteField11(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetSeries() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 11) + offset += p.Series.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *Model) fastWriteField12(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetVisibility() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 12) + offset += p.Visibility.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *Model) fastWriteField13(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetIcon() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 13) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Icon) + } + return offset +} + +func (p *Model) fastWriteField14(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetTags() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.LIST, 14) + listBeginOffset := offset + offset += thrift.Binary.ListBeginLength() + var length int + for _, v := range p.Tags { + length++ + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, v) + } + thrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRING, length) + } + return offset +} + +func (p *Model) fastWriteField15(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetStatus() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 15) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Status) + } + return offset +} + +func (p *Model) fastWriteField16(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetOriginalModelURL() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 16) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.OriginalModelURL) + } + return offset +} + +func (p *Model) fastWriteField17(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetPresetModel() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.BOOL, 17) + offset += thrift.Binary.WriteBool(buf[offset:], *p.PresetModel) + } + return offset +} + +func (p *Model) fastWriteField100(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetCreatedBy() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 100) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.CreatedBy) + } + return offset +} + +func (p *Model) fastWriteField101(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetCreatedAt() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I64, 101) + offset += thrift.Binary.WriteI64(buf[offset:], *p.CreatedAt) + } + return offset +} + +func (p *Model) fastWriteField102(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetUpdatedBy() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 102) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.UpdatedBy) + } + return offset +} + +func (p *Model) fastWriteField103(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetUpdatedAt() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I64, 103) + offset += thrift.Binary.WriteI64(buf[offset:], *p.UpdatedAt) + } + return offset +} + +func (p *Model) field1Length() int { + l := 0 + if p.IsSetModelID() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I64Length() + } + return l +} + +func (p *Model) field2Length() int { + l := 0 + if p.IsSetWorkspaceID() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I64Length() + } + return l +} + +func (p *Model) field3Length() int { + l := 0 + if p.IsSetName() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Name) + } + return l +} + +func (p *Model) field4Length() int { + l := 0 + if p.IsSetDesc() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Desc) + } + return l +} + +func (p *Model) field5Length() int { + l := 0 + if p.IsSetAbility() { + l += thrift.Binary.FieldBeginLength() + l += p.Ability.BLength() + } + return l +} + +func (p *Model) field6Length() int { + l := 0 + if p.IsSetProtocol() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Protocol) + } + return l +} + +func (p *Model) field7Length() int { + l := 0 + if p.IsSetProtocolConfig() { + l += thrift.Binary.FieldBeginLength() + l += p.ProtocolConfig.BLength() + } + return l +} + +func (p *Model) field8Length() int { + l := 0 + if p.IsSetScenarioConfigs() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.MapBeginLength() + for k, v := range p.ScenarioConfigs { + _, _ = k, v + + l += thrift.Binary.StringLengthNocopy(k) + l += v.BLength() + } + } + return l +} + +func (p *Model) field9Length() int { + l := 0 + if p.IsSetParamConfig() { + l += thrift.Binary.FieldBeginLength() + l += p.ParamConfig.BLength() + } + return l +} + +func (p *Model) field10Length() int { + l := 0 + if p.IsSetIdentification() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Identification) + } + return l +} + +func (p *Model) field11Length() int { + l := 0 + if p.IsSetSeries() { + l += thrift.Binary.FieldBeginLength() + l += p.Series.BLength() + } + return l +} + +func (p *Model) field12Length() int { + l := 0 + if p.IsSetVisibility() { + l += thrift.Binary.FieldBeginLength() + l += p.Visibility.BLength() + } + return l +} + +func (p *Model) field13Length() int { + l := 0 + if p.IsSetIcon() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Icon) + } + return l +} + +func (p *Model) field14Length() int { + l := 0 + if p.IsSetTags() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.ListBeginLength() + for _, v := range p.Tags { + _ = v + l += thrift.Binary.StringLengthNocopy(v) + } + } + return l +} + +func (p *Model) field15Length() int { + l := 0 + if p.IsSetStatus() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Status) + } + return l +} + +func (p *Model) field16Length() int { + l := 0 + if p.IsSetOriginalModelURL() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.OriginalModelURL) + } + return l +} + +func (p *Model) field17Length() int { + l := 0 + if p.IsSetPresetModel() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.BoolLength() + } + return l +} + +func (p *Model) field100Length() int { + l := 0 + if p.IsSetCreatedBy() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.CreatedBy) + } + return l +} + +func (p *Model) field101Length() int { + l := 0 + if p.IsSetCreatedAt() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I64Length() + } + return l +} + +func (p *Model) field102Length() int { + l := 0 + if p.IsSetUpdatedBy() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.UpdatedBy) + } + return l +} + +func (p *Model) field103Length() int { + l := 0 + if p.IsSetUpdatedAt() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I64Length() + } + return l +} + +func (p *Model) DeepCopy(s interface{}) error { + src, ok := s.(*Model) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.ModelID != nil { + tmp := *src.ModelID + p.ModelID = &tmp + } + + if src.WorkspaceID != nil { + tmp := *src.WorkspaceID + p.WorkspaceID = &tmp + } + + if src.Name != nil { + var tmp string + if *src.Name != "" { + tmp = kutils.StringDeepCopy(*src.Name) + } + p.Name = &tmp + } + + if src.Desc != nil { + var tmp string + if *src.Desc != "" { + tmp = kutils.StringDeepCopy(*src.Desc) + } + p.Desc = &tmp + } + + var _ability *Ability + if src.Ability != nil { + _ability = &Ability{} + if err := _ability.DeepCopy(src.Ability); err != nil { + return err + } + } + p.Ability = _ability + + if src.Protocol != nil { + tmp := *src.Protocol + p.Protocol = &tmp + } + + var _protocolConfig *ProtocolConfig + if src.ProtocolConfig != nil { + _protocolConfig = &ProtocolConfig{} + if err := _protocolConfig.DeepCopy(src.ProtocolConfig); err != nil { + return err + } + } + p.ProtocolConfig = _protocolConfig + + if src.ScenarioConfigs != nil { + p.ScenarioConfigs = make(map[common.Scenario]*ScenarioConfig, len(src.ScenarioConfigs)) + for key, val := range src.ScenarioConfigs { + var _key common.Scenario + _key = key + + var _val *ScenarioConfig + if val != nil { + _val = &ScenarioConfig{} + if err := _val.DeepCopy(val); err != nil { + return err + } + } + + p.ScenarioConfigs[_key] = _val + } + } + + var _paramConfig *ParamConfig + if src.ParamConfig != nil { + _paramConfig = &ParamConfig{} + if err := _paramConfig.DeepCopy(src.ParamConfig); err != nil { + return err + } + } + p.ParamConfig = _paramConfig + + if src.Identification != nil { + var tmp string + if *src.Identification != "" { + tmp = kutils.StringDeepCopy(*src.Identification) + } + p.Identification = &tmp + } + + var _series *Series + if src.Series != nil { + _series = &Series{} + if err := _series.DeepCopy(src.Series); err != nil { + return err + } + } + p.Series = _series + + var _visibility *Visibility + if src.Visibility != nil { + _visibility = &Visibility{} + if err := _visibility.DeepCopy(src.Visibility); err != nil { + return err + } + } + p.Visibility = _visibility + + if src.Icon != nil { + var tmp string + if *src.Icon != "" { + tmp = kutils.StringDeepCopy(*src.Icon) + } + p.Icon = &tmp + } + + if src.Tags != nil { + p.Tags = make([]string, 0, len(src.Tags)) + for _, elem := range src.Tags { + var _elem string + if elem != "" { + _elem = kutils.StringDeepCopy(elem) + } + p.Tags = append(p.Tags, _elem) + } + } + + if src.Status != nil { + tmp := *src.Status + p.Status = &tmp + } + + if src.OriginalModelURL != nil { + var tmp string + if *src.OriginalModelURL != "" { + tmp = kutils.StringDeepCopy(*src.OriginalModelURL) + } + p.OriginalModelURL = &tmp + } + + if src.PresetModel != nil { + tmp := *src.PresetModel + p.PresetModel = &tmp + } + + if src.CreatedBy != nil { + var tmp string + if *src.CreatedBy != "" { + tmp = kutils.StringDeepCopy(*src.CreatedBy) + } + p.CreatedBy = &tmp + } + + if src.CreatedAt != nil { + tmp := *src.CreatedAt + p.CreatedAt = &tmp + } + + if src.UpdatedBy != nil { + var tmp string + if *src.UpdatedBy != "" { + tmp = kutils.StringDeepCopy(*src.UpdatedBy) + } + p.UpdatedBy = &tmp + } + + if src.UpdatedAt != nil { + tmp := *src.UpdatedAt + p.UpdatedAt = &tmp + } + + return nil +} + +func (p *Series) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 3: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField3(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Series[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *Series) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Name = _field + return offset, nil +} + +func (p *Series) FastReadField2(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Icon = _field + return offset, nil +} + +func (p *Series) FastReadField3(buf []byte) (int, error) { + offset := 0 + + var _field *Family + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Family = _field + return offset, nil +} + +func (p *Series) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *Series) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + offset += p.fastWriteField3(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *Series) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + l += p.field3Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *Series) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetName() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 1) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Name) + } + return offset +} + +func (p *Series) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetIcon() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 2) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Icon) + } + return offset +} + +func (p *Series) fastWriteField3(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetFamily() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 3) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Family) + } + return offset +} + +func (p *Series) field1Length() int { + l := 0 + if p.IsSetName() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Name) + } + return l +} + +func (p *Series) field2Length() int { + l := 0 + if p.IsSetIcon() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Icon) + } + return l +} + +func (p *Series) field3Length() int { + l := 0 + if p.IsSetFamily() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Family) + } + return l +} + +func (p *Series) DeepCopy(s interface{}) error { + src, ok := s.(*Series) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.Name != nil { + var tmp string + if *src.Name != "" { + tmp = kutils.StringDeepCopy(*src.Name) + } + p.Name = &tmp + } + + if src.Icon != nil { + var tmp string + if *src.Icon != "" { + tmp = kutils.StringDeepCopy(*src.Icon) + } + p.Icon = &tmp + } + + if src.Family != nil { + tmp := *src.Family + p.Family = &tmp + } + + return nil +} + +func (p *Visibility) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.LIST { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Visibility[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *Visibility) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *VisibleMode + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Mode = _field + return offset, nil +} + +func (p *Visibility) FastReadField2(buf []byte) (int, error) { + offset := 0 + + _, size, l, err := thrift.Binary.ReadListBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _field := make([]int64, 0, size) + for i := 0; i < size; i++ { + var _elem int64 + if v, l, err := thrift.Binary.ReadI64(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _elem = v + } + + _field = append(_field, _elem) + } + p.SpaceIDs = _field + return offset, nil +} + +func (p *Visibility) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *Visibility) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *Visibility) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *Visibility) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetMode() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 1) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Mode) + } + return offset +} + +func (p *Visibility) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetSpaceIDs() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.LIST, 2) + listBeginOffset := offset + offset += thrift.Binary.ListBeginLength() + var length int + for _, v := range p.SpaceIDs { + length++ + offset += thrift.Binary.WriteI64(buf[offset:], v) + } + thrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.I64, length) + } + return offset +} + +func (p *Visibility) field1Length() int { + l := 0 + if p.IsSetMode() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Mode) + } + return l +} + +func (p *Visibility) field2Length() int { + l := 0 + if p.IsSetSpaceIDs() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.ListBeginLength() + l += + thrift.Binary.I64Length() * len(p.SpaceIDs) + } + return l +} + +func (p *Visibility) DeepCopy(s interface{}) error { + src, ok := s.(*Visibility) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.Mode != nil { + tmp := *src.Mode + p.Mode = &tmp + } + + if src.SpaceIDs != nil { + p.SpaceIDs = make([]int64, 0, len(src.SpaceIDs)) + for _, elem := range src.SpaceIDs { + var _elem int64 + _elem = elem + p.SpaceIDs = append(p.SpaceIDs, _elem) + } + } + + return nil +} + +func (p *ProviderInfo) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ProviderInfo[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *ProviderInfo) FastReadField1(buf []byte) (int, error) { + offset := 0 + _field := NewMaaSInfo() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.MaasInfo = _field + return offset, nil +} + +func (p *ProviderInfo) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *ProviderInfo) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *ProviderInfo) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *ProviderInfo) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetMaasInfo() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 1) + offset += p.MaasInfo.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *ProviderInfo) field1Length() int { + l := 0 + if p.IsSetMaasInfo() { + l += thrift.Binary.FieldBeginLength() + l += p.MaasInfo.BLength() + } + return l +} + +func (p *ProviderInfo) DeepCopy(s interface{}) error { + src, ok := s.(*ProviderInfo) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + var _maasInfo *MaaSInfo + if src.MaasInfo != nil { + _maasInfo = &MaaSInfo{} + if err := _maasInfo.DeepCopy(src.MaasInfo); err != nil { + return err + } + } + p.MaasInfo = _maasInfo + + return nil +} + +func (p *MaaSInfo) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 3: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField3(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 4: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField4(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MaaSInfo[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *MaaSInfo) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Host = _field + return offset, nil +} + +func (p *MaaSInfo) FastReadField2(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Region = _field + return offset, nil +} + +func (p *MaaSInfo) FastReadField3(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.BaseURL = _field + return offset, nil +} + +func (p *MaaSInfo) FastReadField4(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.CustomizationJobsID = _field + return offset, nil +} + +func (p *MaaSInfo) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *MaaSInfo) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + offset += p.fastWriteField3(buf[offset:], w) + offset += p.fastWriteField4(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *MaaSInfo) BLength() int { l := 0 - if p.IsSetModelID() { - l += thrift.Binary.FieldBeginLength() - l += thrift.Binary.I64Length() + if p != nil { + l += p.field1Length() + l += p.field2Length() + l += p.field3Length() + l += p.field4Length() } + l += thrift.Binary.FieldStopLength() return l } -func (p *Model) field2Length() int { - l := 0 - if p.IsSetWorkspaceID() { - l += thrift.Binary.FieldBeginLength() - l += thrift.Binary.I64Length() +func (p *MaaSInfo) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetHost() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 1) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Host) } - return l + return offset } -func (p *Model) field3Length() int { - l := 0 - if p.IsSetName() { - l += thrift.Binary.FieldBeginLength() - l += thrift.Binary.StringLengthNocopy(*p.Name) +func (p *MaaSInfo) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetRegion() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 2) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Region) } - return l + return offset } -func (p *Model) field4Length() int { - l := 0 - if p.IsSetDesc() { - l += thrift.Binary.FieldBeginLength() - l += thrift.Binary.StringLengthNocopy(*p.Desc) +func (p *MaaSInfo) fastWriteField3(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetBaseURL() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 3) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.BaseURL) } - return l + return offset } -func (p *Model) field5Length() int { - l := 0 - if p.IsSetAbility() { - l += thrift.Binary.FieldBeginLength() - l += p.Ability.BLength() +func (p *MaaSInfo) fastWriteField4(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetCustomizationJobsID() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 4) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.CustomizationJobsID) } - return l + return offset } -func (p *Model) field6Length() int { +func (p *MaaSInfo) field1Length() int { l := 0 - if p.IsSetProtocol() { + if p.IsSetHost() { l += thrift.Binary.FieldBeginLength() - l += thrift.Binary.StringLengthNocopy(*p.Protocol) + l += thrift.Binary.StringLengthNocopy(*p.Host) } return l } -func (p *Model) field7Length() int { +func (p *MaaSInfo) field2Length() int { l := 0 - if p.IsSetProtocolConfig() { + if p.IsSetRegion() { l += thrift.Binary.FieldBeginLength() - l += p.ProtocolConfig.BLength() + l += thrift.Binary.StringLengthNocopy(*p.Region) } return l } -func (p *Model) field8Length() int { +func (p *MaaSInfo) field3Length() int { l := 0 - if p.IsSetScenarioConfigs() { + if p.IsSetBaseURL() { l += thrift.Binary.FieldBeginLength() - l += thrift.Binary.MapBeginLength() - for k, v := range p.ScenarioConfigs { - _, _ = k, v - - l += thrift.Binary.StringLengthNocopy(k) - l += v.BLength() - } + l += thrift.Binary.StringLengthNocopy(*p.BaseURL) } return l } -func (p *Model) field9Length() int { +func (p *MaaSInfo) field4Length() int { l := 0 - if p.IsSetParamConfig() { + if p.IsSetCustomizationJobsID() { l += thrift.Binary.FieldBeginLength() - l += p.ParamConfig.BLength() + l += thrift.Binary.StringLengthNocopy(*p.CustomizationJobsID) } return l } -func (p *Model) DeepCopy(s interface{}) error { - src, ok := s.(*Model) +func (p *MaaSInfo) DeepCopy(s interface{}) error { + src, ok := s.(*MaaSInfo) if !ok { return fmt.Errorf("%T's type not matched %T", s, p) } - if src.ModelID != nil { - tmp := *src.ModelID - p.ModelID = &tmp - } - - if src.WorkspaceID != nil { - tmp := *src.WorkspaceID - p.WorkspaceID = &tmp - } - - if src.Name != nil { + if src.Host != nil { var tmp string - if *src.Name != "" { - tmp = kutils.StringDeepCopy(*src.Name) + if *src.Host != "" { + tmp = kutils.StringDeepCopy(*src.Host) } - p.Name = &tmp + p.Host = &tmp } - if src.Desc != nil { + if src.Region != nil { var tmp string - if *src.Desc != "" { - tmp = kutils.StringDeepCopy(*src.Desc) - } - p.Desc = &tmp - } - - var _ability *Ability - if src.Ability != nil { - _ability = &Ability{} - if err := _ability.DeepCopy(src.Ability); err != nil { - return err - } - } - p.Ability = _ability - - if src.Protocol != nil { - tmp := *src.Protocol - p.Protocol = &tmp - } - - var _protocolConfig *ProtocolConfig - if src.ProtocolConfig != nil { - _protocolConfig = &ProtocolConfig{} - if err := _protocolConfig.DeepCopy(src.ProtocolConfig); err != nil { - return err + if *src.Region != "" { + tmp = kutils.StringDeepCopy(*src.Region) } + p.Region = &tmp } - p.ProtocolConfig = _protocolConfig - - if src.ScenarioConfigs != nil { - p.ScenarioConfigs = make(map[common.Scenario]*ScenarioConfig, len(src.ScenarioConfigs)) - for key, val := range src.ScenarioConfigs { - var _key common.Scenario - _key = key - var _val *ScenarioConfig - if val != nil { - _val = &ScenarioConfig{} - if err := _val.DeepCopy(val); err != nil { - return err - } - } - - p.ScenarioConfigs[_key] = _val + if src.BaseURL != nil { + var tmp string + if *src.BaseURL != "" { + tmp = kutils.StringDeepCopy(*src.BaseURL) } + p.BaseURL = &tmp } - var _paramConfig *ParamConfig - if src.ParamConfig != nil { - _paramConfig = &ParamConfig{} - if err := _paramConfig.DeepCopy(src.ParamConfig); err != nil { - return err + if src.CustomizationJobsID != nil { + var tmp string + if *src.CustomizationJobsID != "" { + tmp = kutils.StringDeepCopy(*src.CustomizationJobsID) } + p.CustomizationJobsID = &tmp } - p.ParamConfig = _paramConfig return nil } @@ -741,6 +2252,20 @@ func (p *Ability) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 8: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField8(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } default: l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l @@ -855,6 +2380,20 @@ func (p *Ability) FastReadField7(buf []byte) (int, error) { return offset, nil } +func (p *Ability) FastReadField8(buf []byte) (int, error) { + offset := 0 + + var _field *InterfaceCategory + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.InterfaceCategory = _field + return offset, nil +} + func (p *Ability) FastWrite(buf []byte) int { return p.FastWriteNocopy(buf, nil) } @@ -869,6 +2408,7 @@ func (p *Ability) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset += p.fastWriteField5(buf[offset:], w) offset += p.fastWriteField6(buf[offset:], w) offset += p.fastWriteField7(buf[offset:], w) + offset += p.fastWriteField8(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset @@ -884,6 +2424,7 @@ func (p *Ability) BLength() int { l += p.field5Length() l += p.field6Length() l += p.field7Length() + l += p.field8Length() } l += thrift.Binary.FieldStopLength() return l @@ -952,6 +2493,15 @@ func (p *Ability) fastWriteField7(buf []byte, w thrift.NocopyWriter) int { return offset } +func (p *Ability) fastWriteField8(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetInterfaceCategory() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 8) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.InterfaceCategory) + } + return offset +} + func (p *Ability) field1Length() int { l := 0 if p.IsSetMaxContextTokens() { @@ -1015,6 +2565,15 @@ func (p *Ability) field7Length() int { return l } +func (p *Ability) field8Length() int { + l := 0 + if p.IsSetInterfaceCategory() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.InterfaceCategory) + } + return l +} + func (p *Ability) DeepCopy(s interface{}) error { src, ok := s.(*Ability) if !ok { @@ -1060,6 +2619,11 @@ func (p *Ability) DeepCopy(s interface{}) error { } p.AbilityMultiModal = _abilityMultiModal + if src.InterfaceCategory != nil { + tmp := *src.InterfaceCategory + p.InterfaceCategory = &tmp + } + return nil } @@ -5772,6 +7336,48 @@ func (p *ParamSchema) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 9: + if fieldTypeId == thrift.LIST { + l, err = p.FastReadField9(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 10: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField10(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 11: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField11(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } default: l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l @@ -5877,18 +7483,43 @@ func (p *ParamSchema) FastReadField6(buf []byte) (int, error) { func (p *ParamSchema) FastReadField7(buf []byte) (int, error) { offset := 0 - var _field *string - if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { - return offset, err - } else { - offset += l - _field = &v + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.DefaultValue = _field + return offset, nil +} + +func (p *ParamSchema) FastReadField8(buf []byte) (int, error) { + offset := 0 + + _, size, l, err := thrift.Binary.ReadListBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _field := make([]*ParamOption, 0, size) + values := make([]ParamOption, size) + for i := 0; i < size; i++ { + _elem := &values[i] + _elem.InitDefault() + if l, err := _elem.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + + _field = append(_field, _elem) } - p.DefaultValue = _field + p.Options = _field return offset, nil } -func (p *ParamSchema) FastReadField8(buf []byte) (int, error) { +func (p *ParamSchema) FastReadField9(buf []byte) (int, error) { offset := 0 _, size, l, err := thrift.Binary.ReadListBegin(buf[offset:]) @@ -5896,8 +7527,8 @@ func (p *ParamSchema) FastReadField8(buf []byte) (int, error) { if err != nil { return offset, err } - _field := make([]*ParamOption, 0, size) - values := make([]ParamOption, size) + _field := make([]*ParamSchema, 0, size) + values := make([]ParamSchema, size) for i := 0; i < size; i++ { _elem := &values[i] _elem.InitDefault() @@ -5909,7 +7540,33 @@ func (p *ParamSchema) FastReadField8(buf []byte) (int, error) { _field = append(_field, _elem) } - p.Options = _field + p.Properties = _field + return offset, nil +} + +func (p *ParamSchema) FastReadField10(buf []byte) (int, error) { + offset := 0 + _field := NewReaction() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.Reaction = _field + return offset, nil +} + +func (p *ParamSchema) FastReadField11(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Jsonpath = _field return offset, nil } @@ -5928,6 +7585,9 @@ func (p *ParamSchema) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset += p.fastWriteField6(buf[offset:], w) offset += p.fastWriteField7(buf[offset:], w) offset += p.fastWriteField8(buf[offset:], w) + offset += p.fastWriteField9(buf[offset:], w) + offset += p.fastWriteField10(buf[offset:], w) + offset += p.fastWriteField11(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset @@ -5944,6 +7604,9 @@ func (p *ParamSchema) BLength() int { l += p.field6Length() l += p.field7Length() l += p.field8Length() + l += p.field9Length() + l += p.field10Length() + l += p.field11Length() } l += thrift.Binary.FieldStopLength() return l @@ -6028,6 +7691,40 @@ func (p *ParamSchema) fastWriteField8(buf []byte, w thrift.NocopyWriter) int { return offset } +func (p *ParamSchema) fastWriteField9(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetProperties() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.LIST, 9) + listBeginOffset := offset + offset += thrift.Binary.ListBeginLength() + var length int + for _, v := range p.Properties { + length++ + offset += v.FastWriteNocopy(buf[offset:], w) + } + thrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRUCT, length) + } + return offset +} + +func (p *ParamSchema) fastWriteField10(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetReaction() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 10) + offset += p.Reaction.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *ParamSchema) fastWriteField11(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetJsonpath() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 11) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Jsonpath) + } + return offset +} + func (p *ParamSchema) field1Length() int { l := 0 if p.IsSetName() { @@ -6104,6 +7801,37 @@ func (p *ParamSchema) field8Length() int { return l } +func (p *ParamSchema) field9Length() int { + l := 0 + if p.IsSetProperties() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.ListBeginLength() + for _, v := range p.Properties { + _ = v + l += v.BLength() + } + } + return l +} + +func (p *ParamSchema) field10Length() int { + l := 0 + if p.IsSetReaction() { + l += thrift.Binary.FieldBeginLength() + l += p.Reaction.BLength() + } + return l +} + +func (p *ParamSchema) field11Length() int { + l := 0 + if p.IsSetJsonpath() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Jsonpath) + } + return l +} + func (p *ParamSchema) DeepCopy(s interface{}) error { src, ok := s.(*ParamSchema) if !ok { @@ -6178,6 +7906,214 @@ func (p *ParamSchema) DeepCopy(s interface{}) error { } } + if src.Properties != nil { + p.Properties = make([]*ParamSchema, 0, len(src.Properties)) + for _, elem := range src.Properties { + var _elem *ParamSchema + if elem != nil { + _elem = &ParamSchema{} + if err := _elem.DeepCopy(elem); err != nil { + return err + } + } + + p.Properties = append(p.Properties, _elem) + } + } + + var _reaction *Reaction + if src.Reaction != nil { + _reaction = &Reaction{} + if err := _reaction.DeepCopy(src.Reaction); err != nil { + return err + } + } + p.Reaction = _reaction + + if src.Jsonpath != nil { + var tmp string + if *src.Jsonpath != "" { + tmp = kutils.StringDeepCopy(*src.Jsonpath) + } + p.Jsonpath = &tmp + } + + return nil +} + +func (p *Reaction) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Reaction[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *Reaction) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Dependency = _field + return offset, nil +} + +func (p *Reaction) FastReadField2(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Visible = _field + return offset, nil +} + +func (p *Reaction) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *Reaction) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *Reaction) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *Reaction) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetDependency() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 1) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Dependency) + } + return offset +} + +func (p *Reaction) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetVisible() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 2) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Visible) + } + return offset +} + +func (p *Reaction) field1Length() int { + l := 0 + if p.IsSetDependency() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Dependency) + } + return l +} + +func (p *Reaction) field2Length() int { + l := 0 + if p.IsSetVisible() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Visible) + } + return l +} + +func (p *Reaction) DeepCopy(s interface{}) error { + src, ok := s.(*Reaction) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.Dependency != nil { + var tmp string + if *src.Dependency != "" { + tmp = kutils.StringDeepCopy(*src.Dependency) + } + p.Dependency = &tmp + } + + if src.Visible != nil { + var tmp string + if *src.Visible != "" { + tmp = kutils.StringDeepCopy(*src.Visible) + } + p.Visible = &tmp + } + return nil } diff --git a/backend/kitex_gen/coze/loop/llm/domain/manage/manage.go b/backend/kitex_gen/coze/loop/llm/domain/manage/manage.go index 6f0e4dbaf..ba2134f52 100644 --- a/backend/kitex_gen/coze/loop/llm/domain/manage/manage.go +++ b/backend/kitex_gen/coze/loop/llm/domain/manage/manage.go @@ -36,6 +36,112 @@ const ( ParamTypeString = "string" + ParamTypeVoid = "void" + + ParamTypeObject = "object" + + FamilyUndefined = "undefined" + + FamilyGpt = "gpt" + + FamilySeed = "seed" + + FamilyGemini = "gemini" + + FamilyClaude = "claude" + + FamilyErnie = "ernie" + + FamilyBaichuan = "baichuan" + + FamilyQwen = "qwen" + + FamilyGlm = "glm" + + FamilySkylark = "skylark" + + FamilyMoonshot = "moonshot" + + FamilyMinimax = "minimax" + + FamilyDoubao = "doubao" + + FamilyBaichuan2 = "baichuan2" + + FamilyDeepseekv2 = "deepseekv2" + + FamilyDeepseekCoderV2 = "deepseek_coder_v2" + + FamilyDeepseekCoder = "deepseek_coder" + + FamilyInternalm25 = "internalm2_5" + + FamilyQwen2 = "qwen2" + + FamilyQwen25 = "qwen2.5" + + FamilyQwen25Coder = "qwen2.5_coder" + + FamilyMiniCpm = "mini_cpm" + + FamilyMiniCpm3 = "mini_cpm_3" + + FamilyChatGlm3 = "chat_glm_3" + + FamilyMistra = "mistral" + + FamilyGemma = "gemma" + + FamilyGemma2 = "gemma_2" + + FamilyInternVl2 = "intern_vl2" + + FamilyInternVl25 = "intern_vl2.5" + + FamilyDeepseekV3 = "deepseek_v3" + + FamilyDeepseekR1 = "deepseek_r1" + + FamilyKimi = "kimi" + + FamilySeedream = "seedream" + + FamilyInternVl3 = "intern_vl3" + + FamilyDeepseek = "deepseek" + + ProviderUndefined = "undefined" + + ProviderMaas = "maas" + + VisibleModeDefault = "default" + + VisibleModeSpecified = "specified" + + VisibleModeUndefined = "undefined" + + VisibleModeAll = "all" + + ModelStatusUndefined = "undefined" + + ModelStatusAvailable = "available" + + ModelStatusUnavailable = "unavailable" + + InterfaceCategoryUndefined = "undefined" + + InterfaceCategoryChatCompletionAPI = "chat_completion_api" + + InterfaceCategoryResponseAPI = "response_api" + + AbilityUndefined = "undefined" + + AbilityJSONMode = "json_mode" + + AbilityFunctionCall = "function_call" + + AbilityMultiModal_ = "multi_modal" + VideoFormatUndefined = "undefined" VideoFormatMp4 = "mp4" @@ -71,6 +177,18 @@ type Protocol = string type ParamType = string +type Family = string + +type Provider = string + +type VisibleMode = string + +type ModelStatus = string + +type InterfaceCategory = string + +type AbilityEnum = string + type VideoFormat = string type Model struct { @@ -83,6 +201,25 @@ type Model struct { ProtocolConfig *ProtocolConfig `thrift:"protocol_config,7,optional" frugal:"7,optional,ProtocolConfig" form:"protocol_config" json:"protocol_config,omitempty" query:"protocol_config"` ScenarioConfigs map[common.Scenario]*ScenarioConfig `thrift:"scenario_configs,8,optional" frugal:"8,optional,map" form:"scenario_configs" json:"scenario_configs,omitempty" query:"scenario_configs"` ParamConfig *ParamConfig `thrift:"param_config,9,optional" frugal:"9,optional,ParamConfig" form:"param_config" json:"param_config,omitempty" query:"param_config"` + // 模型表示 (name, endpoint) + Identification *string `thrift:"identification,10,optional" frugal:"10,optional,string" form:"identification" json:"identification,omitempty" query:"identification"` + // 模型 + Series *Series `thrift:"series,11,optional" frugal:"11,optional,Series" form:"series" json:"series,omitempty" query:"series"` + Visibility *Visibility `thrift:"visibility,12,optional" frugal:"12,optional,Visibility" form:"visibility" json:"visibility,omitempty" query:"visibility"` + // 模型图标 + Icon *string `thrift:"icon,13,optional" frugal:"13,optional,string" form:"icon" json:"icon,omitempty" query:"icon"` + //模型标签 + Tags []string `thrift:"tags,14,optional" frugal:"14,optional,list" form:"tags" json:"tags,omitempty" query:"tags"` + // 模型状态 + Status *ModelStatus `thrift:"status,15,optional" frugal:"15,optional,string" form:"status" json:"status,omitempty" query:"status"` + // 模型跳转链接 + OriginalModelURL *string `thrift:"original_model_url,16,optional" frugal:"16,optional,string" form:"original_model_url" json:"original_model_url,omitempty" query:"original_model_url"` + // 是否为预置模型 + PresetModel *bool `thrift:"preset_model,17,optional" frugal:"17,optional,bool" form:"preset_model" json:"preset_model,omitempty" query:"preset_model"` + CreatedBy *string `thrift:"created_by,100,optional" frugal:"100,optional,string" form:"created_by" json:"created_by,omitempty" query:"created_by"` + CreatedAt *int64 `thrift:"created_at,101,optional" frugal:"101,optional,i64" form:"created_at" json:"created_at,omitempty" query:"created_at"` + UpdatedBy *string `thrift:"updated_by,102,optional" frugal:"102,optional,string" form:"updated_by" json:"updated_by,omitempty" query:"updated_by"` + UpdatedAt *int64 `thrift:"updated_at,103,optional" frugal:"103,optional,i64" form:"updated_at" json:"updated_at,omitempty" query:"updated_at"` } func NewModel() *Model { @@ -199,6 +336,150 @@ func (p *Model) GetParamConfig() (v *ParamConfig) { } return p.ParamConfig } + +var Model_Identification_DEFAULT string + +func (p *Model) GetIdentification() (v string) { + if p == nil { + return + } + if !p.IsSetIdentification() { + return Model_Identification_DEFAULT + } + return *p.Identification +} + +var Model_Series_DEFAULT *Series + +func (p *Model) GetSeries() (v *Series) { + if p == nil { + return + } + if !p.IsSetSeries() { + return Model_Series_DEFAULT + } + return p.Series +} + +var Model_Visibility_DEFAULT *Visibility + +func (p *Model) GetVisibility() (v *Visibility) { + if p == nil { + return + } + if !p.IsSetVisibility() { + return Model_Visibility_DEFAULT + } + return p.Visibility +} + +var Model_Icon_DEFAULT string + +func (p *Model) GetIcon() (v string) { + if p == nil { + return + } + if !p.IsSetIcon() { + return Model_Icon_DEFAULT + } + return *p.Icon +} + +var Model_Tags_DEFAULT []string + +func (p *Model) GetTags() (v []string) { + if p == nil { + return + } + if !p.IsSetTags() { + return Model_Tags_DEFAULT + } + return p.Tags +} + +var Model_Status_DEFAULT ModelStatus + +func (p *Model) GetStatus() (v ModelStatus) { + if p == nil { + return + } + if !p.IsSetStatus() { + return Model_Status_DEFAULT + } + return *p.Status +} + +var Model_OriginalModelURL_DEFAULT string + +func (p *Model) GetOriginalModelURL() (v string) { + if p == nil { + return + } + if !p.IsSetOriginalModelURL() { + return Model_OriginalModelURL_DEFAULT + } + return *p.OriginalModelURL +} + +var Model_PresetModel_DEFAULT bool + +func (p *Model) GetPresetModel() (v bool) { + if p == nil { + return + } + if !p.IsSetPresetModel() { + return Model_PresetModel_DEFAULT + } + return *p.PresetModel +} + +var Model_CreatedBy_DEFAULT string + +func (p *Model) GetCreatedBy() (v string) { + if p == nil { + return + } + if !p.IsSetCreatedBy() { + return Model_CreatedBy_DEFAULT + } + return *p.CreatedBy +} + +var Model_CreatedAt_DEFAULT int64 + +func (p *Model) GetCreatedAt() (v int64) { + if p == nil { + return + } + if !p.IsSetCreatedAt() { + return Model_CreatedAt_DEFAULT + } + return *p.CreatedAt +} + +var Model_UpdatedBy_DEFAULT string + +func (p *Model) GetUpdatedBy() (v string) { + if p == nil { + return + } + if !p.IsSetUpdatedBy() { + return Model_UpdatedBy_DEFAULT + } + return *p.UpdatedBy +} + +var Model_UpdatedAt_DEFAULT int64 + +func (p *Model) GetUpdatedAt() (v int64) { + if p == nil { + return + } + if !p.IsSetUpdatedAt() { + return Model_UpdatedAt_DEFAULT + } + return *p.UpdatedAt +} func (p *Model) SetModelID(val *int64) { p.ModelID = val } @@ -226,17 +507,65 @@ func (p *Model) SetScenarioConfigs(val map[common.Scenario]*ScenarioConfig) { func (p *Model) SetParamConfig(val *ParamConfig) { p.ParamConfig = val } +func (p *Model) SetIdentification(val *string) { + p.Identification = val +} +func (p *Model) SetSeries(val *Series) { + p.Series = val +} +func (p *Model) SetVisibility(val *Visibility) { + p.Visibility = val +} +func (p *Model) SetIcon(val *string) { + p.Icon = val +} +func (p *Model) SetTags(val []string) { + p.Tags = val +} +func (p *Model) SetStatus(val *ModelStatus) { + p.Status = val +} +func (p *Model) SetOriginalModelURL(val *string) { + p.OriginalModelURL = val +} +func (p *Model) SetPresetModel(val *bool) { + p.PresetModel = val +} +func (p *Model) SetCreatedBy(val *string) { + p.CreatedBy = val +} +func (p *Model) SetCreatedAt(val *int64) { + p.CreatedAt = val +} +func (p *Model) SetUpdatedBy(val *string) { + p.UpdatedBy = val +} +func (p *Model) SetUpdatedAt(val *int64) { + p.UpdatedAt = val +} var fieldIDToName_Model = map[int16]string{ - 1: "model_id", - 2: "workspace_id", - 3: "name", - 4: "desc", - 5: "ability", - 6: "protocol", - 7: "protocol_config", - 8: "scenario_configs", - 9: "param_config", + 1: "model_id", + 2: "workspace_id", + 3: "name", + 4: "desc", + 5: "ability", + 6: "protocol", + 7: "protocol_config", + 8: "scenario_configs", + 9: "param_config", + 10: "identification", + 11: "series", + 12: "visibility", + 13: "icon", + 14: "tags", + 15: "status", + 16: "original_model_url", + 17: "preset_model", + 100: "created_by", + 101: "created_at", + 102: "updated_by", + 103: "updated_at", } func (p *Model) IsSetModelID() bool { @@ -275,6 +604,54 @@ func (p *Model) IsSetParamConfig() bool { return p.ParamConfig != nil } +func (p *Model) IsSetIdentification() bool { + return p.Identification != nil +} + +func (p *Model) IsSetSeries() bool { + return p.Series != nil +} + +func (p *Model) IsSetVisibility() bool { + return p.Visibility != nil +} + +func (p *Model) IsSetIcon() bool { + return p.Icon != nil +} + +func (p *Model) IsSetTags() bool { + return p.Tags != nil +} + +func (p *Model) IsSetStatus() bool { + return p.Status != nil +} + +func (p *Model) IsSetOriginalModelURL() bool { + return p.OriginalModelURL != nil +} + +func (p *Model) IsSetPresetModel() bool { + return p.PresetModel != nil +} + +func (p *Model) IsSetCreatedBy() bool { + return p.CreatedBy != nil +} + +func (p *Model) IsSetCreatedAt() bool { + return p.CreatedAt != nil +} + +func (p *Model) IsSetUpdatedBy() bool { + return p.UpdatedBy != nil +} + +func (p *Model) IsSetUpdatedAt() bool { + return p.UpdatedAt != nil +} + func (p *Model) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -365,79 +742,175 @@ func (p *Model) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } - default: - if err = iprot.Skip(fieldTypeId); err != nil { + case 10: + if fieldTypeId == thrift.STRING { + if err = p.ReadField10(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } - } - if err = iprot.ReadFieldEnd(); err != nil { - goto ReadFieldEndError - } - } - if err = iprot.ReadStructEnd(); err != nil { - goto ReadStructEndError - } - - return nil -ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) -ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) -ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Model[fieldId]), err) -SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) - -ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) -} - -func (p *Model) ReadField1(iprot thrift.TProtocol) error { - - var _field *int64 - if v, err := iprot.ReadI64(); err != nil { - return err - } else { - _field = &v - } - p.ModelID = _field - return nil -} -func (p *Model) ReadField2(iprot thrift.TProtocol) error { - - var _field *int64 - if v, err := iprot.ReadI64(); err != nil { - return err - } else { - _field = &v - } - p.WorkspaceID = _field - return nil -} -func (p *Model) ReadField3(iprot thrift.TProtocol) error { - - var _field *string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _field = &v - } - p.Name = _field - return nil -} -func (p *Model) ReadField4(iprot thrift.TProtocol) error { - - var _field *string - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _field = &v - } - p.Desc = _field - return nil -} + case 11: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField11(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 12: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField12(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 13: + if fieldTypeId == thrift.STRING { + if err = p.ReadField13(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 14: + if fieldTypeId == thrift.LIST { + if err = p.ReadField14(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 15: + if fieldTypeId == thrift.STRING { + if err = p.ReadField15(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 16: + if fieldTypeId == thrift.STRING { + if err = p.ReadField16(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 17: + if fieldTypeId == thrift.BOOL { + if err = p.ReadField17(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 100: + if fieldTypeId == thrift.STRING { + if err = p.ReadField100(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 101: + if fieldTypeId == thrift.I64 { + if err = p.ReadField101(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 102: + if fieldTypeId == thrift.STRING { + if err = p.ReadField102(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 103: + if fieldTypeId == thrift.I64 { + if err = p.ReadField103(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Model[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *Model) ReadField1(iprot thrift.TProtocol) error { + + var _field *int64 + if v, err := iprot.ReadI64(); err != nil { + return err + } else { + _field = &v + } + p.ModelID = _field + return nil +} +func (p *Model) ReadField2(iprot thrift.TProtocol) error { + + var _field *int64 + if v, err := iprot.ReadI64(); err != nil { + return err + } else { + _field = &v + } + p.WorkspaceID = _field + return nil +} +func (p *Model) ReadField3(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Name = _field + return nil +} +func (p *Model) ReadField4(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Desc = _field + return nil +} func (p *Model) ReadField5(iprot thrift.TProtocol) error { _field := NewAbility() if err := _field.Read(iprot); err != nil { @@ -462,50 +935,2018 @@ func (p *Model) ReadField7(iprot thrift.TProtocol) error { if err := _field.Read(iprot); err != nil { return err } - p.ProtocolConfig = _field + p.ProtocolConfig = _field + return nil +} +func (p *Model) ReadField8(iprot thrift.TProtocol) error { + _, _, size, err := iprot.ReadMapBegin() + if err != nil { + return err + } + _field := make(map[common.Scenario]*ScenarioConfig, size) + values := make([]ScenarioConfig, size) + for i := 0; i < size; i++ { + var _key common.Scenario + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _key = v + } + + _val := &values[i] + _val.InitDefault() + if err := _val.Read(iprot); err != nil { + return err + } + + _field[_key] = _val + } + if err := iprot.ReadMapEnd(); err != nil { + return err + } + p.ScenarioConfigs = _field + return nil +} +func (p *Model) ReadField9(iprot thrift.TProtocol) error { + _field := NewParamConfig() + if err := _field.Read(iprot); err != nil { + return err + } + p.ParamConfig = _field + return nil +} +func (p *Model) ReadField10(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Identification = _field + return nil +} +func (p *Model) ReadField11(iprot thrift.TProtocol) error { + _field := NewSeries() + if err := _field.Read(iprot); err != nil { + return err + } + p.Series = _field + return nil +} +func (p *Model) ReadField12(iprot thrift.TProtocol) error { + _field := NewVisibility() + if err := _field.Read(iprot); err != nil { + return err + } + p.Visibility = _field + return nil +} +func (p *Model) ReadField13(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Icon = _field + return nil +} +func (p *Model) ReadField14(iprot thrift.TProtocol) error { + _, size, err := iprot.ReadListBegin() + if err != nil { + return err + } + _field := make([]string, 0, size) + for i := 0; i < size; i++ { + + var _elem string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _elem = v + } + + _field = append(_field, _elem) + } + if err := iprot.ReadListEnd(); err != nil { + return err + } + p.Tags = _field + return nil +} +func (p *Model) ReadField15(iprot thrift.TProtocol) error { + + var _field *ModelStatus + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Status = _field + return nil +} +func (p *Model) ReadField16(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.OriginalModelURL = _field + return nil +} +func (p *Model) ReadField17(iprot thrift.TProtocol) error { + + var _field *bool + if v, err := iprot.ReadBool(); err != nil { + return err + } else { + _field = &v + } + p.PresetModel = _field + return nil +} +func (p *Model) ReadField100(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.CreatedBy = _field + return nil +} +func (p *Model) ReadField101(iprot thrift.TProtocol) error { + + var _field *int64 + if v, err := iprot.ReadI64(); err != nil { + return err + } else { + _field = &v + } + p.CreatedAt = _field + return nil +} +func (p *Model) ReadField102(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.UpdatedBy = _field + return nil +} +func (p *Model) ReadField103(iprot thrift.TProtocol) error { + + var _field *int64 + if v, err := iprot.ReadI64(); err != nil { + return err + } else { + _field = &v + } + p.UpdatedAt = _field + return nil +} + +func (p *Model) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("Model"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + if err = p.writeField3(oprot); err != nil { + fieldId = 3 + goto WriteFieldError + } + if err = p.writeField4(oprot); err != nil { + fieldId = 4 + goto WriteFieldError + } + if err = p.writeField5(oprot); err != nil { + fieldId = 5 + goto WriteFieldError + } + if err = p.writeField6(oprot); err != nil { + fieldId = 6 + goto WriteFieldError + } + if err = p.writeField7(oprot); err != nil { + fieldId = 7 + goto WriteFieldError + } + if err = p.writeField8(oprot); err != nil { + fieldId = 8 + goto WriteFieldError + } + if err = p.writeField9(oprot); err != nil { + fieldId = 9 + goto WriteFieldError + } + if err = p.writeField10(oprot); err != nil { + fieldId = 10 + goto WriteFieldError + } + if err = p.writeField11(oprot); err != nil { + fieldId = 11 + goto WriteFieldError + } + if err = p.writeField12(oprot); err != nil { + fieldId = 12 + goto WriteFieldError + } + if err = p.writeField13(oprot); err != nil { + fieldId = 13 + goto WriteFieldError + } + if err = p.writeField14(oprot); err != nil { + fieldId = 14 + goto WriteFieldError + } + if err = p.writeField15(oprot); err != nil { + fieldId = 15 + goto WriteFieldError + } + if err = p.writeField16(oprot); err != nil { + fieldId = 16 + goto WriteFieldError + } + if err = p.writeField17(oprot); err != nil { + fieldId = 17 + goto WriteFieldError + } + if err = p.writeField100(oprot); err != nil { + fieldId = 100 + goto WriteFieldError + } + if err = p.writeField101(oprot); err != nil { + fieldId = 101 + goto WriteFieldError + } + if err = p.writeField102(oprot); err != nil { + fieldId = 102 + goto WriteFieldError + } + if err = p.writeField103(oprot); err != nil { + fieldId = 103 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *Model) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetModelID() { + if err = oprot.WriteFieldBegin("model_id", thrift.I64, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI64(*p.ModelID); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *Model) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetWorkspaceID() { + if err = oprot.WriteFieldBegin("workspace_id", thrift.I64, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI64(*p.WorkspaceID); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} +func (p *Model) writeField3(oprot thrift.TProtocol) (err error) { + if p.IsSetName() { + if err = oprot.WriteFieldBegin("name", thrift.STRING, 3); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Name); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) +} +func (p *Model) writeField4(oprot thrift.TProtocol) (err error) { + if p.IsSetDesc() { + if err = oprot.WriteFieldBegin("desc", thrift.STRING, 4); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Desc); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) +} +func (p *Model) writeField5(oprot thrift.TProtocol) (err error) { + if p.IsSetAbility() { + if err = oprot.WriteFieldBegin("ability", thrift.STRUCT, 5); err != nil { + goto WriteFieldBeginError + } + if err := p.Ability.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 5 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 5 end error: ", p), err) +} +func (p *Model) writeField6(oprot thrift.TProtocol) (err error) { + if p.IsSetProtocol() { + if err = oprot.WriteFieldBegin("protocol", thrift.STRING, 6); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Protocol); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 6 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 6 end error: ", p), err) +} +func (p *Model) writeField7(oprot thrift.TProtocol) (err error) { + if p.IsSetProtocolConfig() { + if err = oprot.WriteFieldBegin("protocol_config", thrift.STRUCT, 7); err != nil { + goto WriteFieldBeginError + } + if err := p.ProtocolConfig.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 7 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 7 end error: ", p), err) +} +func (p *Model) writeField8(oprot thrift.TProtocol) (err error) { + if p.IsSetScenarioConfigs() { + if err = oprot.WriteFieldBegin("scenario_configs", thrift.MAP, 8); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteMapBegin(thrift.STRING, thrift.STRUCT, len(p.ScenarioConfigs)); err != nil { + return err + } + for k, v := range p.ScenarioConfigs { + if err := oprot.WriteString(k); err != nil { + return err + } + if err := v.Write(oprot); err != nil { + return err + } + } + if err := oprot.WriteMapEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 8 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 8 end error: ", p), err) +} +func (p *Model) writeField9(oprot thrift.TProtocol) (err error) { + if p.IsSetParamConfig() { + if err = oprot.WriteFieldBegin("param_config", thrift.STRUCT, 9); err != nil { + goto WriteFieldBeginError + } + if err := p.ParamConfig.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 9 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 9 end error: ", p), err) +} +func (p *Model) writeField10(oprot thrift.TProtocol) (err error) { + if p.IsSetIdentification() { + if err = oprot.WriteFieldBegin("identification", thrift.STRING, 10); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Identification); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 10 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 10 end error: ", p), err) +} +func (p *Model) writeField11(oprot thrift.TProtocol) (err error) { + if p.IsSetSeries() { + if err = oprot.WriteFieldBegin("series", thrift.STRUCT, 11); err != nil { + goto WriteFieldBeginError + } + if err := p.Series.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 11 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 11 end error: ", p), err) +} +func (p *Model) writeField12(oprot thrift.TProtocol) (err error) { + if p.IsSetVisibility() { + if err = oprot.WriteFieldBegin("visibility", thrift.STRUCT, 12); err != nil { + goto WriteFieldBeginError + } + if err := p.Visibility.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 12 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 12 end error: ", p), err) +} +func (p *Model) writeField13(oprot thrift.TProtocol) (err error) { + if p.IsSetIcon() { + if err = oprot.WriteFieldBegin("icon", thrift.STRING, 13); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Icon); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 13 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 13 end error: ", p), err) +} +func (p *Model) writeField14(oprot thrift.TProtocol) (err error) { + if p.IsSetTags() { + if err = oprot.WriteFieldBegin("tags", thrift.LIST, 14); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteListBegin(thrift.STRING, len(p.Tags)); err != nil { + return err + } + for _, v := range p.Tags { + if err := oprot.WriteString(v); err != nil { + return err + } + } + if err := oprot.WriteListEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 14 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 14 end error: ", p), err) +} +func (p *Model) writeField15(oprot thrift.TProtocol) (err error) { + if p.IsSetStatus() { + if err = oprot.WriteFieldBegin("status", thrift.STRING, 15); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Status); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 15 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 15 end error: ", p), err) +} +func (p *Model) writeField16(oprot thrift.TProtocol) (err error) { + if p.IsSetOriginalModelURL() { + if err = oprot.WriteFieldBegin("original_model_url", thrift.STRING, 16); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.OriginalModelURL); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 16 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 16 end error: ", p), err) +} +func (p *Model) writeField17(oprot thrift.TProtocol) (err error) { + if p.IsSetPresetModel() { + if err = oprot.WriteFieldBegin("preset_model", thrift.BOOL, 17); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteBool(*p.PresetModel); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 17 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 17 end error: ", p), err) +} +func (p *Model) writeField100(oprot thrift.TProtocol) (err error) { + if p.IsSetCreatedBy() { + if err = oprot.WriteFieldBegin("created_by", thrift.STRING, 100); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.CreatedBy); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 100 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 100 end error: ", p), err) +} +func (p *Model) writeField101(oprot thrift.TProtocol) (err error) { + if p.IsSetCreatedAt() { + if err = oprot.WriteFieldBegin("created_at", thrift.I64, 101); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI64(*p.CreatedAt); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 101 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 101 end error: ", p), err) +} +func (p *Model) writeField102(oprot thrift.TProtocol) (err error) { + if p.IsSetUpdatedBy() { + if err = oprot.WriteFieldBegin("updated_by", thrift.STRING, 102); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.UpdatedBy); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 102 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 102 end error: ", p), err) +} +func (p *Model) writeField103(oprot thrift.TProtocol) (err error) { + if p.IsSetUpdatedAt() { + if err = oprot.WriteFieldBegin("updated_at", thrift.I64, 103); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI64(*p.UpdatedAt); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 103 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 103 end error: ", p), err) +} + +func (p *Model) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("Model(%+v)", *p) + +} + +func (p *Model) DeepEqual(ano *Model) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.ModelID) { + return false + } + if !p.Field2DeepEqual(ano.WorkspaceID) { + return false + } + if !p.Field3DeepEqual(ano.Name) { + return false + } + if !p.Field4DeepEqual(ano.Desc) { + return false + } + if !p.Field5DeepEqual(ano.Ability) { + return false + } + if !p.Field6DeepEqual(ano.Protocol) { + return false + } + if !p.Field7DeepEqual(ano.ProtocolConfig) { + return false + } + if !p.Field8DeepEqual(ano.ScenarioConfigs) { + return false + } + if !p.Field9DeepEqual(ano.ParamConfig) { + return false + } + if !p.Field10DeepEqual(ano.Identification) { + return false + } + if !p.Field11DeepEqual(ano.Series) { + return false + } + if !p.Field12DeepEqual(ano.Visibility) { + return false + } + if !p.Field13DeepEqual(ano.Icon) { + return false + } + if !p.Field14DeepEqual(ano.Tags) { + return false + } + if !p.Field15DeepEqual(ano.Status) { + return false + } + if !p.Field16DeepEqual(ano.OriginalModelURL) { + return false + } + if !p.Field17DeepEqual(ano.PresetModel) { + return false + } + if !p.Field100DeepEqual(ano.CreatedBy) { + return false + } + if !p.Field101DeepEqual(ano.CreatedAt) { + return false + } + if !p.Field102DeepEqual(ano.UpdatedBy) { + return false + } + if !p.Field103DeepEqual(ano.UpdatedAt) { + return false + } + return true +} + +func (p *Model) Field1DeepEqual(src *int64) bool { + + if p.ModelID == src { + return true + } else if p.ModelID == nil || src == nil { + return false + } + if *p.ModelID != *src { + return false + } + return true +} +func (p *Model) Field2DeepEqual(src *int64) bool { + + if p.WorkspaceID == src { + return true + } else if p.WorkspaceID == nil || src == nil { + return false + } + if *p.WorkspaceID != *src { + return false + } + return true +} +func (p *Model) Field3DeepEqual(src *string) bool { + + if p.Name == src { + return true + } else if p.Name == nil || src == nil { + return false + } + if strings.Compare(*p.Name, *src) != 0 { + return false + } + return true +} +func (p *Model) Field4DeepEqual(src *string) bool { + + if p.Desc == src { + return true + } else if p.Desc == nil || src == nil { + return false + } + if strings.Compare(*p.Desc, *src) != 0 { + return false + } + return true +} +func (p *Model) Field5DeepEqual(src *Ability) bool { + + if !p.Ability.DeepEqual(src) { + return false + } + return true +} +func (p *Model) Field6DeepEqual(src *Protocol) bool { + + if p.Protocol == src { + return true + } else if p.Protocol == nil || src == nil { + return false + } + if strings.Compare(*p.Protocol, *src) != 0 { + return false + } + return true +} +func (p *Model) Field7DeepEqual(src *ProtocolConfig) bool { + + if !p.ProtocolConfig.DeepEqual(src) { + return false + } + return true +} +func (p *Model) Field8DeepEqual(src map[common.Scenario]*ScenarioConfig) bool { + + if len(p.ScenarioConfigs) != len(src) { + return false + } + for k, v := range p.ScenarioConfigs { + _src := src[k] + if !v.DeepEqual(_src) { + return false + } + } + return true +} +func (p *Model) Field9DeepEqual(src *ParamConfig) bool { + + if !p.ParamConfig.DeepEqual(src) { + return false + } + return true +} +func (p *Model) Field10DeepEqual(src *string) bool { + + if p.Identification == src { + return true + } else if p.Identification == nil || src == nil { + return false + } + if strings.Compare(*p.Identification, *src) != 0 { + return false + } + return true +} +func (p *Model) Field11DeepEqual(src *Series) bool { + + if !p.Series.DeepEqual(src) { + return false + } + return true +} +func (p *Model) Field12DeepEqual(src *Visibility) bool { + + if !p.Visibility.DeepEqual(src) { + return false + } + return true +} +func (p *Model) Field13DeepEqual(src *string) bool { + + if p.Icon == src { + return true + } else if p.Icon == nil || src == nil { + return false + } + if strings.Compare(*p.Icon, *src) != 0 { + return false + } + return true +} +func (p *Model) Field14DeepEqual(src []string) bool { + + if len(p.Tags) != len(src) { + return false + } + for i, v := range p.Tags { + _src := src[i] + if strings.Compare(v, _src) != 0 { + return false + } + } + return true +} +func (p *Model) Field15DeepEqual(src *ModelStatus) bool { + + if p.Status == src { + return true + } else if p.Status == nil || src == nil { + return false + } + if strings.Compare(*p.Status, *src) != 0 { + return false + } + return true +} +func (p *Model) Field16DeepEqual(src *string) bool { + + if p.OriginalModelURL == src { + return true + } else if p.OriginalModelURL == nil || src == nil { + return false + } + if strings.Compare(*p.OriginalModelURL, *src) != 0 { + return false + } + return true +} +func (p *Model) Field17DeepEqual(src *bool) bool { + + if p.PresetModel == src { + return true + } else if p.PresetModel == nil || src == nil { + return false + } + if *p.PresetModel != *src { + return false + } + return true +} +func (p *Model) Field100DeepEqual(src *string) bool { + + if p.CreatedBy == src { + return true + } else if p.CreatedBy == nil || src == nil { + return false + } + if strings.Compare(*p.CreatedBy, *src) != 0 { + return false + } + return true +} +func (p *Model) Field101DeepEqual(src *int64) bool { + + if p.CreatedAt == src { + return true + } else if p.CreatedAt == nil || src == nil { + return false + } + if *p.CreatedAt != *src { + return false + } + return true +} +func (p *Model) Field102DeepEqual(src *string) bool { + + if p.UpdatedBy == src { + return true + } else if p.UpdatedBy == nil || src == nil { + return false + } + if strings.Compare(*p.UpdatedBy, *src) != 0 { + return false + } + return true +} +func (p *Model) Field103DeepEqual(src *int64) bool { + + if p.UpdatedAt == src { + return true + } else if p.UpdatedAt == nil || src == nil { + return false + } + if *p.UpdatedAt != *src { + return false + } + return true +} + +type Series struct { + // series name + Name *string `thrift:"name,1,optional" frugal:"1,optional,string" form:"name" json:"name,omitempty" query:"name"` + // series icon url + Icon *string `thrift:"icon,2,optional" frugal:"2,optional,string" form:"icon" json:"icon,omitempty" query:"icon"` + // family name + Family *Family `thrift:"family,3,optional" frugal:"3,optional,string" form:"family" json:"family,omitempty" query:"family"` +} + +func NewSeries() *Series { + return &Series{} +} + +func (p *Series) InitDefault() { +} + +var Series_Name_DEFAULT string + +func (p *Series) GetName() (v string) { + if p == nil { + return + } + if !p.IsSetName() { + return Series_Name_DEFAULT + } + return *p.Name +} + +var Series_Icon_DEFAULT string + +func (p *Series) GetIcon() (v string) { + if p == nil { + return + } + if !p.IsSetIcon() { + return Series_Icon_DEFAULT + } + return *p.Icon +} + +var Series_Family_DEFAULT Family + +func (p *Series) GetFamily() (v Family) { + if p == nil { + return + } + if !p.IsSetFamily() { + return Series_Family_DEFAULT + } + return *p.Family +} +func (p *Series) SetName(val *string) { + p.Name = val +} +func (p *Series) SetIcon(val *string) { + p.Icon = val +} +func (p *Series) SetFamily(val *Family) { + p.Family = val +} + +var fieldIDToName_Series = map[int16]string{ + 1: "name", + 2: "icon", + 3: "family", +} + +func (p *Series) IsSetName() bool { + return p.Name != nil +} + +func (p *Series) IsSetIcon() bool { + return p.Icon != nil +} + +func (p *Series) IsSetFamily() bool { + return p.Family != nil +} + +func (p *Series) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.STRING { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 3: + if fieldTypeId == thrift.STRING { + if err = p.ReadField3(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Series[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *Series) ReadField1(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Name = _field + return nil +} +func (p *Series) ReadField2(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Icon = _field + return nil +} +func (p *Series) ReadField3(iprot thrift.TProtocol) error { + + var _field *Family + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Family = _field + return nil +} + +func (p *Series) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("Series"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + if err = p.writeField3(oprot); err != nil { + fieldId = 3 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *Series) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetName() { + if err = oprot.WriteFieldBegin("name", thrift.STRING, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Name); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *Series) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetIcon() { + if err = oprot.WriteFieldBegin("icon", thrift.STRING, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Icon); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} +func (p *Series) writeField3(oprot thrift.TProtocol) (err error) { + if p.IsSetFamily() { + if err = oprot.WriteFieldBegin("family", thrift.STRING, 3); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Family); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) +} + +func (p *Series) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("Series(%+v)", *p) + +} + +func (p *Series) DeepEqual(ano *Series) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.Name) { + return false + } + if !p.Field2DeepEqual(ano.Icon) { + return false + } + if !p.Field3DeepEqual(ano.Family) { + return false + } + return true +} + +func (p *Series) Field1DeepEqual(src *string) bool { + + if p.Name == src { + return true + } else if p.Name == nil || src == nil { + return false + } + if strings.Compare(*p.Name, *src) != 0 { + return false + } + return true +} +func (p *Series) Field2DeepEqual(src *string) bool { + + if p.Icon == src { + return true + } else if p.Icon == nil || src == nil { + return false + } + if strings.Compare(*p.Icon, *src) != 0 { + return false + } + return true +} +func (p *Series) Field3DeepEqual(src *Family) bool { + + if p.Family == src { + return true + } else if p.Family == nil || src == nil { + return false + } + if strings.Compare(*p.Family, *src) != 0 { + return false + } + return true +} + +type Visibility struct { + Mode *VisibleMode `thrift:"mode,1,optional" frugal:"1,optional,string" form:"mode" json:"mode,omitempty" query:"mode"` + // Mode为Specified有效,配置为除模型所属空间外的其他空间 + SpaceIDs []int64 `thrift:"spaceIDs,2,optional" frugal:"2,optional,list" form:"spaceIDs" json:"spaceIDs,omitempty" query:"spaceIDs"` +} + +func NewVisibility() *Visibility { + return &Visibility{} +} + +func (p *Visibility) InitDefault() { +} + +var Visibility_Mode_DEFAULT VisibleMode + +func (p *Visibility) GetMode() (v VisibleMode) { + if p == nil { + return + } + if !p.IsSetMode() { + return Visibility_Mode_DEFAULT + } + return *p.Mode +} + +var Visibility_SpaceIDs_DEFAULT []int64 + +func (p *Visibility) GetSpaceIDs() (v []int64) { + if p == nil { + return + } + if !p.IsSetSpaceIDs() { + return Visibility_SpaceIDs_DEFAULT + } + return p.SpaceIDs +} +func (p *Visibility) SetMode(val *VisibleMode) { + p.Mode = val +} +func (p *Visibility) SetSpaceIDs(val []int64) { + p.SpaceIDs = val +} + +var fieldIDToName_Visibility = map[int16]string{ + 1: "mode", + 2: "spaceIDs", +} + +func (p *Visibility) IsSetMode() bool { + return p.Mode != nil +} + +func (p *Visibility) IsSetSpaceIDs() bool { + return p.SpaceIDs != nil +} + +func (p *Visibility) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.LIST { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Visibility[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *Visibility) ReadField1(iprot thrift.TProtocol) error { + + var _field *VisibleMode + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Mode = _field + return nil +} +func (p *Visibility) ReadField2(iprot thrift.TProtocol) error { + _, size, err := iprot.ReadListBegin() + if err != nil { + return err + } + _field := make([]int64, 0, size) + for i := 0; i < size; i++ { + + var _elem int64 + if v, err := iprot.ReadI64(); err != nil { + return err + } else { + _elem = v + } + + _field = append(_field, _elem) + } + if err := iprot.ReadListEnd(); err != nil { + return err + } + p.SpaceIDs = _field + return nil +} + +func (p *Visibility) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("Visibility"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *Visibility) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetMode() { + if err = oprot.WriteFieldBegin("mode", thrift.STRING, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Mode); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *Visibility) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetSpaceIDs() { + if err = oprot.WriteFieldBegin("spaceIDs", thrift.LIST, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteListBegin(thrift.I64, len(p.SpaceIDs)); err != nil { + return err + } + for _, v := range p.SpaceIDs { + if err := oprot.WriteI64(v); err != nil { + return err + } + } + if err := oprot.WriteListEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} + +func (p *Visibility) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("Visibility(%+v)", *p) + +} + +func (p *Visibility) DeepEqual(ano *Visibility) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.Mode) { + return false + } + if !p.Field2DeepEqual(ano.SpaceIDs) { + return false + } + return true +} + +func (p *Visibility) Field1DeepEqual(src *VisibleMode) bool { + + if p.Mode == src { + return true + } else if p.Mode == nil || src == nil { + return false + } + if strings.Compare(*p.Mode, *src) != 0 { + return false + } + return true +} +func (p *Visibility) Field2DeepEqual(src []int64) bool { + + if len(p.SpaceIDs) != len(src) { + return false + } + for i, v := range p.SpaceIDs { + _src := src[i] + if v != _src { + return false + } + } + return true +} + +type ProviderInfo struct { + MaasInfo *MaaSInfo `thrift:"maas_info,1,optional" frugal:"1,optional,MaaSInfo" form:"maas_info" json:"maas_info,omitempty" query:"maas_info"` +} + +func NewProviderInfo() *ProviderInfo { + return &ProviderInfo{} +} + +func (p *ProviderInfo) InitDefault() { +} + +var ProviderInfo_MaasInfo_DEFAULT *MaaSInfo + +func (p *ProviderInfo) GetMaasInfo() (v *MaaSInfo) { + if p == nil { + return + } + if !p.IsSetMaasInfo() { + return ProviderInfo_MaasInfo_DEFAULT + } + return p.MaasInfo +} +func (p *ProviderInfo) SetMaasInfo(val *MaaSInfo) { + p.MaasInfo = val +} + +var fieldIDToName_ProviderInfo = map[int16]string{ + 1: "maas_info", +} + +func (p *ProviderInfo) IsSetMaasInfo() bool { + return p.MaasInfo != nil +} + +func (p *ProviderInfo) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ProviderInfo[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *ProviderInfo) ReadField1(iprot thrift.TProtocol) error { + _field := NewMaaSInfo() + if err := _field.Read(iprot); err != nil { + return err + } + p.MaasInfo = _field + return nil +} + +func (p *ProviderInfo) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("ProviderInfo"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *ProviderInfo) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetMaasInfo() { + if err = oprot.WriteFieldBegin("maas_info", thrift.STRUCT, 1); err != nil { + goto WriteFieldBeginError + } + if err := p.MaasInfo.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} + +func (p *ProviderInfo) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("ProviderInfo(%+v)", *p) + +} + +func (p *ProviderInfo) DeepEqual(ano *ProviderInfo) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.MaasInfo) { + return false + } + return true +} + +func (p *ProviderInfo) Field1DeepEqual(src *MaaSInfo) bool { + + if !p.MaasInfo.DeepEqual(src) { + return false + } + return true +} + +type MaaSInfo struct { + Host *string `thrift:"host,1,optional" frugal:"1,optional,string" form:"host" json:"host,omitempty" query:"host"` + Region *string `thrift:"region,2,optional" frugal:"2,optional,string" form:"region" json:"region,omitempty" query:"region"` + // v3 sdk + BaseURL *string `thrift:"baseURL,3,optional" frugal:"3,optional,string" form:"baseURL" json:"baseURL,omitempty" query:"baseURL"` + // 精调模型任务的 ID + CustomizationJobsID *string `thrift:"customizationJobsID,4,optional" frugal:"4,optional,string" form:"customizationJobsID" json:"customizationJobsID,omitempty" query:"customizationJobsID"` +} + +func NewMaaSInfo() *MaaSInfo { + return &MaaSInfo{} +} + +func (p *MaaSInfo) InitDefault() { +} + +var MaaSInfo_Host_DEFAULT string + +func (p *MaaSInfo) GetHost() (v string) { + if p == nil { + return + } + if !p.IsSetHost() { + return MaaSInfo_Host_DEFAULT + } + return *p.Host +} + +var MaaSInfo_Region_DEFAULT string + +func (p *MaaSInfo) GetRegion() (v string) { + if p == nil { + return + } + if !p.IsSetRegion() { + return MaaSInfo_Region_DEFAULT + } + return *p.Region +} + +var MaaSInfo_BaseURL_DEFAULT string + +func (p *MaaSInfo) GetBaseURL() (v string) { + if p == nil { + return + } + if !p.IsSetBaseURL() { + return MaaSInfo_BaseURL_DEFAULT + } + return *p.BaseURL +} + +var MaaSInfo_CustomizationJobsID_DEFAULT string + +func (p *MaaSInfo) GetCustomizationJobsID() (v string) { + if p == nil { + return + } + if !p.IsSetCustomizationJobsID() { + return MaaSInfo_CustomizationJobsID_DEFAULT + } + return *p.CustomizationJobsID +} +func (p *MaaSInfo) SetHost(val *string) { + p.Host = val +} +func (p *MaaSInfo) SetRegion(val *string) { + p.Region = val +} +func (p *MaaSInfo) SetBaseURL(val *string) { + p.BaseURL = val +} +func (p *MaaSInfo) SetCustomizationJobsID(val *string) { + p.CustomizationJobsID = val +} + +var fieldIDToName_MaaSInfo = map[int16]string{ + 1: "host", + 2: "region", + 3: "baseURL", + 4: "customizationJobsID", +} + +func (p *MaaSInfo) IsSetHost() bool { + return p.Host != nil +} + +func (p *MaaSInfo) IsSetRegion() bool { + return p.Region != nil +} + +func (p *MaaSInfo) IsSetBaseURL() bool { + return p.BaseURL != nil +} + +func (p *MaaSInfo) IsSetCustomizationJobsID() bool { + return p.CustomizationJobsID != nil +} + +func (p *MaaSInfo) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.STRING { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 3: + if fieldTypeId == thrift.STRING { + if err = p.ReadField3(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 4: + if fieldTypeId == thrift.STRING { + if err = p.ReadField4(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_MaaSInfo[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *MaaSInfo) ReadField1(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Host = _field + return nil +} +func (p *MaaSInfo) ReadField2(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Region = _field return nil } -func (p *Model) ReadField8(iprot thrift.TProtocol) error { - _, _, size, err := iprot.ReadMapBegin() - if err != nil { - return err - } - _field := make(map[common.Scenario]*ScenarioConfig, size) - values := make([]ScenarioConfig, size) - for i := 0; i < size; i++ { - var _key common.Scenario - if v, err := iprot.ReadString(); err != nil { - return err - } else { - _key = v - } +func (p *MaaSInfo) ReadField3(iprot thrift.TProtocol) error { - _val := &values[i] - _val.InitDefault() - if err := _val.Read(iprot); err != nil { - return err - } - - _field[_key] = _val - } - if err := iprot.ReadMapEnd(); err != nil { + var _field *string + if v, err := iprot.ReadString(); err != nil { return err + } else { + _field = &v } - p.ScenarioConfigs = _field + p.BaseURL = _field return nil } -func (p *Model) ReadField9(iprot thrift.TProtocol) error { - _field := NewParamConfig() - if err := _field.Read(iprot); err != nil { +func (p *MaaSInfo) ReadField4(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { return err + } else { + _field = &v } - p.ParamConfig = _field + p.CustomizationJobsID = _field return nil } -func (p *Model) Write(oprot thrift.TProtocol) (err error) { +func (p *MaaSInfo) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 - if err = oprot.WriteStructBegin("Model"); err != nil { + if err = oprot.WriteStructBegin("MaaSInfo"); err != nil { goto WriteStructBeginError } if p != nil { @@ -525,26 +2966,6 @@ func (p *Model) Write(oprot thrift.TProtocol) (err error) { fieldId = 4 goto WriteFieldError } - if err = p.writeField5(oprot); err != nil { - fieldId = 5 - goto WriteFieldError - } - if err = p.writeField6(oprot); err != nil { - fieldId = 6 - goto WriteFieldError - } - if err = p.writeField7(oprot); err != nil { - fieldId = 7 - goto WriteFieldError - } - if err = p.writeField8(oprot); err != nil { - fieldId = 8 - goto WriteFieldError - } - if err = p.writeField9(oprot); err != nil { - fieldId = 9 - goto WriteFieldError - } } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -563,12 +2984,12 @@ WriteStructEndError: return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } -func (p *Model) writeField1(oprot thrift.TProtocol) (err error) { - if p.IsSetModelID() { - if err = oprot.WriteFieldBegin("model_id", thrift.I64, 1); err != nil { +func (p *MaaSInfo) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetHost() { + if err = oprot.WriteFieldBegin("host", thrift.STRING, 1); err != nil { goto WriteFieldBeginError } - if err := oprot.WriteI64(*p.ModelID); err != nil { + if err := oprot.WriteString(*p.Host); err != nil { return err } if err = oprot.WriteFieldEnd(); err != nil { @@ -581,12 +3002,12 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) } -func (p *Model) writeField2(oprot thrift.TProtocol) (err error) { - if p.IsSetWorkspaceID() { - if err = oprot.WriteFieldBegin("workspace_id", thrift.I64, 2); err != nil { +func (p *MaaSInfo) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetRegion() { + if err = oprot.WriteFieldBegin("region", thrift.STRING, 2); err != nil { goto WriteFieldBeginError } - if err := oprot.WriteI64(*p.WorkspaceID); err != nil { + if err := oprot.WriteString(*p.Region); err != nil { return err } if err = oprot.WriteFieldEnd(); err != nil { @@ -599,12 +3020,12 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) } -func (p *Model) writeField3(oprot thrift.TProtocol) (err error) { - if p.IsSetName() { - if err = oprot.WriteFieldBegin("name", thrift.STRING, 3); err != nil { +func (p *MaaSInfo) writeField3(oprot thrift.TProtocol) (err error) { + if p.IsSetBaseURL() { + if err = oprot.WriteFieldBegin("baseURL", thrift.STRING, 3); err != nil { goto WriteFieldBeginError } - if err := oprot.WriteString(*p.Name); err != nil { + if err := oprot.WriteString(*p.BaseURL); err != nil { return err } if err = oprot.WriteFieldEnd(); err != nil { @@ -617,12 +3038,12 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) } -func (p *Model) writeField4(oprot thrift.TProtocol) (err error) { - if p.IsSetDesc() { - if err = oprot.WriteFieldBegin("desc", thrift.STRING, 4); err != nil { +func (p *MaaSInfo) writeField4(oprot thrift.TProtocol) (err error) { + if p.IsSetCustomizationJobsID() { + if err = oprot.WriteFieldBegin("customizationJobsID", thrift.STRING, 4); err != nil { goto WriteFieldBeginError } - if err := oprot.WriteString(*p.Desc); err != nil { + if err := oprot.WriteString(*p.CustomizationJobsID); err != nil { return err } if err = oprot.WriteFieldEnd(); err != nil { @@ -635,242 +3056,80 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) } -func (p *Model) writeField5(oprot thrift.TProtocol) (err error) { - if p.IsSetAbility() { - if err = oprot.WriteFieldBegin("ability", thrift.STRUCT, 5); err != nil { - goto WriteFieldBeginError - } - if err := p.Ability.Write(oprot); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 5 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 5 end error: ", p), err) -} -func (p *Model) writeField6(oprot thrift.TProtocol) (err error) { - if p.IsSetProtocol() { - if err = oprot.WriteFieldBegin("protocol", thrift.STRING, 6); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteString(*p.Protocol); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 6 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 6 end error: ", p), err) -} -func (p *Model) writeField7(oprot thrift.TProtocol) (err error) { - if p.IsSetProtocolConfig() { - if err = oprot.WriteFieldBegin("protocol_config", thrift.STRUCT, 7); err != nil { - goto WriteFieldBeginError - } - if err := p.ProtocolConfig.Write(oprot); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 7 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 7 end error: ", p), err) -} -func (p *Model) writeField8(oprot thrift.TProtocol) (err error) { - if p.IsSetScenarioConfigs() { - if err = oprot.WriteFieldBegin("scenario_configs", thrift.MAP, 8); err != nil { - goto WriteFieldBeginError - } - if err := oprot.WriteMapBegin(thrift.STRING, thrift.STRUCT, len(p.ScenarioConfigs)); err != nil { - return err - } - for k, v := range p.ScenarioConfigs { - if err := oprot.WriteString(k); err != nil { - return err - } - if err := v.Write(oprot); err != nil { - return err - } - } - if err := oprot.WriteMapEnd(); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 8 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 8 end error: ", p), err) -} -func (p *Model) writeField9(oprot thrift.TProtocol) (err error) { - if p.IsSetParamConfig() { - if err = oprot.WriteFieldBegin("param_config", thrift.STRUCT, 9); err != nil { - goto WriteFieldBeginError - } - if err := p.ParamConfig.Write(oprot); err != nil { - return err - } - if err = oprot.WriteFieldEnd(); err != nil { - goto WriteFieldEndError - } - } - return nil -WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 9 begin error: ", p), err) -WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 9 end error: ", p), err) -} -func (p *Model) String() string { +func (p *MaaSInfo) String() string { if p == nil { return "" } - return fmt.Sprintf("Model(%+v)", *p) + return fmt.Sprintf("MaaSInfo(%+v)", *p) } -func (p *Model) DeepEqual(ano *Model) bool { +func (p *MaaSInfo) DeepEqual(ano *MaaSInfo) bool { if p == ano { return true } else if p == nil || ano == nil { return false } - if !p.Field1DeepEqual(ano.ModelID) { - return false - } - if !p.Field2DeepEqual(ano.WorkspaceID) { - return false - } - if !p.Field3DeepEqual(ano.Name) { - return false - } - if !p.Field4DeepEqual(ano.Desc) { - return false - } - if !p.Field5DeepEqual(ano.Ability) { - return false - } - if !p.Field6DeepEqual(ano.Protocol) { - return false - } - if !p.Field7DeepEqual(ano.ProtocolConfig) { - return false - } - if !p.Field8DeepEqual(ano.ScenarioConfigs) { - return false - } - if !p.Field9DeepEqual(ano.ParamConfig) { - return false - } - return true -} - -func (p *Model) Field1DeepEqual(src *int64) bool { - - if p.ModelID == src { - return true - } else if p.ModelID == nil || src == nil { - return false - } - if *p.ModelID != *src { + if !p.Field1DeepEqual(ano.Host) { return false } - return true -} -func (p *Model) Field2DeepEqual(src *int64) bool { - - if p.WorkspaceID == src { - return true - } else if p.WorkspaceID == nil || src == nil { - return false - } - if *p.WorkspaceID != *src { - return false - } - return true -} -func (p *Model) Field3DeepEqual(src *string) bool { - - if p.Name == src { - return true - } else if p.Name == nil || src == nil { - return false - } - if strings.Compare(*p.Name, *src) != 0 { - return false - } - return true -} -func (p *Model) Field4DeepEqual(src *string) bool { - - if p.Desc == src { - return true - } else if p.Desc == nil || src == nil { + if !p.Field2DeepEqual(ano.Region) { return false } - if strings.Compare(*p.Desc, *src) != 0 { + if !p.Field3DeepEqual(ano.BaseURL) { return false } - return true -} -func (p *Model) Field5DeepEqual(src *Ability) bool { - - if !p.Ability.DeepEqual(src) { + if !p.Field4DeepEqual(ano.CustomizationJobsID) { return false } return true } -func (p *Model) Field6DeepEqual(src *Protocol) bool { - if p.Protocol == src { +func (p *MaaSInfo) Field1DeepEqual(src *string) bool { + + if p.Host == src { return true - } else if p.Protocol == nil || src == nil { + } else if p.Host == nil || src == nil { return false } - if strings.Compare(*p.Protocol, *src) != 0 { + if strings.Compare(*p.Host, *src) != 0 { return false } return true } -func (p *Model) Field7DeepEqual(src *ProtocolConfig) bool { +func (p *MaaSInfo) Field2DeepEqual(src *string) bool { - if !p.ProtocolConfig.DeepEqual(src) { + if p.Region == src { + return true + } else if p.Region == nil || src == nil { + return false + } + if strings.Compare(*p.Region, *src) != 0 { return false } return true } -func (p *Model) Field8DeepEqual(src map[common.Scenario]*ScenarioConfig) bool { +func (p *MaaSInfo) Field3DeepEqual(src *string) bool { - if len(p.ScenarioConfigs) != len(src) { + if p.BaseURL == src { + return true + } else if p.BaseURL == nil || src == nil { return false } - for k, v := range p.ScenarioConfigs { - _src := src[k] - if !v.DeepEqual(_src) { - return false - } + if strings.Compare(*p.BaseURL, *src) != 0 { + return false } return true } -func (p *Model) Field9DeepEqual(src *ParamConfig) bool { +func (p *MaaSInfo) Field4DeepEqual(src *string) bool { - if !p.ParamConfig.DeepEqual(src) { + if p.CustomizationJobsID == src { + return true + } else if p.CustomizationJobsID == nil || src == nil { + return false + } + if strings.Compare(*p.CustomizationJobsID, *src) != 0 { return false } return true @@ -884,6 +3143,7 @@ type Ability struct { JSONMode *bool `thrift:"json_mode,5,optional" frugal:"5,optional,bool" form:"json_mode" json:"json_mode,omitempty" query:"json_mode"` MultiModal *bool `thrift:"multi_modal,6,optional" frugal:"6,optional,bool" form:"multi_modal" json:"multi_modal,omitempty" query:"multi_modal"` AbilityMultiModal *AbilityMultiModal `thrift:"ability_multi_modal,7,optional" frugal:"7,optional,AbilityMultiModal" form:"ability_multi_modal" json:"ability_multi_modal,omitempty" query:"ability_multi_modal"` + InterfaceCategory *InterfaceCategory `thrift:"interface_category,8,optional" frugal:"8,optional,string" form:"interface_category" json:"interface_category,omitempty" query:"interface_category"` } func NewAbility() *Ability { @@ -976,6 +3236,18 @@ func (p *Ability) GetAbilityMultiModal() (v *AbilityMultiModal) { } return p.AbilityMultiModal } + +var Ability_InterfaceCategory_DEFAULT InterfaceCategory + +func (p *Ability) GetInterfaceCategory() (v InterfaceCategory) { + if p == nil { + return + } + if !p.IsSetInterfaceCategory() { + return Ability_InterfaceCategory_DEFAULT + } + return *p.InterfaceCategory +} func (p *Ability) SetMaxContextTokens(val *int64) { p.MaxContextTokens = val } @@ -997,6 +3269,9 @@ func (p *Ability) SetMultiModal(val *bool) { func (p *Ability) SetAbilityMultiModal(val *AbilityMultiModal) { p.AbilityMultiModal = val } +func (p *Ability) SetInterfaceCategory(val *InterfaceCategory) { + p.InterfaceCategory = val +} var fieldIDToName_Ability = map[int16]string{ 1: "max_context_tokens", @@ -1006,6 +3281,7 @@ var fieldIDToName_Ability = map[int16]string{ 5: "json_mode", 6: "multi_modal", 7: "ability_multi_modal", + 8: "interface_category", } func (p *Ability) IsSetMaxContextTokens() bool { @@ -1036,6 +3312,10 @@ func (p *Ability) IsSetAbilityMultiModal() bool { return p.AbilityMultiModal != nil } +func (p *Ability) IsSetInterfaceCategory() bool { + return p.InterfaceCategory != nil +} + func (p *Ability) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -1110,6 +3390,14 @@ func (p *Ability) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 8: + if fieldTypeId == thrift.STRING { + if err = p.ReadField8(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError @@ -1213,6 +3501,17 @@ func (p *Ability) ReadField7(iprot thrift.TProtocol) error { p.AbilityMultiModal = _field return nil } +func (p *Ability) ReadField8(iprot thrift.TProtocol) error { + + var _field *InterfaceCategory + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.InterfaceCategory = _field + return nil +} func (p *Ability) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 @@ -1248,6 +3547,10 @@ func (p *Ability) Write(oprot thrift.TProtocol) (err error) { fieldId = 7 goto WriteFieldError } + if err = p.writeField8(oprot); err != nil { + fieldId = 8 + goto WriteFieldError + } } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -1392,6 +3695,24 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 7 end error: ", p), err) } +func (p *Ability) writeField8(oprot thrift.TProtocol) (err error) { + if p.IsSetInterfaceCategory() { + if err = oprot.WriteFieldBegin("interface_category", thrift.STRING, 8); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.InterfaceCategory); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 8 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 8 end error: ", p), err) +} func (p *Ability) String() string { if p == nil { @@ -1428,6 +3749,9 @@ func (p *Ability) DeepEqual(ano *Ability) bool { if !p.Field7DeepEqual(ano.AbilityMultiModal) { return false } + if !p.Field8DeepEqual(ano.InterfaceCategory) { + return false + } return true } @@ -1510,10 +3834,24 @@ func (p *Ability) Field7DeepEqual(src *AbilityMultiModal) bool { } return true } +func (p *Ability) Field8DeepEqual(src *InterfaceCategory) bool { + + if p.InterfaceCategory == src { + return true + } else if p.InterfaceCategory == nil || src == nil { + return false + } + if strings.Compare(*p.InterfaceCategory, *src) != 0 { + return false + } + return true +} type AbilityMultiModal struct { + // 图片 Image *bool `thrift:"image,1,optional" frugal:"1,optional,bool" form:"image" json:"image,omitempty" query:"image"` AbilityImage *AbilityImage `thrift:"ability_image,2,optional" frugal:"2,optional,AbilityImage" form:"ability_image" json:"ability_image,omitempty" query:"ability_image"` + // 视频 Video *bool `thrift:"video,3,optional" frugal:"3,optional,bool" form:"video" json:"video,omitempty" query:"video"` AbilityVideo *AbilityVideo `thrift:"ability_video,4,optional" frugal:"4,optional,AbilityVideo" form:"ability_video" json:"ability_video,omitempty" query:"ability_video"` } @@ -7991,6 +10329,11 @@ type ParamSchema struct { Max *string `thrift:"max,6,optional" frugal:"6,optional,string" form:"max" json:"max,omitempty" query:"max"` DefaultValue *string `thrift:"default_value,7,optional" frugal:"7,optional,string" form:"default_value" json:"default_value,omitempty" query:"default_value"` Options []*ParamOption `thrift:"options,8,optional" frugal:"8,optional,list" form:"options" json:"options,omitempty" query:"options"` + Properties []*ParamSchema `thrift:"properties,9,optional" frugal:"9,optional,list" form:"properties" json:"properties,omitempty" query:"properties"` + // 依赖参数 + Reaction *Reaction `thrift:"reaction,10,optional" frugal:"10,optional,Reaction" form:"reaction" json:"reaction,omitempty" query:"reaction"` + // 赋值路径 + Jsonpath *string `thrift:"jsonpath,11,optional" frugal:"11,optional,string" form:"jsonpath" json:"jsonpath,omitempty" query:"jsonpath"` } func NewParamSchema() *ParamSchema { @@ -8095,6 +10438,42 @@ func (p *ParamSchema) GetOptions() (v []*ParamOption) { } return p.Options } + +var ParamSchema_Properties_DEFAULT []*ParamSchema + +func (p *ParamSchema) GetProperties() (v []*ParamSchema) { + if p == nil { + return + } + if !p.IsSetProperties() { + return ParamSchema_Properties_DEFAULT + } + return p.Properties +} + +var ParamSchema_Reaction_DEFAULT *Reaction + +func (p *ParamSchema) GetReaction() (v *Reaction) { + if p == nil { + return + } + if !p.IsSetReaction() { + return ParamSchema_Reaction_DEFAULT + } + return p.Reaction +} + +var ParamSchema_Jsonpath_DEFAULT string + +func (p *ParamSchema) GetJsonpath() (v string) { + if p == nil { + return + } + if !p.IsSetJsonpath() { + return ParamSchema_Jsonpath_DEFAULT + } + return *p.Jsonpath +} func (p *ParamSchema) SetName(val *string) { p.Name = val } @@ -8119,16 +10498,28 @@ func (p *ParamSchema) SetDefaultValue(val *string) { func (p *ParamSchema) SetOptions(val []*ParamOption) { p.Options = val } +func (p *ParamSchema) SetProperties(val []*ParamSchema) { + p.Properties = val +} +func (p *ParamSchema) SetReaction(val *Reaction) { + p.Reaction = val +} +func (p *ParamSchema) SetJsonpath(val *string) { + p.Jsonpath = val +} var fieldIDToName_ParamSchema = map[int16]string{ - 1: "name", - 2: "label", - 3: "desc", - 4: "type", - 5: "min", - 6: "max", - 7: "default_value", - 8: "options", + 1: "name", + 2: "label", + 3: "desc", + 4: "type", + 5: "min", + 6: "max", + 7: "default_value", + 8: "options", + 9: "properties", + 10: "reaction", + 11: "jsonpath", } func (p *ParamSchema) IsSetName() bool { @@ -8163,6 +10554,18 @@ func (p *ParamSchema) IsSetOptions() bool { return p.Options != nil } +func (p *ParamSchema) IsSetProperties() bool { + return p.Properties != nil +} + +func (p *ParamSchema) IsSetReaction() bool { + return p.Reaction != nil +} + +func (p *ParamSchema) IsSetJsonpath() bool { + return p.Jsonpath != nil +} + func (p *ParamSchema) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -8245,6 +10648,30 @@ func (p *ParamSchema) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 9: + if fieldTypeId == thrift.LIST { + if err = p.ReadField9(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 10: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField10(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 11: + if fieldTypeId == thrift.STRING { + if err = p.ReadField11(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError @@ -8374,6 +10801,48 @@ func (p *ParamSchema) ReadField8(iprot thrift.TProtocol) error { p.Options = _field return nil } +func (p *ParamSchema) ReadField9(iprot thrift.TProtocol) error { + _, size, err := iprot.ReadListBegin() + if err != nil { + return err + } + _field := make([]*ParamSchema, 0, size) + values := make([]ParamSchema, size) + for i := 0; i < size; i++ { + _elem := &values[i] + _elem.InitDefault() + + if err := _elem.Read(iprot); err != nil { + return err + } + + _field = append(_field, _elem) + } + if err := iprot.ReadListEnd(); err != nil { + return err + } + p.Properties = _field + return nil +} +func (p *ParamSchema) ReadField10(iprot thrift.TProtocol) error { + _field := NewReaction() + if err := _field.Read(iprot); err != nil { + return err + } + p.Reaction = _field + return nil +} +func (p *ParamSchema) ReadField11(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Jsonpath = _field + return nil +} func (p *ParamSchema) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 @@ -8413,6 +10882,18 @@ func (p *ParamSchema) Write(oprot thrift.TProtocol) (err error) { fieldId = 8 goto WriteFieldError } + if err = p.writeField9(oprot); err != nil { + fieldId = 9 + goto WriteFieldError + } + if err = p.writeField10(oprot); err != nil { + fieldId = 10 + goto WriteFieldError + } + if err = p.writeField11(oprot); err != nil { + fieldId = 11 + goto WriteFieldError + } } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -8539,12 +11020,82 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 6 end error: ", p), err) } -func (p *ParamSchema) writeField7(oprot thrift.TProtocol) (err error) { - if p.IsSetDefaultValue() { - if err = oprot.WriteFieldBegin("default_value", thrift.STRING, 7); err != nil { +func (p *ParamSchema) writeField7(oprot thrift.TProtocol) (err error) { + if p.IsSetDefaultValue() { + if err = oprot.WriteFieldBegin("default_value", thrift.STRING, 7); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.DefaultValue); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 7 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 7 end error: ", p), err) +} +func (p *ParamSchema) writeField8(oprot thrift.TProtocol) (err error) { + if p.IsSetOptions() { + if err = oprot.WriteFieldBegin("options", thrift.LIST, 8); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteListBegin(thrift.STRUCT, len(p.Options)); err != nil { + return err + } + for _, v := range p.Options { + if err := v.Write(oprot); err != nil { + return err + } + } + if err := oprot.WriteListEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 8 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 8 end error: ", p), err) +} +func (p *ParamSchema) writeField9(oprot thrift.TProtocol) (err error) { + if p.IsSetProperties() { + if err = oprot.WriteFieldBegin("properties", thrift.LIST, 9); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteListBegin(thrift.STRUCT, len(p.Properties)); err != nil { + return err + } + for _, v := range p.Properties { + if err := v.Write(oprot); err != nil { + return err + } + } + if err := oprot.WriteListEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 9 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 9 end error: ", p), err) +} +func (p *ParamSchema) writeField10(oprot thrift.TProtocol) (err error) { + if p.IsSetReaction() { + if err = oprot.WriteFieldBegin("reaction", thrift.STRUCT, 10); err != nil { goto WriteFieldBeginError } - if err := oprot.WriteString(*p.DefaultValue); err != nil { + if err := p.Reaction.Write(oprot); err != nil { return err } if err = oprot.WriteFieldEnd(); err != nil { @@ -8553,24 +11104,16 @@ func (p *ParamSchema) writeField7(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 7 begin error: ", p), err) + return thrift.PrependError(fmt.Sprintf("%T write field 10 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 7 end error: ", p), err) + return thrift.PrependError(fmt.Sprintf("%T write field 10 end error: ", p), err) } -func (p *ParamSchema) writeField8(oprot thrift.TProtocol) (err error) { - if p.IsSetOptions() { - if err = oprot.WriteFieldBegin("options", thrift.LIST, 8); err != nil { +func (p *ParamSchema) writeField11(oprot thrift.TProtocol) (err error) { + if p.IsSetJsonpath() { + if err = oprot.WriteFieldBegin("jsonpath", thrift.STRING, 11); err != nil { goto WriteFieldBeginError } - if err := oprot.WriteListBegin(thrift.STRUCT, len(p.Options)); err != nil { - return err - } - for _, v := range p.Options { - if err := v.Write(oprot); err != nil { - return err - } - } - if err := oprot.WriteListEnd(); err != nil { + if err := oprot.WriteString(*p.Jsonpath); err != nil { return err } if err = oprot.WriteFieldEnd(); err != nil { @@ -8579,9 +11122,9 @@ func (p *ParamSchema) writeField8(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 8 begin error: ", p), err) + return thrift.PrependError(fmt.Sprintf("%T write field 11 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 8 end error: ", p), err) + return thrift.PrependError(fmt.Sprintf("%T write field 11 end error: ", p), err) } func (p *ParamSchema) String() string { @@ -8622,6 +11165,15 @@ func (p *ParamSchema) DeepEqual(ano *ParamSchema) bool { if !p.Field8DeepEqual(ano.Options) { return false } + if !p.Field9DeepEqual(ano.Properties) { + return false + } + if !p.Field10DeepEqual(ano.Reaction) { + return false + } + if !p.Field11DeepEqual(ano.Jsonpath) { + return false + } return true } @@ -8722,6 +11274,298 @@ func (p *ParamSchema) Field8DeepEqual(src []*ParamOption) bool { } return true } +func (p *ParamSchema) Field9DeepEqual(src []*ParamSchema) bool { + + if len(p.Properties) != len(src) { + return false + } + for i, v := range p.Properties { + _src := src[i] + if !v.DeepEqual(_src) { + return false + } + } + return true +} +func (p *ParamSchema) Field10DeepEqual(src *Reaction) bool { + + if !p.Reaction.DeepEqual(src) { + return false + } + return true +} +func (p *ParamSchema) Field11DeepEqual(src *string) bool { + + if p.Jsonpath == src { + return true + } else if p.Jsonpath == nil || src == nil { + return false + } + if strings.Compare(*p.Jsonpath, *src) != 0 { + return false + } + return true +} + +type Reaction struct { + // 依赖的字段 + Dependency *string `thrift:"dependency,1,optional" frugal:"1,optional,string" form:"dependency" json:"dependency,omitempty" query:"dependency"` + // 可见性表达式 + Visible *string `thrift:"visible,2,optional" frugal:"2,optional,string" form:"visible" json:"visible,omitempty" query:"visible"` +} + +func NewReaction() *Reaction { + return &Reaction{} +} + +func (p *Reaction) InitDefault() { +} + +var Reaction_Dependency_DEFAULT string + +func (p *Reaction) GetDependency() (v string) { + if p == nil { + return + } + if !p.IsSetDependency() { + return Reaction_Dependency_DEFAULT + } + return *p.Dependency +} + +var Reaction_Visible_DEFAULT string + +func (p *Reaction) GetVisible() (v string) { + if p == nil { + return + } + if !p.IsSetVisible() { + return Reaction_Visible_DEFAULT + } + return *p.Visible +} +func (p *Reaction) SetDependency(val *string) { + p.Dependency = val +} +func (p *Reaction) SetVisible(val *string) { + p.Visible = val +} + +var fieldIDToName_Reaction = map[int16]string{ + 1: "dependency", + 2: "visible", +} + +func (p *Reaction) IsSetDependency() bool { + return p.Dependency != nil +} + +func (p *Reaction) IsSetVisible() bool { + return p.Visible != nil +} + +func (p *Reaction) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.STRING { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Reaction[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *Reaction) ReadField1(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Dependency = _field + return nil +} +func (p *Reaction) ReadField2(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Visible = _field + return nil +} + +func (p *Reaction) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("Reaction"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *Reaction) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetDependency() { + if err = oprot.WriteFieldBegin("dependency", thrift.STRING, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Dependency); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *Reaction) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetVisible() { + if err = oprot.WriteFieldBegin("visible", thrift.STRING, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Visible); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} + +func (p *Reaction) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("Reaction(%+v)", *p) + +} + +func (p *Reaction) DeepEqual(ano *Reaction) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.Dependency) { + return false + } + if !p.Field2DeepEqual(ano.Visible) { + return false + } + return true +} + +func (p *Reaction) Field1DeepEqual(src *string) bool { + + if p.Dependency == src { + return true + } else if p.Dependency == nil || src == nil { + return false + } + if strings.Compare(*p.Dependency, *src) != 0 { + return false + } + return true +} +func (p *Reaction) Field2DeepEqual(src *string) bool { + + if p.Visible == src { + return true + } else if p.Visible == nil || src == nil { + return false + } + if strings.Compare(*p.Visible, *src) != 0 { + return false + } + return true +} type ParamOption struct { // 实际值 diff --git a/backend/kitex_gen/coze/loop/llm/domain/manage/manage_validator.go b/backend/kitex_gen/coze/loop/llm/domain/manage/manage_validator.go index b10464682..e5cabe55e 100644 --- a/backend/kitex_gen/coze/loop/llm/domain/manage/manage_validator.go +++ b/backend/kitex_gen/coze/loop/llm/domain/manage/manage_validator.go @@ -37,6 +37,33 @@ func (p *Model) IsValid() error { return fmt.Errorf("field ParamConfig not valid, %w", err) } } + if p.Series != nil { + if err := p.Series.IsValid(); err != nil { + return fmt.Errorf("field Series not valid, %w", err) + } + } + if p.Visibility != nil { + if err := p.Visibility.IsValid(); err != nil { + return fmt.Errorf("field Visibility not valid, %w", err) + } + } + return nil +} +func (p *Series) IsValid() error { + return nil +} +func (p *Visibility) IsValid() error { + return nil +} +func (p *ProviderInfo) IsValid() error { + if p.MaasInfo != nil { + if err := p.MaasInfo.IsValid(); err != nil { + return fmt.Errorf("field MaasInfo not valid, %w", err) + } + } + return nil +} +func (p *MaaSInfo) IsValid() error { return nil } func (p *Ability) IsValid() error { @@ -156,6 +183,14 @@ func (p *ParamConfig) IsValid() error { return nil } func (p *ParamSchema) IsValid() error { + if p.Reaction != nil { + if err := p.Reaction.IsValid(); err != nil { + return fmt.Errorf("field Reaction not valid, %w", err) + } + } + return nil +} +func (p *Reaction) IsValid() error { return nil } func (p *ParamOption) IsValid() error { diff --git a/backend/kitex_gen/coze/loop/llm/domain/runtime/k-runtime.go b/backend/kitex_gen/coze/loop/llm/domain/runtime/k-runtime.go index d6baa6312..3c0a4b2dc 100644 --- a/backend/kitex_gen/coze/loop/llm/domain/runtime/k-runtime.go +++ b/backend/kitex_gen/coze/loop/llm/domain/runtime/k-runtime.go @@ -188,6 +188,48 @@ func (p *ModelConfig) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 11: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField11(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 12: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField12(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 13: + if fieldTypeId == thrift.BOOL { + l, err = p.FastReadField13(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } case 100: if fieldTypeId == thrift.LIST { l, err = p.FastReadField100(buf[offset:]) @@ -202,6 +244,20 @@ func (p *ModelConfig) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 101: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField101(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } default: l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l @@ -374,6 +430,48 @@ func (p *ModelConfig) FastReadField10(buf []byte) (int, error) { return offset, nil } +func (p *ModelConfig) FastReadField11(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Identification = _field + return offset, nil +} + +func (p *ModelConfig) FastReadField12(buf []byte) (int, error) { + offset := 0 + + var _field *manage.Protocol + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Protocol = _field + return offset, nil +} + +func (p *ModelConfig) FastReadField13(buf []byte) (int, error) { + offset := 0 + + var _field *bool + if v, l, err := thrift.Binary.ReadBool(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.PresetModel = _field + return offset, nil +} + func (p *ModelConfig) FastReadField100(buf []byte) (int, error) { offset := 0 @@ -399,6 +497,20 @@ func (p *ModelConfig) FastReadField100(buf []byte) (int, error) { return offset, nil } +func (p *ModelConfig) FastReadField101(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Extra = _field + return offset, nil +} + func (p *ModelConfig) FastWrite(buf []byte) int { return p.FastWriteNocopy(buf, nil) } @@ -413,10 +525,14 @@ func (p *ModelConfig) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset += p.fastWriteField8(buf[offset:], w) offset += p.fastWriteField9(buf[offset:], w) offset += p.fastWriteField10(buf[offset:], w) + offset += p.fastWriteField13(buf[offset:], w) offset += p.fastWriteField5(buf[offset:], w) offset += p.fastWriteField6(buf[offset:], w) offset += p.fastWriteField7(buf[offset:], w) + offset += p.fastWriteField11(buf[offset:], w) + offset += p.fastWriteField12(buf[offset:], w) offset += p.fastWriteField100(buf[offset:], w) + offset += p.fastWriteField101(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset @@ -435,7 +551,11 @@ func (p *ModelConfig) BLength() int { l += p.field8Length() l += p.field9Length() l += p.field10Length() + l += p.field11Length() + l += p.field12Length() + l += p.field13Length() l += p.field100Length() + l += p.field101Length() } l += thrift.Binary.FieldStopLength() return l @@ -536,6 +656,33 @@ func (p *ModelConfig) fastWriteField10(buf []byte, w thrift.NocopyWriter) int { return offset } +func (p *ModelConfig) fastWriteField11(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetIdentification() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 11) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Identification) + } + return offset +} + +func (p *ModelConfig) fastWriteField12(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetProtocol() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 12) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Protocol) + } + return offset +} + +func (p *ModelConfig) fastWriteField13(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetPresetModel() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.BOOL, 13) + offset += thrift.Binary.WriteBool(buf[offset:], *p.PresetModel) + } + return offset +} + func (p *ModelConfig) fastWriteField100(buf []byte, w thrift.NocopyWriter) int { offset := 0 if p.IsSetParamConfigValues() { @@ -552,6 +699,15 @@ func (p *ModelConfig) fastWriteField100(buf []byte, w thrift.NocopyWriter) int { return offset } +func (p *ModelConfig) fastWriteField101(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetExtra() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 101) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Extra) + } + return offset +} + func (p *ModelConfig) field1Length() int { l := 0 l += thrift.Binary.FieldBeginLength() @@ -644,6 +800,33 @@ func (p *ModelConfig) field10Length() int { return l } +func (p *ModelConfig) field11Length() int { + l := 0 + if p.IsSetIdentification() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Identification) + } + return l +} + +func (p *ModelConfig) field12Length() int { + l := 0 + if p.IsSetProtocol() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Protocol) + } + return l +} + +func (p *ModelConfig) field13Length() int { + l := 0 + if p.IsSetPresetModel() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.BoolLength() + } + return l +} + func (p *ModelConfig) field100Length() int { l := 0 if p.IsSetParamConfigValues() { @@ -657,6 +840,15 @@ func (p *ModelConfig) field100Length() int { return l } +func (p *ModelConfig) field101Length() int { + l := 0 + if p.IsSetExtra() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Extra) + } + return l +} + func (p *ModelConfig) DeepCopy(s interface{}) error { src, ok := s.(*ModelConfig) if !ok { @@ -720,6 +912,24 @@ func (p *ModelConfig) DeepCopy(s interface{}) error { p.FrequencyPenalty = &tmp } + if src.Identification != nil { + var tmp string + if *src.Identification != "" { + tmp = kutils.StringDeepCopy(*src.Identification) + } + p.Identification = &tmp + } + + if src.Protocol != nil { + tmp := *src.Protocol + p.Protocol = &tmp + } + + if src.PresetModel != nil { + tmp := *src.PresetModel + p.PresetModel = &tmp + } + if src.ParamConfigValues != nil { p.ParamConfigValues = make([]*ParamConfigValue, 0, len(src.ParamConfigValues)) for _, elem := range src.ParamConfigValues { @@ -735,6 +945,14 @@ func (p *ModelConfig) DeepCopy(s interface{}) error { } } + if src.Extra != nil { + var tmp string + if *src.Extra != "" { + tmp = kutils.StringDeepCopy(*src.Extra) + } + p.Extra = &tmp + } + return nil } diff --git a/backend/kitex_gen/coze/loop/llm/domain/runtime/runtime.go b/backend/kitex_gen/coze/loop/llm/domain/runtime/runtime.go index 3bc6f8e85..717a9b88a 100644 --- a/backend/kitex_gen/coze/loop/llm/domain/runtime/runtime.go +++ b/backend/kitex_gen/coze/loop/llm/domain/runtime/runtime.go @@ -80,8 +80,14 @@ type ModelConfig struct { TopK *int32 `thrift:"top_k,8,optional" frugal:"8,optional,i32" form:"top_k" json:"top_k,omitempty" query:"top_k"` PresencePenalty *float64 `thrift:"presence_penalty,9,optional" frugal:"9,optional,double" form:"presence_penalty" json:"presence_penalty,omitempty" query:"presence_penalty"` FrequencyPenalty *float64 `thrift:"frequency_penalty,10,optional" frugal:"10,optional,double" form:"frequency_penalty" json:"frequency_penalty,omitempty" query:"frequency_penalty"` + Identification *string `thrift:"identification,11,optional" frugal:"11,optional,string" form:"identification" json:"identification,omitempty" query:"identification"` + // 模型提供方 + Protocol *manage.Protocol `thrift:"protocol,12,optional" frugal:"12,optional,string" form:"protocol" json:"protocol,omitempty" query:"protocol"` + // 是否为预置模型 + PresetModel *bool `thrift:"preset_model,13,optional" frugal:"13,optional,bool" form:"preset_model" json:"preset_model,omitempty" query:"preset_model"` // 与ParamSchema对应 ParamConfigValues []*ParamConfigValue `thrift:"param_config_values,100,optional" frugal:"100,optional,list" form:"param_config_values" json:"param_config_values,omitempty" query:"param_config_values"` + Extra *string `thrift:"extra,101,optional" frugal:"101,optional,string" form:"extra" json:"extra,omitempty" query:"extra"` } func NewModelConfig() *ModelConfig { @@ -206,6 +212,42 @@ func (p *ModelConfig) GetFrequencyPenalty() (v float64) { return *p.FrequencyPenalty } +var ModelConfig_Identification_DEFAULT string + +func (p *ModelConfig) GetIdentification() (v string) { + if p == nil { + return + } + if !p.IsSetIdentification() { + return ModelConfig_Identification_DEFAULT + } + return *p.Identification +} + +var ModelConfig_Protocol_DEFAULT manage.Protocol + +func (p *ModelConfig) GetProtocol() (v manage.Protocol) { + if p == nil { + return + } + if !p.IsSetProtocol() { + return ModelConfig_Protocol_DEFAULT + } + return *p.Protocol +} + +var ModelConfig_PresetModel_DEFAULT bool + +func (p *ModelConfig) GetPresetModel() (v bool) { + if p == nil { + return + } + if !p.IsSetPresetModel() { + return ModelConfig_PresetModel_DEFAULT + } + return *p.PresetModel +} + var ModelConfig_ParamConfigValues_DEFAULT []*ParamConfigValue func (p *ModelConfig) GetParamConfigValues() (v []*ParamConfigValue) { @@ -217,6 +259,18 @@ func (p *ModelConfig) GetParamConfigValues() (v []*ParamConfigValue) { } return p.ParamConfigValues } + +var ModelConfig_Extra_DEFAULT string + +func (p *ModelConfig) GetExtra() (v string) { + if p == nil { + return + } + if !p.IsSetExtra() { + return ModelConfig_Extra_DEFAULT + } + return *p.Extra +} func (p *ModelConfig) SetModelID(val int64) { p.ModelID = val } @@ -247,9 +301,21 @@ func (p *ModelConfig) SetPresencePenalty(val *float64) { func (p *ModelConfig) SetFrequencyPenalty(val *float64) { p.FrequencyPenalty = val } +func (p *ModelConfig) SetIdentification(val *string) { + p.Identification = val +} +func (p *ModelConfig) SetProtocol(val *manage.Protocol) { + p.Protocol = val +} +func (p *ModelConfig) SetPresetModel(val *bool) { + p.PresetModel = val +} func (p *ModelConfig) SetParamConfigValues(val []*ParamConfigValue) { p.ParamConfigValues = val } +func (p *ModelConfig) SetExtra(val *string) { + p.Extra = val +} var fieldIDToName_ModelConfig = map[int16]string{ 1: "model_id", @@ -262,7 +328,11 @@ var fieldIDToName_ModelConfig = map[int16]string{ 8: "top_k", 9: "presence_penalty", 10: "frequency_penalty", + 11: "identification", + 12: "protocol", + 13: "preset_model", 100: "param_config_values", + 101: "extra", } func (p *ModelConfig) IsSetTemperature() bool { @@ -301,10 +371,26 @@ func (p *ModelConfig) IsSetFrequencyPenalty() bool { return p.FrequencyPenalty != nil } +func (p *ModelConfig) IsSetIdentification() bool { + return p.Identification != nil +} + +func (p *ModelConfig) IsSetProtocol() bool { + return p.Protocol != nil +} + +func (p *ModelConfig) IsSetPresetModel() bool { + return p.PresetModel != nil +} + func (p *ModelConfig) IsSetParamConfigValues() bool { return p.ParamConfigValues != nil } +func (p *ModelConfig) IsSetExtra() bool { + return p.Extra != nil +} + func (p *ModelConfig) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -405,6 +491,30 @@ func (p *ModelConfig) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 11: + if fieldTypeId == thrift.STRING { + if err = p.ReadField11(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 12: + if fieldTypeId == thrift.STRING { + if err = p.ReadField12(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 13: + if fieldTypeId == thrift.BOOL { + if err = p.ReadField13(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } case 100: if fieldTypeId == thrift.LIST { if err = p.ReadField100(iprot); err != nil { @@ -413,6 +523,14 @@ func (p *ModelConfig) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 101: + if fieldTypeId == thrift.STRING { + if err = p.ReadField101(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError @@ -567,6 +685,39 @@ func (p *ModelConfig) ReadField10(iprot thrift.TProtocol) error { p.FrequencyPenalty = _field return nil } +func (p *ModelConfig) ReadField11(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Identification = _field + return nil +} +func (p *ModelConfig) ReadField12(iprot thrift.TProtocol) error { + + var _field *manage.Protocol + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Protocol = _field + return nil +} +func (p *ModelConfig) ReadField13(iprot thrift.TProtocol) error { + + var _field *bool + if v, err := iprot.ReadBool(); err != nil { + return err + } else { + _field = &v + } + p.PresetModel = _field + return nil +} func (p *ModelConfig) ReadField100(iprot thrift.TProtocol) error { _, size, err := iprot.ReadListBegin() if err != nil { @@ -590,6 +741,17 @@ func (p *ModelConfig) ReadField100(iprot thrift.TProtocol) error { p.ParamConfigValues = _field return nil } +func (p *ModelConfig) ReadField101(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Extra = _field + return nil +} func (p *ModelConfig) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 @@ -637,10 +799,26 @@ func (p *ModelConfig) Write(oprot thrift.TProtocol) (err error) { fieldId = 10 goto WriteFieldError } + if err = p.writeField11(oprot); err != nil { + fieldId = 11 + goto WriteFieldError + } + if err = p.writeField12(oprot); err != nil { + fieldId = 12 + goto WriteFieldError + } + if err = p.writeField13(oprot); err != nil { + fieldId = 13 + goto WriteFieldError + } if err = p.writeField100(oprot); err != nil { fieldId = 100 goto WriteFieldError } + if err = p.writeField101(oprot); err != nil { + fieldId = 101 + goto WriteFieldError + } } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -845,6 +1023,60 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 10 end error: ", p), err) } +func (p *ModelConfig) writeField11(oprot thrift.TProtocol) (err error) { + if p.IsSetIdentification() { + if err = oprot.WriteFieldBegin("identification", thrift.STRING, 11); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Identification); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 11 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 11 end error: ", p), err) +} +func (p *ModelConfig) writeField12(oprot thrift.TProtocol) (err error) { + if p.IsSetProtocol() { + if err = oprot.WriteFieldBegin("protocol", thrift.STRING, 12); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Protocol); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 12 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 12 end error: ", p), err) +} +func (p *ModelConfig) writeField13(oprot thrift.TProtocol) (err error) { + if p.IsSetPresetModel() { + if err = oprot.WriteFieldBegin("preset_model", thrift.BOOL, 13); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteBool(*p.PresetModel); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 13 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 13 end error: ", p), err) +} func (p *ModelConfig) writeField100(oprot thrift.TProtocol) (err error) { if p.IsSetParamConfigValues() { if err = oprot.WriteFieldBegin("param_config_values", thrift.LIST, 100); err != nil { @@ -871,6 +1103,24 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 100 end error: ", p), err) } +func (p *ModelConfig) writeField101(oprot thrift.TProtocol) (err error) { + if p.IsSetExtra() { + if err = oprot.WriteFieldBegin("extra", thrift.STRING, 101); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Extra); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 101 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 101 end error: ", p), err) +} func (p *ModelConfig) String() string { if p == nil { @@ -916,9 +1166,21 @@ func (p *ModelConfig) DeepEqual(ano *ModelConfig) bool { if !p.Field10DeepEqual(ano.FrequencyPenalty) { return false } + if !p.Field11DeepEqual(ano.Identification) { + return false + } + if !p.Field12DeepEqual(ano.Protocol) { + return false + } + if !p.Field13DeepEqual(ano.PresetModel) { + return false + } if !p.Field100DeepEqual(ano.ParamConfigValues) { return false } + if !p.Field101DeepEqual(ano.Extra) { + return false + } return true } @@ -1033,6 +1295,42 @@ func (p *ModelConfig) Field10DeepEqual(src *float64) bool { } return true } +func (p *ModelConfig) Field11DeepEqual(src *string) bool { + + if p.Identification == src { + return true + } else if p.Identification == nil || src == nil { + return false + } + if strings.Compare(*p.Identification, *src) != 0 { + return false + } + return true +} +func (p *ModelConfig) Field12DeepEqual(src *manage.Protocol) bool { + + if p.Protocol == src { + return true + } else if p.Protocol == nil || src == nil { + return false + } + if strings.Compare(*p.Protocol, *src) != 0 { + return false + } + return true +} +func (p *ModelConfig) Field13DeepEqual(src *bool) bool { + + if p.PresetModel == src { + return true + } else if p.PresetModel == nil || src == nil { + return false + } + if *p.PresetModel != *src { + return false + } + return true +} func (p *ModelConfig) Field100DeepEqual(src []*ParamConfigValue) bool { if len(p.ParamConfigValues) != len(src) { @@ -1046,6 +1344,18 @@ func (p *ModelConfig) Field100DeepEqual(src []*ParamConfigValue) bool { } return true } +func (p *ModelConfig) Field101DeepEqual(src *string) bool { + + if p.Extra == src { + return true + } else if p.Extra == nil || src == nil { + return false + } + if strings.Compare(*p.Extra, *src) != 0 { + return false + } + return true +} type ParamConfigValue struct { // 传给下游模型的key,与ParamSchema.name对齐 diff --git a/backend/kitex_gen/coze/loop/llm/manage/coze.loop.llm.manage.go b/backend/kitex_gen/coze/loop/llm/manage/coze.loop.llm.manage.go index a2676d2f8..58bc08dc2 100644 --- a/backend/kitex_gen/coze/loop/llm/manage/coze.loop.llm.manage.go +++ b/backend/kitex_gen/coze/loop/llm/manage/coze.loop.llm.manage.go @@ -12,12 +12,492 @@ import ( "strings" ) +type Filter struct { + NameLike *string `thrift:"name_like,1,optional" frugal:"1,optional,string" form:"name_like" json:"name_like,omitempty" query:"name_like"` + Families []manage.Family `thrift:"families,2,optional" frugal:"2,optional,list" form:"families" json:"families,omitempty" query:"families"` + Statuses []manage.ModelStatus `thrift:"statuses,3,optional" frugal:"3,optional,list" form:"statuses" json:"statuses,omitempty" query:"statuses"` + Abilities []manage.AbilityEnum `thrift:"abilities,4,optional" frugal:"4,optional,list" form:"abilities" json:"abilities,omitempty" query:"abilities"` +} + +func NewFilter() *Filter { + return &Filter{} +} + +func (p *Filter) InitDefault() { +} + +var Filter_NameLike_DEFAULT string + +func (p *Filter) GetNameLike() (v string) { + if p == nil { + return + } + if !p.IsSetNameLike() { + return Filter_NameLike_DEFAULT + } + return *p.NameLike +} + +var Filter_Families_DEFAULT []manage.Family + +func (p *Filter) GetFamilies() (v []manage.Family) { + if p == nil { + return + } + if !p.IsSetFamilies() { + return Filter_Families_DEFAULT + } + return p.Families +} + +var Filter_Statuses_DEFAULT []manage.ModelStatus + +func (p *Filter) GetStatuses() (v []manage.ModelStatus) { + if p == nil { + return + } + if !p.IsSetStatuses() { + return Filter_Statuses_DEFAULT + } + return p.Statuses +} + +var Filter_Abilities_DEFAULT []manage.AbilityEnum + +func (p *Filter) GetAbilities() (v []manage.AbilityEnum) { + if p == nil { + return + } + if !p.IsSetAbilities() { + return Filter_Abilities_DEFAULT + } + return p.Abilities +} +func (p *Filter) SetNameLike(val *string) { + p.NameLike = val +} +func (p *Filter) SetFamilies(val []manage.Family) { + p.Families = val +} +func (p *Filter) SetStatuses(val []manage.ModelStatus) { + p.Statuses = val +} +func (p *Filter) SetAbilities(val []manage.AbilityEnum) { + p.Abilities = val +} + +var fieldIDToName_Filter = map[int16]string{ + 1: "name_like", + 2: "families", + 3: "statuses", + 4: "abilities", +} + +func (p *Filter) IsSetNameLike() bool { + return p.NameLike != nil +} + +func (p *Filter) IsSetFamilies() bool { + return p.Families != nil +} + +func (p *Filter) IsSetStatuses() bool { + return p.Statuses != nil +} + +func (p *Filter) IsSetAbilities() bool { + return p.Abilities != nil +} + +func (p *Filter) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.LIST { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 3: + if fieldTypeId == thrift.LIST { + if err = p.ReadField3(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 4: + if fieldTypeId == thrift.LIST { + if err = p.ReadField4(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Filter[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *Filter) ReadField1(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.NameLike = _field + return nil +} +func (p *Filter) ReadField2(iprot thrift.TProtocol) error { + _, size, err := iprot.ReadListBegin() + if err != nil { + return err + } + _field := make([]manage.Family, 0, size) + for i := 0; i < size; i++ { + + var _elem manage.Family + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _elem = v + } + + _field = append(_field, _elem) + } + if err := iprot.ReadListEnd(); err != nil { + return err + } + p.Families = _field + return nil +} +func (p *Filter) ReadField3(iprot thrift.TProtocol) error { + _, size, err := iprot.ReadListBegin() + if err != nil { + return err + } + _field := make([]manage.ModelStatus, 0, size) + for i := 0; i < size; i++ { + + var _elem manage.ModelStatus + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _elem = v + } + + _field = append(_field, _elem) + } + if err := iprot.ReadListEnd(); err != nil { + return err + } + p.Statuses = _field + return nil +} +func (p *Filter) ReadField4(iprot thrift.TProtocol) error { + _, size, err := iprot.ReadListBegin() + if err != nil { + return err + } + _field := make([]manage.AbilityEnum, 0, size) + for i := 0; i < size; i++ { + + var _elem manage.AbilityEnum + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _elem = v + } + + _field = append(_field, _elem) + } + if err := iprot.ReadListEnd(); err != nil { + return err + } + p.Abilities = _field + return nil +} + +func (p *Filter) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("Filter"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + if err = p.writeField3(oprot); err != nil { + fieldId = 3 + goto WriteFieldError + } + if err = p.writeField4(oprot); err != nil { + fieldId = 4 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *Filter) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetNameLike() { + if err = oprot.WriteFieldBegin("name_like", thrift.STRING, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.NameLike); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *Filter) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetFamilies() { + if err = oprot.WriteFieldBegin("families", thrift.LIST, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteListBegin(thrift.STRING, len(p.Families)); err != nil { + return err + } + for _, v := range p.Families { + if err := oprot.WriteString(v); err != nil { + return err + } + } + if err := oprot.WriteListEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} +func (p *Filter) writeField3(oprot thrift.TProtocol) (err error) { + if p.IsSetStatuses() { + if err = oprot.WriteFieldBegin("statuses", thrift.LIST, 3); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteListBegin(thrift.STRING, len(p.Statuses)); err != nil { + return err + } + for _, v := range p.Statuses { + if err := oprot.WriteString(v); err != nil { + return err + } + } + if err := oprot.WriteListEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) +} +func (p *Filter) writeField4(oprot thrift.TProtocol) (err error) { + if p.IsSetAbilities() { + if err = oprot.WriteFieldBegin("abilities", thrift.LIST, 4); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteListBegin(thrift.STRING, len(p.Abilities)); err != nil { + return err + } + for _, v := range p.Abilities { + if err := oprot.WriteString(v); err != nil { + return err + } + } + if err := oprot.WriteListEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) +} + +func (p *Filter) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("Filter(%+v)", *p) + +} + +func (p *Filter) DeepEqual(ano *Filter) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.NameLike) { + return false + } + if !p.Field2DeepEqual(ano.Families) { + return false + } + if !p.Field3DeepEqual(ano.Statuses) { + return false + } + if !p.Field4DeepEqual(ano.Abilities) { + return false + } + return true +} + +func (p *Filter) Field1DeepEqual(src *string) bool { + + if p.NameLike == src { + return true + } else if p.NameLike == nil || src == nil { + return false + } + if strings.Compare(*p.NameLike, *src) != 0 { + return false + } + return true +} +func (p *Filter) Field2DeepEqual(src []manage.Family) bool { + + if len(p.Families) != len(src) { + return false + } + for i, v := range p.Families { + _src := src[i] + if strings.Compare(v, _src) != 0 { + return false + } + } + return true +} +func (p *Filter) Field3DeepEqual(src []manage.ModelStatus) bool { + + if len(p.Statuses) != len(src) { + return false + } + for i, v := range p.Statuses { + _src := src[i] + if strings.Compare(v, _src) != 0 { + return false + } + } + return true +} +func (p *Filter) Field4DeepEqual(src []manage.AbilityEnum) bool { + + if len(p.Abilities) != len(src) { + return false + } + for i, v := range p.Abilities { + _src := src[i] + if strings.Compare(v, _src) != 0 { + return false + } + } + return true +} + type ListModelsRequest struct { WorkspaceID *int64 `thrift:"workspace_id,1,optional" frugal:"1,optional,i64" json:"workspace_id" form:"workspace_id" query:"workspace_id"` Scenario *common.Scenario `thrift:"scenario,2,optional" frugal:"2,optional,string" form:"scenario" json:"scenario,omitempty" query:"scenario"` - PageSize *int32 `thrift:"page_size,127,optional" frugal:"127,optional,i32" form:"page_size" json:"page_size,omitempty" query:"page_size"` - PageToken *string `thrift:"page_token,128,optional" frugal:"128,optional,string" form:"page_token" json:"page_token,omitempty" query:"page_token"` - Base *base.Base `thrift:"Base,255,optional" frugal:"255,optional,base.Base" form:"Base" json:"Base,omitempty" query:"Base"` + Filter *Filter `thrift:"filter,3,optional" frugal:"3,optional,Filter" form:"filter" json:"filter,omitempty" query:"filter"` + // 是否为预置模型 + PresetModel *bool `thrift:"preset_model,4,optional" frugal:"4,optional,bool" form:"preset_model" json:"preset_model,omitempty" query:"preset_model"` + Cookie *string `thrift:"cookie,100,optional" frugal:"100,optional,string" header:"cookie" json:"cookie,omitempty"` + PageSize *int32 `thrift:"page_size,127,optional" frugal:"127,optional,i32" form:"page_size" json:"page_size,omitempty" query:"page_size"` + PageToken *string `thrift:"page_token,128,optional" frugal:"128,optional,string" form:"page_token" json:"page_token,omitempty" query:"page_token"` + Page *int32 `thrift:"page,129,optional" frugal:"129,optional,i32" form:"page" json:"page,omitempty" query:"page"` + Base *base.Base `thrift:"Base,255,optional" frugal:"255,optional,base.Base" form:"Base" json:"Base,omitempty" query:"Base"` } func NewListModelsRequest() *ListModelsRequest { @@ -51,6 +531,42 @@ func (p *ListModelsRequest) GetScenario() (v common.Scenario) { return *p.Scenario } +var ListModelsRequest_Filter_DEFAULT *Filter + +func (p *ListModelsRequest) GetFilter() (v *Filter) { + if p == nil { + return + } + if !p.IsSetFilter() { + return ListModelsRequest_Filter_DEFAULT + } + return p.Filter +} + +var ListModelsRequest_PresetModel_DEFAULT bool + +func (p *ListModelsRequest) GetPresetModel() (v bool) { + if p == nil { + return + } + if !p.IsSetPresetModel() { + return ListModelsRequest_PresetModel_DEFAULT + } + return *p.PresetModel +} + +var ListModelsRequest_Cookie_DEFAULT string + +func (p *ListModelsRequest) GetCookie() (v string) { + if p == nil { + return + } + if !p.IsSetCookie() { + return ListModelsRequest_Cookie_DEFAULT + } + return *p.Cookie +} + var ListModelsRequest_PageSize_DEFAULT int32 func (p *ListModelsRequest) GetPageSize() (v int32) { @@ -75,6 +591,18 @@ func (p *ListModelsRequest) GetPageToken() (v string) { return *p.PageToken } +var ListModelsRequest_Page_DEFAULT int32 + +func (p *ListModelsRequest) GetPage() (v int32) { + if p == nil { + return + } + if !p.IsSetPage() { + return ListModelsRequest_Page_DEFAULT + } + return *p.Page +} + var ListModelsRequest_Base_DEFAULT *base.Base func (p *ListModelsRequest) GetBase() (v *base.Base) { @@ -92,12 +620,24 @@ func (p *ListModelsRequest) SetWorkspaceID(val *int64) { func (p *ListModelsRequest) SetScenario(val *common.Scenario) { p.Scenario = val } +func (p *ListModelsRequest) SetFilter(val *Filter) { + p.Filter = val +} +func (p *ListModelsRequest) SetPresetModel(val *bool) { + p.PresetModel = val +} +func (p *ListModelsRequest) SetCookie(val *string) { + p.Cookie = val +} func (p *ListModelsRequest) SetPageSize(val *int32) { p.PageSize = val } func (p *ListModelsRequest) SetPageToken(val *string) { p.PageToken = val } +func (p *ListModelsRequest) SetPage(val *int32) { + p.Page = val +} func (p *ListModelsRequest) SetBase(val *base.Base) { p.Base = val } @@ -105,8 +645,12 @@ func (p *ListModelsRequest) SetBase(val *base.Base) { var fieldIDToName_ListModelsRequest = map[int16]string{ 1: "workspace_id", 2: "scenario", + 3: "filter", + 4: "preset_model", + 100: "cookie", 127: "page_size", 128: "page_token", + 129: "page", 255: "Base", } @@ -118,6 +662,18 @@ func (p *ListModelsRequest) IsSetScenario() bool { return p.Scenario != nil } +func (p *ListModelsRequest) IsSetFilter() bool { + return p.Filter != nil +} + +func (p *ListModelsRequest) IsSetPresetModel() bool { + return p.PresetModel != nil +} + +func (p *ListModelsRequest) IsSetCookie() bool { + return p.Cookie != nil +} + func (p *ListModelsRequest) IsSetPageSize() bool { return p.PageSize != nil } @@ -126,6 +682,10 @@ func (p *ListModelsRequest) IsSetPageToken() bool { return p.PageToken != nil } +func (p *ListModelsRequest) IsSetPage() bool { + return p.Page != nil +} + func (p *ListModelsRequest) IsSetBase() bool { return p.Base != nil } @@ -164,6 +724,30 @@ func (p *ListModelsRequest) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 3: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField3(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 4: + if fieldTypeId == thrift.BOOL { + if err = p.ReadField4(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 100: + if fieldTypeId == thrift.STRING { + if err = p.ReadField100(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } case 127: if fieldTypeId == thrift.I32 { if err = p.ReadField127(iprot); err != nil { @@ -180,6 +764,14 @@ func (p *ListModelsRequest) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 129: + if fieldTypeId == thrift.I32 { + if err = p.ReadField129(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } case 255: if fieldTypeId == thrift.STRUCT { if err = p.ReadField255(iprot); err != nil { @@ -228,15 +820,45 @@ func (p *ListModelsRequest) ReadField1(iprot thrift.TProtocol) error { p.WorkspaceID = _field return nil } -func (p *ListModelsRequest) ReadField2(iprot thrift.TProtocol) error { +func (p *ListModelsRequest) ReadField2(iprot thrift.TProtocol) error { + + var _field *common.Scenario + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Scenario = _field + return nil +} +func (p *ListModelsRequest) ReadField3(iprot thrift.TProtocol) error { + _field := NewFilter() + if err := _field.Read(iprot); err != nil { + return err + } + p.Filter = _field + return nil +} +func (p *ListModelsRequest) ReadField4(iprot thrift.TProtocol) error { + + var _field *bool + if v, err := iprot.ReadBool(); err != nil { + return err + } else { + _field = &v + } + p.PresetModel = _field + return nil +} +func (p *ListModelsRequest) ReadField100(iprot thrift.TProtocol) error { - var _field *common.Scenario + var _field *string if v, err := iprot.ReadString(); err != nil { return err } else { _field = &v } - p.Scenario = _field + p.Cookie = _field return nil } func (p *ListModelsRequest) ReadField127(iprot thrift.TProtocol) error { @@ -261,6 +883,17 @@ func (p *ListModelsRequest) ReadField128(iprot thrift.TProtocol) error { p.PageToken = _field return nil } +func (p *ListModelsRequest) ReadField129(iprot thrift.TProtocol) error { + + var _field *int32 + if v, err := iprot.ReadI32(); err != nil { + return err + } else { + _field = &v + } + p.Page = _field + return nil +} func (p *ListModelsRequest) ReadField255(iprot thrift.TProtocol) error { _field := base.NewBase() if err := _field.Read(iprot); err != nil { @@ -284,6 +917,18 @@ func (p *ListModelsRequest) Write(oprot thrift.TProtocol) (err error) { fieldId = 2 goto WriteFieldError } + if err = p.writeField3(oprot); err != nil { + fieldId = 3 + goto WriteFieldError + } + if err = p.writeField4(oprot); err != nil { + fieldId = 4 + goto WriteFieldError + } + if err = p.writeField100(oprot); err != nil { + fieldId = 100 + goto WriteFieldError + } if err = p.writeField127(oprot); err != nil { fieldId = 127 goto WriteFieldError @@ -292,6 +937,10 @@ func (p *ListModelsRequest) Write(oprot thrift.TProtocol) (err error) { fieldId = 128 goto WriteFieldError } + if err = p.writeField129(oprot); err != nil { + fieldId = 129 + goto WriteFieldError + } if err = p.writeField255(oprot); err != nil { fieldId = 255 goto WriteFieldError @@ -350,6 +999,60 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) } +func (p *ListModelsRequest) writeField3(oprot thrift.TProtocol) (err error) { + if p.IsSetFilter() { + if err = oprot.WriteFieldBegin("filter", thrift.STRUCT, 3); err != nil { + goto WriteFieldBeginError + } + if err := p.Filter.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) +} +func (p *ListModelsRequest) writeField4(oprot thrift.TProtocol) (err error) { + if p.IsSetPresetModel() { + if err = oprot.WriteFieldBegin("preset_model", thrift.BOOL, 4); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteBool(*p.PresetModel); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) +} +func (p *ListModelsRequest) writeField100(oprot thrift.TProtocol) (err error) { + if p.IsSetCookie() { + if err = oprot.WriteFieldBegin("cookie", thrift.STRING, 100); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Cookie); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 100 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 100 end error: ", p), err) +} func (p *ListModelsRequest) writeField127(oprot thrift.TProtocol) (err error) { if p.IsSetPageSize() { if err = oprot.WriteFieldBegin("page_size", thrift.I32, 127); err != nil { @@ -386,6 +1089,24 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 128 end error: ", p), err) } +func (p *ListModelsRequest) writeField129(oprot thrift.TProtocol) (err error) { + if p.IsSetPage() { + if err = oprot.WriteFieldBegin("page", thrift.I32, 129); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteI32(*p.Page); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 129 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 129 end error: ", p), err) +} func (p *ListModelsRequest) writeField255(oprot thrift.TProtocol) (err error) { if p.IsSetBase() { if err = oprot.WriteFieldBegin("Base", thrift.STRUCT, 255); err != nil { @@ -425,12 +1146,24 @@ func (p *ListModelsRequest) DeepEqual(ano *ListModelsRequest) bool { if !p.Field2DeepEqual(ano.Scenario) { return false } + if !p.Field3DeepEqual(ano.Filter) { + return false + } + if !p.Field4DeepEqual(ano.PresetModel) { + return false + } + if !p.Field100DeepEqual(ano.Cookie) { + return false + } if !p.Field127DeepEqual(ano.PageSize) { return false } if !p.Field128DeepEqual(ano.PageToken) { return false } + if !p.Field129DeepEqual(ano.Page) { + return false + } if !p.Field255DeepEqual(ano.Base) { return false } @@ -461,6 +1194,37 @@ func (p *ListModelsRequest) Field2DeepEqual(src *common.Scenario) bool { } return true } +func (p *ListModelsRequest) Field3DeepEqual(src *Filter) bool { + + if !p.Filter.DeepEqual(src) { + return false + } + return true +} +func (p *ListModelsRequest) Field4DeepEqual(src *bool) bool { + + if p.PresetModel == src { + return true + } else if p.PresetModel == nil || src == nil { + return false + } + if *p.PresetModel != *src { + return false + } + return true +} +func (p *ListModelsRequest) Field100DeepEqual(src *string) bool { + + if p.Cookie == src { + return true + } else if p.Cookie == nil || src == nil { + return false + } + if strings.Compare(*p.Cookie, *src) != 0 { + return false + } + return true +} func (p *ListModelsRequest) Field127DeepEqual(src *int32) bool { if p.PageSize == src { @@ -485,6 +1249,18 @@ func (p *ListModelsRequest) Field128DeepEqual(src *string) bool { } return true } +func (p *ListModelsRequest) Field129DeepEqual(src *int32) bool { + + if p.Page == src { + return true + } else if p.Page == nil || src == nil { + return false + } + if *p.Page != *src { + return false + } + return true +} func (p *ListModelsRequest) Field255DeepEqual(src *base.Base) bool { if !p.Base.DeepEqual(src) { @@ -994,8 +1770,12 @@ func (p *ListModelsResponse) Field255DeepEqual(src *base.BaseResp) bool { } type GetModelRequest struct { - WorkspaceID *int64 `thrift:"workspace_id,1,optional" frugal:"1,optional,i64" json:"workspace_id" form:"workspace_id" query:"workspace_id"` - ModelID *int64 `thrift:"model_id,2,optional" frugal:"2,optional,i64" json:"model_id" path:"model_id" ` + WorkspaceID *int64 `thrift:"workspace_id,1,optional" frugal:"1,optional,i64" json:"workspace_id" form:"workspace_id" query:"workspace_id"` + ModelID *int64 `thrift:"model_id,2,optional" frugal:"2,optional,i64" json:"model_id" path:"model_id" ` + Identification *string `thrift:"identification,3,optional" frugal:"3,optional,string" form:"identification" json:"identification,omitempty" query:"identification"` + Protocol *manage.Protocol `thrift:"protocol,4,optional" frugal:"4,optional,string" form:"protocol" json:"protocol,omitempty" query:"protocol"` + // 是否为预置模型 + PresetModel *bool `thrift:"preset_model,5,optional" frugal:"5,optional,bool" form:"preset_model" json:"preset_model,omitempty" query:"preset_model"` Base *base.Base `thrift:"Base,255,optional" frugal:"255,optional,base.Base" form:"Base" json:"Base,omitempty" query:"Base"` } @@ -1030,6 +1810,42 @@ func (p *GetModelRequest) GetModelID() (v int64) { return *p.ModelID } +var GetModelRequest_Identification_DEFAULT string + +func (p *GetModelRequest) GetIdentification() (v string) { + if p == nil { + return + } + if !p.IsSetIdentification() { + return GetModelRequest_Identification_DEFAULT + } + return *p.Identification +} + +var GetModelRequest_Protocol_DEFAULT manage.Protocol + +func (p *GetModelRequest) GetProtocol() (v manage.Protocol) { + if p == nil { + return + } + if !p.IsSetProtocol() { + return GetModelRequest_Protocol_DEFAULT + } + return *p.Protocol +} + +var GetModelRequest_PresetModel_DEFAULT bool + +func (p *GetModelRequest) GetPresetModel() (v bool) { + if p == nil { + return + } + if !p.IsSetPresetModel() { + return GetModelRequest_PresetModel_DEFAULT + } + return *p.PresetModel +} + var GetModelRequest_Base_DEFAULT *base.Base func (p *GetModelRequest) GetBase() (v *base.Base) { @@ -1047,6 +1863,15 @@ func (p *GetModelRequest) SetWorkspaceID(val *int64) { func (p *GetModelRequest) SetModelID(val *int64) { p.ModelID = val } +func (p *GetModelRequest) SetIdentification(val *string) { + p.Identification = val +} +func (p *GetModelRequest) SetProtocol(val *manage.Protocol) { + p.Protocol = val +} +func (p *GetModelRequest) SetPresetModel(val *bool) { + p.PresetModel = val +} func (p *GetModelRequest) SetBase(val *base.Base) { p.Base = val } @@ -1054,6 +1879,9 @@ func (p *GetModelRequest) SetBase(val *base.Base) { var fieldIDToName_GetModelRequest = map[int16]string{ 1: "workspace_id", 2: "model_id", + 3: "identification", + 4: "protocol", + 5: "preset_model", 255: "Base", } @@ -1065,6 +1893,18 @@ func (p *GetModelRequest) IsSetModelID() bool { return p.ModelID != nil } +func (p *GetModelRequest) IsSetIdentification() bool { + return p.Identification != nil +} + +func (p *GetModelRequest) IsSetProtocol() bool { + return p.Protocol != nil +} + +func (p *GetModelRequest) IsSetPresetModel() bool { + return p.PresetModel != nil +} + func (p *GetModelRequest) IsSetBase() bool { return p.Base != nil } @@ -1103,6 +1943,30 @@ func (p *GetModelRequest) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 3: + if fieldTypeId == thrift.STRING { + if err = p.ReadField3(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 4: + if fieldTypeId == thrift.STRING { + if err = p.ReadField4(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 5: + if fieldTypeId == thrift.BOOL { + if err = p.ReadField5(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } case 255: if fieldTypeId == thrift.STRUCT { if err = p.ReadField255(iprot); err != nil { @@ -1162,6 +2026,39 @@ func (p *GetModelRequest) ReadField2(iprot thrift.TProtocol) error { p.ModelID = _field return nil } +func (p *GetModelRequest) ReadField3(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Identification = _field + return nil +} +func (p *GetModelRequest) ReadField4(iprot thrift.TProtocol) error { + + var _field *manage.Protocol + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Protocol = _field + return nil +} +func (p *GetModelRequest) ReadField5(iprot thrift.TProtocol) error { + + var _field *bool + if v, err := iprot.ReadBool(); err != nil { + return err + } else { + _field = &v + } + p.PresetModel = _field + return nil +} func (p *GetModelRequest) ReadField255(iprot thrift.TProtocol) error { _field := base.NewBase() if err := _field.Read(iprot); err != nil { @@ -1185,6 +2082,18 @@ func (p *GetModelRequest) Write(oprot thrift.TProtocol) (err error) { fieldId = 2 goto WriteFieldError } + if err = p.writeField3(oprot); err != nil { + fieldId = 3 + goto WriteFieldError + } + if err = p.writeField4(oprot); err != nil { + fieldId = 4 + goto WriteFieldError + } + if err = p.writeField5(oprot); err != nil { + fieldId = 5 + goto WriteFieldError + } if err = p.writeField255(oprot); err != nil { fieldId = 255 goto WriteFieldError @@ -1243,6 +2152,60 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) } +func (p *GetModelRequest) writeField3(oprot thrift.TProtocol) (err error) { + if p.IsSetIdentification() { + if err = oprot.WriteFieldBegin("identification", thrift.STRING, 3); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Identification); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) +} +func (p *GetModelRequest) writeField4(oprot thrift.TProtocol) (err error) { + if p.IsSetProtocol() { + if err = oprot.WriteFieldBegin("protocol", thrift.STRING, 4); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Protocol); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) +} +func (p *GetModelRequest) writeField5(oprot thrift.TProtocol) (err error) { + if p.IsSetPresetModel() { + if err = oprot.WriteFieldBegin("preset_model", thrift.BOOL, 5); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteBool(*p.PresetModel); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 5 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 5 end error: ", p), err) +} func (p *GetModelRequest) writeField255(oprot thrift.TProtocol) (err error) { if p.IsSetBase() { if err = oprot.WriteFieldBegin("Base", thrift.STRUCT, 255); err != nil { @@ -1282,6 +2245,15 @@ func (p *GetModelRequest) DeepEqual(ano *GetModelRequest) bool { if !p.Field2DeepEqual(ano.ModelID) { return false } + if !p.Field3DeepEqual(ano.Identification) { + return false + } + if !p.Field4DeepEqual(ano.Protocol) { + return false + } + if !p.Field5DeepEqual(ano.PresetModel) { + return false + } if !p.Field255DeepEqual(ano.Base) { return false } @@ -1312,6 +2284,42 @@ func (p *GetModelRequest) Field2DeepEqual(src *int64) bool { } return true } +func (p *GetModelRequest) Field3DeepEqual(src *string) bool { + + if p.Identification == src { + return true + } else if p.Identification == nil || src == nil { + return false + } + if strings.Compare(*p.Identification, *src) != 0 { + return false + } + return true +} +func (p *GetModelRequest) Field4DeepEqual(src *manage.Protocol) bool { + + if p.Protocol == src { + return true + } else if p.Protocol == nil || src == nil { + return false + } + if strings.Compare(*p.Protocol, *src) != 0 { + return false + } + return true +} +func (p *GetModelRequest) Field5DeepEqual(src *bool) bool { + + if p.PresetModel == src { + return true + } else if p.PresetModel == nil || src == nil { + return false + } + if *p.PresetModel != *src { + return false + } + return true +} func (p *GetModelRequest) Field255DeepEqual(src *base.Base) bool { if !p.Base.DeepEqual(src) { diff --git a/backend/kitex_gen/coze/loop/llm/manage/coze.loop.llm.manage_validator.go b/backend/kitex_gen/coze/loop/llm/manage/coze.loop.llm.manage_validator.go index 7fdd1dbf0..f2a625cda 100644 --- a/backend/kitex_gen/coze/loop/llm/manage/coze.loop.llm.manage_validator.go +++ b/backend/kitex_gen/coze/loop/llm/manage/coze.loop.llm.manage_validator.go @@ -21,6 +21,9 @@ var ( _ = time.Nanosecond ) +func (p *Filter) IsValid() error { + return nil +} func (p *ListModelsRequest) IsValid() error { if p.WorkspaceID == nil { return fmt.Errorf("field WorkspaceID not_nil rule failed") @@ -28,6 +31,11 @@ func (p *ListModelsRequest) IsValid() error { if *p.WorkspaceID <= int64(0) { return fmt.Errorf("field WorkspaceID gt rule failed, current value: %v", *p.WorkspaceID) } + if p.Filter != nil { + if err := p.Filter.IsValid(); err != nil { + return fmt.Errorf("field Filter not valid, %w", err) + } + } if p.Base != nil { if err := p.Base.IsValid(); err != nil { return fmt.Errorf("field Base not valid, %w", err) diff --git a/backend/kitex_gen/coze/loop/llm/manage/k-coze.loop.llm.manage.go b/backend/kitex_gen/coze/loop/llm/manage/k-coze.loop.llm.manage.go index fa64d1cdc..a28041d95 100644 --- a/backend/kitex_gen/coze/loop/llm/manage/k-coze.loop.llm.manage.go +++ b/backend/kitex_gen/coze/loop/llm/manage/k-coze.loop.llm.manage.go @@ -31,6 +31,360 @@ var ( _ = thrift.STOP ) +func (p *Filter) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.LIST { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 3: + if fieldTypeId == thrift.LIST { + l, err = p.FastReadField3(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 4: + if fieldTypeId == thrift.LIST { + l, err = p.FastReadField4(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Filter[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *Filter) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.NameLike = _field + return offset, nil +} + +func (p *Filter) FastReadField2(buf []byte) (int, error) { + offset := 0 + + _, size, l, err := thrift.Binary.ReadListBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _field := make([]manage.Family, 0, size) + for i := 0; i < size; i++ { + var _elem manage.Family + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _elem = v + } + + _field = append(_field, _elem) + } + p.Families = _field + return offset, nil +} + +func (p *Filter) FastReadField3(buf []byte) (int, error) { + offset := 0 + + _, size, l, err := thrift.Binary.ReadListBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _field := make([]manage.ModelStatus, 0, size) + for i := 0; i < size; i++ { + var _elem manage.ModelStatus + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _elem = v + } + + _field = append(_field, _elem) + } + p.Statuses = _field + return offset, nil +} + +func (p *Filter) FastReadField4(buf []byte) (int, error) { + offset := 0 + + _, size, l, err := thrift.Binary.ReadListBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _field := make([]manage.AbilityEnum, 0, size) + for i := 0; i < size; i++ { + var _elem manage.AbilityEnum + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _elem = v + } + + _field = append(_field, _elem) + } + p.Abilities = _field + return offset, nil +} + +func (p *Filter) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *Filter) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + offset += p.fastWriteField3(buf[offset:], w) + offset += p.fastWriteField4(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *Filter) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + l += p.field3Length() + l += p.field4Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *Filter) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetNameLike() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 1) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.NameLike) + } + return offset +} + +func (p *Filter) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetFamilies() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.LIST, 2) + listBeginOffset := offset + offset += thrift.Binary.ListBeginLength() + var length int + for _, v := range p.Families { + length++ + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, v) + } + thrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRING, length) + } + return offset +} + +func (p *Filter) fastWriteField3(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetStatuses() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.LIST, 3) + listBeginOffset := offset + offset += thrift.Binary.ListBeginLength() + var length int + for _, v := range p.Statuses { + length++ + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, v) + } + thrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRING, length) + } + return offset +} + +func (p *Filter) fastWriteField4(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetAbilities() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.LIST, 4) + listBeginOffset := offset + offset += thrift.Binary.ListBeginLength() + var length int + for _, v := range p.Abilities { + length++ + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, v) + } + thrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRING, length) + } + return offset +} + +func (p *Filter) field1Length() int { + l := 0 + if p.IsSetNameLike() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.NameLike) + } + return l +} + +func (p *Filter) field2Length() int { + l := 0 + if p.IsSetFamilies() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.ListBeginLength() + for _, v := range p.Families { + _ = v + l += thrift.Binary.StringLengthNocopy(v) + } + } + return l +} + +func (p *Filter) field3Length() int { + l := 0 + if p.IsSetStatuses() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.ListBeginLength() + for _, v := range p.Statuses { + _ = v + l += thrift.Binary.StringLengthNocopy(v) + } + } + return l +} + +func (p *Filter) field4Length() int { + l := 0 + if p.IsSetAbilities() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.ListBeginLength() + for _, v := range p.Abilities { + _ = v + l += thrift.Binary.StringLengthNocopy(v) + } + } + return l +} + +func (p *Filter) DeepCopy(s interface{}) error { + src, ok := s.(*Filter) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.NameLike != nil { + var tmp string + if *src.NameLike != "" { + tmp = kutils.StringDeepCopy(*src.NameLike) + } + p.NameLike = &tmp + } + + if src.Families != nil { + p.Families = make([]manage.Family, 0, len(src.Families)) + for _, elem := range src.Families { + var _elem manage.Family + _elem = elem + p.Families = append(p.Families, _elem) + } + } + + if src.Statuses != nil { + p.Statuses = make([]manage.ModelStatus, 0, len(src.Statuses)) + for _, elem := range src.Statuses { + var _elem manage.ModelStatus + _elem = elem + p.Statuses = append(p.Statuses, _elem) + } + } + + if src.Abilities != nil { + p.Abilities = make([]manage.AbilityEnum, 0, len(src.Abilities)) + for _, elem := range src.Abilities { + var _elem manage.AbilityEnum + _elem = elem + p.Abilities = append(p.Abilities, _elem) + } + } + + return nil +} + func (p *ListModelsRequest) FastRead(buf []byte) (int, error) { var err error @@ -62,9 +416,51 @@ func (p *ListModelsRequest) FastRead(buf []byte) (int, error) { goto SkipFieldError } } - case 2: + case 2: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 3: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField3(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 4: + if fieldTypeId == thrift.BOOL { + l, err = p.FastReadField4(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 100: if fieldTypeId == thrift.STRING { - l, err = p.FastReadField2(buf[offset:]) + l, err = p.FastReadField100(buf[offset:]) offset += l if err != nil { goto ReadFieldError @@ -104,6 +500,20 @@ func (p *ListModelsRequest) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 129: + if fieldTypeId == thrift.I32 { + l, err = p.FastReadField129(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } case 255: if fieldTypeId == thrift.STRUCT { l, err = p.FastReadField255(buf[offset:]) @@ -164,6 +574,46 @@ func (p *ListModelsRequest) FastReadField2(buf []byte) (int, error) { return offset, nil } +func (p *ListModelsRequest) FastReadField3(buf []byte) (int, error) { + offset := 0 + _field := NewFilter() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.Filter = _field + return offset, nil +} + +func (p *ListModelsRequest) FastReadField4(buf []byte) (int, error) { + offset := 0 + + var _field *bool + if v, l, err := thrift.Binary.ReadBool(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.PresetModel = _field + return offset, nil +} + +func (p *ListModelsRequest) FastReadField100(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Cookie = _field + return offset, nil +} + func (p *ListModelsRequest) FastReadField127(buf []byte) (int, error) { offset := 0 @@ -192,6 +642,20 @@ func (p *ListModelsRequest) FastReadField128(buf []byte) (int, error) { return offset, nil } +func (p *ListModelsRequest) FastReadField129(buf []byte) (int, error) { + offset := 0 + + var _field *int32 + if v, l, err := thrift.Binary.ReadI32(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Page = _field + return offset, nil +} + func (p *ListModelsRequest) FastReadField255(buf []byte) (int, error) { offset := 0 _field := base.NewBase() @@ -212,8 +676,12 @@ func (p *ListModelsRequest) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) i offset := 0 if p != nil { offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField4(buf[offset:], w) offset += p.fastWriteField127(buf[offset:], w) + offset += p.fastWriteField129(buf[offset:], w) offset += p.fastWriteField2(buf[offset:], w) + offset += p.fastWriteField3(buf[offset:], w) + offset += p.fastWriteField100(buf[offset:], w) offset += p.fastWriteField128(buf[offset:], w) offset += p.fastWriteField255(buf[offset:], w) } @@ -226,8 +694,12 @@ func (p *ListModelsRequest) BLength() int { if p != nil { l += p.field1Length() l += p.field2Length() + l += p.field3Length() + l += p.field4Length() + l += p.field100Length() l += p.field127Length() l += p.field128Length() + l += p.field129Length() l += p.field255Length() } l += thrift.Binary.FieldStopLength() @@ -252,6 +724,33 @@ func (p *ListModelsRequest) fastWriteField2(buf []byte, w thrift.NocopyWriter) i return offset } +func (p *ListModelsRequest) fastWriteField3(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetFilter() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 3) + offset += p.Filter.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *ListModelsRequest) fastWriteField4(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetPresetModel() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.BOOL, 4) + offset += thrift.Binary.WriteBool(buf[offset:], *p.PresetModel) + } + return offset +} + +func (p *ListModelsRequest) fastWriteField100(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetCookie() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 100) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Cookie) + } + return offset +} + func (p *ListModelsRequest) fastWriteField127(buf []byte, w thrift.NocopyWriter) int { offset := 0 if p.IsSetPageSize() { @@ -270,6 +769,15 @@ func (p *ListModelsRequest) fastWriteField128(buf []byte, w thrift.NocopyWriter) return offset } +func (p *ListModelsRequest) fastWriteField129(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetPage() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I32, 129) + offset += thrift.Binary.WriteI32(buf[offset:], *p.Page) + } + return offset +} + func (p *ListModelsRequest) fastWriteField255(buf []byte, w thrift.NocopyWriter) int { offset := 0 if p.IsSetBase() { @@ -297,6 +805,33 @@ func (p *ListModelsRequest) field2Length() int { return l } +func (p *ListModelsRequest) field3Length() int { + l := 0 + if p.IsSetFilter() { + l += thrift.Binary.FieldBeginLength() + l += p.Filter.BLength() + } + return l +} + +func (p *ListModelsRequest) field4Length() int { + l := 0 + if p.IsSetPresetModel() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.BoolLength() + } + return l +} + +func (p *ListModelsRequest) field100Length() int { + l := 0 + if p.IsSetCookie() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Cookie) + } + return l +} + func (p *ListModelsRequest) field127Length() int { l := 0 if p.IsSetPageSize() { @@ -315,6 +850,15 @@ func (p *ListModelsRequest) field128Length() int { return l } +func (p *ListModelsRequest) field129Length() int { + l := 0 + if p.IsSetPage() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I32Length() + } + return l +} + func (p *ListModelsRequest) field255Length() int { l := 0 if p.IsSetBase() { @@ -340,6 +884,28 @@ func (p *ListModelsRequest) DeepCopy(s interface{}) error { p.Scenario = &tmp } + var _filter *Filter + if src.Filter != nil { + _filter = &Filter{} + if err := _filter.DeepCopy(src.Filter); err != nil { + return err + } + } + p.Filter = _filter + + if src.PresetModel != nil { + tmp := *src.PresetModel + p.PresetModel = &tmp + } + + if src.Cookie != nil { + var tmp string + if *src.Cookie != "" { + tmp = kutils.StringDeepCopy(*src.Cookie) + } + p.Cookie = &tmp + } + if src.PageSize != nil { tmp := *src.PageSize p.PageSize = &tmp @@ -353,6 +919,11 @@ func (p *ListModelsRequest) DeepCopy(s interface{}) error { p.PageToken = &tmp } + if src.Page != nil { + tmp := *src.Page + p.Page = &tmp + } + var _base *base.Base if src.Base != nil { _base = &base.Base{} @@ -772,6 +1343,48 @@ func (p *GetModelRequest) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 3: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField3(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 4: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField4(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 5: + if fieldTypeId == thrift.BOOL { + l, err = p.FastReadField5(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } case 255: if fieldTypeId == thrift.STRUCT { l, err = p.FastReadField255(buf[offset:]) @@ -832,6 +1445,48 @@ func (p *GetModelRequest) FastReadField2(buf []byte) (int, error) { return offset, nil } +func (p *GetModelRequest) FastReadField3(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Identification = _field + return offset, nil +} + +func (p *GetModelRequest) FastReadField4(buf []byte) (int, error) { + offset := 0 + + var _field *manage.Protocol + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Protocol = _field + return offset, nil +} + +func (p *GetModelRequest) FastReadField5(buf []byte) (int, error) { + offset := 0 + + var _field *bool + if v, l, err := thrift.Binary.ReadBool(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.PresetModel = _field + return offset, nil +} + func (p *GetModelRequest) FastReadField255(buf []byte) (int, error) { offset := 0 _field := base.NewBase() @@ -853,6 +1508,9 @@ func (p *GetModelRequest) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int if p != nil { offset += p.fastWriteField1(buf[offset:], w) offset += p.fastWriteField2(buf[offset:], w) + offset += p.fastWriteField5(buf[offset:], w) + offset += p.fastWriteField3(buf[offset:], w) + offset += p.fastWriteField4(buf[offset:], w) offset += p.fastWriteField255(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) @@ -864,6 +1522,9 @@ func (p *GetModelRequest) BLength() int { if p != nil { l += p.field1Length() l += p.field2Length() + l += p.field3Length() + l += p.field4Length() + l += p.field5Length() l += p.field255Length() } l += thrift.Binary.FieldStopLength() @@ -888,6 +1549,33 @@ func (p *GetModelRequest) fastWriteField2(buf []byte, w thrift.NocopyWriter) int return offset } +func (p *GetModelRequest) fastWriteField3(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetIdentification() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 3) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Identification) + } + return offset +} + +func (p *GetModelRequest) fastWriteField4(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetProtocol() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 4) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Protocol) + } + return offset +} + +func (p *GetModelRequest) fastWriteField5(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetPresetModel() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.BOOL, 5) + offset += thrift.Binary.WriteBool(buf[offset:], *p.PresetModel) + } + return offset +} + func (p *GetModelRequest) fastWriteField255(buf []byte, w thrift.NocopyWriter) int { offset := 0 if p.IsSetBase() { @@ -915,6 +1603,33 @@ func (p *GetModelRequest) field2Length() int { return l } +func (p *GetModelRequest) field3Length() int { + l := 0 + if p.IsSetIdentification() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Identification) + } + return l +} + +func (p *GetModelRequest) field4Length() int { + l := 0 + if p.IsSetProtocol() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Protocol) + } + return l +} + +func (p *GetModelRequest) field5Length() int { + l := 0 + if p.IsSetPresetModel() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.BoolLength() + } + return l +} + func (p *GetModelRequest) field255Length() int { l := 0 if p.IsSetBase() { @@ -940,6 +1655,24 @@ func (p *GetModelRequest) DeepCopy(s interface{}) error { p.ModelID = &tmp } + if src.Identification != nil { + var tmp string + if *src.Identification != "" { + tmp = kutils.StringDeepCopy(*src.Identification) + } + p.Identification = &tmp + } + + if src.Protocol != nil { + tmp := *src.Protocol + p.Protocol = &tmp + } + + if src.PresetModel != nil { + tmp := *src.PresetModel + p.PresetModel = &tmp + } + var _base *base.Base if src.Base != nil { _base = &base.Base{} diff --git a/backend/modules/evaluation/application/convertor/common/common.go b/backend/modules/evaluation/application/convertor/common/common.go index f4166b3ba..18cb91b4f 100644 --- a/backend/modules/evaluation/application/convertor/common/common.go +++ b/backend/modules/evaluation/application/convertor/common/common.go @@ -432,11 +432,14 @@ func ConvertModelConfigDTO2DO(config *commondto.ModelConfig) *commonentity.Model } return &commonentity.ModelConfig{ - ModelID: config.GetModelID(), - ModelName: gptr.Indirect(config.ModelName), - Temperature: config.Temperature, - MaxTokens: config.MaxTokens, - TopP: config.TopP, + ModelID: config.ModelID, + ModelName: gptr.Indirect(config.ModelName), + Temperature: config.Temperature, + MaxTokens: config.MaxTokens, + TopP: config.TopP, + Protocol: config.Protocol, + Identification: config.Identification, + PresetModel: config.PresetModel, } } @@ -447,14 +450,17 @@ func ConvertModelConfigDO2DTO(config *commonentity.ModelConfig) *commondto.Model } dto := &commondto.ModelConfig{ - ModelID: gptr.Of(config.ModelID), - ModelName: gptr.Of(config.ModelName), - Temperature: config.Temperature, - MaxTokens: config.MaxTokens, - TopP: config.TopP, - } - if config.ModelID > 0 { - dto.ModelID = gptr.Of(config.ModelID) + ModelID: config.ModelID, + ModelName: gptr.Of(config.ModelName), + Temperature: config.Temperature, + MaxTokens: config.MaxTokens, + TopP: config.TopP, + Protocol: config.Protocol, + Identification: config.Identification, + PresetModel: config.PresetModel, + } + if config.GetModelID() > 0 { + dto.ModelID = config.ModelID } else if config.ProviderModelID != nil && len(gptr.Indirect(config.ProviderModelID)) > 0 { pModelID, err := strconv.ParseInt(gptr.Indirect(config.ProviderModelID), 10, 64) if err != nil { diff --git a/backend/modules/evaluation/application/convertor/common/common_test.go b/backend/modules/evaluation/application/convertor/common/common_test.go index bfeea9ade..cfffbfb68 100755 --- a/backend/modules/evaluation/application/convertor/common/common_test.go +++ b/backend/modules/evaluation/application/convertor/common/common_test.go @@ -1258,7 +1258,7 @@ func TestConvertModelConfigDTO2DO(t *testing.T) { TopP: gptr.Of(0.9), }, expected: &commonentity.ModelConfig{ - ModelID: 123, + ModelID: gptr.Of(int64(123)), ModelName: "gpt-4", Temperature: gptr.Of(0.7), MaxTokens: gptr.Of(int32(2048)), @@ -1271,7 +1271,7 @@ func TestConvertModelConfigDTO2DO(t *testing.T) { ModelID: gptr.Of(int64(456)), }, expected: &commonentity.ModelConfig{ - ModelID: 456, + ModelID: gptr.Of(int64(456)), }, }, } @@ -1301,7 +1301,7 @@ func TestConvertModelConfigDO2DTO(t *testing.T) { { name: "complete model config with model ID", input: &commonentity.ModelConfig{ - ModelID: 123, + ModelID: gptr.Of(int64(123)), ModelName: "gpt-4", Temperature: gptr.Of(0.7), MaxTokens: gptr.Of(int32(2048)), @@ -1318,7 +1318,7 @@ func TestConvertModelConfigDO2DTO(t *testing.T) { { name: "model config with provider model ID", input: &commonentity.ModelConfig{ - ModelID: 0, + ModelID: gptr.Of(int64(0)), ProviderModelID: gptr.Of("456"), ModelName: "claude-3", Temperature: gptr.Of(0.5), @@ -1332,7 +1332,7 @@ func TestConvertModelConfigDO2DTO(t *testing.T) { { name: "model config with invalid provider model ID", input: &commonentity.ModelConfig{ - ModelID: 0, + ModelID: gptr.Of(int64(0)), ProviderModelID: gptr.Of("invalid"), ModelName: "claude-3", }, diff --git a/backend/modules/evaluation/application/convertor/experiment/expt_result.go b/backend/modules/evaluation/application/convertor/experiment/expt_result.go index 6da902eec..e830820d9 100644 --- a/backend/modules/evaluation/application/convertor/experiment/expt_result.go +++ b/backend/modules/evaluation/application/convertor/experiment/expt_result.go @@ -371,6 +371,7 @@ func ExportRecordDO2DTO(from *entity.ExptResultExportRecord) *domain_expt.ExptRe }, }, URL: from.URL, + URL_: from.URL, Expired: ptr.Of(from.Expired), } diff --git a/backend/modules/evaluation/application/convertor/experiment/expt_result_test.go b/backend/modules/evaluation/application/convertor/experiment/expt_result_test.go index ad00a75aa..13a412f6d 100644 --- a/backend/modules/evaluation/application/convertor/experiment/expt_result_test.go +++ b/backend/modules/evaluation/application/convertor/experiment/expt_result_test.go @@ -5,10 +5,14 @@ package experiment import ( "testing" + "time" "github.com/bytedance/gg/gptr" "github.com/stretchr/testify/assert" + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/data/domain/dataset" + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/domain/evaluator" + domain_expt "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/domain/expt" "github.com/coze-dev/coze-loop/backend/modules/evaluation/domain/entity" ) @@ -96,7 +100,7 @@ func TestItemResultsDO2DTO_ExtField(t *testing.T) { wantNil bool }{ { - name: "Ext字段有值", + name: "Ext field has value", from: &entity.ItemResult{ ItemID: 1, TurnResults: []*entity.TurnResult{}, @@ -113,7 +117,7 @@ func TestItemResultsDO2DTO_ExtField(t *testing.T) { wantNil: false, }, { - name: "Ext字段为空map", + name: "Ext field is empty map", from: &entity.ItemResult{ ItemID: 1, TurnResults: []*entity.TurnResult{}, @@ -124,7 +128,7 @@ func TestItemResultsDO2DTO_ExtField(t *testing.T) { wantNil: true, }, { - name: "Ext字段为nil", + name: "Ext field is nil", from: &entity.ItemResult{ ItemID: 1, TurnResults: []*entity.TurnResult{}, @@ -135,7 +139,7 @@ func TestItemResultsDO2DTO_ExtField(t *testing.T) { wantNil: true, }, { - name: "Ext字段有多个值", + name: "Ext field has multiple values", from: &entity.ItemResult{ ItemID: 1, TurnResults: []*entity.TurnResult{}, @@ -172,76 +176,237 @@ func TestItemResultsDO2DTO_ExtField(t *testing.T) { } } -func TestItemResultsDO2DTOs_ExtField(t *testing.T) { - tests := []struct { - name string - from []*entity.ItemResult - want []map[string]string - }{ +func TestColumnEvalSetFieldsDO2DTO(t *testing.T) { + from := &entity.ColumnEvalSetField{ + Key: gptr.Of("key1"), + Name: gptr.Of("name1"), + Description: gptr.Of("desc1"), + ContentType: "Text", + TextSchema: gptr.Of("schema1"), + SchemaKey: gptr.Of(entity.SchemaKey(dataset.SchemaKey_String)), + } + + got := ColumnEvalSetFieldsDO2DTO(from) + + assert.Equal(t, *from.Key, *got.Key) + assert.Equal(t, *from.Name, *got.Name) + assert.Equal(t, *from.Description, *got.Description) + assert.Equal(t, from.TextSchema, got.TextSchema) + assert.Equal(t, dataset.SchemaKey_String, *got.SchemaKey) +} + +func TestColumnEvalSetFieldsDO2DTOs(t *testing.T) { + from := []*entity.ColumnEvalSetField{ { - name: "多个ItemResult,Ext字段都有值", - from: []*entity.ItemResult{ - { - ItemID: 1, - TurnResults: []*entity.TurnResult{}, - Ext: map[string]string{ - "key1": "value1", - }, - }, + Key: gptr.Of("key1"), + }, + { + Key: gptr.Of("key2"), + }, + } + + got := ColumnEvalSetFieldsDO2DTOs(from) + + assert.Len(t, got, 2) + assert.Equal(t, *from[0].Key, *got[0].Key) + assert.Equal(t, *from[1].Key, *got[1].Key) +} + +func TestColumnEvaluatorsDO2DTO(t *testing.T) { + from := &entity.ColumnEvaluator{ + EvaluatorVersionID: 1, + EvaluatorID: 2, + EvaluatorType: 3, + Name: gptr.Of("name1"), + Version: gptr.Of("v1"), + Description: gptr.Of("desc1"), + Builtin: gptr.Of(true), + } + + got := ColumnEvaluatorsDO2DTO(from) + + assert.Equal(t, from.EvaluatorVersionID, got.EvaluatorVersionID) + assert.Equal(t, from.EvaluatorID, got.EvaluatorID) + assert.Equal(t, evaluator.EvaluatorType(from.EvaluatorType), got.EvaluatorType) + assert.Equal(t, *from.Name, *got.Name) + assert.Equal(t, *from.Version, *got.Version) + assert.Equal(t, *from.Description, *got.Description) + assert.Equal(t, *from.Builtin, *got.Builtin) +} + +func TestExptColumnEvaluatorsDO2DTOs(t *testing.T) { + from := []*entity.ExptColumnEvaluator{ + { + ExptID: 101, + ColumnEvaluators: []*entity.ColumnEvaluator{ { - ItemID: 2, - TurnResults: []*entity.TurnResult{}, - Ext: map[string]string{ - "key2": "value2", - }, + Name: gptr.Of("eval1"), }, }, - want: []map[string]string{ - {"key1": "value1"}, - {"key2": "value2"}, - }, }, + } + + got := ExptColumnEvaluatorsDO2DTOs(from) + + assert.Len(t, got, 1) + assert.Equal(t, from[0].ExptID, got[0].ExperimentID) + assert.Len(t, got[0].ColumnEvaluators, 1) + assert.Equal(t, *from[0].ColumnEvaluators[0].Name, *got[0].ColumnEvaluators[0].Name) +} + +func TestTagValueDO2DtO(t *testing.T) { + from := &entity.TagValue{ + TagValueId: 1, + TagValueName: "tag1", + Status: "active", + } + + got := TagValueDO2DtO(from) + + assert.Equal(t, from.TagValueId, *got.TagValueID) + assert.Equal(t, from.TagValueName, *got.TagValueName) + assert.Equal(t, from.Status, *got.Status) +} + +func TestExptColumnAnnotationDO2DTOs(t *testing.T) { + from := []*entity.ExptColumnAnnotation{ { - name: "多个ItemResult,部分Ext字段为空", - from: []*entity.ItemResult{ + ExptID: 101, + ColumnAnnotations: []*entity.ColumnAnnotation{ { - ItemID: 1, - TurnResults: []*entity.TurnResult{}, - Ext: map[string]string{ - "key1": "value1", + TagName: "tag1", + TagContentSpec: &entity.TagContentSpec{ + ContinuousNumberSpec: &entity.ContinuousNumberSpec{ + MinValue: gptr.Of(float64(1)), + }, }, + TagStatus: "active", }, - { - ItemID: 2, - TurnResults: []*entity.TurnResult{}, - Ext: map[string]string{}, - }, - }, - want: []map[string]string{ - {"key1": "value1"}, - nil, }, }, - { - name: "空列表", - from: []*entity.ItemResult{}, - want: []map[string]string{}, + } + + got := ExptColumnAnnotationDO2DTOs(from) + + assert.Len(t, got, 1) + assert.Equal(t, from[0].ExptID, got[0].ExperimentID) + assert.Len(t, got[0].ColumnAnnotations, 1) + assert.Equal(t, from[0].ColumnAnnotations[0].TagName, *got[0].ColumnAnnotations[0].TagKeyName) + assert.NotNil(t, got[0].ColumnAnnotations[0].ContentSpec) + assert.Equal(t, *from[0].ColumnAnnotations[0].TagContentSpec.ContinuousNumberSpec.MinValue, *got[0].ColumnAnnotations[0].ContentSpec.ContinuousNumberSpec.MinValue) +} + +func TestTurnResultsDO2DTO(t *testing.T) { + from := &entity.TurnResult{ + TurnID: 1, + TurnIndex: gptr.Of(int64(0)), + ExperimentResults: []*entity.ExperimentResult{ + { + ExperimentID: 101, + Payload: &entity.ExperimentTurnPayload{ + TurnID: 1, + }, + }, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := ItemResultsDO2DTOs(tt.from) - assert.Equal(t, len(tt.from), len(got)) - for i, item := range got { - if i < len(tt.want) { - if tt.want[i] == nil { - assert.Nil(t, item.Ext) - } else { - assert.Equal(t, tt.want[i], item.Ext) - } - } - } - }) + got := TurnResultsDO2DTO(from) + + assert.Equal(t, from.TurnID, got.TurnID) + assert.Equal(t, from.TurnIndex, got.TurnIndex) + assert.Len(t, got.ExperimentResults, 1) +} + +func TestTurnAnnotationDO2DTO(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + got := TurnAnnotationDO2DTO(nil) + assert.NotNil(t, got) + assert.Empty(t, got.AnnotateRecords) + }) + + t.Run("with records", func(t *testing.T) { + from := &entity.TurnAnnotateResult{ + AnnotateRecords: map[int64]*entity.AnnotateRecord{ + 1: { + ID: 1, + TagKeyID: 2, + AnnotateData: &entity.AnnotateData{ + Score: gptr.Of(float64(4.5)), + }, + }, + }, + } + + got := TurnAnnotationDO2DTO(from) + + assert.Len(t, got.AnnotateRecords, 1) + assert.Equal(t, "4.5", *got.AnnotateRecords[1].Score) + }) +} + +func TestTurnTrajectoryAnalysisResultDO2DTO(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + got := TurnTrajectoryAnalysisResultDO2DTO(nil) + assert.NotNil(t, got) + }) + + t.Run("with data", func(t *testing.T) { + from := &entity.AnalysisRecord{ + ID: 1, + Status: 1, + } + got := TurnTrajectoryAnalysisResultDO2DTO(from) + assert.Equal(t, from.ID, *got.RecordID) + }) +} + +func TestTurnSystemInfoDO2DTO(t *testing.T) { + from := &entity.TurnSystemInfo{ + TurnRunState: 1, + LogID: gptr.Of("log1"), + Error: &entity.RunError{ + Code: 1, + Message: gptr.Of("msg1"), + }, } + + got := TurnSystemInfoDO2DTO(from) + + assert.Equal(t, int32(from.TurnRunState), int32(*got.TurnRunState)) + assert.Equal(t, from.LogID, got.LogID) + assert.Equal(t, from.Error.Code, got.Error.Code) +} + +func TestExportRecordDO2DTO(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + got := ExportRecordDO2DTO(nil) + assert.Nil(t, got) + }) + + t.Run("with data", func(t *testing.T) { + now := time.Now() + from := &entity.ExptResultExportRecord{ + ID: 1, + SpaceID: 2, + ExptID: 101, + CsvExportStatus: entity.CSVExportStatus_Success, + CreatedBy: "3", + URL: gptr.Of("http://test.com"), + Expired: false, + StartAt: &now, + EndAt: &now, + ErrMsg: "", + } + + got := ExportRecordDO2DTO(from) + + assert.Equal(t, from.ID, got.ExportID) + assert.Equal(t, from.SpaceID, got.WorkspaceID) + assert.Equal(t, from.ExptID, got.ExptID) + assert.Equal(t, domain_expt.CSVExportStatusSuccess, got.CsvExportStatus) + assert.Equal(t, *from.URL, *got.URL) + assert.Equal(t, *from.URL, *got.URL_) + assert.Equal(t, from.StartAt.Unix(), *got.StartTime) + assert.Equal(t, from.EndAt.Unix(), *got.EndTime) + }) } diff --git a/backend/modules/evaluation/application/convertor/target/eval_target.go b/backend/modules/evaluation/application/convertor/target/eval_target.go index 4c35eb97f..1ddc355ec 100644 --- a/backend/modules/evaluation/application/convertor/target/eval_target.go +++ b/backend/modules/evaluation/application/convertor/target/eval_target.go @@ -166,7 +166,7 @@ func EvalTargetVersionDO2DTO(targetVersionDO *do.EvalTargetVersion) (targetVersi BaseInfo: commonconvertor.ConvertBaseInfoDO2DTO(targetVersionDO.CozeWorkflow.BaseInfo), } } - case do.EvalTargetTypeVolcengineAgent: + case do.EvalTargetTypeVolcengineAgent, do.EvalTargetTypeVolcengineAgentAgentkit: targetVersionDTO.EvalTargetContent = &dto.EvalTargetContent{ InputSchemas: make([]*commondto.ArgsSchema, 0), OutputSchemas: make([]*commondto.ArgsSchema, 0), @@ -186,6 +186,7 @@ func EvalTargetVersionDO2DTO(targetVersionDO *do.EvalTargetVersion) (targetVersi VolcengineAgentEndpoints: endpoints, Protocol: gptr.Of(gptr.Indirect(targetVersionDO.VolcengineAgent.Protocol)), BaseInfo: commonconvertor.ConvertBaseInfoDO2DTO(targetVersionDO.VolcengineAgent.BaseInfo), + RuntimeID: targetVersionDO.VolcengineAgent.RuntimeID, } } case do.EvalTargetTypeCustomRPCServer: diff --git a/backend/modules/evaluation/application/evaluator_app.go b/backend/modules/evaluation/application/evaluator_app.go index 0b3944593..6e43778fe 100644 --- a/backend/modules/evaluation/application/evaluator_app.go +++ b/backend/modules/evaluation/application/evaluator_app.go @@ -268,6 +268,9 @@ func (e *EvaluatorHandlerImpl) GetEvaluator(ctx context.Context, request *evalua // CreateEvaluator 创建 evaluator_version func (e *EvaluatorHandlerImpl) CreateEvaluator(ctx context.Context, request *evaluatorservice.CreateEvaluatorRequest) (resp *evaluatorservice.CreateEvaluatorResponse, err error) { + if request.GetEvaluator() != nil && request.GetEvaluator().GetWorkspaceID() == 0 { + request.Evaluator.WorkspaceID = request.WorkspaceID + } // 校验参数 if err = e.checkCreateEvaluatorRequest(ctx, request); err != nil { return nil, err diff --git a/backend/modules/evaluation/application/evaluator_app_test.go b/backend/modules/evaluation/application/evaluator_app_test.go index b4f50dbc7..130a2dc4e 100644 --- a/backend/modules/evaluation/application/evaluator_app_test.go +++ b/backend/modules/evaluation/application/evaluator_app_test.go @@ -9,6 +9,7 @@ import ( "fmt" "reflect" "strconv" + "strings" "testing" "time" @@ -22,7 +23,7 @@ import ( benefitmocks "github.com/coze-dev/coze-loop/backend/infra/external/benefit/mocks" idgenmocks "github.com/coze-dev/coze-loop/backend/infra/idgen/mocks" "github.com/coze-dev/coze-loop/backend/infra/middleware/session" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/domain/common" + common "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/domain/common" evaluatordto "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/domain/evaluator" evaluatorservice "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/evaluation/evaluator" "github.com/coze-dev/coze-loop/backend/modules/evaluation/application/convertor/evaluator" @@ -315,21 +316,33 @@ func TestEvaluatorHandlerImpl_ListEvaluators(t *testing.T) { wantErr: false, }, { - name: "success - with evaluator type filter", + name: "success - builtin evaluators request with filters", req: &evaluatorservice.ListEvaluatorsRequest{ WorkspaceID: validSpaceID, + Builtin: gptr.Of(true), + SearchName: gptr.Of("builtin"), EvaluatorType: []evaluatordto.EvaluatorType{evaluatordto.EvaluatorType_Prompt}, }, mockSetup: func() { - mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) - mockEvaluatorService.EXPECT().ListEvaluator(gomock.Any(), gomock.Any()). - Return(validEvaluators[:1], int64(1), nil) + // Mock auth + mockAuth.EXPECT().Authorization(gomock.Any(), &rpc.AuthorizationParam{ + ObjectID: strconv.FormatInt(validSpaceID, 10), + SpaceID: validSpaceID, + ActionObjects: []*rpc.ActionObject{{Action: gptr.Of("listLoopEvaluator"), EntityType: gptr.Of(rpc.AuthEntityType_Space)}}, + }).Return(nil) + + // Mock builtin evaluator service call + mockEvaluatorService.EXPECT().ListBuiltinEvaluator(gomock.Any(), gomock.Any()). + Return(validEvaluators, int64(2), nil) + + // Mock user info service mockUserInfoService.EXPECT().PackUserInfo(gomock.Any(), gomock.Any()).Return() }, wantResp: &evaluatorservice.ListEvaluatorsResponse{ - Total: gptr.Of(int64(1)), + Total: gptr.Of(int64(2)), Evaluators: []*evaluatordto.Evaluator{ evaluator.ConvertEvaluatorDO2DTO(validEvaluators[0]), + evaluator.ConvertEvaluatorDO2DTO(validEvaluators[1]), }, }, wantErr: false, @@ -1056,370 +1069,692 @@ func TestEvaluatorHandlerImpl_BatchGetEvaluatorVersions(t *testing.T) { } } -// 新增的复杂业务逻辑测试 +func TestEvaluatorHandlerImpl_ListEvaluatorVersions(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() -// TestEvaluatorHandlerImpl_ComplexBusinessScenarios 测试复杂业务场景 -func TestEvaluatorHandlerImpl_ComplexBusinessScenarios(t *testing.T) { - t.Parallel() + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) + mockUserInfoService := userinfomocks.NewMockUserInfoService(ctrl) + + app := &EvaluatorHandlerImpl{ + auth: mockAuth, + evaluatorService: mockEvaluatorService, + userInfoService: mockUserInfoService, + } + + workspaceID := int64(100) + evaluatorID := int64(200) + evaluators := []*entity.Evaluator{ + { + ID: evaluatorID, + SpaceID: workspaceID, + EvaluatorType: entity.EvaluatorTypePrompt, + PromptEvaluatorVersion: &entity.PromptEvaluatorVersion{ + ID: 1, + EvaluatorID: evaluatorID, + Version: "1.0.0", + }, + }, + } tests := []struct { - name string - testFunc func(t *testing.T) + name string + req *evaluatorservice.ListEvaluatorVersionsRequest + mockSetup func() + wantErr bool + wantErrCode int32 }{ { - name: "多层依赖服务交互测试", - testFunc: func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - // 创建所有依赖的 mock - mockIDGen := idgenmocks.NewMockIIDGenerator(ctrl) - mockConfiger := confmocks.NewMockIConfiger(ctrl) - mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) - mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) - mockEvaluatorRecordService := mocks.NewMockEvaluatorRecordService(ctrl) - mockMetrics := metricsmock.NewMockEvaluatorExecMetrics(ctrl) - mockUserInfoService := userinfomocks.NewMockUserInfoService(ctrl) - mockAuditClient := auditmocks.NewMockIAuditService(ctrl) - mockBenefitService := benefitmocks.NewMockIBenefitService(ctrl) - mockFileProvider := rpcmocks.NewMockIFileProvider(ctrl) - - mockExptResultService := mocks.NewMockExptResultService(ctrl) - handler := NewEvaluatorHandlerImpl( - mockIDGen, - mockConfiger, - mockAuth, - mockEvaluatorService, - mockEvaluatorRecordService, - nil, // mockEvaluatorTemplateService - 暂时设为nil - mockMetrics, - mockUserInfoService, - mockAuditClient, - mockBenefitService, - mockFileProvider, - make(map[entity.EvaluatorType]service.EvaluatorSourceService), - mockExptResultService, - ) + name: "success", + req: &evaluatorservice.ListEvaluatorVersionsRequest{ + WorkspaceID: workspaceID, + EvaluatorID: &evaluatorID, + }, + mockSetup: func() { + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockEvaluatorService.EXPECT().ListEvaluatorVersion(gomock.Any(), gomock.Any()). + Return(evaluators, int64(1), nil) + mockUserInfoService.EXPECT().PackUserInfo(gomock.Any(), gomock.Any()).Return() + }, + wantErr: false, + }, + { + name: "auth_failed", + req: &evaluatorservice.ListEvaluatorVersionsRequest{ + WorkspaceID: workspaceID, + EvaluatorID: &evaluatorID, + }, + mockSetup: func() { + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()). + Return(errorx.NewByCode(errno.CommonNoPermissionCode)) + }, + wantErr: true, + wantErrCode: errno.CommonNoPermissionCode, + }, + { + name: "service_failed", + req: &evaluatorservice.ListEvaluatorVersionsRequest{ + WorkspaceID: workspaceID, + EvaluatorID: &evaluatorID, + }, + mockSetup: func() { + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockEvaluatorService.EXPECT().ListEvaluatorVersion(gomock.Any(), gomock.Any()). + Return(nil, int64(0), errors.New("db error")) + }, + wantErr: true, + }, + { + name: "success_with_params", + req: &evaluatorservice.ListEvaluatorVersionsRequest{ + WorkspaceID: workspaceID, + EvaluatorID: &evaluatorID, + PageSize: gptr.Of(int32(10)), + PageNumber: gptr.Of(int32(2)), + OrderBys: []*common.OrderBy{ + {Field: gptr.Of("id"), IsAsc: gptr.Of(true)}, + }, + }, + mockSetup: func() { + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockEvaluatorService.EXPECT().ListEvaluatorVersion(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, req *entity.ListEvaluatorVersionRequest) ([]*entity.Evaluator, int64, error) { + assert.Equal(t, int32(10), req.PageSize) + assert.Equal(t, int32(2), req.PageNum) + assert.Equal(t, "id", *req.OrderBys[0].Field) + return evaluators, int64(1), nil + }) + mockUserInfoService.EXPECT().PackUserInfo(gomock.Any(), gomock.Any()).Return() + }, + wantErr: false, + }, + } - // 测试复杂的调试场景,涉及多个服务交互 - request := &evaluatorservice.DebugEvaluatorRequest{ - WorkspaceID: 123, - EvaluatorType: evaluatordto.EvaluatorType_Prompt, - EvaluatorContent: &evaluatordto.EvaluatorContent{ - PromptEvaluator: &evaluatordto.PromptEvaluator{ - MessageList: []*common.Message{ - { - Role: common.RolePtr(common.Role_User), - Content: &common.Content{ - ContentType: gptr.Of(common.ContentTypeMultiPart), - MultiPart: []*common.Content{ - { - ContentType: gptr.Of(common.ContentTypeText), - Text: gptr.Of("请分析这张图片:"), - }, - { - ContentType: gptr.Of(common.ContentTypeImage), - Image: &common.Image{ - URI: gptr.Of("test-image-uri"), - }, - }, - }, - }, - }, - }, - }, - }, - InputData: &evaluatordto.EvaluatorInputData{ - InputFields: map[string]*common.Content{ - "image": { - ContentType: gptr.Of(common.ContentTypeImage), - Image: &common.Image{ - URI: gptr.Of("input-image-uri"), - }, - }, - }, - }, + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mockSetup() + resp, err := app.ListEvaluatorVersions(context.Background(), tt.req) + if tt.wantErr { + assert.Error(t, err) + if tt.wantErrCode != 0 { + statusErr, ok := errorx.FromStatusError(err) + assert.True(t, ok) + assert.Equal(t, tt.wantErrCode, statusErr.Code()) } + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, int64(1), *resp.Total) + assert.Len(t, resp.EvaluatorVersions, 1) + } + }) + } +} - // 设置复杂的 mock 期望 - // 1. 鉴权 - mockAuth.EXPECT(). - Authorization(gomock.Any(), &rpc.AuthorizationParam{ - ObjectID: "123", - SpaceID: int64(123), - ActionObjects: []*rpc.ActionObject{{Action: gptr.Of("debugLoopEvaluator"), EntityType: gptr.Of(rpc.AuthEntityType_Space)}}, - }). - Return(nil). - Times(1) - - // 2. 权益检查 - mockBenefitService.EXPECT(). - CheckEvaluatorBenefit(gomock.Any(), &benefit.CheckEvaluatorBenefitParams{ - ConnectorUID: "", - SpaceID: 123, - }). - Return(&benefit.CheckEvaluatorBenefitResult{DenyReason: nil}, nil). - Times(1) - - // 3. 文件 URI 转 URL - mockFileProvider.EXPECT(). - MGetFileURL(gomock.Any(), []string{"input-image-uri"}). - Return(map[string]string{"input-image-uri": "https://example.com/image.jpg"}, nil). - Times(1) - - // 4. 评估器调试 - mockEvaluatorService.EXPECT(). - DebugEvaluator(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, evaluator *entity.Evaluator, input *entity.EvaluatorInputData, evaluatorRunConf *entity.EvaluatorRunConfig, exptSpaceID int64) (*entity.EvaluatorOutputData, error) { - // 验证输入数据已被正确处理 - assert.Equal(t, int64(123), evaluator.SpaceID) - assert.Equal(t, entity.EvaluatorTypePrompt, evaluator.EvaluatorType) +func TestEvaluatorHandlerImpl_SubmitEvaluatorVersion(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() - // 验证 URI 已转换为 URL - imageContent := input.InputFields["image"] - assert.NotNil(t, imageContent) - assert.NotNil(t, imageContent.Image) - assert.Equal(t, "https://example.com/image.jpg", gptr.Indirect(imageContent.Image.URL)) + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) + mockAuditClient := auditmocks.NewMockIAuditService(ctrl) - return &entity.EvaluatorOutputData{ - EvaluatorResult: &entity.EvaluatorResult{ - Score: gptr.Of(0.85), - Reasoning: "多模态内容分析完成", - }, - }, nil - }). - Times(1) + app := &EvaluatorHandlerImpl{ + auth: mockAuth, + evaluatorService: mockEvaluatorService, + auditClient: mockAuditClient, + } - ctx := context.Background() - resp, err := handler.DebugEvaluator(ctx, request) + workspaceID := int64(100) + evaluatorID := int64(200) + version := "1.0.0" + evaluatorDO := &entity.Evaluator{ + ID: evaluatorID, + SpaceID: workspaceID, + } - assert.NoError(t, err) - assert.NotNil(t, resp) - assert.NotNil(t, resp.EvaluatorOutputData) - assert.Equal(t, 0.85, gptr.Indirect(resp.EvaluatorOutputData.EvaluatorResult_.Score)) + tests := []struct { + name string + req *evaluatorservice.SubmitEvaluatorVersionRequest + mockSetup func() + wantErr bool + wantErrCode int32 + }{ + { + name: "success", + req: &evaluatorservice.SubmitEvaluatorVersionRequest{ + WorkspaceID: workspaceID, + EvaluatorID: evaluatorID, + Version: version, }, + mockSetup: func() { + mockAuditClient.EXPECT().Audit(gomock.Any(), gomock.Any()). + Return(audit.AuditRecord{AuditStatus: audit.AuditStatus_Approved}, nil) + mockEvaluatorService.EXPECT().GetEvaluator(gomock.Any(), workspaceID, evaluatorID, false). + Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockEvaluatorService.EXPECT().SubmitEvaluatorVersion(gomock.Any(), evaluatorDO, version, gomock.Any(), gomock.Any()). + Return(evaluatorDO, nil) + }, + wantErr: false, }, { - name: "权限验证和审核流程测试", - testFunc: func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) - mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) - mockAuditClient := auditmocks.NewMockIAuditService(ctrl) - mockMetrics := metricsmock.NewMockEvaluatorExecMetrics(ctrl) + name: "invalid_version", + req: &evaluatorservice.SubmitEvaluatorVersionRequest{ + WorkspaceID: workspaceID, + EvaluatorID: evaluatorID, + Version: "invalid", + }, + mockSetup: func() {}, + wantErr: true, + }, + { + name: "audit_rejected", + req: &evaluatorservice.SubmitEvaluatorVersionRequest{ + WorkspaceID: workspaceID, + EvaluatorID: evaluatorID, + Version: version, + }, + mockSetup: func() { + mockAuditClient.EXPECT().Audit(gomock.Any(), gomock.Any()). + Return(audit.AuditRecord{AuditStatus: audit.AuditStatus_Rejected}, nil) + }, + wantErr: true, + wantErrCode: errno.RiskContentDetectedCode, + }, + { + name: "version_too_long", + req: &evaluatorservice.SubmitEvaluatorVersionRequest{ + WorkspaceID: workspaceID, + EvaluatorID: evaluatorID, + Version: string(make([]byte, consts.MaxEvaluatorVersionLength+1)), + }, + mockSetup: func() {}, + wantErr: true, + }, + { + name: "description_too_long", + req: &evaluatorservice.SubmitEvaluatorVersionRequest{ + WorkspaceID: workspaceID, + EvaluatorID: evaluatorID, + Version: version, + Description: gptr.Of(string(make([]byte, consts.MaxEvaluatorVersionDescLength+1))), + }, + mockSetup: func() {}, + wantErr: true, + }, + { + name: "evaluator_not_found", + req: &evaluatorservice.SubmitEvaluatorVersionRequest{ + WorkspaceID: workspaceID, + EvaluatorID: evaluatorID, + Version: version, + }, + mockSetup: func() { + mockAuditClient.EXPECT().Audit(gomock.Any(), gomock.Any()). + Return(audit.AuditRecord{AuditStatus: audit.AuditStatus_Approved}, nil) + mockEvaluatorService.EXPECT().GetEvaluator(gomock.Any(), workspaceID, evaluatorID, false). + Return(nil, nil) + }, + wantErr: true, + wantErrCode: errno.EvaluatorNotExistCode, + }, + } - handler := &EvaluatorHandlerImpl{ - auth: mockAuth, - evaluatorService: mockEvaluatorService, - auditClient: mockAuditClient, - metrics: mockMetrics, + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mockSetup() + resp, err := app.SubmitEvaluatorVersion(context.Background(), tt.req) + if tt.wantErr { + assert.Error(t, err) + if tt.wantErrCode != 0 { + statusErr, ok := errorx.FromStatusError(err) + assert.True(t, ok) + assert.Equal(t, tt.wantErrCode, statusErr.Code()) } + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + } + }) + } +} - // 测试包含敏感内容的创建请求 - request := &evaluatorservice.CreateEvaluatorRequest{ - Evaluator: &evaluatordto.Evaluator{ - WorkspaceID: gptr.Of(int64(123)), - Name: gptr.Of("敏感内容评估器"), - Description: gptr.Of("包含敏感词汇的描述"), - EvaluatorType: gptr.Of(evaluatordto.EvaluatorType_Prompt), - CurrentVersion: &evaluatordto.EvaluatorVersion{ - Version: gptr.Of("1.0.0"), - Description: gptr.Of("版本描述包含敏感内容"), - EvaluatorContent: &evaluatordto.EvaluatorContent{ - PromptEvaluator: &evaluatordto.PromptEvaluator{}, - }, - }, - }, - } +func TestEvaluatorHandlerImpl_CheckEvaluatorName(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() - // 设置审核被拒绝的场景 - mockAuditClient.EXPECT(). - Audit(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, param audit.AuditParam) (audit.AuditRecord, error) { - // 验证审核参数 - assert.Equal(t, audit.AuditType_CozeLoopEvaluatorModify, param.AuditType) - assert.Contains(t, param.AuditData["texts"], "敏感内容评估器") + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) - return audit.AuditRecord{ - AuditStatus: audit.AuditStatus_Rejected, - FailedReason: gptr.Of("内容包含敏感词汇"), - }, nil - }). - Times(1) + app := &EvaluatorHandlerImpl{ + auth: mockAuth, + evaluatorService: mockEvaluatorService, + } - ctx := context.Background() - _, err := handler.CreateEvaluator(ctx, request) - assert.Error(t, err) + workspaceID := int64(100) + name := "test-name" - // 验证错误类型 - statusErr, ok := errorx.FromStatusError(err) - assert.True(t, ok) - assert.Equal(t, int32(errno.RiskContentDetectedCode), statusErr.Code()) + tests := []struct { + name string + req *evaluatorservice.CheckEvaluatorNameRequest + mockSetup func() + wantPass bool + wantErr bool + }{ + { + name: "pass", + req: &evaluatorservice.CheckEvaluatorNameRequest{ + WorkspaceID: workspaceID, + Name: name, }, + mockSetup: func() { + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockEvaluatorService.EXPECT().CheckNameExist(gomock.Any(), workspaceID, gomock.Any(), name). + Return(false, nil) + }, + wantPass: true, + wantErr: false, }, { - name: "并发安全和数据一致性测试", - testFunc: func(t *testing.T) { - t.Parallel() + name: "name_exists", + req: &evaluatorservice.CheckEvaluatorNameRequest{ + WorkspaceID: workspaceID, + Name: name, + }, + mockSetup: func() { + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockEvaluatorService.EXPECT().CheckNameExist(gomock.Any(), workspaceID, gomock.Any(), name). + Return(true, nil) + }, + wantPass: false, + wantErr: false, + }, + { + name: "auth_failed", + req: &evaluatorservice.CheckEvaluatorNameRequest{ + WorkspaceID: workspaceID, + Name: name, + }, + mockSetup: func() { + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(errors.New("auth failed")) + }, + wantErr: true, + }, + } - ctrl := gomock.NewController(t) - defer ctrl.Finish() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mockSetup() + resp, err := app.CheckEvaluatorName(context.Background(), tt.req) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.wantPass, *resp.Pass) + } + }) + } +} - mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) - mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) - mockUserInfoService := userinfomocks.NewMockUserInfoService(ctrl) +func TestEvaluatorHandlerImpl_GetEvaluatorRecord(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() - handler := &EvaluatorHandlerImpl{ - auth: mockAuth, - evaluatorService: mockEvaluatorService, - userInfoService: mockUserInfoService, - } + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) + mockEvaluatorRecordService := mocks.NewMockEvaluatorRecordService(ctrl) + mockUserInfoService := userinfomocks.NewMockUserInfoService(ctrl) - // 模拟并发访问同一个评估器 - evaluatorID := int64(123) - spaceID := int64(456) + app := &EvaluatorHandlerImpl{ + auth: mockAuth, + evaluatorService: mockEvaluatorService, + evaluatorRecordService: mockEvaluatorRecordService, + userInfoService: mockUserInfoService, + } - evaluator := &entity.Evaluator{ - ID: evaluatorID, - SpaceID: spaceID, - Name: "并发测试评估器", - } + recordID := int64(10) + versionID := int64(20) + spaceID := int64(100) + record := &entity.EvaluatorRecord{ + ID: recordID, + EvaluatorVersionID: versionID, + SpaceID: spaceID, + } + evaluatorDO := &entity.Evaluator{ + ID: 1, + SpaceID: spaceID, + } - // 设置并发调用的期望 - mockEvaluatorService.EXPECT(). - GetEvaluator(gomock.Any(), spaceID, evaluatorID, false). - Return(evaluator, nil). - Times(10) // 10个并发请求 + tests := []struct { + name string + req *evaluatorservice.GetEvaluatorRecordRequest + mockSetup func() + wantErr bool + }{ + { + name: "success", + req: &evaluatorservice.GetEvaluatorRecordRequest{ + EvaluatorRecordID: recordID, + }, + mockSetup: func() { + mockEvaluatorRecordService.EXPECT().GetEvaluatorRecord(gomock.Any(), recordID, false). + Return(record, nil) + mockEvaluatorService.EXPECT().GetEvaluatorVersion(gomock.Any(), gomock.Any(), versionID, false, false). + Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockUserInfoService.EXPECT().PackUserInfo(gomock.Any(), gomock.Any()).Return() + }, + wantErr: false, + }, + { + name: "record_not_found", + req: &evaluatorservice.GetEvaluatorRecordRequest{ + EvaluatorRecordID: recordID, + }, + mockSetup: func() { + mockEvaluatorRecordService.EXPECT().GetEvaluatorRecord(gomock.Any(), recordID, false). + Return(nil, nil) + }, + wantErr: false, + }, + { + name: "evaluator_not_found", + req: &evaluatorservice.GetEvaluatorRecordRequest{ + EvaluatorRecordID: recordID, + }, + mockSetup: func() { + mockEvaluatorRecordService.EXPECT().GetEvaluatorRecord(gomock.Any(), recordID, false). + Return(record, nil) + mockEvaluatorService.EXPECT().GetEvaluatorVersion(gomock.Any(), gomock.Any(), versionID, false, false). + Return(nil, nil) + }, + wantErr: false, + }, + { + name: "auth_failed", + req: &evaluatorservice.GetEvaluatorRecordRequest{ + EvaluatorRecordID: recordID, + }, + mockSetup: func() { + mockEvaluatorRecordService.EXPECT().GetEvaluatorRecord(gomock.Any(), recordID, false). + Return(record, nil) + mockEvaluatorService.EXPECT().GetEvaluatorVersion(gomock.Any(), gomock.Any(), versionID, false, false). + Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(errors.New("auth failed")) + }, + wantErr: true, + }, + } - mockAuth.EXPECT(). - Authorization(gomock.Any(), gomock.Any()). - Return(nil). - Times(10) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mockSetup() + resp, err := app.GetEvaluatorRecord(context.Background(), tt.req) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + } + }) + } +} - mockUserInfoService.EXPECT(). - PackUserInfo(gomock.Any(), gomock.Any()). - Times(10) +func TestEvaluatorHandlerImpl_BatchGetEvaluatorRecords(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() - // 并发调用 - const numGoroutines = 10 - results := make(chan error, numGoroutines) + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockEvaluatorRecordService := mocks.NewMockEvaluatorRecordService(ctrl) - for i := 0; i < numGoroutines; i++ { - go func() { - ctx := context.Background() - request := &evaluatorservice.GetEvaluatorRequest{ - WorkspaceID: spaceID, - EvaluatorID: &evaluatorID, - } + app := &EvaluatorHandlerImpl{ + auth: mockAuth, + evaluatorRecordService: mockEvaluatorRecordService, + } - resp, err := handler.GetEvaluator(ctx, request) - if err != nil { - results <- err - return - } + recordIDs := []int64{10, 11} + spaceID := int64(100) + records := []*entity.EvaluatorRecord{ + {ID: 10, SpaceID: spaceID}, + {ID: 11, SpaceID: spaceID}, + } - // 验证响应数据一致性 - if resp.Evaluator.GetEvaluatorID() != evaluatorID { - results <- fmt.Errorf("inconsistent evaluator ID: expected %d, got %d", - evaluatorID, resp.Evaluator.GetEvaluatorID()) - return - } + tests := []struct { + name string + req *evaluatorservice.BatchGetEvaluatorRecordsRequest + mockSetup func() + wantErr bool + }{ + { + name: "success", + req: &evaluatorservice.BatchGetEvaluatorRecordsRequest{ + EvaluatorRecordIds: recordIDs, + }, + mockSetup: func() { + mockEvaluatorRecordService.EXPECT().BatchGetEvaluatorRecord(gomock.Any(), recordIDs, false). + Return(records, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + }, + wantErr: false, + }, + { + name: "empty_records", + req: &evaluatorservice.BatchGetEvaluatorRecordsRequest{ + EvaluatorRecordIds: recordIDs, + }, + mockSetup: func() { + mockEvaluatorRecordService.EXPECT().BatchGetEvaluatorRecord(gomock.Any(), recordIDs, false). + Return(nil, nil) + }, + wantErr: false, + }, + { + name: "service_error", + req: &evaluatorservice.BatchGetEvaluatorRecordsRequest{ + EvaluatorRecordIds: recordIDs, + }, + mockSetup: func() { + mockEvaluatorRecordService.EXPECT().BatchGetEvaluatorRecord(gomock.Any(), recordIDs, false). + Return(nil, errors.New("db error")) + }, + wantErr: true, + }, + } - results <- nil - }() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mockSetup() + resp, err := app.BatchGetEvaluatorRecords(context.Background(), tt.req) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + if tt.name == "success" { + assert.Len(t, resp.Records, 2) + } else { + assert.Len(t, resp.Records, 0) } + } + }) + } +} - // 收集结果 - for i := 0; i < numGoroutines; i++ { - select { - case err := <-results: - assert.NoError(t, err) - case <-time.After(5 * time.Second): - t.Fatal("Timeout waiting for concurrent calls") - } - } - }, +func TestEvaluatorHandlerImpl_GetDefaultPromptEvaluatorTools(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConfiger := confmocks.NewMockIConfiger(ctrl) + + app := &EvaluatorHandlerImpl{ + configer: mockConfiger, + } + + toolsConf := map[string]*evaluatordto.Tool{ + consts.DefaultEvaluatorToolKey: { + Type: evaluatordto.ToolType_Function, + Function: &evaluatordto.Function{Name: "default-tool"}, }, + } + + mockConfiger.EXPECT().GetEvaluatorToolConf(gomock.Any()).Return(toolsConf) + + resp, err := app.GetDefaultPromptEvaluatorTools(context.Background(), &evaluatorservice.GetDefaultPromptEvaluatorToolsRequest{}) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Len(t, resp.Tools, 1) + assert.Equal(t, "default-tool", resp.Tools[0].Function.Name) +} + +// 新增的复杂业务逻辑测试 + +// TestEvaluatorHandlerImpl_ComplexBusinessScenarios 测试复杂业务场景 +func TestEvaluatorHandlerImpl_ComplexBusinessScenarios(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + testFunc func(t *testing.T) + }{ { - name: "错误处理和恢复机制测试", + name: "多层依赖服务交互测试", testFunc: func(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) defer ctrl.Finish() + // 创建所有依赖的 mock + mockIDGen := idgenmocks.NewMockIIDGenerator(ctrl) + mockConfiger := confmocks.NewMockIConfiger(ctrl) mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) mockEvaluatorRecordService := mocks.NewMockEvaluatorRecordService(ctrl) + mockMetrics := metricsmock.NewMockEvaluatorExecMetrics(ctrl) + mockUserInfoService := userinfomocks.NewMockUserInfoService(ctrl) + mockAuditClient := auditmocks.NewMockIAuditService(ctrl) + mockBenefitService := benefitmocks.NewMockIBenefitService(ctrl) + mockFileProvider := rpcmocks.NewMockIFileProvider(ctrl) - handler := &EvaluatorHandlerImpl{ - auth: mockAuth, - evaluatorService: mockEvaluatorService, - evaluatorRecordService: mockEvaluatorRecordService, - } + mockExptResultService := mocks.NewMockExptResultService(ctrl) + handler := NewEvaluatorHandlerImpl( + mockIDGen, + mockConfiger, + mockAuth, + mockEvaluatorService, + mockEvaluatorRecordService, + nil, // mockEvaluatorTemplateService - 暂时设为nil + mockMetrics, + mockUserInfoService, + mockAuditClient, + mockBenefitService, + mockFileProvider, + make(map[entity.EvaluatorType]service.EvaluatorSourceService), + mockExptResultService, + ) - // 测试运行评估器时的错误恢复 - request := &evaluatorservice.RunEvaluatorRequest{ - EvaluatorVersionID: 123, - WorkspaceID: 456, + // 测试复杂的调试场景,涉及多个服务交互 + request := &evaluatorservice.DebugEvaluatorRequest{ + WorkspaceID: 123, + EvaluatorType: evaluatordto.EvaluatorType_Prompt, + EvaluatorContent: &evaluatordto.EvaluatorContent{ + PromptEvaluator: &evaluatordto.PromptEvaluator{ + MessageList: []*common.Message{ + { + Role: common.RolePtr(common.Role_User), + Content: &common.Content{ + ContentType: gptr.Of(common.ContentTypeMultiPart), + MultiPart: []*common.Content{ + { + ContentType: gptr.Of(common.ContentTypeText), + Text: gptr.Of("请分析这张图片:"), + }, + { + ContentType: gptr.Of(common.ContentTypeImage), + Image: &common.Image{ + URI: gptr.Of("test-image-uri"), + }, + }, + }, + }, + }, + }, + }, + }, InputData: &evaluatordto.EvaluatorInputData{ - InputFields: map[string]*common.Content{}, + InputFields: map[string]*common.Content{ + "image": { + ContentType: gptr.Of(common.ContentTypeImage), + Image: &common.Image{ + URI: gptr.Of("input-image-uri"), + }, + }, + }, }, } - // 第一次调用失败,第二次成功(模拟重试机制) - callCount := 0 - mockEvaluatorService.EXPECT(). - GetEvaluatorVersion(gomock.Any(), gomock.Any(), int64(123), false, gomock.Any()). - DoAndReturn(func(ctx context.Context, spaceID *int64, evaluatorVersionID int64, includeDeleted bool, withTags bool) (*entity.Evaluator, error) { - callCount++ - if callCount == 1 { - return nil, errors.New("temporary database error") - } - return &entity.Evaluator{ - ID: 1, - SpaceID: 456, - Name: "test-evaluator", - }, nil - }). - Times(2) - + // 设置复杂的 mock 期望 + // 1. 鉴权 mockAuth.EXPECT(). - Authorization(gomock.Any(), gomock.Any()). + Authorization(gomock.Any(), &rpc.AuthorizationParam{ + ObjectID: "123", + SpaceID: int64(123), + ActionObjects: []*rpc.ActionObject{{Action: gptr.Of("debugLoopEvaluator"), EntityType: gptr.Of(rpc.AuthEntityType_Space)}}, + }). Return(nil). Times(1) + // 2. 权益检查 + mockBenefitService.EXPECT(). + CheckEvaluatorBenefit(gomock.Any(), &benefit.CheckEvaluatorBenefitParams{ + ConnectorUID: "", + SpaceID: 123, + }). + Return(&benefit.CheckEvaluatorBenefitResult{DenyReason: nil}, nil). + Times(1) + + // 3. 文件 URI 转 URL + mockFileProvider.EXPECT(). + MGetFileURL(gomock.Any(), []string{"input-image-uri"}). + Return(map[string]string{"input-image-uri": "https://example.com/image.jpg"}, nil). + Times(1) + + // 4. 评估器调试 mockEvaluatorService.EXPECT(). - RunEvaluator(gomock.Any(), gomock.Any()). - Return(&entity.EvaluatorRecord{ - ID: 789, - EvaluatorVersionID: 123, - SpaceID: 456, - }, nil). + DebugEvaluator(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, evaluator *entity.Evaluator, input *entity.EvaluatorInputData, evaluatorRunConf *entity.EvaluatorRunConfig, exptSpaceID int64) (*entity.EvaluatorOutputData, error) { + // 验证输入数据已被正确处理 + assert.Equal(t, int64(123), evaluator.SpaceID) + assert.Equal(t, entity.EvaluatorTypePrompt, evaluator.EvaluatorType) + + // 验证 URI 已转换为 URL + imageContent := input.InputFields["image"] + assert.NotNil(t, imageContent) + assert.NotNil(t, imageContent.Image) + assert.Equal(t, "https://example.com/image.jpg", gptr.Indirect(imageContent.Image.URL)) + + return &entity.EvaluatorOutputData{ + EvaluatorResult: &entity.EvaluatorResult{ + Score: gptr.Of(0.85), + Reasoning: "多模态内容分析完成", + }, + }, nil + }). Times(1) ctx := context.Background() + resp, err := handler.DebugEvaluator(ctx, request) - // 第一次调用应该失败 - resp1, err1 := handler.RunEvaluator(ctx, request) - assert.Error(t, err1) - assert.Nil(t, resp1) - - // 第二次调用应该成功 - resp2, err2 := handler.RunEvaluator(ctx, request) - assert.NoError(t, err2) - assert.NotNil(t, resp2) - assert.Equal(t, int64(789), resp2.Record.GetID()) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.NotNil(t, resp.EvaluatorOutputData) + assert.Equal(t, 0.85, gptr.Indirect(resp.EvaluatorOutputData.EvaluatorResult_.Score)) }, }, { - name: "大数据量处理性能测试", + name: "权限验证和审核流程测试", testFunc: func(t *testing.T) { t.Parallel() @@ -1428,22 +1763,239 @@ func TestEvaluatorHandlerImpl_ComplexBusinessScenarios(t *testing.T) { mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) - mockUserInfoService := userinfomocks.NewMockUserInfoService(ctrl) + mockAuditClient := auditmocks.NewMockIAuditService(ctrl) + mockMetrics := metricsmock.NewMockEvaluatorExecMetrics(ctrl) handler := &EvaluatorHandlerImpl{ auth: mockAuth, evaluatorService: mockEvaluatorService, - userInfoService: mockUserInfoService, + auditClient: mockAuditClient, + metrics: mockMetrics, } - // 创建大量评估器数据 - const numEvaluators = 1000 - evaluators := make([]*entity.Evaluator, numEvaluators) - for i := 0; i < numEvaluators; i++ { - evaluators[i] = &entity.Evaluator{ - ID: int64(i + 1), - SpaceID: 123, - Name: fmt.Sprintf("evaluator-%d", i+1), + // 测试包含敏感内容的创建请求 + request := &evaluatorservice.CreateEvaluatorRequest{ + Evaluator: &evaluatordto.Evaluator{ + WorkspaceID: gptr.Of(int64(123)), + Name: gptr.Of("敏感内容评估器"), + Description: gptr.Of("包含敏感词汇的描述"), + EvaluatorType: gptr.Of(evaluatordto.EvaluatorType_Prompt), + CurrentVersion: &evaluatordto.EvaluatorVersion{ + Version: gptr.Of("1.0.0"), + Description: gptr.Of("版本描述包含敏感内容"), + EvaluatorContent: &evaluatordto.EvaluatorContent{ + PromptEvaluator: &evaluatordto.PromptEvaluator{}, + }, + }, + }, + } + + // 设置审核被拒绝的场景 + mockAuditClient.EXPECT(). + Audit(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, param audit.AuditParam) (audit.AuditRecord, error) { + // 验证审核参数 + assert.Equal(t, audit.AuditType_CozeLoopEvaluatorModify, param.AuditType) + assert.Contains(t, param.AuditData["texts"], "敏感内容评估器") + + return audit.AuditRecord{ + AuditStatus: audit.AuditStatus_Rejected, + FailedReason: gptr.Of("内容包含敏感词汇"), + }, nil + }). + Times(1) + + ctx := context.Background() + _, err := handler.CreateEvaluator(ctx, request) + assert.Error(t, err) + + // 验证错误类型 + statusErr, ok := errorx.FromStatusError(err) + assert.True(t, ok) + assert.Equal(t, int32(errno.RiskContentDetectedCode), statusErr.Code()) + }, + }, + { + name: "并发安全和数据一致性测试", + testFunc: func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) + mockUserInfoService := userinfomocks.NewMockUserInfoService(ctrl) + + handler := &EvaluatorHandlerImpl{ + auth: mockAuth, + evaluatorService: mockEvaluatorService, + userInfoService: mockUserInfoService, + } + + // 模拟并发访问同一个评估器 + evaluatorID := int64(123) + spaceID := int64(456) + + evaluator := &entity.Evaluator{ + ID: evaluatorID, + SpaceID: spaceID, + Name: "并发测试评估器", + } + + // 设置并发调用的期望 + mockEvaluatorService.EXPECT(). + GetEvaluator(gomock.Any(), spaceID, evaluatorID, false). + Return(evaluator, nil). + Times(10) // 10个并发请求 + + mockAuth.EXPECT(). + Authorization(gomock.Any(), gomock.Any()). + Return(nil). + Times(10) + + mockUserInfoService.EXPECT(). + PackUserInfo(gomock.Any(), gomock.Any()). + Times(10) + + // 并发调用 + const numGoroutines = 10 + results := make(chan error, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + ctx := context.Background() + request := &evaluatorservice.GetEvaluatorRequest{ + WorkspaceID: spaceID, + EvaluatorID: &evaluatorID, + } + + resp, err := handler.GetEvaluator(ctx, request) + if err != nil { + results <- err + return + } + + // 验证响应数据一致性 + if resp.Evaluator.GetEvaluatorID() != evaluatorID { + results <- fmt.Errorf("inconsistent evaluator ID: expected %d, got %d", + evaluatorID, resp.Evaluator.GetEvaluatorID()) + return + } + + results <- nil + }() + } + + // 收集结果 + for i := 0; i < numGoroutines; i++ { + select { + case err := <-results: + assert.NoError(t, err) + case <-time.After(5 * time.Second): + t.Fatal("Timeout waiting for concurrent calls") + } + } + }, + }, + { + name: "错误处理和恢复机制测试", + testFunc: func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) + mockEvaluatorRecordService := mocks.NewMockEvaluatorRecordService(ctrl) + + handler := &EvaluatorHandlerImpl{ + auth: mockAuth, + evaluatorService: mockEvaluatorService, + evaluatorRecordService: mockEvaluatorRecordService, + } + + // 测试运行评估器时的错误恢复 + request := &evaluatorservice.RunEvaluatorRequest{ + EvaluatorVersionID: 123, + WorkspaceID: 456, + InputData: &evaluatordto.EvaluatorInputData{ + InputFields: map[string]*common.Content{}, + }, + } + + // 第一次调用失败,第二次成功(模拟重试机制) + callCount := 0 + mockEvaluatorService.EXPECT(). + GetEvaluatorVersion(gomock.Any(), gomock.Any(), int64(123), false, gomock.Any()). + DoAndReturn(func(ctx context.Context, spaceID *int64, evaluatorVersionID int64, includeDeleted bool, withTags bool) (*entity.Evaluator, error) { + callCount++ + if callCount == 1 { + return nil, errors.New("temporary database error") + } + return &entity.Evaluator{ + ID: 1, + SpaceID: 456, + Name: "test-evaluator", + }, nil + }). + Times(2) + + mockAuth.EXPECT(). + Authorization(gomock.Any(), gomock.Any()). + Return(nil). + Times(1) + + mockEvaluatorService.EXPECT(). + RunEvaluator(gomock.Any(), gomock.Any()). + Return(&entity.EvaluatorRecord{ + ID: 789, + EvaluatorVersionID: 123, + SpaceID: 456, + }, nil). + Times(1) + + ctx := context.Background() + + // 第一次调用应该失败 + resp1, err1 := handler.RunEvaluator(ctx, request) + assert.Error(t, err1) + assert.Nil(t, resp1) + + // 第二次调用应该成功 + resp2, err2 := handler.RunEvaluator(ctx, request) + assert.NoError(t, err2) + assert.NotNil(t, resp2) + assert.Equal(t, int64(789), resp2.Record.GetID()) + }, + }, + { + name: "大数据量处理性能测试", + testFunc: func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) + mockUserInfoService := userinfomocks.NewMockUserInfoService(ctrl) + + handler := &EvaluatorHandlerImpl{ + auth: mockAuth, + evaluatorService: mockEvaluatorService, + userInfoService: mockUserInfoService, + } + + // 创建大量评估器数据 + const numEvaluators = 1000 + evaluators := make([]*entity.Evaluator, numEvaluators) + for i := 0; i < numEvaluators; i++ { + evaluators[i] = &entity.Evaluator{ + ID: int64(i + 1), + SpaceID: 123, + Name: fmt.Sprintf("evaluator-%d", i+1), } } @@ -1776,7 +2328,6 @@ func TestEvaluatorHandlerImpl_ComplexBusinessScenarios(t *testing.T) { } } -// TestEvaluatorHandlerImpl_EdgeCasesAndBoundaryConditions 测试边界条件 func TestEvaluatorHandlerImpl_EdgeCasesAndBoundaryConditions(t *testing.T) { t.Parallel() @@ -1990,6 +2541,7 @@ func TestEvaluatorHandlerImpl_ListTemplates_Code(t *testing.T) { tests := []struct { name string request *evaluatorservice.ListTemplatesRequest + mockSetup func() expectedKeys []string }{ { @@ -2006,11 +2558,55 @@ func TestEvaluatorHandlerImpl_ListTemplates_Code(t *testing.T) { }, expectedKeys: []string{"js_template_1", "python_template_1", "python_template_2"}, // 按template_key去重后排序 }, + { + name: "Code类型-配置为空", + request: &evaluatorservice.ListTemplatesRequest{ + BuiltinTemplateType: evaluatordto.TemplateType_Code, + }, + mockSetup: func() { + mockConfiger.EXPECT().GetCodeEvaluatorTemplateConf(gomock.Any()).Return(nil) + }, + expectedKeys: []string{}, + }, + { + name: "Prompt类型", + request: &evaluatorservice.ListTemplatesRequest{ + BuiltinTemplateType: evaluatordto.TemplateType_Prompt, + }, + mockSetup: func() { + promptTemplates := map[string]map[string]*evaluatordto.EvaluatorContent{ + "prompt": { + "key1": { + PromptEvaluator: &evaluatordto.PromptEvaluator{ + PromptTemplateKey: gptr.Of("key1"), + PromptTemplateName: gptr.Of("name1"), + }, + }, + }, + } + mockConfiger.EXPECT().GetEvaluatorTemplateConf(gomock.Any()).Return(promptTemplates) + }, + expectedKeys: []string{"key1"}, + }, + { + name: "Prompt类型-配置为空", + request: &evaluatorservice.ListTemplatesRequest{ + BuiltinTemplateType: evaluatordto.TemplateType_Prompt, + }, + mockSetup: func() { + mockConfiger.EXPECT().GetEvaluatorTemplateConf(gomock.Any()).Return(make(map[string]map[string]*evaluatordto.EvaluatorContent)) + }, + expectedKeys: []string{}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mockConfiger.EXPECT().GetCodeEvaluatorTemplateConf(gomock.Any()).Return(codeTemplateConf) + if tt.mockSetup != nil { + tt.mockSetup() + } else { + mockConfiger.EXPECT().GetCodeEvaluatorTemplateConf(gomock.Any()).Return(codeTemplateConf) + } resp, err := handler.ListTemplates(context.Background(), tt.request) @@ -2023,6 +2619,8 @@ func TestEvaluatorHandlerImpl_ListTemplates_Code(t *testing.T) { for i, template := range resp.BuiltinTemplateKeys { if template.GetCodeEvaluator() != nil { actualKeys[i] = template.GetCodeEvaluator().GetCodeTemplateKey() + } else if template.GetPromptEvaluator() != nil { + actualKeys[i] = template.GetPromptEvaluator().GetPromptTemplateKey() } } @@ -2033,16 +2631,295 @@ func TestEvaluatorHandlerImpl_ListTemplates_Code(t *testing.T) { } } -// 新增:运行配置参数透传与扩展字段注入 -func TestEvaluatorHandlerImpl_DebugEvaluator_RuntimeParamExt(t *testing.T) { - t.Skip("暂时跳过:依赖外部 benefitService 行为,已通过 buildRunEvaluatorRequest 的单测验证 runtime_param 注入") +func TestEvaluatorHandlerImpl_GetTemplateInfo(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) - mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) - mockBenefitService := benefitmocks.NewMockIBenefitService(ctrl) - + mockConfiger := confmocks.NewMockIConfiger(ctrl) + handler := &EvaluatorHandlerImpl{ + configer: mockConfiger, + } + + ctx := context.Background() + + t.Run("Prompt success", func(t *testing.T) { + promptTemplates := map[string]map[string]*evaluatordto.EvaluatorContent{ + "prompt": { + "key1": { + PromptEvaluator: &evaluatordto.PromptEvaluator{ + PromptTemplateKey: gptr.Of("key1"), + }, + }, + }, + } + mockConfiger.EXPECT().GetEvaluatorTemplateConf(gomock.Any()).Return(promptTemplates) + resp, err := handler.GetTemplateInfo(ctx, &evaluatorservice.GetTemplateInfoRequest{ + BuiltinTemplateType: evaluatordto.TemplateType_Prompt, + BuiltinTemplateKey: "key1", + }) + assert.NoError(t, err) + assert.Equal(t, "key1", *resp.EvaluatorContent.PromptEvaluator.PromptTemplateKey) + }) + + t.Run("Code success python", func(t *testing.T) { + codeTemplates := map[string]map[string]*evaluatordto.EvaluatorContent{ + "key1": { + "Python": { + CodeEvaluator: &evaluatordto.CodeEvaluator{ + CodeTemplateKey: gptr.Of("key1"), + }, + }, + }, + } + mockConfiger.EXPECT().GetCodeEvaluatorTemplateConf(gomock.Any()).Return(codeTemplates) + resp, err := handler.GetTemplateInfo(ctx, &evaluatorservice.GetTemplateInfoRequest{ + BuiltinTemplateType: evaluatordto.TemplateType_Code, + BuiltinTemplateKey: "key1", + }) + assert.NoError(t, err) + assert.Equal(t, "key1", *resp.EvaluatorContent.CodeEvaluator.CodeTemplateKey) + }) + + t.Run("Code custom", func(t *testing.T) { + customTemplates := map[string]map[string]*evaluatordto.EvaluatorContent{ + "custom": { + "Python": { + CodeEvaluator: &evaluatordto.CodeEvaluator{ + CodeTemplateKey: gptr.Of("custom"), + }, + }, + }, + } + mockConfiger.EXPECT().GetCustomCodeEvaluatorTemplateConf(gomock.Any()).Return(customTemplates) + resp, err := handler.GetTemplateInfo(ctx, &evaluatorservice.GetTemplateInfoRequest{ + BuiltinTemplateType: evaluatordto.TemplateType_Code, + BuiltinTemplateKey: "custom", + }) + assert.NoError(t, err) + assert.Equal(t, "custom", *resp.EvaluatorContent.CodeEvaluator.CodeTemplateKey) + }) + + t.Run("not found", func(t *testing.T) { + mockConfiger.EXPECT().GetEvaluatorTemplateConf(gomock.Any()).Return(make(map[string]map[string]*evaluatordto.EvaluatorContent)) + _, err := handler.GetTemplateInfo(ctx, &evaluatorservice.GetTemplateInfoRequest{ + BuiltinTemplateType: evaluatordto.TemplateType_Prompt, + BuiltinTemplateKey: "non-existent", + }) + assert.Error(t, err) + }) +} + +// 新增:运行配置参数透传与扩展字段注入 +func TestEvaluatorHandlerImpl_DebugEvaluator_Comprehensive(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) + mockBenefitService := benefitmocks.NewMockIBenefitService(ctrl) + mockConfiger := confmocks.NewMockIConfiger(ctrl) + mockFileProvider := rpcmocks.NewMockIFileProvider(ctrl) + + app := &EvaluatorHandlerImpl{ + auth: mockAuth, + evaluatorService: mockEvaluatorService, + benefitService: mockBenefitService, + configer: mockConfiger, + fileProvider: mockFileProvider, + } + + workspaceID := int64(100) + ctx := context.Background() + + tests := []struct { + name string + req *evaluatorservice.DebugEvaluatorRequest + mockSetup func() + wantErr bool + }{ + { + name: "success_prompt", + req: &evaluatorservice.DebugEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorType: evaluatordto.EvaluatorType_Prompt, + EvaluatorContent: &evaluatordto.EvaluatorContent{ + PromptEvaluator: &evaluatordto.PromptEvaluator{}, + }, + }, + mockSetup: func() { + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockBenefitService.EXPECT().CheckEvaluatorBenefit(gomock.Any(), gomock.Any()). + Return(&benefit.CheckEvaluatorBenefitResult{DenyReason: nil}, nil) + mockEvaluatorService.EXPECT().DebugEvaluator(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), workspaceID). + Return(&entity.EvaluatorOutputData{}, nil) + }, + wantErr: false, + }, + { + name: "success_prompt_with_uris", + req: &evaluatorservice.DebugEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorType: evaluatordto.EvaluatorType_Prompt, + EvaluatorContent: &evaluatordto.EvaluatorContent{ + PromptEvaluator: &evaluatordto.PromptEvaluator{ + Tools: []*evaluatordto.Tool{ + {Function: &evaluatordto.Function{Name: "test_tool"}}, + }, + }, + }, + InputData: &evaluatordto.EvaluatorInputData{ + InputFields: map[string]*common.Content{ + "field1": { + ContentType: gptr.Of(common.ContentTypeImage), + Image: &common.Image{ + URI: gptr.Of("uri1"), + }, + }, + }, + }, + }, + mockSetup: func() { + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockBenefitService.EXPECT().CheckEvaluatorBenefit(gomock.Any(), gomock.Any()). + Return(&benefit.CheckEvaluatorBenefitResult{DenyReason: nil}, nil) + mockFileProvider.EXPECT().MGetFileURL(gomock.Any(), []string{"uri1"}).Return(map[string]string{"uri1": "url1"}, nil) + mockEvaluatorService.EXPECT().DebugEvaluator(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), workspaceID). + Return(&entity.EvaluatorOutputData{}, nil) + }, + wantErr: false, + }, + { + name: "success_prompt_with_multipart_uris", + req: &evaluatorservice.DebugEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorType: evaluatordto.EvaluatorType_Prompt, + EvaluatorContent: &evaluatordto.EvaluatorContent{ + PromptEvaluator: &evaluatordto.PromptEvaluator{ + Tools: []*evaluatordto.Tool{ + {Function: &evaluatordto.Function{Name: "test_tool"}}, + }, + }, + }, + InputData: &evaluatordto.EvaluatorInputData{ + InputFields: map[string]*common.Content{ + "field1": { + ContentType: gptr.Of(common.ContentTypeMultiPart), + MultiPart: []*common.Content{ + { + ContentType: gptr.Of(common.ContentTypeImage), + Image: &common.Image{ + URI: gptr.Of("uri1"), + }, + }, + }, + }, + }, + }, + }, + mockSetup: func() { + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockBenefitService.EXPECT().CheckEvaluatorBenefit(gomock.Any(), gomock.Any()). + Return(&benefit.CheckEvaluatorBenefitResult{DenyReason: nil}, nil) + mockFileProvider.EXPECT().MGetFileURL(gomock.Any(), []string{"uri1"}).Return(map[string]string{"uri1": "url1"}, nil) + mockEvaluatorService.EXPECT().DebugEvaluator(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), workspaceID). + Return(&entity.EvaluatorOutputData{}, nil) + }, + wantErr: false, + }, + { + name: "success_custom_rpc", + req: &evaluatorservice.DebugEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorType: evaluatordto.EvaluatorType_CustomRPC, + EvaluatorContent: &evaluatordto.EvaluatorContent{ + CustomRPCEvaluator: &evaluatordto.CustomRPCEvaluator{}, + }, + }, + mockSetup: func() { + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + // authCustomRPCEvaluatorContentWritable + mockConfiger.EXPECT().GetBuiltinEvaluatorSpaceConf(gomock.Any()).Return([]string{"100"}) + mockConfiger.EXPECT().CheckCustomRPCEvaluatorWritable(gomock.Any(), "100", []string{"100"}).Return(true, nil) + + mockBenefitService.EXPECT().CheckEvaluatorBenefit(gomock.Any(), gomock.Any()). + Return(&benefit.CheckEvaluatorBenefitResult{DenyReason: nil}, nil) + mockEvaluatorService.EXPECT().DebugEvaluator(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), workspaceID). + Return(&entity.EvaluatorOutputData{}, nil) + }, + wantErr: false, + }, + { + name: "benefit_denied", + req: &evaluatorservice.DebugEvaluatorRequest{ + WorkspaceID: workspaceID, + }, + mockSetup: func() { + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockBenefitService.EXPECT().CheckEvaluatorBenefit(gomock.Any(), gomock.Any()). + Return(&benefit.CheckEvaluatorBenefitResult{DenyReason: gptr.Of(benefit.DenyReason(1))}, nil) + }, + wantErr: true, + }, + { + name: "benefit_error", + req: &evaluatorservice.DebugEvaluatorRequest{ + WorkspaceID: workspaceID, + }, + mockSetup: func() { + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockBenefitService.EXPECT().CheckEvaluatorBenefit(gomock.Any(), gomock.Any()). + Return(nil, errors.New("benefit service error")) + }, + wantErr: true, + }, + { + name: "auth_failed", + req: &evaluatorservice.DebugEvaluatorRequest{ + WorkspaceID: workspaceID, + }, + mockSetup: func() { + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(errors.New("auth failed")) + }, + wantErr: true, + }, + { + name: "custom_rpc_auth_failed", + req: &evaluatorservice.DebugEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorType: evaluatordto.EvaluatorType_CustomRPC, + }, + mockSetup: func() { + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + // authCustomRPCEvaluatorContentWritable failed + mockConfiger.EXPECT().GetBuiltinEvaluatorSpaceConf(gomock.Any()).Return([]string{"100"}) + mockConfiger.EXPECT().CheckCustomRPCEvaluatorWritable(gomock.Any(), "100", []string{"100"}).Return(false, nil) + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mockSetup() + resp, err := app.DebugEvaluator(ctx, tt.req) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + } + }) + } +} + +func TestEvaluatorHandlerImpl_DebugEvaluator_RuntimeParamExt(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) + mockBenefitService := benefitmocks.NewMockIBenefitService(ctrl) + handler := &EvaluatorHandlerImpl{ auth: mockAuth, evaluatorService: mockEvaluatorService, @@ -3234,254 +4111,698 @@ func TestEvaluatorHandlerImpl_BatchDebugEvaluator(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - // 为每个测试用例创建独立的 mock - ctrl := gomock.NewController(t) - defer ctrl.Finish() + // 为每个测试用例创建独立的 mock + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) + mockBenefitService := benefitmocks.NewMockIBenefitService(ctrl) + mockFileProvider := rpcmocks.NewMockIFileProvider(ctrl) + + app := &EvaluatorHandlerImpl{ + auth: mockAuth, + benefitService: mockBenefitService, + evaluatorService: mockEvaluatorService, + fileProvider: mockFileProvider, + } + + tt.mockSetup(mockAuth, mockBenefitService, mockEvaluatorService, mockFileProvider) + + resp, err := app.BatchDebugEvaluator(context.Background(), tt.req) + + if tt.wantErr { + assert.Error(t, err) + if tt.wantErrCode != 0 { + statusErr, ok := errorx.FromStatusError(err) + assert.True(t, ok) + assert.Equal(t, tt.wantErrCode, statusErr.Code()) + } + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, len(tt.wantResp.EvaluatorOutputData), len(resp.EvaluatorOutputData)) + + // 验证结果数量 + assert.Equal(t, len(tt.wantResp.EvaluatorOutputData), len(resp.EvaluatorOutputData)) + + // 对于特定测试用例,验证错误处理逻辑 + if tt.name == "edge case - evaluator service returns nil output with error" { + assert.NotNil(t, resp.EvaluatorOutputData[0].EvaluatorRunError) + assert.Equal(t, int32(500), *resp.EvaluatorOutputData[0].EvaluatorRunError.Code) + assert.Equal(t, "code execution failed", *resp.EvaluatorOutputData[0].EvaluatorRunError.Message) + } + } + }) + } +} + +// TestEvaluatorHandlerImpl_ListTemplatesV2 测试 ListTemplatesV2 方法 +func TestEvaluatorHandlerImpl_ListTemplatesV2(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockTemplateService := mocks.NewMockEvaluatorTemplateService(ctrl) + + app := &EvaluatorHandlerImpl{ + evaluatorTemplateService: mockTemplateService, + } + + tests := []struct { + name string + req *evaluatorservice.ListTemplatesV2Request + mockSetup func() + wantResp *evaluatorservice.ListTemplatesV2Response + wantErr bool + wantErrCode int32 + }{ + { + name: "success - normal request", + req: &evaluatorservice.ListTemplatesV2Request{ + PageSize: gptr.Of(int32(20)), + PageNumber: gptr.Of(int32(1)), + }, + mockSetup: func() { + mockTemplateService.EXPECT(). + ListEvaluatorTemplate(gomock.Any(), gomock.Any()). + Return(&entity.ListEvaluatorTemplateResponse{ + Templates: []*entity.EvaluatorTemplate{ + { + ID: 1, + Name: "template1", + Description: "test template 1", + }, + { + ID: 2, + Name: "template2", + Description: "test template 2", + }, + }, + TotalCount: 2, + }, nil) + }, + wantResp: &evaluatorservice.ListTemplatesV2Response{ + Total: gptr.Of(int64(2)), + }, + wantErr: false, + }, + { + name: "success - with pagination", + req: &evaluatorservice.ListTemplatesV2Request{ + PageSize: gptr.Of(int32(10)), + PageNumber: gptr.Of(int32(2)), + }, + mockSetup: func() { + mockTemplateService.EXPECT(). + ListEvaluatorTemplate(gomock.Any(), gomock.Any()). + Return(&entity.ListEvaluatorTemplateResponse{ + Templates: []*entity.EvaluatorTemplate{}, + TotalCount: 25, + }, nil) + }, + wantResp: &evaluatorservice.ListTemplatesV2Response{ + Total: gptr.Of(int64(25)), + }, + wantErr: false, + }, + { + name: "success - with filter option", + req: &evaluatorservice.ListTemplatesV2Request{ + PageSize: gptr.Of(int32(20)), + PageNumber: gptr.Of(int32(1)), + FilterOption: &evaluatordto.EvaluatorFilterOption{}, + }, + mockSetup: func() { + mockTemplateService.EXPECT(). + ListEvaluatorTemplate(gomock.Any(), gomock.Any()). + Return(&entity.ListEvaluatorTemplateResponse{ + Templates: []*entity.EvaluatorTemplate{}, + TotalCount: 0, + }, nil) + }, + wantResp: &evaluatorservice.ListTemplatesV2Response{ + Total: gptr.Of(int64(0)), + }, + wantErr: false, + }, + { + name: "error - service failure", + req: &evaluatorservice.ListTemplatesV2Request{ + PageSize: gptr.Of(int32(20)), + PageNumber: gptr.Of(int32(1)), + }, + mockSetup: func() { + mockTemplateService.EXPECT(). + ListEvaluatorTemplate(gomock.Any(), gomock.Any()). + Return(nil, errorx.NewByCode(errno.CommonInternalErrorCode)) + }, + wantResp: nil, + wantErr: true, + wantErrCode: errno.CommonInternalErrorCode, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mockSetup() + + resp, err := app.ListTemplatesV2(context.Background(), tt.req) + + if tt.wantErr { + assert.Error(t, err) + if tt.wantErrCode != 0 { + statusErr, ok := errorx.FromStatusError(err) + assert.True(t, ok) + assert.Equal(t, tt.wantErrCode, statusErr.Code()) + } + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + if tt.wantResp.Total != nil { + assert.Equal(t, *tt.wantResp.Total, *resp.Total) + } + } + }) + } +} + +// TestEvaluatorHandlerImpl_GetTemplateV2 测试 GetTemplateV2 方法 +func TestEvaluatorHandlerImpl_GetTemplateV2(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockTemplateService := mocks.NewMockEvaluatorTemplateService(ctrl) + + app := &EvaluatorHandlerImpl{ + evaluatorTemplateService: mockTemplateService, + } + + templateID := int64(123) + template := &entity.EvaluatorTemplate{ + ID: templateID, + Name: "test template", + Description: "test description", + } + + tests := []struct { + name string + req *evaluatorservice.GetTemplateV2Request + mockSetup func() + wantResp *evaluatorservice.GetTemplateV2Response + wantErr bool + wantErrCode int32 + }{ + { + name: "success - normal request", + req: &evaluatorservice.GetTemplateV2Request{ + EvaluatorTemplateID: gptr.Of(templateID), + }, + mockSetup: func() { + mockTemplateService.EXPECT(). + GetEvaluatorTemplate(gomock.Any(), &entity.GetEvaluatorTemplateRequest{ + ID: templateID, + IncludeDeleted: false, + }). + Return(&entity.GetEvaluatorTemplateResponse{ + Template: template, + }, nil) + }, + wantResp: &evaluatorservice.GetTemplateV2Response{ + EvaluatorTemplate: evaluator.ConvertEvaluatorTemplateDO2DTO(template), + }, + wantErr: false, + }, + { + name: "custom code", + req: &evaluatorservice.GetTemplateV2Request{ + CustomCode: gptr.Of(true), + }, + mockSetup: func() { + customTemplates := map[string]map[string]*evaluatordto.EvaluatorContent{ + "custom": { + "Python": { + CodeEvaluator: &evaluatordto.CodeEvaluator{ + CodeContent: gptr.Of("print(1)"), + }, + }, + }, + } + mockConfiger := confmocks.NewMockIConfiger(ctrl) + mockConfiger.EXPECT().GetCustomCodeEvaluatorTemplateConf(gomock.Any()).Return(customTemplates) + app.configer = mockConfiger + }, + wantResp: &evaluatorservice.GetTemplateV2Response{ + EvaluatorTemplate: &evaluatordto.EvaluatorTemplate{ + EvaluatorType: evaluatordto.EvaluatorTypePtr(evaluatordto.EvaluatorType_Code), + EvaluatorContent: &evaluatordto.EvaluatorContent{ + CodeEvaluator: &evaluatordto.CodeEvaluator{ + Lang2CodeContent: map[string]string{"Python": "print(1)"}, + }, + }, + }, + }, + wantErr: false, + }, + { + name: "success - template not found", + req: &evaluatorservice.GetTemplateV2Request{ + EvaluatorTemplateID: gptr.Of(templateID), + }, + mockSetup: func() { + mockTemplateService.EXPECT(). + GetEvaluatorTemplate(gomock.Any(), gomock.Any()). + Return(&entity.GetEvaluatorTemplateResponse{ + Template: nil, + }, nil) + }, + wantResp: &evaluatorservice.GetTemplateV2Response{}, + wantErr: false, + }, + { + name: "error - service failure", + req: &evaluatorservice.GetTemplateV2Request{ + EvaluatorTemplateID: gptr.Of(templateID), + }, + mockSetup: func() { + mockTemplateService.EXPECT(). + GetEvaluatorTemplate(gomock.Any(), gomock.Any()). + Return(nil, errorx.NewByCode(errno.CommonInternalErrorCode)) + }, + wantResp: nil, + wantErr: true, + wantErrCode: errno.CommonInternalErrorCode, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mockSetup() + + resp, err := app.GetTemplateV2(context.Background(), tt.req) + + if tt.wantErr { + assert.Error(t, err) + if tt.wantErrCode != 0 { + statusErr, ok := errorx.FromStatusError(err) + assert.True(t, ok) + assert.Equal(t, tt.wantErrCode, statusErr.Code()) + } + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + if tt.wantResp.EvaluatorTemplate != nil { + assert.Equal(t, tt.wantResp.GetEvaluatorTemplate().GetID(), resp.GetEvaluatorTemplate().GetID()) + } + } + }) + } +} + +func TestEvaluatorHandlerImpl_CreateEvaluator_CustomRPC(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Setup mocks + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) + mockAuditClient := auditmocks.NewMockIAuditService(ctrl) + mockMetrics := metricsmock.NewMockEvaluatorExecMetrics(ctrl) + mockConfiger := confmocks.NewMockIConfiger(ctrl) + + app := &EvaluatorHandlerImpl{ + auth: mockAuth, + evaluatorService: mockEvaluatorService, + auditClient: mockAuditClient, + metrics: mockMetrics, + configer: mockConfiger, + } + + ctx := context.Background() + workspaceID := int64(123456) + + tests := []struct { + name string + evaluatorType evaluatordto.EvaluatorType + allowedSpaceIDs []string + checkWritable bool + checkError error + wantErr bool + wantErrCode int32 + }{ + { + name: "成功 - CustomRPC类型且空间有权限", + evaluatorType: evaluatordto.EvaluatorType_CustomRPC, + allowedSpaceIDs: []string{"123456", "789012"}, + checkWritable: true, + checkError: nil, + wantErr: false, + }, + { + name: "失败 - CustomRPC类型但空间无权限", + evaluatorType: evaluatordto.EvaluatorType_CustomRPC, + allowedSpaceIDs: []string{"789012", "345678"}, + checkWritable: false, + checkError: nil, + wantErr: true, + wantErrCode: errno.CommonInvalidParamCode, + }, + { + name: "失败 - CustomRPC类型但配置检查失败", + evaluatorType: evaluatordto.EvaluatorType_CustomRPC, + allowedSpaceIDs: []string{"123456"}, + checkWritable: false, + checkError: errors.New("配置检查失败"), + wantErr: true, + wantErrCode: 0, + }, + { + name: "成功 - 非CustomRPC类型无需额外权限校验", + evaluatorType: evaluatordto.EvaluatorType_Prompt, + allowedSpaceIDs: []string{}, + checkWritable: false, + checkError: nil, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var request *evaluatorservice.CreateEvaluatorRequest + if tt.evaluatorType == evaluatordto.EvaluatorType_CustomRPC { + request = &evaluatorservice.CreateEvaluatorRequest{ + Evaluator: &evaluatordto.Evaluator{ + WorkspaceID: gptr.Of(workspaceID), + Name: gptr.Of("测试CustomRPC评估器"), + Description: gptr.Of("测试描述"), + EvaluatorType: gptr.Of(tt.evaluatorType), + CurrentVersion: &evaluatordto.EvaluatorVersion{ + Version: gptr.Of("1.0.0"), + Description: gptr.Of("版本描述"), + EvaluatorContent: &evaluatordto.EvaluatorContent{ + CustomRPCEvaluator: &evaluatordto.CustomRPCEvaluator{ + ServiceName: gptr.Of("test.psm.service"), + AccessProtocol: evaluatordto.EvaluatorAccessProtocolRPC, + }, + }, + }, + }, + } + } else { + request = &evaluatorservice.CreateEvaluatorRequest{ + Evaluator: &evaluatordto.Evaluator{ + WorkspaceID: gptr.Of(workspaceID), + Name: gptr.Of("测试CustomRPC评估器"), + Description: gptr.Of("测试描述"), + EvaluatorType: gptr.Of(tt.evaluatorType), + CurrentVersion: &evaluatordto.EvaluatorVersion{ + Version: gptr.Of("1.0.0"), + Description: gptr.Of("版本描述"), + EvaluatorContent: &evaluatordto.EvaluatorContent{ + PromptEvaluator: &evaluatordto.PromptEvaluator{ + PromptTemplateKey: gptr.Of("test_template"), + }, + }, + }, + }, + } + } + + // Mock 基础权限校验 + mockAuth.EXPECT(). + Authorization(gomock.Any(), &rpc.AuthorizationParam{ + ObjectID: strconv.FormatInt(workspaceID, 10), + SpaceID: workspaceID, + ActionObjects: []*rpc.ActionObject{{Action: gptr.Of("createLoopEvaluator"), EntityType: gptr.Of(rpc.AuthEntityType_Space)}}, + }). + Return(nil). + Times(1) + + // Mock 机审 + mockAuditClient.EXPECT(). + Audit(gomock.Any(), gomock.Any()). + Return(audit.AuditRecord{AuditStatus: audit.AuditStatus_Approved}, nil). + Times(1) + + // 如果是CustomRPC类型,需要Mock额外的权限校验 + if tt.evaluatorType == evaluatordto.EvaluatorType_CustomRPC { + mockConfiger.EXPECT(). + GetBuiltinEvaluatorSpaceConf(gomock.Any()). + Return(tt.allowedSpaceIDs). + Times(1) + + mockConfiger.EXPECT(). + CheckCustomRPCEvaluatorWritable(gomock.Any(), strconv.FormatInt(workspaceID, 10), tt.allowedSpaceIDs). + Return(tt.checkWritable, tt.checkError). + Times(1) + } - mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) - mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) - mockBenefitService := benefitmocks.NewMockIBenefitService(ctrl) - mockFileProvider := rpcmocks.NewMockIFileProvider(ctrl) + // Mock 创建评估器 + if !tt.wantErr { + mockMetrics.EXPECT(). + EmitCreate(workspaceID, nil). + Times(1) - app := &EvaluatorHandlerImpl{ - auth: mockAuth, - benefitService: mockBenefitService, - evaluatorService: mockEvaluatorService, - fileProvider: mockFileProvider, + mockEvaluatorService.EXPECT(). + CreateEvaluator(gomock.Any(), gomock.Any(), gomock.Any()). + Return(int64(12345), nil). + Times(1) } - tt.mockSetup(mockAuth, mockBenefitService, mockEvaluatorService, mockFileProvider) - - resp, err := app.BatchDebugEvaluator(context.Background(), tt.req) + resp, err := app.CreateEvaluator(ctx, request) if tt.wantErr { assert.Error(t, err) if tt.wantErrCode != 0 { - statusErr, ok := errorx.FromStatusError(err) - assert.True(t, ok) - assert.Equal(t, tt.wantErrCode, statusErr.Code()) + if statusErr, ok := errorx.FromStatusError(err); ok { + assert.Equal(t, tt.wantErrCode, statusErr.Code()) + } } } else { assert.NoError(t, err) assert.NotNil(t, resp) - assert.Equal(t, len(tt.wantResp.EvaluatorOutputData), len(resp.EvaluatorOutputData)) - - // 验证结果数量 - assert.Equal(t, len(tt.wantResp.EvaluatorOutputData), len(resp.EvaluatorOutputData)) - - // 对于特定测试用例,验证错误处理逻辑 - if tt.name == "edge case - evaluator service returns nil output with error" { - assert.NotNil(t, resp.EvaluatorOutputData[0].EvaluatorRunError) - assert.Equal(t, int32(500), *resp.EvaluatorOutputData[0].EvaluatorRunError.Code) - assert.Equal(t, "code execution failed", *resp.EvaluatorOutputData[0].EvaluatorRunError.Message) - } + assert.Equal(t, int64(12345), gptr.Indirect(resp.EvaluatorID)) } }) } } -// TestEvaluatorHandlerImpl_ListTemplatesV2 测试 ListTemplatesV2 方法 -func TestEvaluatorHandlerImpl_ListTemplatesV2(t *testing.T) { - t.Parallel() - +func TestEvaluatorHandlerImpl_UpdateEvaluatorDraft_CustomRPC(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockTemplateService := mocks.NewMockEvaluatorTemplateService(ctrl) + // Setup mocks + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) + mockConfiger := confmocks.NewMockIConfiger(ctrl) + mockUserInfoService := userinfomocks.NewMockUserInfoService(ctrl) app := &EvaluatorHandlerImpl{ - evaluatorTemplateService: mockTemplateService, + auth: mockAuth, + evaluatorService: mockEvaluatorService, + configer: mockConfiger, + userInfoService: mockUserInfoService, } + ctx := context.Background() + workspaceID := int64(123456) + evaluatorID := int64(789) + tests := []struct { - name string - req *evaluatorservice.ListTemplatesV2Request - mockSetup func() - wantResp *evaluatorservice.ListTemplatesV2Response - wantErr bool - wantErrCode int32 + name string + spaceID int64 + allowedSpaceIDs []string + checkWritable bool + checkError error + wantErr bool + wantErrCode int32 }{ { - name: "success - normal request", - req: &evaluatorservice.ListTemplatesV2Request{ - PageSize: gptr.Of(int32(20)), - PageNumber: gptr.Of(int32(1)), - }, - mockSetup: func() { - mockTemplateService.EXPECT(). - ListEvaluatorTemplate(gomock.Any(), gomock.Any()). - Return(&entity.ListEvaluatorTemplateResponse{ - Templates: []*entity.EvaluatorTemplate{ - { - ID: 1, - Name: "template1", - Description: "test template 1", - }, - { - ID: 2, - Name: "template2", - Description: "test template 2", - }, - }, - TotalCount: 2, - }, nil) - }, - wantResp: &evaluatorservice.ListTemplatesV2Response{ - Total: gptr.Of(int64(2)), - }, - wantErr: false, - }, - { - name: "success - with pagination", - req: &evaluatorservice.ListTemplatesV2Request{ - PageSize: gptr.Of(int32(10)), - PageNumber: gptr.Of(int32(2)), - }, - mockSetup: func() { - mockTemplateService.EXPECT(). - ListEvaluatorTemplate(gomock.Any(), gomock.Any()). - Return(&entity.ListEvaluatorTemplateResponse{ - Templates: []*entity.EvaluatorTemplate{}, - TotalCount: 25, - }, nil) - }, - wantResp: &evaluatorservice.ListTemplatesV2Response{ - Total: gptr.Of(int64(25)), - }, - wantErr: false, - }, - { - name: "success - with filter option", - req: &evaluatorservice.ListTemplatesV2Request{ - PageSize: gptr.Of(int32(20)), - PageNumber: gptr.Of(int32(1)), - FilterOption: &evaluatordto.EvaluatorFilterOption{}, - }, - mockSetup: func() { - mockTemplateService.EXPECT(). - ListEvaluatorTemplate(gomock.Any(), gomock.Any()). - Return(&entity.ListEvaluatorTemplateResponse{ - Templates: []*entity.EvaluatorTemplate{}, - TotalCount: 0, - }, nil) - }, - wantResp: &evaluatorservice.ListTemplatesV2Response{ - Total: gptr.Of(int64(0)), - }, - wantErr: false, + name: "失败 - CustomRPC类型但空间无权限", + spaceID: workspaceID, + allowedSpaceIDs: []string{"789012", "345678"}, + checkWritable: false, + checkError: nil, + wantErr: true, + wantErrCode: errno.CommonInvalidParamCode, }, { - name: "error - service failure", - req: &evaluatorservice.ListTemplatesV2Request{ - PageSize: gptr.Of(int32(20)), - PageNumber: gptr.Of(int32(1)), - }, - mockSetup: func() { - mockTemplateService.EXPECT(). - ListEvaluatorTemplate(gomock.Any(), gomock.Any()). - Return(nil, errorx.NewByCode(errno.CommonInternalErrorCode)) - }, - wantResp: nil, - wantErr: true, - wantErrCode: errno.CommonInternalErrorCode, + name: "失败 - CustomRPC类型但配置检查失败", + spaceID: workspaceID, + allowedSpaceIDs: []string{"123456"}, + checkWritable: false, + checkError: errors.New("配置检查失败"), + wantErr: true, + wantErrCode: 0, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tt.mockSetup() + request := &evaluatorservice.UpdateEvaluatorDraftRequest{ + WorkspaceID: workspaceID, + EvaluatorID: evaluatorID, + EvaluatorType: evaluatordto.EvaluatorType_CustomRPC, + EvaluatorContent: &evaluatordto.EvaluatorContent{ + CustomRPCEvaluator: &evaluatordto.CustomRPCEvaluator{ + ServiceName: gptr.Of("test.psm.service"), + AccessProtocol: evaluatordto.EvaluatorAccessProtocolRPC, + }, + }, + } - resp, err := app.ListTemplatesV2(context.Background(), tt.req) + // Mock 获取评估器信息 + evaluatorDO := &entity.Evaluator{ + ID: evaluatorID, + SpaceID: tt.spaceID, + Name: "测试评估器", + } - if tt.wantErr { - assert.Error(t, err) - if tt.wantErrCode != 0 { - statusErr, ok := errorx.FromStatusError(err) - assert.True(t, ok) + mockEvaluatorService.EXPECT(). + GetEvaluator(gomock.Any(), workspaceID, evaluatorID, false). + Return(evaluatorDO, nil). + Times(1) + + // Mock 基础权限校验 + mockAuth.EXPECT(). + Authorization(gomock.Any(), &rpc.AuthorizationParam{ + ObjectID: strconv.FormatInt(evaluatorID, 10), + SpaceID: tt.spaceID, + ActionObjects: []*rpc.ActionObject{{Action: gptr.Of(consts.Edit), EntityType: gptr.Of(rpc.AuthEntityType_Evaluator)}}, + }). + Return(nil). + Times(1) + + // Mock 额外的权限校验 + mockConfiger.EXPECT(). + GetBuiltinEvaluatorSpaceConf(gomock.Any()). + Return(tt.allowedSpaceIDs). + Times(1) + + mockConfiger.EXPECT(). + CheckCustomRPCEvaluatorWritable(gomock.Any(), strconv.FormatInt(tt.spaceID, 10), tt.allowedSpaceIDs). + Return(tt.checkWritable, tt.checkError). + Times(1) + + resp, err := app.UpdateEvaluatorDraft(ctx, request) + + assert.Error(t, err) + if tt.wantErrCode != 0 { + if statusErr, ok := errorx.FromStatusError(err); ok { assert.Equal(t, tt.wantErrCode, statusErr.Code()) } - } else { - assert.NoError(t, err) - assert.NotNil(t, resp) - if tt.wantResp.Total != nil { - assert.Equal(t, *tt.wantResp.Total, *resp.Total) - } } + assert.Nil(t, resp) }) } } -// TestEvaluatorHandlerImpl_GetTemplateV2 测试 GetTemplateV2 方法 -func TestEvaluatorHandlerImpl_GetTemplateV2(t *testing.T) { +// TestEvaluatorHandlerImpl_CreateEvaluatorTemplate 测试 CreateEvaluatorTemplate 方法 +func TestEvaluatorHandlerImpl_CreateEvaluatorTemplate(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) defer ctrl.Finish() mockTemplateService := mocks.NewMockEvaluatorTemplateService(ctrl) + mockConfiger := confmocks.NewMockIConfiger(ctrl) + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) app := &EvaluatorHandlerImpl{ evaluatorTemplateService: mockTemplateService, + configer: mockConfiger, + auth: mockAuth, } - templateID := int64(123) - template := &entity.EvaluatorTemplate{ - ID: templateID, - Name: "test template", - Description: "test description", + workspaceID := int64(123) + templateDTO := &evaluatordto.EvaluatorTemplate{ + ID: gptr.Of(int64(1)), + WorkspaceID: gptr.Of(workspaceID), + Name: gptr.Of("test template"), + Description: gptr.Of("test description"), } tests := []struct { name string - req *evaluatorservice.GetTemplateV2Request + req *evaluatorservice.CreateEvaluatorTemplateRequest mockSetup func() - wantResp *evaluatorservice.GetTemplateV2Response + wantResp *evaluatorservice.CreateEvaluatorTemplateResponse wantErr bool wantErrCode int32 }{ { name: "success - normal request", - req: &evaluatorservice.GetTemplateV2Request{ - EvaluatorTemplateID: gptr.Of(templateID), + req: &evaluatorservice.CreateEvaluatorTemplateRequest{ + EvaluatorTemplate: templateDTO, }, mockSetup: func() { + mockAuth.EXPECT(). + Authorization(gomock.Any(), gomock.Any()). + Return(nil) + + mockConfiger.EXPECT(). + GetEvaluatorTemplateSpaceConf(gomock.Any()). + Return([]string{"123"}) + mockTemplateService.EXPECT(). - GetEvaluatorTemplate(gomock.Any(), &entity.GetEvaluatorTemplateRequest{ - ID: templateID, - IncludeDeleted: false, - }). - Return(&entity.GetEvaluatorTemplateResponse{ - Template: template, + CreateEvaluatorTemplate(gomock.Any(), gomock.Any()). + Return(&entity.CreateEvaluatorTemplateResponse{ + Template: evaluator.ConvertEvaluatorTemplateDTO2DO(templateDTO), }, nil) }, - wantResp: &evaluatorservice.GetTemplateV2Response{ - EvaluatorTemplate: evaluator.ConvertEvaluatorTemplateDO2DTO(template), + wantResp: &evaluatorservice.CreateEvaluatorTemplateResponse{ + EvaluatorTemplate: templateDTO, }, wantErr: false, }, { - name: "success - template not found", - req: &evaluatorservice.GetTemplateV2Request{ - EvaluatorTemplateID: gptr.Of(templateID), + name: "error - nil template", + req: &evaluatorservice.CreateEvaluatorTemplateRequest{ + EvaluatorTemplate: nil, + }, + mockSetup: func() {}, + wantResp: nil, + wantErr: true, + wantErrCode: errno.CommonInvalidParamCode, + }, + { + name: "error - auth failed", + req: &evaluatorservice.CreateEvaluatorTemplateRequest{ + EvaluatorTemplate: &evaluatordto.EvaluatorTemplate{ + ID: gptr.Of(int64(1)), + WorkspaceID: gptr.Of(int64(789)), // 不在允许列表中 + Name: gptr.Of("test template"), + Description: gptr.Of("test description"), + }, }, mockSetup: func() { - mockTemplateService.EXPECT(). - GetEvaluatorTemplate(gomock.Any(), gomock.Any()). - Return(&entity.GetEvaluatorTemplateResponse{ - Template: nil, - }, nil) + mockAuth.EXPECT(). + Authorization(gomock.Any(), gomock.Any()). + Return(errorx.NewByCode(errno.CommonNoPermissionCode)) }, - wantResp: &evaluatorservice.GetTemplateV2Response{}, - wantErr: false, + wantResp: nil, + wantErr: true, + wantErrCode: errno.CommonNoPermissionCode, }, { name: "error - service failure", - req: &evaluatorservice.GetTemplateV2Request{ - EvaluatorTemplateID: gptr.Of(templateID), + req: &evaluatorservice.CreateEvaluatorTemplateRequest{ + EvaluatorTemplate: templateDTO, }, mockSetup: func() { + mockAuth.EXPECT(). + Authorization(gomock.Any(), gomock.Any()). + Return(nil) + + mockConfiger.EXPECT(). + GetEvaluatorTemplateSpaceConf(gomock.Any()). + Return([]string{"123"}) + mockTemplateService.EXPECT(). - GetEvaluatorTemplate(gomock.Any(), gomock.Any()). + CreateEvaluatorTemplate(gomock.Any(), gomock.Any()). Return(nil, errorx.NewByCode(errno.CommonInternalErrorCode)) }, wantResp: nil, @@ -3494,7 +4815,7 @@ func TestEvaluatorHandlerImpl_GetTemplateV2(t *testing.T) { t.Run(tt.name, func(t *testing.T) { tt.mockSetup() - resp, err := app.GetTemplateV2(context.Background(), tt.req) + resp, err := app.CreateEvaluatorTemplate(context.Background(), tt.req) if tt.wantErr { assert.Error(t, err) @@ -3506,395 +4827,421 @@ func TestEvaluatorHandlerImpl_GetTemplateV2(t *testing.T) { } else { assert.NoError(t, err) assert.NotNil(t, resp) - if tt.wantResp.EvaluatorTemplate != nil { - assert.Equal(t, templateID, resp.GetEvaluatorTemplate().GetID()) - } + assert.NotNil(t, resp.EvaluatorTemplate) } }) } } -func TestEvaluatorHandlerImpl_CreateEvaluator_CustomRPC(t *testing.T) { +// TestEvaluatorHandlerImpl_UpdateEvaluatorTemplate 测试 UpdateEvaluatorTemplate 方法 +func TestEvaluatorHandlerImpl_UpdateEvaluatorTemplate(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) defer ctrl.Finish() - // Setup mocks - mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) - mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) - mockAuditClient := auditmocks.NewMockIAuditService(ctrl) - mockMetrics := metricsmock.NewMockEvaluatorExecMetrics(ctrl) + mockTemplateService := mocks.NewMockEvaluatorTemplateService(ctrl) mockConfiger := confmocks.NewMockIConfiger(ctrl) + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) app := &EvaluatorHandlerImpl{ - auth: mockAuth, - evaluatorService: mockEvaluatorService, - auditClient: mockAuditClient, - metrics: mockMetrics, - configer: mockConfiger, + evaluatorTemplateService: mockTemplateService, + configer: mockConfiger, + auth: mockAuth, } - ctx := context.Background() - workspaceID := int64(123456) + templateID := int64(123) + workspaceID := int64(456) + templateDTO := &evaluatordto.EvaluatorTemplate{ + ID: gptr.Of(templateID), + WorkspaceID: gptr.Of(workspaceID), + Name: gptr.Of("updated template"), + Description: gptr.Of("updated description"), + } tests := []struct { - name string - evaluatorType evaluatordto.EvaluatorType - allowedSpaceIDs []string - checkWritable bool - checkError error - wantErr bool - wantErrCode int32 + name string + req *evaluatorservice.UpdateEvaluatorTemplateRequest + mockSetup func() + wantResp *evaluatorservice.UpdateEvaluatorTemplateResponse + wantErr bool + wantErrCode int32 }{ { - name: "成功 - CustomRPC类型且空间有权限", - evaluatorType: evaluatordto.EvaluatorType_CustomRPC, - allowedSpaceIDs: []string{"123456", "789012"}, - checkWritable: true, - checkError: nil, - wantErr: false, - }, - { - name: "失败 - CustomRPC类型但空间无权限", - evaluatorType: evaluatordto.EvaluatorType_CustomRPC, - allowedSpaceIDs: []string{"789012", "345678"}, - checkWritable: false, - checkError: nil, - wantErr: true, - wantErrCode: errno.CommonInvalidParamCode, - }, - { - name: "失败 - CustomRPC类型但配置检查失败", - evaluatorType: evaluatordto.EvaluatorType_CustomRPC, - allowedSpaceIDs: []string{"123456"}, - checkWritable: false, - checkError: errors.New("配置检查失败"), - wantErr: true, - wantErrCode: 0, - }, - { - name: "成功 - 非CustomRPC类型无需额外权限校验", - evaluatorType: evaluatordto.EvaluatorType_Prompt, - allowedSpaceIDs: []string{}, - checkWritable: false, - checkError: nil, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var request *evaluatorservice.CreateEvaluatorRequest - if tt.evaluatorType == evaluatordto.EvaluatorType_CustomRPC { - request = &evaluatorservice.CreateEvaluatorRequest{ - Evaluator: &evaluatordto.Evaluator{ - WorkspaceID: gptr.Of(workspaceID), - Name: gptr.Of("测试CustomRPC评估器"), - Description: gptr.Of("测试描述"), - EvaluatorType: gptr.Of(tt.evaluatorType), - CurrentVersion: &evaluatordto.EvaluatorVersion{ - Version: gptr.Of("1.0.0"), - Description: gptr.Of("版本描述"), - EvaluatorContent: &evaluatordto.EvaluatorContent{ - CustomRPCEvaluator: &evaluatordto.CustomRPCEvaluator{ - ServiceName: gptr.Of("test.psm.service"), - AccessProtocol: evaluatordto.EvaluatorAccessProtocolRPC, - }, - }, - }, - }, - } - } else { - request = &evaluatorservice.CreateEvaluatorRequest{ - Evaluator: &evaluatordto.Evaluator{ - WorkspaceID: gptr.Of(workspaceID), - Name: gptr.Of("测试CustomRPC评估器"), - Description: gptr.Of("测试描述"), - EvaluatorType: gptr.Of(tt.evaluatorType), - CurrentVersion: &evaluatordto.EvaluatorVersion{ - Version: gptr.Of("1.0.0"), - Description: gptr.Of("版本描述"), - EvaluatorContent: &evaluatordto.EvaluatorContent{ - PromptEvaluator: &evaluatordto.PromptEvaluator{ - PromptTemplateKey: gptr.Of("test_template"), - }, - }, - }, - }, - } - } - - // Mock 基础权限校验 - mockAuth.EXPECT(). - Authorization(gomock.Any(), &rpc.AuthorizationParam{ - ObjectID: strconv.FormatInt(workspaceID, 10), - SpaceID: workspaceID, - ActionObjects: []*rpc.ActionObject{{Action: gptr.Of("createLoopEvaluator"), EntityType: gptr.Of(rpc.AuthEntityType_Space)}}, - }). - Return(nil). - Times(1) - - // Mock 机审 - mockAuditClient.EXPECT(). - Audit(gomock.Any(), gomock.Any()). - Return(audit.AuditRecord{AuditStatus: audit.AuditStatus_Approved}, nil). - Times(1) + name: "success - normal request", + req: &evaluatorservice.UpdateEvaluatorTemplateRequest{ + EvaluatorTemplateID: templateID, + EvaluatorTemplate: templateDTO, + }, + mockSetup: func() { + mockAuth.EXPECT(). + Authorization(gomock.Any(), gomock.Any()). + Return(nil) - // 如果是CustomRPC类型,需要Mock额外的权限校验 - if tt.evaluatorType == evaluatordto.EvaluatorType_CustomRPC { mockConfiger.EXPECT(). - GetBuiltinEvaluatorSpaceConf(gomock.Any()). - Return(tt.allowedSpaceIDs). - Times(1) + GetEvaluatorTemplateSpaceConf(gomock.Any()). + Return([]string{"456"}) + + mockTemplateService.EXPECT(). + UpdateEvaluatorTemplate(gomock.Any(), gomock.Any()). + Return(&entity.UpdateEvaluatorTemplateResponse{ + Template: evaluator.ConvertEvaluatorTemplateDTO2DO(templateDTO), + }, nil) + }, + wantResp: &evaluatorservice.UpdateEvaluatorTemplateResponse{ + EvaluatorTemplate: templateDTO, + }, + wantErr: false, + }, + { + name: "error - nil template", + req: &evaluatorservice.UpdateEvaluatorTemplateRequest{ + EvaluatorTemplateID: templateID, + EvaluatorTemplate: nil, + }, + mockSetup: func() {}, + wantResp: nil, + wantErr: true, + wantErrCode: errno.CommonInvalidParamCode, + }, + { + name: "error - auth failed", + req: &evaluatorservice.UpdateEvaluatorTemplateRequest{ + EvaluatorTemplateID: templateID, + EvaluatorTemplate: &evaluatordto.EvaluatorTemplate{ + ID: gptr.Of(templateID), + WorkspaceID: gptr.Of(int64(789)), // 不在允许列表中 + Name: gptr.Of("updated template"), + Description: gptr.Of("updated description"), + }, + }, + mockSetup: func() { + mockAuth.EXPECT(). + Authorization(gomock.Any(), gomock.Any()). + Return(errorx.NewByCode(errno.CommonNoPermissionCode)) + }, + wantResp: nil, + wantErr: true, + wantErrCode: errno.CommonNoPermissionCode, + }, + { + name: "error - service failure", + req: &evaluatorservice.UpdateEvaluatorTemplateRequest{ + EvaluatorTemplateID: templateID, + EvaluatorTemplate: templateDTO, + }, + mockSetup: func() { + mockAuth.EXPECT(). + Authorization(gomock.Any(), gomock.Any()). + Return(nil) mockConfiger.EXPECT(). - CheckCustomRPCEvaluatorWritable(gomock.Any(), strconv.FormatInt(workspaceID, 10), tt.allowedSpaceIDs). - Return(tt.checkWritable, tt.checkError). - Times(1) - } + GetEvaluatorTemplateSpaceConf(gomock.Any()). + Return([]string{"456"}) - // Mock 创建评估器 - if !tt.wantErr { - mockMetrics.EXPECT(). - EmitCreate(workspaceID, nil). - Times(1) + mockTemplateService.EXPECT(). + UpdateEvaluatorTemplate(gomock.Any(), gomock.Any()). + Return(nil, errorx.NewByCode(errno.CommonInternalErrorCode)) + }, + wantResp: nil, + wantErr: true, + wantErrCode: errno.CommonInternalErrorCode, + }, + } - mockEvaluatorService.EXPECT(). - CreateEvaluator(gomock.Any(), gomock.Any(), gomock.Any()). - Return(int64(12345), nil). - Times(1) - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mockSetup() - resp, err := app.CreateEvaluator(ctx, request) + resp, err := app.UpdateEvaluatorTemplate(context.Background(), tt.req) if tt.wantErr { assert.Error(t, err) if tt.wantErrCode != 0 { - if statusErr, ok := errorx.FromStatusError(err); ok { - assert.Equal(t, tt.wantErrCode, statusErr.Code()) - } + statusErr, ok := errorx.FromStatusError(err) + assert.True(t, ok) + assert.Equal(t, tt.wantErrCode, statusErr.Code()) } } else { assert.NoError(t, err) assert.NotNil(t, resp) - assert.Equal(t, int64(12345), gptr.Indirect(resp.EvaluatorID)) + assert.NotNil(t, resp.EvaluatorTemplate) } }) } } -func TestEvaluatorHandlerImpl_UpdateEvaluatorDraft_CustomRPC(t *testing.T) { +// TestEvaluatorHandlerImpl_DeleteEvaluatorTemplate 测试 DeleteEvaluatorTemplate 方法 +func TestEvaluatorHandlerImpl_DeleteEvaluatorTemplate(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) defer ctrl.Finish() - // Setup mocks - mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) - mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) + mockTemplateService := mocks.NewMockEvaluatorTemplateService(ctrl) mockConfiger := confmocks.NewMockIConfiger(ctrl) - mockUserInfoService := userinfomocks.NewMockUserInfoService(ctrl) + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) app := &EvaluatorHandlerImpl{ - auth: mockAuth, - evaluatorService: mockEvaluatorService, - configer: mockConfiger, - userInfoService: mockUserInfoService, + evaluatorTemplateService: mockTemplateService, + configer: mockConfiger, + auth: mockAuth, } - ctx := context.Background() - workspaceID := int64(123456) - evaluatorID := int64(789) + templateID := int64(123) + workspaceID := int64(456) + template := &entity.EvaluatorTemplate{ + ID: templateID, + SpaceID: workspaceID, + Name: "test template", + } tests := []struct { - name string - spaceID int64 - allowedSpaceIDs []string - checkWritable bool - checkError error - wantErr bool - wantErrCode int32 + name string + req *evaluatorservice.DeleteEvaluatorTemplateRequest + mockSetup func() + wantResp *evaluatorservice.DeleteEvaluatorTemplateResponse + wantErr bool + wantErrCode int32 }{ { - name: "失败 - CustomRPC类型但空间无权限", - spaceID: workspaceID, - allowedSpaceIDs: []string{"789012", "345678"}, - checkWritable: false, - checkError: nil, - wantErr: true, - wantErrCode: errno.CommonInvalidParamCode, + name: "success - normal request", + req: &evaluatorservice.DeleteEvaluatorTemplateRequest{ + EvaluatorTemplateID: templateID, + }, + mockSetup: func() { + mockTemplateService.EXPECT(). + GetEvaluatorTemplate(gomock.Any(), &entity.GetEvaluatorTemplateRequest{ + ID: templateID, + IncludeDeleted: false, + }). + Return(&entity.GetEvaluatorTemplateResponse{ + Template: template, + }, nil) + + mockAuth.EXPECT(). + Authorization(gomock.Any(), gomock.Any()). + Return(nil) + + mockConfiger.EXPECT(). + GetEvaluatorTemplateSpaceConf(gomock.Any()). + Return([]string{"456"}) + + mockTemplateService.EXPECT(). + DeleteEvaluatorTemplate(gomock.Any(), &entity.DeleteEvaluatorTemplateRequest{ + ID: templateID, + }). + Return(&entity.DeleteEvaluatorTemplateResponse{}, nil) + }, + wantResp: &evaluatorservice.DeleteEvaluatorTemplateResponse{}, + wantErr: false, }, { - name: "失败 - CustomRPC类型但配置检查失败", - spaceID: workspaceID, - allowedSpaceIDs: []string{"123456"}, - checkWritable: false, - checkError: errors.New("配置检查失败"), - wantErr: true, - wantErrCode: 0, + name: "error - template id is 0", + req: &evaluatorservice.DeleteEvaluatorTemplateRequest{ + EvaluatorTemplateID: 0, + }, + mockSetup: func() {}, + wantResp: nil, + wantErr: true, + wantErrCode: errno.CommonInvalidParamCode, + }, + { + name: "error - template not found", + req: &evaluatorservice.DeleteEvaluatorTemplateRequest{ + EvaluatorTemplateID: templateID, + }, + mockSetup: func() { + mockTemplateService.EXPECT(). + GetEvaluatorTemplate(gomock.Any(), gomock.Any()). + Return(&entity.GetEvaluatorTemplateResponse{ + Template: nil, + }, nil) + }, + wantResp: nil, + wantErr: true, + wantErrCode: errno.ResourceNotFoundCode, + }, + { + name: "error - auth failed", + req: &evaluatorservice.DeleteEvaluatorTemplateRequest{ + EvaluatorTemplateID: templateID, + }, + mockSetup: func() { + // 使用不在允许列表中的workspaceID的template + testTemplate := &entity.EvaluatorTemplate{ + ID: templateID, + SpaceID: 789, // 不在允许列表中 + Name: "test template", + } + mockTemplateService.EXPECT(). + GetEvaluatorTemplate(gomock.Any(), gomock.Any()). + Return(&entity.GetEvaluatorTemplateResponse{ + Template: testTemplate, + }, nil) + + mockAuth.EXPECT(). + Authorization(gomock.Any(), gomock.Any()). + Return(errorx.NewByCode(errno.CommonNoPermissionCode)) + }, + wantResp: nil, + wantErr: true, + wantErrCode: errno.CommonNoPermissionCode, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - request := &evaluatorservice.UpdateEvaluatorDraftRequest{ - WorkspaceID: workspaceID, - EvaluatorID: evaluatorID, - EvaluatorType: evaluatordto.EvaluatorType_CustomRPC, - EvaluatorContent: &evaluatordto.EvaluatorContent{ - CustomRPCEvaluator: &evaluatordto.CustomRPCEvaluator{ - ServiceName: gptr.Of("test.psm.service"), - AccessProtocol: evaluatordto.EvaluatorAccessProtocolRPC, - }, - }, - } - - // Mock 获取评估器信息 - evaluatorDO := &entity.Evaluator{ - ID: evaluatorID, - SpaceID: tt.spaceID, - Name: "测试评估器", - } - - mockEvaluatorService.EXPECT(). - GetEvaluator(gomock.Any(), workspaceID, evaluatorID, false). - Return(evaluatorDO, nil). - Times(1) - - // Mock 基础权限校验 - mockAuth.EXPECT(). - Authorization(gomock.Any(), &rpc.AuthorizationParam{ - ObjectID: strconv.FormatInt(evaluatorID, 10), - SpaceID: tt.spaceID, - ActionObjects: []*rpc.ActionObject{{Action: gptr.Of(consts.Edit), EntityType: gptr.Of(rpc.AuthEntityType_Evaluator)}}, - }). - Return(nil). - Times(1) - - // Mock 额外的权限校验 - mockConfiger.EXPECT(). - GetBuiltinEvaluatorSpaceConf(gomock.Any()). - Return(tt.allowedSpaceIDs). - Times(1) - - mockConfiger.EXPECT(). - CheckCustomRPCEvaluatorWritable(gomock.Any(), strconv.FormatInt(tt.spaceID, 10), tt.allowedSpaceIDs). - Return(tt.checkWritable, tt.checkError). - Times(1) + tt.mockSetup() - resp, err := app.UpdateEvaluatorDraft(ctx, request) + resp, err := app.DeleteEvaluatorTemplate(context.Background(), tt.req) - assert.Error(t, err) - if tt.wantErrCode != 0 { - if statusErr, ok := errorx.FromStatusError(err); ok { + if tt.wantErr { + assert.Error(t, err) + if tt.wantErrCode != 0 { + statusErr, ok := errorx.FromStatusError(err) + assert.True(t, ok) assert.Equal(t, tt.wantErrCode, statusErr.Code()) } + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) } - assert.Nil(t, resp) }) } } -// TestEvaluatorHandlerImpl_CreateEvaluatorTemplate 测试 CreateEvaluatorTemplate 方法 -func TestEvaluatorHandlerImpl_CreateEvaluatorTemplate(t *testing.T) { +// TestEvaluatorHandlerImpl_DebugBuiltinEvaluator 测试 DebugBuiltinEvaluator 方法 +func TestEvaluatorHandlerImpl_DebugBuiltinEvaluator(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) defer ctrl.Finish() - mockTemplateService := mocks.NewMockEvaluatorTemplateService(ctrl) - mockConfiger := confmocks.NewMockIConfiger(ctrl) + mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) app := &EvaluatorHandlerImpl{ - evaluatorTemplateService: mockTemplateService, - configer: mockConfiger, - auth: mockAuth, + evaluatorService: mockEvaluatorService, + auth: mockAuth, } - workspaceID := int64(123) - templateDTO := &evaluatordto.EvaluatorTemplate{ - ID: gptr.Of(int64(1)), - WorkspaceID: gptr.Of(workspaceID), - Name: gptr.Of("test template"), - Description: gptr.Of("test description"), + evaluatorID := int64(123) + workspaceID := int64(456) + builtinEvaluator := &entity.Evaluator{ + ID: evaluatorID, + SpaceID: workspaceID, + Name: "builtin evaluator", + Builtin: true, + } + + inputData := &evaluatordto.EvaluatorInputData{ + InputFields: map[string]*common.Content{ + "input": { + ContentType: gptr.Of(common.ContentTypeText), + Text: gptr.Of("test input"), + }, + }, + } + + outputData := &entity.EvaluatorOutputData{ + EvaluatorResult: &entity.EvaluatorResult{ + Score: gptr.Of(0.85), + Reasoning: "test result", + }, } tests := []struct { name string - req *evaluatorservice.CreateEvaluatorTemplateRequest + req *evaluatorservice.DebugBuiltinEvaluatorRequest mockSetup func() - wantResp *evaluatorservice.CreateEvaluatorTemplateResponse + wantResp *evaluatorservice.DebugBuiltinEvaluatorResponse wantErr bool wantErrCode int32 }{ { name: "success - normal request", - req: &evaluatorservice.CreateEvaluatorTemplateRequest{ - EvaluatorTemplate: templateDTO, + req: &evaluatorservice.DebugBuiltinEvaluatorRequest{ + EvaluatorID: evaluatorID, + WorkspaceID: workspaceID, + InputData: inputData, }, mockSetup: func() { mockAuth.EXPECT(). - Authorization(gomock.Any(), gomock.Any()). + Authorization(gomock.Any(), &rpc.AuthorizationParam{ + ObjectID: strconv.FormatInt(workspaceID, 10), + SpaceID: workspaceID, + ActionObjects: []*rpc.ActionObject{{Action: gptr.Of("listLoopEvaluator"), EntityType: gptr.Of(rpc.AuthEntityType_Space)}}, + }). Return(nil) - mockConfiger.EXPECT(). - GetEvaluatorTemplateSpaceConf(gomock.Any()). - Return([]string{"123"}) + mockEvaluatorService.EXPECT(). + GetBuiltinEvaluator(gomock.Any(), evaluatorID). + Return(builtinEvaluator, nil) - mockTemplateService.EXPECT(). - CreateEvaluatorTemplate(gomock.Any(), gomock.Any()). - Return(&entity.CreateEvaluatorTemplateResponse{ - Template: evaluator.ConvertEvaluatorTemplateDTO2DO(templateDTO), - }, nil) + mockEvaluatorService.EXPECT(). + DebugEvaluator(gomock.Any(), builtinEvaluator, gomock.Any(), gomock.Any(), gomock.Any()). + Return(outputData, nil) }, - wantResp: &evaluatorservice.CreateEvaluatorTemplateResponse{ - EvaluatorTemplate: templateDTO, + wantResp: &evaluatorservice.DebugBuiltinEvaluatorResponse{ + OutputData: evaluator.ConvertEvaluatorOutputDataDO2DTO(outputData), }, wantErr: false, }, { - name: "error - nil template", - req: &evaluatorservice.CreateEvaluatorTemplateRequest{ - EvaluatorTemplate: nil, + name: "error - auth failed", + req: &evaluatorservice.DebugBuiltinEvaluatorRequest{ + EvaluatorID: evaluatorID, + WorkspaceID: workspaceID, + InputData: inputData, + }, + mockSetup: func() { + mockAuth.EXPECT(). + Authorization(gomock.Any(), gomock.Any()). + Return(errorx.NewByCode(errno.CommonNoPermissionCode)) }, - mockSetup: func() {}, wantResp: nil, wantErr: true, - wantErrCode: errno.CommonInvalidParamCode, + wantErrCode: errno.CommonNoPermissionCode, }, { - name: "error - auth failed", - req: &evaluatorservice.CreateEvaluatorTemplateRequest{ - EvaluatorTemplate: &evaluatordto.EvaluatorTemplate{ - ID: gptr.Of(int64(1)), - WorkspaceID: gptr.Of(int64(789)), // 不在允许列表中 - Name: gptr.Of("test template"), - Description: gptr.Of("test description"), - }, + name: "error - evaluator not found", + req: &evaluatorservice.DebugBuiltinEvaluatorRequest{ + EvaluatorID: evaluatorID, + WorkspaceID: workspaceID, + InputData: inputData, }, mockSetup: func() { mockAuth.EXPECT(). Authorization(gomock.Any(), gomock.Any()). - Return(errorx.NewByCode(errno.CommonNoPermissionCode)) + Return(nil) + + mockEvaluatorService.EXPECT(). + GetBuiltinEvaluator(gomock.Any(), evaluatorID). + Return(nil, nil) }, wantResp: nil, wantErr: true, - wantErrCode: errno.CommonNoPermissionCode, + wantErrCode: errno.EvaluatorNotExistCode, }, { - name: "error - service failure", - req: &evaluatorservice.CreateEvaluatorTemplateRequest{ - EvaluatorTemplate: templateDTO, + name: "error - debug failure", + req: &evaluatorservice.DebugBuiltinEvaluatorRequest{ + EvaluatorID: evaluatorID, + WorkspaceID: workspaceID, + InputData: inputData, }, mockSetup: func() { mockAuth.EXPECT(). Authorization(gomock.Any(), gomock.Any()). Return(nil) - mockConfiger.EXPECT(). - GetEvaluatorTemplateSpaceConf(gomock.Any()). - Return([]string{"123"}) + mockEvaluatorService.EXPECT(). + GetBuiltinEvaluator(gomock.Any(), evaluatorID). + Return(builtinEvaluator, nil) - mockTemplateService.EXPECT(). - CreateEvaluatorTemplate(gomock.Any(), gomock.Any()). + mockEvaluatorService.EXPECT(). + DebugEvaluator(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, errorx.NewByCode(errno.CommonInternalErrorCode)) }, wantResp: nil, @@ -3907,7 +5254,7 @@ func TestEvaluatorHandlerImpl_CreateEvaluatorTemplate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { tt.mockSetup() - resp, err := app.CreateEvaluatorTemplate(context.Background(), tt.req) + resp, err := app.DebugBuiltinEvaluator(context.Background(), tt.req) if tt.wantErr { assert.Error(t, err) @@ -3919,282 +5266,485 @@ func TestEvaluatorHandlerImpl_CreateEvaluatorTemplate(t *testing.T) { } else { assert.NoError(t, err) assert.NotNil(t, resp) - assert.NotNil(t, resp.EvaluatorTemplate) + assert.NotNil(t, resp.OutputData) } }) } } -// TestEvaluatorHandlerImpl_UpdateEvaluatorTemplate 测试 UpdateEvaluatorTemplate 方法 -func TestEvaluatorHandlerImpl_UpdateEvaluatorTemplate(t *testing.T) { - t.Parallel() - +// TestEvaluatorHandlerImpl_UpdateEvaluatorRecord 测试 UpdateEvaluatorRecord 方法 +func TestEvaluatorHandlerImpl_UpdateEvaluatorRecord(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockTemplateService := mocks.NewMockEvaluatorTemplateService(ctrl) - mockConfiger := confmocks.NewMockIConfiger(ctrl) mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) + mockEvaluatorRecordService := mocks.NewMockEvaluatorRecordService(ctrl) + mockAuditClient := auditmocks.NewMockIAuditService(ctrl) + mockConfiger := confmocks.NewMockIConfiger(ctrl) app := &EvaluatorHandlerImpl{ - evaluatorTemplateService: mockTemplateService, - configer: mockConfiger, - auth: mockAuth, + auth: mockAuth, + evaluatorService: mockEvaluatorService, + evaluatorRecordService: mockEvaluatorRecordService, + auditClient: mockAuditClient, + configer: mockConfiger, } - templateID := int64(123) - workspaceID := int64(456) - templateDTO := &evaluatordto.EvaluatorTemplate{ - ID: gptr.Of(templateID), - WorkspaceID: gptr.Of(workspaceID), - Name: gptr.Of("updated template"), - Description: gptr.Of("updated description"), + recordID := int64(10) + versionID := int64(20) + spaceID := int64(100) + ctx := context.Background() + + record := &entity.EvaluatorRecord{ + ID: recordID, + EvaluatorVersionID: versionID, + SpaceID: spaceID, + } + evaluatorDO := &entity.Evaluator{ + ID: 1, + SpaceID: spaceID, + Builtin: false, + } + builtinEvaluatorDO := &entity.Evaluator{ + ID: 1, + SpaceID: spaceID, + Builtin: true, + } + + tests := []struct { + name string + req *evaluatorservice.UpdateEvaluatorRecordRequest + mockSetup func() + wantErr bool + }{ + { + name: "success_custom", + req: &evaluatorservice.UpdateEvaluatorRecordRequest{ + EvaluatorRecordID: recordID, + Correction: &evaluatordto.Correction{ + Score: gptr.Of(0.95), + }, + }, + mockSetup: func() { + mockEvaluatorRecordService.EXPECT().GetEvaluatorRecord(gomock.Any(), recordID, false).Return(record, nil) + mockEvaluatorService.EXPECT().GetEvaluatorVersion(gomock.Any(), gomock.Nil(), versionID, false, false).Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockAuditClient.EXPECT().Audit(gomock.Any(), gomock.Any()).Return(audit.AuditRecord{AuditStatus: audit.AuditStatus_Approved}, nil) + mockEvaluatorRecordService.EXPECT().CorrectEvaluatorRecord(gomock.Any(), record, gomock.Any()).Return(nil) + }, + wantErr: false, + }, + { + name: "success_builtin", + req: &evaluatorservice.UpdateEvaluatorRecordRequest{ + EvaluatorRecordID: recordID, + }, + mockSetup: func() { + mockEvaluatorRecordService.EXPECT().GetEvaluatorRecord(gomock.Any(), recordID, false).Return(record, nil) + mockEvaluatorService.EXPECT().GetEvaluatorVersion(gomock.Any(), gomock.Nil(), versionID, false, false).Return(builtinEvaluatorDO, nil) + mockConfiger.EXPECT().GetBuiltinEvaluatorSpaceConf(gomock.Any()).Return([]string{"100"}) + mockAuditClient.EXPECT().Audit(gomock.Any(), gomock.Any()).Return(audit.AuditRecord{AuditStatus: audit.AuditStatus_Approved}, nil) + mockEvaluatorRecordService.EXPECT().CorrectEvaluatorRecord(gomock.Any(), record, gomock.Any()).Return(nil) + }, + wantErr: false, + }, + { + name: "record_not_found", + req: &evaluatorservice.UpdateEvaluatorRecordRequest{ + EvaluatorRecordID: recordID, + }, + mockSetup: func() { + mockEvaluatorRecordService.EXPECT().GetEvaluatorRecord(gomock.Any(), recordID, false).Return(nil, nil) + }, + wantErr: true, + }, + { + name: "evaluator_not_found", + req: &evaluatorservice.UpdateEvaluatorRecordRequest{ + EvaluatorRecordID: recordID, + }, + mockSetup: func() { + mockEvaluatorRecordService.EXPECT().GetEvaluatorRecord(gomock.Any(), recordID, false).Return(record, nil) + mockEvaluatorService.EXPECT().GetEvaluatorVersion(gomock.Any(), gomock.Nil(), versionID, false, false).Return(nil, nil) + }, + wantErr: false, + }, + { + name: "audit_rejected", + req: &evaluatorservice.UpdateEvaluatorRecordRequest{ + EvaluatorRecordID: recordID, + }, + mockSetup: func() { + mockEvaluatorRecordService.EXPECT().GetEvaluatorRecord(gomock.Any(), recordID, false).Return(record, nil) + mockEvaluatorService.EXPECT().GetEvaluatorVersion(gomock.Any(), gomock.Nil(), versionID, false, false).Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockAuditClient.EXPECT().Audit(gomock.Any(), gomock.Any()).Return(audit.AuditRecord{AuditStatus: audit.AuditStatus_Rejected}, nil) + }, + wantErr: true, + }, + { + name: "audit_service_error_still_pass", + req: &evaluatorservice.UpdateEvaluatorRecordRequest{ + EvaluatorRecordID: recordID, + }, + mockSetup: func() { + mockEvaluatorRecordService.EXPECT().GetEvaluatorRecord(gomock.Any(), recordID, false).Return(record, nil) + mockEvaluatorService.EXPECT().GetEvaluatorVersion(gomock.Any(), gomock.Nil(), versionID, false, false).Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockAuditClient.EXPECT().Audit(gomock.Any(), gomock.Any()).Return(audit.AuditRecord{}, errors.New("audit error")) + mockEvaluatorRecordService.EXPECT().CorrectEvaluatorRecord(gomock.Any(), record, gomock.Any()).Return(nil) + }, + wantErr: false, + }, + { + name: "correct_service_error", + req: &evaluatorservice.UpdateEvaluatorRecordRequest{ + EvaluatorRecordID: recordID, + }, + mockSetup: func() { + mockEvaluatorRecordService.EXPECT().GetEvaluatorRecord(gomock.Any(), recordID, false).Return(record, nil) + mockEvaluatorService.EXPECT().GetEvaluatorVersion(gomock.Any(), gomock.Nil(), versionID, false, false).Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockAuditClient.EXPECT().Audit(gomock.Any(), gomock.Any()).Return(audit.AuditRecord{AuditStatus: audit.AuditStatus_Approved}, nil) + mockEvaluatorRecordService.EXPECT().CorrectEvaluatorRecord(gomock.Any(), record, gomock.Any()).Return(errors.New("db error")) + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mockSetup() + resp, err := app.UpdateEvaluatorRecord(ctx, tt.req) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + } + }) + } +} + +// TestEvaluatorHandlerImpl_UpdateBuiltinEvaluatorTags 测试 UpdateBuiltinEvaluatorTags 方法 + +func TestEvaluatorHandlerImpl_BatchDebugEvaluator_Comprehensive(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) + mockBenefitService := benefitmocks.NewMockIBenefitService(ctrl) + mockConfiger := confmocks.NewMockIConfiger(ctrl) + + app := &EvaluatorHandlerImpl{ + auth: mockAuth, + evaluatorService: mockEvaluatorService, + benefitService: mockBenefitService, + configer: mockConfiger, } + workspaceID := int64(100) + ctx := context.Background() + tests := []struct { - name string - req *evaluatorservice.UpdateEvaluatorTemplateRequest - mockSetup func() - wantResp *evaluatorservice.UpdateEvaluatorTemplateResponse - wantErr bool - wantErrCode int32 + name string + req *evaluatorservice.BatchDebugEvaluatorRequest + mockSetup func() + wantErr bool }{ { - name: "success - normal request", - req: &evaluatorservice.UpdateEvaluatorTemplateRequest{ - EvaluatorTemplateID: templateID, - EvaluatorTemplate: templateDTO, + name: "success_custom_rpc", + req: &evaluatorservice.BatchDebugEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorType: evaluatordto.EvaluatorType_CustomRPC, + InputData: []*evaluatordto.EvaluatorInputData{ + {}, + }, }, mockSetup: func() { - mockAuth.EXPECT(). - Authorization(gomock.Any(), gomock.Any()). - Return(nil) - - mockConfiger.EXPECT(). - GetEvaluatorTemplateSpaceConf(gomock.Any()). - Return([]string{"456"}) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + // authCustomRPCEvaluatorContentWritable + mockConfiger.EXPECT().GetBuiltinEvaluatorSpaceConf(gomock.Any()).Return([]string{"100"}) + mockConfiger.EXPECT().CheckCustomRPCEvaluatorWritable(gomock.Any(), "100", []string{"100"}).Return(true, nil) - mockTemplateService.EXPECT(). - UpdateEvaluatorTemplate(gomock.Any(), gomock.Any()). - Return(&entity.UpdateEvaluatorTemplateResponse{ - Template: evaluator.ConvertEvaluatorTemplateDTO2DO(templateDTO), - }, nil) - }, - wantResp: &evaluatorservice.UpdateEvaluatorTemplateResponse{ - EvaluatorTemplate: templateDTO, + mockBenefitService.EXPECT().CheckEvaluatorBenefit(gomock.Any(), gomock.Any()). + Return(&benefit.CheckEvaluatorBenefitResult{DenyReason: nil}, nil) + mockEvaluatorService.EXPECT().DebugEvaluator(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), workspaceID). + Return(&entity.EvaluatorOutputData{}, nil) }, wantErr: false, }, { - name: "error - nil template", - req: &evaluatorservice.UpdateEvaluatorTemplateRequest{ - EvaluatorTemplateID: templateID, - EvaluatorTemplate: nil, + name: "benefit_denied", + req: &evaluatorservice.BatchDebugEvaluatorRequest{ + WorkspaceID: workspaceID, }, - mockSetup: func() {}, - wantResp: nil, - wantErr: true, - wantErrCode: errno.CommonInvalidParamCode, + mockSetup: func() { + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockBenefitService.EXPECT().CheckEvaluatorBenefit(gomock.Any(), gomock.Any()). + Return(&benefit.CheckEvaluatorBenefitResult{DenyReason: gptr.Of(benefit.DenyReason(1))}, nil) + }, + wantErr: true, }, { - name: "error - auth failed", - req: &evaluatorservice.UpdateEvaluatorTemplateRequest{ - EvaluatorTemplateID: templateID, - EvaluatorTemplate: &evaluatordto.EvaluatorTemplate{ - ID: gptr.Of(templateID), - WorkspaceID: gptr.Of(int64(789)), // 不在允许列表中 - Name: gptr.Of("updated template"), - Description: gptr.Of("updated description"), + name: "success_with_runtime_param", + req: &evaluatorservice.BatchDebugEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorRunConf: &evaluatordto.EvaluatorRunConfig{ + EvaluatorRuntimeParam: &common.RuntimeParam{ + JSONValue: gptr.Of(`{"key":"val"}`), + }, + }, + InputData: []*evaluatordto.EvaluatorInputData{ + {}, }, }, mockSetup: func() { - mockAuth.EXPECT(). - Authorization(gomock.Any(), gomock.Any()). - Return(errorx.NewByCode(errno.CommonNoPermissionCode)) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockBenefitService.EXPECT().CheckEvaluatorBenefit(gomock.Any(), gomock.Any()). + Return(&benefit.CheckEvaluatorBenefitResult{DenyReason: nil}, nil) + mockEvaluatorService.EXPECT().DebugEvaluator(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), workspaceID). + Return(&entity.EvaluatorOutputData{}, nil) }, - wantResp: nil, - wantErr: true, - wantErrCode: errno.CommonNoPermissionCode, + wantErr: false, }, { - name: "error - service failure", - req: &evaluatorservice.UpdateEvaluatorTemplateRequest{ - EvaluatorTemplateID: templateID, - EvaluatorTemplate: templateDTO, + name: "success_with_debug_error", + req: &evaluatorservice.BatchDebugEvaluatorRequest{ + WorkspaceID: workspaceID, + InputData: []*evaluatordto.EvaluatorInputData{ + {}, + }, }, mockSetup: func() { - mockAuth.EXPECT(). - Authorization(gomock.Any(), gomock.Any()). - Return(nil) - - mockConfiger.EXPECT(). - GetEvaluatorTemplateSpaceConf(gomock.Any()). - Return([]string{"456"}) - - mockTemplateService.EXPECT(). - UpdateEvaluatorTemplate(gomock.Any(), gomock.Any()). - Return(nil, errorx.NewByCode(errno.CommonInternalErrorCode)) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockBenefitService.EXPECT().CheckEvaluatorBenefit(gomock.Any(), gomock.Any()). + Return(&benefit.CheckEvaluatorBenefitResult{DenyReason: nil}, nil) + mockEvaluatorService.EXPECT().DebugEvaluator(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), workspaceID). + Return(nil, errors.New("debug error")) }, - wantResp: nil, - wantErr: true, - wantErrCode: errno.CommonInternalErrorCode, + wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.mockSetup() - - resp, err := app.UpdateEvaluatorTemplate(context.Background(), tt.req) - + resp, err := app.BatchDebugEvaluator(ctx, tt.req) if tt.wantErr { assert.Error(t, err) - if tt.wantErrCode != 0 { - statusErr, ok := errorx.FromStatusError(err) - assert.True(t, ok) - assert.Equal(t, tt.wantErrCode, statusErr.Code()) - } } else { assert.NoError(t, err) assert.NotNil(t, resp) - assert.NotNil(t, resp.EvaluatorTemplate) } }) } } -// TestEvaluatorHandlerImpl_DeleteEvaluatorTemplate 测试 DeleteEvaluatorTemplate 方法 -func TestEvaluatorHandlerImpl_DeleteEvaluatorTemplate(t *testing.T) { - t.Parallel() - +func TestEvaluatorHandlerImpl_CreateEvaluator_Comprehensive(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockTemplateService := mocks.NewMockEvaluatorTemplateService(ctrl) - mockConfiger := confmocks.NewMockIConfiger(ctrl) mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) + mockAuditClient := auditmocks.NewMockIAuditService(ctrl) + mockConfiger := confmocks.NewMockIConfiger(ctrl) + mockMetrics := metricsmock.NewMockEvaluatorExecMetrics(ctrl) + mockFileProvider := rpcmocks.NewMockIFileProvider(ctrl) app := &EvaluatorHandlerImpl{ - evaluatorTemplateService: mockTemplateService, - configer: mockConfiger, - auth: mockAuth, + auth: mockAuth, + evaluatorService: mockEvaluatorService, + auditClient: mockAuditClient, + configer: mockConfiger, + metrics: mockMetrics, + fileProvider: mockFileProvider, } - templateID := int64(123) - workspaceID := int64(456) - template := &entity.EvaluatorTemplate{ - ID: templateID, - SpaceID: workspaceID, - Name: "test template", - } + workspaceID := int64(100) + ctx := context.Background() tests := []struct { - name string - req *evaluatorservice.DeleteEvaluatorTemplateRequest - mockSetup func() - wantResp *evaluatorservice.DeleteEvaluatorTemplateResponse - wantErr bool - wantErrCode int32 + name string + req *evaluatorservice.CreateEvaluatorRequest + mockSetup func() + wantErr bool }{ { - name: "success - normal request", - req: &evaluatorservice.DeleteEvaluatorTemplateRequest{ - EvaluatorTemplateID: templateID, + name: "success", + req: &evaluatorservice.CreateEvaluatorRequest{ + Evaluator: &evaluatordto.Evaluator{ + WorkspaceID: gptr.Of(workspaceID), + Name: gptr.Of("test"), + EvaluatorType: gptr.Of(evaluatordto.EvaluatorType_Prompt), + CurrentVersion: &evaluatordto.EvaluatorVersion{ + Version: gptr.Of("1.0.0"), + EvaluatorContent: &evaluatordto.EvaluatorContent{ + PromptEvaluator: &evaluatordto.PromptEvaluator{}, + }, + }, + }, }, mockSetup: func() { - mockTemplateService.EXPECT(). - GetEvaluatorTemplate(gomock.Any(), &entity.GetEvaluatorTemplateRequest{ - ID: templateID, - IncludeDeleted: false, - }). - Return(&entity.GetEvaluatorTemplateResponse{ - Template: template, - }, nil) - - mockAuth.EXPECT(). - Authorization(gomock.Any(), gomock.Any()). - Return(nil) - - mockConfiger.EXPECT(). - GetEvaluatorTemplateSpaceConf(gomock.Any()). - Return([]string{"456"}) - - mockTemplateService.EXPECT(). - DeleteEvaluatorTemplate(gomock.Any(), &entity.DeleteEvaluatorTemplateRequest{ - ID: templateID, - }). - Return(&entity.DeleteEvaluatorTemplateResponse{}, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockAuditClient.EXPECT().Audit(gomock.Any(), gomock.Any()).Return(audit.AuditRecord{AuditStatus: audit.AuditStatus_Approved}, nil) + mockEvaluatorService.EXPECT().CreateEvaluator(gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(1), nil) + mockMetrics.EXPECT().EmitCreate(gomock.Any(), gomock.Any()).AnyTimes() }, - wantResp: &evaluatorservice.DeleteEvaluatorTemplateResponse{}, - wantErr: false, + wantErr: false, }, { - name: "error - template id is 0", - req: &evaluatorservice.DeleteEvaluatorTemplateRequest{ - EvaluatorTemplateID: 0, + name: "req_nil", + req: nil, + mockSetup: func() { }, - mockSetup: func() {}, - wantResp: nil, - wantErr: true, - wantErrCode: errno.CommonInvalidParamCode, + wantErr: true, }, { - name: "error - template not found", - req: &evaluatorservice.DeleteEvaluatorTemplateRequest{ - EvaluatorTemplateID: templateID, + name: "workspace_id_zero", + req: &evaluatorservice.CreateEvaluatorRequest{ + Evaluator: &evaluatordto.Evaluator{ + WorkspaceID: gptr.Of(int64(0)), + }, }, mockSetup: func() { - mockTemplateService.EXPECT(). - GetEvaluatorTemplate(gomock.Any(), gomock.Any()). - Return(&entity.GetEvaluatorTemplateResponse{ - Template: nil, - }, nil) }, - wantResp: nil, - wantErr: true, - wantErrCode: errno.ResourceNotFoundCode, + wantErr: true, }, { - name: "error - auth failed", - req: &evaluatorservice.DeleteEvaluatorTemplateRequest{ - EvaluatorTemplateID: templateID, + name: "name_empty", + req: &evaluatorservice.CreateEvaluatorRequest{ + Evaluator: &evaluatordto.Evaluator{ + WorkspaceID: gptr.Of(workspaceID), + Name: gptr.Of(""), + }, }, mockSetup: func() { - // 使用不在允许列表中的workspaceID的template - testTemplate := &entity.EvaluatorTemplate{ - ID: templateID, - SpaceID: 789, // 不在允许列表中 - Name: "test template", - } - mockTemplateService.EXPECT(). - GetEvaluatorTemplate(gomock.Any(), gomock.Any()). - Return(&entity.GetEvaluatorTemplateResponse{ - Template: testTemplate, - }, nil) + }, + wantErr: true, + }, + { + name: "version_nil", + req: &evaluatorservice.CreateEvaluatorRequest{ + Evaluator: &evaluatordto.Evaluator{ + WorkspaceID: gptr.Of(workspaceID), + Name: gptr.Of("test"), + CurrentVersion: nil, + }, + }, + mockSetup: func() { + }, + wantErr: true, + }, + { + name: "content_nil", + req: &evaluatorservice.CreateEvaluatorRequest{ + Evaluator: &evaluatordto.Evaluator{ + WorkspaceID: gptr.Of(workspaceID), + Name: gptr.Of("test"), + CurrentVersion: &evaluatordto.EvaluatorVersion{ + Version: gptr.Of("1.0.0"), + EvaluatorContent: nil, + }, + }, + }, + mockSetup: func() { + }, + wantErr: true, + }, + { + name: "builtin_success", + req: &evaluatorservice.CreateEvaluatorRequest{ + Evaluator: &evaluatordto.Evaluator{ + WorkspaceID: gptr.Of(workspaceID), + Name: gptr.Of("test"), + Builtin: gptr.Of(true), + EvaluatorType: gptr.Of(evaluatordto.EvaluatorType_Prompt), + CurrentVersion: &evaluatordto.EvaluatorVersion{ + Version: gptr.Of("1.0.0"), + EvaluatorContent: &evaluatordto.EvaluatorContent{ + PromptEvaluator: &evaluatordto.PromptEvaluator{}, + }, + }, + }, + }, + mockSetup: func() { + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) - mockAuth.EXPECT(). - Authorization(gomock.Any(), gomock.Any()). - Return(errorx.NewByCode(errno.CommonNoPermissionCode)) + mockAuditClient.EXPECT().Audit(gomock.Any(), gomock.Any()).Return(audit.AuditRecord{AuditStatus: audit.AuditStatus_Approved}, nil) + mockEvaluatorService.EXPECT().CreateEvaluator(gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(1), nil) + mockMetrics.EXPECT().EmitCreate(gomock.Any(), gomock.Any()).AnyTimes() + }, + wantErr: false, + }, + { + name: "name_too_long", + req: &evaluatorservice.CreateEvaluatorRequest{ + Evaluator: &evaluatordto.Evaluator{ + WorkspaceID: gptr.Of(workspaceID), + Name: gptr.Of(strings.Repeat("a", consts.MaxEvaluatorNameLength+1)), + }, + }, + mockSetup: func() {}, + wantErr: true, + }, + { + name: "desc_too_long", + req: &evaluatorservice.CreateEvaluatorRequest{ + Evaluator: &evaluatordto.Evaluator{ + WorkspaceID: gptr.Of(workspaceID), + Name: gptr.Of("test"), + Description: gptr.Of(strings.Repeat("a", consts.MaxEvaluatorDescLength+1)), + }, + }, + mockSetup: func() {}, + wantErr: true, + }, + { + name: "version_too_long", + req: &evaluatorservice.CreateEvaluatorRequest{ + Evaluator: &evaluatordto.Evaluator{ + WorkspaceID: gptr.Of(workspaceID), + Name: gptr.Of("test"), + CurrentVersion: &evaluatordto.EvaluatorVersion{ + Version: gptr.Of(strings.Repeat("a", consts.MaxEvaluatorVersionLength+1)), + EvaluatorContent: &evaluatordto.EvaluatorContent{ + PromptEvaluator: &evaluatordto.PromptEvaluator{}, + }, + }, + }, + }, + mockSetup: func() {}, + wantErr: true, + }, + { + name: "version_desc_too_long", + req: &evaluatorservice.CreateEvaluatorRequest{ + Evaluator: &evaluatordto.Evaluator{ + WorkspaceID: gptr.Of(workspaceID), + Name: gptr.Of("test"), + CurrentVersion: &evaluatordto.EvaluatorVersion{ + Version: gptr.Of("1.0.0"), + Description: gptr.Of(strings.Repeat("a", consts.MaxEvaluatorVersionDescLength+1)), + EvaluatorContent: &evaluatordto.EvaluatorContent{ + PromptEvaluator: &evaluatordto.PromptEvaluator{}, + }, + }, + }, }, - wantResp: nil, - wantErr: true, - wantErrCode: errno.CommonNoPermissionCode, + mockSetup: func() {}, + wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.mockSetup() - - resp, err := app.DeleteEvaluatorTemplate(context.Background(), tt.req) - + resp, err := app.CreateEvaluator(ctx, tt.req) if tt.wantErr { assert.Error(t, err) - if tt.wantErrCode != 0 { - statusErr, ok := errorx.FromStatusError(err) - assert.True(t, ok) - assert.Equal(t, tt.wantErrCode, statusErr.Code()) - } } else { assert.NoError(t, err) assert.NotNil(t, resp) @@ -4203,283 +5753,344 @@ func TestEvaluatorHandlerImpl_DeleteEvaluatorTemplate(t *testing.T) { } } -// TestEvaluatorHandlerImpl_DebugBuiltinEvaluator 测试 DebugBuiltinEvaluator 方法 -func TestEvaluatorHandlerImpl_DebugBuiltinEvaluator(t *testing.T) { - t.Parallel() - +func TestEvaluatorHandlerImpl_DeleteEvaluator_Comprehensive(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) app := &EvaluatorHandlerImpl{ - evaluatorService: mockEvaluatorService, auth: mockAuth, + evaluatorService: mockEvaluatorService, } - evaluatorID := int64(123) - workspaceID := int64(456) - builtinEvaluator := &entity.Evaluator{ + workspaceID := int64(100) + evaluatorID := int64(1) + ctx := context.Background() + + evaluatorDO := &entity.Evaluator{ ID: evaluatorID, SpaceID: workspaceID, - Name: "builtin evaluator", - Builtin: true, - } - - inputData := &evaluatordto.EvaluatorInputData{ - InputFields: map[string]*common.Content{ - "input": { - ContentType: gptr.Of(common.ContentTypeText), - Text: gptr.Of("test input"), - }, - }, - } - - outputData := &entity.EvaluatorOutputData{ - EvaluatorResult: &entity.EvaluatorResult{ - Score: gptr.Of(0.85), - Reasoning: "test result", - }, } tests := []struct { - name string - req *evaluatorservice.DebugBuiltinEvaluatorRequest - mockSetup func() - wantResp *evaluatorservice.DebugBuiltinEvaluatorResponse - wantErr bool - wantErrCode int32 + name string + req *evaluatorservice.DeleteEvaluatorRequest + mockSetup func() + wantErr bool }{ { - name: "success - normal request", - req: &evaluatorservice.DebugBuiltinEvaluatorRequest{ - EvaluatorID: evaluatorID, + name: "success", + req: &evaluatorservice.DeleteEvaluatorRequest{ WorkspaceID: workspaceID, - InputData: inputData, + EvaluatorID: &evaluatorID, }, mockSetup: func() { - mockAuth.EXPECT(). - Authorization(gomock.Any(), &rpc.AuthorizationParam{ - ObjectID: strconv.FormatInt(workspaceID, 10), - SpaceID: workspaceID, - ActionObjects: []*rpc.ActionObject{{Action: gptr.Of("listLoopEvaluator"), EntityType: gptr.Of(rpc.AuthEntityType_Space)}}, - }). - Return(nil) - - mockEvaluatorService.EXPECT(). - GetBuiltinEvaluator(gomock.Any(), evaluatorID). - Return(builtinEvaluator, nil) - - mockEvaluatorService.EXPECT(). - DebugEvaluator(gomock.Any(), builtinEvaluator, gomock.Any(), gomock.Any(), gomock.Any()). - Return(outputData, nil) - }, - wantResp: &evaluatorservice.DebugBuiltinEvaluatorResponse{ - OutputData: evaluator.ConvertEvaluatorOutputDataDO2DTO(outputData), + mockEvaluatorService.EXPECT().BatchGetEvaluator(gomock.Any(), workspaceID, []int64{evaluatorID}, false). + Return([]*entity.Evaluator{evaluatorDO}, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockEvaluatorService.EXPECT().DeleteEvaluator(gomock.Any(), []int64{evaluatorID}, gomock.Any()).Return(nil) }, wantErr: false, }, { - name: "error - auth failed", - req: &evaluatorservice.DebugBuiltinEvaluatorRequest{ - EvaluatorID: evaluatorID, + name: "get_error", + req: &evaluatorservice.DeleteEvaluatorRequest{ WorkspaceID: workspaceID, - InputData: inputData, + EvaluatorID: &evaluatorID, }, mockSetup: func() { - mockAuth.EXPECT(). - Authorization(gomock.Any(), gomock.Any()). - Return(errorx.NewByCode(errno.CommonNoPermissionCode)) + mockEvaluatorService.EXPECT().BatchGetEvaluator(gomock.Any(), workspaceID, []int64{evaluatorID}, false). + Return(nil, errors.New("db error")) }, - wantResp: nil, - wantErr: true, - wantErrCode: errno.CommonNoPermissionCode, + wantErr: true, }, { - name: "error - evaluator not found", - req: &evaluatorservice.DebugBuiltinEvaluatorRequest{ - EvaluatorID: evaluatorID, + name: "auth_failed", + req: &evaluatorservice.DeleteEvaluatorRequest{ WorkspaceID: workspaceID, - InputData: inputData, + EvaluatorID: &evaluatorID, }, mockSetup: func() { - mockAuth.EXPECT(). - Authorization(gomock.Any(), gomock.Any()). - Return(nil) - - mockEvaluatorService.EXPECT(). - GetBuiltinEvaluator(gomock.Any(), evaluatorID). - Return(nil, nil) + mockEvaluatorService.EXPECT().BatchGetEvaluator(gomock.Any(), workspaceID, []int64{evaluatorID}, false). + Return([]*entity.Evaluator{evaluatorDO}, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(errors.New("auth failed")) }, - wantResp: nil, - wantErr: true, - wantErrCode: errno.EvaluatorNotExistCode, + wantErr: true, }, { - name: "error - debug failure", - req: &evaluatorservice.DebugBuiltinEvaluatorRequest{ - EvaluatorID: evaluatorID, + name: "delete_error", + req: &evaluatorservice.DeleteEvaluatorRequest{ WorkspaceID: workspaceID, - InputData: inputData, + EvaluatorID: &evaluatorID, }, mockSetup: func() { - mockAuth.EXPECT(). - Authorization(gomock.Any(), gomock.Any()). - Return(nil) - - mockEvaluatorService.EXPECT(). - GetBuiltinEvaluator(gomock.Any(), evaluatorID). - Return(builtinEvaluator, nil) - - mockEvaluatorService.EXPECT(). - DebugEvaluator(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(nil, errorx.NewByCode(errno.CommonInternalErrorCode)) + mockEvaluatorService.EXPECT().BatchGetEvaluator(gomock.Any(), workspaceID, []int64{evaluatorID}, false). + Return([]*entity.Evaluator{evaluatorDO}, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockEvaluatorService.EXPECT().DeleteEvaluator(gomock.Any(), []int64{evaluatorID}, gomock.Any()). + Return(errors.New("delete error")) }, - wantResp: nil, - wantErr: true, - wantErrCode: errno.CommonInternalErrorCode, + wantErr: true, + }, + { + name: "evaluator_not_found_skip_delete", + req: &evaluatorservice.DeleteEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorID: &evaluatorID, + }, + mockSetup: func() { + mockEvaluatorService.EXPECT().BatchGetEvaluator(gomock.Any(), workspaceID, []int64{evaluatorID}, false). + Return([]*entity.Evaluator{nil}, nil) + mockEvaluatorService.EXPECT().DeleteEvaluator(gomock.Any(), []int64{evaluatorID}, gomock.Any()).Return(nil) + }, + wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.mockSetup() - - resp, err := app.DebugBuiltinEvaluator(context.Background(), tt.req) - + resp, err := app.DeleteEvaluator(ctx, tt.req) if tt.wantErr { assert.Error(t, err) - if tt.wantErrCode != 0 { - statusErr, ok := errorx.FromStatusError(err) - assert.True(t, ok) - assert.Equal(t, tt.wantErrCode, statusErr.Code()) - } } else { assert.NoError(t, err) assert.NotNil(t, resp) - assert.NotNil(t, resp.OutputData) } }) } } -// TestEvaluatorHandlerImpl_UpdateEvaluatorRecord 测试 UpdateEvaluatorRecord 方法 -func TestEvaluatorHandlerImpl_UpdateEvaluatorRecord(t *testing.T) { - t.Parallel() +func TestEvaluatorHandlerImpl_UpdateEvaluator_Comprehensive(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) + mockAuditClient := auditmocks.NewMockIAuditService(ctrl) + mockConfiger := confmocks.NewMockIConfiger(ctrl) + mockMetrics := metricsmock.NewMockEvaluatorExecMetrics(ctrl) + mockFileProvider := rpcmocks.NewMockIFileProvider(ctrl) + + app := &EvaluatorHandlerImpl{ + auth: mockAuth, + evaluatorService: mockEvaluatorService, + auditClient: mockAuditClient, + configer: mockConfiger, + metrics: mockMetrics, + fileProvider: mockFileProvider, + } + + workspaceID := int64(100) + evaluatorID := int64(1) + ctx := context.Background() - const ( - workspaceID = int64(101) - evaluatorID = int64(202) - evaluatorVersionID = int64(303) - recordID = int64(404) - ) + evaluatorDO := &entity.Evaluator{ + ID: evaluatorID, + SpaceID: workspaceID, + } tests := []struct { - name string - evaluator *entity.Evaluator - setupAuth func(t *testing.T, mockAuth *rpcmocks.MockIAuthProvider, mockConfiger *confmocks.MockIConfiger) - expectConfig bool + name string + req *evaluatorservice.UpdateEvaluatorRequest + mockSetup func() + wantErr bool }{ { - name: "success - custom evaluator uses evaluator authorization", - evaluator: &entity.Evaluator{ - ID: evaluatorID, - SpaceID: workspaceID, - Builtin: false, + name: "success", + req: &evaluatorservice.UpdateEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorID: evaluatorID, + Name: gptr.Of("new name"), + }, + mockSetup: func() { + mockEvaluatorService.EXPECT().GetEvaluator(gomock.Any(), workspaceID, evaluatorID, false).Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockAuditClient.EXPECT().Audit(gomock.Any(), gomock.Any()).Return(audit.AuditRecord{AuditStatus: audit.AuditStatus_Approved}, nil) + mockEvaluatorService.EXPECT().UpdateEvaluatorMeta(gomock.Any(), gomock.Any()).Return(nil) + }, + wantErr: false, + }, + { + name: "req_nil", + req: nil, + mockSetup: func() { + }, + wantErr: true, + }, + { + name: "id_zero", + req: &evaluatorservice.UpdateEvaluatorRequest{ + EvaluatorID: 0, + }, + mockSetup: func() { + }, + wantErr: true, + }, + { + name: "workspace_id_zero", + req: &evaluatorservice.UpdateEvaluatorRequest{ + EvaluatorID: evaluatorID, + WorkspaceID: 0, + }, + mockSetup: func() { + }, + wantErr: true, + }, + { + name: "name_too_long", + req: &evaluatorservice.UpdateEvaluatorRequest{ + EvaluatorID: evaluatorID, + WorkspaceID: workspaceID, + Name: gptr.Of(strings.Repeat("a", 101)), + }, + mockSetup: func() { + }, + wantErr: true, + }, + { + name: "description_too_long", + req: &evaluatorservice.UpdateEvaluatorRequest{ + EvaluatorID: evaluatorID, + WorkspaceID: workspaceID, + Description: gptr.Of(strings.Repeat("a", 1001)), + }, + mockSetup: func() { + }, + wantErr: true, + }, + { + name: "evaluator_not_found", + req: &evaluatorservice.UpdateEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorID: evaluatorID, + }, + mockSetup: func() { + mockEvaluatorService.EXPECT().GetEvaluator(gomock.Any(), workspaceID, evaluatorID, false).Return(nil, nil) + }, + wantErr: true, + }, + { + name: "auth_failed", + req: &evaluatorservice.UpdateEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorID: evaluatorID, + }, + mockSetup: func() { + mockEvaluatorService.EXPECT().GetEvaluator(gomock.Any(), workspaceID, evaluatorID, false).Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(errors.New("auth failed")) + }, + wantErr: true, + }, + { + name: "builtin_auth_failed", + req: &evaluatorservice.UpdateEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorID: evaluatorID, + Builtin: gptr.Of(true), + }, + mockSetup: func() { + mockEvaluatorService.EXPECT().GetEvaluator(gomock.Any(), workspaceID, evaluatorID, false).Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) // First auth at line 385 + // authBuiltinManagement calls second auth at line 1839 + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(errors.New("builtin management auth failed")) + }, + wantErr: true, + }, + { + name: "audit_rejected", + req: &evaluatorservice.UpdateEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorID: evaluatorID, + }, + mockSetup: func() { + mockEvaluatorService.EXPECT().GetEvaluator(gomock.Any(), workspaceID, evaluatorID, false).Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockAuditClient.EXPECT().Audit(gomock.Any(), gomock.Any()).Return(audit.AuditRecord{AuditStatus: audit.AuditStatus_Rejected}, nil) + }, + wantErr: true, + }, + { + name: "audit_service_error_passed", + req: &evaluatorservice.UpdateEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorID: evaluatorID, }, - setupAuth: func(t *testing.T, mockAuth *rpcmocks.MockIAuthProvider, _ *confmocks.MockIConfiger) { - mockAuth.EXPECT(). - Authorization(gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, param *rpc.AuthorizationParam) error { - assert.Equal(t, strconv.FormatInt(evaluatorID, 10), param.ObjectID) - assert.Equal(t, workspaceID, param.SpaceID) - if assert.Len(t, param.ActionObjects, 1) { - assert.Equal(t, consts.Edit, gptr.Indirect(param.ActionObjects[0].Action)) - assert.Equal(t, rpc.AuthEntityType_Evaluator, gptr.Indirect(param.ActionObjects[0].EntityType)) - } - return nil - }) + mockSetup: func() { + mockEvaluatorService.EXPECT().GetEvaluator(gomock.Any(), workspaceID, evaluatorID, false).Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockAuditClient.EXPECT().Audit(gomock.Any(), gomock.Any()).Return(audit.AuditRecord{}, errors.New("audit error")) + mockEvaluatorService.EXPECT().UpdateEvaluatorMeta(gomock.Any(), gomock.Any()).Return(nil) }, + wantErr: false, }, { - name: "success - builtin evaluator uses builtin space validation", - evaluator: &entity.Evaluator{ - ID: evaluatorID, - SpaceID: workspaceID, - Builtin: true, + name: "success_with_prompt_content", + req: &evaluatorservice.UpdateEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorID: evaluatorID, + Name: gptr.Of("new name"), + Description: gptr.Of("new desc"), }, - setupAuth: func(t *testing.T, mockAuth *rpcmocks.MockIAuthProvider, mockConfiger *confmocks.MockIConfiger) { - // authWrite 为 false 时,不会调用 Authorization,只检查空间配置 - mockConfiger.EXPECT(). - GetBuiltinEvaluatorSpaceConf(gomock.Any()). - Return([]string{strconv.FormatInt(workspaceID, 10)}) + mockSetup: func() { + mockEvaluatorService.EXPECT().GetEvaluator(gomock.Any(), workspaceID, evaluatorID, false).Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockAuditClient.EXPECT().Audit(gomock.Any(), gomock.Any()).Return(audit.AuditRecord{AuditStatus: audit.AuditStatus_Approved}, nil) + mockEvaluatorService.EXPECT().UpdateEvaluatorMeta(gomock.Any(), gomock.Any()).Return(nil) + }, + wantErr: false, + }, + { + name: "success_custom_rpc", + req: &evaluatorservice.UpdateEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorID: evaluatorID, + EvaluatorType: evaluatordto.EvaluatorType_CustomRPC, + }, + mockSetup: func() { + mockEvaluatorService.EXPECT().GetEvaluator(gomock.Any(), workspaceID, evaluatorID, false).Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockAuditClient.EXPECT().Audit(gomock.Any(), gomock.Any()).Return(audit.AuditRecord{AuditStatus: audit.AuditStatus_Approved}, nil) + mockEvaluatorService.EXPECT().UpdateEvaluatorMeta(gomock.Any(), gomock.Any()).Return(nil) + }, + wantErr: false, + }, + { + name: "success_with_info_and_box_type", + req: &evaluatorservice.UpdateEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorID: evaluatorID, + EvaluatorInfo: &evaluatordto.EvaluatorInfo{ + Benchmark: gptr.Of("bench"), + }, + BoxType: gptr.Of(evaluatordto.EvaluatorBoxType("Black")), + }, + mockSetup: func() { + mockEvaluatorService.EXPECT().GetEvaluator(gomock.Any(), workspaceID, evaluatorID, false).Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockAuditClient.EXPECT().Audit(gomock.Any(), gomock.Any()).Return(audit.AuditRecord{AuditStatus: audit.AuditStatus_Approved}, nil) + mockEvaluatorService.EXPECT().UpdateEvaluatorMeta(gomock.Any(), gomock.Any()).Return(nil) }, + wantErr: false, }, } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) - mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) - mockEvaluatorRecordService := mocks.NewMockEvaluatorRecordService(ctrl) - mockAuditClient := auditmocks.NewMockIAuditService(ctrl) - mockConfiger := confmocks.NewMockIConfiger(ctrl) - - handler := &EvaluatorHandlerImpl{ - auth: mockAuth, - evaluatorService: mockEvaluatorService, - evaluatorRecordService: mockEvaluatorRecordService, - auditClient: mockAuditClient, - configer: mockConfiger, - } - - tt.setupAuth(t, mockAuth, mockConfiger) - - evaluatorRecord := &entity.EvaluatorRecord{ - ID: recordID, - EvaluatorVersionID: evaluatorVersionID, - } - mockEvaluatorRecordService.EXPECT(). - GetEvaluatorRecord(gomock.Any(), recordID, false). - Return(evaluatorRecord, nil) - - mockEvaluatorService.EXPECT(). - GetEvaluatorVersion(gomock.Any(), gomock.Nil(), evaluatorVersionID, false, false). - Return(tt.evaluator, nil) - - mockAuditClient.EXPECT(). - Audit(gomock.Any(), gomock.Any()). - Return(audit.AuditRecord{AuditStatus: audit.AuditStatus_Approved}, nil) - - mockEvaluatorRecordService.EXPECT(). - CorrectEvaluatorRecord(gomock.Any(), evaluatorRecord, gomock.Any()). - Return(nil) - - req := &evaluatorservice.UpdateEvaluatorRecordRequest{ - WorkspaceID: workspaceID, - EvaluatorRecordID: recordID, - Correction: &evaluatordto.Correction{ - Explain: gptr.Of("need update"), - UpdatedBy: gptr.Of("tester"), - }, + tt.mockSetup() + resp, err := app.UpdateEvaluator(ctx, tt.req) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) } - - resp, err := handler.UpdateEvaluatorRecord(context.Background(), req) - assert.NoError(t, err) - assert.NotNil(t, resp) - assert.NotNil(t, resp.Record) }) } } -// TestEvaluatorHandlerImpl_UpdateBuiltinEvaluatorTags 测试 UpdateBuiltinEvaluatorTags 方法 func TestEvaluatorHandlerImpl_UpdateBuiltinEvaluatorTags(t *testing.T) { t.Parallel() @@ -4974,7 +6585,7 @@ func TestEvaluatorHandlerImpl_UpdateEvaluatorDraft(t *testing.T) { {Role: entity.RoleSystem, Content: &entity.Content{Text: gptr.Of("old content")}}, }, ModelConfig: &entity.ModelConfig{ - ModelID: int64(1), + ModelID: gptr.Of(int64(1)), }, }, } @@ -5219,12 +6830,16 @@ func TestEvaluatorHandlerImpl_UpdateEvaluator(t *testing.T) { mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) mockAuditClient := auditmocks.NewMockIAuditService(ctrl) mockConfiger := confmocks.NewMockIConfiger(ctrl) + mockMetrics := metricsmock.NewMockEvaluatorExecMetrics(ctrl) + mockFileProvider := rpcmocks.NewMockIFileProvider(ctrl) app := &EvaluatorHandlerImpl{ auth: mockAuth, evaluatorService: mockEvaluatorService, auditClient: mockAuditClient, configer: mockConfiger, + metrics: mockMetrics, + fileProvider: mockFileProvider, } // Test data @@ -5718,3 +7333,300 @@ func TestEvaluatorHandlerImpl_BatchGetEvaluators(t *testing.T) { }) } } + +func TestEvaluatorHandlerImpl_RunEvaluator_Comprehensive(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) + mockConfiger := confmocks.NewMockIConfiger(ctrl) + + handler := &EvaluatorHandlerImpl{ + auth: mockAuth, + evaluatorService: mockEvaluatorService, + configer: mockConfiger, + } + + ctx := context.Background() + versionID := int64(123) + workspaceID := int64(456) + evaluatorName := "test-eval" + + evaluatorDO := &entity.Evaluator{ + ID: 1, + SpaceID: workspaceID, + Name: evaluatorName, + Builtin: false, + } + + builtinEvaluatorDO := &entity.Evaluator{ + ID: 2, + SpaceID: workspaceID, + Name: evaluatorName, + Builtin: true, + } + + tests := []struct { + name string + req *evaluatorservice.RunEvaluatorRequest + mockSetup func() + wantErr bool + errCode int32 + }{ + { + name: "success_normal", + req: &evaluatorservice.RunEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorVersionID: versionID, + }, + mockSetup: func() { + mockEvaluatorService.EXPECT().GetEvaluatorVersion(gomock.Any(), nil, versionID, false, false).Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockEvaluatorService.EXPECT().RunEvaluator(gomock.Any(), gomock.Any()).Return(&entity.EvaluatorRecord{ID: 789}, nil) + }, + }, + { + name: "success_builtin", + req: &evaluatorservice.RunEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorVersionID: versionID, + }, + mockSetup: func() { + mockEvaluatorService.EXPECT().GetEvaluatorVersion(gomock.Any(), nil, versionID, false, false).Return(builtinEvaluatorDO, nil) + // skips auth + mockEvaluatorService.EXPECT().RunEvaluator(gomock.Any(), gomock.Any()).Return(&entity.EvaluatorRecord{ID: 789}, nil) + }, + }, + { + name: "error_not_found", + req: &evaluatorservice.RunEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorVersionID: versionID, + }, + mockSetup: func() { + mockEvaluatorService.EXPECT().GetEvaluatorVersion(gomock.Any(), nil, versionID, false, false).Return(nil, nil) + }, + wantErr: true, + errCode: errno.EvaluatorNotExistCode, + }, + { + name: "error_get_version_failed", + req: &evaluatorservice.RunEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorVersionID: versionID, + }, + mockSetup: func() { + mockEvaluatorService.EXPECT().GetEvaluatorVersion(gomock.Any(), nil, versionID, false, false).Return(nil, errors.New("db error")) + }, + wantErr: true, + }, + { + name: "error_auth_failed", + req: &evaluatorservice.RunEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorVersionID: versionID, + }, + mockSetup: func() { + mockEvaluatorService.EXPECT().GetEvaluatorVersion(gomock.Any(), nil, versionID, false, false).Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(errorx.NewByCode(errno.CommonNoPermissionCode)) + }, + wantErr: true, + errCode: errno.CommonNoPermissionCode, + }, + { + name: "error_run_failed", + req: &evaluatorservice.RunEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorVersionID: versionID, + }, + mockSetup: func() { + mockEvaluatorService.EXPECT().GetEvaluatorVersion(gomock.Any(), nil, versionID, false, false).Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockEvaluatorService.EXPECT().RunEvaluator(gomock.Any(), gomock.Any()).Return(nil, errors.New("run error")) + }, + wantErr: true, + }, + { + name: "with_runtime_param", + req: &evaluatorservice.RunEvaluatorRequest{ + WorkspaceID: workspaceID, + EvaluatorVersionID: versionID, + EvaluatorRunConf: &evaluatordto.EvaluatorRunConfig{ + EvaluatorRuntimeParam: &common.RuntimeParam{ + JSONValue: gptr.Of(`{"key":"val"}`), + }, + }, + }, + mockSetup: func() { + mockEvaluatorService.EXPECT().GetEvaluatorVersion(gomock.Any(), nil, versionID, false, false).Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockEvaluatorService.EXPECT().RunEvaluator(gomock.Any(), gomock.Any()).Return(&entity.EvaluatorRecord{ID: 789}, nil) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mockSetup() + resp, err := handler.RunEvaluator(ctx, tt.req) + if tt.wantErr { + assert.Error(t, err) + if tt.errCode != 0 { + statusErr, _ := errorx.FromStatusError(err) + assert.Equal(t, tt.errCode, statusErr.Code()) + } + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + } + }) + } +} + +func TestEvaluatorHandlerImpl_UpdateEvaluatorRecord_Comprehensive(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockEvaluatorService := mocks.NewMockEvaluatorService(ctrl) + mockEvaluatorRecordService := mocks.NewMockEvaluatorRecordService(ctrl) + mockAuditClient := auditmocks.NewMockIAuditService(ctrl) + mockConfiger := confmocks.NewMockIConfiger(ctrl) + + handler := &EvaluatorHandlerImpl{ + auth: mockAuth, + evaluatorService: mockEvaluatorService, + evaluatorRecordService: mockEvaluatorRecordService, + auditClient: mockAuditClient, + configer: mockConfiger, + } + + ctx := context.Background() + recordID := int64(789) + workspaceID := int64(456) + versionID := int64(123) + + recordDO := &entity.EvaluatorRecord{ + ID: recordID, + EvaluatorVersionID: versionID, + SpaceID: workspaceID, + } + + evaluatorDO := &entity.Evaluator{ + ID: 1, + SpaceID: workspaceID, + Builtin: false, + } + + builtinEvaluatorDO := &entity.Evaluator{ + ID: 2, + SpaceID: workspaceID, + Builtin: true, + } + + tests := []struct { + name string + req *evaluatorservice.UpdateEvaluatorRecordRequest + mockSetup func() + wantErr bool + errCode int32 + }{ + { + name: "success_normal", + req: &evaluatorservice.UpdateEvaluatorRecordRequest{ + EvaluatorRecordID: recordID, + Correction: &evaluatordto.Correction{ + Score: gptr.Of(float64(0.85)), + Explain: gptr.Of("good"), + }, + }, + mockSetup: func() { + mockEvaluatorRecordService.EXPECT().GetEvaluatorRecord(gomock.Any(), recordID, false).Return(recordDO, nil) + mockEvaluatorService.EXPECT().GetEvaluatorVersion(gomock.Any(), nil, versionID, false, false).Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockAuditClient.EXPECT().Audit(gomock.Any(), gomock.Any()).Return(audit.AuditRecord{AuditStatus: audit.AuditStatus_Approved}, nil) + mockEvaluatorRecordService.EXPECT().CorrectEvaluatorRecord(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + }, + }, + { + name: "success_builtin", + req: &evaluatorservice.UpdateEvaluatorRecordRequest{ + EvaluatorRecordID: recordID, + }, + mockSetup: func() { + mockEvaluatorRecordService.EXPECT().GetEvaluatorRecord(gomock.Any(), recordID, false).Return(recordDO, nil) + mockEvaluatorService.EXPECT().GetEvaluatorVersion(gomock.Any(), nil, versionID, false, false).Return(builtinEvaluatorDO, nil) + mockConfiger.EXPECT().GetBuiltinEvaluatorSpaceConf(gomock.Any()).Return([]string{strconv.FormatInt(workspaceID, 10)}) + mockAuditClient.EXPECT().Audit(gomock.Any(), gomock.Any()).Return(audit.AuditRecord{AuditStatus: audit.AuditStatus_Approved}, nil) + mockEvaluatorRecordService.EXPECT().CorrectEvaluatorRecord(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + }, + }, + { + name: "error_record_not_found", + req: &evaluatorservice.UpdateEvaluatorRecordRequest{ + EvaluatorRecordID: recordID, + }, + mockSetup: func() { + mockEvaluatorRecordService.EXPECT().GetEvaluatorRecord(gomock.Any(), recordID, false).Return(nil, nil) + }, + wantErr: true, + errCode: errno.EvaluatorRecordNotFoundCode, + }, + { + name: "error_evaluator_not_found", + req: &evaluatorservice.UpdateEvaluatorRecordRequest{ + EvaluatorRecordID: recordID, + }, + mockSetup: func() { + mockEvaluatorRecordService.EXPECT().GetEvaluatorRecord(gomock.Any(), recordID, false).Return(recordDO, nil) + mockEvaluatorService.EXPECT().GetEvaluatorVersion(gomock.Any(), nil, versionID, false, false).Return(nil, nil) + }, + wantErr: false, // returns empty resp + }, + { + name: "error_audit_failed", + req: &evaluatorservice.UpdateEvaluatorRecordRequest{ + EvaluatorRecordID: recordID, + }, + mockSetup: func() { + mockEvaluatorRecordService.EXPECT().GetEvaluatorRecord(gomock.Any(), recordID, false).Return(recordDO, nil) + mockEvaluatorService.EXPECT().GetEvaluatorVersion(gomock.Any(), nil, versionID, false, false).Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockAuditClient.EXPECT().Audit(gomock.Any(), gomock.Any()).Return(audit.AuditRecord{AuditStatus: audit.AuditStatus_Rejected}, nil) + }, + wantErr: true, + errCode: errno.RiskContentDetectedCode, + }, + { + name: "audit_service_error_passed", + req: &evaluatorservice.UpdateEvaluatorRecordRequest{ + EvaluatorRecordID: recordID, + }, + mockSetup: func() { + mockEvaluatorRecordService.EXPECT().GetEvaluatorRecord(gomock.Any(), recordID, false).Return(recordDO, nil) + mockEvaluatorService.EXPECT().GetEvaluatorVersion(gomock.Any(), nil, versionID, false, false).Return(evaluatorDO, nil) + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockAuditClient.EXPECT().Audit(gomock.Any(), gomock.Any()).Return(audit.AuditRecord{}, errors.New("audit service down")) + mockEvaluatorRecordService.EXPECT().CorrectEvaluatorRecord(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mockSetup() + resp, err := handler.UpdateEvaluatorRecord(ctx, tt.req) + if tt.wantErr { + assert.Error(t, err) + if tt.errCode != 0 { + statusErr, _ := errorx.FromStatusError(err) + assert.Equal(t, tt.errCode, statusErr.Code()) + } + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + } + }) + } +} diff --git a/backend/modules/evaluation/application/experiment_app.go b/backend/modules/evaluation/application/experiment_app.go index 83c62ae3c..ccdcf4807 100644 --- a/backend/modules/evaluation/application/experiment_app.go +++ b/backend/modules/evaluation/application/experiment_app.go @@ -208,6 +208,15 @@ func (e *experimentApplication) BatchGetExperimentTemplate(ctx context.Context, session := entity.NewSession(ctx) logs.CtxInfo(ctx, "BatchGetExperimentTemplate template_ids: %v, workspace_id: %d", req.GetTemplateIds(), req.GetWorkspaceID()) + // 权限校验,与 ListExperimentTemplates 一致:空间级 listLoopExptTemplate + if err = e.auth.Authorization(ctx, &rpc.AuthorizationParam{ + ObjectID: strconv.FormatInt(req.GetWorkspaceID(), 10), + SpaceID: req.GetWorkspaceID(), + ActionObjects: []*rpc.ActionObject{{Action: gptr.Of(consts.ActionReadExptTemplate), EntityType: gptr.Of(rpc.AuthEntityType_Space)}}, + }); err != nil { + return nil, err + } + templateIDs := req.GetTemplateIds() if len(templateIDs) == 0 { return &expt.BatchGetExperimentTemplateResponse{ @@ -221,11 +230,6 @@ func (e *experimentApplication) BatchGetExperimentTemplate(ctx context.Context, return nil, err } - // 权限校验,抽象成与实验类似的批量鉴权方法 - if err := e.AuthReadExptTemplates(ctx, templates, req.GetWorkspaceID()); err != nil { - return nil, err - } - dtos := experiment.ToExptTemplateDTOs(templates) // 填充完整的用户信息 e.mPackExptTemplateUserInfo(ctx, dtos) @@ -248,23 +252,23 @@ func (e *experimentApplication) UpdateExperimentTemplate(ctx context.Context, re logs.CtxInfo(ctx, "UpdateExperimentTemplate template_id: %d, workspace_id: %d", templateID, workspaceID) - // 获取现有模板用于权限校验 - got, err := e.templateManager.Get(ctx, templateID, workspaceID, session) - if err != nil { + // 权限校验,与 ListExperimentTemplates 一致:空间级 listLoopExptTemplate + if err = e.auth.Authorization(ctx, &rpc.AuthorizationParam{ + ObjectID: strconv.FormatInt(workspaceID, 10), + SpaceID: workspaceID, + ActionObjects: []*rpc.ActionObject{{Action: gptr.Of(consts.ActionReadExptTemplate), EntityType: gptr.Of(rpc.AuthEntityType_Space)}}, + }); err != nil { return nil, err } - // 权限校验 - err = e.auth.AuthorizationWithoutSPI(ctx, &rpc.AuthorizationWithoutSPIParam{ - ObjectID: strconv.FormatInt(templateID, 10), - SpaceID: workspaceID, - ActionObjects: []*rpc.ActionObject{{Action: gptr.Of(consts.Edit), EntityType: gptr.Of(rpc.AuthEntityType_EvaluationExptTemplate)}}, - OwnerID: gptr.Of(got.GetCreatedBy()), - ResourceSpaceID: workspaceID, - }) + // 获取现有模板用于业务逻辑 + got, err := e.templateManager.Get(ctx, templateID, workspaceID, session) if err != nil { return nil, err } + if got == nil { + return nil, errorx.NewByCode(errno.ResourceNotFoundCode, errorx.WithExtraMsg("template not found")) + } // 转换请求参数 param, err := experiment.ConvertUpdateExptTemplateReq(req) @@ -299,23 +303,23 @@ func (e *experimentApplication) UpdateExperimentTemplateMeta(ctx context.Context logs.CtxInfo(ctx, "UpdateExperimentTemplateMeta template_id: %d, workspace_id: %d", templateID, workspaceID) - // 获取现有模板用于权限校验 - got, err := e.templateManager.Get(ctx, templateID, workspaceID, session) - if err != nil { + // 权限校验,与 ListExperimentTemplates 一致:空间级 listLoopExptTemplate + if err = e.auth.Authorization(ctx, &rpc.AuthorizationParam{ + ObjectID: strconv.FormatInt(workspaceID, 10), + SpaceID: workspaceID, + ActionObjects: []*rpc.ActionObject{{Action: gptr.Of(consts.ActionReadExptTemplate), EntityType: gptr.Of(rpc.AuthEntityType_Space)}}, + }); err != nil { return nil, err } - // 权限校验 - err = e.auth.AuthorizationWithoutSPI(ctx, &rpc.AuthorizationWithoutSPIParam{ - ObjectID: strconv.FormatInt(templateID, 10), - SpaceID: workspaceID, - ActionObjects: []*rpc.ActionObject{{Action: gptr.Of(consts.Edit), EntityType: gptr.Of(rpc.AuthEntityType_EvaluationExptTemplate)}}, - OwnerID: gptr.Of(got.GetCreatedBy()), - ResourceSpaceID: workspaceID, - }) + // 获取现有模板用于业务逻辑 + got, err := e.templateManager.Get(ctx, templateID, workspaceID, session) if err != nil { return nil, err } + if got == nil { + return nil, errorx.NewByCode(errno.ResourceNotFoundCode, errorx.WithExtraMsg("template not found")) + } // 转换请求参数 param, err := experiment.ConvertUpdateExptTemplateMetaReq(req) @@ -351,24 +355,12 @@ func (e *experimentApplication) DeleteExperimentTemplate(ctx context.Context, re session := entity.NewSession(ctx) logs.CtxInfo(ctx, "DeleteExperimentTemplate template_id: %d, workspace_id: %d", req.GetTemplateID(), req.GetWorkspaceID()) - // 获取现有模板用于权限校验 - existingTemplate, err := e.templateManager.Get(ctx, req.GetTemplateID(), req.GetWorkspaceID(), session) - if err != nil { - return nil, err - } - if existingTemplate == nil { - return nil, errorx.NewByCode(errno.CommonInvalidParamCode, errorx.WithExtraMsg("template not found")) - } - - // 权限校验 - err = e.auth.AuthorizationWithoutSPI(ctx, &rpc.AuthorizationWithoutSPIParam{ - ObjectID: strconv.FormatInt(req.GetTemplateID(), 10), - SpaceID: req.GetWorkspaceID(), - ActionObjects: []*rpc.ActionObject{{Action: gptr.Of(consts.Edit), EntityType: gptr.Of(rpc.AuthEntityType_EvaluationExptTemplate)}}, - OwnerID: gptr.Of(existingTemplate.GetCreatedBy()), - ResourceSpaceID: req.GetWorkspaceID(), - }) - if err != nil { + // 权限校验,与 ListExperimentTemplates 一致:空间级 listLoopExptTemplate + if err = e.auth.Authorization(ctx, &rpc.AuthorizationParam{ + ObjectID: strconv.FormatInt(req.GetWorkspaceID(), 10), + SpaceID: req.GetWorkspaceID(), + ActionObjects: []*rpc.ActionObject{{Action: gptr.Of(consts.ActionReadExptTemplate), EntityType: gptr.Of(rpc.AuthEntityType_Space)}}, + }); err != nil { return nil, err } @@ -1177,7 +1169,7 @@ func (e *experimentApplication) BatchGetExperimentAggrResult_(ctx context.Contex } return &expt.BatchGetExperimentAggrResultResponse{ - ExptAggregateResults: exptAggregateResultDTOs, + ExptAggregateResult_: exptAggregateResultDTOs, }, nil } @@ -1652,8 +1644,8 @@ func (e *experimentApplication) GetExptResultExportRecord(ctx context.Context, r } return &expt.GetExptResultExportRecordResponse{ - ExptResultExportRecord: experiment.ExportRecordDO2DTO(record), - BaseResp: base.NewBaseResp(), + ExptResultExportRecords: experiment.ExportRecordDO2DTO(record), + BaseResp: base.NewBaseResp(), }, nil } diff --git a/backend/modules/evaluation/application/experiment_app_test.go b/backend/modules/evaluation/application/experiment_app_test.go index 9147ee2bd..86ba0f37f 100644 --- a/backend/modules/evaluation/application/experiment_app_test.go +++ b/backend/modules/evaluation/application/experiment_app_test.go @@ -2889,20 +2889,13 @@ func TestExperimentApplication_CreateExperimentTemplate(t *testing.T) { } func TestExperimentApplication_BatchGetExperimentTemplate(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockTemplateManager := servicemocks.NewMockIExptTemplateManager(ctrl) - mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) - mockUserInfo := userinfomocks.NewMockUserInfoService(ctrl) - workspaceID := int64(1001) templateID := int64(2001) tests := []struct { name string req *exptpb.BatchGetExperimentTemplateRequest - mockSetup func() + mockSetup func(mockAuth *rpcmocks.MockIAuthProvider, mockTemplateManager *servicemocks.MockIExptTemplateManager, mockUserInfo *userinfomocks.MockUserInfoService) wantLen int wantErr bool }{ @@ -2912,8 +2905,11 @@ func TestExperimentApplication_BatchGetExperimentTemplate(t *testing.T) { WorkspaceID: workspaceID, TemplateIds: nil, }, - mockSetup: func() { - // 当前实现在 template_ids 为空时直接返回,不触发任何鉴权 / MGet 调用 + mockSetup: func(mockAuth *rpcmocks.MockIAuthProvider, mockTemplateManager *servicemocks.MockIExptTemplateManager, mockUserInfo *userinfomocks.MockUserInfoService) { + // 即使 ID 为空,也会先触发空间级鉴权 + mockAuth.EXPECT(). + Authorization(gomock.Any(), gomock.Any()). + Return(nil) }, wantLen: 0, wantErr: false, @@ -2924,7 +2920,7 @@ func TestExperimentApplication_BatchGetExperimentTemplate(t *testing.T) { WorkspaceID: workspaceID, TemplateIds: []int64{templateID}, }, - mockSetup: func() { + mockSetup: func(mockAuth *rpcmocks.MockIAuthProvider, mockTemplateManager *servicemocks.MockIExptTemplateManager, mockUserInfo *userinfomocks.MockUserInfoService) { templates := []*entity.ExptTemplate{ { Meta: &entity.ExptTemplateMeta{ @@ -2945,20 +2941,8 @@ func TestExperimentApplication_BatchGetExperimentTemplate(t *testing.T) { // 批量模板读权限校验 mockAuth.EXPECT(). - MAuthorizeWithoutSPI(gomock.Any(), workspaceID, gomock.Any()). - DoAndReturn(func(_ context.Context, spaceID int64, params []*rpc.AuthorizationWithoutSPIParam) error { - assert.Equal(t, workspaceID, spaceID) - assert.Len(t, params, 1) - p := params[0] - assert.Equal(t, strconv.FormatInt(templateID, 10), p.ObjectID) - assert.Equal(t, workspaceID, p.SpaceID) - assert.Len(t, p.ActionObjects, 1) - assert.Equal(t, consts.Read, *p.ActionObjects[0].Action) - assert.Equal(t, rpc.AuthEntityType_EvaluationExptTemplate, *p.ActionObjects[0].EntityType) - assert.Equal(t, "u1", *p.OwnerID) - assert.Equal(t, workspaceID, p.ResourceSpaceID) - return nil - }) + Authorization(gomock.Any(), gomock.Any()). + Return(nil) mockUserInfo.EXPECT().PackUserInfo(gomock.Any(), gomock.Any()) }, @@ -2969,7 +2953,14 @@ func TestExperimentApplication_BatchGetExperimentTemplate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tt.mockSetup() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockTemplateManager := servicemocks.NewMockIExptTemplateManager(ctrl) + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockUserInfo := userinfomocks.NewMockUserInfoService(ctrl) + + tt.mockSetup(mockAuth, mockTemplateManager, mockUserInfo) app := NewExperimentApplication( nil, // aggResultSvc nil, // resultSvc @@ -3070,19 +3061,10 @@ func TestExperimentApplication_UpdateExperimentTemplate(t *testing.T) { Get(gomock.Any(), templateID, workspaceID, gomock.Any()). Return(existing, nil) - // 使用 AuthorizationWithoutSPI 做模板级编辑权限校验 + // 使用 Authorization 做空间级模板读权限校验 mockAuth.EXPECT(). - AuthorizationWithoutSPI(gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, param *rpc.AuthorizationWithoutSPIParam) error { - assert.Equal(t, strconv.FormatInt(templateID, 10), param.ObjectID) - assert.Equal(t, workspaceID, param.SpaceID) - assert.Len(t, param.ActionObjects, 1) - assert.Equal(t, consts.Edit, *param.ActionObjects[0].Action) - assert.Equal(t, rpc.AuthEntityType_EvaluationExptTemplate, *param.ActionObjects[0].EntityType) - assert.Equal(t, "u1", *param.OwnerID) - assert.Equal(t, workspaceID, param.ResourceSpaceID) - return nil - }) + Authorization(gomock.Any(), gomock.Any()). + Return(nil) mockTemplateManager.EXPECT(). Update(gomock.Any(), gomock.Any(), gomock.Any()). @@ -3178,23 +3160,14 @@ func TestExperimentApplication_UpdateExperimentTemplateMeta(t *testing.T) { BaseInfo: existing.BaseInfo, } + mockAuth.EXPECT(). + Authorization(gomock.Any(), gomock.Any()). + Return(nil) + mockTemplateManager.EXPECT(). Get(gomock.Any(), templateID, workspaceID, gomock.Any()). Return(existing, nil) - mockAuth.EXPECT(). - AuthorizationWithoutSPI(gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, param *rpc.AuthorizationWithoutSPIParam) error { - assert.Equal(t, strconv.FormatInt(templateID, 10), param.ObjectID) - assert.Equal(t, workspaceID, param.SpaceID) - assert.Len(t, param.ActionObjects, 1) - assert.Equal(t, consts.Edit, *param.ActionObjects[0].Action) - assert.Equal(t, rpc.AuthEntityType_EvaluationExptTemplate, *param.ActionObjects[0].EntityType) - assert.Equal(t, "u1", *param.OwnerID) - assert.Equal(t, workspaceID, param.ResourceSpaceID) - return nil - }) - mockTemplateManager.EXPECT(). UpdateMeta(gomock.Any(), gomock.Any(), gomock.Any()). Return(updated, nil) @@ -3239,33 +3212,9 @@ func TestExperimentApplication_DeleteExperimentTemplate(t *testing.T) { TemplateID: templateID, } - existing := &entity.ExptTemplate{ - Meta: &entity.ExptTemplateMeta{ - ID: templateID, - WorkspaceID: workspaceID, - }, - BaseInfo: &entity.BaseInfo{ - CreatedBy: &entity.UserInfo{UserID: gptr.Of("u1")}, - UpdatedBy: &entity.UserInfo{UserID: gptr.Of("u1")}, - }, - } - - mockTemplateManager.EXPECT(). - Get(gomock.Any(), templateID, workspaceID, gomock.Any()). - Return(existing, nil) - mockAuth.EXPECT(). - AuthorizationWithoutSPI(gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, param *rpc.AuthorizationWithoutSPIParam) error { - assert.Equal(t, strconv.FormatInt(templateID, 10), param.ObjectID) - assert.Equal(t, workspaceID, param.SpaceID) - assert.Len(t, param.ActionObjects, 1) - assert.Equal(t, consts.Edit, *param.ActionObjects[0].Action) - assert.Equal(t, rpc.AuthEntityType_EvaluationExptTemplate, *param.ActionObjects[0].EntityType) - assert.Equal(t, "u1", *param.OwnerID) - assert.Equal(t, workspaceID, param.ResourceSpaceID) - return nil - }) + Authorization(gomock.Any(), gomock.Any()). + Return(nil) mockTemplateManager.EXPECT(). Delete(gomock.Any(), templateID, workspaceID, gomock.Any()). @@ -3908,7 +3857,7 @@ func TestExperimentApplication_BatchGetExperimentAggrResult_(t *testing.T) { }, wantResp: &exptpb.BatchGetExperimentAggrResultResponse{ - ExptAggregateResults: []*expt.ExptAggregateResult_{ + ExptAggregateResult_: []*expt.ExptAggregateResult_{ { ExperimentID: validExptID, EvaluatorResults: map[int64]*expt.EvaluatorAggregateResult_{ @@ -3993,12 +3942,12 @@ func TestExperimentApplication_BatchGetExperimentAggrResult_(t *testing.T) { return } if !tt.wantErr { - // 比较 ExptAggregateResults - if len(got.ExptAggregateResults) != len(tt.wantResp.ExptAggregateResults) { - t.Errorf("ExptAggregateResults length mismatch: got %v, want %v", len(got.ExptAggregateResults), len(tt.wantResp.ExptAggregateResults)) + // 比较 ExptAggregateResult_ + if len(got.ExptAggregateResult_) != len(tt.wantResp.ExptAggregateResult_) { + t.Errorf("ExptAggregateResult_ length mismatch: got %v, want %v", len(got.ExptAggregateResult_), len(tt.wantResp.ExptAggregateResult_)) } else { - for i, gotResult := range got.ExptAggregateResults { - wantResult := tt.wantResp.ExptAggregateResults[i] + for i, gotResult := range got.ExptAggregateResult_ { + wantResult := tt.wantResp.ExptAggregateResult_[i] if gotResult.ExperimentID != wantResult.ExperimentID { t.Errorf("ExperimentID mismatch at index %d: got %v, want %v", i, gotResult.ExperimentID, wantResult.ExperimentID) } @@ -4475,7 +4424,7 @@ func TestExperimentApplication_GetExptResultExportRecord(t *testing.T) { Return(&entity.ExptExportWhiteList{UserIDs: []int64{}}).AnyTimes() }, wantResp: &exptpb.GetExptResultExportRecordResponse{ - ExptResultExportRecord: &expt.ExptResultExportRecord{ + ExptResultExportRecords: &expt.ExptResultExportRecord{ ExportID: validExportID, ExptID: int64(789), CsvExportStatus: experiment.CSVExportStatusDO2DTO(entity.CSVExportStatus_Success), @@ -4529,8 +4478,8 @@ func TestExperimentApplication_GetExptResultExportRecord(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, gotResp) - assert.Equal(t, tt.wantResp.ExptResultExportRecord.GetExportID(), gotResp.ExptResultExportRecord.GetExportID()) - assert.Equal(t, tt.wantResp.ExptResultExportRecord.GetCsvExportStatus(), gotResp.ExptResultExportRecord.GetCsvExportStatus()) + assert.Equal(t, tt.wantResp.ExptResultExportRecords.GetExportID(), gotResp.ExptResultExportRecords.GetExportID()) + assert.Equal(t, tt.wantResp.ExptResultExportRecords.GetCsvExportStatus(), gotResp.ExptResultExportRecords.GetCsvExportStatus()) }) } } @@ -5758,3 +5707,109 @@ func TestGetAnalysisRecordFeedbackVote(t *testing.T) { } }) } + +func TestExperimentApplication_ListExperimentStats(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + mockManager := servicemocks.NewMockIExptManager(ctrl) + mockResultSvc := servicemocks.NewMockExptResultService(ctrl) + mockEvalTargetSvc := servicemocks.NewMockIEvalTargetService(ctrl) + + app := &experimentApplication{ + auth: mockAuth, + manager: mockManager, + resultSvc: mockResultSvc, + evalTargetService: mockEvalTargetSvc, + } + + workspaceID := int64(123) + exptID := int64(456) + userID := int64(789) + + req := &exptpb.ListExperimentStatsRequest{ + WorkspaceID: workspaceID, + Session: &common.Session{UserID: gptr.Of(userID)}, + PageNumber: gptr.Of(int32(1)), + PageSize: gptr.Of(int32(10)), + } + + t.Run("success", func(t *testing.T) { + mockAuth.EXPECT().Authorization(gomock.Any(), gomock.Any()).Return(nil) + mockManager.EXPECT().ListExptRaw(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return([]*entity.Experiment{{ID: exptID}}, int64(1), nil) + mockResultSvc.EXPECT().MGetStats(gomock.Any(), []int64{exptID}, workspaceID, gomock.Any()). + Return([]*entity.ExptStats{{ExptID: exptID}}, nil) + + resp, err := app.ListExperimentStats(context.Background(), req) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, int32(1), resp.GetTotal()) + assert.Len(t, resp.GetExptStatsInfos(), 1) + }) +} + +func TestExperimentApplication_AuthReadExptTemplates(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAuth := rpcmocks.NewMockIAuthProvider(ctrl) + app := &experimentApplication{auth: mockAuth} + + workspaceID := int64(123) + templateID := int64(456) + + t.Run("success", func(t *testing.T) { + mockAuth.EXPECT().MAuthorizeWithoutSPI(gomock.Any(), workspaceID, gomock.Any()).Return(nil) + err := app.AuthReadExptTemplates(context.Background(), []*entity.ExptTemplate{{Meta: &entity.ExptTemplateMeta{ID: templateID}}}, workspaceID) + assert.NoError(t, err) + }) +} + +func TestExperimentApplication_UpsertExptTurnResultFilter(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockResultSvc := servicemocks.NewMockExptResultService(ctrl) + app := &experimentApplication{resultSvc: mockResultSvc} + + workspaceID := int64(123) + exptID := int64(456) + + t.Run("manual type", func(t *testing.T) { + req := &exptpb.UpsertExptTurnResultFilterRequest{ + WorkspaceID: gptr.Of(workspaceID), + ExperimentID: gptr.Of(exptID), + FilterType: gptr.Of(exptpb.UpsertExptTurnResultFilterTypeMANUAL), + ItemIds: []int64{1}, + } + mockResultSvc.EXPECT().ManualUpsertExptTurnResultFilter(gomock.Any(), workspaceID, exptID, []int64{1}).Return(nil) + _, err := app.UpsertExptTurnResultFilter(context.Background(), req) + assert.NoError(t, err) + }) + + t.Run("check type", func(t *testing.T) { + req := &exptpb.UpsertExptTurnResultFilterRequest{ + WorkspaceID: gptr.Of(workspaceID), + ExperimentID: gptr.Of(exptID), + FilterType: gptr.Of(exptpb.UpsertExptTurnResultFilterTypeCHECK), + ItemIds: []int64{1}, + RetryTimes: gptr.Of(int32(3)), + } + mockResultSvc.EXPECT().CompareExptTurnResultFilters(gomock.Any(), workspaceID, exptID, []int64{1}, int32(3)).Return(nil) + _, err := app.UpsertExptTurnResultFilter(context.Background(), req) + assert.NoError(t, err) + }) + + t.Run("default type", func(t *testing.T) { + req := &exptpb.UpsertExptTurnResultFilterRequest{ + WorkspaceID: gptr.Of(workspaceID), + ExperimentID: gptr.Of(exptID), + ItemIds: []int64{1}, + } + mockResultSvc.EXPECT().UpsertExptTurnResultFilter(gomock.Any(), workspaceID, exptID, []int64{1}).Return(nil) + _, err := app.UpsertExptTurnResultFilter(context.Background(), req) + assert.NoError(t, err) + }) +} diff --git a/backend/modules/evaluation/domain/entity/common.go b/backend/modules/evaluation/domain/entity/common.go index 398634b9d..37e8e97a6 100644 --- a/backend/modules/evaluation/domain/entity/common.go +++ b/backend/modules/evaluation/domain/entity/common.go @@ -286,18 +286,28 @@ type FunctionCall struct { } type ModelConfig struct { - ModelID int64 `json:"model_id"` - ModelName string `json:"model_name"` - MaxTokens *int32 `json:"max_tokens,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopP *float64 `json:"top_p,omitempty"` - ToolChoice ToolChoiceType `json:"tool_choice" jsonschema:"-"` + ModelID *int64 `json:"model_id"` + ModelName string `json:"model_name"` + MaxTokens *int32 `json:"max_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + ToolChoice ToolChoiceType `json:"tool_choice" jsonschema:"-"` + Protocol *string `json:"protocol,omitempty"` + Identification *string `json:"identification,omitempty"` + PresetModel *bool `json:"preset_model,omitempty"` ProviderModelID *string `json:"provider_model_id,omitempty" jsonschema:"-"` JSONExt *string `json:"json_ext,omitempty"` } +func (m *ModelConfig) GetModelID() int64 { + if m != nil && m.ModelID != nil { + return *m.ModelID + } + return 0 +} + type Reply struct { Item *ReplyItem `json:"item,omitempty"` DebugID int64 `json:"debug_id"` diff --git a/backend/modules/evaluation/domain/entity/evaluator_test.go b/backend/modules/evaluation/domain/entity/evaluator_test.go index 7d2eb7cc2..3c7a99dc4 100644 --- a/backend/modules/evaluation/domain/entity/evaluator_test.go +++ b/backend/modules/evaluation/domain/entity/evaluator_test.go @@ -452,7 +452,7 @@ func TestEvaluator_GetPromptTemplateKey(t *testing.T) { func TestEvaluator_GetModelConfig(t *testing.T) { t.Parallel() modelConfig := &ModelConfig{ - ModelID: 123, + ModelID: gptr.Of(int64(123)), ModelName: "test_model", } @@ -599,7 +599,7 @@ func TestEvaluator_ValidateBaseInfo(t *testing.T) { EvaluatorType: EvaluatorTypePrompt, PromptEvaluatorVersion: &PromptEvaluatorVersion{ MessageList: []*Message{{Role: RoleUser}}, - ModelConfig: &ModelConfig{ModelID: 123}, + ModelConfig: &ModelConfig{ModelID: gptr.Of(int64(123))}, }, }, expectErr: false, @@ -1661,7 +1661,7 @@ func TestPromptEvaluatorVersion_GetPromptTemplateKey(t *testing.T) { } func TestPromptEvaluatorVersion_GetModelConfig(t *testing.T) { - mc := &ModelConfig{ModelID: 123} + mc := &ModelConfig{ModelID: gptr.Of(int64(123))} ver := &PromptEvaluatorVersion{ModelConfig: mc} assert.Equal(t, mc, ver.GetModelConfig()) } @@ -1698,7 +1698,7 @@ func TestPromptEvaluatorVersion_ValidateBaseInfo(t *testing.T) { assert.Error(t, ver.ValidateBaseInfo()) // message list 为空 - ver = &PromptEvaluatorVersion{ModelConfig: &ModelConfig{ModelID: 1}} + ver = &PromptEvaluatorVersion{ModelConfig: &ModelConfig{ModelID: gptr.Of(int64(1))}} assert.Error(t, ver.ValidateBaseInfo()) // model config 为空 @@ -1710,6 +1710,6 @@ func TestPromptEvaluatorVersion_ValidateBaseInfo(t *testing.T) { assert.Error(t, ver.ValidateBaseInfo()) // 正常 - ver = &PromptEvaluatorVersion{MessageList: []*Message{{Role: RoleUser}}, ModelConfig: &ModelConfig{ModelID: 1}} + ver = &PromptEvaluatorVersion{MessageList: []*Message{{Role: RoleUser}}, ModelConfig: &ModelConfig{ModelID: gptr.Of(int64(1))}} assert.NoError(t, ver.ValidateBaseInfo()) } diff --git a/backend/modules/evaluation/domain/entity/evaluator_version_prompt.go b/backend/modules/evaluation/domain/entity/evaluator_version_prompt.go index 7985cf39d..caa7b3108 100644 --- a/backend/modules/evaluation/domain/entity/evaluator_version_prompt.go +++ b/backend/modules/evaluation/domain/entity/evaluator_version_prompt.go @@ -147,7 +147,7 @@ func (do *PromptEvaluatorVersion) ValidateBaseInfo() error { if do.ModelConfig == nil { return errorx.NewByCode(errno.InvalidModelConfigCode, errorx.WithExtraMsg("model config is nil")) } - if do.ModelConfig.ModelID == 0 && do.ModelConfig.ProviderModelID == nil { + if do.ModelConfig.ModelID == nil && do.ModelConfig.ProviderModelID == nil { return errorx.NewByCode(errno.InvalidModelConfigCode, errorx.WithExtraMsg("model id is empty")) } return nil diff --git a/backend/modules/evaluation/domain/entity/runtime_param_test.go b/backend/modules/evaluation/domain/entity/runtime_param_test.go index ae2bacc69..cae18a277 100755 --- a/backend/modules/evaluation/domain/entity/runtime_param_test.go +++ b/backend/modules/evaluation/domain/entity/runtime_param_test.go @@ -25,7 +25,7 @@ func TestPromptRuntimeParam_GetJSONDemo(t *testing.T) { func TestPromptRuntimeParam_GetJSONValue(t *testing.T) { param := &PromptRuntimeParam{ ModelConfig: &ModelConfig{ - ModelID: 123, + ModelID: gptr.Of(int64(123)), ModelName: "test_model", MaxTokens: gptr.Of(int32(100)), Temperature: gptr.Of(0.7), @@ -92,7 +92,7 @@ func TestPromptRuntimeParam_ParseFromJSON(t *testing.T) { func TestNewPromptRuntimeParam(t *testing.T) { modelConfig := &ModelConfig{ - ModelID: 123, + ModelID: gptr.Of(int64(123)), ModelName: "test_model", } diff --git a/backend/modules/evaluation/domain/entity/target.go b/backend/modules/evaluation/domain/entity/target.go index a1e0e18cc..e6c9af814 100644 --- a/backend/modules/evaluation/domain/entity/target.go +++ b/backend/modules/evaluation/domain/entity/target.go @@ -53,6 +53,9 @@ const ( EvalTargetTypeVolcengineAgent EvalTargetType = 5 // 自定义服务 for内场 EvalTargetTypeCustomRPCServer EvalTargetType = 6 + + // 火山智能体Agentkit + EvalTargetTypeVolcengineAgentAgentkit EvalTargetType = 7 ) func (p EvalTargetType) String() string { @@ -69,6 +72,8 @@ func (p EvalTargetType) String() string { return "VolcengineAgent" case EvalTargetTypeCustomRPCServer: return "CustomRPCServer" + case EvalTargetTypeVolcengineAgentAgentkit: + return "VolcengineAgentKit" } return "" } diff --git a/backend/modules/evaluation/domain/entity/target_builtin_volcengine_agent.go b/backend/modules/evaluation/domain/entity/target_builtin_volcengine_agent.go index c79fdd9c4..db6446e2f 100644 --- a/backend/modules/evaluation/domain/entity/target_builtin_volcengine_agent.go +++ b/backend/modules/evaluation/domain/entity/target_builtin_volcengine_agent.go @@ -11,6 +11,7 @@ type VolcengineAgent struct { VolcengineAgentEndpoints []*VolcengineAgentEndpoint BaseInfo *BaseInfo `json:"-"` // 基础信息 Protocol *VolcengineAgentProtocol + RuntimeID *string } type VolcengineAgentEndpoint struct { diff --git a/backend/modules/evaluation/domain/service/evaluator_impl_test.go b/backend/modules/evaluation/domain/service/evaluator_impl_test.go index 540495b1f..f0e80f169 100644 --- a/backend/modules/evaluation/domain/service/evaluator_impl_test.go +++ b/backend/modules/evaluation/domain/service/evaluator_impl_test.go @@ -1493,7 +1493,7 @@ func TestEvaluatorServiceImpl_SubmitEvaluatorVersion(t *testing.T) { PromptTemplateKey: "test-template-key", PromptSuffix: "test-prompt-suffix", ModelConfig: &entity.ModelConfig{ - ModelID: 1, + ModelID: gptr.Of(int64(1)), }, ParseType: entity.ParseTypeFunctionCall, MessageList: []*entity.Message{ @@ -1672,7 +1672,7 @@ func TestEvaluatorServiceImpl_RunEvaluator(t *testing.T) { PromptTemplateKey: "test-template-key", PromptSuffix: "test-prompt-suffix", ModelConfig: &entity.ModelConfig{ - ModelID: 1, + ModelID: gptr.Of(int64(1)), }, ParseType: entity.ParseTypeFunctionCall, MessageList: []*entity.Message{ @@ -1910,7 +1910,7 @@ func Test_EvaluatorServiceImpl_DebugEvaluator(t *testing.T) { PromptTemplateKey: "test-template-key", PromptSuffix: "test-prompt-suffix", ModelConfig: &entity.ModelConfig{ - ModelID: 1, + ModelID: gptr.Of(int64(1)), }, ParseType: entity.ParseTypeFunctionCall, MessageList: []*entity.Message{ @@ -2051,7 +2051,7 @@ func TestEvaluatorServiceImpl_RunEvaluator_DisableTracing(t *testing.T) { PromptTemplateKey: "test-template-key", PromptSuffix: "test-prompt-suffix", ModelConfig: &entity.ModelConfig{ - ModelID: 1, + ModelID: gptr.Of(int64(1)), }, ParseType: entity.ParseTypeFunctionCall, }, diff --git a/backend/modules/evaluation/domain/service/evaluator_record_impl.go b/backend/modules/evaluation/domain/service/evaluator_record_impl.go index 1471a6d41..02219018b 100644 --- a/backend/modules/evaluation/domain/service/evaluator_record_impl.go +++ b/backend/modules/evaluation/domain/service/evaluator_record_impl.go @@ -199,6 +199,8 @@ func (s *EvaluatorRecordServiceImpl) recalculateWeightedScoreForTurn(ctx context version2Record[r.EvaluatorVersionID] = r } } + // 用当前已校正的 record 覆盖,避免主从延迟或读从库时 BatchGet 拿到旧数据,导致重算加权分仍用旧分 + version2Record[rec.EvaluatorVersionID] = rec // 6. 构建权重映射 scoreWeights := make(map[int64]float64) diff --git a/backend/modules/evaluation/domain/service/evaluator_source_prompt_impl.go b/backend/modules/evaluation/domain/service/evaluator_source_prompt_impl.go index 17476f9e2..c97c87792 100644 --- a/backend/modules/evaluation/domain/service/evaluator_source_prompt_impl.go +++ b/backend/modules/evaluation/domain/service/evaluator_source_prompt_impl.go @@ -130,10 +130,10 @@ func (p *EvaluatorSourcePromptServiceImpl) Run(ctx context.Context, evaluator *e } defer func() { var modelID string - if evaluator.PromptEvaluatorVersion.ModelConfig.ModelID == 0 { + if evaluator.PromptEvaluatorVersion.ModelConfig.GetModelID() == 0 { modelID = ptr.From(evaluator.PromptEvaluatorVersion.ModelConfig.ProviderModelID) } else { - modelID = strconv.FormatInt(evaluator.PromptEvaluatorVersion.ModelConfig.ModelID, 10) + modelID = strconv.FormatInt(evaluator.PromptEvaluatorVersion.ModelConfig.GetModelID(), 10) } p.metric.EmitRun(exptSpaceID, err, startTime, modelID) @@ -697,7 +697,7 @@ func (p *EvaluatorSourcePromptServiceImpl) injectParseType(ctx context.Context, return } - if suffixKey, ok := p.configer.GetEvaluatorPromptSuffixMapping(ctx)[strconv.FormatInt(evaluatorDO.GetModelConfig().ModelID, 10)]; ok { + if suffixKey, ok := p.configer.GetEvaluatorPromptSuffixMapping(ctx)[strconv.FormatInt(evaluatorDO.GetModelConfig().GetModelID(), 10)]; ok { evaluatorDO.SetPromptSuffix(p.configer.GetEvaluatorPromptSuffix(ctx)[suffixKey]) evaluatorDO.SetParseType(entity.ParseType(suffixKey)) } else { diff --git a/backend/modules/evaluation/domain/service/evaluator_source_prompt_impl_test.go b/backend/modules/evaluation/domain/service/evaluator_source_prompt_impl_test.go index d036258f8..a3b3d0df6 100755 --- a/backend/modules/evaluation/domain/service/evaluator_source_prompt_impl_test.go +++ b/backend/modules/evaluation/domain/service/evaluator_source_prompt_impl_test.go @@ -54,7 +54,7 @@ func TestEvaluatorSourcePromptServiceImpl_Run(t *testing.T) { PromptTemplateKey: "test-template-key", PromptSuffix: "test-prompt-suffix", ModelConfig: &entity.ModelConfig{ - ModelID: 1, + ModelID: gptr.Of(int64(1)), }, ParseType: entity.ParseTypeFunctionCall, MessageList: []*entity.Message{ @@ -259,7 +259,7 @@ func TestEvaluatorSourcePromptServiceImpl_PreHandle(t *testing.T) { PromptTemplateKey: "test-template-key", PromptSuffix: "test-prompt-suffix", ModelConfig: &entity.ModelConfig{ - ModelID: 1, + ModelID: gptr.Of(int64(1)), }, ParseType: entity.ParseTypeFunctionCall, }, @@ -351,7 +351,7 @@ func TestEvaluatorSourcePromptServiceImpl_Debug(t *testing.T) { PromptTemplateKey: "test-template-key", PromptSuffix: "test-prompt-suffix", ModelConfig: &entity.ModelConfig{ - ModelID: 1, + ModelID: gptr.Of(int64(1)), }, ParseType: entity.ParseTypeFunctionCall, MessageList: []*entity.Message{ @@ -1577,7 +1577,7 @@ func TestEvaluatorSourcePromptServiceImpl_Run_DisableTracing(t *testing.T) { PromptTemplateKey: "test-template-key", PromptSuffix: "test-prompt-suffix", ModelConfig: &entity.ModelConfig{ - ModelID: 1, + ModelID: gptr.Of(int64(1)), }, ParseType: entity.ParseTypeFunctionCall, MessageList: []*entity.Message{ diff --git a/backend/modules/evaluation/domain/service/evaluator_template_impl.go b/backend/modules/evaluation/domain/service/evaluator_template_impl.go index 245275b9f..b33048de7 100644 --- a/backend/modules/evaluation/domain/service/evaluator_template_impl.go +++ b/backend/modules/evaluation/domain/service/evaluator_template_impl.go @@ -237,7 +237,7 @@ func (s *EvaluatorTemplateServiceImpl) ListEvaluatorTemplate(ctx context.Context // 调用repo层查询 repoResp, err := s.templateRepo.ListEvaluatorTemplate(ctx, repoReq) if err != nil { - return nil, errorx.NewByCode(errno.CommonInternalErrorCode) + return nil, errorx.WrapByCode(err, errno.CommonInternalErrorCode) } // 计算总页数 diff --git a/backend/modules/evaluation/domain/service/expt_run_scheduler_event_impl.go b/backend/modules/evaluation/domain/service/expt_run_scheduler_event_impl.go index 3536c240a..0fecb3f7c 100644 --- a/backend/modules/evaluation/domain/service/expt_run_scheduler_event_impl.go +++ b/backend/modules/evaluation/domain/service/expt_run_scheduler_event_impl.go @@ -282,7 +282,11 @@ func (e *ExptSchedulerImpl) schedule(ctx context.Context, event *entity.ExptSche logs.CtxInfo(ctx, "[ExptEval] expt daemon with next tick, expt_id: %v, event: %v", event.ExptID, event) - time.Sleep(time.Second * 3) + select { + case <-time.After(time.Second * 3): + case <-ctx.Done(): + return ctx.Err() + } return mode.NextTick(ctx, event, nextTick) } diff --git a/backend/modules/evaluation/domain/service/expt_run_scheduler_event_impl_test.go b/backend/modules/evaluation/domain/service/expt_run_scheduler_event_impl_test.go index 3a39e7b1f..f30893002 100644 --- a/backend/modules/evaluation/domain/service/expt_run_scheduler_event_impl_test.go +++ b/backend/modules/evaluation/domain/service/expt_run_scheduler_event_impl_test.go @@ -1228,3 +1228,40 @@ func TestExptSchedulerImpl_handleZombies(t *testing.T) { }) } } + +func TestExptSchedulerImpl_Schedule_ContextCancelled(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockManager := svcmocks.NewMockIExptManager(ctrl) + mockFactory := svcmocks.NewMockSchedulerModeFactory(ctrl) + mockConfiger := configmocks.NewMockIConfiger(ctrl) + mockResultSvc := svcmocks.NewMockExptResultService(ctrl) + + svc := &ExptSchedulerImpl{ + Manager: mockManager, + schedulerModeFactory: mockFactory, + Configer: mockConfiger, + ResultSvc: mockResultSvc, + } + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + event := &entity.ExptScheduleEvent{ExptID: 1, SpaceID: 1, ExptRunMode: 1} + exptDetail := &entity.Experiment{ID: 1} + mockMode := entitymocks.NewMockExptSchedulerMode(ctrl) + + mockManager.EXPECT().GetDetail(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(exptDetail, nil) + mockFactory.EXPECT().NewSchedulerMode(gomock.Any()).Return(mockMode, nil) + mockMode.EXPECT().ExptStart(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockMode.EXPECT().ScheduleStart(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockMode.EXPECT().ScanEvalItems(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil, nil, nil) + mockConfiger.EXPECT().GetConsumerConf(gomock.Any()).Return(&entity.ExptConsumerConf{}).AnyTimes() + mockMode.EXPECT().ScheduleEnd(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockMode.EXPECT().ExptEnd(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil) + + err := svc.schedule(ctx, event) + assert.Error(t, err) + assert.True(t, errors.Is(err, context.DeadlineExceeded)) +} diff --git a/backend/modules/evaluation/domain/service/target_source_loopprompt_impl_test.go b/backend/modules/evaluation/domain/service/target_source_loopprompt_impl_test.go index 4518a8c28..8b95cc94e 100644 --- a/backend/modules/evaluation/domain/service/target_source_loopprompt_impl_test.go +++ b/backend/modules/evaluation/domain/service/target_source_loopprompt_impl_test.go @@ -616,7 +616,7 @@ func TestPromptSourceEvalTargetServiceImpl_ListSource(t *testing.T) { Description: "Desc 1", SubmitStatus: entity.SubmitStatus_Submitted, }, - RuntimeParamDemo: gptr.Of("{\"model_config\":{\"model_id\":\"0\",\"model_name\":\"\",\"max_tokens\":0,\"temperature\":0,\"top_p\":0,\"tool_choice\":\"\",\"json_ext\":\"{}\"}}"), + RuntimeParamDemo: gptr.Of("{\"model_config\":{\"model_id\":null,\"model_name\":\"\",\"max_tokens\":0,\"temperature\":0,\"top_p\":0,\"tool_choice\":\"\",\"json_ext\":\"{}\"}}"), }, }, }, @@ -664,7 +664,7 @@ func TestPromptSourceEvalTargetServiceImpl_ListSource(t *testing.T) { Description: "Desc 2", SubmitStatus: entity.SubmitStatus_UnSubmit, }, - RuntimeParamDemo: gptr.Of("{\"model_config\":{\"model_id\":\"0\",\"model_name\":\"\",\"max_tokens\":0,\"temperature\":0,\"top_p\":0,\"tool_choice\":\"\",\"json_ext\":\"{}\"}}"), + RuntimeParamDemo: gptr.Of("{\"model_config\":{\"model_id\":null,\"model_name\":\"\",\"max_tokens\":0,\"temperature\":0,\"top_p\":0,\"tool_choice\":\"\",\"json_ext\":\"{}\"}}"), }, }, }, @@ -701,7 +701,7 @@ func TestPromptSourceEvalTargetServiceImpl_ListSource(t *testing.T) { Name: "", // Default from gptr.From Description: "", // Default from gptr.From }, - RuntimeParamDemo: gptr.Of("{\"model_config\":{\"model_id\":\"0\",\"model_name\":\"\",\"max_tokens\":0,\"temperature\":0,\"top_p\":0,\"tool_choice\":\"\",\"json_ext\":\"{}\"}}"), + RuntimeParamDemo: gptr.Of("{\"model_config\":{\"model_id\":null,\"model_name\":\"\",\"max_tokens\":0,\"temperature\":0,\"top_p\":0,\"tool_choice\":\"\",\"json_ext\":\"{}\"}}"), }, }, }, @@ -785,7 +785,7 @@ func TestPromptSourceEvalTargetServiceImpl_ListSource(t *testing.T) { Description: "", SubmitStatus: entity.SubmitStatus_UnSubmit, }, - RuntimeParamDemo: gptr.Of("{\"model_config\":{\"model_id\":\"0\",\"model_name\":\"\",\"max_tokens\":0,\"temperature\":0,\"top_p\":0,\"tool_choice\":\"\",\"json_ext\":\"{}\"}}"), + RuntimeParamDemo: gptr.Of("{\"model_config\":{\"model_id\":null,\"model_name\":\"\",\"max_tokens\":0,\"temperature\":0,\"top_p\":0,\"tool_choice\":\"\",\"json_ext\":\"{}\"}}"), }, }, }, @@ -885,7 +885,7 @@ func TestPromptSourceEvalTargetServiceImpl_ListSourceVersion(t *testing.T) { SubmitStatus: entity.SubmitStatus_Submitted, Description: "Version 1.0 desc", }, - RuntimeParamDemo: gptr.Of("{\"model_config\":{\"model_id\":\"0\",\"model_name\":\"\",\"max_tokens\":0,\"temperature\":0,\"top_p\":0,\"tool_choice\":\"\",\"json_ext\":\"{}\"}}"), + RuntimeParamDemo: gptr.Of("{\"model_config\":{\"model_id\":null,\"model_name\":\"\",\"max_tokens\":0,\"temperature\":0,\"top_p\":0,\"tool_choice\":\"\",\"json_ext\":\"{}\"}}"), }, }, wantNextCursor: "cursor_next", @@ -932,7 +932,7 @@ func TestPromptSourceEvalTargetServiceImpl_ListSourceVersion(t *testing.T) { SubmitStatus: entity.SubmitStatus_UnSubmit, Description: "Version 0.1 desc", }, - RuntimeParamDemo: gptr.Of("{\"model_config\":{\"model_id\":\"0\",\"model_name\":\"\",\"max_tokens\":0,\"temperature\":0,\"top_p\":0,\"tool_choice\":\"\",\"json_ext\":\"{}\"}}"), + RuntimeParamDemo: gptr.Of("{\"model_config\":{\"model_id\":null,\"model_name\":\"\",\"max_tokens\":0,\"temperature\":0,\"top_p\":0,\"tool_choice\":\"\",\"json_ext\":\"{}\"}}"), }, }, wantNextCursor: "cursor_final", @@ -970,7 +970,7 @@ func TestPromptSourceEvalTargetServiceImpl_ListSourceVersion(t *testing.T) { PromptKey: "key_no_basic", Description: "Desc", }, - RuntimeParamDemo: gptr.Of("{\"model_config\":{\"model_id\":\"0\",\"model_name\":\"\",\"max_tokens\":0,\"temperature\":0,\"top_p\":0,\"tool_choice\":\"\",\"json_ext\":\"{}\"}}"), + RuntimeParamDemo: gptr.Of("{\"model_config\":{\"model_id\":null,\"model_name\":\"\",\"max_tokens\":0,\"temperature\":0,\"top_p\":0,\"tool_choice\":\"\",\"json_ext\":\"{}\"}}"), }, }, wantNextCursor: "next", @@ -1103,7 +1103,7 @@ func TestPromptSourceEvalTargetServiceImpl_ListSourceVersion(t *testing.T) { SubmitStatus: entity.SubmitStatus_Submitted, Description: "Desc A.1", }, - RuntimeParamDemo: gptr.Of("{\"model_config\":{\"model_id\":\"0\",\"model_name\":\"\",\"max_tokens\":0,\"temperature\":0,\"top_p\":0,\"tool_choice\":\"\",\"json_ext\":\"{}\"}}"), + RuntimeParamDemo: gptr.Of("{\"model_config\":{\"model_id\":null,\"model_name\":\"\",\"max_tokens\":0,\"temperature\":0,\"top_p\":0,\"tool_choice\":\"\",\"json_ext\":\"{}\"}}"), }, }, wantNextCursor: "cursor_pagesize_nil", diff --git a/backend/modules/evaluation/infra/mq/rocket/conf.go b/backend/modules/evaluation/infra/mq/rocket/conf.go index 5ed63d8b7..4030179c0 100644 --- a/backend/modules/evaluation/infra/mq/rocket/conf.go +++ b/backend/modules/evaluation/infra/mq/rocket/conf.go @@ -34,6 +34,9 @@ type RMQConf struct { ConsumerGroup string `json:"consumer_group" mapstructure:"consumer_group"` WorkerNum int `json:"worker_num" mapstructure:"worker_num"` ConsumeTimeout time.Duration `json:"consume_timeout" mapstructure:"consume_timeout"` + + AccessKey *string `json:"access_key" mapstructure:"access_key"` + AccessSecret *string `json:"access_secret" mapstructure:"access_secret"` } func (c *RMQConf) Valid() bool { @@ -47,6 +50,8 @@ func (c *RMQConf) ToProducerCfg() mq.ProducerConfig { ProduceTimeout: c.ProduceTimeout, RetryTimes: c.RetryTimes, ProducerGroup: gptr.Of(c.ProducerGroup), + AccessKey: c.AccessKey, + AccessSecret: c.AccessSecret, } } @@ -58,5 +63,7 @@ func (c *RMQConf) ToConsumerCfg() mq.ConsumerConfig { ConsumerGroup: c.ConsumerGroup, ConsumeGoroutineNums: c.WorkerNum, ConsumeTimeout: c.ConsumeTimeout, + AccessKey: c.AccessKey, + AccessSecret: c.AccessSecret, } } diff --git a/backend/modules/evaluation/infra/repo/evaluator/mysql/convertor/evaluator_test.go b/backend/modules/evaluation/infra/repo/evaluator/mysql/convertor/evaluator_test.go index 3081027dc..4063cc5fc 100644 --- a/backend/modules/evaluation/infra/repo/evaluator/mysql/convertor/evaluator_test.go +++ b/backend/modules/evaluation/infra/repo/evaluator/mysql/convertor/evaluator_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/bytedance/gg/gptr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/gorm" @@ -98,7 +99,7 @@ func TestConvertEvaluatorVersionPO2DO(t *testing.T) { }, }, ModelConfig: &evaluatordo.ModelConfig{ - ModelID: 12345, + ModelID: gptr.Of(int64(12345)), ModelName: "test-model", Temperature: ptr.Of(float64(0.7)), MaxTokens: ptr.Of(int32(1000)), @@ -282,7 +283,7 @@ func TestConvertEvaluatorVersionPO2DO(t *testing.T) { }, }, ModelConfig: &evaluatordo.ModelConfig{ - ModelID: 1749615085, + ModelID: gptr.Of(int64(1749615085)), ModelName: "豆包·1.6·深度思考", MaxTokens: ptr.Of(int32(4096)), Temperature: ptr.Of(float64(0.1)), @@ -365,7 +366,7 @@ func TestConvertEvaluatorVersionPO2DO(t *testing.T) { // 验证 ModelConfig if tt.want.PromptEvaluatorVersion.ModelConfig != nil { require.NotNil(t, got.PromptEvaluatorVersion.ModelConfig) - assert.Equal(t, tt.want.PromptEvaluatorVersion.ModelConfig.ModelID, got.PromptEvaluatorVersion.ModelConfig.ModelID) + assert.Equal(t, tt.want.PromptEvaluatorVersion.ModelConfig.GetModelID(), got.PromptEvaluatorVersion.ModelConfig.GetModelID()) assert.Equal(t, tt.want.PromptEvaluatorVersion.ModelConfig.ModelName, got.PromptEvaluatorVersion.ModelConfig.ModelName) if tt.want.PromptEvaluatorVersion.ModelConfig.Temperature != nil { assert.Equal(t, *tt.want.PromptEvaluatorVersion.ModelConfig.Temperature, *got.PromptEvaluatorVersion.ModelConfig.Temperature) diff --git a/backend/modules/evaluation/infra/repo/target/mysql/convertor/eval_target.go b/backend/modules/evaluation/infra/repo/target/mysql/convertor/eval_target.go index 94c4eb363..3820b5b2b 100644 --- a/backend/modules/evaluation/infra/repo/target/mysql/convertor/eval_target.go +++ b/backend/modules/evaluation/infra/repo/target/mysql/convertor/eval_target.go @@ -58,7 +58,7 @@ func EvalTargetVersionDO2PO(do *entity.EvalTargetVersion) (po *model.TargetVersi if err != nil { return nil, err } - case entity.EvalTargetTypeVolcengineAgent: + case entity.EvalTargetTypeVolcengineAgent, entity.EvalTargetTypeVolcengineAgentAgentkit: meta, err = json.Marshal(do.VolcengineAgent) if err != nil { return nil, err @@ -68,6 +68,7 @@ func EvalTargetVersionDO2PO(do *entity.EvalTargetVersion) (po *model.TargetVersi if err != nil { return nil, err } + default: } if do.InputSchema != nil { inputSchema, err = json.Marshal(do.InputSchema) @@ -200,7 +201,7 @@ func EvalTargetVersionPO2DO(targetVersionPO *model.TargetVersion, targetType ent if err := json.Unmarshal(*targetVersionPO.TargetMeta, meta); err == nil { targetVersionDO.CozeWorkflow = meta } - case entity.EvalTargetTypeVolcengineAgent: + case entity.EvalTargetTypeVolcengineAgent, entity.EvalTargetTypeVolcengineAgentAgentkit: meta := &entity.VolcengineAgent{} if err := json.Unmarshal(*targetVersionPO.TargetMeta, meta); err == nil { targetVersionDO.VolcengineAgent = meta diff --git a/backend/modules/evaluation/infra/repo/target/mysql/convertor/eval_target_record_test.go b/backend/modules/evaluation/infra/repo/target/mysql/convertor/eval_target_record_test.go new file mode 100644 index 000000000..b1bc06204 --- /dev/null +++ b/backend/modules/evaluation/infra/repo/target/mysql/convertor/eval_target_record_test.go @@ -0,0 +1,77 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package convertor + +import ( + "testing" + "time" + + "github.com/bytedance/gg/gptr" + "github.com/stretchr/testify/assert" + + "github.com/coze-dev/coze-loop/backend/modules/evaluation/domain/entity" + "github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/repo/target/mysql/gorm_gen/model" +) + +func TestEvalTargetRecordConvert(t *testing.T) { + t.Run("DO2PO", func(t *testing.T) { + do := &entity.EvalTargetRecord{ + ID: 1, + SpaceID: 2, + Status: gptr.Of(entity.EvalTargetRunStatusSuccess), + EvalTargetInputData: &entity.EvalTargetInputData{ + InputFields: map[string]*entity.Content{"k": {ContentType: gptr.Of(entity.ContentTypeText), Text: gptr.Of("v")}}, + }, + EvalTargetOutputData: &entity.EvalTargetOutputData{ + OutputFields: map[string]*entity.Content{"res": {ContentType: gptr.Of(entity.ContentTypeText), Text: gptr.Of("resp")}}, + }, + BaseInfo: &entity.BaseInfo{ + CreatedAt: gptr.Of(int64(123456789)), + }, + } + po, err := EvalTargetRecordDO2PO(do) + assert.NoError(t, err) + assert.Equal(t, int64(1), po.ID) + assert.Equal(t, int32(entity.EvalTargetRunStatusSuccess), po.Status) + assert.NotNil(t, po.InputData) + assert.NotNil(t, po.OutputData) + + poNil, errNil := EvalTargetRecordDO2PO(nil) + assert.NoError(t, errNil) + assert.Nil(t, poNil) + }) + + t.Run("PO2DO", func(t *testing.T) { + input := []byte(`{"InputFields":{"k":{"text":"v"}}}`) + output := []byte(`{"OutputFields":{"res":{"text":"resp"}},"EvalTargetUsage":{"InputTokens":10,"OutputTokens":20}}`) + po := &model.TargetRecord{ + ID: 1, + Status: int32(entity.EvalTargetRunStatusSuccess), + InputData: &input, + OutputData: &output, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + do, err := EvalTargetRecordPO2DO(po) + assert.NoError(t, err) + assert.NotNil(t, do) + assert.NotNil(t, do.EvalTargetInputData) + assert.NotNil(t, do.EvalTargetInputData.InputFields["k"]) + assert.Equal(t, int64(1), do.ID) + assert.Equal(t, entity.EvalTargetRunStatusSuccess, *do.Status) + assert.Equal(t, "v", *do.EvalTargetInputData.InputFields["k"].Text) + assert.Equal(t, int64(30), do.EvalTargetOutputData.EvalTargetUsage.TotalTokens) + + doNil, errNilPo := EvalTargetRecordPO2DO(nil) + assert.NoError(t, errNilPo) + assert.Nil(t, doNil) + }) + + t.Run("PO2DO_unmarshal_error", func(t *testing.T) { + input := []byte(`{invalid}`) + po := &model.TargetRecord{InputData: &input} + _, err := EvalTargetRecordPO2DO(po) + assert.Error(t, err) + }) +} diff --git a/backend/modules/evaluation/infra/repo/target/mysql/convertor/eval_target_test.go b/backend/modules/evaluation/infra/repo/target/mysql/convertor/eval_target_test.go index 5824d387c..49c6c968c 100755 --- a/backend/modules/evaluation/infra/repo/target/mysql/convertor/eval_target_test.go +++ b/backend/modules/evaluation/infra/repo/target/mysql/convertor/eval_target_test.go @@ -178,6 +178,34 @@ func TestEvalTargetVersionDO2PO(t *testing.T) { assert.NotNil(t, po.TargetMeta) }, }, + { + name: "CozeWorkflow类型的版本转换", + do: &entity.EvalTargetVersion{ + ID: 1, + EvalTargetType: entity.EvalTargetTypeCozeWorkflow, + CozeWorkflow: &entity.CozeWorkflow{ID: "wf1"}, + InputSchema: []*entity.ArgsSchema{{Key: gptr.Of("in")}}, + OutputSchema: []*entity.ArgsSchema{{Key: gptr.Of("out")}}, + }, + expectError: false, + checkResult: func(t *testing.T, po *model.TargetVersion) { + assert.NotNil(t, po.TargetMeta) + assert.NotNil(t, po.InputSchema) + assert.NotNil(t, po.OutputSchema) + }, + }, + { + name: "VolcengineAgentAgentkit类型的版本转换", + do: &entity.EvalTargetVersion{ + ID: 1, + EvalTargetType: entity.EvalTargetTypeVolcengineAgentAgentkit, + VolcengineAgent: &entity.VolcengineAgent{Name: "agent"}, + }, + expectError: false, + checkResult: func(t *testing.T, po *model.TargetVersion) { + assert.NotNil(t, po.TargetMeta) + }, + }, } for _, tt := range tests { @@ -393,6 +421,43 @@ func TestEvalTargetVersionPO2DO(t *testing.T) { assert.Equal(t, int64(1), do.ID) }, }, + { + name: "CozeWorkflow类型的版本转换", + targetVersionPO: &model.TargetVersion{ + ID: 1, + TargetMeta: gptr.Of([]byte(`{"id":"wf1"}`)), + }, + targetType: entity.EvalTargetTypeCozeWorkflow, + checkResult: func(t *testing.T, do *entity.EvalTargetVersion) { + assert.NotNil(t, do) + assert.Equal(t, "wf1", do.CozeWorkflow.ID) + }, + }, + { + name: "VolcengineAgentAgentkit类型的版本转换", + targetVersionPO: &model.TargetVersion{ + ID: 1, + TargetMeta: gptr.Of([]byte(`{"RuntimeID":"agent"}`)), + }, + targetType: entity.EvalTargetTypeVolcengineAgentAgentkit, + checkResult: func(t *testing.T, do *entity.EvalTargetVersion) { + assert.NotNil(t, do) + assert.Equal(t, "agent", *do.VolcengineAgent.RuntimeID) + }, + }, + { + name: "Schema转换测试", + targetVersionPO: &model.TargetVersion{ + ID: 1, + InputSchema: gptr.Of([]byte(`[{"key":"in"}]`)), + OutputSchema: gptr.Of([]byte(`[{"key":"out"}]`)), + }, + targetType: entity.EvalTargetTypeCozeBot, + checkResult: func(t *testing.T, do *entity.EvalTargetVersion) { + assert.Len(t, do.InputSchema, 1) + assert.Len(t, do.OutputSchema, 1) + }, + }, } for _, tt := range tests { diff --git a/backend/modules/evaluation/infra/rpc/llm/convertor.go b/backend/modules/evaluation/infra/rpc/llm/convertor.go index 86698a65e..d7d67a3e8 100644 --- a/backend/modules/evaluation/infra/rpc/llm/convertor.go +++ b/backend/modules/evaluation/infra/rpc/llm/convertor.go @@ -36,11 +36,14 @@ func ModelConfigDO2DTO(modelConfig *commonentity.ModelConfig, toolCallConfig *co toolChoice = gptr.Of(ToolChoiceTypeDO2DTO(toolCallConfig.ToolChoice)) } return &runtimedto.ModelConfig{ - ModelID: modelConfig.ModelID, - Temperature: modelConfig.Temperature, - MaxTokens: gptr.Of(int64(gptr.Indirect(modelConfig.MaxTokens))), - TopP: modelConfig.TopP, - ToolChoice: toolChoice, + ModelID: modelConfig.GetModelID(), + Temperature: modelConfig.Temperature, + MaxTokens: gptr.Of(int64(gptr.Indirect(modelConfig.MaxTokens))), + TopP: modelConfig.TopP, + ToolChoice: toolChoice, + Protocol: modelConfig.Protocol, + Identification: modelConfig.Identification, + PresetModel: modelConfig.PresetModel, } } diff --git a/backend/modules/evaluation/infra/rpc/llm/convertor_test.go b/backend/modules/evaluation/infra/rpc/llm/convertor_test.go new file mode 100644 index 000000000..c9045b664 --- /dev/null +++ b/backend/modules/evaluation/infra/rpc/llm/convertor_test.go @@ -0,0 +1,218 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package llm + +import ( + "testing" + + "github.com/bytedance/gg/gptr" + "github.com/stretchr/testify/assert" + + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/common" + runtimedto "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/runtime" + commonentity "github.com/coze-dev/coze-loop/backend/modules/evaluation/domain/entity" +) + +func TestLLMCallParamConvert(t *testing.T) { + param := &commonentity.LLMCallParam{ + SpaceID: 1, + EvaluatorID: "100", + UserID: gptr.Of("user1"), + Scenario: commonentity.ScenarioEvaluator, + Messages: []*commonentity.Message{ + { + Role: commonentity.RoleUser, + Content: &commonentity.Content{Text: gptr.Of("hello")}, + }, + }, + ModelConfig: &commonentity.ModelConfig{ + ModelID: gptr.Of(int64(100)), + Temperature: gptr.Of(0.7), + }, + ToolCallConfig: &commonentity.ToolCallConfig{ + ToolChoice: commonentity.ToolChoiceTypeAuto, + }, + } + + got := LLMCallParamConvert(param) + assert.Equal(t, int64(1), *got.BizParam.WorkspaceID) + assert.Equal(t, "user1", *got.BizParam.UserID) + assert.Equal(t, common.ScenarioEvaluator, *got.BizParam.Scenario) + assert.Equal(t, "100", *got.BizParam.ScenarioEntityID) + assert.Equal(t, runtimedto.RoleUser, got.Messages[0].Role) +} + +func TestModelConfigDO2DTO(t *testing.T) { + t.Run("nil_input", func(t *testing.T) { + assert.Nil(t, ModelConfigDO2DTO(nil, nil)) + }) + + t.Run("full_input", func(t *testing.T) { + mc := &commonentity.ModelConfig{ + ModelID: gptr.Of(int64(1)), + Temperature: gptr.Of(0.5), + MaxTokens: gptr.Of(int32(100)), + TopP: gptr.Of(0.9), + } + tcc := &commonentity.ToolCallConfig{ + ToolChoice: commonentity.ToolChoiceTypeRequired, + } + got := ModelConfigDO2DTO(mc, tcc) + assert.Equal(t, int64(1), got.ModelID) + assert.Equal(t, 0.5, *got.Temperature) + assert.Equal(t, int64(100), *got.MaxTokens) + assert.Equal(t, 0.9, *got.TopP) + assert.Equal(t, runtimedto.ToolChoiceRequired, *got.ToolChoice) + }) +} + +func TestToolChoiceTypeDO2DTO(t *testing.T) { + assert.Equal(t, runtimedto.ToolChoiceNone, ToolChoiceTypeDO2DTO(commonentity.ToolChoiceTypeNone)) + assert.Equal(t, runtimedto.ToolChoiceAuto, ToolChoiceTypeDO2DTO(commonentity.ToolChoiceTypeAuto)) + assert.Equal(t, runtimedto.ToolChoiceRequired, ToolChoiceTypeDO2DTO(commonentity.ToolChoiceTypeRequired)) + assert.Equal(t, runtimedto.ToolChoiceAuto, ToolChoiceTypeDO2DTO(commonentity.ToolChoiceType("unknown"))) +} + +func TestRoleDO2DTO(t *testing.T) { + assert.Equal(t, runtimedto.RoleSystem, RoleDO2DTO(commonentity.RoleSystem)) + assert.Equal(t, runtimedto.RoleUser, RoleDO2DTO(commonentity.RoleUser)) + assert.Equal(t, runtimedto.RoleAssistant, RoleDO2DTO(commonentity.RoleAssistant)) + assert.Equal(t, runtimedto.RoleTool, RoleDO2DTO(commonentity.RoleTool)) + assert.Equal(t, runtimedto.RoleUser, RoleDO2DTO(commonentity.Role(999))) +} + +func TestToolCallConvert(t *testing.T) { + t.Run("nil_input", func(t *testing.T) { + assert.Nil(t, ToolCallDO2DTO(nil)) + assert.Nil(t, ToolCallDTO2DO(nil)) + }) + + t.Run("valid_input", func(t *testing.T) { + do := &commonentity.ToolCall{ + Index: 0, + ID: "call_1", + Type: commonentity.ToolTypeFunction, + FunctionCall: &commonentity.FunctionCall{ + Name: "test_func", + Arguments: gptr.Of("{}"), + }, + } + + dto := ToolCallDO2DTO(do) + assert.Equal(t, int64(0), *dto.Index) + assert.Equal(t, "call_1", *dto.ID) + assert.Equal(t, runtimedto.ToolTypeFunction, *dto.Type) + assert.Equal(t, "test_func", *dto.FunctionCall.Name) + + back := ToolCallDTO2DO(dto) + assert.Equal(t, do.Index, back.Index) + assert.Equal(t, do.ID, back.ID) + assert.Equal(t, do.FunctionCall.Name, back.FunctionCall.Name) + }) +} + +func TestReplyItemDTO2DO(t *testing.T) { + t.Run("nil_input", func(t *testing.T) { + assert.Nil(t, ReplyItemDTO2DO(nil)) + }) + + t.Run("full_input", func(t *testing.T) { + dto := &runtimedto.Message{ + Content: gptr.Of("content"), + ReasoningContent: gptr.Of("reasoning"), + ResponseMeta: &runtimedto.ResponseMeta{ + FinishReason: gptr.Of("stop"), + Usage: &runtimedto.TokenUsage{ + PromptTokens: gptr.Of(int64(10)), + CompletionTokens: gptr.Of(int64(20)), + }, + }, + } + got := ReplyItemDTO2DO(dto) + assert.Equal(t, "content", *got.Content) + assert.Equal(t, "reasoning", *got.ReasoningContent) + assert.Equal(t, "stop", got.FinishReason) + assert.Equal(t, int64(10), got.TokenUsage.InputTokens) + assert.Equal(t, int64(20), got.TokenUsage.OutputTokens) + }) +} + +func TestScenarioDO2DTO(t *testing.T) { + assert.Equal(t, common.ScenarioEvalTarget, ScenarioDO2DTO(commonentity.ScenarioEvalTarget)) + assert.Equal(t, common.ScenarioEvaluator, ScenarioDO2DTO(commonentity.ScenarioEvaluator)) + assert.Equal(t, common.ScenarioDefault, ScenarioDO2DTO(commonentity.Scenario("unknown"))) +} + +func TestToolDO2DTO(t *testing.T) { + t.Run("nil_input", func(t *testing.T) { + assert.Nil(t, ToolDO2DTO(nil)) + }) + t.Run("valid_input", func(t *testing.T) { + do := &commonentity.Tool{ + Function: &commonentity.Function{ + Name: "test", + Description: "desc", + Parameters: "params", + }, + } + got := ToolDO2DTO(do) + assert.Equal(t, "test", *got.Name) + assert.Equal(t, "desc", *got.Desc) + assert.Equal(t, "params", *got.Def) + }) +} + +func TestToolsDO2DTO(t *testing.T) { + t.Run("empty_input", func(t *testing.T) { + assert.Nil(t, ToolsDO2DTO(nil)) + }) + t.Run("valid_input", func(t *testing.T) { + dos := []*commonentity.Tool{{Function: &commonentity.Function{Name: "t1"}}} + got := ToolsDO2DTO(dos) + assert.Len(t, got, 1) + }) +} + +func TestToolCallsDO2DTO(t *testing.T) { + t.Run("empty_input", func(t *testing.T) { + assert.Nil(t, ToolCallsDO2DTO(nil)) + }) + t.Run("valid_input", func(t *testing.T) { + dos := []*commonentity.ToolCall{{ID: "c1"}} + got := ToolCallsDO2DTO(dos) + assert.Len(t, got, 1) + }) +} + +func TestToolCallsDTO2DO(t *testing.T) { + t.Run("empty_input", func(t *testing.T) { + assert.Nil(t, ToolCallsDTO2DO(nil)) + }) + t.Run("valid_input", func(t *testing.T) { + dtos := []*runtimedto.ToolCall{{ID: gptr.Of("c1")}} + got := ToolCallsDTO2DO(dtos) + assert.Len(t, got, 1) + }) +} + +func TestContentTypeDO2DTO(t *testing.T) { + assert.Equal(t, runtimedto.ChatMessagePartTypeText, ContentTypeDO2DTO(commonentity.ContentTypeText)) + assert.Equal(t, runtimedto.ChatMessagePartTypeText, ContentTypeDO2DTO(commonentity.ContentType("unknown"))) +} + +func TestMessageDO2DTO(t *testing.T) { + assert.Nil(t, MessageDO2DTO(nil)) +} + +func TestFunctionDO2DTO(t *testing.T) { + assert.Nil(t, FunctionDO2DTO(nil)) +} + +func TestFunctionDTO2DO(t *testing.T) { + assert.Nil(t, FunctionDTO2DO(nil)) +} + +func TestTokenUsageDTO2DO(t *testing.T) { + assert.Nil(t, TokenUsageDTO2DO(nil)) +} diff --git a/backend/modules/evaluation/infra/rpc/prompt/convert_test.go b/backend/modules/evaluation/infra/rpc/prompt/convert_test.go index 70ab70f37..6d7f5b934 100644 --- a/backend/modules/evaluation/infra/rpc/prompt/convert_test.go +++ b/backend/modules/evaluation/infra/rpc/prompt/convert_test.go @@ -298,3 +298,135 @@ func TestConvertFromContent(t *testing.T) { }) } } + +func TestConvertToLoopPrompts(t *testing.T) { + assert.Nil(t, ConvertToLoopPrompts(nil)) + res := ConvertToLoopPrompts([]*prompt.Prompt{{ID: gptr.Of(int64(1))}}) + assert.Len(t, res, 1) + assert.Equal(t, int64(1), res[0].ID) +} + +func TestConvertToLoopPrompt(t *testing.T) { + assert.Nil(t, ConvertToLoopPrompt(nil)) + p := &prompt.Prompt{ + ID: gptr.Of(int64(1)), + PromptKey: gptr.Of("key"), + PromptBasic: &prompt.PromptBasic{ + DisplayName: gptr.Of("name"), + Description: gptr.Of("desc"), + LatestVersion: gptr.Of("v1"), + }, + PromptCommit: &prompt.PromptCommit{ + Detail: &prompt.PromptDetail{ + PromptTemplate: &prompt.PromptTemplate{ + VariableDefs: []*prompt.VariableDef{ + {Key: gptr.Of("k1"), Type: gptr.Of("t1"), TypeTags: []string{"tag1"}}, + }, + }, + }, + CommitInfo: &prompt.CommitInfo{ + Version: gptr.Of("v1"), + BaseVersion: gptr.Of("v0"), + Description: gptr.Of("commit desc"), + CommittedAt: gptr.Of(int64(123456789)), + CommittedBy: gptr.Of("1001"), + }, + }, + } + res := ConvertToLoopPrompt(p) + assert.NotNil(t, res) + assert.Equal(t, int64(1), res.ID) + assert.Equal(t, "key", res.PromptKey) + assert.Equal(t, "name", *res.PromptBasic.DisplayName) + assert.Len(t, res.PromptCommit.Detail.PromptTemplate.VariableDefs, 1) + assert.Equal(t, "v1", *res.PromptCommit.CommitInfo.Version) +} + +func TestConvertVariables2Prompt(t *testing.T) { + assert.Nil(t, ConvertVariables2Prompt(nil)) + vars := []*entity.VariableVal{ + { + Key: gptr.Of("k1"), + Value: gptr.Of("v1"), + PlaceholderMessages: []*entity.Message{ + {Role: entity.RoleUser, Content: &entity.Content{ContentType: gptr.Of(entity.ContentTypeText), Text: gptr.Of("msg")}}, + }, + }, + } + res := ConvertVariables2Prompt(vars) + assert.Len(t, res, 1) + assert.Equal(t, "k1", *res[0].Key) + assert.Len(t, res[0].PlaceholderMessages, 1) +} + +func TestConvertPromptToolCalls2Eval(t *testing.T) { + assert.Nil(t, ConvertPromptToolCalls2Eval(nil)) + calls := []*prompt.ToolCall{ + { + Index: gptr.Of(int64(0)), + ID: gptr.Of("id1"), + FunctionCall: &prompt.FunctionCall{ + Name: gptr.Of("func1"), + Arguments: gptr.Of(`{"a":1}`), + }, + }, + } + res := ConvertPromptToolCalls2Eval(calls) + assert.Len(t, res, 1) + assert.Equal(t, int64(0), res[0].Index) + assert.Equal(t, "id1", res[0].ID) + assert.Equal(t, "func1", res[0].FunctionCall.Name) +} + +func TestRole2PromptRole(t *testing.T) { + assert.Equal(t, prompt.RoleSystem, Role2PromptRole(entity.RoleSystem)) + assert.Equal(t, prompt.RoleUser, Role2PromptRole(entity.RoleUser)) + assert.Equal(t, prompt.RoleAssistant, Role2PromptRole(entity.RoleAssistant)) + assert.Equal(t, prompt.RoleTool, Role2PromptRole(entity.RoleTool)) + assert.Equal(t, prompt.RoleUser, Role2PromptRole(entity.Role(99))) +} + +func TestConvertContent(t *testing.T) { + assert.Nil(t, ConvertContent(nil)) + + t.Run("text", func(t *testing.T) { + c := &entity.Content{ContentType: gptr.Of(entity.ContentTypeText), Text: gptr.Of("text")} + res := ConvertContent(c) + assert.Len(t, res, 1) + assert.Equal(t, prompt.ContentTypeText, *res[0].Type) + }) + + t.Run("image", func(t *testing.T) { + c := &entity.Content{ + ContentType: gptr.Of(entity.ContentTypeImage), + Image: &entity.Image{ + URL: gptr.Of("url"), + URI: gptr.Of("uri"), + }, + } + res := ConvertContent(c) + assert.Len(t, res, 1) + assert.Equal(t, prompt.ContentTypeImageURL, *res[0].Type) + assert.Equal(t, "url", *res[0].ImageURL.URL) + }) + + t.Run("multipart", func(t *testing.T) { + c := &entity.Content{ + ContentType: gptr.Of(entity.ContentTypeMultipart), + MultiPart: []*entity.Content{ + {ContentType: gptr.Of(entity.ContentTypeText), Text: gptr.Of("t1")}, + {ContentType: gptr.Of(entity.ContentTypeImage), Image: &entity.Image{URL: gptr.Of("u1")}}, + }, + } + res := ConvertContent(c) + assert.Len(t, res, 2) + assert.Equal(t, prompt.ContentTypeText, *res[0].Type) + assert.Equal(t, prompt.ContentTypeImageURL, *res[1].Type) + }) + + t.Run("default", func(t *testing.T) { + c := &entity.Content{ContentType: gptr.Of(entity.ContentType("unknown"))} + res := ConvertContent(c) + assert.Len(t, res, 0) + }) +} diff --git a/backend/modules/evaluation/infra/rpc/prompt/mocks/prompt_execute_client.go b/backend/modules/evaluation/infra/rpc/prompt/mocks/prompt_execute_client.go new file mode 100644 index 000000000..b184f1c51 --- /dev/null +++ b/backend/modules/evaluation/infra/rpc/prompt/mocks/prompt_execute_client.go @@ -0,0 +1,63 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/apis/promptexecuteservice (interfaces: Client) +// +// Generated by this command: +// +// mockgen -package=mocks -mock_names=Client=MockPromptExecuteClient github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/apis/promptexecuteservice Client +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + callopt "github.com/cloudwego/kitex/client/callopt" + execute "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/prompt/execute" + gomock "go.uber.org/mock/gomock" +) + +// MockPromptExecuteClient is a mock of Client interface. +type MockPromptExecuteClient struct { + ctrl *gomock.Controller + recorder *MockPromptExecuteClientMockRecorder + isgomock struct{} +} + +// MockPromptExecuteClientMockRecorder is the mock recorder for MockPromptExecuteClient. +type MockPromptExecuteClientMockRecorder struct { + mock *MockPromptExecuteClient +} + +// NewMockPromptExecuteClient creates a new mock instance. +func NewMockPromptExecuteClient(ctrl *gomock.Controller) *MockPromptExecuteClient { + mock := &MockPromptExecuteClient{ctrl: ctrl} + mock.recorder = &MockPromptExecuteClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPromptExecuteClient) EXPECT() *MockPromptExecuteClientMockRecorder { + return m.recorder +} + +// ExecuteInternal mocks base method. +func (m *MockPromptExecuteClient) ExecuteInternal(ctx context.Context, req *execute.ExecuteInternalRequest, callOptions ...callopt.Option) (*execute.ExecuteInternalResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, req} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ExecuteInternal", varargs...) + ret0, _ := ret[0].(*execute.ExecuteInternalResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ExecuteInternal indicates an expected call of ExecuteInternal. +func (mr *MockPromptExecuteClientMockRecorder) ExecuteInternal(ctx, req any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, req}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecuteInternal", reflect.TypeOf((*MockPromptExecuteClient)(nil).ExecuteInternal), varargs...) +} diff --git a/backend/modules/evaluation/infra/rpc/prompt/mocks/prompt_manage_client.go b/backend/modules/evaluation/infra/rpc/prompt/mocks/prompt_manage_client.go new file mode 100644 index 000000000..9eb57e767 --- /dev/null +++ b/backend/modules/evaluation/infra/rpc/prompt/mocks/prompt_manage_client.go @@ -0,0 +1,363 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/prompt/promptmanageservice (interfaces: Client) +// +// Generated by this command: +// +// mockgen -package=mocks -mock_names=Client=MockPromptManageClient github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/prompt/promptmanageservice Client +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + callopt "github.com/cloudwego/kitex/client/callopt" + manage "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/prompt/manage" + gomock "go.uber.org/mock/gomock" +) + +// MockPromptManageClient is a mock of Client interface. +type MockPromptManageClient struct { + ctrl *gomock.Controller + recorder *MockPromptManageClientMockRecorder + isgomock struct{} +} + +// MockPromptManageClientMockRecorder is the mock recorder for MockPromptManageClient. +type MockPromptManageClientMockRecorder struct { + mock *MockPromptManageClient +} + +// NewMockPromptManageClient creates a new mock instance. +func NewMockPromptManageClient(ctrl *gomock.Controller) *MockPromptManageClient { + mock := &MockPromptManageClient{ctrl: ctrl} + mock.recorder = &MockPromptManageClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPromptManageClient) EXPECT() *MockPromptManageClientMockRecorder { + return m.recorder +} + +// BatchGetLabel mocks base method. +func (m *MockPromptManageClient) BatchGetLabel(ctx context.Context, request *manage.BatchGetLabelRequest, callOptions ...callopt.Option) (*manage.BatchGetLabelResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, request} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "BatchGetLabel", varargs...) + ret0, _ := ret[0].(*manage.BatchGetLabelResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BatchGetLabel indicates an expected call of BatchGetLabel. +func (mr *MockPromptManageClientMockRecorder) BatchGetLabel(ctx, request any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, request}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetLabel", reflect.TypeOf((*MockPromptManageClient)(nil).BatchGetLabel), varargs...) +} + +// BatchGetPrompt mocks base method. +func (m *MockPromptManageClient) BatchGetPrompt(ctx context.Context, request *manage.BatchGetPromptRequest, callOptions ...callopt.Option) (*manage.BatchGetPromptResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, request} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "BatchGetPrompt", varargs...) + ret0, _ := ret[0].(*manage.BatchGetPromptResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BatchGetPrompt indicates an expected call of BatchGetPrompt. +func (mr *MockPromptManageClientMockRecorder) BatchGetPrompt(ctx, request any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, request}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetPrompt", reflect.TypeOf((*MockPromptManageClient)(nil).BatchGetPrompt), varargs...) +} + +// ClonePrompt mocks base method. +func (m *MockPromptManageClient) ClonePrompt(ctx context.Context, request *manage.ClonePromptRequest, callOptions ...callopt.Option) (*manage.ClonePromptResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, request} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ClonePrompt", varargs...) + ret0, _ := ret[0].(*manage.ClonePromptResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ClonePrompt indicates an expected call of ClonePrompt. +func (mr *MockPromptManageClientMockRecorder) ClonePrompt(ctx, request any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, request}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClonePrompt", reflect.TypeOf((*MockPromptManageClient)(nil).ClonePrompt), varargs...) +} + +// CommitDraft mocks base method. +func (m *MockPromptManageClient) CommitDraft(ctx context.Context, request *manage.CommitDraftRequest, callOptions ...callopt.Option) (*manage.CommitDraftResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, request} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CommitDraft", varargs...) + ret0, _ := ret[0].(*manage.CommitDraftResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CommitDraft indicates an expected call of CommitDraft. +func (mr *MockPromptManageClientMockRecorder) CommitDraft(ctx, request any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, request}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CommitDraft", reflect.TypeOf((*MockPromptManageClient)(nil).CommitDraft), varargs...) +} + +// CreateLabel mocks base method. +func (m *MockPromptManageClient) CreateLabel(ctx context.Context, request *manage.CreateLabelRequest, callOptions ...callopt.Option) (*manage.CreateLabelResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, request} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CreateLabel", varargs...) + ret0, _ := ret[0].(*manage.CreateLabelResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateLabel indicates an expected call of CreateLabel. +func (mr *MockPromptManageClientMockRecorder) CreateLabel(ctx, request any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, request}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateLabel", reflect.TypeOf((*MockPromptManageClient)(nil).CreateLabel), varargs...) +} + +// CreatePrompt mocks base method. +func (m *MockPromptManageClient) CreatePrompt(ctx context.Context, request *manage.CreatePromptRequest, callOptions ...callopt.Option) (*manage.CreatePromptResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, request} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CreatePrompt", varargs...) + ret0, _ := ret[0].(*manage.CreatePromptResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreatePrompt indicates an expected call of CreatePrompt. +func (mr *MockPromptManageClientMockRecorder) CreatePrompt(ctx, request any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, request}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreatePrompt", reflect.TypeOf((*MockPromptManageClient)(nil).CreatePrompt), varargs...) +} + +// DeletePrompt mocks base method. +func (m *MockPromptManageClient) DeletePrompt(ctx context.Context, request *manage.DeletePromptRequest, callOptions ...callopt.Option) (*manage.DeletePromptResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, request} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "DeletePrompt", varargs...) + ret0, _ := ret[0].(*manage.DeletePromptResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeletePrompt indicates an expected call of DeletePrompt. +func (mr *MockPromptManageClientMockRecorder) DeletePrompt(ctx, request any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, request}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePrompt", reflect.TypeOf((*MockPromptManageClient)(nil).DeletePrompt), varargs...) +} + +// GetPrompt mocks base method. +func (m *MockPromptManageClient) GetPrompt(ctx context.Context, request *manage.GetPromptRequest, callOptions ...callopt.Option) (*manage.GetPromptResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, request} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetPrompt", varargs...) + ret0, _ := ret[0].(*manage.GetPromptResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPrompt indicates an expected call of GetPrompt. +func (mr *MockPromptManageClientMockRecorder) GetPrompt(ctx, request any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, request}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrompt", reflect.TypeOf((*MockPromptManageClient)(nil).GetPrompt), varargs...) +} + +// ListCommit mocks base method. +func (m *MockPromptManageClient) ListCommit(ctx context.Context, request *manage.ListCommitRequest, callOptions ...callopt.Option) (*manage.ListCommitResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, request} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ListCommit", varargs...) + ret0, _ := ret[0].(*manage.ListCommitResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListCommit indicates an expected call of ListCommit. +func (mr *MockPromptManageClientMockRecorder) ListCommit(ctx, request any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, request}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListCommit", reflect.TypeOf((*MockPromptManageClient)(nil).ListCommit), varargs...) +} + +// ListLabel mocks base method. +func (m *MockPromptManageClient) ListLabel(ctx context.Context, request *manage.ListLabelRequest, callOptions ...callopt.Option) (*manage.ListLabelResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, request} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ListLabel", varargs...) + ret0, _ := ret[0].(*manage.ListLabelResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListLabel indicates an expected call of ListLabel. +func (mr *MockPromptManageClientMockRecorder) ListLabel(ctx, request any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, request}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListLabel", reflect.TypeOf((*MockPromptManageClient)(nil).ListLabel), varargs...) +} + +// ListParentPrompt mocks base method. +func (m *MockPromptManageClient) ListParentPrompt(ctx context.Context, request *manage.ListParentPromptRequest, callOptions ...callopt.Option) (*manage.ListParentPromptResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, request} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ListParentPrompt", varargs...) + ret0, _ := ret[0].(*manage.ListParentPromptResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListParentPrompt indicates an expected call of ListParentPrompt. +func (mr *MockPromptManageClientMockRecorder) ListParentPrompt(ctx, request any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, request}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListParentPrompt", reflect.TypeOf((*MockPromptManageClient)(nil).ListParentPrompt), varargs...) +} + +// ListPrompt mocks base method. +func (m *MockPromptManageClient) ListPrompt(ctx context.Context, request *manage.ListPromptRequest, callOptions ...callopt.Option) (*manage.ListPromptResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, request} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ListPrompt", varargs...) + ret0, _ := ret[0].(*manage.ListPromptResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListPrompt indicates an expected call of ListPrompt. +func (mr *MockPromptManageClientMockRecorder) ListPrompt(ctx, request any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, request}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListPrompt", reflect.TypeOf((*MockPromptManageClient)(nil).ListPrompt), varargs...) +} + +// RevertDraftFromCommit mocks base method. +func (m *MockPromptManageClient) RevertDraftFromCommit(ctx context.Context, request *manage.RevertDraftFromCommitRequest, callOptions ...callopt.Option) (*manage.RevertDraftFromCommitResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, request} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "RevertDraftFromCommit", varargs...) + ret0, _ := ret[0].(*manage.RevertDraftFromCommitResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RevertDraftFromCommit indicates an expected call of RevertDraftFromCommit. +func (mr *MockPromptManageClientMockRecorder) RevertDraftFromCommit(ctx, request any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, request}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevertDraftFromCommit", reflect.TypeOf((*MockPromptManageClient)(nil).RevertDraftFromCommit), varargs...) +} + +// SaveDraft mocks base method. +func (m *MockPromptManageClient) SaveDraft(ctx context.Context, request *manage.SaveDraftRequest, callOptions ...callopt.Option) (*manage.SaveDraftResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, request} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "SaveDraft", varargs...) + ret0, _ := ret[0].(*manage.SaveDraftResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SaveDraft indicates an expected call of SaveDraft. +func (mr *MockPromptManageClientMockRecorder) SaveDraft(ctx, request any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, request}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveDraft", reflect.TypeOf((*MockPromptManageClient)(nil).SaveDraft), varargs...) +} + +// UpdateCommitLabels mocks base method. +func (m *MockPromptManageClient) UpdateCommitLabels(ctx context.Context, request *manage.UpdateCommitLabelsRequest, callOptions ...callopt.Option) (*manage.UpdateCommitLabelsResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, request} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "UpdateCommitLabels", varargs...) + ret0, _ := ret[0].(*manage.UpdateCommitLabelsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateCommitLabels indicates an expected call of UpdateCommitLabels. +func (mr *MockPromptManageClientMockRecorder) UpdateCommitLabels(ctx, request any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, request}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateCommitLabels", reflect.TypeOf((*MockPromptManageClient)(nil).UpdateCommitLabels), varargs...) +} + +// UpdatePrompt mocks base method. +func (m *MockPromptManageClient) UpdatePrompt(ctx context.Context, request *manage.UpdatePromptRequest, callOptions ...callopt.Option) (*manage.UpdatePromptResponse, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, request} + for _, a := range callOptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "UpdatePrompt", varargs...) + ret0, _ := ret[0].(*manage.UpdatePromptResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdatePrompt indicates an expected call of UpdatePrompt. +func (mr *MockPromptManageClientMockRecorder) UpdatePrompt(ctx, request any, callOptions ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, request}, callOptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePrompt", reflect.TypeOf((*MockPromptManageClient)(nil).UpdatePrompt), varargs...) +} diff --git a/backend/modules/evaluation/infra/rpc/prompt/prompt.go b/backend/modules/evaluation/infra/rpc/prompt/prompt.go index 86f1dbb16..48e62bae8 100644 --- a/backend/modules/evaluation/infra/rpc/prompt/prompt.go +++ b/backend/modules/evaluation/infra/rpc/prompt/prompt.go @@ -59,7 +59,7 @@ func (p PromptRPCAdapter) ExecutePrompt(ctx context.Context, spaceID int64, para if runtimeParam != nil && runtimeParam.ModelConfig != nil { req.OverridePromptParams = &prompt.OverridePromptParams{ ModelConfig: &prompt.ModelConfig{ - ModelID: gptr.Of(runtimeParam.ModelConfig.ModelID), + ModelID: runtimeParam.ModelConfig.ModelID, MaxTokens: runtimeParam.ModelConfig.MaxTokens, Temperature: runtimeParam.ModelConfig.Temperature, TopP: runtimeParam.ModelConfig.TopP, diff --git a/backend/modules/evaluation/infra/rpc/prompt/prompt_impl_test.go b/backend/modules/evaluation/infra/rpc/prompt/prompt_impl_test.go new file mode 100644 index 000000000..e55245a82 --- /dev/null +++ b/backend/modules/evaluation/infra/rpc/prompt/prompt_impl_test.go @@ -0,0 +1,244 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package prompt + +import ( + "context" + "testing" + + "github.com/bytedance/gg/gptr" + "github.com/cloudwego/kitex/client/callopt" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/coze-dev/coze-loop/backend/kitex_gen/base" + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/prompt/domain/prompt" + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/prompt/execute" + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/prompt/manage" + "github.com/coze-dev/coze-loop/backend/modules/evaluation/domain/component/rpc" + "github.com/coze-dev/coze-loop/backend/modules/evaluation/infra/rpc/prompt/mocks" +) + +func TestPromptRPCAdapter_ExecutePrompt(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockManage := mocks.NewMockPromptManageClient(ctrl) + mockExecute := mocks.NewMockPromptExecuteClient(ctrl) + + adapter := NewPromptRPCAdapter(mockManage, mockExecute) + + ctx := context.Background() + param := &rpc.ExecutePromptParam{ + PromptID: 1, + PromptVersion: "v1", + } + + t.Run("success", func(t *testing.T) { + mockExecute.EXPECT().ExecuteInternal(gomock.Any(), gomock.Any(), gomock.Any()).Return(&execute.ExecuteInternalResponse{ + BaseResp: &base.BaseResp{StatusCode: 0}, + Message: &prompt.Message{ + Content: gptr.Of("resp"), + }, + }, nil) + + res, err := adapter.ExecutePrompt(ctx, 1, param) + assert.NoError(t, err) + assert.Equal(t, "resp", *res.Content) + }) + + t.Run("error_base_resp", func(t *testing.T) { + mockExecute.EXPECT().ExecuteInternal(gomock.Any(), gomock.Any(), gomock.Any()).Return(&execute.ExecuteInternalResponse{ + BaseResp: &base.BaseResp{StatusCode: 500, StatusMessage: "error"}, + }, nil) + + _, err := adapter.ExecutePrompt(ctx, 1, param) + assert.Error(t, err) + }) + + t.Run("with_runtime_param", func(t *testing.T) { + paramWithRuntime := &rpc.ExecutePromptParam{ + PromptID: 1, + PromptVersion: "v1", + RuntimeParam: gptr.Of(`{"model_config":{"model_id":123,"max_tokens":100}}`), + } + mockExecute.EXPECT().ExecuteInternal(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, req *execute.ExecuteInternalRequest, opts ...callopt.Option) (*execute.ExecuteInternalResponse, error) { + assert.Equal(t, int64(123), *req.OverridePromptParams.ModelConfig.ModelID) + assert.Equal(t, int32(100), *req.OverridePromptParams.ModelConfig.MaxTokens) + return &execute.ExecuteInternalResponse{ + BaseResp: &base.BaseResp{StatusCode: 0}, + Message: &prompt.Message{Content: gptr.Of("resp")}, + }, nil + }) + + _, err := adapter.ExecutePrompt(ctx, 1, paramWithRuntime) + assert.NoError(t, err) + }) + + t.Run("invalid_runtime_param", func(t *testing.T) { + paramInvalid := &rpc.ExecutePromptParam{ + RuntimeParam: gptr.Of(`{invalid}`), + } + // It logs the error but continues without override + mockExecute.EXPECT().ExecuteInternal(gomock.Any(), gomock.Any(), gomock.Any()).Return(&execute.ExecuteInternalResponse{ + BaseResp: &base.BaseResp{StatusCode: 0}, + Message: &prompt.Message{Content: gptr.Of("resp")}, + }, nil) + _, err := adapter.ExecutePrompt(ctx, 1, paramInvalid) + assert.NoError(t, err) + }) + + t.Run("resp_nil", func(t *testing.T) { + mockExecute.EXPECT().ExecuteInternal(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + _, err := adapter.ExecutePrompt(ctx, 1, param) + assert.Error(t, err) + }) +} + +func TestPromptRPCAdapter_GetPrompt(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockManage := mocks.NewMockPromptManageClient(ctrl) + mockExecute := mocks.NewMockPromptExecuteClient(ctrl) + + adapter := NewPromptRPCAdapter(mockManage, mockExecute) + + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + mockManage.EXPECT().GetPrompt(gomock.Any(), gomock.Any(), gomock.Any()).Return(&manage.GetPromptResponse{ + BaseResp: &base.BaseResp{StatusCode: 0}, + Prompt: &prompt.Prompt{ID: gptr.Of(int64(1))}, + }, nil) + + res, err := adapter.GetPrompt(ctx, 1, 1, rpc.GetPromptParams{}) + assert.NoError(t, err) + assert.Equal(t, int64(1), res.ID) + }) + + t.Run("success_with_version", func(t *testing.T) { + mockManage.EXPECT().GetPrompt(gomock.Any(), gomock.Any(), gomock.Any()).Return(&manage.GetPromptResponse{ + BaseResp: &base.BaseResp{StatusCode: 0}, + Prompt: &prompt.Prompt{ID: gptr.Of(int64(1))}, + }, nil) + + res, err := adapter.GetPrompt(ctx, 1, 1, rpc.GetPromptParams{CommitVersion: gptr.Of("v1")}) + assert.NoError(t, err) + assert.Equal(t, int64(1), res.ID) + }) + + t.Run("not_found", func(t *testing.T) { + mockManage.EXPECT().GetPrompt(gomock.Any(), gomock.Any(), gomock.Any()).Return(&manage.GetPromptResponse{ + BaseResp: &base.BaseResp{StatusCode: 0}, + Prompt: nil, + }, nil) + + res, err := adapter.GetPrompt(ctx, 1, 1, rpc.GetPromptParams{}) + assert.NoError(t, err) + assert.Nil(t, res) + }) +} + +func TestPromptRPCAdapter_MGetPrompt(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockManage := mocks.NewMockPromptManageClient(ctrl) + mockExecute := mocks.NewMockPromptExecuteClient(ctrl) + + adapter := NewPromptRPCAdapter(mockManage, mockExecute) + + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + mockManage.EXPECT().BatchGetPrompt(gomock.Any(), gomock.Any(), gomock.Any()).Return(&manage.BatchGetPromptResponse{ + BaseResp: &base.BaseResp{StatusCode: 0}, + Results: []*manage.PromptResult_{ + {Prompt: &prompt.Prompt{ID: gptr.Of(int64(1))}}, + {Prompt: nil}, + }, + }, nil) + + res, err := adapter.MGetPrompt(ctx, 1, []*rpc.MGetPromptQuery{{PromptID: 1, Version: gptr.Of("v1")}}) + assert.NoError(t, err) + assert.Len(t, res, 1) + }) + + t.Run("rpc_error", func(t *testing.T) { + mockManage.EXPECT().BatchGetPrompt(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, assert.AnError) + _, err := adapter.MGetPrompt(ctx, 1, []*rpc.MGetPromptQuery{{PromptID: 1}}) + assert.Error(t, err) + }) +} + +func TestPromptRPCAdapter_ListPrompt(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockManage := mocks.NewMockPromptManageClient(ctrl) + mockExecute := mocks.NewMockPromptExecuteClient(ctrl) + + adapter := NewPromptRPCAdapter(mockManage, mockExecute) + + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + mockManage.EXPECT().ListPrompt(gomock.Any(), gomock.Any(), gomock.Any()).Return(&manage.ListPromptResponse{ + BaseResp: &base.BaseResp{StatusCode: 0}, + Prompts: []*prompt.Prompt{{ID: gptr.Of(int64(1))}}, + Total: gptr.Of(int32(1)), + }, nil) + + res, total, err := adapter.ListPrompt(ctx, &rpc.ListPromptParam{SpaceID: gptr.Of(int64(1))}) + assert.NoError(t, err) + assert.Equal(t, int32(1), *total) + assert.Len(t, res, 1) + }) + + t.Run("error", func(t *testing.T) { + mockManage.EXPECT().ListPrompt(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, assert.AnError) + _, _, err := adapter.ListPrompt(ctx, &rpc.ListPromptParam{}) + assert.Error(t, err) + }) +} + +func TestPromptRPCAdapter_ListPromptVersion(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockManage := mocks.NewMockPromptManageClient(ctrl) + mockExecute := mocks.NewMockPromptExecuteClient(ctrl) + + adapter := NewPromptRPCAdapter(mockManage, mockExecute) + + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + mockManage.EXPECT().ListCommit(gomock.Any(), gomock.Any(), gomock.Any()).Return(&manage.ListCommitResponse{ + BaseResp: &base.BaseResp{StatusCode: 0}, + PromptCommitInfos: []*prompt.CommitInfo{ + {Version: gptr.Of("v1")}, + }, + NextPageToken: gptr.Of("next"), + }, nil) + + res, next, err := adapter.ListPromptVersion(ctx, &rpc.ListPromptVersionParam{PromptID: 1}) + assert.NoError(t, err) + assert.Equal(t, "next", next) + assert.Len(t, res, 1) + }) + + t.Run("no_next_page", func(t *testing.T) { + mockManage.EXPECT().ListCommit(gomock.Any(), gomock.Any(), gomock.Any()).Return(&manage.ListCommitResponse{ + BaseResp: &base.BaseResp{StatusCode: 0}, + NextPageToken: nil, + }, nil) + + res, next, err := adapter.ListPromptVersion(ctx, &rpc.ListPromptVersionParam{PromptID: 1}) + assert.NoError(t, err) + assert.Equal(t, "", next) + assert.Len(t, res, 0) + }) +} diff --git a/backend/modules/llm/application/convertor/manage.go b/backend/modules/llm/application/convertor/manage.go index 979410d27..035788dba 100644 --- a/backend/modules/llm/application/convertor/manage.go +++ b/backend/modules/llm/application/convertor/manage.go @@ -4,8 +4,12 @@ package convertor import ( + "github.com/bytedance/gg/gptr" + "github.com/bytedance/gg/gslice" + "github.com/bytedance/gg/gvalue" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/common" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/manage" + manage2 "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/manage" "github.com/coze-dev/coze-loop/backend/modules/llm/domain/entity" "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/lang/slices" @@ -25,16 +29,127 @@ func ModelDO2DTO(model *entity.Model, mask bool) *manage.Model { if !mask { pc = ProtocolConfigDO2DTO(model.ProtocolConfig) } - return &manage.Model{ - ModelID: ptr.Of(model.ID), - WorkspaceID: ptr.Of(model.WorkspaceID), - Name: ptr.Of(model.Name), - Desc: ptr.Of(model.Desc), - Ability: AbilityDO2DTO(model.Ability), - Protocol: ptr.Of(manage.Protocol(model.Protocol)), - ProtocolConfig: pc, - ScenarioConfigs: ScenarioConfigMapDO2DTO(model.ScenarioConfigs), - ParamConfig: ParamConfigDO2DTO(model.ParamConfig), + resp := &manage.Model{ + ModelID: ptr.Of(model.ID), + WorkspaceID: ptr.Of(model.WorkspaceID), + Name: ptr.Of(model.Name), + Desc: ptr.Of(model.Desc), + Ability: AbilityDO2DTO(model.Ability), + Protocol: ptr.Of(manage.Protocol(model.Protocol)), + ProtocolConfig: pc, + Identification: ptr.Of(model.Identification), + Icon: ptr.Of(model.Icon), + Status: ptr.Of(ModelStatusDO2DTO(model.Status)), + Tags: model.Tags, + Series: SeriesDO2DTO(model.Series), + Visibility: VisibilityDO2DTO(model.Visibility), + ScenarioConfigs: ScenarioConfigMapDO2DTO(model.ScenarioConfigs), + ParamConfig: ParamConfigDO2DTO(model.ParamConfig), + OriginalModelURL: ptr.Of(model.OriginalModelURL), + PresetModel: ptr.Of(model.PresetModel), + } + if gvalue.IsNotZero(model.CreatedAt) { + resp.CreatedAt = gptr.Of(model.CreatedAt) + } + if gvalue.IsNotZero(model.UpdatedAt) { + resp.UpdatedAt = gptr.Of(model.UpdatedAt) + } + if gvalue.IsNotZero(model.CreatedBy) { + resp.CreatedBy = gptr.Of(model.CreatedBy) + } + if gvalue.IsNotZero(model.UpdatedBy) { + resp.UpdatedBy = gptr.Of(model.UpdatedBy) + } + return resp +} + +func SeriesDO2DTO(v *entity.Series) *manage.Series { + if v == nil { + return nil + } + return &manage.Series{ + Name: ptr.Of(v.Name), + Icon: ptr.Of(v.Icon), + Family: ptr.Of(FamilyDO2DTO(v.Family)), + } +} + +func FamilyDO2DTO(v entity.Family) manage.Family { + switch v { + case entity.FamilySeed: + return manage.FamilySeed + case entity.FamilyGLM: + return manage.FamilyGlm + case entity.FamilyKimi: + return manage.FamilyKimi + case entity.FamilyDeepSeek: + return manage.FamilyDeepseek + case entity.FamilyDoubao: + return manage.FamilyDoubao + default: + return manage.FamilyUndefined + } +} + +func FamilyDTO2DO(val manage.Family) entity.Family { + switch val { + case manage.FamilySeed: + return entity.FamilySeed + case manage.FamilyDeepseek: + return entity.FamilyDeepSeek + case manage.FamilyGlm: + return entity.FamilyGLM + case manage.FamilyKimi: + return entity.FamilyKimi + case manage.FamilyDoubao: + return entity.FamilyDoubao + default: + return entity.FamilyUndefined + } +} + +func VisibilityDO2DTO(v *entity.Visibility) *manage.Visibility { + if v == nil { + return nil + } + return &manage.Visibility{ + Mode: ptr.Of(VisibleModelDO2DTO(v.Mode)), + SpaceIDs: v.SpaceIDs, + } +} + +func VisibleModelDO2DTO(v entity.VisibleMode) manage.VisibleMode { + switch v { + case entity.VisibleModelAll: + return manage.VisibleModeAll + case entity.VisibleModelSpecified: + return manage.VisibleModeSpecified + case entity.VisibleModelDefault: + return manage.VisibleModeDefault + default: + return manage.VisibleModeUndefined + } +} + +func ModelStatusDO2DTO(status entity.ModelStatus) manage.ModelStatus { + switch status { + case entity.ModelStatusDisabled: + return manage.ModelStatusUnavailable + case entity.ModelStatusEnabled: + return manage.ModelStatusAvailable + default: + return manage.ModelStatusUndefined + } +} + +func ModelStatusDTO2DO(val manage.ModelStatus) entity.ModelStatus { + switch val { + case manage.ModelStatusUnavailable: + return entity.ModelStatusDisabled + case manage.ModelStatusAvailable: + return entity.ModelStatusEnabled + default: + return entity.ModelStatusUndefined } } @@ -269,6 +384,19 @@ func ParamSchemaDO2DTO(ps *entity.ParamSchema) *manage.ParamSchema { Max: ptr.Of(ps.Max), DefaultValue: ptr.Of(ps.DefaultValue), Options: ParamOptionsDO2DTO(ps.Options), + Properties: gslice.Map(ps.Properties, ParamSchemaDO2DTO), + Jsonpath: ptr.Of(ps.JsonPath), + Reaction: ReactionDO2DTO(ps.Reaction), + } +} + +func ReactionDO2DTO(r *entity.Reaction) *manage.Reaction { + if r == nil { + return nil + } + return &manage.Reaction{ + Dependency: ptr.Of(r.Dependency), + Visible: ptr.Of(r.Visible), } } @@ -287,3 +415,28 @@ func ParamOptionDO2DTO(o *entity.ParamOption) *manage.ParamOption { Label: ptr.Of(o.Label), } } + +func AbilityEnumDTO2DO(val manage.AbilityEnum) entity.AbilityEnum { + switch val { + case manage.AbilityJSONMode: + return entity.AbilityEnumJsonMode + case manage.AbilityFunctionCall: + return entity.AbilityEnumFunctionCall + case manage.AbilityMultiModal_: + return entity.AbilityEnumMultiModal + default: + return entity.AbilityEnumUndefined + } +} + +func ListModelsFilterDTO2DO(val *manage2.Filter) *entity.ListModelsFilter { + if val == nil { + return nil + } + return &entity.ListModelsFilter{ + NameLike: val.NameLike, + Families: gslice.Map(val.Families, FamilyDTO2DO), + ModelStatuses: gslice.Map(val.Statuses, ModelStatusDTO2DO), + Abilities: gslice.Map(val.Abilities, AbilityEnumDTO2DO), + } +} diff --git a/backend/modules/llm/application/convertor/manage_test.go b/backend/modules/llm/application/convertor/manage_test.go new file mode 100644 index 000000000..6d2b8dc48 --- /dev/null +++ b/backend/modules/llm/application/convertor/manage_test.go @@ -0,0 +1,342 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package convertor + +import ( + "testing" + + "github.com/bytedance/gg/gptr" + "github.com/stretchr/testify/assert" + + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/manage" + manage2 "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/manage" + "github.com/coze-dev/coze-loop/backend/modules/llm/domain/entity" +) + +func TestModelDO2DTO(t *testing.T) { + model := &entity.Model{ + ID: 1, + WorkspaceID: 2, + Name: "model1", + Desc: "desc1", + Protocol: "ark", + ProtocolConfig: &entity.ProtocolConfig{ + BaseURL: "http://test.com", + }, + Status: entity.ModelStatusEnabled, + PresetModel: true, + CreatedAt: 123456, + } + + t.Run("no mask", func(t *testing.T) { + got := ModelDO2DTO(model, false) + assert.NotNil(t, got) + assert.Equal(t, model.ID, *got.ModelID) + assert.NotNil(t, got.ProtocolConfig) + assert.Equal(t, model.ProtocolConfig.BaseURL, *got.ProtocolConfig.BaseURL) + assert.Equal(t, int64(123456), *got.CreatedAt) + }) + + t.Run("with mask", func(t *testing.T) { + got := ModelDO2DTO(model, true) + assert.NotNil(t, got) + assert.Nil(t, got.ProtocolConfig) + }) + + t.Run("nil input", func(t *testing.T) { + got := ModelDO2DTO(nil, false) + assert.Nil(t, got) + }) +} + +func TestModelsDO2DTO(t *testing.T) { + models := []*entity.Model{ + {ID: 1}, + {ID: 2}, + } + got := ModelsDO2DTO(models, false) + assert.Len(t, got, 2) + assert.Equal(t, int64(1), *got[0].ModelID) + assert.Equal(t, int64(2), *got[1].ModelID) +} + +func TestSeriesDO2DTO(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + got := SeriesDO2DTO(nil) + assert.Nil(t, got) + }) + + t.Run("valid input", func(t *testing.T) { + v := &entity.Series{ + Name: "series1", + Icon: "icon1", + Family: entity.FamilySeed, + } + got := SeriesDO2DTO(v) + assert.Equal(t, v.Name, *got.Name) + assert.Equal(t, v.Icon, *got.Icon) + assert.Equal(t, manage.FamilySeed, *got.Family) + }) +} + +func TestFamilyDO2DTO(t *testing.T) { + tests := []struct { + name string + from entity.Family + want manage.Family + }{ + {"seed", entity.FamilySeed, manage.FamilySeed}, + {"deepseek", entity.FamilyDeepSeek, manage.FamilyDeepseek}, + {"glm", entity.FamilyGLM, manage.FamilyGlm}, + {"kimi", entity.FamilyKimi, manage.FamilyKimi}, + {"doubao", entity.FamilyDoubao, manage.FamilyDoubao}, + {"undefined", "other", manage.FamilyUndefined}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, FamilyDO2DTO(tt.from)) + }) + } +} + +func TestFamilyDTO2DO(t *testing.T) { + tests := []struct { + name string + from manage.Family + want entity.Family + }{ + {"seed", manage.FamilySeed, entity.FamilySeed}, + {"deepseek", manage.FamilyDeepseek, entity.FamilyDeepSeek}, + {"glm", manage.FamilyGlm, entity.FamilyGLM}, + {"kimi", manage.FamilyKimi, entity.FamilyKimi}, + {"doubao", manage.FamilyDoubao, entity.FamilyDoubao}, + {"undefined", manage.FamilyUndefined, entity.FamilyUndefined}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, FamilyDTO2DO(tt.from)) + }) + } +} + +func TestVisibilityDO2DTO(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + assert.Nil(t, VisibilityDO2DTO(nil)) + }) + t.Run("valid input", func(t *testing.T) { + v := &entity.Visibility{ + Mode: entity.VisibleModelAll, + SpaceIDs: []int64{1, 2}, + } + got := VisibilityDO2DTO(v) + assert.Equal(t, manage.VisibleModeAll, *got.Mode) + assert.Equal(t, v.SpaceIDs, got.SpaceIDs) + }) +} + +func TestModelStatusConvert(t *testing.T) { + assert.Equal(t, manage.ModelStatusAvailable, ModelStatusDO2DTO(entity.ModelStatusEnabled)) + assert.Equal(t, entity.ModelStatusEnabled, ModelStatusDTO2DO(manage.ModelStatusAvailable)) +} + +func TestAbilityDO2DTO(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + assert.Nil(t, AbilityDO2DTO(nil)) + }) + t.Run("valid input", func(t *testing.T) { + a := &entity.Ability{ + MaxContextTokens: gptr.Of(int64(100)), + FunctionCall: true, + } + got := AbilityDO2DTO(a) + assert.Equal(t, a.MaxContextTokens, got.MaxContextTokens) + assert.True(t, *got.FunctionCall) + }) +} + +func TestProtocolConfigArkDO2DTO(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + assert.Nil(t, ProtocolConfigArkDO2DTO(nil)) + }) + t.Run("valid input", func(t *testing.T) { + p := &entity.ProtocolConfigArk{ + Region: "cn-beijing", + AccessKey: "ak", + } + got := ProtocolConfigArkDO2DTO(p) + assert.Equal(t, p.Region, *got.Region) + assert.Equal(t, p.AccessKey, *got.AccessKey) + }) +} + +func TestListModelsFilterDTO2DO(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + assert.Nil(t, ListModelsFilterDTO2DO(nil)) + }) + t.Run("valid input", func(t *testing.T) { + f := &manage2.Filter{ + NameLike: gptr.Of("test"), + Families: []manage.Family{manage.FamilySeed}, + } + got := ListModelsFilterDTO2DO(f) + assert.Equal(t, f.NameLike, got.NameLike) + assert.Equal(t, entity.FamilySeed, got.Families[0]) + }) +} + +func TestAbilityImageDO2DTO(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + assert.Nil(t, AbilityImageDO2DTO(nil)) + }) + t.Run("valid input", func(t *testing.T) { + a := &entity.AbilityImage{ + URLEnabled: true, + BinaryEnabled: false, + MaxImageSize: 1024, + MaxImageCount: 5, + } + got := AbilityImageDO2DTO(a) + assert.True(t, *got.URLEnabled) + assert.False(t, *got.BinaryEnabled) + assert.Equal(t, a.MaxImageSize, *got.MaxImageSize) + }) +} + +func TestProtocolConfigDO2DTO(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + assert.Nil(t, ProtocolConfigDO2DTO(nil)) + }) + t.Run("full input", func(t *testing.T) { + p := &entity.ProtocolConfig{ + BaseURL: "http://test.com", + APIKey: "key", + Model: "model", + ProtocolConfigArk: &entity.ProtocolConfigArk{ + Region: "region", + }, + ProtocolConfigOpenAI: &entity.ProtocolConfigOpenAI{ + ByAzure: true, + }, + ProtocolConfigClaude: &entity.ProtocolConfigClaude{ + ByBedrock: true, + }, + ProtocolConfigDeepSeek: &entity.ProtocolConfigDeepSeek{ + ResponseFormatType: "json", + }, + ProtocolConfigGemini: &entity.ProtocolConfigGemini{ + EnableCodeExecution: true, + SafetySettings: []entity.ProtocolConfigGeminiSafetySetting{ + {Category: 1, Threshold: 2}, + }, + }, + ProtocolConfigQwen: &entity.ProtocolConfigQwen{ + ResponseFormatType: gptr.Of("json"), + }, + ProtocolConfigQianfan: &entity.ProtocolConfigQianfan{ + LLMRetryCount: gptr.Of(3), + LLMRetryTimeout: gptr.Of(float32(1.5)), + LLMRetryBackoffFactor: gptr.Of(float32(2.0)), + }, + ProtocolConfigOllama: &entity.ProtocolConfigOllama{ + Format: gptr.Of("json"), + }, + ProtocolConfigArkBot: &entity.ProtocolConfigArkBot{ + Region: "region", + }, + } + got := ProtocolConfigDO2DTO(p) + assert.NotNil(t, got) + assert.Equal(t, p.BaseURL, *got.BaseURL) + assert.True(t, *got.ProtocolConfigOpenai.ByAzure) + assert.Len(t, got.ProtocolConfigGemini.SafetySettings, 1) + assert.Equal(t, int32(1), *got.ProtocolConfigGemini.SafetySettings[0].Category) + }) +} + +func TestVisibleModelDO2DTO(t *testing.T) { + assert.Equal(t, manage.VisibleModeAll, VisibleModelDO2DTO(entity.VisibleModelAll)) + assert.Equal(t, manage.VisibleModeSpecified, VisibleModelDO2DTO(entity.VisibleModelSpecified)) + assert.Equal(t, manage.VisibleModeDefault, VisibleModelDO2DTO(entity.VisibleModelDefault)) + assert.Equal(t, manage.VisibleModeUndefined, VisibleModelDO2DTO(entity.VisibleModeUndefined)) +} + +func TestModelStatusDO2DTO(t *testing.T) { + assert.Equal(t, manage.ModelStatusAvailable, ModelStatusDO2DTO(entity.ModelStatusEnabled)) + assert.Equal(t, manage.ModelStatusUnavailable, ModelStatusDO2DTO(entity.ModelStatusDisabled)) + assert.Equal(t, manage.ModelStatusUndefined, ModelStatusDO2DTO(entity.ModelStatusUndefined)) +} + +func TestModelStatusDTO2DO(t *testing.T) { + assert.Equal(t, entity.ModelStatusEnabled, ModelStatusDTO2DO(manage.ModelStatusAvailable)) + assert.Equal(t, entity.ModelStatusDisabled, ModelStatusDTO2DO(manage.ModelStatusUnavailable)) + assert.Equal(t, entity.ModelStatusUndefined, ModelStatusDTO2DO(manage.ModelStatusUndefined)) +} + +func TestAbilityMultiModalDO2DTO(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + assert.Nil(t, AbilityMultiModalDO2DTO(nil)) + }) + t.Run("valid input", func(t *testing.T) { + a := &entity.AbilityMultiModal{ + Image: true, + AbilityImage: &entity.AbilityImage{ + URLEnabled: true, + }, + } + got := AbilityMultiModalDO2DTO(a) + assert.True(t, *got.Image) + assert.True(t, *got.AbilityImage.URLEnabled) + }) +} + +func TestScenarioConfigMapDO2DTO(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + assert.Nil(t, ScenarioConfigMapDO2DTO(nil)) + }) + t.Run("valid input", func(t *testing.T) { + m := map[entity.Scenario]*entity.ScenarioConfig{ + entity.ScenarioDefault: { + Scenario: entity.ScenarioDefault, + Quota: &entity.Quota{ + Qpm: 10, + }, + }, + } + got := ScenarioConfigMapDO2DTO(m) + assert.Len(t, got, 1) + }) +} + +func TestParamConfigDO2DTO(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + assert.Nil(t, ParamConfigDO2DTO(nil)) + }) + t.Run("valid input", func(t *testing.T) { + p := &entity.ParamConfig{ + ParamSchemas: []*entity.ParamSchema{ + { + Name: "param1", + Options: []*entity.ParamOption{ + {Value: "v1"}, + }, + Reaction: &entity.Reaction{ + Dependency: "dep1", + }, + }, + }, + } + got := ParamConfigDO2DTO(p) + assert.Len(t, got.ParamSchemas, 1) + assert.Equal(t, "param1", *got.ParamSchemas[0].Name) + assert.Len(t, got.ParamSchemas[0].Options, 1) + assert.Equal(t, "dep1", *got.ParamSchemas[0].Reaction.Dependency) + }) +} + +func TestAbilityEnumDTO2DO(t *testing.T) { + assert.Equal(t, entity.AbilityEnumFunctionCall, AbilityEnumDTO2DO(manage.AbilityFunctionCall)) + assert.Equal(t, entity.AbilityEnumMultiModal, AbilityEnumDTO2DO(manage.AbilityMultiModal_)) + assert.Equal(t, entity.AbilityEnumJsonMode, AbilityEnumDTO2DO(manage.AbilityJSONMode)) + assert.Equal(t, entity.AbilityEnumUndefined, AbilityEnumDTO2DO(manage.AbilityUndefined)) +} diff --git a/backend/modules/llm/application/convertor/runtime_option.go b/backend/modules/llm/application/convertor/runtime_option.go index 731b112ac..18cf71c63 100644 --- a/backend/modules/llm/application/convertor/runtime_option.go +++ b/backend/modules/llm/application/convertor/runtime_option.go @@ -10,7 +10,7 @@ import ( "github.com/coze-dev/coze-loop/backend/pkg/lang/slices" ) -func ModelAndTools2OptionDOs(modelCfg *druntime.ModelConfig, tools []*druntime.Tool, parameters map[string]string) []entity.Option { +func ModelAndTools2OptionDOs(modelCfg *druntime.ModelConfig, tools []*druntime.Tool, parameters map[string]string, paramValues map[string]*entity.ParamValue) []entity.Option { var opts []entity.Option if modelCfg != nil { if modelCfg.Temperature != nil { @@ -50,6 +50,9 @@ func ModelAndTools2OptionDOs(modelCfg *druntime.ModelConfig, tools []*druntime.T if parameters != nil { opts = append(opts, entity.WithParameters(parameters)) } + if paramValues != nil { + opts = append(opts, entity.WithParamValues(paramValues)) + } return opts } diff --git a/backend/modules/llm/application/convertor/runtime_option_test.go b/backend/modules/llm/application/convertor/runtime_option_test.go new file mode 100644 index 000000000..7372e1697 --- /dev/null +++ b/backend/modules/llm/application/convertor/runtime_option_test.go @@ -0,0 +1,95 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package convertor + +import ( + "testing" + + "github.com/bytedance/gg/gptr" + "github.com/stretchr/testify/assert" + + druntime "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/runtime" + "github.com/coze-dev/coze-loop/backend/modules/llm/domain/entity" +) + +func TestModelAndTools2OptionDOs(t *testing.T) { + t.Run("full input", func(t *testing.T) { + modelCfg := &druntime.ModelConfig{ + Temperature: gptr.Of(float64(0.7)), + MaxTokens: gptr.Of(int64(100)), + TopP: gptr.Of(float64(0.9)), + Stop: []string{"stop1"}, + ToolChoice: gptr.Of(druntime.ToolChoiceAuto), + ResponseFormat: &druntime.ResponseFormat{Type: gptr.Of(druntime.ResponseFormatJSONObject)}, + TopK: gptr.Of(int32(10)), + PresencePenalty: gptr.Of(float64(0.5)), + FrequencyPenalty: gptr.Of(float64(0.6)), + } + tools := []*druntime.Tool{ + { + Name: gptr.Of("tool1"), + }, + } + parameters := map[string]string{"key1": "value1"} + paramValues := map[string]*entity.ParamValue{"pv1": {Value: "v1"}} + + got := ModelAndTools2OptionDOs(modelCfg, tools, parameters, paramValues) + assert.NotEmpty(t, got) + }) + + t.Run("nil input", func(t *testing.T) { + got := ModelAndTools2OptionDOs(nil, nil, nil, nil) + assert.Empty(t, got) + }) +} + +func TestToolsDTO2DO(t *testing.T) { + ts := []*druntime.Tool{ + { + Name: gptr.Of("tool1"), + }, + } + got := ToolsDTO2DO(ts) + assert.Len(t, got, 1) + assert.Equal(t, "tool1", got[0].Name) +} + +func TestResponseFormatDTO2DO(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + assert.Nil(t, ResponseFormatDTO2DO(nil)) + }) + t.Run("valid input", func(t *testing.T) { + r := &druntime.ResponseFormat{ + Type: gptr.Of(druntime.ResponseFormatJSONObject), + } + got := ResponseFormatDTO2DO(r) + assert.Equal(t, entity.ResponseFormatType(r.GetType()), got.Type) + }) +} + +func TestToolDTO2DO(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + assert.Nil(t, ToolDTO2DO(nil)) + }) + t.Run("valid input", func(t *testing.T) { + t1 := &druntime.Tool{ + Name: gptr.Of("tool1"), + Desc: gptr.Of("desc1"), + } + got := ToolDTO2DO(t1) + assert.Equal(t, *t1.Name, got.Name) + assert.Equal(t, *t1.Desc, got.Desc) + }) +} + +func TestToolChoiceDTO2DO(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + assert.Nil(t, ToolChoiceDTO2DO(nil)) + }) + t.Run("valid input", func(t *testing.T) { + tc := druntime.ToolChoiceAuto + got := ToolChoiceDTO2DO(&tc) + assert.Equal(t, entity.ToolChoice(tc), *got) + }) +} diff --git a/backend/modules/llm/application/convertor/runtime_test.go b/backend/modules/llm/application/convertor/runtime_test.go new file mode 100644 index 000000000..3a9f397d1 --- /dev/null +++ b/backend/modules/llm/application/convertor/runtime_test.go @@ -0,0 +1,148 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package convertor + +import ( + "testing" + + "github.com/bytedance/gg/gptr" + "github.com/stretchr/testify/assert" + + druntime "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/runtime" + "github.com/coze-dev/coze-loop/backend/modules/llm/domain/entity" +) + +func TestMessagesDTO2DO(t *testing.T) { + dtos := []*druntime.Message{ + { + Role: druntime.RoleUser, + Content: gptr.Of("hello"), + }, + } + got := MessagesDTO2DO(dtos) + assert.Len(t, got, 1) + assert.Equal(t, entity.RoleUser, got[0].Role) + assert.Equal(t, "hello", got[0].Content) +} + +func TestMessageDTO2DO(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + assert.Nil(t, MessageDTO2DO(nil)) + }) + t.Run("full input", func(t *testing.T) { + dto := &druntime.Message{ + Role: druntime.RoleAssistant, + Content: gptr.Of("content"), + ReasoningContent: gptr.Of("reasoning"), + MultimodalContents: []*druntime.ChatMessagePart{ + { + Type: gptr.Of(druntime.ChatMessagePartTypeText), + Text: gptr.Of("text"), + }, + }, + ToolCalls: []*druntime.ToolCall{ + { + ID: gptr.Of("call1"), + }, + }, + ToolCallID: gptr.Of("tc1"), + ResponseMeta: &druntime.ResponseMeta{ + FinishReason: gptr.Of("stop"), + Usage: &druntime.TokenUsage{ + PromptTokens: gptr.Of(int64(10)), + }, + }, + } + got := MessageDTO2DO(dto) + assert.NotNil(t, got) + assert.Equal(t, entity.RoleAssistant, got.Role) + assert.Equal(t, "content", got.Content) + assert.Equal(t, "reasoning", got.ReasoningContent) + assert.Len(t, got.MultiModalContent, 1) + assert.Len(t, got.ToolCalls, 1) + assert.Equal(t, "tc1", got.ToolCallID) + assert.Equal(t, "stop", got.ResponseMeta.FinishReason) + assert.Equal(t, 10, got.ResponseMeta.Usage.PromptTokens) + }) +} + +func TestMessageDO2DTO(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + assert.Nil(t, MessageDO2DTO(nil)) + }) + t.Run("full input", func(t *testing.T) { + do := &entity.Message{ + Role: entity.RoleUser, + Content: "content", + ReasoningContent: "reasoning", + MultiModalContent: []*entity.ChatMessagePart{ + { + Type: entity.ChatMessagePartTypeText, + Text: "text", + }, + }, + ToolCalls: []*entity.ToolCall{ + { + ID: "call1", + }, + }, + ToolCallID: "tc1", + ResponseMeta: &entity.ResponseMeta{ + FinishReason: "stop", + Usage: &entity.TokenUsage{ + PromptTokens: 10, + }, + }, + } + got := MessageDO2DTO(do) + assert.NotNil(t, got) + assert.Equal(t, druntime.RoleUser, got.Role) + assert.Equal(t, "content", *got.Content) + assert.Equal(t, "reasoning", *got.ReasoningContent) + assert.Len(t, got.MultimodalContents, 1) + assert.Len(t, got.ToolCalls, 1) + assert.Equal(t, "tc1", *got.ToolCallID) + assert.Equal(t, "stop", *got.ResponseMeta.FinishReason) + assert.Equal(t, int64(10), *got.ResponseMeta.Usage.PromptTokens) + }) +} + +func TestToolCallConvert(t *testing.T) { + dto := &druntime.ToolCall{ + Index: gptr.Of(int64(0)), + ID: gptr.Of("id1"), + Type: gptr.Of(druntime.ToolTypeFunction), + FunctionCall: &druntime.FunctionCall{ + Name: gptr.Of("func1"), + Arguments: gptr.Of("{}"), + }, + } + do := ToolCallDTO2DO(dto) + assert.Equal(t, "id1", do.ID) + assert.Equal(t, "func1", do.Function.Name) + + dto2 := ToolCallDO2DTO(do) + assert.Equal(t, "id1", *dto2.ID) + assert.Equal(t, "func1", *dto2.FunctionCall.Name) +} + +func TestChatMessagePartConvert(t *testing.T) { + dto := &druntime.ChatMessagePart{ + Type: gptr.Of(druntime.ChatMessagePartTypeImageURL), + ImageURL: &druntime.ChatMessageImageURL{ + URL: gptr.Of("http://img.com"), + Detail: gptr.Of(druntime.ImageURLDetailHigh), + MimeType: gptr.Of("image/png"), + }, + } + do := ChatMessagePartDTO2DO(dto) + assert.Equal(t, entity.ChatMessagePartTypeImageURL, do.Type) + assert.Equal(t, "http://img.com", do.ImageURL.URL) + assert.Equal(t, entity.ImageURLDetailHigh, do.ImageURL.Detail) + + dto2 := ChatMessagePartDO2DTO(do) + assert.Equal(t, druntime.ChatMessagePartTypeImageURL, *dto2.Type) + assert.Equal(t, "http://img.com", *dto2.ImageURL.URL) + assert.Equal(t, druntime.ImageURLDetailHigh, *dto2.ImageURL.Detail) +} diff --git a/backend/modules/llm/application/manage_test.go b/backend/modules/llm/application/manage_test.go index 725df6bcd..99e0ed52a 100644 --- a/backend/modules/llm/application/manage_test.go +++ b/backend/modules/llm/application/manage_test.go @@ -3,16 +3,107 @@ package application -//func TestListModel(t *testing.T) { -// app, _ := InitManageApplication(context.Background(), viper.NewFileConfigLoaderFactory()) -// resp, err := app.ListModels(context.Background(), &manage.ListModelsRequest{ -// WorkspaceID: nil, -// Scenario: nil, -// PageSize: nil, -// PageToken: nil, -// Base: nil, -// }) -// fmt.Println(err) -// respStr, _ := sonic.MarshalString(&resp) -// fmt.Println(respStr) -//} +import ( + "context" + "testing" + + "github.com/bytedance/gg/gptr" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/common" + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/manage" + "github.com/coze-dev/coze-loop/backend/modules/llm/domain/component/rpc/mocks" + "github.com/coze-dev/coze-loop/backend/modules/llm/domain/entity" + serviceMocks "github.com/coze-dev/coze-loop/backend/modules/llm/domain/service/mocks" +) + +func TestManageApp_ListModels(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockSrv := serviceMocks.NewMockIManage(ctrl) + mockAuth := mocks.NewMockIAuthProvider(ctrl) + app := NewManageApplication(mockSrv, mockAuth) + + ctx := context.Background() + + t.Run("auth_error", func(t *testing.T) { + mockAuth.EXPECT().CheckSpacePermission(ctx, int64(1), "listModels").Return(assert.AnError) + req := &manage.ListModelsRequest{WorkspaceID: gptr.Of(int64(1))} + res, err := app.ListModels(ctx, req) + assert.Error(t, err) + assert.NotNil(t, res) + }) + + t.Run("success", func(t *testing.T) { + mockAuth.EXPECT().CheckSpacePermission(ctx, int64(1), "listModels").Return(nil) + mockSrv.EXPECT().ListModels(ctx, gomock.Any()).Return([]*entity.Model{{ID: 1}}, int64(1), false, int64(0), nil) + + req := &manage.ListModelsRequest{ + WorkspaceID: gptr.Of(int64(1)), + Scenario: gptr.Of(common.ScenarioEvaluator), + PageToken: gptr.Of("0"), + PageSize: gptr.Of(int32(10)), + } + res, err := app.ListModels(ctx, req) + assert.NoError(t, err) + assert.Len(t, res.Models, 1) + assert.Equal(t, int32(1), *res.Total) + }) + + t.Run("success_no_scenario", func(t *testing.T) { + mockAuth.EXPECT().CheckSpacePermission(ctx, int64(1), "listModels").Return(nil) + mockSrv.EXPECT().ListModels(ctx, gomock.Any()).Return(nil, int64(0), false, int64(0), nil) + + req := &manage.ListModelsRequest{WorkspaceID: gptr.Of(int64(1))} + _, err := app.ListModels(ctx, req) + assert.NoError(t, err) + }) + + t.Run("srv_error", func(t *testing.T) { + mockAuth.EXPECT().CheckSpacePermission(ctx, int64(1), "listModels").Return(nil) + mockSrv.EXPECT().ListModels(ctx, gomock.Any()).Return(nil, int64(0), false, int64(0), assert.AnError) + + req := &manage.ListModelsRequest{WorkspaceID: gptr.Of(int64(1))} + _, err := app.ListModels(ctx, req) + assert.Error(t, err) + }) +} + +func TestManageApp_GetModel(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockSrv := serviceMocks.NewMockIManage(ctrl) + mockAuth := mocks.NewMockIAuthProvider(ctrl) + app := NewManageApplication(mockSrv, mockAuth) + + ctx := context.Background() + + t.Run("auth_error", func(t *testing.T) { + mockAuth.EXPECT().CheckSpacePermission(ctx, int64(1), "getModel").Return(assert.AnError) + req := &manage.GetModelRequest{WorkspaceID: gptr.Of(int64(1)), ModelID: gptr.Of(int64(1))} + _, err := app.GetModel(ctx, req) + assert.Error(t, err) + }) + + t.Run("success", func(t *testing.T) { + mockAuth.EXPECT().CheckSpacePermission(ctx, int64(1), "getModel").Return(nil) + mockSrv.EXPECT().GetModelByID(ctx, int64(100)).Return(&entity.Model{ID: 100}, nil) + + req := &manage.GetModelRequest{WorkspaceID: gptr.Of(int64(1)), ModelID: gptr.Of(int64(100))} + res, err := app.GetModel(ctx, req) + assert.NoError(t, err) + assert.Equal(t, int64(100), *res.Model.ModelID) + }) + + t.Run("srv_error", func(t *testing.T) { + mockAuth.EXPECT().CheckSpacePermission(ctx, int64(1), "getModel").Return(nil) + mockSrv.EXPECT().GetModelByID(ctx, int64(100)).Return(nil, assert.AnError) + + req := &manage.GetModelRequest{WorkspaceID: gptr.Of(int64(1)), ModelID: gptr.Of(int64(100))} + _, err := app.GetModel(ctx, req) + assert.Error(t, err) + }) +} diff --git a/backend/modules/llm/application/runtime.go b/backend/modules/llm/application/runtime.go index 3605c3589..cd25cc792 100644 --- a/backend/modules/llm/application/runtime.go +++ b/backend/modules/llm/application/runtime.go @@ -76,7 +76,7 @@ func (r *runtimeApp) Chat(ctx context.Context, req *runtime.ChatRequest) (resp * if err != nil { return resp, errorx.NewByCode(llm_errorx.RequestNotValidCode, errorx.WithExtraMsg(err.Error())) } - options := convertor.ModelAndTools2OptionDOs(req.GetModelConfig(), req.GetTools(), nil) + options := convertor.ModelAndTools2OptionDOs(req.GetModelConfig(), req.GetTools(), nil, nil) var respMsg *entity.Message // 5. start span var span looptracer.Span @@ -136,7 +136,7 @@ func (r *runtimeApp) ChatStream(ctx context.Context, req *runtime.ChatRequest, s if err != nil { return errorx.NewByCode(llm_errorx.RequestNotValidCode, errorx.WithExtraMsg(err.Error())) } - options := convertor.ModelAndTools2OptionDOs(req.GetModelConfig(), req.GetTools(), nil) + options := convertor.ModelAndTools2OptionDOs(req.GetModelConfig(), req.GetTools(), nil, nil) // 4. start trace var span looptracer.Span ctx, span = looptracer.GetTracer().StartSpan(ctx, model.Name, tracespec.VModelSpanType, looptracer.WithSpanWorkspaceID(strconv.FormatInt(req.GetBizParam().GetWorkspaceID(), 10))) diff --git a/backend/modules/llm/application/runtime_test.go b/backend/modules/llm/application/runtime_test.go index 109933ffa..099430c59 100644 --- a/backend/modules/llm/application/runtime_test.go +++ b/backend/modules/llm/application/runtime_test.go @@ -5,8 +5,10 @@ package application import ( "context" + "io" "testing" + "github.com/cloudwego/kitex/pkg/streaming" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" @@ -18,10 +20,14 @@ import ( "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/runtime" "github.com/coze-dev/coze-loop/backend/modules/llm/application/convertor" "github.com/coze-dev/coze-loop/backend/modules/llm/domain/entity" + entitymocks "github.com/coze-dev/coze-loop/backend/modules/llm/domain/entity/mocks" "github.com/coze-dev/coze-loop/backend/modules/llm/domain/service" llmservicemocks "github.com/coze-dev/coze-loop/backend/modules/llm/domain/service/mocks" + llm_errorx "github.com/coze-dev/coze-loop/backend/modules/llm/pkg/errno" + "github.com/coze-dev/coze-loop/backend/pkg/errorx" "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/unittest" + "github.com/pkg/errors" ) func Test_runtimeApp_Chat(t *testing.T) { @@ -174,10 +180,71 @@ func Test_runtimeApp_Chat(t *testing.T) { }, wantErr: nil, }, + { + name: "validate_fail", + fieldsGetter: func(ctrl *gomock.Controller) fields { + return fields{} + }, + args: args{ + ctx: context.Background(), + req: &runtime.ChatRequest{}, + }, + wantErr: errorx.NewByCode(llm_errorx.RequestNotValidCode), + }, + { + name: "get_model_fail", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockManage := llmservicemocks.NewMockIManage(ctrl) + mockManage.EXPECT().GetModelByID(gomock.Any(), gomock.Any()).Return(nil, errors.New("err")) + return fields{manageSrv: mockManage} + }, + args: args{ + ctx: context.Background(), + req: req, + }, + wantErr: errors.New("err"), + }, + { + name: "model_invalid", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockManage := llmservicemocks.NewMockIManage(ctrl) + mockManage.EXPECT().GetModelByID(gomock.Any(), gomock.Any()).Return(&entity.Model{}, nil) + return fields{manageSrv: mockManage} + }, + args: args{ + ctx: context.Background(), + req: req, + }, + wantErr: errorx.NewByCode(llm_errorx.ModelInvalidCode), + }, + { + name: "rate_limit_blocked", + fieldsGetter: func(ctrl *gomock.Controller) fields { + mockManage := llmservicemocks.NewMockIManage(ctrl) + mockLimiter := limitermocks.NewMockIRateLimiter(ctrl) + model := &entity.Model{ + ID: 1, Name: "model", Ability: &entity.Ability{}, + Protocol: "ark", + ProtocolConfig: &entity.ProtocolConfig{ + BaseURL: "http://test.com", + }, + ScenarioConfigs: map[entity.Scenario]*entity.ScenarioConfig{ + entity.ScenarioDefault: {Scenario: entity.ScenarioDefault, Quota: &entity.Quota{Qpm: 10}}, + }, + } + mockManage.EXPECT().GetModelByID(gomock.Any(), gomock.Any()).Return(model, nil) + mockLimiter.EXPECT().AllowN(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&limiter.Result{Allowed: false}, nil) + return fields{manageSrv: mockManage, rateLimiter: mockLimiter} + }, + args: args{ + ctx: context.Background(), + req: req, + }, + wantErr: errorx.NewByCode(llm_errorx.ModelQPMLimitCode), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Parallel() ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -197,3 +264,171 @@ func Test_runtimeApp_Chat(t *testing.T) { }) } } + +func TestNewRuntimeApplication(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockFactory := limitermocks.NewMockIRateLimiterFactory(ctrl) + mockFactory.EXPECT().NewRateLimiter().Return(nil) + got := NewRuntimeApplication(nil, nil, nil, mockFactory) + assert.NotNil(t, got) +} + +func Test_runtimeApp_validateChatReq(t *testing.T) { + r := &runtimeApp{} + tests := []struct { + name string + req *runtime.ChatRequest + }{ + {"nil model config", &runtime.ChatRequest{}}, + {"nil biz param", &runtime.ChatRequest{ModelConfig: &druntime.ModelConfig{ModelID: 1}, Messages: []*druntime.Message{{}}}}, + {"missing workspace id", &runtime.ChatRequest{ + ModelConfig: &druntime.ModelConfig{ModelID: 1}, + Messages: []*druntime.Message{{}}, + BizParam: &druntime.BizParam{Scenario: ptr.Of(common.ScenarioPromptDebug), ScenarioEntityID: ptr.Of("id")}, + }}, + {"missing scenario", &runtime.ChatRequest{ + ModelConfig: &druntime.ModelConfig{ModelID: 1}, + Messages: []*druntime.Message{{}}, + BizParam: &druntime.BizParam{WorkspaceID: ptr.Of(int64(1)), ScenarioEntityID: ptr.Of("id")}, + }}, + {"missing entity id", &runtime.ChatRequest{ + ModelConfig: &druntime.ModelConfig{ModelID: 1}, + Messages: []*druntime.Message{{}}, + BizParam: &druntime.BizParam{WorkspaceID: ptr.Of(int64(1)), Scenario: ptr.Of(common.ScenarioPromptDebug)}, + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := r.validateChatReq(context.Background(), tt.req) + assert.Error(t, err) + }) + } +} + +type mockChatStreamServer struct{} + +func (m *mockChatStreamServer) Send(ctx context.Context, resp *runtime.ChatResponse) error { + return nil +} + +func (m *mockChatStreamServer) RecvMsg(ctx context.Context, msg interface{}) error { + return nil +} + +func (m *mockChatStreamServer) SendMsg(ctx context.Context, msg interface{}) error { + return nil +} + +func (m *mockChatStreamServer) SetHeader(header streaming.Header) error { + return nil +} + +func (m *mockChatStreamServer) SendHeader(header streaming.Header) error { + return nil +} + +func (m *mockChatStreamServer) SetTrailer(trailer streaming.Trailer) error { + return nil +} + +func Test_runtimeApp_ChatStream(t *testing.T) { + req := &runtime.ChatRequest{ + ModelConfig: &druntime.ModelConfig{ + ModelID: 1, + }, + Messages: []*druntime.Message{ + {Role: druntime.RoleUser, Content: ptr.Of("hi")}, + }, + BizParam: &druntime.BizParam{ + WorkspaceID: ptr.Of(int64(1)), + Scenario: ptr.Of(common.ScenarioPromptDebug), + ScenarioEntityID: ptr.Of("entity_id"), + }, + } + + model := &entity.Model{ + ID: 1, + Name: "model", + Ability: &entity.Ability{ + MultiModal: true, + AbilityMultiModal: &entity.AbilityMultiModal{ + Image: true, + AbilityImage: &entity.AbilityImage{ + URLEnabled: true, + }, + }, + }, + Protocol: "ark", + ProtocolConfig: &entity.ProtocolConfig{ + BaseURL: "http://test.com", + }, + ScenarioConfigs: map[entity.Scenario]*entity.ScenarioConfig{ + entity.ScenarioPromptDebug: { + Scenario: entity.ScenarioPromptDebug, + Quota: &entity.Quota{ + Qpm: 10, + }, + }, + }, + } + + t.Run("success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockManage := llmservicemocks.NewMockIManage(ctrl) + mockRuntime := llmservicemocks.NewMockIRuntime(ctrl) + mockLimiter := limitermocks.NewMockIRateLimiter(ctrl) + mockStream := entitymocks.NewMockIStreamReader(ctrl) + + r := &runtimeApp{ + manageSrv: mockManage, + runtimeSrv: mockRuntime, + rateLimiter: mockLimiter, + } + + mockManage.EXPECT().GetModelByID(gomock.Any(), gomock.Any()).Return(model, nil) + mockLimiter.EXPECT().AllowN(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&limiter.Result{Allowed: true}, nil).AnyTimes() + mockRuntime.EXPECT().HandleMsgsPreCallModel(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + mockRuntime.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockStream, nil) + mockStream.EXPECT().Recv().Return(&entity.Message{Content: "h"}, nil) + mockStream.EXPECT().Recv().Return(nil, io.EOF) + mockRuntime.EXPECT().CreateModelRequestRecord(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + + err := r.ChatStream(context.Background(), req, &mockChatStreamServer{}) + assert.NoError(t, err) + }) + + t.Run("validate_fail", func(t *testing.T) { + r := &runtimeApp{} + err := r.ChatStream(context.Background(), &runtime.ChatRequest{}, &mockChatStreamServer{}) + assert.Error(t, err) + }) + + t.Run("get_model_fail", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockManage := llmservicemocks.NewMockIManage(ctrl) + r := &runtimeApp{manageSrv: mockManage} + mockManage.EXPECT().GetModelByID(gomock.Any(), gomock.Any()).Return(nil, errors.New("err")) + err := r.ChatStream(context.Background(), req, &mockChatStreamServer{}) + assert.Error(t, err) + }) + + t.Run("stream_fail", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockManage := llmservicemocks.NewMockIManage(ctrl) + mockRuntime := llmservicemocks.NewMockIRuntime(ctrl) + mockLimiter := limitermocks.NewMockIRateLimiter(ctrl) + r := &runtimeApp{manageSrv: mockManage, runtimeSrv: mockRuntime, rateLimiter: mockLimiter} + mockManage.EXPECT().GetModelByID(gomock.Any(), gomock.Any()).Return(model, nil) + mockLimiter.EXPECT().AllowN(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&limiter.Result{Allowed: true}, nil).AnyTimes() + mockRuntime.EXPECT().HandleMsgsPreCallModel(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + mockRuntime.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("err")) + mockRuntime.EXPECT().CreateModelRequestRecord(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + err := r.ChatStream(context.Background(), req, &mockChatStreamServer{}) + assert.Error(t, err) + }) +} diff --git a/backend/modules/llm/domain/entity/convertor.go b/backend/modules/llm/domain/entity/convertor.go index e52b82834..93730a9e8 100644 --- a/backend/modules/llm/domain/entity/convertor.go +++ b/backend/modules/llm/domain/entity/convertor.go @@ -4,9 +4,9 @@ package entity import ( - "github.com/coze-dev/cozeloop-go/spec/tracespec" - + druntime "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/runtime" "github.com/coze-dev/coze-loop/backend/pkg/lang/slices" + "github.com/coze-dev/cozeloop-go/spec/tracespec" ) func MergeStreamMsgs(msgs []*Message) *Message { @@ -221,3 +221,25 @@ func OptionsToTrace(os []Option) *tracespec.ModelCallOption { } return res } + +func ConvertToParamValues(model *Model, paramValues []*druntime.ParamConfigValue) map[string]*ParamValue { + if model == nil || model.ParamConfig == nil { + return nil + } + schemaMap := make(map[string]*ParamSchema) + for _, item := range model.ParamConfig.ParamSchemas { + schemaMap[item.Name] = item + } + resp := make(map[string]*ParamValue) + for _, item := range paramValues { + if v, ok := schemaMap[item.GetName()]; ok { + resp[item.GetName()] = &ParamValue{ + Name: item.GetName(), + ParamType: v.Type, + Value: item.GetValue().GetValue(), + JsonPath: v.JsonPath, + } + } + } + return resp +} diff --git a/backend/modules/llm/domain/entity/convertor_test.go b/backend/modules/llm/domain/entity/convertor_test.go index cef58f7c8..3a37b9e98 100644 --- a/backend/modules/llm/domain/entity/convertor_test.go +++ b/backend/modules/llm/domain/entity/convertor_test.go @@ -6,10 +6,11 @@ package entity import ( "testing" + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/manage" + druntime "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/runtime" + "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/cozeloop-go/spec/tracespec" "github.com/stretchr/testify/assert" - - "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" ) func TestMergeStreamMsgs(t *testing.T) { @@ -407,3 +408,135 @@ func TestToTraceModelInput(t *testing.T) { }) } } + +func TestConvertToParamValues(t *testing.T) { + type args struct { + model *Model + paramValues []*druntime.ParamConfigValue + } + tests := []struct { + name string + args args + want map[string]*ParamValue + }{ + { + name: "model is nil", + args: args{ + model: nil, + paramValues: []*druntime.ParamConfigValue{}, + }, + want: nil, + }, + { + name: "model param config is nil", + args: args{ + model: &Model{}, + paramValues: []*druntime.ParamConfigValue{}, + }, + want: nil, + }, + { + name: "param values is empty", + args: args{ + model: &Model{ + ParamConfig: &ParamConfig{ + ParamSchemas: []*ParamSchema{ + { + Name: "temperature", + Type: ParamTypeFloat, + JsonPath: "temperature", + }, + }, + }, + }, + paramValues: []*druntime.ParamConfigValue{}, + }, + want: map[string]*ParamValue{}, + }, + { + name: "param values with non-existent param", + args: args{ + model: &Model{ + ParamConfig: &ParamConfig{ + ParamSchemas: []*ParamSchema{ + { + Name: "temperature", + Type: ParamTypeFloat, + JsonPath: "temperature", + }, + }, + }, + }, + paramValues: []*druntime.ParamConfigValue{ + { + Name: ptr.Of("non_existent"), + Value: &manage.ParamOption{ + Value: ptr.Of("0.5"), + }, + }, + }, + }, + want: map[string]*ParamValue{}, + }, + { + name: "param values with existing param", + args: args{ + model: &Model{ + ParamConfig: &ParamConfig{ + ParamSchemas: []*ParamSchema{ + { + Name: "temperature", + Type: ParamTypeFloat, + JsonPath: "temperature", + }, + { + Name: "max_tokens", + Type: ParamTypeInt, + JsonPath: "max_tokens", + }, + }, + }, + }, + paramValues: []*druntime.ParamConfigValue{ + { + Name: ptr.Of("temperature"), + Value: &manage.ParamOption{ + Value: ptr.Of("0.5"), + }, + }, + { + Name: ptr.Of("max_tokens"), + Value: &manage.ParamOption{ + Value: ptr.Of("1000"), + }, + }, + { + Name: ptr.Of("non_existent"), + Value: &manage.ParamOption{ + Value: ptr.Of("test"), + }, + }, + }, + }, + want: map[string]*ParamValue{ + "temperature": { + Name: "temperature", + ParamType: ParamTypeFloat, + Value: "0.5", + JsonPath: "temperature", + }, + "max_tokens": { + Name: "max_tokens", + ParamType: ParamTypeInt, + Value: "1000", + JsonPath: "max_tokens", + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, ConvertToParamValues(tt.args.model, tt.args.paramValues), "ConvertToParamValues(%v, %v)", tt.args.model, tt.args.paramValues) + }) + } +} diff --git a/backend/modules/llm/domain/entity/eino_convertor_test.go b/backend/modules/llm/domain/entity/eino_convertor_test.go index 031b3a925..e5a9d18ee 100644 --- a/backend/modules/llm/domain/entity/eino_convertor_test.go +++ b/backend/modules/llm/domain/entity/eino_convertor_test.go @@ -229,3 +229,78 @@ func TestToDOToolCalls(t *testing.T) { }) } } + +func TestEinoConvertor_MoreFromDO(t *testing.T) { + t.Run("FromDOMessages", func(t *testing.T) { + dos := []*Message{ + { + Role: RoleUser, + Content: "hello", + MultiModalContent: []*ChatMessagePart{ + {Type: ChatMessagePartTypeImageURL, ImageURL: &ChatMessageImageURL{URL: "url"}}, + }, + }, + } + res := FromDOMessages(dos) + assert.Len(t, res, 1) + assert.Equal(t, schema.User, res[0].Role) + }) + + t.Run("FromDOImageURL_nil", func(t *testing.T) { + assert.Nil(t, FromDOImageURL(nil)) + }) + + t.Run("FromDOOptions", func(t *testing.T) { + temp := float32(0.7) + maxT := 100 + opts := &Options{ + Temperature: &temp, + MaxTokens: &maxT, + } + res, err := FromDOOptions(opts) + assert.NoError(t, err) + assert.Len(t, res, 2) + }) + + t.Run("FromDOTools", func(t *testing.T) { + toolDef := `{"type": "object"}` + dos := []*ToolInfo{ + { + Name: "t1", + Desc: "d1", + ToolDefType: ToolDefTypeOpenAPIV3, + Def: toolDef, + }, + } + res, err := FromDOTools(dos) + assert.NoError(t, err) + assert.Len(t, res, 1) + + _, err = FromDOTools([]*ToolInfo{{ToolDefType: "unknown"}}) + assert.Error(t, err) + }) +} + +func TestEinoConvertor_MoreToDO(t *testing.T) { + t.Run("ToDOMessage_nil", func(t *testing.T) { + res, err := ToDOMessage(nil) + assert.NoError(t, err) + assert.Nil(t, res) + }) + + t.Run("ToDOMultiContents", func(t *testing.T) { + cms := []schema.ChatMessagePart{ + {Type: schema.ChatMessagePartTypeText, Text: "txt"}, + {Type: schema.ChatMessagePartTypeImageURL, ImageURL: &schema.ChatMessageImageURL{URL: "url"}}, + } + res := ToDOMultiContents(cms) + assert.Len(t, res, 2) + assert.Equal(t, ChatMessagePartTypeText, res[0].Type) + assert.Equal(t, ChatMessagePartTypeImageURL, res[1].Type) + }) + + t.Run("GetReasoningContent", func(t *testing.T) { + msg := &schema.Message{} + assert.Equal(t, "", GetReasoningContent(msg)) + }) +} diff --git a/backend/modules/llm/domain/entity/manage.go b/backend/modules/llm/domain/entity/manage.go index 3b08d96bc..e7ce589c8 100644 --- a/backend/modules/llm/domain/entity/manage.go +++ b/backend/modules/llm/domain/entity/manage.go @@ -7,9 +7,8 @@ import ( "strconv" "github.com/bytedance/sonic" - "github.com/pkg/errors" - "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" + "github.com/pkg/errors" ) type Model struct { @@ -20,11 +19,24 @@ type Model struct { Ability *Ability `json:"ability" yaml:"ability" mapstructure:"ability"` // 模型能力 - Frame Frame `json:"frame" yaml:"frame" mapstructure:"frame"` // 该模型使用的外部框架,目前只支持eino - Protocol Protocol `json:"protocol" yaml:"protocol" mapstructure:"protocol"` // 该模型的协议类型,如ark/deepseek/openai等 - ProtocolConfig *ProtocolConfig `json:"protocol_config" yaml:"protocol_config" mapstructure:"protocol_config"` // 该模型的协议配置 - ScenarioConfigs map[Scenario]*ScenarioConfig `json:"scenario_configs" yaml:"scenario_configs" mapstructure:"scenario_configs"` // 该模型的场景配置 - ParamConfig *ParamConfig `json:"param_config" yaml:"param_config" mapstructure:"param_config"` // 该模型的参数配置 + Frame Frame `json:"frame" yaml:"frame" mapstructure:"frame"` // 该模型使用的外部框架,目前只支持eino + Protocol Protocol `json:"protocol" yaml:"protocol" mapstructure:"protocol"` // 该模型的协议类型,如ark/deepseek/openai等 + ProtocolConfig *ProtocolConfig `json:"protocol_config" yaml:"protocol_config" mapstructure:"protocol_config"` // 该模型的协议配置 + ScenarioConfigs map[Scenario]*ScenarioConfig `json:"scenario_configs" yaml:"scenario_configs" mapstructure:"scenario_configs"` // 该模型的场景配置 + ParamConfig *ParamConfig `json:"param_config" yaml:"param_config" mapstructure:"param_config"` // 该模型的参数配置 + Identification string `json:"identification" yaml:"identification"` + Series *Series `json:"series" yaml:"series"` + Visibility *Visibility `json:"visibility" yaml:"visibility"` + Icon string `json:"icon" yaml:"icon" mapstructure:"icon"` // 模型图标 + Tags []string `json:"tags" yaml:"tags" mapstructure:"tags"` // 模型标签 + Status ModelStatus `json:"status" yaml:"status" mapstructure:"status"` // 模型状态 + OriginalModelURL string `json:"original_model_url" yaml:"original_model_url" mapstructure:"original_model_url"` // 模型跳转链接 + PresetModel bool `json:"preset_model" yaml:"preset_model" mapstructure:"preset_model"` // 是否为预置模型 + + CreatedBy string `json:"created_by" yaml:"created_by" mapstructure:"created_by"` // 创建人 + CreatedAt int64 `json:"created_at" yaml:"created_at" mapstructure:"created_at"` // 创建时间 + UpdatedBy string `json:"updated_by" yaml:"updated_by" mapstructure:"updated_by"` // 更新人 + UpdatedAt int64 `json:"updated_at" yaml:"updated_at" mapstructure:"updated_at"` // 更新时间 } func (m *Model) Valid() error { @@ -143,6 +155,27 @@ type Ability struct { JsonMode bool `json:"json_mode" yaml:"json_mode" mapstructure:"json_mode"` MultiModal bool `json:"multi_modal" yaml:"multi_modal" mapstructure:"multi_modal"` AbilityMultiModal *AbilityMultiModal `json:"ability_multi_modal" yaml:"ability_multi_modal" mapstructure:"ability_multi_modal"` + Thinking bool `json:"thinking" mapstructure:"thinking"` +} + +func (a *Ability) GetAbilityEnums() []AbilityEnum { + var resp []AbilityEnum + if a == nil { + return resp + } + if a.FunctionCall { + resp = append(resp, AbilityEnumFunctionCall) + } + if a.JsonMode { + resp = append(resp, AbilityEnumJsonMode) + } + if a.MultiModal { + resp = append(resp, AbilityEnumMultiModal) + } + if a.Thinking { + resp = append(resp, AbilityEnumThinking) + } + return resp } type AbilityMultiModal struct { @@ -327,6 +360,14 @@ type ParamSchema struct { Max string `json:"max" yaml:"max" mapstructure:"max"` DefaultValue string `json:"default_value" yaml:"default_value" mapstructure:"default_value"` Options []*ParamOption `json:"options" yaml:"options" mapstructure:"options"` + Properties []*ParamSchema `json:"properties" mapstructrue:"properties"` + JsonPath string `json:"json_path" mapstructrue:"json_path"` + Reaction *Reaction `json:"reaction" mapstructrue:"reaction"` +} + +type Reaction struct { + Dependency string `json:"dependency"` + Visible string `json:"visible"` } type ParamOption struct { @@ -341,6 +382,8 @@ const ( ParamTypeInt ParamType = "int" ParamTypeBoolean ParamType = "boolean" ParamTypeString ParamType = "string" + ParamTypeVoid ParamType = "void" + ParamTypeObject ParamType = "object" ) type Frame string @@ -353,17 +396,35 @@ const ( type Protocol string const ( - ProtocolArk Protocol = "ark" - ProtocolOpenAI Protocol = "openai" - ProtocolDeepseek Protocol = "deepseek" - ProtocolClaude Protocol = "claude" - ProtocolOllama Protocol = "ollama" - ProtocolGemini Protocol = "gemini" - ProtocolQwen Protocol = "qwen" - ProtocolQianfan Protocol = "qianfan" - ProtocolArkBot Protocol = "arkbot" + ProtocolUndefined Protocol = "undefined" + ProtocolArk Protocol = "ark" + ProtocolOpenAI Protocol = "openai" + ProtocolDeepseek Protocol = "deepseek" + ProtocolClaude Protocol = "claude" + ProtocolOllama Protocol = "ollama" + ProtocolGemini Protocol = "gemini" + ProtocolQwen Protocol = "qwen" + ProtocolQianfan Protocol = "qianfan" + ProtocolArkBot Protocol = "arkbot" +) + +type Family string + +const ( + FamilyUndefined Family = "undefined" + FamilySeed Family = "seed" + FamilyGLM Family = "glm" + FamilyKimi Family = "kimi" + FamilyDeepSeek Family = "deepseek" + FamilyDoubao Family = "doubao" ) +type Series struct { + Name string `json:"name"` + Icon string `json:"icon"` + Family Family `json:"family"` +} + type ListModelReq struct { WorkspaceID *int64 Scenario *Scenario @@ -375,3 +436,42 @@ type GetModelReq struct { WorkspaceID *int64 ModelID int64 } + +type VisibleMode string + +const ( + VisibleModeUndefined VisibleMode = "undefined" + VisibleModelAll VisibleMode = "all" + VisibleModelSpecified VisibleMode = "specified" + VisibleModelDefault VisibleMode = "default" +) + +type Visibility struct { + Mode VisibleMode `json:"mode"` + SpaceIDs []int64 `json:"space_ids"` // model为specified时有效 +} + +type ModelStatus string + +const ( + ModelStatusUndefined ModelStatus = "undefined" + ModelStatusEnabled ModelStatus = "enabled" + ModelStatusDisabled ModelStatus = "disabled" +) + +type ListModelsFilter struct { + NameLike *string `json:"name_like,omitempty"` + Families []Family `json:"families,omitempty"` + ModelStatuses []ModelStatus `json:"model_statuses,omitempty"` + Abilities []AbilityEnum `json:"abilities,omitempty"` +} + +type AbilityEnum string + +const ( + AbilityEnumUndefined AbilityEnum = "undefined" + AbilityEnumFunctionCall AbilityEnum = "function_call" + AbilityEnumMultiModal AbilityEnum = "multi_modal" + AbilityEnumJsonMode AbilityEnum = "json_mode" + AbilityEnumThinking AbilityEnum = "thinking" +) diff --git a/backend/modules/llm/domain/entity/manage_test.go b/backend/modules/llm/domain/entity/manage_test.go index c98814745..8c1de4aee 100644 --- a/backend/modules/llm/domain/entity/manage_test.go +++ b/backend/modules/llm/domain/entity/manage_test.go @@ -6,9 +6,8 @@ package entity import ( "testing" - "github.com/stretchr/testify/assert" - "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" + "github.com/stretchr/testify/assert" ) func TestModel_Available(t *testing.T) { @@ -415,3 +414,111 @@ func TestParamConfig_GetCommonParamDefaultVal(t *testing.T) { }) } } + +func TestAbility_GetAbilityEnums(t *testing.T) { + type fields struct { + ability *Ability + } + tests := []struct { + name string + fields fields + want []AbilityEnum + }{ + { + name: "ability is nil", + fields: fields{ + ability: nil, + }, + want: nil, + }, + { + name: "ability has no enabled abilities", + fields: fields{ + ability: &Ability{ + FunctionCall: false, + JsonMode: false, + MultiModal: false, + Thinking: false, + }, + }, + want: nil, + }, + { + name: "ability has function call enabled", + fields: fields{ + ability: &Ability{ + FunctionCall: true, + JsonMode: false, + MultiModal: false, + Thinking: false, + }, + }, + want: []AbilityEnum{AbilityEnumFunctionCall}, + }, + { + name: "ability has json mode enabled", + fields: fields{ + ability: &Ability{ + FunctionCall: false, + JsonMode: true, + MultiModal: false, + Thinking: false, + }, + }, + want: []AbilityEnum{AbilityEnumJsonMode}, + }, + { + name: "ability has multi modal enabled", + fields: fields{ + ability: &Ability{ + FunctionCall: false, + JsonMode: false, + MultiModal: true, + Thinking: false, + }, + }, + want: []AbilityEnum{AbilityEnumMultiModal}, + }, + { + name: "ability has thinking enabled", + fields: fields{ + ability: &Ability{ + FunctionCall: false, + JsonMode: false, + MultiModal: false, + Thinking: true, + }, + }, + want: []AbilityEnum{AbilityEnumThinking}, + }, + { + name: "ability has multiple abilities enabled", + fields: fields{ + ability: &Ability{ + FunctionCall: true, + JsonMode: true, + MultiModal: false, + Thinking: false, + }, + }, + want: []AbilityEnum{AbilityEnumFunctionCall, AbilityEnumJsonMode}, + }, + { + name: "ability has all abilities enabled", + fields: fields{ + ability: &Ability{ + FunctionCall: true, + JsonMode: true, + MultiModal: true, + Thinking: true, + }, + }, + want: []AbilityEnum{AbilityEnumFunctionCall, AbilityEnumJsonMode, AbilityEnumMultiModal, AbilityEnumThinking}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, tt.fields.ability.GetAbilityEnums(), "GetAbilityEnums()") + }) + } +} diff --git a/backend/modules/llm/domain/entity/runtime.go b/backend/modules/llm/domain/entity/runtime.go index a2d3afc42..d2940e50e 100644 --- a/backend/modules/llm/domain/entity/runtime.go +++ b/backend/modules/llm/domain/entity/runtime.go @@ -3,7 +3,11 @@ package entity -import "time" +import ( + "fmt" + "strconv" + "time" +) type Message struct { Role Role `json:"role"` @@ -239,3 +243,25 @@ const ( type ResponseFormat struct { Type ResponseFormatType `json:"type,omitempty"` } + +type ParamValue struct { + Name string `json:"name"` + ParamType ParamType `json:"param_type"` + Value string `json:"value"` + JsonPath string `json:"json_path"` +} + +func (p *ParamValue) GetValue() (any, error) { + switch p.ParamType { + case ParamTypeBoolean: + return strconv.ParseBool(p.Value) + case ParamTypeFloat: + return strconv.ParseFloat(p.Value, 64) + case ParamTypeInt: + return strconv.ParseInt(p.Value, 10, 64) + case ParamTypeString: + return p.Value, nil + default: + return nil, fmt.Errorf("unsupported param type: %s", p.ParamType) + } +} diff --git a/backend/modules/llm/domain/entity/runtime_option.go b/backend/modules/llm/domain/entity/runtime_option.go index b18ee5702..a3ed97f57 100644 --- a/backend/modules/llm/domain/entity/runtime_option.go +++ b/backend/modules/llm/domain/entity/runtime_option.go @@ -28,6 +28,8 @@ type Options struct { FrequencyPenalty *float32 // Parameters is the extra parameters for the model. Parameters map[string]string + // ParamValues + ParamValues map[string]*ParamValue } type Option struct { @@ -177,3 +179,11 @@ func WithParameters(p map[string]string) Option { }, } } + +func WithParamValues(p map[string]*ParamValue) Option { + return Option{ + apply: func(opts *Options) { + opts.ParamValues = p + }, + } +} diff --git a/backend/modules/llm/domain/entity/runtime_test.go b/backend/modules/llm/domain/entity/runtime_test.go new file mode 100644 index 000000000..a05b738bb --- /dev/null +++ b/backend/modules/llm/domain/entity/runtime_test.go @@ -0,0 +1,146 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package entity + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMessage_TokenGetters(t *testing.T) { + t.Run("nil_message", func(t *testing.T) { + var m *Message + assert.Equal(t, 0, m.GetInputToken()) + assert.Equal(t, 0, m.GetOutputToken()) + }) + + t.Run("nil_response_meta", func(t *testing.T) { + m := &Message{} + assert.Equal(t, 0, m.GetInputToken()) + assert.Equal(t, 0, m.GetOutputToken()) + }) + + t.Run("nil_usage", func(t *testing.T) { + m := &Message{ResponseMeta: &ResponseMeta{}} + assert.Equal(t, 0, m.GetInputToken()) + assert.Equal(t, 0, m.GetOutputToken()) + }) + + t.Run("success", func(t *testing.T) { + m := &Message{ + ResponseMeta: &ResponseMeta{ + Usage: &TokenUsage{ + PromptTokens: 10, + CompletionTokens: 20, + }, + }, + } + assert.Equal(t, 10, m.GetInputToken()) + assert.Equal(t, 20, m.GetOutputToken()) + }) +} + +func TestMessage_MultiModal(t *testing.T) { + t.Run("has_multi_modal", func(t *testing.T) { + assert.False(t, (*Message)(nil).HasMultiModalContent()) + assert.False(t, (&Message{}).HasMultiModalContent()) + + m := &Message{ + MultiModalContent: []*ChatMessagePart{ + {Type: ChatMessagePartTypeText, Text: "text"}, + }, + } + assert.False(t, m.HasMultiModalContent()) + + m.MultiModalContent = append(m.MultiModalContent, &ChatMessagePart{ + Type: ChatMessagePartTypeImageURL, + }) + assert.True(t, m.HasMultiModalContent()) + }) + + t.Run("get_image_count_and_max_size", func(t *testing.T) { + m := &Message{ + MultiModalContent: []*ChatMessagePart{ + { + Type: ChatMessagePartTypeImageURL, + ImageURL: &ChatMessageImageURL{ + URL: "http://example.com/a.jpg", + }, + }, + { + Type: ChatMessagePartTypeImageURL, + ImageURL: &ChatMessageImageURL{ + URL: "base64data", // simplified base64 + MIMEType: "image/png", + }, + }, + }, + } + hasUrl, hasBinary, cnt, maxSize := m.GetImageCountAndMaxSize() + assert.True(t, hasUrl) + assert.True(t, hasBinary) + assert.Equal(t, int64(2), cnt) + assert.True(t, maxSize > 0) + }) +} + +func TestChatMessagePart_Checks(t *testing.T) { + t.Run("is_multi_modal", func(t *testing.T) { + assert.False(t, (*ChatMessagePart)(nil).IsMultiModal()) + assert.False(t, (&ChatMessagePart{Type: ChatMessagePartTypeText}).IsMultiModal()) + assert.True(t, (&ChatMessagePart{Type: ChatMessagePartTypeImageURL}).IsMultiModal()) + }) + + t.Run("is_url_binary", func(t *testing.T) { + assert.False(t, (*ChatMessagePart)(nil).IsURL()) + assert.False(t, (*ChatMessagePart)(nil).IsBinary()) + + p := &ChatMessagePart{ + Type: ChatMessagePartTypeImageURL, + ImageURL: &ChatMessageImageURL{ + URL: "url", + }, + } + assert.True(t, p.IsURL()) + assert.False(t, p.IsBinary()) + + p.ImageURL.MIMEType = "image/png" + assert.False(t, p.IsURL()) + assert.True(t, p.IsBinary()) + }) +} + +func TestParamValue_GetValue(t *testing.T) { + tests := []struct { + name string + paramType ParamType + value string + want interface{} + wantErr bool + }{ + {"bool_true", ParamTypeBoolean, "true", true, false}, + {"bool_invalid", ParamTypeBoolean, "not_bool", nil, true}, + {"float", ParamTypeFloat, "1.23", 1.23, false}, + {"int", ParamTypeInt, "123", int64(123), false}, + {"string", ParamTypeString, "hello", "hello", false}, + {"unsupported", ParamType("unknown"), "val", nil, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &ParamValue{ + ParamType: tt.paramType, + Value: tt.value, + } + got, err := p.GetValue() + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go index 517ba2000..01a56a588 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go @@ -342,7 +342,6 @@ func (h *TraceHubServiceImpl) fetchSpans(ctx context.Context, listParam *repo.Li logs.CtxInfo(ctx, "Completed listing spans, task_id=%d", sub.t.ID) return spans, "", nil } - return spans, result.PageToken, nil } diff --git a/idl/thrift/coze/loop/evaluation/coze.loop.evaluation.eval_set.thrift b/idl/thrift/coze/loop/evaluation/coze.loop.evaluation.eval_set.thrift index 7c1933fd6..64ba18719 100644 --- a/idl/thrift/coze/loop/evaluation/coze.loop.evaluation.eval_set.thrift +++ b/idl/thrift/coze/loop/evaluation/coze.loop.evaluation.eval_set.thrift @@ -365,30 +365,68 @@ struct GetEvaluationSetItemFieldResponse { service EvaluationSetService { // 基本信息管理 - CreateEvaluationSetResponse CreateEvaluationSet(1: CreateEvaluationSetRequest req) (api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets") - UpdateEvaluationSetResponse UpdateEvaluationSet(1: UpdateEvaluationSetRequest req) (api.category="evaluation_set", api.patch = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id") - DeleteEvaluationSetResponse DeleteEvaluationSet(1: DeleteEvaluationSetRequest req) (api.category="evaluation_set", api.delete = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id"), - GetEvaluationSetResponse GetEvaluationSet(1: GetEvaluationSetRequest req) (api.category="evaluation_set", api.get = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id"), - ListEvaluationSetsResponse ListEvaluationSets(1: ListEvaluationSetsRequest req) (api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets/list"), - CreateEvaluationSetWithImportResponse CreateEvaluationSetWithImport(1: CreateEvaluationSetWithImportRequest req) (api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets/create_with_import") - ParseImportSourceFileResponse ParseImportSourceFile(1: ParseImportSourceFileRequest req) (api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets/parse_import_source_file") + CreateEvaluationSetResponse CreateEvaluationSet(1: CreateEvaluationSetRequest req) ( + api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets", api.op_type = 'create', api.tag = 'volc-agentkit,open' + ) + UpdateEvaluationSetResponse UpdateEvaluationSet(1: UpdateEvaluationSetRequest req) ( + api.category="evaluation_set", api.patch = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id", api.op_type = 'update', api.tag = 'volc-agentkit,open' + ) + DeleteEvaluationSetResponse DeleteEvaluationSet(1: DeleteEvaluationSetRequest req) ( + api.category="evaluation_set", api.delete = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id", api.op_type = 'delete', api.tag = 'volc-agentkit,open' + ) + GetEvaluationSetResponse GetEvaluationSet(1: GetEvaluationSetRequest req) ( + api.category="evaluation_set", api.get = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id", api.op_type = 'query', api.tag = 'volc-agentkit,open' + ) + ListEvaluationSetsResponse ListEvaluationSets(1: ListEvaluationSetsRequest req) ( + api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets/list", api.op_type = 'list', api.tag = 'volc-agentkit,open' + ) + CreateEvaluationSetWithImportResponse CreateEvaluationSetWithImport(1: CreateEvaluationSetWithImportRequest req) ( + api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets/create_with_import", api.op_type = 'create', api.tag = 'volc-agentkit' + ) + ParseImportSourceFileResponse ParseImportSourceFile(1: ParseImportSourceFileRequest req) ( + api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets/parse_import_source_file", api.op_type = 'query', api.tag = 'volc-agentkit' + ) // 版本管理 - CreateEvaluationSetVersionResponse CreateEvaluationSetVersion(1: CreateEvaluationSetVersionRequest req) (api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/versions"), - GetEvaluationSetVersionResponse GetEvaluationSetVersion(1: GetEvaluationSetVersionRequest req) (api.category="evaluation_set", api.get = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/versions/:version_id"), - ListEvaluationSetVersionsResponse ListEvaluationSetVersions(1: ListEvaluationSetVersionsRequest req) (api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/versions/list"), - BatchGetEvaluationSetVersionsResponse BatchGetEvaluationSetVersions(1: BatchGetEvaluationSetVersionsRequest req) (api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_set_versions/batch_get"), + CreateEvaluationSetVersionResponse CreateEvaluationSetVersion(1: CreateEvaluationSetVersionRequest req) ( + api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/versions", api.op_type = 'create', api.tag = 'volc-agentkit,open' + ) + GetEvaluationSetVersionResponse GetEvaluationSetVersion(1: GetEvaluationSetVersionRequest req) ( + api.category="evaluation_set", api.get = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/versions/:version_id", api.op_type = 'query', api.tag = 'volc-agentkit' + ) + ListEvaluationSetVersionsResponse ListEvaluationSetVersions(1: ListEvaluationSetVersionsRequest req) ( + api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/versions/list", api.op_type = 'list', api.tag = 'volc-agentkit,open' + ) + BatchGetEvaluationSetVersionsResponse BatchGetEvaluationSetVersions(1: BatchGetEvaluationSetVersionsRequest req) ( + api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_set_versions/batch_get", api.op_type = 'query', api.tag = 'volc-agentkit' + ) // 字段管理 - UpdateEvaluationSetSchemaResponse UpdateEvaluationSetSchema(1: UpdateEvaluationSetSchemaRequest req) (api.category="evaluation_set", api.put = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/schema"), + UpdateEvaluationSetSchemaResponse UpdateEvaluationSetSchema(1: UpdateEvaluationSetSchemaRequest req) ( + api.category="evaluation_set", api.put = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/schema", api.op_type = 'update', api.tag = 'volc-agentkit,open' + ) // 数据管理 - BatchCreateEvaluationSetItemsResponse BatchCreateEvaluationSetItems(1: BatchCreateEvaluationSetItemsRequest req) (api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/items/batch_create") - UpdateEvaluationSetItemResponse UpdateEvaluationSetItem(1: UpdateEvaluationSetItemRequest req) (api.category="evaluation_set", api.put = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/items/:item_id") - BatchDeleteEvaluationSetItemsResponse BatchDeleteEvaluationSetItems(1: BatchDeleteEvaluationSetItemsRequest req) (api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/items/batch_delete") - ListEvaluationSetItemsResponse ListEvaluationSetItems(1: ListEvaluationSetItemsRequest req) (api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/items/list") - BatchGetEvaluationSetItemsResponse BatchGetEvaluationSetItems(1: BatchGetEvaluationSetItemsRequest req) (api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/items/batch_get") - ClearEvaluationSetDraftItemResponse ClearEvaluationSetDraftItem(1: ClearEvaluationSetDraftItemRequest req) (api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/items/clear") - GetEvaluationSetItemFieldResponse GetEvaluationSetItemField(1: GetEvaluationSetItemFieldRequest req) (api.get = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/items/:item_pk/field") + BatchCreateEvaluationSetItemsResponse BatchCreateEvaluationSetItems(1: BatchCreateEvaluationSetItemsRequest req) ( + api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/items/batch_create", api.op_type = 'query', api.tag = 'volc-agentkit,open' + ) + UpdateEvaluationSetItemResponse UpdateEvaluationSetItem(1: UpdateEvaluationSetItemRequest req) ( + api.category="evaluation_set", api.put = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/items/:item_id", api.op_type = 'update', api.tag = 'volc-agentkit,open' + ) + BatchDeleteEvaluationSetItemsResponse BatchDeleteEvaluationSetItems(1: BatchDeleteEvaluationSetItemsRequest req) ( + api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/items/batch_delete", api.op_type = 'delete', api.tag = 'volc-agentkit,open' + ) + ListEvaluationSetItemsResponse ListEvaluationSetItems(1: ListEvaluationSetItemsRequest req) ( + api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/items/list", api.op_type = 'list', api.tag = 'volc-agentkit,open' + ) + BatchGetEvaluationSetItemsResponse BatchGetEvaluationSetItems(1: BatchGetEvaluationSetItemsRequest req) ( + api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/items/batch_get", api.op_type = 'query', api.tag = 'volc-agentkit' + ) + ClearEvaluationSetDraftItemResponse ClearEvaluationSetDraftItem(1: ClearEvaluationSetDraftItemRequest req) ( + api.category="evaluation_set", api.post = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/items/clear", api.op_type = 'update', api.tag = 'volc-agentkit' + ) + GetEvaluationSetItemFieldResponse GetEvaluationSetItemField(1: GetEvaluationSetItemFieldRequest req) ( + api.category="evaluation_set", api.get = "/api/evaluation/v1/evaluation_sets/:evaluation_set_id/items/:item_pk/field", api.op_type = 'query', api.tag = 'volc-agentkit,open' + ) } diff --git a/idl/thrift/coze/loop/evaluation/coze.loop.evaluation.eval_target.thrift b/idl/thrift/coze/loop/evaluation/coze.loop.evaluation.eval_target.thrift index cd3c9b577..b43011157 100644 --- a/idl/thrift/coze/loop/evaluation/coze.loop.evaluation.eval_target.thrift +++ b/idl/thrift/coze/loop/evaluation/coze.loop.evaluation.eval_target.thrift @@ -271,31 +271,51 @@ struct MockEvalTargetOutputResponse { service EvalTargetService { // 创建评测对象 - CreateEvalTargetResponse CreateEvalTarget(1: CreateEvalTargetRequest request) (api.category="eval_target", api.post = "/api/evaluation/v1/eval_targets") + CreateEvalTargetResponse CreateEvalTarget(1: CreateEvalTargetRequest request) ( + api.category="eval_target", api.post = "/api/evaluation/v1/eval_targets", api.op_type = 'create', api.tag = 'volc-agentkit' + ) // 根据source target获取评测对象信息 - BatchGetEvalTargetsBySourceResponse BatchGetEvalTargetsBySource(1: BatchGetEvalTargetsBySourceRequest request) (api.category="eval_target", api.post = "/api/evaluation/v1/eval_targets/batch_get_by_source") + BatchGetEvalTargetsBySourceResponse BatchGetEvalTargetsBySource(1: BatchGetEvalTargetsBySourceRequest request) ( + api.category="eval_target", api.post = "/api/evaluation/v1/eval_targets/batch_get_by_source", api.op_type = 'query', api.tag = 'volc-agentkit' + ) // 获取评测对象+版本 - GetEvalTargetVersionResponse GetEvalTargetVersion(1: GetEvalTargetVersionRequest request) (api.category="eval_target", api.get = "/api/evaluation/v1/eval_target_versions/:eval_target_version_id") + GetEvalTargetVersionResponse GetEvalTargetVersion(1: GetEvalTargetVersionRequest request) ( + api.category="eval_target", api.get = "/api/evaluation/v1/eval_target_versions/:eval_target_version_id", api.op_type = 'query', api.tag = 'volc-agentkit' + ) // 批量获取+版本 - BatchGetEvalTargetVersionsResponse BatchGetEvalTargetVersions(1: BatchGetEvalTargetVersionsRequest request) (api.category="eval_target", api.post = "/api/evaluation/v1/eval_target_versions/batch_get") + BatchGetEvalTargetVersionsResponse BatchGetEvalTargetVersions(1: BatchGetEvalTargetVersionsRequest request) ( + api.category="eval_target", api.post = "/api/evaluation/v1/eval_target_versions/batch_get", api.op_type = 'query', api.tag = 'volc-agentkit' + ) // Source评测对象列表 - ListSourceEvalTargetsResponse ListSourceEvalTargets(1: ListSourceEvalTargetsRequest request) (api.category="eval_target", api.post = "/api/evaluation/v1/eval_targets/list_source") + ListSourceEvalTargetsResponse ListSourceEvalTargets(1: ListSourceEvalTargetsRequest request) ( + api.category="eval_target", api.post = "/api/evaluation/v1/eval_targets/list_source", api.op_type = 'list', api.tag = 'volc-agentkit' + ) // Source评测对象版本列表 - ListSourceEvalTargetVersionsResponse ListSourceEvalTargetVersions(1: ListSourceEvalTargetVersionsRequest request) (api.category="eval_target", api.post = "/api/evaluation/v1/eval_targets/list_source_version") - BatchGetSourceEvalTargetsResponse BatchGetSourceEvalTargets (1: BatchGetSourceEvalTargetsRequest request) (api.category="eval_target", api.post = "/api/evaluation/v1/eval_targets/batch_get_source") + ListSourceEvalTargetVersionsResponse ListSourceEvalTargetVersions(1: ListSourceEvalTargetVersionsRequest request) ( + api.category="eval_target", api.post = "/api/evaluation/v1/eval_targets/list_source_version", api.op_type = 'list', api.tag = 'volc-agentkit' + ) + BatchGetSourceEvalTargetsResponse BatchGetSourceEvalTargets (1: BatchGetSourceEvalTargetsRequest request) ( + api.category="eval_target", api.post = "/api/evaluation/v1/eval_targets/batch_get_source", api.op_type = 'query', api.tag = 'volc-agentkit' + ) // 搜索自定义评测对象 SearchCustomEvalTargetResponse SearchCustomEvalTarget(1: SearchCustomEvalTargetRequest req) (api.category="eval_target", api.post = "/api/evaluation/v1/eval_targets/search_custom") // 执行 ExecuteEvalTargetResponse ExecuteEvalTarget(1: ExecuteEvalTargetRequest request) (api.category="eval_target", api.post = "/api/evaluation/v1/eval_targets/:eval_target_id/versions/:eval_target_version_id/execute") AsyncExecuteEvalTargetResponse AsyncExecuteEvalTarget(1: AsyncExecuteEvalTargetRequest request) - GetEvalTargetRecordResponse GetEvalTargetRecord(1: GetEvalTargetRecordRequest request) (api.category="eval_target", api.get = "/api/evaluation/v1/eval_target_records/:eval_target_record_id") - BatchGetEvalTargetRecordsResponse BatchGetEvalTargetRecords(1: BatchGetEvalTargetRecordsRequest request) (api.category="eval_target", api.post = "/api/evaluation/v1/eval_target_records/batch_get") + GetEvalTargetRecordResponse GetEvalTargetRecord(1: GetEvalTargetRecordRequest request) ( + api.category="eval_target", api.get = "/api/evaluation/v1/eval_target_records/:eval_target_record_id", api.op_type = 'query', api.tag = 'volc-agentkit' + ) + BatchGetEvalTargetRecordsResponse BatchGetEvalTargetRecords(1: BatchGetEvalTargetRecordsRequest request) ( + api.category="eval_target", api.post = "/api/evaluation/v1/eval_target_records/batch_get", api.op_type = 'query', api.tag = 'volc-agentkit' + ) // debug DebugEvalTargetResponse DebugEvalTarget(1: DebugEvalTargetRequest request) (api.category="eval_target", api.post = "/api/evaluation/v1/eval_targets/debug") AsyncDebugEvalTargetResponse AsyncDebugEvalTarget(1: AsyncDebugEvalTargetRequest request) (api.category="eval_target", api.post = "/api/evaluation/v1/eval_targets/async_debug") // mock输出数据 - MockEvalTargetOutputResponse MockEvalTargetOutput(1: MockEvalTargetOutputRequest request) (api.category="eval_target", api.post = "/api/evaluation/v1/eval_targets/mock_output") -} (api.js_conv="true" ) + MockEvalTargetOutputResponse MockEvalTargetOutput(1: MockEvalTargetOutputRequest request) ( + api.category="eval_target", api.post = "/api/evaluation/v1/eval_targets/mock_output", api.op_type = 'query', api.tag = 'volc-agentkit' + ) +} diff --git a/idl/thrift/coze/loop/evaluation/coze.loop.evaluation.evaluator.thrift b/idl/thrift/coze/loop/evaluation/coze.loop.evaluation.evaluator.thrift index 72b756f24..e102032a1 100644 --- a/idl/thrift/coze/loop/evaluation/coze.loop.evaluation.evaluator.thrift +++ b/idl/thrift/coze/loop/evaluation/coze.loop.evaluation.evaluator.thrift @@ -57,6 +57,8 @@ struct GetEvaluatorResponse { struct CreateEvaluatorRequest { 1: required evaluator.Evaluator evaluator (api.body='evaluator') + 2: optional i64 workspace_id (api.body='workspace_id', api.js_conv='true', go.tag='json:"workspace_id"') + 100: optional string cid (api.body='cid') 255: optional base.Base Base @@ -474,56 +476,110 @@ struct ListEvaluatorTagsResponse { service EvaluatorService { // 评估器 - ListEvaluatorsResponse ListEvaluators(1: ListEvaluatorsRequest request) (api.post= "/api/evaluation/v1/evaluators/list") // 按查询条件查询evaluator - BatchGetEvaluatorsResponse BatchGetEvaluators(1: BatchGetEvaluatorsRequest request) (api.post= "/api/evaluation/v1/evaluators/batch_get") // 按id批量查询evaluator - GetEvaluatorResponse GetEvaluator(1: GetEvaluatorRequest request) (api.get= "/api/evaluation/v1/evaluators/:evaluator_id") // 按id单个查询evaluator - CreateEvaluatorResponse CreateEvaluator(1: CreateEvaluatorRequest request) (api.post= "/api/evaluation/v1/evaluators") // 创建evaluator - UpdateEvaluatorResponse UpdateEvaluator(1: UpdateEvaluatorRequest request) (api.patch= "/api/evaluation/v1/evaluators/:evaluator_id") // 修改evaluator元信息 - UpdateEvaluatorDraftResponse UpdateEvaluatorDraft(1: UpdateEvaluatorDraftRequest request) (api.patch= "/api/evaluation/v1/evaluators/:evaluator_id/update_draft") // 修改evaluator草稿 - DeleteEvaluatorResponse DeleteEvaluator(1: DeleteEvaluatorRequest request) (api.delete= "/api/evaluation/v1/evaluators/:evaluator_id") // 批量删除evaluator - CheckEvaluatorNameResponse CheckEvaluatorName(1: CheckEvaluatorNameRequest request) (api.post= "/api/evaluation/v1/evaluators/check_name") // 校验evaluator名称是否重复 + ListEvaluatorsResponse ListEvaluators(1: ListEvaluatorsRequest request) ( + api.post= "/api/evaluation/v1/evaluators/list", api.op_type = 'list', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) // 按查询条件查询evaluator + BatchGetEvaluatorsResponse BatchGetEvaluators(1: BatchGetEvaluatorsRequest request) ( + api.post= "/api/evaluation/v1/evaluators/batch_get", api.op_type = 'query', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) // 按id批量查询evaluator + GetEvaluatorResponse GetEvaluator(1: GetEvaluatorRequest request) ( + api.get= "/api/evaluation/v1/evaluators/:evaluator_id", api.op_type = 'query', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) // 按id单个查询evaluator + CreateEvaluatorResponse CreateEvaluator(1: CreateEvaluatorRequest request) ( + api.post= "/api/evaluation/v1/evaluators", api.op_type = 'create', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) // 创建evaluator + UpdateEvaluatorResponse UpdateEvaluator(1: UpdateEvaluatorRequest request) ( + api.patch= "/api/evaluation/v1/evaluators/:evaluator_id", api.op_type = 'update', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) // 修改evaluator元信息 + UpdateEvaluatorDraftResponse UpdateEvaluatorDraft(1: UpdateEvaluatorDraftRequest request) ( + api.patch= "/api/evaluation/v1/evaluators/:evaluator_id/update_draft", api.op_type = 'update', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) // 修改evaluator草稿 + DeleteEvaluatorResponse DeleteEvaluator(1: DeleteEvaluatorRequest request) ( + api.delete= "/api/evaluation/v1/evaluators/:evaluator_id", api.op_type = 'delete', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) // 批量删除evaluator + CheckEvaluatorNameResponse CheckEvaluatorName(1: CheckEvaluatorNameRequest request) ( + api.post= "/api/evaluation/v1/evaluators/check_name", api.op_type = 'query', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) // 校验evaluator名称是否重复 // 评估器版本 - ListEvaluatorVersionsResponse ListEvaluatorVersions(1: ListEvaluatorVersionsRequest request) (api.post= "/api/evaluation/v1/evaluators/:evaluator_id/versions/list") // 按evaluator id查询evaluator version - GetEvaluatorVersionResponse GetEvaluatorVersion(1: GetEvaluatorVersionRequest request) (api.get= "/api/evaluation/v1/evaluators_versions/:evaluator_version_id") // 按版本id单个查询evaluator version - BatchGetEvaluatorVersionsResponse BatchGetEvaluatorVersions(1: BatchGetEvaluatorVersionsRequest request) (api.post= "/api/evaluation/v1/evaluators_versions/batch_get") // 按版本id批量查询evaluator version - SubmitEvaluatorVersionResponse SubmitEvaluatorVersion(1: SubmitEvaluatorVersionRequest request) (api.post= "/api/evaluation/v1/evaluators/:evaluator_id/submit_version") // 提交evaluator版本 + ListEvaluatorVersionsResponse ListEvaluatorVersions(1: ListEvaluatorVersionsRequest request) ( + api.post= "/api/evaluation/v1/evaluators/:evaluator_id/versions/list", api.op_type = 'list', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) // 按evaluator id查询evaluator version + GetEvaluatorVersionResponse GetEvaluatorVersion(1: GetEvaluatorVersionRequest request) ( + api.get= "/api/evaluation/v1/evaluators_versions/:evaluator_version_id", api.op_type = 'query', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) // 按版本id单个查询evaluator version + BatchGetEvaluatorVersionsResponse BatchGetEvaluatorVersions(1: BatchGetEvaluatorVersionsRequest request) ( + api.post= "/api/evaluation/v1/evaluators_versions/batch_get", api.op_type = 'query', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) // 按版本id批量查询evaluator version + SubmitEvaluatorVersionResponse SubmitEvaluatorVersion(1: SubmitEvaluatorVersionRequest request) ( + api.post= "/api/evaluation/v1/evaluators/:evaluator_id/submit_version", api.op_type = 'create', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) // 提交evaluator版本 // 评估器预置模版 - ListTemplatesResponse ListTemplates(1: ListTemplatesRequest request) (api.post= "/api/evaluation/v1/evaluators/list_template") // 获取内置评估器模板列表(不含具体内容) - GetTemplateInfoResponse GetTemplateInfo(1: GetTemplateInfoRequest request) (api.post= "/api/evaluation/v1/evaluators/get_template_info") // 按key单个查询内置评估器模板详情 - GetDefaultPromptEvaluatorToolsResponse GetDefaultPromptEvaluatorTools(1: GetDefaultPromptEvaluatorToolsRequest req) (api.post="/api/evaluation/v1/evaluators/default_prompt_evaluator_tools") // 获取prompt evaluator tools配置 + ListTemplatesResponse ListTemplates(1: ListTemplatesRequest request) ( + api.post= "/api/evaluation/v1/evaluators/list_template", api.op_type = 'list', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) // 获取内置评估器模板列表(不含具体内容) + GetTemplateInfoResponse GetTemplateInfo(1: GetTemplateInfoRequest request) ( + api.post= "/api/evaluation/v1/evaluators/get_template_info", api.op_type = 'query', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) // 按key单个查询内置评估器模板详情 + GetDefaultPromptEvaluatorToolsResponse GetDefaultPromptEvaluatorTools(1: GetDefaultPromptEvaluatorToolsRequest req) ( + api.post="/api/evaluation/v1/evaluators/default_prompt_evaluator_tools", api.op_type = 'query', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) // 获取prompt evaluator tools配置 // 评估器执行 - RunEvaluatorResponse RunEvaluator(1: RunEvaluatorRequest req) (api.post="/api/evaluation/v1/evaluators_versions/:evaluator_version_id/run")// evaluator 运行 - DebugEvaluatorResponse DebugEvaluator(1: DebugEvaluatorRequest req) (api.post="/api/evaluation/v1/evaluators/debug")// evaluator 调试 - BatchDebugEvaluatorResponse BatchDebugEvaluator(1: BatchDebugEvaluatorRequest req) (api.post="/api/evaluation/v1/evaluators/batch_debug")// evaluator 调试 + RunEvaluatorResponse RunEvaluator(1: RunEvaluatorRequest req) ( + api.post="/api/evaluation/v1/evaluators_versions/:evaluator_version_id/run", api.op_type = 'update', api.tag = 'volc-agentkit', api.category = 'evaluator' + )// evaluator 运行 + DebugEvaluatorResponse DebugEvaluator(1: DebugEvaluatorRequest req) ( + api.post="/api/evaluation/v1/evaluators/debug", api.op_type = 'update', api.tag = 'volc-agentkit', api.category = 'evaluator', api.timeout = '300000' + )// evaluator 调试 + BatchDebugEvaluatorResponse BatchDebugEvaluator(1: BatchDebugEvaluatorRequest req) ( + api.post="/api/evaluation/v1/evaluators/batch_debug", api.op_type = 'update', api.tag = 'volc-agentkit', api.category = 'evaluator', api.timeout = '300000' + )// evaluator 调试 // 评估器执行结果 - UpdateEvaluatorRecordResponse UpdateEvaluatorRecord(1: UpdateEvaluatorRecordRequest req) (api.patch="/api/evaluation/v1/evaluator_records/:evaluator_record_id") // 修正evaluator运行分数 + UpdateEvaluatorRecordResponse UpdateEvaluatorRecord(1: UpdateEvaluatorRecordRequest req) ( + api.patch="/api/evaluation/v1/evaluator_records/:evaluator_record_id", api.op_type = 'query', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) // 修正evaluator运行分数 GetEvaluatorRecordResponse GetEvaluatorRecord(1: GetEvaluatorRecordRequest req) BatchGetEvaluatorRecordsResponse BatchGetEvaluatorRecords(1: BatchGetEvaluatorRecordsRequest req) // 评估器验证 - ValidateEvaluatorResponse ValidateEvaluator(1: ValidateEvaluatorRequest request) (api.post="/api/evaluation/v1/evaluators/validate") + ValidateEvaluatorResponse ValidateEvaluator(1: ValidateEvaluatorRequest request) ( + api.post="/api/evaluation/v1/evaluators/validate", api.op_type = 'query', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) // 查询评估器模板 - ListTemplatesV2Response ListTemplatesV2(1: ListTemplatesV2Request request) (api.post="/api/evaluation/v1/evaluator_template/list") - GetTemplateV2Response GetTemplateV2(1: GetTemplateV2Request request) (api.get="/api/evaluation/v1/evaluator_template/:evaluator_template_id") + ListTemplatesV2Response ListTemplatesV2(1: ListTemplatesV2Request request) ( + api.post="/api/evaluation/v1/evaluator_template/list", api.op_type = 'list', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) + GetTemplateV2Response GetTemplateV2(1: GetTemplateV2Request request) ( + api.get="/api/evaluation/v1/evaluator_template/:evaluator_template_id", api.op_type = 'query', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) // 创建评估器模板 - CreateEvaluatorTemplateResponse CreateEvaluatorTemplate(1: CreateEvaluatorTemplateRequest request) (api.post="/api/evaluation/v1/evaluator_template") + CreateEvaluatorTemplateResponse CreateEvaluatorTemplate(1: CreateEvaluatorTemplateRequest request) ( + api.post="/api/evaluation/v1/evaluator_template", api.op_type = 'create', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) // 更新评估器模板 - UpdateEvaluatorTemplateResponse UpdateEvaluatorTemplate(1: UpdateEvaluatorTemplateRequest request) (api.patch="/api/evaluation/v1/evaluator_template/:evaluator_template_id") + UpdateEvaluatorTemplateResponse UpdateEvaluatorTemplate(1: UpdateEvaluatorTemplateRequest request) ( + api.patch="/api/evaluation/v1/evaluator_template/:evaluator_template_id", api.op_type = 'update', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) // 删除 - DeleteEvaluatorTemplateResponse DeleteEvaluatorTemplate(1: DeleteEvaluatorTemplateRequest request) (api.delete="/api/evaluation/v1/evaluator_template/:evaluator_template_id") + DeleteEvaluatorTemplateResponse DeleteEvaluatorTemplate(1: DeleteEvaluatorTemplateRequest request) ( + api.delete="/api/evaluation/v1/evaluator_template/:evaluator_template_id", api.op_type = 'delete', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) // 调试预置评估器 - DebugBuiltinEvaluatorResponse DebugBuiltinEvaluator(1: DebugBuiltinEvaluatorRequest req) (api.post="/api/evaluation/v1/evaluators/debug_builtin")// 调试预置评估器 + DebugBuiltinEvaluatorResponse DebugBuiltinEvaluator(1: DebugBuiltinEvaluatorRequest req) ( + api.post="/api/evaluation/v1/evaluators/debug_builtin", api.op_type = 'update', api.tag = 'volc-agentkit', api.category = 'evaluator', api.timeout = '300000' + )// 调试预置评估器 // 更新预置评估器tag UpdateBuiltinEvaluatorTagsResponse UpdateBuiltinEvaluatorTags(1: UpdateBuiltinEvaluatorTagsRequest req) // 查询Tag - ListEvaluatorTagsResponse ListEvaluatorTags(1: ListEvaluatorTagsRequest req) (api.post="/api/evaluation/v1/evaluators/list_tags") + ListEvaluatorTagsResponse ListEvaluatorTags(1: ListEvaluatorTagsRequest req) ( + api.post="/api/evaluation/v1/evaluators/list_tags", api.op_type = 'query', api.tag = 'volc-agentkit', api.category = 'evaluator' + ) -} (api.js_conv="true" ) \ No newline at end of file +} \ No newline at end of file diff --git a/idl/thrift/coze/loop/evaluation/coze.loop.evaluation.expt.thrift b/idl/thrift/coze/loop/evaluation/coze.loop.evaluation.expt.thrift index ce9b0be7d..4f3f329c3 100644 --- a/idl/thrift/coze/loop/evaluation/coze.loop.evaluation.expt.thrift +++ b/idl/thrift/coze/loop/evaluation/coze.loop.evaluation.expt.thrift @@ -260,7 +260,7 @@ struct BatchGetExperimentAggrResultRequest { } struct BatchGetExperimentAggrResultResponse { - 1: optional list expt_aggregate_results (api.body = 'expt_aggregate_result') + 1: optional list expt_aggregate_result (api.body = 'expt_aggregate_result') 255: base.BaseResp BaseResp } @@ -470,7 +470,7 @@ struct ListExperimentTemplatesResponse { struct CheckExperimentTemplateNameRequest { 1: required i64 workspace_id (api.body='workspace_id', api.js_conv='true', go.tag='json:"workspace_id"') - 2: required string name (api.body='name', api.js_conv='true', go.tag='json:"name"') + 2: required string name (api.body='name') 3: optional i64 template_id (api.body='template_id', api.js_conv='true', go.tag='json:"template_id"') 255: optional base.Base Base @@ -604,7 +604,7 @@ struct GetExptResultExportRecordRequest { } struct GetExptResultExportRecordResponse { - 1: optional expt.ExptResultExportRecord expt_result_export_record (api.body = "expt_result_export_records") + 1: optional expt.ExptResultExportRecord expt_result_export_records (api.body = "expt_result_export_records") 255: base.BaseResp BaseResp } @@ -722,38 +722,65 @@ struct GetAnalysisRecordFeedbackVoteResponse { service ExperimentService { - CheckExperimentNameResponse CheckExperimentName(1: CheckExperimentNameRequest req) (api.post = '/api/evaluation/v1/experiments/check_name') + CheckExperimentNameResponse CheckExperimentName(1: CheckExperimentNameRequest req) ( + api.post = '/api/evaluation/v1/experiments/check_name', api.op_type = 'query', api.tag = 'volc-agentkit', api.category = 'experiment' + ) // CreateExperiment 只创建,不提交运行 CreateExperimentResponse CreateExperiment(1: CreateExperimentRequest req) // SubmitExperiment 创建并提交运行 - SubmitExperimentResponse SubmitExperiment(1: SubmitExperimentRequest req) (api.post = '/api/evaluation/v1/experiments/submit') + SubmitExperimentResponse SubmitExperiment(1: SubmitExperimentRequest req) ( + api.post = '/api/evaluation/v1/experiments/submit', api.op_type = 'create', api.tag = 'volc-agentkit,open', api.category = 'experiment' + ) - BatchGetExperimentsResponse BatchGetExperiments(1: BatchGetExperimentsRequest req) (api.post = '/api/evaluation/v1/experiments/batch_get') + BatchGetExperimentsResponse BatchGetExperiments(1: BatchGetExperimentsRequest req) ( + api.post = '/api/evaluation/v1/experiments/batch_get', api.op_type = 'query', api.tag = 'volc-agentkit,open', api.category = 'experiment' + ) - ListExperimentsResponse ListExperiments(1: ListExperimentsRequest req) (api.post = '/api/evaluation/v1/experiments/list') + ListExperimentsResponse ListExperiments(1: ListExperimentsRequest req) ( + api.post = '/api/evaluation/v1/experiments/list', api.op_type = 'list', api.tag = 'volc-agentkit', api.category = 'experiment' + ) - UpdateExperimentResponse UpdateExperiment(1: UpdateExperimentRequest req) (api.patch = '/api/evaluation/v1/experiments/:expt_id') + UpdateExperimentResponse UpdateExperiment(1: UpdateExperimentRequest req) ( + api.patch = '/api/evaluation/v1/experiments/:expt_id', api.op_type = 'update', api.tag = 'volc-agentkit', api.category = 'experiment' + ) - DeleteExperimentResponse DeleteExperiment(1: DeleteExperimentRequest req) (api.delete = '/api/evaluation/v1/experiments/:expt_id') + DeleteExperimentResponse DeleteExperiment(1: DeleteExperimentRequest req) ( + api.delete = '/api/evaluation/v1/experiments/:expt_id', api.op_type = 'delete', api.tag = 'volc-agentkit', api.category = 'experiment' + ) - BatchDeleteExperimentsResponse BatchDeleteExperiments(1: BatchDeleteExperimentsRequest req) (api.delete = '/api/evaluation/v1/experiments/batch_delete') + BatchDeleteExperimentsResponse BatchDeleteExperiments(1: BatchDeleteExperimentsRequest req) ( + api.delete = '/api/evaluation/v1/experiments/batch_delete', api.op_type = 'delete', api.tag = 'volc-agentkit', api.category = 'experiment' + ) - CloneExperimentResponse CloneExperiment(1: CloneExperimentRequest req) (api.post = '/api/evaluation/v1/experiments/:expt_id/clone') + CloneExperimentResponse CloneExperiment(1: CloneExperimentRequest req) ( + api.post = '/api/evaluation/v1/experiments/:expt_id/clone', api.op_type = 'create', api.tag = 'volc-agentkit', api.category = 'experiment' + ) // RunExperiment 运行已创建的实验 RunExperimentResponse RunExperiment(1: RunExperimentRequest req) - RetryExperimentResponse RetryExperiment(1: RetryExperimentRequest req) (api.post = '/api/evaluation/v1/experiments/:expt_id/retry') + RetryExperimentResponse RetryExperiment(1: RetryExperimentRequest req) ( + api.post = '/api/evaluation/v1/experiments/:expt_id/retry', api.op_type = 'update', api.tag = 'volc-agentkit', api.category = 'experiment' + ) - KillExperimentResponse KillExperiment(1: KillExperimentRequest req) (api.post = '/api/evaluation/v1/experiments/:expt_id/kill') + KillExperimentResponse KillExperiment(1: KillExperimentRequest req) ( + api.post = '/api/evaluation/v1/experiments/:expt_id/kill', api.op_type = 'update', api.tag = 'volc-agentkit', api.category = 'experiment' + ) // MGetExperimentResult 获取实验结果 - BatchGetExperimentResultResponse BatchGetExperimentResult(1: BatchGetExperimentResultRequest req) (api.post = "/api/evaluation/v1/experiments/results/batch_get") + BatchGetExperimentResultResponse BatchGetExperimentResult(1: BatchGetExperimentResultRequest req) ( + api.post = "/api/evaluation/v1/experiments/results/batch_get", api.op_type = 'query', api.tag = 'volc-agentkit,open', api.category = 'experiment' + ) - CalculateExperimentAggrResultResponse CalculateExperimentAggrResult(1: CalculateExperimentAggrResultRequest req) (api.post = "/api/evaluation/v1/experiments/:expt_id/aggr_results") - BatchGetExperimentAggrResultResponse BatchGetExperimentAggrResult(1: BatchGetExperimentAggrResultRequest req) (api.post = "/api/evaluation/v1/experiments/aggr_results/batch_get") + CalculateExperimentAggrResultResponse CalculateExperimentAggrResult(1: CalculateExperimentAggrResultRequest req) ( + api.post = "/api/evaluation/v1/experiments/:expt_id/aggr_results", api.op_type = 'update', api.tag = 'volc-agentkit', api.category = 'experiment' + ) + + BatchGetExperimentAggrResultResponse BatchGetExperimentAggrResult(1: BatchGetExperimentAggrResultRequest req) ( + api.post = "/api/evaluation/v1/experiments/aggr_results/batch_get", api.op_type = 'query', api.tag = 'volc-agentkit,open', api.category = 'experiment' + ) // 在线实验 InvokeExperimentResponse InvokeExperiment(1: InvokeExperimentRequest req) @@ -772,12 +799,18 @@ service ExperimentService { UpdateAnnotateRecordResp UpdateAnnotateRecord(1: UpdateAnnotateRecordReq req) (api.post = "/api/evaluation/v1/experiments/:expt_id/annotate_record/update") // 报告下载 - ExportExptResultResponse ExportExptResult(1: ExportExptResultRequest req) (api.post="/api/evaluation/v1/experiments/:expt_id/results/export") - ListExptResultExportRecordResponse ListExptResultExportRecord(1: ListExptResultExportRecordRequest req) (api.post="/api/evaluation/v1/experiments/:expt_id/export_records/list") - GetExptResultExportRecordResponse GetExptResultExportRecord(1: GetExptResultExportRecordRequest req) (api.post="/api/evaluation/v1/experiments/:expt_id/export_records/:export_id") + ExportExptResultResponse ExportExptResult(1: ExportExptResultRequest req) ( + api.post="/api/evaluation/v1/experiments/:expt_id/results/export", api.op_type = 'query', api.tag = 'volc-agentkit', api.category = 'experiment' + ) + ListExptResultExportRecordResponse ListExptResultExportRecord(1: ListExptResultExportRecordRequest req) ( + api.post="/api/evaluation/v1/experiments/:expt_id/export_records/list", api.op_type = 'list', api.tag = 'volc-agentkit', api.category = 'experiment' + ) + GetExptResultExportRecordResponse GetExptResultExportRecord(1: GetExptResultExportRecordRequest req) ( + api.post="/api/evaluation/v1/experiments/:expt_id/export_records/:export_id", api.op_type = 'query', api.tag = 'volc-agentkit', api.category = 'experiment' + ) // 报告分析 - InsightAnalysisExperimentResponse InsightAnalysisExperiment(1: InsightAnalysisExperimentRequest req) (api.post="/api/evaluation/v1/experiments/:expt_id/insight_analysis") + InsightAnalysisExperimentResponse InsightAnalysisExperiment(1: InsightAnalysisExperimentRequest req) (api.post="/api/evaluation/v1/experiments/:expt_id/insight_analysis" ) ListExptInsightAnalysisRecordResponse ListExptInsightAnalysisRecord(1: ListExptInsightAnalysisRecordRequest req) (api.post="/api/evaluation/v1/experiments/:expt_id/insight_analysis_records/list") DeleteExptInsightAnalysisRecordResponse DeleteExptInsightAnalysisRecord(1: DeleteExptInsightAnalysisRecordRequest req) (api.delete="/api/evaluation/v1/experiments/:expt_id/insight_analysis_records/:insight_analysis_record_id") GetExptInsightAnalysisRecordResponse GetExptInsightAnalysisRecord(1: GetExptInsightAnalysisRecordRequest req) (api.post="/api/evaluation/v1/experiments/:expt_id/insight_analysis_records/:insight_analysis_record_id") @@ -786,12 +819,26 @@ service ExperimentService { GetAnalysisRecordFeedbackVoteResponse GetAnalysisRecordFeedbackVote(1: GetAnalysisRecordFeedbackVoteRequest req) (api.get="/api/evaluation/v1/experiments/insight_analysis_records/:insight_analysis_record_id/feedback_vote") // 实验模板 - CreateExperimentTemplateResponse CreateExperimentTemplate(1: CreateExperimentTemplateRequest req) (api.post = '/api/evaluation/v1/experiment_templates') - BatchGetExperimentTemplateResponse BatchGetExperimentTemplate(1: BatchGetExperimentTemplateRequest req) (api.post = '/api/evaluation/v1/experiment_templates/batch_get') - UpdateExperimentTemplateMetaResponse UpdateExperimentTemplateMeta(1: UpdateExperimentTemplateMetaRequest req) (api.post = '/api/evaluation/v1/experiment_templates/update_meta') - UpdateExperimentTemplateResponse UpdateExperimentTemplate(1: UpdateExperimentTemplateRequest req) (api.patch = '/api/evaluation/v1/experiment_templates/:template_id') // 更新实验模板(不允许修改关联的评测对象 / 评测集,仅允许修改默认版本、映射、评估器与配置) - DeleteExperimentTemplateResponse DeleteExperimentTemplate(1: DeleteExperimentTemplateRequest req) (api.delete = '/api/evaluation/v1/experiment_templates/:template_id') - ListExperimentTemplatesResponse ListExperimentTemplates(1: ListExperimentTemplatesRequest req) (api.post = '/api/evaluation/v1/experiment_templates/list') - CheckExperimentTemplateNameResponse CheckExperimentTemplateName(1: CheckExperimentTemplateNameRequest req) (api.post = '/api/evaluation/v1/experiment_templates/check_name') + CreateExperimentTemplateResponse CreateExperimentTemplate(1: CreateExperimentTemplateRequest req) ( + api.post = '/api/evaluation/v1/experiment_templates', api.op_type = 'query', api.tag = 'volc-agentkit', api.category = 'experiment' + ) + BatchGetExperimentTemplateResponse BatchGetExperimentTemplate(1: BatchGetExperimentTemplateRequest req) ( + api.post = '/api/evaluation/v1/experiment_templates/batch_get', api.op_type = 'query', api.tag = 'volc-agentkit', api.category = 'experiment' + ) + UpdateExperimentTemplateMetaResponse UpdateExperimentTemplateMeta(1: UpdateExperimentTemplateMetaRequest req) ( + api.post = '/api/evaluation/v1/experiment_templates/update_meta', api.op_type = 'update', api.tag = 'volc-agentkit', api.category = 'experiment' + ) + UpdateExperimentTemplateResponse UpdateExperimentTemplate(1: UpdateExperimentTemplateRequest req) ( + api.patch = '/api/evaluation/v1/experiment_templates/:template_id', api.op_type = 'update', api.tag = 'volc-agentkit', api.category = 'experiment' + ) // 更新实验模板(不允许修改关联的评测对象 / 评测集,仅允许修改默认版本、映射、评估器与配置) + DeleteExperimentTemplateResponse DeleteExperimentTemplate(1: DeleteExperimentTemplateRequest req) ( + api.delete = '/api/evaluation/v1/experiment_templates/:template_id', api.op_type = 'delete', api.tag = 'volc-agentkit', api.category = 'experiment' + ) + ListExperimentTemplatesResponse ListExperimentTemplates(1: ListExperimentTemplatesRequest req) ( + api.post = '/api/evaluation/v1/experiment_templates/list', api.op_type = 'list', api.tag = 'volc-agentkit', api.category = 'experiment' + ) + CheckExperimentTemplateNameResponse CheckExperimentTemplateName(1: CheckExperimentTemplateNameRequest req) ( + api.post = '/api/evaluation/v1/experiment_templates/check_name', api.op_type = 'query', api.tag = 'volc-agentkit', api.category = 'experiment' + ) } diff --git a/idl/thrift/coze/loop/evaluation/coze.loop.evaluation.openapi.thrift b/idl/thrift/coze/loop/evaluation/coze.loop.evaluation.openapi.thrift index f728a9f69..c8b3883b4 100644 --- a/idl/thrift/coze/loop/evaluation/coze.loop.evaluation.openapi.thrift +++ b/idl/thrift/coze/loop/evaluation/coze.loop.evaluation.openapi.thrift @@ -461,43 +461,43 @@ struct GetExperimentAggrResultOpenAPIData { service EvaluationOpenAPIService { // 评测集接口 // 创建评测集 - CreateEvaluationSetOApiResponse CreateEvaluationSetOApi(1: CreateEvaluationSetOApiRequest req) (api.tag="openapi", api.post = "/v1/loop/evaluation/evaluation_sets") + CreateEvaluationSetOApiResponse CreateEvaluationSetOApi(1: CreateEvaluationSetOApiRequest req) (api.category="openapi", api.post = "/v1/loop/evaluation/evaluation_sets") // 获取评测集详情 - GetEvaluationSetOApiResponse GetEvaluationSetOApi(1: GetEvaluationSetOApiRequest req) (api.tag="openapi", api.get = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id") + GetEvaluationSetOApiResponse GetEvaluationSetOApi(1: GetEvaluationSetOApiRequest req) (api.category="openapi", api.get = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id") // 更新评测集详情 - UpdateEvaluationSetOApiResponse UpdateEvaluationSetOApi(1: UpdateEvaluationSetOApiRequest req) (api.tag="openapi", api.put = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id") + UpdateEvaluationSetOApiResponse UpdateEvaluationSetOApi(1: UpdateEvaluationSetOApiRequest req) (api.category="openapi", api.put = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id") // 删除评测集 - DeleteEvaluationSetOApiResponse DeleteEvaluationSetOApi(1: DeleteEvaluationSetOApiRequest req) (api.tag="openapi", api.delete = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id") + DeleteEvaluationSetOApiResponse DeleteEvaluationSetOApi(1: DeleteEvaluationSetOApiRequest req) (api.category="openapi", api.delete = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id") // 查询评测集列表 - ListEvaluationSetsOApiResponse ListEvaluationSetsOApi(1: ListEvaluationSetsOApiRequest req) (api.tag="openapi", api.get = "/v1/loop/evaluation/evaluation_sets") + ListEvaluationSetsOApiResponse ListEvaluationSetsOApi(1: ListEvaluationSetsOApiRequest req) (api.category="openapi", api.get = "/v1/loop/evaluation/evaluation_sets") // 创建评测集版本 - CreateEvaluationSetVersionOApiResponse CreateEvaluationSetVersionOApi(1: CreateEvaluationSetVersionOApiRequest req) (api.tag="openapi", api.post = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id/versions") + CreateEvaluationSetVersionOApiResponse CreateEvaluationSetVersionOApi(1: CreateEvaluationSetVersionOApiRequest req) (api.category="openapi", api.post = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id/versions") // 获取评测集版本列表 - ListEvaluationSetVersionsOApiResponse ListEvaluationSetVersionsOApi(1: ListEvaluationSetVersionsOApiRequest req) (api.tag="evaluation_set", api.get = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id/versions") + ListEvaluationSetVersionsOApiResponse ListEvaluationSetVersionsOApi(1: ListEvaluationSetVersionsOApiRequest req) (api.category="openapi", api.get = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id/versions") // 批量添加评测集数据 - BatchCreateEvaluationSetItemsOApiResponse BatchCreateEvaluationSetItemsOApi(1: BatchCreateEvaluationSetItemsOApiRequest req) (api.tag="openapi", api.post = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id/items") + BatchCreateEvaluationSetItemsOApiResponse BatchCreateEvaluationSetItemsOApi(1: BatchCreateEvaluationSetItemsOApiRequest req) (api.category="openapi", api.post = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id/items") // 批量更新评测集数据 - BatchUpdateEvaluationSetItemsOApiResponse BatchUpdateEvaluationSetItemsOApi(1: BatchUpdateEvaluationSetItemsOApiRequest req) (api.tag="openapi", api.put = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id/items") + BatchUpdateEvaluationSetItemsOApiResponse BatchUpdateEvaluationSetItemsOApi(1: BatchUpdateEvaluationSetItemsOApiRequest req) (api.category="openapi", api.put = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id/items") // 批量删除评测集数据 - BatchDeleteEvaluationSetItemsOApiResponse BatchDeleteEvaluationSetItemsOApi(1: BatchDeleteEvaluationSetItemsOApiRequest req) (api.tag="openapi", api.delete = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id/items") + BatchDeleteEvaluationSetItemsOApiResponse BatchDeleteEvaluationSetItemsOApi(1: BatchDeleteEvaluationSetItemsOApiRequest req) (api.category="openapi", api.delete = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id/items") // 查询评测集特定版本数据 - ListEvaluationSetVersionItemsOApiResponse ListEvaluationSetVersionItemsOApi(1: ListEvaluationSetVersionItemsOApiRequest req) (api.tag="openapi", api.get = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id/items") + ListEvaluationSetVersionItemsOApiResponse ListEvaluationSetVersionItemsOApi(1: ListEvaluationSetVersionItemsOApiRequest req) (api.category="openapi", api.get = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id/items") // 查询评测集某个filed值,用于获取超长文本的内容 - GetEvaluationItemFieldOApiResponse GetEvaluationItemFieldOApi(1: GetEvaluationItemFieldOApiRequest req) (api.tag="openapi", api.get = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id/items/:item_id/field") + GetEvaluationItemFieldOApiResponse GetEvaluationItemFieldOApi(1: GetEvaluationItemFieldOApiRequest req) (api.category="openapi", api.get = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id/items/:item_id/field") // 更新评测集字段信息 - UpdateEvaluationSetSchemaOApiResponse UpdateEvaluationSetSchemaOApi(1: UpdateEvaluationSetSchemaOApiRequest req) (api.tag="openapi", api.put = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id/schema"), + UpdateEvaluationSetSchemaOApiResponse UpdateEvaluationSetSchemaOApi(1: UpdateEvaluationSetSchemaOApiRequest req) (api.category="openapi", api.put = "/v1/loop/evaluation/evaluation_sets/:evaluation_set_id/schema"), // 评测目标调用结果上报接口 ReportEvalTargetInvokeResultResponse ReportEvalTargetInvokeResult(1: ReportEvalTargetInvokeResultRequest req) (api.category="openapi", api.post = "/v1/loop/eval_targets/result") // 评测实验接口 // 创建评测实验 - SubmitExperimentOApiResponse SubmitExperimentOApi(1: SubmitExperimentOApiRequest req) (api.tag="openapi", api.post = "/v1/loop/evaluation/experiments") + SubmitExperimentOApiResponse SubmitExperimentOApi(1: SubmitExperimentOApiRequest req) (api.category="openapi", api.post = "/v1/loop/evaluation/experiments") // 获取评测实验 - GetExperimentsOApiResponse GetExperimentsOApi(1: GetExperimentsOApiRequest req) (api.tag="openapi", api.get = '/v1/loop/evaluation/experiments/:experiment_id') + GetExperimentsOApiResponse GetExperimentsOApi(1: GetExperimentsOApiRequest req) (api.category="openapi", api.get = '/v1/loop/evaluation/experiments/:experiment_id') // 查询评测实验结果 - ListExperimentResultOApiResponse ListExperimentResultOApi(1: ListExperimentResultOApiRequest req) (api.tag="openapi", api.post = "/v1/loop/evaluation/experiments/:experiment_id/results") + ListExperimentResultOApiResponse ListExperimentResultOApi(1: ListExperimentResultOApiRequest req) (api.category="openapi", api.post = "/v1/loop/evaluation/experiments/:experiment_id/results") // 获取聚合结果 - GetExperimentAggrResultOApiResponse GetExperimentAggrResultOApi(1: GetExperimentAggrResultOApiRequest req) (api.tag="openapi", api.post = "/v1/loop/evaluation/experiments/:experiment_id/aggr_results") + GetExperimentAggrResultOApiResponse GetExperimentAggrResultOApi(1: GetExperimentAggrResultOApiRequest req) (api.category="openapi", api.post = "/v1/loop/evaluation/experiments/:experiment_id/aggr_results") } diff --git a/idl/thrift/coze/loop/evaluation/domain/common.thrift b/idl/thrift/coze/loop/evaluation/domain/common.thrift index bca9cd8ce..4ed4dbe11 100644 --- a/idl/thrift/coze/loop/evaluation/domain/common.thrift +++ b/idl/thrift/coze/loop/evaluation/domain/common.thrift @@ -1,6 +1,7 @@ namespace go coze.loop.evaluation.domain.common include "../../data/domain/dataset.thrift" +include "../../llm/domain/manage.thrift" typedef string ContentType(ts.enum="true") @@ -117,6 +118,9 @@ struct ModelConfig { 3: optional double temperature 4: optional i32 max_tokens 5: optional double top_p + 6: optional manage.Protocol protocol + 7: optional string identification + 8: optional bool preset_model 50: optional string json_ext } diff --git a/idl/thrift/coze/loop/evaluation/domain/eval_target.thrift b/idl/thrift/coze/loop/evaluation/domain/eval_target.thrift index 8ad39c246..f97dfad74 100644 --- a/idl/thrift/coze/loop/evaluation/domain/eval_target.thrift +++ b/idl/thrift/coze/loop/evaluation/domain/eval_target.thrift @@ -53,6 +53,8 @@ enum EvalTargetType { CozeWorkflow = 4 VolcengineAgent = 5 // 火山智能体 CustomRPCServer = 6 // 自定义RPC服务 for内场 + + VolcengineAgentAgentkit = 7 // 火山智能体Agentkit } // Agent协议类型 @@ -125,6 +127,7 @@ struct VolcengineAgent { 11: optional string description // DTO使用,不存数据库 12: optional list volcengine_agent_endpoints // DTO使用,不存数据库 13: optional VolcengineAgentProtocol protocol // 注册协议 + 14: optional string runtime_id 100: optional common.BaseInfo base_info } diff --git a/idl/thrift/coze/loop/evaluation/domain/evaluator.thrift b/idl/thrift/coze/loop/evaluation/domain/evaluator.thrift index 645718073..7d5dd5a13 100644 --- a/idl/thrift/coze/loop/evaluation/domain/evaluator.thrift +++ b/idl/thrift/coze/loop/evaluation/domain/evaluator.thrift @@ -127,7 +127,7 @@ struct EvaluatorContent { // 明确有顺序的 evaluator 与版本映射元素 struct EvaluatorIDVersionItem { 1: optional i64 evaluator_id (api.js_conv = 'true', go.tag = 'json:"evaluator_id"') - 2: optional string version (api.js_conv = 'true', go.tag = 'json:"version"') + 2: optional string version 3: optional EvaluatorRunConfig run_config (go.tag = 'json:"run_config"') 4: optional i64 evaluator_version_id (api.js_conv = 'true', go.tag = 'json:"evaluator_version_id"') 5: optional double score_weight (go.tag = 'json:"score_weight"') diff --git a/idl/thrift/coze/loop/evaluation/domain/expt.thrift b/idl/thrift/coze/loop/evaluation/domain/expt.thrift index 01f8d34b3..fb247339b 100644 --- a/idl/thrift/coze/loop/evaluation/domain/expt.thrift +++ b/idl/thrift/coze/loop/evaluation/domain/expt.thrift @@ -570,9 +570,11 @@ struct ExptResultExportRecord { 5: optional common.BaseInfo base_info 6: optional i64 start_time (api.js_conv='true', go.tag='json:"start_time"') 7: optional i64 end_time (api.js_conv='true', go.tag='json:"end_time"') + // deprecated, cause not match snake name 8: optional string URL 9: optional bool expired 10: optional RunError error + 11: optional string url } // 分析任务状态 diff --git a/idl/thrift/coze/loop/llm/coze.loop.llm.manage.thrift b/idl/thrift/coze/loop/llm/coze.loop.llm.manage.thrift index 447a6c79c..39ceadb26 100644 --- a/idl/thrift/coze/loop/llm/coze.loop.llm.manage.thrift +++ b/idl/thrift/coze/loop/llm/coze.loop.llm.manage.thrift @@ -4,11 +4,24 @@ include "../../../base.thrift" include "./domain/manage.thrift" include "./domain/common.thrift" +struct Filter { + 1: optional string name_like + 2: optional list families + 3: optional list statuses + 4: optional list abilities +} + struct ListModelsRequest { 1: optional i64 workspace_id (api.js_conv='true', vt.not_nil='true', vt.gt='0', go.tag='json:"workspace_id"') 2: optional common.Scenario scenario + 3: optional Filter filter + 4: optional bool preset_model // 是否为预置模型 + + + 100: optional string cookie (api.header='cookie') 127: optional i32 page_size 128: optional string page_token + 129: optional i32 page 255: optional base.Base Base } @@ -25,6 +38,9 @@ struct ListModelsResponse { struct GetModelRequest { 1: optional i64 workspace_id (api.js_conv='true', vt.not_nil='true', vt.gt='0', go.tag='json:"workspace_id"') 2: optional i64 model_id (api.js_conv='true', api.path='model_id', go.tag='json:"model_id"') + 3: optional string identification + 4: optional manage.Protocol protocol + 5: optional bool preset_model // 是否为预置模型 255: optional base.Base Base } diff --git a/idl/thrift/coze/loop/llm/domain/manage.thrift b/idl/thrift/coze/loop/llm/domain/manage.thrift index a3462c26e..a00862d38 100644 --- a/idl/thrift/coze/loop/llm/domain/manage.thrift +++ b/idl/thrift/coze/loop/llm/domain/manage.thrift @@ -12,8 +12,44 @@ struct Model { 7: optional ProtocolConfig protocol_config 8: optional map scenario_configs 9: optional ParamConfig param_config + 10: optional string identification // 模型表示 (name, endpoint) + 11: optional Series series // 模型 + 12: optional Visibility visibility + 13: optional string icon // 模型图标 + 14: optional list tags //模型标签 + 15: optional ModelStatus status // 模型状态 + 16: optional string original_model_url // 模型跳转链接 + 17: optional bool preset_model // 是否为预置模型 + + 100: optional string created_by + 101: optional i64 created_at + 102: optional string updated_by + 103: optional i64 updated_at +} + +struct Series { + 1: optional string name // series name + 2: optional string icon // series icon url + 3: optional Family family // family name } +struct Visibility { + 1: optional VisibleMode mode + 2: optional list spaceIDs // Mode为Specified有效,配置为除模型所属空间外的其他空间 +} + +struct ProviderInfo { + 1: optional MaaSInfo maas_info +} + +struct MaaSInfo { + 1: optional string host + 2: optional string region + 3: optional string baseURL // v3 sdk + 4: optional string customizationJobsID // 精调模型任务的 ID +} + + struct Ability { 1: optional i64 max_context_tokens (api.js_conv='true', go.tag='json:"max_context_tokens"') 2: optional i64 max_input_tokens (api.js_conv='true', go.tag='json:"max_input_tokens"') @@ -22,11 +58,14 @@ struct Ability { 5: optional bool json_mode 6: optional bool multi_modal 7: optional AbilityMultiModal ability_multi_modal + 8: optional InterfaceCategory interface_category } struct AbilityMultiModal { + // 图片 1: optional bool image 2: optional AbilityImage ability_image + // 视频 3: optional bool video 4: optional AbilityVideo ability_video } @@ -143,6 +182,14 @@ struct ParamSchema { 6: optional string max 7: optional string default_value 8: optional list options + 9: optional list properties + 10: optional Reaction reaction // 依赖参数 + 11: optional string jsonpath // 赋值路径 +} + +struct Reaction { + 1: optional string dependency // 依赖的字段 + 2: optional string visible // 可见性表达式 } struct ParamOption { @@ -171,6 +218,74 @@ const ParamType param_type_float = "float" const ParamType param_type_int = "int" const ParamType param_type_boolean = "boolean" const ParamType param_type_string = "string" +const ParamType param_type_void = "void" +const ParamType param_type_object = "object" + +typedef string Family (ts.enum="true") +const Family family_undefined = "undefined" +const Family family_gpt = "gpt" +const Family family_seed = "seed" +const Family family_gemini = "gemini" +const Family family_claude = "claude" +const Family family_ernie = "ernie" +const Family family_baichuan = "baichuan" +const Family family_qwen = "qwen" +const Family family_glm = "glm" +const Family family_skylark = "skylark" +const Family family_moonshot = "moonshot" +const Family family_minimax = "minimax" +const Family family_doubao = "doubao" +const Family family_baichuan2 = "baichuan2" +const Family family_deepseekv2 = "deepseekv2" +const Family family_deepseek_coder_v2 = "deepseek_coder_v2" +const Family family_deepseek_coder = "deepseek_coder" +const Family family_internalm25 = "internalm2_5" +const Family family_qwen2 = "qwen2" +const Family family_qwen25 = "qwen2.5" +const Family family_qwen25_coder = "qwen2.5_coder" +const Family family_mini_cpm = "mini_cpm" +const Family family_mini_cpm3 = "mini_cpm_3" +const Family family_chat_glm3 = "chat_glm_3" +const Family family_mistra = "mistral" +const Family family_gemma = "gemma" +const Family family_gemma_2 = "gemma_2" +const Family family_intern_vl2 = "intern_vl2" +const Family family_intern_vl25 = "intern_vl2.5" +const Family family_deepseek_v3 = "deepseek_v3" +const Family family_deepseek_r1 = "deepseek_r1" +const Family family_kimi = "kimi" +const Family family_seedream = "seedream" +const Family family_intern_vl3 = "intern_vl3" +const Family family_deepseek = "deepseek" + + +typedef string Provider (ts.enum="true") +const Provider provider_undefined = "undefined" +const Provider provider_maas = "maas" + +typedef string VisibleMode (ts.enum="true") +const VisibleMode visible_mode_default = "default" +const VisibleMode visible_mode_specified = "specified" +const VisibleMode visible_mode_undefined = "undefined" +const VisibleMode visible_mode_all = "all" + +typedef string ModelStatus (ts.enum="true") +const ModelStatus model_status_undefined = "undefined" +const ModelStatus model_status_available = "available" //可用 +const ModelStatus model_status_unavailable = "unavailable" //不可用 + +typedef string InterfaceCategory (ts.enum="true") +const InterfaceCategory interface_category_undefined = "undefined" +const InterfaceCategory interface_category_chat_completion_api = "chat_completion_api" +const InterfaceCategory interface_category_response_api = "response_api" + +typedef string AbilityEnum (ts.enum="true") +const AbilityEnum ability_undefined = "undefined" +const AbilityEnum ability_json_mode = "json_mode" +const AbilityEnum ability_function_call = "function_call" +const AbilityEnum ability_multi_modal = "multi_modal" + + typedef string VideoFormat (ts.enum="true") const VideoFormat video_format_undefined = "undefined" diff --git a/idl/thrift/coze/loop/llm/domain/runtime.thrift b/idl/thrift/coze/loop/llm/domain/runtime.thrift index e645f3f65..7ff5f4800 100644 --- a/idl/thrift/coze/loop/llm/domain/runtime.thrift +++ b/idl/thrift/coze/loop/llm/domain/runtime.thrift @@ -14,9 +14,13 @@ struct ModelConfig { 8: optional i32 top_k 9: optional double presence_penalty 10: optional double frequency_penalty + 11: optional string identification + 12: optional manage.Protocol protocol // 模型提供方 + 13: optional bool preset_model // 是否为预置模型 // 与ParamSchema对应 100: optional list param_config_values + 101: optional string extra } struct ParamConfigValue {