diff --git a/.gitignore b/.gitignore index 3f62ca4..2adeb23 100644 --- a/.gitignore +++ b/.gitignore @@ -28,4 +28,5 @@ go.work.sum /.vscode /output coverage.out -/apb \ No newline at end of file +/apb +.coda/ \ No newline at end of file diff --git a/client.go b/client.go index b04d448..be08dd2 100644 --- a/client.go +++ b/client.go @@ -318,7 +318,8 @@ func GetPrompt(ctx context.Context, param GetPromptParam, options ...GetPromptOp // PromptFormat format prompt with variables func PromptFormat(ctx context.Context, prompt *entity.Prompt, variables map[string]any, options ...PromptFormatOption) ( - messages []*entity.Message, err error) { + messages []*entity.Message, err error, +) { return getDefaultClient().PromptFormat(ctx, prompt, variables, options...) } @@ -495,6 +496,20 @@ func (c *loopClient) PromptFormat(ctx context.Context, loopPrompt *entity.Prompt return c.promptProvider.PromptFormat(ctx, loopPrompt, variables, config) } +func (c *loopClient) Execute(ctx context.Context, req *entity.ExecuteParam, options ...ExecuteOption) (entity.ExecuteResult, error) { + if c.closed { + return entity.ExecuteResult{}, consts.ErrClientClosed + } + return c.promptProvider.Execute(ctx, req, options...) +} + +func (c *loopClient) ExecuteStreaming(ctx context.Context, req *entity.ExecuteParam, options ...ExecuteStreamingOption) (entity.StreamReader[entity.ExecuteResult], error) { + if c.closed { + return nil, consts.ErrClientClosed + } + return c.promptProvider.ExecuteStreaming(ctx, req, options...) +} + func (c *loopClient) StartSpan(ctx context.Context, name, spanType string, opts ...StartSpanOption) (context.Context, Span) { if c.closed { return ctx, DefaultNoopSpan diff --git a/entity/prompt.go b/entity/prompt.go index 593bfbf..ea006a5 100644 --- a/entity/prompt.go +++ b/entity/prompt.go @@ -3,7 +3,9 @@ package entity -import "github.com/coze-dev/cozeloop-go/internal/util" +import ( + "github.com/coze-dev/cozeloop-go/internal/util" +) type Prompt struct { WorkspaceID string `json:"workspace_id"` @@ -29,9 +31,24 @@ const ( ) type Message struct { - Role Role `json:"role"` - Content *string `json:"content,omitempty"` - Parts []*ContentPart `json:"parts,omitempty"` + Role Role `json:"role"` + ReasoningContent *string `json:"reasoning_content,omitempty"` + Content *string `json:"content,omitempty"` + Parts []*ContentPart `json:"parts,omitempty"` + ToolCallID *string `json:"tool_call_id,omitempty"` + ToolCalls []*ToolCall `json:"tool_calls,omitempty"` +} + +type ToolCall struct { + Index int32 `json:"index"` + ID string `json:"id"` + Type ToolType `json:"type"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` +} + +type FunctionCall struct { + Name string `json:"name"` + Arguments *string `json:"arguments,omitempty"` } type Role string @@ -45,9 +62,10 @@ const ( ) type ContentPart struct { - Type ContentType `json:"type"` - Text *string `json:"text,omitempty"` - ImageURL *string `json:"image_url,omitempty"` + Type ContentType `json:"type"` + Text *string `json:"text,omitempty"` + ImageURL *string `json:"image_url,omitempty"` + Base64Data *string `json:"base64_data,omitempty"` } type ContentType string @@ -55,6 +73,7 @@ type ContentType string const ( ContentTypeText ContentType = "text" ContentTypeImageURL ContentType = "image_url" + ContentTypeBase64Data ContentType = "base64_data" ContentTypeMultiPartVariable ContentType = "multi_part_variable" ) @@ -119,6 +138,25 @@ type LLMConfig struct { JSONMode *bool `json:"json_mode,omitempty"` } +type ExecuteParam struct { + PromptKey string `json:"prompt_key"` + Version string `json:"version,omitempty"` + Label string `json:"label,omitempty"` + VariableVals map[string]any `json:"variable_vals,omitempty"` + Messages []*Message `json:"messages,omitempty"` +} + +type ExecuteResult struct { + Message *Message `json:"message,omitempty"` + FinishReason *string `json:"finish_reason,omitempty"` + Usage *TokenUsage `json:"usage,omitempty"` +} + +type TokenUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + func (p *Prompt) DeepCopy() *Prompt { if p == nil { return nil @@ -181,12 +219,14 @@ func (cp *ContentPart) DeepCopy() *ContentPart { return nil } copied := &ContentPart{ - Type: cp.Type, - ImageURL: cp.ImageURL, + Type: cp.Type, } if cp.Text != nil { copied.Text = util.Ptr(*cp.Text) } + if cp.ImageURL != nil { + copied.ImageURL = util.Ptr(*cp.ImageURL) + } return copied } diff --git a/entity/stream.go b/entity/stream.go new file mode 100644 index 0000000..3c1ee90 --- /dev/null +++ b/entity/stream.go @@ -0,0 +1,8 @@ +// Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +// SPDX-License-Identifier: MIT + +package entity + +type StreamReader[T any] interface { + Recv() (T, error) +} diff --git a/error.go b/error.go index 26f395f..1e889a3 100644 --- a/error.go +++ b/error.go @@ -16,5 +16,7 @@ var ( ErrParsePrivateKey = consts.ErrParsePrivateKey ) -type AuthError = consts.AuthError -type RemoteServiceError = consts.RemoteServiceError +type ( + AuthError = consts.AuthError + RemoteServiceError = consts.RemoteServiceError +) diff --git a/examples/init/log/log.go b/examples/init/log/log.go index b872cdd..cbfe036 100644 --- a/examples/init/log/log.go +++ b/examples/init/log/log.go @@ -31,8 +31,7 @@ func main() { cozeloop.Close(ctx) } -type CustomLogger struct { -} +type CustomLogger struct{} func (l *CustomLogger) CtxDebugf(ctx context.Context, format string, v ...interface{}) { fmt.Printf("[Custom] [DEBUG] "+format+"\n", v...) diff --git a/examples/prompt/prompt_hub.go b/examples/prompt/prompt_hub/prompt_hub.go similarity index 99% rename from examples/prompt/prompt_hub.go rename to examples/prompt/prompt_hub/prompt_hub.go index 52c50b0..c96081d 100644 --- a/examples/prompt/prompt_hub.go +++ b/examples/prompt/prompt_hub/prompt_hub.go @@ -131,8 +131,8 @@ func (r *llmRunner) llmCall(ctx context.Context, messages []*entity.Message) (er defer span.Finish(ctx) // llm is processing - //baseURL := "https://xxx" - //ak := "****" + // baseURL := "https://xxx" + // ak := "****" modelName := "gpt-4o-2024-05-13" maxTokens := 1000 // range: [0, 4096] //transport := &MyTransport{ diff --git a/examples/prompt/prompt_hub_jinja/prompt_hub_jinja.go b/examples/prompt/prompt_hub/prompt_hub_jinja/prompt_hub_jinja.go similarity index 99% rename from examples/prompt/prompt_hub_jinja/prompt_hub_jinja.go rename to examples/prompt/prompt_hub/prompt_hub_jinja/prompt_hub_jinja.go index 294a48c..6310f40 100644 --- a/examples/prompt/prompt_hub_jinja/prompt_hub_jinja.go +++ b/examples/prompt/prompt_hub/prompt_hub_jinja/prompt_hub_jinja.go @@ -160,8 +160,8 @@ func (r *llmRunner) llmCall(ctx context.Context, messages []*entity.Message) (er defer span.Finish(ctx) // llm is processing - //baseURL := "https://xxx" - //ak := "****" + // baseURL := "https://xxx" + // ak := "****" modelName := "gpt-4o-2024-05-13" maxTokens := 1000 // range: [0, 4096] //transport := &MyTransport{ diff --git a/examples/prompt/prompt_hub_label/prompt_hub_label.go b/examples/prompt/prompt_hub/prompt_hub_label/prompt_hub_label.go similarity index 99% rename from examples/prompt/prompt_hub_label/prompt_hub_label.go rename to examples/prompt/prompt_hub/prompt_hub_label/prompt_hub_label.go index 1c40485..a2622a3 100644 --- a/examples/prompt/prompt_hub_label/prompt_hub_label.go +++ b/examples/prompt/prompt_hub/prompt_hub_label/prompt_hub_label.go @@ -125,8 +125,8 @@ func (r *llmRunner) llmCall(ctx context.Context, messages []*entity.Message) (er defer span.Finish(ctx) // llm is processing - //baseURL := "https://xxx" - //ak := "****" + // baseURL := "https://xxx" + // ak := "****" modelName := "gpt-4o-2024-05-13" maxTokens := 1000 // range: [0, 4096] //transport := &MyTransport{ diff --git a/examples/prompt/prompt_hub_multipart/prompt_hub_multipart.go b/examples/prompt/prompt_hub/prompt_hub_multipart/prompt_hub_multipart.go similarity index 98% rename from examples/prompt/prompt_hub_multipart/prompt_hub_multipart.go rename to examples/prompt/prompt_hub/prompt_hub_multipart/prompt_hub_multipart.go index ca66336..ac1be4d 100644 --- a/examples/prompt/prompt_hub_multipart/prompt_hub_multipart.go +++ b/examples/prompt/prompt_hub/prompt_hub_multipart/prompt_hub_multipart.go @@ -77,7 +77,7 @@ func main() { // 4.Format messages of the prompt imageText := "图片样例" - imageURL := "https://example.com" //公网访问地址 + imageURL := "https://example.com" // 公网访问地址 messages, err := llmRunner.client.PromptFormat(ctx, prompt, map[string]any{ "num": "2", "count": 10, @@ -131,8 +131,8 @@ func (r *llmRunner) llmCall(ctx context.Context, messages []*entity.Message) (er defer span.Finish(ctx) // llm is processing - //baseURL := "https://xxx" - //ak := "****" + // baseURL := "https://xxx" + // ak := "****" modelName := "gpt-4o-2024-05-13" maxTokens := 1000 // range: [0, 4096] //transport := &MyTransport{ diff --git a/examples/prompt/ptaas/ptaas.go b/examples/prompt/ptaas/ptaas.go new file mode 100644 index 0000000..8d7f4f5 --- /dev/null +++ b/examples/prompt/ptaas/ptaas.go @@ -0,0 +1,92 @@ +// Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +// SPDX-License-Identifier: MIT + +package main + +import ( + "context" + "fmt" + "io" + + "github.com/coze-dev/cozeloop-go" + "github.com/coze-dev/cozeloop-go/entity" + "github.com/coze-dev/cozeloop-go/internal/util" +) + +func main() { + // 1.Create a prompt on the platform + // Create a Prompt on the platform's Prompt development page (set Prompt Key to 'ptaas_demo'), + // add the following messages to the template, submit a version. + // System: You are a helpful assistant for {{topic}}. + // User: Please help me with {{user_request}} + ctx := context.Background() + + // Set the following environment variables first. + // COZELOOP_WORKSPACE_ID=your workspace id + // COZELOOP_API_TOKEN=your token + // 2.New loop client + client, err := cozeloop.NewClient() + if err != nil { + panic(err) + } + defer client.Close(ctx) + + // 3. Execute prompt + executeRequest := &entity.ExecuteParam{ + PromptKey: "ptaas_demo", + Version: "0.0.1", + VariableVals: map[string]any{ + "topic": "artificial intelligence", + "user_request": "explain what is machine learning", + }, + // You can also append messages to the prompt. + Messages: []*entity.Message{ + { + Role: entity.RoleUser, + Content: util.Ptr("Keep the answer brief."), + }, + }, + } + // 3.1 non stream + nonStream(ctx, client, executeRequest) + // 3.2 stream + stream(ctx, client, executeRequest) +} + +func nonStream(ctx context.Context, client cozeloop.Client, executeRequest *entity.ExecuteParam) { + result, err := client.Execute(ctx, executeRequest) + if err != nil { + panic(err) + } + printExecuteResult(result) +} + +func stream(ctx context.Context, client cozeloop.Client, executeRequest *entity.ExecuteParam) { + streamReader, err := client.ExecuteStreaming(ctx, executeRequest) + if err != nil { + panic(err) + } + for { + result, err := streamReader.Recv() + if err != nil { + if err == io.EOF { + fmt.Println("\nStream finished.") + break + } + panic(err) + } + printExecuteResult(result) + } +} + +func printExecuteResult(result entity.ExecuteResult) { + if result.Message != nil { + fmt.Printf("Message: %s\n", util.ToJSON(result.Message)) + } + if util.PtrValue(result.FinishReason) != "" { + fmt.Printf("FinishReason: %s\n", util.PtrValue(result.FinishReason)) + } + if result.Usage != nil { + fmt.Printf("Usage: %s\n", util.ToJSON(result.Usage)) + } +} diff --git a/examples/prompt/ptaas/ptaas_jinja/ptaas_jinja.go b/examples/prompt/ptaas/ptaas_jinja/ptaas_jinja.go new file mode 100644 index 0000000..0b15c67 --- /dev/null +++ b/examples/prompt/ptaas/ptaas_jinja/ptaas_jinja.go @@ -0,0 +1,71 @@ +// Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +// SPDX-License-Identifier: MIT + +package main + +import ( + "context" + "fmt" + + "github.com/coze-dev/cozeloop-go" + "github.com/coze-dev/cozeloop-go/entity" + "github.com/coze-dev/cozeloop-go/internal/util" +) + +// The explanation of jinja2 template is based on non-streaming execution, and it also applies to streaming execution. +func main() { + // 1.Create a prompt using jinja2 template on the platform + // Create a Prompt on the platform's Prompt development page (set Prompt Key to 'ptaas_demo'), + // add the following messages to the template, submit a version, and set a label (e.g., 'production') for that version. + // System: You are a helpful assistant for {{param.topic}}. Your audience is {{param.age}} years old. + // User: Please help me with {{param.user_request}} + ctx := context.Background() + + // Set the following environment variables first. + // COZELOOP_WORKSPACE_ID=your workspace id + // COZELOOP_API_TOKEN=your token + // 2.New loop client + client, err := cozeloop.NewClient() + if err != nil { + panic(err) + } + defer client.Close(ctx) + + // 3. Execute prompt + executeRequest := &entity.ExecuteParam{ + PromptKey: "ptaas_demo", + Version: "0.0.2", + VariableVals: map[string]any{ + "param": struct { + Topic string `json:"topic"` + Age int `json:"age"` + UserRequest string `json:"user_request"` + }{ + Topic: "artificial intelligence", + Age: 10, + UserRequest: "explain what is machine learning", + }, + }, + } + nonStream(ctx, client, executeRequest) +} + +func nonStream(ctx context.Context, client cozeloop.Client, executeRequest *entity.ExecuteParam) { + result, err := client.Execute(ctx, executeRequest) + if err != nil { + panic(err) + } + printExecuteResult(result) +} + +func printExecuteResult(result entity.ExecuteResult) { + if result.Message != nil { + fmt.Printf("Message: %s\n", util.ToJSON(result.Message)) + } + if util.PtrValue(result.FinishReason) != "" { + fmt.Printf("FinishReason: %s\n", util.PtrValue(result.FinishReason)) + } + if result.Usage != nil { + fmt.Printf("Usage: %s\n", util.ToJSON(result.Usage)) + } +} diff --git a/examples/prompt/ptaas/ptaas_multi_modal/ptaas_multi_modal.go b/examples/prompt/ptaas/ptaas_multi_modal/ptaas_multi_modal.go new file mode 100644 index 0000000..fc7812b --- /dev/null +++ b/examples/prompt/ptaas/ptaas_multi_modal/ptaas_multi_modal.go @@ -0,0 +1,107 @@ +// Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +// SPDX-License-Identifier: MIT + +package main + +import ( + "context" + "encoding/base64" + "fmt" + "os" + + "github.com/coze-dev/cozeloop-go" + "github.com/coze-dev/cozeloop-go/entity" + "github.com/coze-dev/cozeloop-go/internal/util" +) + +// The explanation of multi modal is based on non-streaming execution, and it also applies to streaming execution. +func main() { + // 1.Create a prompt on the platform + // Create a Prompt on the platform's Prompt development page (set Prompt Key to 'ptaas_demo'), + // add the following messages to the template, submit a version. example1 and example2 are the multi modal variables. + // System: You can quickly identify the location where a photo was taken. + // User: 例如:{{example1}} + // Assistant: {{city1}} + // User: 例如:{{example2}} + // Assistant: {{city2}} + ctx := context.Background() + + // Set the following environment variables first. + // COZELOOP_WORKSPACE_ID=your workspace id + // COZELOOP_API_TOKEN=your token + // 2.New loop client + client, err := cozeloop.NewClient() + if err != nil { + panic(err) + } + defer client.Close(ctx) + + // 3. Execute prompt + imagePath := "your image path" + imageBytes, err := os.ReadFile(imagePath) + if err != nil { + panic(err) + } + base64Image := base64.StdEncoding.EncodeToString(imageBytes) + base64data := fmt.Sprintf("data:image/jpeg;base64,%s", base64Image) + executeRequest := &entity.ExecuteParam{ + PromptKey: "ptaas_demo", + Version: "0.0.8", + // multi modal variable can be []*entity.ContentPart(recommend)/[]entity.ContentPart/*entity.ContentPart/entity.ContentPart + // Images can be provided via URL or in base64 encoded format. + // Image URL needs to be publicly accessible. + // Base64-formatted data should follow the standard data URI format, like "data:[][;base64],". + VariableVals: map[string]any{ + "example1": []*entity.ContentPart{ + { + Type: entity.ContentTypeImageURL, + ImageURL: util.Ptr("https://p8.itc.cn/q_70/images03/20221219/61785c89cd17421ca0d007c7a87d09fb.jpeg"), + }, + }, + "city1": "Beijing", + "example2": []*entity.ContentPart{ + { + Type: entity.ContentTypeBase64Data, + Base64Data: util.Ptr(base64data), + }, + }, + "city2": "Shanghai", + }, + Messages: []*entity.Message{ + { + Role: entity.RoleUser, + Parts: []*entity.ContentPart{ + { + Type: entity.ContentTypeImageURL, + ImageURL: util.Ptr("https://img0.baidu.com/it/u=1402951118,1660594928&fm=253&app=138&f=JPEG?w=800&h=1200"), + }, + { + Type: entity.ContentTypeText, + Text: util.Ptr("Where is this photo taken?"), + }, + }, + }, + }, + } + nonStream(ctx, client, executeRequest) +} + +func nonStream(ctx context.Context, client cozeloop.Client, executeRequest *entity.ExecuteParam) { + result, err := client.Execute(ctx, executeRequest) + if err != nil { + panic(err) + } + printExecuteResult(result) +} + +func printExecuteResult(result entity.ExecuteResult) { + if result.Message != nil { + fmt.Printf("Message: %s\n", util.ToJSON(result.Message)) + } + if util.PtrValue(result.FinishReason) != "" { + fmt.Printf("FinishReason: %s\n", util.PtrValue(result.FinishReason)) + } + if result.Usage != nil { + fmt.Printf("Usage: %s\n", util.ToJSON(result.Usage)) + } +} diff --git a/examples/prompt/ptaas/ptaas_placeholder_variable/ptaas_placeholder_variable.go b/examples/prompt/ptaas/ptaas_placeholder_variable/ptaas_placeholder_variable.go new file mode 100644 index 0000000..f4ad5bc --- /dev/null +++ b/examples/prompt/ptaas/ptaas_placeholder_variable/ptaas_placeholder_variable.go @@ -0,0 +1,76 @@ +// Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +// SPDX-License-Identifier: MIT + +package main + +import ( + "context" + "fmt" + + "github.com/coze-dev/cozeloop-go" + "github.com/coze-dev/cozeloop-go/entity" + "github.com/coze-dev/cozeloop-go/internal/util" +) + +// The explanation of placeholder variable is based on non-streaming execution, and it also applies to streaming execution. +func main() { + // 1.Create a prompt on the platform + // Create a Prompt on the platform's Prompt development page (set Prompt Key to 'ptaas_demo'), + // add the following messages to the template, submit a version. + // System: You are a helpful assistant for {{topic}}. + // Placeholder: {{chat_history}} + // User: Please help me with {{user_request}} + ctx := context.Background() + + // Set the following environment variables first. + // COZELOOP_WORKSPACE_ID=your workspace id + // COZELOOP_API_TOKEN=your token + // 2.New loop client + client, err := cozeloop.NewClient() + if err != nil { + panic(err) + } + defer client.Close(ctx) + + // 3. Execute prompt + executeRequest := &entity.ExecuteParam{ + PromptKey: "ptaas_demo", + Version: "0.0.5", + VariableVals: map[string]any{ + "topic": "artificial intelligence", + // chat_history is a placeholder variable, and it can be []*entity.Message(recommend)/[]entity.Message/*entity.Message/entity.Message. + "chat_history": []*entity.Message{ + { + Role: entity.RoleUser, + Content: util.Ptr("hello"), + }, + { + Role: entity.RoleAssistant, + Content: util.Ptr("hello"), + }, + }, + "user_request": "explain what is machine learning", + }, + } + nonStream(ctx, client, executeRequest) +} + +func nonStream(ctx context.Context, client cozeloop.Client, executeRequest *entity.ExecuteParam) { + result, err := client.Execute(ctx, executeRequest) + if err != nil { + panic(err) + } + printExecuteResult(result) +} + +func printExecuteResult(result entity.ExecuteResult) { + if result.Message != nil { + fmt.Printf("Message: %s\n", util.ToJSON(result.Message)) + } + if util.PtrValue(result.FinishReason) != "" { + fmt.Printf("FinishReason: %s\n", util.PtrValue(result.FinishReason)) + } + if result.Usage != nil { + fmt.Printf("Usage: %s\n", util.ToJSON(result.Usage)) + } +} diff --git a/examples/prompt/ptaas/ptaas_time_out/ptaas_time_out.go b/examples/prompt/ptaas/ptaas_time_out/ptaas_time_out.go new file mode 100644 index 0000000..2837fd4 --- /dev/null +++ b/examples/prompt/ptaas/ptaas_time_out/ptaas_time_out.go @@ -0,0 +1,73 @@ +// Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +// SPDX-License-Identifier: MIT + +package main + +import ( + "context" + "fmt" + "time" + + "github.com/coze-dev/cozeloop-go" + "github.com/coze-dev/cozeloop-go/entity" + "github.com/coze-dev/cozeloop-go/internal/util" +) + +// The explanation of timeout settings is based on non-streaming execution, and it also applies to streaming execution. +func main() { + setCtxTimeout() +} + +func setCtxTimeout() { + // 1.Create a prompt on the platform + // Create a Prompt on the platform's Prompt development page (set Prompt Key to 'ptaas_demo'), + // add the following messages to the template, submit a version, and set a label (e.g., 'production') for that version. + // System: You are a helpful assistant for {{topic}}. + // User: Please help me with {{user_request}} + ctx := context.Background() + + // Set the following environment variables first. + // COZELOOP_WORKSPACE_ID=your workspace id + // COZELOOP_API_TOKEN=your token + // 2.New loop client + client, err := cozeloop.NewClient() + if err != nil { + panic(err) + } + defer client.Close(ctx) + + // 3. Set context timeout, default is 600s, max is 600s. + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + + // 4. Execute prompt + executeRequest := &entity.ExecuteParam{ + PromptKey: "ptaas_demo", + Version: "0.0.1", + VariableVals: map[string]any{ + "topic": "artificial intelligence", + "user_request": "explain what is machine learning", + }, + } + nonStream(ctx, client, executeRequest) +} + +func nonStream(ctx context.Context, client cozeloop.Client, executeRequest *entity.ExecuteParam) { + result, err := client.Execute(ctx, executeRequest) + if err != nil { + panic(err) + } + printExecuteResult(result) +} + +func printExecuteResult(result entity.ExecuteResult) { + if result.Message != nil { + fmt.Printf("Message: %s\n", util.ToJSON(result.Message)) + } + if util.PtrValue(result.FinishReason) != "" { + fmt.Printf("FinishReason: %s\n", util.PtrValue(result.FinishReason)) + } + if result.Usage != nil { + fmt.Printf("Usage: %s\n", util.ToJSON(result.Usage)) + } +} diff --git a/examples/prompt/ptaas/ptaas_with_label/ptaas_with_label.go b/examples/prompt/ptaas/ptaas_with_label/ptaas_with_label.go new file mode 100644 index 0000000..6a42e7e --- /dev/null +++ b/examples/prompt/ptaas/ptaas_with_label/ptaas_with_label.go @@ -0,0 +1,64 @@ +// Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +// SPDX-License-Identifier: MIT + +package main + +import ( + "context" + "fmt" + + "github.com/coze-dev/cozeloop-go" + "github.com/coze-dev/cozeloop-go/entity" + "github.com/coze-dev/cozeloop-go/internal/util" +) + +// The explanation of label is based on non-streaming execution, and it also applies to streaming execution. +func main() { + // 1.Create a prompt on the platform + // Create a Prompt on the platform's Prompt development page (set Prompt Key to 'ptaas_demo'), + // add the following messages to the template, submit a version, and set a label (e.g., 'production') for that version. + // System: You are a helpful assistant for {{topic}}. + // User: Please help me with {{user_request}} + ctx := context.Background() + + // Set the following environment variables first. + // COZELOOP_WORKSPACE_ID=your workspace id + // COZELOOP_API_TOKEN=your token + // 2.New loop client + client, err := cozeloop.NewClient() + if err != nil { + panic(err) + } + defer client.Close(ctx) + + // 3. Execute prompt + executeRequest := &entity.ExecuteParam{ + PromptKey: "ptaas_demo", + Label: "production", // Note: When Version is specified, Label field will be ignored + VariableVals: map[string]any{ + "topic": "artificial intelligence", + "user_request": "explain what is machine learning", + }, + } + nonStream(ctx, client, executeRequest) +} + +func nonStream(ctx context.Context, client cozeloop.Client, executeRequest *entity.ExecuteParam) { + result, err := client.Execute(ctx, executeRequest) + if err != nil { + panic(err) + } + printExecuteResult(result) +} + +func printExecuteResult(result entity.ExecuteResult) { + if result.Message != nil { + fmt.Printf("Message: %s\n", util.ToJSON(result.Message)) + } + if util.PtrValue(result.FinishReason) != "" { + fmt.Printf("FinishReason: %s\n", util.PtrValue(result.FinishReason)) + } + if result.Usage != nil { + fmt.Printf("Usage: %s\n", util.ToJSON(result.Usage)) + } +} diff --git a/examples/trace/benchmark_test.go b/examples/trace/benchmark_test.go index 32e78b4..8a5d334 100644 --- a/examples/trace/benchmark_test.go +++ b/examples/trace/benchmark_test.go @@ -46,7 +46,7 @@ func BenchmarkMyFunctionWithQPS(b *testing.B) { select { case <-ticker.C: go func() { - //logger.CtxInfof(ctx, "run span demo ######################################################################################") + // logger.CtxInfof(ctx, "run span demo ######################################################################################") runner.llmRunner(ctx, "test input") }() case <-done: diff --git a/examples/trace/large_text/large_text.go b/examples/trace/large_text/large_text.go index 1d2d8b9..fd23cd0 100644 --- a/examples/trace/large_text/large_text.go +++ b/examples/trace/large_text/large_text.go @@ -89,7 +89,7 @@ func main() { // -- close trace, do flush and close client // Warning! Once Close is executed, the client will become unavailable and a new client needs // to be created via NewClient! Use it only when you need to release resources, such as shutting down an instance! - //client.Close(ctx) + // client.Close(ctx) } func (r *llmRunner) llmCall(ctx context.Context, input string) (err error) { @@ -97,8 +97,8 @@ func (r *llmRunner) llmCall(ctx context.Context, input string) (err error) { defer span.Finish(ctx) // llm is processing - //baseURL := "https://xxx" - //ak := "****" + // baseURL := "https://xxx" + // ak := "****" modelName := "gpt-4o-2024-05-13" //maxTokens := 1000 // range: [0, 4096] //transport := &MyTransport{ diff --git a/examples/trace/multi_modality/multi_modality.go b/examples/trace/multi_modality/multi_modality.go index f94a7b0..2306e86 100644 --- a/examples/trace/multi_modality/multi_modality.go +++ b/examples/trace/multi_modality/multi_modality.go @@ -85,7 +85,7 @@ func main() { // -- close trace, do flush and close client // Warning! Once Close is executed, the client will become unavailable and a new client needs // to be created via NewClient! Use it only when you need to release resources, such as shutting down an instance! - //client.Close(ctx) + // client.Close(ctx) } func (r *llmRunner) llmCall(ctx context.Context) (err error) { @@ -93,8 +93,8 @@ func (r *llmRunner) llmCall(ctx context.Context) (err error) { defer span.Finish(ctx) // llm is processing - //baseURL := "https://xxx" - //ak := "****" + // baseURL := "https://xxx" + // ak := "****" modelName := "gpt-4o-2024-05-13" maxTokens := 1000 // range: [0, 4096] //transport := &MyTransport{ diff --git a/examples/trace/parent_child/parent_child.go b/examples/trace/parent_child/parent_child.go index 0d1e8f0..49558c2 100644 --- a/examples/trace/parent_child/parent_child.go +++ b/examples/trace/parent_child/parent_child.go @@ -84,7 +84,7 @@ func main() { // -- close trace, do flush and close client // Warning! Once Close is executed, the client will become unavailable and a new client needs // to be created via NewClient! Use it only when you need to release resources, such as shutting down an instance! - //client.Close(ctx) + // client.Close(ctx) } func (r *llmRunner) llmCall(ctx context.Context) (err error) { @@ -93,8 +93,8 @@ func (r *llmRunner) llmCall(ctx context.Context) (err error) { defer span.Finish(ctx) // llm is processing - //baseURL := "https://xxx" - //ak := "****" + // baseURL := "https://xxx" + // ak := "****" modelName := "gpt-4o-2024-05-13" //maxTokens := 1000 // range: [0, 4096] //transport := &MyTransport{ diff --git a/examples/trace/prompt/prompt.go b/examples/trace/prompt/prompt.go index 3849ed7..6afa601 100644 --- a/examples/trace/prompt/prompt.go +++ b/examples/trace/prompt/prompt.go @@ -85,7 +85,7 @@ func main() { // -- close trace, do flush and close client // Warning! Once Close is executed, the client will become unavailable and a new client needs // to be created via NewClient! Use it only when you need to release resources, such as shutting down an instance! - //client.Close(ctx) + // client.Close(ctx) } func (r *getPromptRunner) getPrompt(ctx context.Context) (prompt *entity.Prompt, err error) { diff --git a/examples/trace/simple/simple.go b/examples/trace/simple/simple.go index e1a6304..f62ec98 100644 --- a/examples/trace/simple/simple.go +++ b/examples/trace/simple/simple.go @@ -78,7 +78,7 @@ func main() { // -- close trace, do flush and close client // Warning! Once Close is executed, the client will become unavailable and a new client needs // to be created via NewClient! Use it only when you need to release resources, such as shutting down an instance! - //client.Close(ctx) + // client.Close(ctx) } func (r *llmRunner) llmCall(ctx context.Context) (err error) { @@ -86,8 +86,8 @@ func (r *llmRunner) llmCall(ctx context.Context) (err error) { defer span.Finish(ctx) // llm is processing - //baseURL := "https://xxx" - //ak := "****" + // baseURL := "https://xxx" + // ak := "****" modelName := "gpt-4o-2024-05-13" //maxTokens := 1000 // range: [0, 4096] //transport := &MyTransport{ diff --git a/examples/trace/transfer_between_services/transfer_between_services.go b/examples/trace/transfer_between_services/transfer_between_services.go index fee94fb..c198aaa 100644 --- a/examples/trace/transfer_between_services/transfer_between_services.go +++ b/examples/trace/transfer_between_services/transfer_between_services.go @@ -90,7 +90,7 @@ func main() { // -- close trace, do flush and close client // Warning! Once Close is executed, the client will become unavailable and a new client needs // to be created via NewClient! Use it only when you need to release resources, such as shutting down an instance! - //client.Close(ctx) + // client.Close(ctx) } func (r *llmRunner) llmCall(ctx context.Context) (err error) { @@ -98,8 +98,8 @@ func (r *llmRunner) llmCall(ctx context.Context) (err error) { defer span.Finish(ctx) // llm is processing - //baseURL := "https://xxx" - //ak := "****" + // baseURL := "https://xxx" + // ak := "****" modelName := "gpt-4o-2024-05-13" //maxTokens := 1000 // range: [0, 4096] //transport := &MyTransport{ diff --git a/go.mod b/go.mod index 36f820e..96c506a 100644 --- a/go.mod +++ b/go.mod @@ -5,12 +5,12 @@ go 1.18 require ( github.com/bluele/gcache v0.0.2 github.com/bytedance/mockey v1.2.14 - github.com/coze-dev/cozeloop-go/spec v0.1.4 + github.com/coze-dev/cozeloop-go/spec v0.1.4-0.20250829072213-3812ddbfb735 github.com/golang-jwt/jwt v3.2.2+incompatible github.com/nikolalohinski/gonja/v2 v2.3.1 github.com/smartystreets/goconvey v1.8.1 github.com/valyala/fasttemplate v1.2.2 - golang.org/x/sync v0.11.0 + golang.org/x/sync v0.16.0 ) require ( @@ -20,14 +20,14 @@ require ( github.com/jtolds/gls v4.20.0+incompatible // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/pkg/errors v0.9.1 // indirect + github.com/pkg/errors v0.9.2-0.20201214064552-5dd12d0cfe7f // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/smarty/assertions v1.15.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/arch v0.11.0 // indirect - golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect - golang.org/x/sys v0.26.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/arch v0.15.0 // indirect + golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476 // indirect + golang.org/x/sys v0.34.0 // indirect + golang.org/x/text v0.27.0 // indirect ) replace github.com/coze-dev/cozeloop-go/spec => ./spec diff --git a/go.sum b/go.sum index 84ac02c..1eefd6e 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,5 @@ github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ= +github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= github.com/bluele/gcache v0.0.2 h1:WcbfdXICg7G/DGBh1PFfcirkWOQV+v077yF1pSy3DGw= github.com/bluele/gcache v0.0.2/go.mod h1:m15KV+ECjptwSPxKhOhQoAFQVtUFjTVkc3H8o0t/fp0= github.com/bytedance/mockey v1.2.14 h1:KZaFgPdiUwW+jOWFieo3Lr7INM1P+6adO3hxZhDswY8= @@ -9,12 +10,16 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= +github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= +github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -29,9 +34,11 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY github.com/nikolalohinski/gonja/v2 v2.3.1 h1:UGyLa6NDNq6dCGkFY33sziUssjTdh95xrYslxZdqNVU= github.com/nikolalohinski/gonja/v2 v2.3.1/go.mod h1:1Wcc/5huTu6y36e0sOFR1XQoFlylw3c3H3L5WOz0RDg= github.com/onsi/ginkgo/v2 v2.11.0 h1:WgqUCUt/lT6yXoQ8Wef0fsNn5cAuMK7+KT9UFRz2tcU= +github.com/onsi/ginkgo/v2 v2.11.0/go.mod h1:ZhrRA5XmEE3x3rhlzamx/JJvujdZoJ2uvgI7kR0iZvM= github.com/onsi/gomega v1.27.8 h1:gegWiwZjBsf2DgiSbf5hpokZ98JVDMcWkUiigk6/KXc= -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/onsi/gomega v1.27.8/go.mod h1:2J8vzI/s+2shY9XHRApDkdgPo1TKT7P2u6fXeJKFnNQ= +github.com/pkg/errors v0.9.2-0.20201214064552-5dd12d0cfe7f h1:lJqhwddJVYAkyp72a4pwzMClI20xTwL7miDdm2W/KBM= +github.com/pkg/errors v0.9.2-0.20201214064552-5dd12d0cfe7f/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= @@ -44,23 +51,27 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -golang.org/x/arch v0.11.0 h1:KXV8WWKCXm6tRpLirl2szsO5j/oOODwZf4hATmGVNs4= -golang.org/x/arch v0.11.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= -golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8= -golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= +golang.org/x/arch v0.15.0 h1:QtOrQd0bTUnhNVNndMpLHNWrDmYzZ2KDqSrEymqInZw= +golang.org/x/arch v0.15.0/go.mod h1:JmwW7aLIoRUKgaTzhkiEFxvcEiQGyOg9BMonBJUS7EE= +golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476 h1:bsqhLWFR6G6xiQcb+JoGqdKdRU6WzPWmK8E0jxTjzo4= +golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= -golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= -golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/tools v0.20.0 h1:hz/CVckiOxybQvFw6h7b/q80NTr9IUQb4s1IIzW7KNY= +golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= +golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= +golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= +golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= +golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/consts/consts.go b/internal/consts/consts.go index 177fe15..2774fb5 100644 --- a/internal/consts/consts.go +++ b/internal/consts/consts.go @@ -55,8 +55,10 @@ const ( ) const ( - TracePromptHubSpanName = "PromptHub" - TracePromptTemplateSpanName = "PromptTemplate" + TracePromptHubSpanName = "PromptHub" + TracePromptTemplateSpanName = "PromptTemplate" + TracePromptExecuteSpanName = "PromptExecute" + TracePromptExecuteStreamingSpanName = "PromptExecuteStreaming" ) const ( diff --git a/internal/consts/vars.go b/internal/consts/vars.go index a475693..4a21239 100644 --- a/internal/consts/vars.go +++ b/internal/consts/vars.go @@ -11,22 +11,20 @@ import ( // span -var ( - BaggageSpecialChars = []string{"=", ","} -) +var BaggageSpecialChars = []string{"=", ","} + +var TagValueSizeLimit = map[string]int{ + tracespec.Input: MaxBytesOfOneTagValueOfInputOutput, + tracespec.Output: MaxBytesOfOneTagValueOfInputOutput, +} var ( - TagValueSizeLimit = map[string]int{ - tracespec.Input: MaxBytesOfOneTagValueOfInputOutput, - tracespec.Output: MaxBytesOfOneTagValueOfInputOutput, - } + typeInt64 int64 + typeStr string + typeInt int + typeInt32 int32 ) -var typeInt64 int64 -var typeStr string -var typeInt int -var typeInt32 int32 - // ReserveFieldTypes Define the allowed types for each reserved field. var ReserveFieldTypes = map[string][]reflect.Type{ UserID: {reflect.TypeOf(typeStr)}, diff --git a/internal/httpclient/backoff.go b/internal/httpclient/backoff.go index 943e028..3bf3c40 100644 --- a/internal/httpclient/backoff.go +++ b/internal/httpclient/backoff.go @@ -12,9 +12,7 @@ import ( "github.com/coze-dev/cozeloop-go/internal/consts" ) -var ( - defaultBackoff = NewBackoff(defaultBaseDelay, defaultMaxDelay) -) +var defaultBackoff = NewBackoff(defaultBaseDelay, defaultMaxDelay) const ( defaultBaseDelay = 200 * time.Millisecond diff --git a/internal/httpclient/client.go b/internal/httpclient/client.go index aff9076..75d1cf8 100644 --- a/internal/httpclient/client.go +++ b/internal/httpclient/client.go @@ -90,7 +90,7 @@ func (c *Client) PostWithRetry(ctx context.Context, path string, body any, resp func (c *Client) Post(ctx context.Context, path string, body any, resp OpenAPIResponse) error { var cancel context.CancelFunc - if c.timeout > 0 { + if _, ok := ctx.Deadline(); !ok && c.timeout > 0 { ctx, cancel = context.WithTimeout(ctx, c.timeout) defer cancel() } @@ -123,6 +123,45 @@ func (c *Client) Post(ctx context.Context, path string, body any, resp OpenAPIRe return parseResponse(ctx, url, response, resp) } +func (c *Client) PostStream(ctx context.Context, path string, body any) (*http.Response, error) { + if _, ok := ctx.Deadline(); !ok && c.timeout > 0 { + ctx, _ = context.WithTimeout(ctx, c.timeout) + } + + var bodyReader io.Reader + if body != nil { + data, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("marshal body: %w", err) + } + bodyReader = bytes.NewReader(data) + } + + url := c.baseURL + path + request, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bodyReader) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + headers := map[string]string{"Content-Type": "application/json"} + if err := c.setHeaders(ctx, request, headers); err != nil { + return nil, err + } + + response, err := c.httpClient.Do(request) + if err != nil { + logger.CtxErrorf(ctx, "http client PostStream failed, url: %v, err: %v", url, err) + return nil, consts.ErrRemoteService.Wrap(err) + } + if response.StatusCode != http.StatusOK { + logger.CtxErrorf(ctx, "http client PostStream failed, url: %v, status code: %v", url, response.StatusCode) + // 非200不会返回流式,而是直接返回错误信息 + return nil, parseResponse(ctx, url, response, &BaseResponse{}) + } + + return response, nil +} + func (c *Client) UploadFile(ctx context.Context, path string, fileName string, reader io.Reader, form map[string]string, resp OpenAPIResponse) error { var cancel context.CancelFunc if c.uploadTimeout > 0 { diff --git a/internal/prompt/convert.go b/internal/prompt/convert.go index 634cfdb..ab7f0bc 100644 --- a/internal/prompt/convert.go +++ b/internal/prompt/convert.go @@ -46,15 +46,25 @@ func toModelMessages(messages []*Message) []*entity.Message { if msg == nil { continue } - result[i] = &entity.Message{ - Role: toModelRole(msg.Role), - Content: msg.Content, - Parts: toContentParts(msg.Parts), - } + result[i] = toModelMessage(msg) } return result } +func toModelMessage(message *Message) *entity.Message { + if message == nil { + return nil + } + return &entity.Message{ + Role: toModelRole(message.Role), + ReasoningContent: message.ReasoningContent, + Content: message.Content, + Parts: toContentParts(message.Parts), + ToolCallID: message.ToolCallID, + ToolCalls: toModelToolCalls(message.ToolCalls), + } +} + func toContentParts(dos []*ContentPart) []*entity.ContentPart { if dos == nil { return nil @@ -83,6 +93,10 @@ func toContentType(do ContentType) entity.ContentType { switch do { case ContentTypeText: return entity.ContentTypeText + case ContentTypeImageURL: + return entity.ContentTypeMultiPartVariable + case ContentTypeBase64Data: + return entity.ContentTypeBase64Data case ContentTypeMultiPartVariable: return entity.ContentTypeMultiPartVariable default: @@ -90,6 +104,42 @@ func toContentType(do ContentType) entity.ContentType { } } +func toModelToolCalls(toolCalls []*ToolCall) []*entity.ToolCall { + if toolCalls == nil { + return nil + } + result := make([]*entity.ToolCall, 0, len(toolCalls)) + for _, toolCall := range toolCalls { + if toolCall == nil { + continue + } + result = append(result, toModelToolCall(toolCall)) + } + return result +} + +func toModelToolCall(toolCall *ToolCall) *entity.ToolCall { + if toolCall == nil { + return nil + } + return &entity.ToolCall{ + Index: util.PtrValue(toolCall.Index), + ID: util.PtrValue(toolCall.ID), + Type: toModelToolType(toolCall.Type), + FunctionCall: toModelFunctionCall(toolCall.FunctionCall), + } +} + +func toModelFunctionCall(fc *FunctionCall) *entity.FunctionCall { + if fc == nil { + return nil + } + return &entity.FunctionCall{ + Name: fc.Name, + Arguments: fc.Arguments, + } +} + func toModelVariableDefs(defs []*VariableDef) []*entity.VariableDef { if defs == nil { return nil @@ -239,6 +289,16 @@ func toModelToolChoiceType(tct ToolChoiceType) entity.ToolChoiceType { } } +func toModelTokenUsage(usage *TokenUsage) *entity.TokenUsage { + if usage == nil { + return nil + } + return &entity.TokenUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + } +} + // ===============to span model================ func toSpanPromptInput(messages []*entity.Message, arguments map[string]any) *tracespec.PromptInput { return &tracespec.PromptInput{ @@ -340,3 +400,145 @@ func ToSpanPartType(partType entity.ContentType) tracespec.ModelMessagePartType return tracespec.ModelMessagePartType(partType) } } + +// Reverse conversion functions: from entity to openapi types +func toOpenAPIMessages(messages []*entity.Message) []*Message { + if messages == nil { + return nil + } + result := make([]*Message, 0, len(messages)) + for _, msg := range messages { + if msg == nil { + continue + } + result = append(result, toOpenAPIMessage(msg)) + } + return result +} + +// toOpenAPIMessage converts entity.Message to openapi Message +func toOpenAPIMessage(message *entity.Message) *Message { + if message == nil { + return nil + } + return &Message{ + Role: toOpenAPIRole(message.Role), + ReasoningContent: message.ReasoningContent, + Content: message.Content, + Parts: toOpenAPIContentParts(message.Parts), + ToolCallID: message.ToolCallID, + ToolCalls: toOpenAPIToolCalls(message.ToolCalls), + } +} + +// toOpenAPIRole converts entity.Role to openapi Role +func toOpenAPIRole(r entity.Role) Role { + switch r { + case entity.RoleSystem: + return RoleSystem + case entity.RoleUser: + return RoleUser + case entity.RoleAssistant: + return RoleAssistant + case entity.RoleTool: + return RoleTool + case entity.RolePlaceholder: + return RolePlaceholder + default: + return RoleUser + } +} + +// toOpenAPIContentParts converts entity.ContentPart slice to openapi ContentPart slice +func toOpenAPIContentParts(parts []*entity.ContentPart) []*ContentPart { + if parts == nil { + return nil + } + result := make([]*ContentPart, 0, len(parts)) + for _, part := range parts { + if part == nil { + continue + } + result = append(result, toOpenAPIContentPart(part)) + } + return result +} + +// toOpenAPIContentPart converts entity.ContentPart to openapi ContentPart +func toOpenAPIContentPart(part *entity.ContentPart) *ContentPart { + if part == nil { + return nil + } + contentType := toOpenAPIContentType(part.Type) + return &ContentPart{ + Type: &contentType, + Text: part.Text, + ImageURL: part.ImageURL, + Base64Data: part.Base64Data, + } +} + +// toOpenAPIContentType converts entity.ContentType to openapi ContentType +func toOpenAPIContentType(ct entity.ContentType) ContentType { + switch ct { + case entity.ContentTypeText: + return ContentTypeText + case entity.ContentTypeImageURL: + return ContentTypeImageURL + case entity.ContentTypeBase64Data: + return ContentTypeBase64Data + case entity.ContentTypeMultiPartVariable: + return ContentTypeMultiPartVariable + default: + return ContentTypeText + } +} + +// toOpenAPIToolCalls converts entity.ToolCall slice to openapi ToolCall slice +func toOpenAPIToolCalls(toolCalls []*entity.ToolCall) []*ToolCall { + if toolCalls == nil { + return nil + } + result := make([]*ToolCall, 0, len(toolCalls)) + for _, toolCall := range toolCalls { + if toolCall == nil { + continue + } + result = append(result, toOpenAPIToolCall(toolCall)) + } + return result +} + +// toOpenAPIToolCall converts entity.ToolCall to openapi ToolCall +func toOpenAPIToolCall(toolCall *entity.ToolCall) *ToolCall { + if toolCall == nil { + return nil + } + return &ToolCall{ + Index: &toolCall.Index, + ID: &toolCall.ID, + Type: toOpenAPIToolType(toolCall.Type), + FunctionCall: toOpenAPIFunctionCall(toolCall.FunctionCall), + } +} + +// toOpenAPIToolType converts entity.ToolType to openapi ToolType +func toOpenAPIToolType(tt entity.ToolType) ToolType { + switch tt { + case entity.ToolTypeFunction: + return ToolTypeFunction + default: + return ToolTypeFunction + } +} + +// toOpenAPIFunctionCall converts entity.FunctionCall to openapi FunctionCall +func toOpenAPIFunctionCall(fc *entity.FunctionCall) *FunctionCall { + if fc == nil { + return nil + } + return &FunctionCall{ + Name: fc.Name, + Arguments: fc.Arguments, + } +} diff --git a/internal/prompt/execute_stream.go b/internal/prompt/execute_stream.go new file mode 100755 index 0000000..ce0d816 --- /dev/null +++ b/internal/prompt/execute_stream.go @@ -0,0 +1,95 @@ +// Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +// SPDX-License-Identifier: MIT + +package prompt + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/coze-dev/cozeloop-go/entity" + "github.com/coze-dev/cozeloop-go/internal/consts" + "github.com/coze-dev/cozeloop-go/internal/httpclient" + "github.com/coze-dev/cozeloop-go/internal/stream" +) + +// ExecuteSSEParser implements SSEParser for ExecuteResult +type ExecuteSSEParser struct { + logID string +} + +// NewExecuteSSEParser creates a new ExecuteSSEParser +func NewExecuteSSEParser(logID string) *ExecuteSSEParser { + return &ExecuteSSEParser{ + logID: logID, + } +} + +// Parse parses SSE event into ExecuteResult +func (p *ExecuteSSEParser) Parse(sse *stream.ServerSentEvent) (entity.ExecuteResult, error) { + // Skip empty data + if sse.Data == "" { + return entity.ExecuteResult{}, nil + } + + // Parse streaming response + var executeStreamingData ExecuteStreamingData + if err := json.Unmarshal([]byte(sse.Data), &executeStreamingData); err != nil { + return entity.ExecuteResult{}, fmt.Errorf("failed to unmarshal streaming response: %w", err) + } + + // Convert to ExecuteResult + result := entity.ExecuteResult{} + result.Message = toModelMessage(executeStreamingData.Message) + result.FinishReason = executeStreamingData.FinishReason + result.Usage = toModelTokenUsage(executeStreamingData.Usage) + + return result, nil +} + +// HandleError checks if the SSE event contains an error +func (p *ExecuteSSEParser) HandleError(sse *stream.ServerSentEvent) error { + // Check if event field contains "error" (case-insensitive) + if sse.Event != "" && bytes.Contains([]byte(strings.ToLower(sse.Event)), []byte("error")) { + // This is an error event, parse the data field for error information + data := sse.Data + if data == "" { + // Event indicates error but no data, return generic error + return consts.NewRemoteServiceError(http.StatusOK, -1, "Error event received without data", p.logID) + } + + // Try to parse as error response + var errResp httpclient.BaseResponse + if err := json.Unmarshal([]byte(data), &errResp); err == nil { + return consts.NewRemoteServiceError(http.StatusOK, errResp.Code, errResp.Msg, p.logID) + } + + // If no structured error found, return raw data as error message + return consts.NewRemoteServiceError(http.StatusOK, -1, data, p.logID) + } + + // Event field doesn't contain "error", this is not an error event + return nil +} + +// ExecuteStreamReader wraps BaseStreamReader for ExecuteResult +type ExecuteStreamReader struct { + *stream.BaseStreamReader[entity.ExecuteResult] +} + +// NewExecuteStreamReader creates a new ExecuteStreamReader +func NewExecuteStreamReader(ctx context.Context, resp *http.Response) (*ExecuteStreamReader, error) { + // 从响应头中获取logID + logID := resp.Header.Get(consts.LogIDHeader) + + parser := NewExecuteSSEParser(logID) + baseReader := stream.NewBaseStreamReader[entity.ExecuteResult](ctx, resp, parser) + + return &ExecuteStreamReader{ + BaseStreamReader: baseReader, + }, nil +} diff --git a/internal/prompt/openapi.go b/internal/prompt/openapi.go index 36ae148..413227c 100644 --- a/internal/prompt/openapi.go +++ b/internal/prompt/openapi.go @@ -6,7 +6,9 @@ package prompt import ( "context" "encoding/json" + "net/http" "sort" + "time" "golang.org/x/sync/singleflight" @@ -14,8 +16,12 @@ import ( ) const ( - mpullPromptPath = "/v1/loop/prompts/mget" - maxPromptQueryBatchSize = 25 + mpullPromptPath = "/v1/loop/prompts/mget" + executePromptPath = "/v1/loop/prompts/execute" + executeStreamingPromptPath = "/v1/loop/prompts/execute_streaming" + maxPromptQueryBatchSize = 25 + + defaultExecuteTimeout = 10 * time.Minute ) type Prompt struct { @@ -42,9 +48,12 @@ const ( ) type Message struct { - Role Role `json:"role"` - Content *string `json:"content,omitempty"` - Parts []*ContentPart `json:"parts,omitempty"` + Role Role `json:"role"` + ReasoningContent *string `json:"reasoning_content,omitempty"` + Content *string `json:"content,omitempty"` + Parts []*ContentPart `json:"parts,omitempty"` + ToolCallID *string `json:"tool_call_id,omitempty"` + ToolCalls []*ToolCall `json:"tool_calls,omitempty"` } type Role string @@ -58,14 +67,18 @@ const ( ) type ContentPart struct { - Type *ContentType `json:"type"` - Text *string `json:"text,omitempty"` + Type *ContentType `json:"type"` + Text *string `json:"text,omitempty"` + ImageURL *string `json:"image_url,omitempty"` + Base64Data *string `json:"base64_data,omitempty"` } type ContentType string const ( ContentTypeText ContentType = "text" + ContentTypeImageURL ContentType = "image_url" + ContentTypeBase64Data ContentType = "base64_data" ContentTypeMultiPartVariable ContentType = "multi_part_variable" ) @@ -130,6 +143,30 @@ type LLMConfig struct { JSONMode *bool `json:"json_mode,omitempty"` } +type VariableVal struct { + Key string `json:"key"` + Value *string `json:"value,omitempty"` + PlaceholderMessages []*Message `json:"placeholder_messages,omitempty"` + MultiPartValues []*ContentPart `json:"multi_part_values,omitempty"` +} + +type ToolCall struct { + Index *int32 `json:"index,omitempty"` + ID *string `json:"id,omitempty"` + Type ToolType `json:"type"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` +} + +type FunctionCall struct { + Name string `json:"name"` + Arguments *string `json:"arguments,omitempty"` +} + +type TokenUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + type OpenAPIClient struct { httpClient *httpclient.Client sf singleflight.Group @@ -225,3 +262,48 @@ func (o *OpenAPIClient) doMPullPrompt(ctx context.Context, req MPullPromptReques } return resp.Data.Items, nil } + +type ExecuteRequest struct { + WorkspaceID string `json:"workspace_id"` + PromptIdentifier *PromptQuery `json:"prompt_identifier,omitempty"` + VariableVals []*VariableVal `json:"variable_vals,omitempty"` + Messages []*Message `json:"messages,omitempty"` +} + +type ExecuteResponse struct { + httpclient.BaseResponse + Data *ExecuteData `json:"data"` +} + +type ExecuteData struct { + Message *Message `json:"message,omitempty"` + FinishReason *string `json:"finish_reason,omitempty"` + Usage *TokenUsage `json:"usage,omitempty"` +} + +// ExecuteStreamingData 流式执行响应数据结构体 +type ExecuteStreamingData struct { + Code *int32 `json:"code,omitempty"` + Msg *string `json:"msg,omitempty"` + Message *Message `json:"message,omitempty"` + FinishReason *string `json:"finish_reason,omitempty"` + Usage *TokenUsage `json:"usage,omitempty"` +} + +// Execute 执行Prompt请求 +func (o *OpenAPIClient) Execute(ctx context.Context, req ExecuteRequest) (*ExecuteData, error) { + ctx, cancel := context.WithTimeout(ctx, defaultExecuteTimeout) + defer cancel() + var response ExecuteResponse + err := o.httpClient.Post(ctx, executePromptPath, req, &response) + if err != nil { + return nil, err + } + return response.Data, nil +} + +// ExecuteStreaming 流式执行Prompt请求 +func (o *OpenAPIClient) ExecuteStreaming(ctx context.Context, req ExecuteRequest) (*http.Response, error) { + ctx, _ = context.WithTimeout(ctx, defaultExecuteTimeout) + return o.httpClient.PostStream(ctx, executeStreamingPromptPath, req) +} diff --git a/internal/prompt/prompt.go b/internal/prompt/prompt_hub.go similarity index 98% rename from internal/prompt/prompt.go rename to internal/prompt/prompt_hub.go index 929720a..a43e0ea 100644 --- a/internal/prompt/prompt.go +++ b/internal/prompt/prompt_hub.go @@ -40,11 +40,9 @@ type GetPromptParam struct { Label string } -type GetPromptOptions struct { -} +type GetPromptOptions struct{} -type PromptFormatOptions struct { -} +type PromptFormatOptions struct{} func NewPromptProvider(httpClient *httpclient.Client, traceProvider *trace.Provider, options Options) *Provider { openAPI := &OpenAPIClient{httpClient: httpClient} @@ -266,7 +264,8 @@ func validateVariableValuesType(variableDefs []*entity.VariableDef, variables ma func formatNormalMessages(templateType entity.TemplateType, messages []*entity.Message, variableDefs []*entity.VariableDef, - variableVals map[string]any) (results []*entity.Message, err error) { + variableVals map[string]any, +) (results []*entity.Message, err error) { variableDefMap := make(map[string]*entity.VariableDef) for _, variableDef := range variableDefs { if variableDef != nil { @@ -300,7 +299,8 @@ func formatNormalMessages(templateType entity.TemplateType, func formatMultiPart(templateType entity.TemplateType, parts []*entity.ContentPart, defMap map[string]*entity.VariableDef, - valMap map[string]any) []*entity.ContentPart { + valMap map[string]any, +) []*entity.ContentPart { var formatedParts []*entity.ContentPart // render text for _, part := range parts { @@ -370,7 +370,8 @@ func formatPlaceholderMessages(messages []*entity.Message, variableVals map[stri func renderTextContent(templateType entity.TemplateType, templateStr string, variableDefMap map[string]*entity.VariableDef, - variableVals map[string]any) (string, error) { + variableVals map[string]any, +) (string, error) { switch templateType { case entity.TemplateTypeNormal: return fasttemplate.ExecuteFuncString(templateStr, consts.PromptNormalTemplateStartTag, consts.PromptNormalTemplateEndTag, func(w io.Writer, tag string) (int, error) { diff --git a/internal/prompt/prompt_test.go b/internal/prompt/prompt_hub_test.go similarity index 100% rename from internal/prompt/prompt_test.go rename to internal/prompt/prompt_hub_test.go diff --git a/internal/prompt/ptaas.go b/internal/prompt/ptaas.go new file mode 100644 index 0000000..360423d --- /dev/null +++ b/internal/prompt/ptaas.go @@ -0,0 +1,181 @@ +// Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +// SPDX-License-Identifier: MIT + +package prompt + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/coze-dev/cozeloop-go/entity" + "github.com/coze-dev/cozeloop-go/internal/consts" +) + +// ExecuteOptions Execute选项 +type ExecuteOptions struct{} + +// ExecuteStreamingOptions ExecuteStreaming选项 +type ExecuteStreamingOptions struct{} + +// ExecuteOption Execute选项函数 +type ExecuteOption func(option *ExecuteOptions) + +// ExecuteStreamingOption ExecuteStreaming选项函数 +type ExecuteStreamingOption func(option *ExecuteStreamingOptions) + +// Execute 执行Prompt并返回结果 +func (p *Provider) Execute(ctx context.Context, req *entity.ExecuteParam, options ...ExecuteOption) (entity.ExecuteResult, error) { + result := entity.ExecuteResult{} + // 处理选项 + opts := &ExecuteOptions{} + for _, option := range options { + option(opts) + } + + // 构建请求体 + executeReq, err := buildExecuteRequest(req, p.config.WorkspaceID) + if err != nil { + return entity.ExecuteResult{}, err + } + + // 通过OpenAPIClient发送HTTP请求 + data, err := p.openAPIClient.Execute(ctx, executeReq) + if err != nil { + return result, err + } + + if data != nil { + result.Message = toModelMessage(data.Message) + result.FinishReason = data.FinishReason + result.Usage = toModelTokenUsage(data.Usage) + } + // 转换响应 + return result, nil +} + +// ExecuteStreaming 流式执行Prompt并返回流式读取器 +func (p *Provider) ExecuteStreaming(ctx context.Context, req *entity.ExecuteParam, options ...ExecuteStreamingOption) (entity.StreamReader[entity.ExecuteResult], error) { + // 处理选项 + opts := &ExecuteStreamingOptions{} + for _, option := range options { + option(opts) + } + + // 构建请求体 + executeReq, err := buildExecuteRequest(req, p.config.WorkspaceID) + if err != nil { + return nil, err + } + + // 通过OpenAPIClient发送流式HTTP请求 + resp, err := p.openAPIClient.ExecuteStreaming(ctx, executeReq) + if err != nil { + return nil, err + } + + // 创建新的流式读取器 + streamReader, err := NewExecuteStreamReader(ctx, resp) + if err != nil { + return nil, err + } + + return streamReader, nil +} + +// buildExecuteRequest 构建Execute请求体 +func buildExecuteRequest(param *entity.ExecuteParam, workspaceID string) (ExecuteRequest, error) { + if param == nil { + return ExecuteRequest{}, consts.ErrInvalidParam.Wrap(fmt.Errorf("execute param is nil")) + } + if param.PromptKey == "" { + return ExecuteRequest{}, consts.ErrInvalidParam.Wrap(fmt.Errorf("prompt key is empty")) + } + + executeReq := ExecuteRequest{ + WorkspaceID: workspaceID, + PromptIdentifier: &PromptQuery{ + PromptKey: param.PromptKey, + Version: param.Version, + Label: param.Label, + }, + Messages: toOpenAPIMessages(param.Messages), + } + + // 添加变量值 + var variableVals []*VariableVal + for key, value := range param.VariableVals { + if value == nil { + return ExecuteRequest{}, consts.ErrInvalidParam.Wrap(fmt.Errorf("variable: %s val is nil", key)) + } + + variableVal := &VariableVal{Key: key} + switch v := value.(type) { + // string 类型 + case string: + variableVal.Value = &v + // string 指针类型 + case *string: + if v == nil { + return ExecuteRequest{}, consts.ErrInvalidParam.Wrap(fmt.Errorf("variable: %s val is nil", key)) + } + variableVal.Value = v + + // Message 相关类型 + case entity.Message: + variableVal.PlaceholderMessages = []*Message{toOpenAPIMessage(&v)} + case *entity.Message: + if v == nil { + return ExecuteRequest{}, consts.ErrInvalidParam.Wrap(fmt.Errorf("variable: %s val is nil", key)) + } + variableVal.PlaceholderMessages = []*Message{toOpenAPIMessage(v)} + case []*entity.Message: + var apiMsgs []*Message + for _, msg := range v { + apiMsgs = append(apiMsgs, toOpenAPIMessage(msg)) + } + variableVal.PlaceholderMessages = apiMsgs + case []entity.Message: + var apiMsgs []*Message + for _, msg := range v { + apiMsgs = append(apiMsgs, toOpenAPIMessage(&msg)) + } + variableVal.PlaceholderMessages = apiMsgs + + // ContentPart 相关类型 + case entity.ContentPart: + variableVal.MultiPartValues = []*ContentPart{toOpenAPIContentPart(&v)} + case *entity.ContentPart: + if v == nil { + return ExecuteRequest{}, consts.ErrInvalidParam.Wrap(fmt.Errorf("variable: %s val is nil", key)) + } + variableVal.MultiPartValues = []*ContentPart{toOpenAPIContentPart(v)} + case []*entity.ContentPart: + var apiParts []*ContentPart + for _, part := range v { + apiParts = append(apiParts, toOpenAPIContentPart(part)) + } + variableVal.MultiPartValues = apiParts + case []entity.ContentPart: + var apiParts []*ContentPart + for _, part := range v { + apiParts = append(apiParts, toOpenAPIContentPart(&part)) + } + variableVal.MultiPartValues = apiParts + + // 其他类型序列化后传入 Value 字段 + default: + jsonBytes, err := json.Marshal(value) + if err != nil { + return ExecuteRequest{}, consts.ErrInvalidParam.Wrap(fmt.Errorf("failed to marshal variable %s: %w", key, err)) + } + jsonStr := string(jsonBytes) + variableVal.Value = &jsonStr + } + + variableVals = append(variableVals, variableVal) + } + executeReq.VariableVals = variableVals + + return executeReq, nil +} diff --git a/internal/stream/base_reader.go b/internal/stream/base_reader.go new file mode 100755 index 0000000..323cb81 --- /dev/null +++ b/internal/stream/base_reader.go @@ -0,0 +1,103 @@ +// Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +// SPDX-License-Identifier: MIT + +package stream + +import ( + "context" + "fmt" + "net/http" +) + +// SSEParser defines the interface for parsing SSE events into specific types +type SSEParser[T any] interface { + Parse(sse *ServerSentEvent) (T, error) + HandleError(sse *ServerSentEvent) error +} + +// BaseStreamReader provides generic SSE stream reading capabilities +type BaseStreamReader[T any] struct { + ctx context.Context + response *http.Response + decoder *SSEDecoder + parser SSEParser[T] + closed bool + events <-chan SSEEvent +} + +// NewBaseStreamReader creates a new base stream reader +func NewBaseStreamReader[T any](ctx context.Context, resp *http.Response, parser SSEParser[T]) *BaseStreamReader[T] { + decoder := NewSSEDecoder(resp.Body) + events := decoder.Decode(ctx) + + return &BaseStreamReader[T]{ + ctx: ctx, + response: resp, + decoder: decoder, + parser: parser, + closed: false, + events: events, + } +} + +// Recv receives the next item from the stream +func (r *BaseStreamReader[T]) Recv() (T, error) { + var zero T + + if r.closed { + return zero, fmt.Errorf("stream reader is closed") + } + + for { + select { + case <-r.ctx.Done(): + r.Close() + return zero, r.ctx.Err() + + case sseEvent, ok := <-r.events: + if !ok { + // Channel closed, stream ended + r.Close() + return zero, fmt.Errorf("stream ended") + } + + if sseEvent.Error != nil { + r.Close() + return zero, sseEvent.Error + } + + if sseEvent.Event == nil { + continue + } + + // Check for error events first + if err := r.parser.HandleError(sseEvent.Event); err != nil { + r.Close() + return zero, err + } + + // Parse the event + result, err := r.parser.Parse(sseEvent.Event) + if err != nil { + // Continue to next event for parsing errors + continue + } + + return result, nil + } + } +} + +// Close closes the stream reader and releases resources +func (r *BaseStreamReader[T]) Close() error { + if r.closed { + return nil + } + + r.closed = true + if r.response != nil && r.response.Body != nil { + return r.response.Body.Close() + } + + return nil +} diff --git a/internal/stream/sse.go b/internal/stream/sse.go new file mode 100755 index 0000000..3f7c872 --- /dev/null +++ b/internal/stream/sse.go @@ -0,0 +1,134 @@ +// Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +// SPDX-License-Identifier: MIT + +package stream + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "strconv" + "strings" + + "github.com/coze-dev/cozeloop-go/internal/util" +) + +// ServerSentEvent represents a Server-Sent Event +type ServerSentEvent struct { + Event string + Data string + ID string + Retry *int +} + +// JSON unmarshals the Data field into the provided interface +func (sse *ServerSentEvent) JSON(v interface{}) error { + if sse.Data == "" { + return fmt.Errorf("empty data field") + } + return json.Unmarshal([]byte(sse.Data), v) +} + +// SSEDecoder decodes Server-Sent Events from an io.Reader +type SSEDecoder struct { + scanner *bufio.Scanner +} + +// NewSSEDecoder creates a new SSE decoder +func NewSSEDecoder(reader io.Reader) *SSEDecoder { + return &SSEDecoder{ + scanner: bufio.NewScanner(reader), + } +} + +// Decode decodes SSE events from the reader and returns a channel +func (d *SSEDecoder) Decode(ctx context.Context) <-chan SSEEvent { + ch := make(chan SSEEvent, 1) + + util.GoSafe(ctx, func() { + defer close(ch) + + for { + event, err := d.DecodeEvent() + ch <- SSEEvent{ + Event: event, + Error: err, + } + } + }) + + return ch +} + +// SSEEvent wraps either an event or an error +type SSEEvent struct { + Event *ServerSentEvent + Error error +} + +// DecodeEvent decodes a single SSE event +func (d *SSEDecoder) DecodeEvent() (*ServerSentEvent, error) { + event := &ServerSentEvent{} + var dataLines []string + + for d.scanner.Scan() { + line := d.scanner.Text() + + // Empty line indicates end of event + if strings.TrimSpace(line) == "" { + if len(dataLines) > 0 || event.Event != "" || event.ID != "" || event.Retry != nil { + event.Data = strings.Join(dataLines, "\n") + return event, nil + } + continue + } + + colonIndex := strings.Index(line, ":") + if colonIndex == -1 { + // Line without colon, treat as field name with empty value + field := strings.TrimSpace(line) + d.processField(event, field, "", &dataLines) + continue + } + + field := line[:colonIndex] + value := line[colonIndex+1:] + + // Remove leading space from value + if strings.HasPrefix(value, " ") { + value = value[1:] + } + + d.processField(event, field, value, &dataLines) + } + + if err := d.scanner.Err(); err != nil { + return nil, err + } + + // If we reach here, it's EOF + if len(dataLines) > 0 || event.Event != "" || event.ID != "" || event.Retry != nil { + event.Data = strings.Join(dataLines, "\n") + return event, nil + } + + return nil, io.EOF +} + +// processField processes a single SSE field +func (d *SSEDecoder) processField(event *ServerSentEvent, field, value string, dataLines *[]string) { + switch field { + case "event": + event.Event = value + case "data": + *dataLines = append(*dataLines, value) + case "id": + event.ID = value + case "retry": + if retry, err := strconv.Atoi(value); err == nil { + event.Retry = &retry + } + } +} diff --git a/internal/trace/exporter.go b/internal/trace/exporter.go index a8bb9a5..0671c12 100644 --- a/internal/trace/exporter.go +++ b/internal/trace/exporter.go @@ -187,16 +187,14 @@ func parseTag(spanTag map[string]interface{}, isSystemTag bool) (map[string]stri return vStrMap, vLongMap, vDoubleMap, vBoolMap } -var ( - tagValueConverterMap = map[string]*tagValueConverter{ - tracespec.Input: { - convertFunc: convertInput, - }, - tracespec.Output: { - convertFunc: convertOutput, - }, - } -) +var tagValueConverterMap = map[string]*tagValueConverter{ + tracespec.Input: { + convertFunc: convertInput, + }, + tracespec.Output: { + convertFunc: convertOutput, + }, +} type tagValueConverter struct { convertFunc func(ctx context.Context, spanKey string, span *Span) (valueRes string, uploadFile []*entity.UploadFile, err error) @@ -395,7 +393,7 @@ func transferText(src string, span *Span, tagKey string) (string, *entity.Upload } if len(src) > consts.MaxBytesOfOneTagValueOfInputOutput { - //key := "traceid/spanid/tagkey/filetype/large_text" + // key := "traceid/spanid/tagkey/filetype/large_text" key := fmt.Sprintf(KeyTemplateLargeText, span.GetTraceID(), span.GetSpanID(), tagKey, fileTypeText) return util.TruncateStringByChar(src, consts.TextTruncateCharLength), &entity.UploadFile{ TosKey: key, @@ -418,7 +416,7 @@ func transferImage(src *tracespec.ModelImageURL, span *Span, tagKey string) *ent return nil } - //key := "traceid_spanid_tagkey_filetype_randomid" + // key := "traceid_spanid_tagkey_filetype_randomid" key := fmt.Sprintf(KeyTemplateMultiModality, span.GetTraceID(), span.GetSpanID(), tagKey, fileTypeImage, util.Gen16CharID()) bin, _ := base64.StdEncoding.DecodeString(src.URL) src.URL = key @@ -441,7 +439,7 @@ func transferFile(src *tracespec.ModelFileURL, span *Span, tagKey string) *entit return nil } - //key := "traceid/spanid/tagkey/filetype/randomid" + // key := "traceid/spanid/tagkey/filetype/randomid" key := fmt.Sprintf(KeyTemplateMultiModality, span.GetTraceID(), span.GetSpanID(), tagKey, fileTypeFile, util.Gen16CharID()) bin, _ := base64.StdEncoding.DecodeString(src.URL) src.URL = key diff --git a/internal/trace/exporter_test.go b/internal/trace/exporter_test.go index 42f3faf..2061e04 100644 --- a/internal/trace/exporter_test.go +++ b/internal/trace/exporter_test.go @@ -14,7 +14,7 @@ import ( func Test_ExportSpans(t *testing.T) { ctx := context.Background() - spans := []*UploadSpan{&UploadSpan{}, &UploadSpan{}} + spans := []*UploadSpan{{}, {}} PatchConvey("Test transferToUploadSpanAndFile failed", t, func() { Mock((*httpclient.Client).Post).Return(nil).Build() diff --git a/internal/trace/queue_manager.go b/internal/trace/queue_manager.go index 9247059..c89e64b 100644 --- a/internal/trace/queue_manager.go +++ b/internal/trace/queue_manager.go @@ -199,7 +199,7 @@ func (b *BatchQueueManager) Enqueue(ctx context.Context, sd interface{}, byteSiz return } var extraParams *consts.FinishEventInfoExtra - var eventType = consts.SpanFinishEventFileQueueEntryRate + eventType := consts.SpanFinishEventFileQueueEntryRate var detailMsg string var isFail bool select { diff --git a/internal/trace/span.go b/internal/trace/span.go index dbff22e..a40937b 100644 --- a/internal/trace/span.go +++ b/internal/trace/span.go @@ -636,7 +636,7 @@ func (s *Span) addDefaultTag(ctx context.Context, tagKVs map[string]interface{}) // GetRectifiedMap get rectified tag map and cut off keys func (s *Span) GetRectifiedMap(ctx context.Context, inputMap map[string]interface{}) (map[string]interface{}, []string, int64) { - var validateMap = make(map[string]interface{}) + validateMap := make(map[string]interface{}) var cutOffKeys []string var bytesSize int64 for key, value := range inputMap { @@ -777,7 +777,7 @@ func isValidBaggageItem(ctx context.Context, key, value string) bool { logger.CtxInfof(ctx, "length of Baggage is too large, key:%s, value:%s", key, value) return false } - //special char check + // special char check if hasSpecialChar(key) { logger.CtxErrorf(ctx, "Baggage should not contain special characters, key:%s, value:%s", key, value) return false diff --git a/internal/trace/span_processor.go b/internal/trace/span_processor.go index 57a6cab..320f632 100644 --- a/internal/trace/span_processor.go +++ b/internal/trace/span_processor.go @@ -79,8 +79,8 @@ func NewBatchSpanProcessor( if ex != nil { exporter = ex } - var spanQueueLength = DefaultMaxQueueLength - var spanMaxExportBatchLength = DefaultMaxExportBatchLength + spanQueueLength := DefaultMaxQueueLength + spanMaxExportBatchLength := DefaultMaxExportBatchLength if queueConf != nil { if queueConf.SpanQueueLength > 0 { spanQueueLength = queueConf.SpanQueueLength diff --git a/internal/trace/span_test.go b/internal/trace/span_test.go index 0a59834..b046ed0 100644 --- a/internal/trace/span_test.go +++ b/internal/trace/span_test.go @@ -165,7 +165,7 @@ func Test_SpanSpecialTag(t *testing.T) { now := time.Now() s := &Span{ isFinished: 0, - //spanProcessor: GetBatchSpanProcessor(httpClient, GetBatchFileProcessor(httpClient)), + // spanProcessor: GetBatchSpanProcessor(httpClient, GetBatchFileProcessor(httpClient)), lock: sync.RWMutex{}, TagMap: make(map[string]interface{}), } @@ -282,7 +282,6 @@ func Test_SpanSpecialTag(t *testing.T) { So(len(span.GetTagMap()), ShouldEqual, 14) So(len(span.GetBaggage()), ShouldEqual, 7) - }) } diff --git a/internal/util/convert.go b/internal/util/convert.go index d86dd93..22e7036 100644 --- a/internal/util/convert.go +++ b/internal/util/convert.go @@ -30,11 +30,9 @@ func RmDupStrSlice(slice []string) []string { return res } -var ( - bufferPool = sync.Pool{New: func() interface{} { - return new(bytes.Buffer) - }} -) +var bufferPool = sync.Pool{New: func() interface{} { + return new(bytes.Buffer) +}} func GetStringBuffer() *bytes.Buffer { return bufferPool.Get().(*bytes.Buffer) diff --git a/internal/util/validate_test.go b/internal/util/validate_test.go index 831f90c..8f4f198 100644 --- a/internal/util/validate_test.go +++ b/internal/util/validate_test.go @@ -12,5 +12,4 @@ func TestIsValidMDNBase64(t *testing.T) { t.Errorf("ParseValidMDNBase64() = %v", got) } }) - } diff --git a/noop.go b/noop.go index dbee28b..41a4e6a 100644 --- a/noop.go +++ b/noop.go @@ -37,6 +37,16 @@ func (c *NoopClient) PromptFormat(ctx context.Context, prompt *entity.Prompt, va return nil, c.newClientError } +func (c *NoopClient) Execute(ctx context.Context, req *entity.ExecuteParam, options ...ExecuteOption) (entity.ExecuteResult, error) { + logger.CtxWarnf(context.Background(), "Noop client not supported. %v", c.newClientError) + return entity.ExecuteResult{}, c.newClientError +} + +func (c *NoopClient) ExecuteStreaming(ctx context.Context, req *entity.ExecuteParam, options ...ExecuteStreamingOption) (entity.StreamReader[entity.ExecuteResult], error) { + logger.CtxWarnf(context.Background(), "Noop client not supported. %v", c.newClientError) + return nil, c.newClientError +} + func (c *NoopClient) StartSpan(ctx context.Context, name, spanType string, opts ...StartSpanOption) (context.Context, Span) { logger.CtxWarnf(context.Background(), "Noop client not supported. %v", c.newClientError) return ctx, DefaultNoopSpan diff --git a/prompt.go b/prompt.go index 9aa81bf..86b0b1e 100644 --- a/prompt.go +++ b/prompt.go @@ -17,6 +17,10 @@ type PromptClient interface { GetPrompt(ctx context.Context, param GetPromptParam, options ...GetPromptOption) (*entity.Prompt, error) // PromptFormat format prompt with variables PromptFormat(ctx context.Context, prompt *entity.Prompt, variables map[string]any, options ...PromptFormatOption) (messages []*entity.Message, err error) + // Execute execute prompt and return result + Execute(ctx context.Context, param *entity.ExecuteParam, options ...ExecuteOption) (entity.ExecuteResult, error) + // ExecuteStreaming execute prompt in streaming mode and return stream reader + ExecuteStreaming(ctx context.Context, param *entity.ExecuteParam, options ...ExecuteStreamingOption) (entity.StreamReader[entity.ExecuteResult], error) } type GetPromptParam = prompt.GetPromptParam @@ -24,3 +28,7 @@ type GetPromptParam = prompt.GetPromptParam type GetPromptOption func(option *prompt.GetPromptOptions) type PromptFormatOption func(option *prompt.PromptFormatOptions) + +type ExecuteOption = prompt.ExecuteOption + +type ExecuteStreamingOption = prompt.ExecuteStreamingOption diff --git a/span.go b/span.go index 34ee338..f9a4044 100644 --- a/span.go +++ b/span.go @@ -20,7 +20,7 @@ type Span interface { SetTags(ctx context.Context, tagKVs map[string]interface{}) // SetBaggage sets tags and also passes these tags to other downstream spans (assuming - //the user uses ToHeader and FromHeader to handle header passing between services). + // the user uses ToHeader and FromHeader to handle header passing between services). SetBaggage(ctx context.Context, baggageItems map[string]string) // Finish The span will be reported only after an explicit call to Finish. diff --git a/spec/tracespec/span_key.go b/spec/tracespec/span_key.go index 74110b1..54c12de 100644 --- a/spec/tracespec/span_key.go +++ b/spec/tracespec/span_key.go @@ -50,4 +50,5 @@ const ( CallType = "call_type" LogID = "log_id" + TraceID = "trace_id" ) diff --git a/spec/tracespec/span_value.go b/spec/tracespec/span_value.go index d7d6511..990bdad 100644 --- a/spec/tracespec/span_value.go +++ b/spec/tracespec/span_value.go @@ -5,11 +5,13 @@ package tracespec // SpanType tag builtin values const ( - VPromptHubSpanType = "prompt_hub" - VPromptTemplateSpanType = "prompt" - VModelSpanType = "model" - VRetrieverSpanType = "retriever" - VToolSpanType = "tool" + VPromptHubSpanType = "prompt_hub" + VPromptTemplateSpanType = "prompt" + VPromptExecuteSpanType = "prompt_execute" + VPromptExecuteStreamingSpanType = "prompt_execute_streaming" + VModelSpanType = "model" + VRetrieverSpanType = "retriever" + VToolSpanType = "tool" ) const ( @@ -40,10 +42,12 @@ const ( VLibLangChain = "langchain" VLibOpentelemetry = "opentelemetry" - VSceneCustom = "custom" // user custom, it has the same meaning as blank. - VScenePromptHub = "prompt_hub" // get_prompt - VScenePromptTemplate = "prompt_template" // prompt_template - VSceneIntegration = "integration" + VSceneCustom = "custom" // user custom, it has the same meaning as blank. + VScenePromptHub = "prompt_hub" // get_prompt + VScenePromptTemplate = "prompt_template" // prompt_template + VScenePromptExecute = "prompt_execute" // execute_prompt + VScenePromptExecuteStreaming = "prompt_execute_streaming" // execute_prompt_streaming + VSceneIntegration = "integration" ) // Tag values for prompt input.