diff --git a/components/model/openai-go/README.md b/components/model/openai-go/README.md new file mode 100644 index 000000000..dd9a1df21 --- /dev/null +++ b/components/model/openai-go/README.md @@ -0,0 +1,80 @@ +# OpenAI (official openai-go SDK) + +An OpenAI model implementation for [Eino](https://github.com/cloudwego/eino) using the official OpenAI Go SDK (`github.com/openai/openai-go/v3`). This is intended as a starting point for eventual replacement of the existing openai implementation (see ../openai) which is based on github.com/sashabaranov/go-openai. Newer models from OpenAI increasingly do not fully support the older chat completions API which github.com/sashabaranov/go-openai is based on. Consequently, this component targets the **Responses API only**. + +## Features + +- Implements `github.com/cloudwego/eino/components/model.ToolCallingChatModel` +- Responses API (non-stream + streaming) +- Tool calling support (function tools) +- Multimodal inputs via `schema.Message.UserInputMultiContent`: + - text + - image_url (URL or base64 via `Base64Data` + `MIMEType`) + - file_url (URL or base64) + +## Installation + +```bash +go get github.com/cloudwego/eino-ext/components/model/openai-go@latest +``` + +## Quick start + +```go +package main + +import ( + "context" + "log" + "os" + + "github.com/cloudwego/eino/schema" + "github.com/cloudwego/eino-ext/components/model/openai-go" +) + +func main() { + ctx := context.Background() + + cm, err := openaigo.NewChatModel(ctx, &openaigo.Config{ + APIKey: os.Getenv("OPENAI_API_KEY"), + Model: "gpt-5.4", // any Responses API capable model + }) + if err != nil { + log.Fatal(err) + } + + out, err := cm.Generate(ctx, []*schema.Message{ + {Role: schema.User, Content: "Hello"}, + }) + if err != nil { + log.Fatal(err) + } + + log.Println(out.Content) +} +``` + +## Tool calling + +Bind tools using `WithTools()`: + +```go +cm2, err := cm.WithTools([]*schema.ToolInfo{ + { + Name: "get_weather", + Desc: "Get weather at the given location", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "location": {Type: schema.String, Required: true}, + }), + }, +}) +``` + +Then control selection with Eino common options: + +- `model.WithTools(...)` +- `model.WithToolChoice(schema.ToolChoiceAllowed|Forced|Forbidden, allowedToolNames...)` + +## Streaming + +Use `Stream()` to receive incremental `*schema.Message` deltas. diff --git a/components/model/openai-go/chatmodel.go b/components/model/openai-go/chatmodel.go new file mode 100644 index 000000000..4647c8012 --- /dev/null +++ b/components/model/openai-go/chatmodel.go @@ -0,0 +1,189 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openaigo + +import ( + "context" + "errors" + "fmt" + "net/http" + "time" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" + "github.com/openai/openai-go/v3/responses" +) + +var _ model.ToolCallingChatModel = (*ChatModel)(nil) + +type Config struct { + APIKey string `json:"api_key"` + + // Timeout specifies the maximum duration to wait for API responses. + // If HTTPClient is set, Timeout will not be used. + // Optional. Default: no timeout + Timeout time.Duration `json:"timeout"` + + // HTTPClient specifies the client to send HTTP requests. + // If HTTPClient is set, Timeout will not be used. + // Optional. Default &http.Client{Timeout: Timeout} + HTTPClient *http.Client `json:"http_client"` + + // BaseURL specifies the OpenAI endpoint URL + // Optional. Default: https://api.openai.com/v1 + BaseURL string `json:"base_url"` + + // Model specifies the ID of the model to use. + // Optional. + Model string `json:"model,omitempty"` + + // MaxOutputTokens is an upper bound for the number of tokens that can be generated for a response, + // including visible output tokens and reasoning tokens. + MaxOutputTokens *int `json:"max_output_tokens,omitempty"` + + TopP *float32 `json:"top_p,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + + // Reasoning config for reasoning models. + Reasoning *Reasoning `json:"reasoning,omitempty"` + + // Store indicates whether to store the generated model response for later retrieval. + Store *bool `json:"store,omitempty"` + + // Metadata set of key-value pairs that can be attached to an object. + Metadata map[string]string `json:"metadata,omitempty"` + + // ExtraFields will override any existing fields with the same key. + // Optional. Useful for experimental features not yet officially supported. + ExtraFields map[string]any `json:"extra_fields,omitempty"` +} + +type ChatModel struct { + cli openai.Client + + model string + maxOutTok *int + topP *float32 + temperature *float32 + reasoning *Reasoning + store *bool + metadata map[string]string + extraFields map[string]any + + tools []responses.ToolUnionParam + rawTools []*schema.ToolInfo + toolChoice *schema.ToolChoice +} + +func NewChatModel(_ context.Context, config *Config) (*ChatModel, error) { + if config == nil { + return nil, fmt.Errorf("config cannot be nil") + } + + opts := make([]option.RequestOption, 0, 4) + if config.APIKey != "" { + opts = append(opts, option.WithAPIKey(config.APIKey)) + } + if config.BaseURL != "" { + opts = append(opts, option.WithBaseURL(config.BaseURL)) + } + if config.HTTPClient != nil { + opts = append(opts, option.WithHTTPClient(config.HTTPClient)) + } else if config.Timeout > 0 { + opts = append(opts, option.WithHTTPClient(&http.Client{Timeout: config.Timeout})) + } + + cli := openai.NewClient(opts...) + + cm := &ChatModel{ + cli: cli, + model: config.Model, + maxOutTok: config.MaxOutputTokens, + topP: config.TopP, + temperature: config.Temperature, + reasoning: config.Reasoning, + store: config.Store, + metadata: cloneStringMap(config.Metadata), + extraFields: cloneAnyMap(config.ExtraFields), + } + + return cm, nil +} + +func (cm *ChatModel) Generate(ctx context.Context, in []*schema.Message, opts ...model.Option) (outMsg *schema.Message, err error) { + ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel) + + params, cbIn, err := cm.buildParams(in, false, opts...) + if err != nil { + return nil, err + } + + ctx = callbacks.OnStart(ctx, cbIn) + defer func() { + if err != nil { + callbacks.OnError(ctx, err) + } + }() + + resp, err := cm.cli.Responses.New(ctx, params) + if err != nil { + return nil, err + } + + outMsg, err = cm.convertResponseToMessage(resp) + if err != nil { + return nil, err + } + + callbacks.OnEnd(ctx, &model.CallbackOutput{ + Message: outMsg, + Config: cbIn.Config, + TokenUsage: toModelTokenUsage(outMsg.ResponseMeta), + Extra: map[string]any{ + callbackExtraModelName: string(resp.Model), + }, + }) + + return outMsg, nil +} + +func (cm *ChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + if len(tools) == 0 { + return nil, errors.New("no tools to bind") + } + openAITools, rawTools, err := toOpenAITools(tools) + if err != nil { + return nil, err + } + + tc := schema.ToolChoiceAllowed + ncm := *cm + ncm.tools = openAITools + ncm.rawTools = rawTools + ncm.toolChoice = &tc + return &ncm, nil +} + +const typ = "OpenAI" + +func (cm *ChatModel) GetType() string { return typ } + +func (cm *ChatModel) IsCallbacksEnabled() bool { return true } diff --git a/components/model/openai-go/chatmodel_test.go b/components/model/openai-go/chatmodel_test.go new file mode 100644 index 000000000..c7bbf4f72 --- /dev/null +++ b/components/model/openai-go/chatmodel_test.go @@ -0,0 +1,133 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openaigo + +import ( + "context" + "testing" + + "github.com/cloudwego/eino/schema" +) + +func TestNewChatModel_NilConfig(t *testing.T) { + cm, err := NewChatModel(context.Background(), nil) + if err == nil { + t.Fatalf("expected error") + } + if cm != nil { + t.Fatalf("expected nil model") + } +} + +func TestNewChatModel_Basic(t *testing.T) { + cm, err := NewChatModel(context.Background(), &Config{APIKey: "test", Model: "gpt-4o-mini"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cm == nil { + t.Fatalf("expected non-nil model") + } + if cm.GetType() != typ { + t.Fatalf("expected type %q, got %q", typ, cm.GetType()) + } + if !cm.IsCallbacksEnabled() { + t.Fatalf("expected callbacks enabled") + } +} + +func TestNewChatModel_ClonesConfigMaps(t *testing.T) { + metadata := map[string]string{"source": "config"} + extra := map[string]any{"trace_id": "abc123"} + + cm, err := NewChatModel(context.Background(), &Config{ + Model: "gpt-4o-mini", + Metadata: metadata, + ExtraFields: extra, + Reasoning: &Reasoning{ + Effort: ReasoningEffortLow, + Summary: ReasoningSummaryDetailed, + }, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + metadata["source"] = "mutated" + extra["trace_id"] = "changed" + + if got := cm.metadata["source"]; got != "config" { + t.Fatalf("expected cloned metadata to stay unchanged, got %q", got) + } + if got := cm.extraFields["trace_id"]; got != "abc123" { + t.Fatalf("expected cloned extra fields to stay unchanged, got %#v", got) + } + if cm.reasoning == nil || cm.reasoning.Effort != ReasoningEffortLow || cm.reasoning.Summary != ReasoningSummaryDetailed { + t.Fatalf("unexpected reasoning config: %#v", cm.reasoning) + } +} + +func TestWithTools(t *testing.T) { + cm := &ChatModel{} + + if _, err := cm.WithTools(nil); err == nil { + t.Fatalf("expected error for empty tools") + } + + if _, err := cm.WithTools([]*schema.ToolInfo{nil}); err == nil { + t.Fatalf("expected error for nil tool") + } + + tool := makeWeatherTool() + binding, err := cm.WithTools([]*schema.ToolInfo{tool}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + bound, ok := binding.(*ChatModel) + if !ok { + t.Fatalf("expected *ChatModel, got %T", binding) + } + if bound == cm { + t.Fatalf("expected WithTools to return a bound copy") + } + if len(bound.tools) != 1 || len(bound.rawTools) != 1 { + t.Fatalf("expected one bound tool, got tools=%d rawTools=%d", len(bound.tools), len(bound.rawTools)) + } + if bound.rawTools[0] != tool { + t.Fatalf("expected raw tool to be preserved") + } + if bound.toolChoice == nil || *bound.toolChoice != schema.ToolChoiceAllowed { + t.Fatalf("expected allowed tool choice, got %#v", bound.toolChoice) + } + if len(cm.tools) != 0 || len(cm.rawTools) != 0 || cm.toolChoice != nil { + t.Fatalf("expected receiver to remain unchanged") + } +} + +func TestToInputItems_ToolOutputString(t *testing.T) { + items, err := toInputItems([]*schema.Message{{ + Role: schema.Tool, + ToolCallID: "call_1", + Content: "ok", + }}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(items) != 1 { + t.Fatalf("expected 1 item, got %d", len(items)) + } +} diff --git a/components/model/openai-go/examples/generate/generate.go b/components/model/openai-go/examples/generate/generate.go new file mode 100644 index 000000000..5586ebb45 --- /dev/null +++ b/components/model/openai-go/examples/generate/generate.go @@ -0,0 +1,49 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "fmt" + "log" + "os" + + "github.com/cloudwego/eino/schema" + + openaigo "github.com/cloudwego/eino-ext/components/model/openai-go" +) + +func main() { + ctx := context.Background() + + cm, err := openaigo.NewChatModel(ctx, &openaigo.Config{ + APIKey: os.Getenv("OPENAI_API_KEY"), + Model: os.Getenv("OPENAI_MODEL"), + BaseURL: os.Getenv("OPENAI_BASE_URL"), + }) + if err != nil { + log.Fatalf("NewChatModel failed, err=%v", err) + } + + resp, err := cm.Generate(ctx, []*schema.Message{ + {Role: schema.User, Content: "as a machine, how do you answer user's question?"}, + }) + if err != nil { + log.Fatalf("Generate failed, err=%v", err) + } + fmt.Printf("output: \n%v", resp) +} diff --git a/components/model/openai-go/examples/generate_with_image/generate_with_image.go b/components/model/openai-go/examples/generate_with_image/generate_with_image.go new file mode 100644 index 000000000..126a05219 --- /dev/null +++ b/components/model/openai-go/examples/generate_with_image/generate_with_image.go @@ -0,0 +1,69 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "fmt" + "log" + "os" + + "github.com/cloudwego/eino/schema" + + openaigo "github.com/cloudwego/eino-ext/components/model/openai-go" +) + +func main() { + ctx := context.Background() + + cm, err := openaigo.NewChatModel(ctx, &openaigo.Config{ + APIKey: os.Getenv("OPENAI_API_KEY"), + Model: os.Getenv("OPENAI_MODEL"), + BaseURL: os.Getenv("OPENAI_BASE_URL"), + }) + if err != nil { + log.Fatalf("NewChatModel failed, err=%v", err) + } + + multiModalMsg := &schema.Message{ + Role: schema.User, + UserInputMultiContent: []schema.MessageInputPart{ + { + Type: schema.ChatMessagePartTypeText, + Text: "this picture is a landscape photo, what's the picture's content", + }, + { + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageInputImage{ + MessagePartCommon: schema.MessagePartCommon{ + URL: of("https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcT11qEDxU4X_MVKYQVU5qiAVFidA58f8GG0bQ&s"), + }, + Detail: schema.ImageURLDetailAuto, + }, + }, + }, + } + + resp, err := cm.Generate(ctx, []*schema.Message{multiModalMsg}) + if err != nil { + log.Fatalf("Generate failed, err=%v", err) + } + + fmt.Printf("output: \n%v", resp) +} + +func of[T any](a T) *T { return &a } diff --git a/components/model/openai-go/examples/image_generate/image_generate.go b/components/model/openai-go/examples/image_generate/image_generate.go new file mode 100644 index 000000000..405a60566 --- /dev/null +++ b/components/model/openai-go/examples/image_generate/image_generate.go @@ -0,0 +1,79 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "encoding/json" + "log" + "os" + + "github.com/cloudwego/eino/schema" + + openaigo "github.com/cloudwego/eino-ext/components/model/openai-go" +) + +func main() { + ctx := context.Background() + + cm, err := openaigo.NewChatModel(ctx, &openaigo.Config{ + APIKey: os.Getenv("OPENAI_API_KEY"), + Model: os.Getenv("OPENAI_MODEL"), // model should support image generation via Responses API + BaseURL: os.Getenv("OPENAI_BASE_URL"), + }) + if err != nil { + log.Fatalf("NewChatModel failed, err=%v", err) + } + + /* + The generated multimodal content is stored in the `AssistantGenMultiContent` field. + For this example, the resulting message will have a structure similar to this: + + resp := &schema.Message{ + Role: schema.Assistant, + AssistantGenMultiContent: []schema.MessageOutputPart{ + { + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageOutputImage{ + MessagePartCommon: schema.MessagePartCommon{ + Base64Data: &base64String, // The base64 encoded image data + MIMEType: "image/png", + }, + }, + }, + }, + } + */ + resp, err := cm.Generate(ctx, []*schema.Message{ + { + Role: schema.User, + UserInputMultiContent: []schema.MessageInputPart{ + { + Type: schema.ChatMessagePartTypeText, + Text: "Generate an image of a cat", + }, + }, + }, + }) + if err != nil { + log.Fatalf("Generate error: %v", err) + } + + log.Printf("\ngenerate output:\n") + respBody, _ := json.MarshalIndent(resp, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} diff --git a/components/model/openai-go/examples/intent_tool/intent_tool.go b/components/model/openai-go/examples/intent_tool/intent_tool.go new file mode 100644 index 000000000..960894d1b --- /dev/null +++ b/components/model/openai-go/examples/intent_tool/intent_tool.go @@ -0,0 +1,114 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "fmt" + "io" + "log" + "os" + + "github.com/cloudwego/eino/schema" + + openaigo "github.com/cloudwego/eino-ext/components/model/openai-go" +) + +func main() { + ctx := context.Background() + + chatModel, err := openaigo.NewChatModel(ctx, &openaigo.Config{ + APIKey: os.Getenv("OPENAI_API_KEY"), + Model: os.Getenv("OPENAI_MODEL"), + BaseURL: os.Getenv("OPENAI_BASE_URL"), + Reasoning: &openaigo.Reasoning{ + Effort: openaigo.ReasoningEffortMedium, + Summary: openaigo.ReasoningSummaryAuto, + }, + }) + if err != nil { + log.Fatalf("NewChatModel failed, err=%v", err) + } + + cm, err := chatModel.WithTools([]*schema.ToolInfo{ + { + Name: "user_company", + Desc: "Retrieve the user's company and position based on their name and email.", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "name": {Type: "string", Desc: "user's name"}, + "email": {Type: "string", Desc: "user's email"}, + }), + }, + { + Name: "user_salary", + Desc: "Retrieve the user's salary based on their name and email.", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "name": {Type: "string", Desc: "user's name"}, + "email": {Type: "string", Desc: "user's email"}, + }), + }, + }) + if err != nil { + log.Fatalf("WithTools failed, err=%v", err) + } + + resp, err := cm.Generate(ctx, []*schema.Message{ + { + Role: schema.System, + Content: "As a real estate agent, provide relevant property information based on the user's salary and job using the user_company and user_salary APIs. An email address is required.", + }, + { + Role: schema.User, + Content: "My name is John and my email is john@abc.com, please recommend some houses that suit me.", + }, + }) + if err != nil { + log.Fatalf("Generate failed, err=%v", err) + } + fmt.Printf("output: \n%v\n", resp) + + streamResp, err := cm.Stream(ctx, []*schema.Message{ + { + Role: schema.System, + Content: "As a real estate agent, provide relevant property information based on the user's salary and job using the user_company and user_salary APIs. An email address is required.", + }, + { + Role: schema.User, + Content: "My name is John and my email is john@abc.com, please recommend some houses that suit me.", + }, + }) + if err != nil { + log.Fatalf("Stream failed, err=%v", err) + } + + var messages []*schema.Message + for { + chunk, err := streamResp.Recv() + if err == io.EOF { + break + } + if err != nil { + log.Fatalf("Recv failed, err=%v", err) + } + messages = append(messages, chunk) + } + resp2, err := schema.ConcatMessages(messages) + if err != nil { + log.Fatalf("ConcatMessages failed, err=%v", err) + } + fmt.Printf("stream output: \n%v\n", resp2) +} diff --git a/components/model/openai-go/examples/stream/stream.go b/components/model/openai-go/examples/stream/stream.go new file mode 100644 index 000000000..816f6c76c --- /dev/null +++ b/components/model/openai-go/examples/stream/stream.go @@ -0,0 +1,68 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "fmt" + "io" + "log" + "os" + + "github.com/cloudwego/eino/schema" + + openaigo "github.com/cloudwego/eino-ext/components/model/openai-go" +) + +func main() { + ctx := context.Background() + + cm, err := openaigo.NewChatModel(ctx, &openaigo.Config{ + APIKey: os.Getenv("OPENAI_API_KEY"), + Model: os.Getenv("OPENAI_MODEL"), + BaseURL: os.Getenv("OPENAI_BASE_URL"), + }) + if err != nil { + log.Fatalf("NewChatModel failed, err=%v", err) + } + + stream, err := cm.Stream(ctx, []*schema.Message{ + {Role: schema.User, Content: "Write a short poem about spring."}, + }) + if err != nil { + log.Fatalf("Stream error: %v", err) + } + + fmt.Println("Assistant:") + for { + chunk, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + log.Fatalf("Stream receive error: %v", err) + } + + if chunk.Content != "" { + fmt.Print(chunk.Content) + } + if chunk.ReasoningContent != "" { + fmt.Printf("\n[reasoning]\n%s\n", chunk.ReasoningContent) + } + } + fmt.Println() +} diff --git a/components/model/openai-go/go.mod b/components/model/openai-go/go.mod new file mode 100644 index 000000000..53f89b8f0 --- /dev/null +++ b/components/model/openai-go/go.mod @@ -0,0 +1,41 @@ +module github.com/cloudwego/eino-ext/components/model/openai-go + +go 1.22.3 + +require ( + github.com/cloudwego/eino v0.8.11 + github.com/openai/openai-go/v3 v3.32.0 +) + +require ( + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic v1.15.0 // indirect + github.com/bytedance/sonic/loader v0.5.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/eino-contrib/jsonschema v1.0.3 // indirect + github.com/goph/emperror v0.17.2 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.2.9 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/nikolalohinski/gonja v1.5.3 // indirect + github.com/pelletier/go-toml/v2 v2.0.9 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yargevad/filepathx v1.0.0 // indirect + golang.org/x/arch v0.11.0 // indirect + golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect + golang.org/x/sys v0.29.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/components/model/openai-go/go.sum b/components/model/openai-go/go.sum new file mode 100644 index 000000000..b00ab9524 --- /dev/null +++ b/components/model/openai-go/go.sum @@ -0,0 +1,149 @@ +github.com/airbrake/gobrake v3.6.1+incompatible/go.mod h1:wM4gu3Cn0W0K7GUuVWnlXZU11AGBXMILnrdOU8Kn00o= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8= +github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= +github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k= +github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE= +github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= +github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/cloudwego/eino v0.8.11 h1:lf/j1VXQTzPV9/pXijgjAIELTxvZi6zbPobmv3h/gco= +github.com/cloudwego/eino v0.8.11/go.mod h1:+2N4nsMPxA6kGBHpH+75JuTfEcGprAMTdsZESrShKpU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/eino-contrib/jsonschema v1.0.3 h1:2Kfsm1xlMV0ssY2nuxshS4AwbLFuqmPmzIjLVJ1Fsp0= +github.com/eino-contrib/jsonschema v1.0.3/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98= +github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/goph/emperror v0.17.2 h1:yLapQcmEsO0ipe9p5TaN22djm3OFV/TfM/fcYP0/J18= +github.com/goph/emperror v0.17.2/go.mod h1:+ZbQ+fUNO/6FNiUo0ujtMjhgad9Xa6fQL9KhH4LNHic= +github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= +github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= +github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY= +github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= +github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c= +github.com/nikolalohinski/gonja v1.5.3/go.mod h1:RmjwxNiXAEqcq1HeK5SSMmqFJvKOfTfXhkJv6YBtPa4= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/openai/openai-go/v3 v3.32.0 h1:aHp/3wkX1W6jB8zTtf9xV0aK0qPFSVDqS7AHmlJ4hXs= +github.com/openai/openai-go/v3 v3.32.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= +github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0= +github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtISeXco0L5PKQ= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg= +github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= +github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= +github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= +github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg= +github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE= +github.com/yargevad/filepathx v1.0.0 h1:SYcT+N3tYGi+NvazubCNlvgIPbzAk7i7y2dwg3I5FYc= +github.com/yargevad/filepathx v1.0.0/go.mod h1:BprfX/gpYNJHJfc35GjRRpVcwWXS89gGulUIU5tK3tA= +golang.org/x/arch v0.11.0 h1:KXV8WWKCXm6tRpLirl2szsO5j/oOODwZf4hATmGVNs4= +golang.org/x/arch v0.11.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= +golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= +golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw= +golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c= +golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/components/model/openai-go/option.go b/components/model/openai-go/option.go new file mode 100644 index 000000000..2239f9843 --- /dev/null +++ b/components/model/openai-go/option.go @@ -0,0 +1,63 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openaigo + +import "github.com/cloudwego/eino/components/model" + +type options struct { + MaxOutputTokens *int + Reasoning *Reasoning + Store *bool + Metadata map[string]string + ExtraFields map[string]any +} + +// WithMaxOutputTokens sets max_output_tokens for the Responses API. +func WithMaxOutputTokens(n int) model.Option { + return model.WrapImplSpecificOptFn(func(o *options) { + o.MaxOutputTokens = &n + }) +} + +// WithReasoning overrides the reasoning config for this request. +func WithReasoning(r *Reasoning) model.Option { + return model.WrapImplSpecificOptFn(func(o *options) { + o.Reasoning = r + }) +} + +// WithStore sets whether to store the response. +func WithStore(store bool) model.Option { + return model.WrapImplSpecificOptFn(func(o *options) { + o.Store = &store + }) +} + +// WithMetadata overrides request metadata. +func WithMetadata(m map[string]string) model.Option { + return model.WrapImplSpecificOptFn(func(o *options) { + o.Metadata = cloneStringMap(m) + }) +} + +// WithExtraFields injects extra fields into the request body. +// Extra fields overwrite any existing fields with the same key. +func WithExtraFields(extra map[string]any) model.Option { + return model.WrapImplSpecificOptFn(func(o *options) { + o.ExtraFields = cloneAnyMap(extra) + }) +} diff --git a/components/model/openai-go/stream.go b/components/model/openai-go/stream.go new file mode 100644 index 000000000..ed7b9a81f --- /dev/null +++ b/components/model/openai-go/stream.go @@ -0,0 +1,218 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openaigo + +import ( + "context" + "fmt" + "runtime/debug" + "strings" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" + "github.com/openai/openai-go/v3/responses" +) + +func (cm *ChatModel) Stream(ctx context.Context, in []*schema.Message, opts ...model.Option) (outStream *schema.StreamReader[*schema.Message], err error) { + ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel) + + params, cbIn, err := cm.buildParams(in, true, opts...) + if err != nil { + return nil, err + } + + ctx = callbacks.OnStart(ctx, cbIn) + defer func() { + if err != nil { + callbacks.OnError(ctx, err) + } + }() + + stream := cm.cli.Responses.NewStreaming(ctx, params) + + sr, sw := schema.Pipe[*model.CallbackOutput](1) + go func() { + defer func() { + pe := recover() + _ = stream.Close() + if pe != nil { + _ = sw.Send(nil, newPanicErr(pe, debug.Stack())) + } + sw.Close() + }() + + state := newStreamState() + for stream.Next() { + ev := stream.Current() + msg, done, deltaOnly, err2 := state.consume(ev) + if err2 != nil { + _ = sw.Send(nil, err2) + return + } + if msg == nil { + continue + } + + // ensure callbacks can receive token usage on final chunk. + if !deltaOnly { + msg.ResponseMeta = ensureResponseMeta(msg.ResponseMeta) + } + + closed := sw.Send(&model.CallbackOutput{ + Message: msg, + Config: cbIn.Config, + TokenUsage: toModelTokenUsage(msg.ResponseMeta), + Extra: func() map[string]any { + if done && state.modelName != "" { + return map[string]any{callbackExtraModelName: state.modelName} + } + return nil + }(), + }, nil) + if closed { + return + } + } + + if stream.Err() != nil { + _ = sw.Send(nil, stream.Err()) + return + } + }() + + ctx, nsr := callbacks.OnEndWithStreamOutput(ctx, schema.StreamReaderWithConvert(sr, + func(src *model.CallbackOutput) (callbacks.CallbackOutput, error) { return src, nil }, + )) + + outStream = schema.StreamReaderWithConvert(nsr, func(src callbacks.CallbackOutput) (*schema.Message, error) { + s := src.(*model.CallbackOutput) + if s.Message == nil { + return nil, schema.ErrNoValue + } + return s.Message, nil + }) + + return outStream, nil +} + +// consume and map streaming events into eino messages. +type streamState struct { + modelName string + functionArgBufs map[string]*strings.Builder // key: item_id + callIDByItemID map[string]string + nameByItemID map[string]string +} + +func newStreamState() *streamState { + return &streamState{ + functionArgBufs: make(map[string]*strings.Builder), + callIDByItemID: make(map[string]string), + nameByItemID: make(map[string]string), + } +} + +// consume returns: +// - msg: message chunk (delta) +// - done: if this chunk ends the response +// - deltaOnly: whether it's a pure delta message (so no finalization) +func (s *streamState) consume(ev responses.ResponseStreamEventUnion) (msg *schema.Message, done bool, deltaOnly bool, err error) { + switch v := ev.AsAny().(type) { + case responses.ResponseErrorEvent: + return nil, false, false, fmt.Errorf("openai stream error: %s (%s)", v.Message, v.Code) + case responses.ResponseCreatedEvent: + s.modelName = string(v.Response.Model) + return nil, false, true, nil + case responses.ResponseInProgressEvent: + // ignore; model name can be here too + s.modelName = string(v.Response.Model) + return nil, false, true, nil + case responses.ResponseTextDeltaEvent: + if v.Delta == "" { + return nil, false, true, nil + } + return &schema.Message{Role: schema.Assistant, Content: v.Delta}, false, true, nil + case responses.ResponseReasoningTextDeltaEvent: + if v.Delta == "" { + return nil, false, true, nil + } + m := &schema.Message{Role: schema.Assistant, ReasoningContent: v.Delta} + return m, false, true, nil + case responses.ResponseOutputItemAddedEvent: + // function call item appears here with call_id and name + item := v.Item + if item.Type == "function_call" { + call := item.AsFunctionCall() + s.callIDByItemID[item.ID] = call.CallID + s.nameByItemID[item.ID] = call.Name + } + return nil, false, true, nil + case responses.ResponseFunctionCallArgumentsDeltaEvent: + if v.Delta == "" { + return nil, false, true, nil + } + b := s.functionArgBufs[v.ItemID] + if b == nil { + b = &strings.Builder{} + s.functionArgBufs[v.ItemID] = b + } + b.WriteString(v.Delta) + return nil, false, true, nil + case responses.ResponseFunctionCallArgumentsDoneEvent: + // Finalize args: only emit ToolCalls when arguments are complete. + callID := s.callIDByItemID[v.ItemID] + name := s.nameByItemID[v.ItemID] + if callID == "" { + callID = v.ItemID + } + + args := v.Arguments + if args == "" { + if b := s.functionArgBufs[v.ItemID]; b != nil { + args = b.String() + } + } + return &schema.Message{Role: schema.Assistant, ToolCalls: []schema.ToolCall{{ + ID: callID, + Type: "function", + Function: schema.FunctionCall{ + Name: name, + Arguments: args, + }, + }}}, false, true, nil + case responses.ResponseCompletedEvent: + // IMPORTANT: do not emit the full final assistant message content here. + // The Responses streaming API already sends the assistant text as deltas + // (ResponseTextDeltaEvent / ResponseReasoningTextDeltaEvent). Emitting the + // final full message (resp.OutputText()) would cause downstream consumers + // that concatenate chunks to duplicate output. + return &schema.Message{ + Role: schema.Assistant, + ResponseMeta: &schema.ResponseMeta{ + FinishReason: string(v.Response.Status), + Usage: toEinoTokenUsage(v.Response.Usage), + }, + }, true, false, nil + case responses.ResponseFailedEvent: + return &schema.Message{Role: schema.Assistant, ResponseMeta: &schema.ResponseMeta{FinishReason: string(v.Response.Status)}}, true, false, nil + case responses.ResponseIncompleteEvent: + return &schema.Message{Role: schema.Assistant, ResponseMeta: &schema.ResponseMeta{FinishReason: string(v.Response.Status), Usage: toEinoTokenUsage(v.Response.Usage)}}, true, false, nil + default: + return nil, false, true, nil + } +} diff --git a/components/model/openai-go/stream_test.go b/components/model/openai-go/stream_test.go new file mode 100644 index 000000000..19626e7f8 --- /dev/null +++ b/components/model/openai-go/stream_test.go @@ -0,0 +1,189 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openaigo + +import ( + "strings" + "testing" + + "github.com/openai/openai-go/v3/responses" +) + +func TestStreamStateConsume(t *testing.T) { + t.Run("created and progress events update model", func(t *testing.T) { + s := newStreamState() + msg, done, deltaOnly, err := s.consume(mustJSON[responses.ResponseStreamEventUnion](t, map[string]any{ + "type": "response.created", + "response": map[string]any{"model": "gpt-created"}, + })) + if err != nil || msg != nil || done || !deltaOnly { + t.Fatalf("unexpected created event result: msg=%#v done=%v deltaOnly=%v err=%v", msg, done, deltaOnly, err) + } + if s.modelName != "gpt-created" { + t.Fatalf("expected model name to be recorded, got %q", s.modelName) + } + + msg, done, deltaOnly, err = s.consume(mustJSON[responses.ResponseStreamEventUnion](t, map[string]any{ + "type": "response.in_progress", + "response": map[string]any{"model": "gpt-progress"}, + })) + if err != nil || msg != nil || done || !deltaOnly { + t.Fatalf("unexpected in-progress result: msg=%#v done=%v deltaOnly=%v err=%v", msg, done, deltaOnly, err) + } + if s.modelName != "gpt-progress" { + t.Fatalf("expected model name from in-progress event, got %q", s.modelName) + } + }) + + t.Run("text and reasoning deltas", func(t *testing.T) { + s := newStreamState() + msg, done, deltaOnly, err := s.consume(mustJSON[responses.ResponseStreamEventUnion](t, map[string]any{ + "type": "response.output_text.delta", + "delta": "hello", + })) + if err != nil || done || !deltaOnly || msg == nil || msg.Content != "hello" { + t.Fatalf("unexpected text delta result: msg=%#v done=%v deltaOnly=%v err=%v", msg, done, deltaOnly, err) + } + + msg, done, deltaOnly, err = s.consume(mustJSON[responses.ResponseStreamEventUnion](t, map[string]any{ + "type": "response.reasoning_text.delta", + "delta": "thinking", + })) + if err != nil || done || !deltaOnly || msg == nil || msg.ReasoningContent != "thinking" { + t.Fatalf("unexpected reasoning delta result: msg=%#v done=%v deltaOnly=%v err=%v", msg, done, deltaOnly, err) + } + + msg, done, deltaOnly, err = s.consume(mustJSON[responses.ResponseStreamEventUnion](t, map[string]any{ + "type": "response.output_text.delta", + })) + if err != nil || msg != nil || done || !deltaOnly { + t.Fatalf("expected empty text delta to be ignored, got msg=%#v done=%v deltaOnly=%v err=%v", msg, done, deltaOnly, err) + } + }) + + t.Run("function call lifecycle", func(t *testing.T) { + s := newStreamState() + _, _, _, err := s.consume(mustJSON[responses.ResponseStreamEventUnion](t, map[string]any{ + "type": "response.output_item.added", + "item": map[string]any{ + "type": "function_call", + "id": "item_1", + "call_id": "call_123", + "name": "lookup_weather", + "arguments": "", + }, + })) + if err != nil { + t.Fatalf("unexpected output item added error: %v", err) + } + if got := s.callIDByItemID["item_1"]; got != "call_123" { + t.Fatalf("expected call id to be tracked, got %q", got) + } + if got := s.nameByItemID["item_1"]; got != "lookup_weather" { + t.Fatalf("expected function name to be tracked, got %q", got) + } + + msg, done, deltaOnly, err := s.consume(mustJSON[responses.ResponseStreamEventUnion](t, map[string]any{ + "type": "response.function_call_arguments.delta", + "item_id": "item_1", + "delta": `{"city":"`, + })) + if err != nil || msg != nil || done || !deltaOnly { + t.Fatalf("unexpected function-call delta result: msg=%#v done=%v deltaOnly=%v err=%v", msg, done, deltaOnly, err) + } + if got := s.functionArgBufs["item_1"].String(); got != `{"city":"` { + t.Fatalf("unexpected buffered args %q", got) + } + + msg, done, deltaOnly, err = s.consume(mustJSON[responses.ResponseStreamEventUnion](t, map[string]any{ + "type": "response.function_call_arguments.done", + "item_id": "item_1", + "arguments": `{"city":"beijing"}`, + })) + if err != nil || done || !deltaOnly || msg == nil { + t.Fatalf("unexpected function-call done result: msg=%#v done=%v deltaOnly=%v err=%v", msg, done, deltaOnly, err) + } + if len(msg.ToolCalls) != 1 || msg.ToolCalls[0].ID != "call_123" || msg.ToolCalls[0].Function.Name != "lookup_weather" || msg.ToolCalls[0].Function.Arguments != `{"city":"beijing"}` { + t.Fatalf("unexpected emitted tool call: %#v", msg.ToolCalls) + } + + msg, done, deltaOnly, err = s.consume(mustJSON[responses.ResponseStreamEventUnion](t, map[string]any{ + "type": "response.function_call_arguments.done", + "item_id": "item_fallback", + "name": "fallback_tool", + })) + if err != nil || done || !deltaOnly || msg == nil { + t.Fatalf("unexpected fallback done result: msg=%#v done=%v deltaOnly=%v err=%v", msg, done, deltaOnly, err) + } + if msg.ToolCalls[0].ID != "item_fallback" || msg.ToolCalls[0].Function.Name != "" { + t.Fatalf("unexpected fallback tool call: %#v", msg.ToolCalls[0]) + } + }) + + t.Run("completion and terminal events", func(t *testing.T) { + s := newStreamState() + msg, done, deltaOnly, err := s.consume(mustJSON[responses.ResponseStreamEventUnion](t, map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "status": "completed", + "usage": map[string]any{"input_tokens": 3, "output_tokens": 2, "total_tokens": 5}, + }, + })) + if err != nil || !done || deltaOnly || msg == nil { + t.Fatalf("unexpected completed result: msg=%#v done=%v deltaOnly=%v err=%v", msg, done, deltaOnly, err) + } + if msg.ResponseMeta == nil || msg.ResponseMeta.FinishReason != string(responses.ResponseStatusCompleted) || msg.ResponseMeta.Usage == nil || msg.ResponseMeta.Usage.TotalTokens != 5 { + t.Fatalf("unexpected completed response meta: %#v", msg.ResponseMeta) + } + + msg, done, deltaOnly, err = s.consume(mustJSON[responses.ResponseStreamEventUnion](t, map[string]any{ + "type": "response.failed", + "response": map[string]any{"status": "failed"}, + })) + if err != nil || !done || deltaOnly || msg == nil || msg.ResponseMeta.FinishReason != string(responses.ResponseStatusFailed) { + t.Fatalf("unexpected failed result: msg=%#v done=%v deltaOnly=%v err=%v", msg, done, deltaOnly, err) + } + + msg, done, deltaOnly, err = s.consume(mustJSON[responses.ResponseStreamEventUnion](t, map[string]any{ + "type": "response.incomplete", + "response": map[string]any{ + "status": "incomplete", + "usage": map[string]any{"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}, + }, + })) + if err != nil || !done || deltaOnly || msg == nil || msg.ResponseMeta.FinishReason != string(responses.ResponseStatusIncomplete) || msg.ResponseMeta.Usage == nil || msg.ResponseMeta.Usage.TotalTokens != 2 { + t.Fatalf("unexpected incomplete result: msg=%#v done=%v deltaOnly=%v err=%v", msg, done, deltaOnly, err) + } + }) + + t.Run("error and unknown events", func(t *testing.T) { + s := newStreamState() + msg, done, deltaOnly, err := s.consume(mustJSON[responses.ResponseStreamEventUnion](t, map[string]any{ + "type": "error", + "message": "boom", + "code": "bad_request", + })) + if err == nil || !strings.Contains(err.Error(), "boom") || msg != nil || done || deltaOnly { + t.Fatalf("expected stream error, got msg=%#v done=%v deltaOnly=%v err=%v", msg, done, deltaOnly, err) + } + + msg, done, deltaOnly, err = s.consume(responses.ResponseStreamEventUnion{Type: "response.output_item.done"}) + if err != nil || msg != nil || done || !deltaOnly { + t.Fatalf("expected unknown event to be ignored, got msg=%#v done=%v deltaOnly=%v err=%v", msg, done, deltaOnly, err) + } + }) +} diff --git a/components/model/openai-go/typeconv.go b/components/model/openai-go/typeconv.go new file mode 100644 index 000000000..3090ec48d --- /dev/null +++ b/components/model/openai-go/typeconv.go @@ -0,0 +1,499 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openaigo + +import ( + "fmt" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/responses" +) + +func (cm *ChatModel) buildParams(in []*schema.Message, stream bool, opts ...model.Option) (responses.ResponseNewParams, *model.CallbackInput, error) { + common := model.GetCommonOptions(&model.Options{ + Temperature: cm.temperature, + MaxTokens: func() *int { + // Responses API uses MaxOutputTokens; keep MaxTokens in common opts unused. + return nil + }(), + Model: &cm.model, + TopP: cm.topP, + Tools: cm.rawTools, + ToolChoice: cm.toolChoice, + }, opts...) + + spec := model.GetImplSpecificOptions(&options{ + MaxOutputTokens: cm.maxOutTok, + Reasoning: cm.reasoning, + Store: cm.store, + Metadata: cm.metadata, + ExtraFields: cm.extraFields, + }, opts...) + + params := responses.ResponseNewParams{} + if common.Model != nil { + params.Model = responsesModelFromString(*common.Model) + } + if spec.MaxOutputTokens != nil { + params.MaxOutputTokens = openai.Int(int64(*spec.MaxOutputTokens)) + } + if common.Temperature != nil { + params.Temperature = openai.Float(float64(*common.Temperature)) + } + if common.TopP != nil { + params.TopP = openai.Float(float64(*common.TopP)) + } + if spec.Store != nil { + params.Store = openai.Bool(*spec.Store) + } + if len(spec.Metadata) > 0 { + params.Metadata = spec.Metadata + } + if spec.Reasoning != nil { + params.Reasoning = spec.Reasoning.toSDK() + } + if stream { + params.StreamOptions = responses.ResponseNewParamsStreamOptions{IncludeObfuscation: openai.Bool(false)} + } + + // Tools. + tools := cm.tools + cbTools := cm.rawTools + if common.Tools != nil { + var err error + tools, cbTools, err = toOpenAITools(common.Tools) + if err != nil { + return responses.ResponseNewParams{}, nil, err + } + } + if len(tools) > 0 { + params.Tools = tools + } + + if err := populateToolChoice(¶ms, common.ToolChoice, common.AllowedToolNames, tools); err != nil { + return responses.ResponseNewParams{}, nil, err + } + + // Input. + inputItems, err := toInputItems(in) + if err != nil { + return responses.ResponseNewParams{}, nil, err + } + params.Input = responses.ResponseNewParamsInputUnion{OfInputItemList: inputItems} + + if len(spec.ExtraFields) > 0 { + params.SetExtraFields(spec.ExtraFields) + } + + cbIn := &model.CallbackInput{ + Messages: in, + Tools: cbTools, + ToolChoice: common.ToolChoice, + Config: &model.Config{ + Model: string(params.Model), + MaxTokens: int(optInt64(params.MaxOutputTokens)), + Temperature: float32(optFloat64(params.Temperature)), + TopP: float32(optFloat64(params.TopP)), + }, + } + + return params, cbIn, nil +} + +func toInputItems(in []*schema.Message) (responses.ResponseInputParam, error) { + items := make([]responses.ResponseInputItemUnionParam, 0, len(in)) + for _, msg := range in { + if msg == nil { + continue + } + switch msg.Role { + case schema.User: + content, err := toInputContentFromMessage(msg) + if err != nil { + return nil, err + } + items = append(items, responses.ResponseInputItemParamOfMessage(content, responses.EasyInputMessageRoleUser)) + case schema.System: + content, err := toInputContentFromMessage(msg) + if err != nil { + return nil, err + } + items = append(items, responses.ResponseInputItemParamOfMessage(content, responses.EasyInputMessageRoleSystem)) + case schema.Assistant: + assistantText, hasAssistantText, err := extractAssistantTextForHistory(msg) + if err != nil { + return nil, err + } + if hasAssistantText { + items = append(items, responses.ResponseInputItemParamOfMessage(assistantText, responses.EasyInputMessageRoleAssistant)) + } + + // assistant tool calls + for _, tc := range msg.ToolCalls { + items = append(items, responses.ResponseInputItemParamOfFunctionCall(tc.Function.Arguments, tc.ID, tc.Function.Name)) + } + case schema.Tool: + // tool call output + if msg.ToolCallID == "" { + return nil, fmt.Errorf("tool message missing ToolCallID") + } + if len(msg.UserInputMultiContent) == 0 { + items = append(items, responses.ResponseInputItemParamOfFunctionCallOutput(msg.ToolCallID, msg.Content)) + break + } + outItems := make([]responses.ResponseFunctionCallOutputItemUnionParam, 0, len(msg.UserInputMultiContent)) + for _, part := range msg.UserInputMultiContent { + switch part.Type { + case schema.ChatMessagePartTypeText: + outItems = append(outItems, responses.ResponseFunctionCallOutputItemUnionParam{OfInputText: &responses.ResponseInputTextContentParam{Text: part.Text}}) + case schema.ChatMessagePartTypeImageURL: + if part.Image == nil { + return nil, fmt.Errorf("image field must not be nil in tool message") + } + url, err := commonToDataOrURL(part.Image.MessagePartCommon) + if err != nil { + return nil, err + } + outItems = append(outItems, responses.ResponseFunctionCallOutputItemUnionParam{OfInputImage: &responses.ResponseInputImageContentParam{ImageURL: openai.String(url)}}) + case schema.ChatMessagePartTypeFileURL: + if part.File == nil { + return nil, fmt.Errorf("file field must not be nil in tool message") + } + url, err := commonToDataOrURL(part.File.MessagePartCommon) + if err != nil { + return nil, err + } + p := &responses.ResponseInputFileContentParam{} + if part.File.URL != nil { + p.FileURL = openai.String(url) + } else { + p.FileData = openai.String(url) + } + if part.File.Name != "" { + p.Filename = openai.String(part.File.Name) + } + outItems = append(outItems, responses.ResponseFunctionCallOutputItemUnionParam{OfInputFile: p}) + default: + return nil, fmt.Errorf("unsupported tool output content type: %s", part.Type) + } + } + items = append(items, responses.ResponseInputItemParamOfFunctionCallOutput(msg.ToolCallID, responses.ResponseFunctionCallOutputItemListParam(outItems))) + default: + return nil, fmt.Errorf("unknown role: %s", msg.Role) + } + } + + return items, nil +} + +func toInputContentFromMessage(msg *schema.Message) (responses.ResponseInputMessageContentListParam, error) { + if len(msg.UserInputMultiContent) > 0 && len(msg.AssistantGenMultiContent) > 0 { + return nil, fmt.Errorf("a message cannot contain both UserInputMultiContent and AssistantGenMultiContent") + } + if len(msg.UserInputMultiContent) > 0 { + parts := make([]responses.ResponseInputContentUnionParam, 0, len(msg.UserInputMultiContent)) + for _, part := range msg.UserInputMultiContent { + p, err := toInputContentPartFromInputPart(part) + if err != nil { + return nil, err + } + parts = append(parts, p) + } + return responses.ResponseInputMessageContentListParam(parts), nil + } + if len(msg.AssistantGenMultiContent) > 0 { + // For assistant messages, only text parts can be re-sent as input. + parts := make([]responses.ResponseInputContentUnionParam, 0, len(msg.AssistantGenMultiContent)) + for _, part := range msg.AssistantGenMultiContent { + if part.Type != schema.ChatMessagePartTypeText { + return nil, fmt.Errorf("unsupported assistant output part type in re-input: %s", part.Type) + } + parts = append(parts, responses.ResponseInputContentUnionParam{OfInputText: &responses.ResponseInputTextParam{Text: part.Text}}) + } + return responses.ResponseInputMessageContentListParam(parts), nil + } + + // Backward compatible deprecated MultiContent. + if len(msg.MultiContent) > 0 { + parts := make([]responses.ResponseInputContentUnionParam, 0, len(msg.MultiContent)) + for _, c := range msg.MultiContent { + switch c.Type { + case schema.ChatMessagePartTypeText: + parts = append(parts, responses.ResponseInputContentUnionParam{OfInputText: &responses.ResponseInputTextParam{Text: c.Text}}) + case schema.ChatMessagePartTypeImageURL: + if c.ImageURL == nil { + continue + } + parts = append(parts, responses.ResponseInputContentUnionParam{OfInputImage: &responses.ResponseInputImageParam{ + Detail: responses.ResponseInputImageDetailAuto, + ImageURL: openai.String(func() string { + if c.ImageURL.URI != "" { + return c.ImageURL.URI + } + return c.ImageURL.URL + }()), + }}) + default: + return nil, fmt.Errorf("unsupported deprecated MultiContent part type: %s", c.Type) + } + } + return responses.ResponseInputMessageContentListParam(parts), nil + } + + if msg.Content == "" { + // allow empty content for assistant messages + if msg.Role == schema.Assistant { + return responses.ResponseInputMessageContentListParam([]responses.ResponseInputContentUnionParam{}), nil + } + return nil, fmt.Errorf("message content is empty") + } + return responses.ResponseInputMessageContentListParam([]responses.ResponseInputContentUnionParam{{ + OfInputText: &responses.ResponseInputTextParam{Text: msg.Content}, + }}), nil +} + +func toInputContentPartFromInputPart(part schema.MessageInputPart) (responses.ResponseInputContentUnionParam, error) { + switch part.Type { + case schema.ChatMessagePartTypeText: + return responses.ResponseInputContentUnionParam{OfInputText: &responses.ResponseInputTextParam{Text: part.Text}}, nil + case schema.ChatMessagePartTypeImageURL: + if part.Image == nil { + return responses.ResponseInputContentUnionParam{}, fmt.Errorf("image field must not be nil when type is %s", part.Type) + } + url, err := commonToDataOrURL(part.Image.MessagePartCommon) + if err != nil { + return responses.ResponseInputContentUnionParam{}, err + } + return responses.ResponseInputContentUnionParam{OfInputImage: &responses.ResponseInputImageParam{ + Detail: toSDKImageDetail(part.Image.Detail), + ImageURL: openai.String(url), + }}, nil + case schema.ChatMessagePartTypeFileURL: + if part.File == nil { + return responses.ResponseInputContentUnionParam{}, fmt.Errorf("file field must not be nil when type is %s", part.Type) + } + fileURL, err := commonToDataOrURL(part.File.MessagePartCommon) + if err != nil { + return responses.ResponseInputContentUnionParam{}, err + } + fileParam := &responses.ResponseInputFileParam{} + if part.File.URL != nil { + fileParam.FileURL = openai.String(fileURL) + } else if part.File.Base64Data != nil { + fileParam.FileData = openai.String(fileURL) + } + if part.File.Name != "" { + fileParam.Filename = openai.String(part.File.Name) + } + return responses.ResponseInputContentUnionParam{OfInputFile: fileParam}, nil + default: + return responses.ResponseInputContentUnionParam{}, fmt.Errorf("unsupported content type: %s", part.Type) + } +} + +// Deprecated: tool call outputs are constructed inline in toInputItems. +func toFunctionCallOutputFromToolMessage(_ *schema.Message) (any, error) { + return nil, fmt.Errorf("deprecated") +} + +func populateToolChoice(params *responses.ResponseNewParams, tc *schema.ToolChoice, allowedToolNames []string, tools []responses.ToolUnionParam) error { + if tc == nil { + return nil + } + + switch *tc { + case schema.ToolChoiceForbidden: + params.ToolChoice = responses.ResponseNewParamsToolChoiceUnion{OfToolChoiceMode: openai.Opt(responses.ToolChoiceOptionsNone)} + return nil + case schema.ToolChoiceAllowed: + params.ToolChoice = responses.ResponseNewParamsToolChoiceUnion{OfToolChoiceMode: openai.Opt(responses.ToolChoiceOptionsAuto)} + return nil + case schema.ToolChoiceForced: + if len(tools) == 0 { + return fmt.Errorf("tool_choice is forced but no tools are provided") + } + + // If a single allowed tool is specified (or only one tool exists), force it. + var onlyOneToolName string + if len(allowedToolNames) > 0 { + if len(allowedToolNames) > 1 { + return fmt.Errorf("only one allowed tool name can be configured") + } + allowed := allowedToolNames[0] + if !toolNameExists(tools, allowed) { + return fmt.Errorf("allowed tool name '%s' not found in tools list", allowed) + } + onlyOneToolName = allowed + } else if len(tools) == 1 { + if tools[0].OfFunction != nil { + onlyOneToolName = tools[0].OfFunction.Name + } + } + + if onlyOneToolName != "" { + params.ToolChoice = responses.ResponseNewParamsToolChoiceUnion{OfFunctionTool: &responses.ToolChoiceFunctionParam{Name: onlyOneToolName}} + return nil + } + + params.ToolChoice = responses.ResponseNewParamsToolChoiceUnion{OfToolChoiceMode: openai.Opt(responses.ToolChoiceOptionsRequired)} + return nil + default: + return fmt.Errorf("unknown tool choice: %s", *tc) + } +} + +func toOpenAITools(tis []*schema.ToolInfo) ([]responses.ToolUnionParam, []*schema.ToolInfo, error) { + tools := make([]responses.ToolUnionParam, len(tis)) + rawTools := make([]*schema.ToolInfo, len(tis)) + copy(rawTools, tis) + for i := range tis { + ti := tis[i] + if ti == nil { + return nil, nil, fmt.Errorf("tool info cannot be nil") + } + paramsJSONSchema, err := ti.ParamsOneOf.ToJSONSchema() + if err != nil { + return nil, nil, fmt.Errorf("failed to convert tool parameters to JSONSchema: %w", err) + } + paramsMap, err := jsonSchemaToMap(paramsJSONSchema) + if err != nil { + return nil, nil, err + } + enforceOpenAIStrictJSONSchema(paramsMap) + t := responses.ToolUnionParam{OfFunction: &responses.FunctionToolParam{ + Name: ti.Name, + Description: openai.String(ti.Desc), + Parameters: paramsMap, + Strict: openai.Bool(true), + }} + tools[i] = t + } + return tools, rawTools, nil +} + +func toSDKImageDetail(detail schema.ImageURLDetail) responses.ResponseInputImageDetail { + switch detail { + case schema.ImageURLDetailHigh: + return responses.ResponseInputImageDetailHigh + case schema.ImageURLDetailLow: + return responses.ResponseInputImageDetailLow + case schema.ImageURLDetailAuto: + return responses.ResponseInputImageDetailAuto + default: + return responses.ResponseInputImageDetailAuto + } +} + +func (cm *ChatModel) convertResponseToMessage(resp *responses.Response) (*schema.Message, error) { + if resp == nil { + return nil, fmt.Errorf("nil response") + } + + msg := &schema.Message{Role: schema.Assistant} + msg.ResponseMeta = ensureResponseMeta(msg.ResponseMeta) + msg.ResponseMeta.FinishReason = string(resp.Status) + msg.ResponseMeta.Usage = toEinoTokenUsage(resp.Usage) + + // Extract tool calls and assistant text. + msg.Content = resp.OutputText() + for _, item := range resp.Output { + switch v := item.AsAny().(type) { + case responses.ResponseFunctionToolCall: + msg.ToolCalls = append(msg.ToolCalls, schema.ToolCall{ + ID: v.CallID, + Type: "function", + Function: schema.FunctionCall{ + Name: v.Name, + Arguments: v.Arguments, + }, + }) + case responses.ResponseOutputItemImageGenerationCall: + // result is base64 image (no data: prefix) + if v.Result != "" { + b64 := v.Result + msg.AssistantGenMultiContent = append(msg.AssistantGenMultiContent, schema.MessageOutputPart{ + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageOutputImage{ + MessagePartCommon: schema.MessagePartCommon{ + Base64Data: &b64, + MIMEType: "image/png", + }, + }, + }) + } + case responses.ResponseReasoningItem: + // Prefer summary text when provided; otherwise content. + msg.ReasoningContent = joinReasoningText(v) + } + } + + if len(msg.Content) > 0 { + // keep assistant text as first part if no parts exist yet. + if len(msg.AssistantGenMultiContent) == 0 { + msg.AssistantGenMultiContent = append(msg.AssistantGenMultiContent, schema.MessageOutputPart{ + Type: schema.ChatMessagePartTypeText, + Text: msg.Content, + }) + } else { + // prepend text part to existing parts + msg.AssistantGenMultiContent = append([]schema.MessageOutputPart{{ + Type: schema.ChatMessagePartTypeText, + Text: msg.Content, + }}, msg.AssistantGenMultiContent...) + } + } + + return msg, nil +} + +func toEinoTokenUsage(usage responses.ResponseUsage) *schema.TokenUsage { + // usage is a value type; if it is all zeros, treat as absent. + if usage.InputTokens == 0 && usage.OutputTokens == 0 && usage.TotalTokens == 0 { + return nil + } + return &schema.TokenUsage{ + PromptTokens: int(usage.InputTokens), + PromptTokenDetails: schema.PromptTokenDetails{ + CachedTokens: int(usage.InputTokensDetails.CachedTokens), + }, + CompletionTokens: int(usage.OutputTokens), + TotalTokens: int(usage.TotalTokens), + CompletionTokensDetails: schema.CompletionTokensDetails{ + ReasoningTokens: int(usage.OutputTokensDetails.ReasoningTokens), + }, + } +} + +func toModelTokenUsage(meta *schema.ResponseMeta) *model.TokenUsage { + if meta == nil || meta.Usage == nil { + return nil + } + u := meta.Usage + return &model.TokenUsage{ + PromptTokens: u.PromptTokens, + PromptTokenDetails: model.PromptTokenDetails{ + CachedTokens: u.PromptTokenDetails.CachedTokens, + }, + CompletionTokens: u.CompletionTokens, + TotalTokens: u.TotalTokens, + CompletionTokensDetails: model.CompletionTokensDetails{ + ReasoningTokens: u.CompletionTokensDetails.ReasoningTokens, + }, + } +} diff --git a/components/model/openai-go/typeconv_test.go b/components/model/openai-go/typeconv_test.go new file mode 100644 index 000000000..cfe8eed7e --- /dev/null +++ b/components/model/openai-go/typeconv_test.go @@ -0,0 +1,677 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openaigo + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" + "github.com/openai/openai-go/v3/packages/param" + "github.com/openai/openai-go/v3/responses" +) + +func TestBuildParams_UsesConfigAndOverrides(t *testing.T) { + maxOut := 50 + cfgTopP := float32(0.7) + cfgTemp := float32(0.3) + store := true + cm := &ChatModel{ + model: "gpt-4o-mini", + maxOutTok: &maxOut, + topP: &cfgTopP, + temperature: &cfgTemp, + reasoning: &Reasoning{ + Effort: ReasoningEffortMedium, + Summary: ReasoningSummaryConcise, + }, + store: &store, + metadata: map[string]string{"scope": "config"}, + extraFields: map[string]any{"base_only": true}, + tools: []responses.ToolUnionParam{responses.ToolParamOfFunction("default_tool", map[string]any{"type": "object"}, true)}, + rawTools: []*schema.ToolInfo{makeWeatherTool()}, + } + + requestTool := makeLookupTool() + forcedMax := 99 + overrideReasoning := &Reasoning{Effort: ReasoningEffortHigh, Summary: ReasoningSummaryDetailed} + params, cbIn, err := cm.buildParams([]*schema.Message{{Role: schema.User, Content: "hello"}}, true, + model.WithModel("gpt-test"), + model.WithTemperature(0.9), + model.WithTopP(0.4), + WithMaxOutputTokens(forcedMax), + WithReasoning(overrideReasoning), + WithStore(false), + WithMetadata(map[string]string{"scope": "request"}), + WithExtraFields(map[string]any{"exp": "beta"}), + model.WithTools([]*schema.ToolInfo{requestTool}), + model.WithToolChoice(schema.ToolChoiceForced, requestTool.Name), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if got := string(params.Model); got != "gpt-test" { + t.Fatalf("expected model override, got %q", got) + } + if got := optInt64(params.MaxOutputTokens); got != int64(forcedMax) { + t.Fatalf("expected max output tokens %d, got %d", forcedMax, got) + } + if got := optFloat64(params.Temperature); got < 0.899 || got > 0.901 { + t.Fatalf("expected temperature 0.9, got %v", got) + } + if got := optFloat64(params.TopP); got < 0.399 || got > 0.401 { + t.Fatalf("expected top_p 0.4, got %v", got) + } + if !params.Store.Valid() || params.Store.Value != false { + t.Fatalf("expected store override to false, got %#v", params.Store) + } + if !params.StreamOptions.IncludeObfuscation.Valid() || params.StreamOptions.IncludeObfuscation.Value != false { + t.Fatalf("expected stream options for streaming request, got %#v", params.StreamOptions) + } + if got := params.Metadata["scope"]; got != "request" { + t.Fatalf("expected request metadata override, got %q", got) + } + if got := string(params.Reasoning.Effort); got != "high" { + t.Fatalf("expected reasoning effort high, got %q", got) + } + if got := string(params.Reasoning.Summary); got != "detailed" { + t.Fatalf("expected reasoning summary detailed, got %q", got) + } + if len(params.Tools) != 1 || params.Tools[0].OfFunction == nil || params.Tools[0].OfFunction.Name != requestTool.Name { + t.Fatalf("expected request tool to be used, got %#v", params.Tools) + } + if params.ToolChoice.OfFunctionTool == nil || params.ToolChoice.OfFunctionTool.Name != requestTool.Name { + t.Fatalf("expected forced function tool choice, got %#v", params.ToolChoice) + } + if got := params.Input.OfInputItemList; len(got) != 1 || got[0].OfMessage == nil { + t.Fatalf("expected one input message, got %#v", got) + } + + body, err := params.MarshalJSON() + if err != nil { + t.Fatalf("marshal params: %v", err) + } + jsonBody := string(body) + if !strings.Contains(jsonBody, `"exp":"beta"`) { + t.Fatalf("expected extra fields in marshaled request, got %s", jsonBody) + } + + if cbIn == nil || cbIn.Config == nil { + t.Fatalf("expected callback input config") + } + if cbIn.Config.Model != "gpt-test" || cbIn.Config.MaxTokens != forcedMax { + t.Fatalf("unexpected callback config: %#v", cbIn.Config) + } + if len(cbIn.Tools) != 1 || cbIn.Tools[0] != requestTool { + t.Fatalf("expected callback tools to use request tools, got %#v", cbIn.Tools) + } + if cbIn.ToolChoice == nil || *cbIn.ToolChoice != schema.ToolChoiceForced { + t.Fatalf("unexpected callback tool choice: %#v", cbIn.ToolChoice) + } +} + +func TestBuildParams_ErrorCases(t *testing.T) { + cm := &ChatModel{model: "gpt-4o-mini"} + + _, _, err := cm.buildParams([]*schema.Message{{Role: schema.User, Content: "hi"}}, false, + model.WithToolChoice(schema.ToolChoiceForced), + ) + if err == nil || !strings.Contains(err.Error(), "no tools") { + t.Fatalf("expected forced tool choice error, got %v", err) + } + + _, _, err = cm.buildParams([]*schema.Message{{Role: schema.User, Content: "hi"}}, false, + model.WithTools([]*schema.ToolInfo{nil}), + ) + if err == nil || !strings.Contains(err.Error(), "tool info cannot be nil") { + t.Fatalf("expected nil tool error, got %v", err) + } +} + +func TestToInputItems_MixedMessages(t *testing.T) { + imgURL := "https://example.com/cat.png" + fileURL := "https://example.com/report.pdf" + fileData := "cGRm" + items, err := toInputItems([]*schema.Message{ + nil, + {Role: schema.System, Content: "You are helpful"}, + {Role: schema.User, UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeText, Text: "see image"}, + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{ + MessagePartCommon: schema.MessagePartCommon{URL: &imgURL}, + Detail: schema.ImageURLDetailHigh, + }}, + {Type: schema.ChatMessagePartTypeFileURL, File: &schema.MessageInputFile{ + MessagePartCommon: schema.MessagePartCommon{URL: &fileURL}, + Name: "report.pdf", + }}, + }}, + {Role: schema.Assistant, Content: "working", ToolCalls: []schema.ToolCall{{ + ID: "call_1", + Type: "function", + Function: schema.FunctionCall{Name: "lookup_weather", Arguments: `{"city":"beijing"}`}, + }}}, + {Role: schema.Tool, ToolCallID: "call_1", UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeText, Text: "sunny"}, + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{ + MessagePartCommon: schema.MessagePartCommon{URL: &imgURL}, + }}, + {Type: schema.ChatMessagePartTypeFileURL, File: &schema.MessageInputFile{ + MessagePartCommon: schema.MessagePartCommon{Base64Data: &fileData, MIMEType: "application/pdf"}, + Name: "inline.pdf", + }}, + }}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(items) != 5 { + t.Fatalf("expected 5 input items, got %d", len(items)) + } + if items[0].OfMessage == nil || items[0].OfMessage.Role != responses.EasyInputMessageRoleSystem { + t.Fatalf("expected first item to be a system message, got %#v", items[0]) + } + if items[1].OfMessage == nil || items[1].OfMessage.Role != responses.EasyInputMessageRoleUser { + t.Fatalf("expected second item to be a user message, got %#v", items[1]) + } + if items[2].OfMessage == nil || items[2].OfMessage.Role != responses.EasyInputMessageRoleAssistant { + t.Fatalf("expected third item to be an assistant history message, got %#v", items[2]) + } + if items[3].OfFunctionCall == nil || items[3].OfFunctionCall.CallID != "call_1" { + t.Fatalf("expected fourth item to be the assistant tool call, got %#v", items[3]) + } + if items[4].OfFunctionCallOutput == nil || items[4].OfFunctionCallOutput.CallID != "call_1" { + t.Fatalf("expected fifth item to be tool output, got %#v", items[4]) + } +} + +func TestToInputItems_ErrorCases(t *testing.T) { + tests := []struct { + name string + msg *schema.Message + want string + }{ + {name: "tool missing call id", msg: &schema.Message{Role: schema.Tool, Content: "oops"}, want: "ToolCallID"}, + {name: "tool image nil", msg: &schema.Message{Role: schema.Tool, ToolCallID: "call", UserInputMultiContent: []schema.MessageInputPart{{Type: schema.ChatMessagePartTypeImageURL}}}, want: "image field must not be nil"}, + {name: "unknown role", msg: &schema.Message{Role: schema.RoleType("mystery"), Content: "hi"}, want: "unknown role"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := toInputItems([]*schema.Message{tt.msg}) + if err == nil || !strings.Contains(err.Error(), tt.want) { + t.Fatalf("expected error containing %q, got %v", tt.want, err) + } + }) + } +} + +func TestToInputContentFromMessage(t *testing.T) { + imgURL := "https://example.com/img.png" + fileData := "aGVsbG8=" + + t.Run("user multimodal", func(t *testing.T) { + parts, err := toInputContentFromMessage(&schema.Message{Role: schema.User, UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeText, Text: "hello"}, + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{URL: &imgURL}}}, + {Type: schema.ChatMessagePartTypeFileURL, File: &schema.MessageInputFile{MessagePartCommon: schema.MessagePartCommon{Base64Data: &fileData, MIMEType: "text/plain"}, Name: "note.txt"}}, + }}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(parts) != 3 { + t.Fatalf("expected 3 parts, got %d", len(parts)) + } + }) + + t.Run("assistant output text parts", func(t *testing.T) { + parts, err := toInputContentFromMessage(&schema.Message{Role: schema.Assistant, AssistantGenMultiContent: []schema.MessageOutputPart{{Type: schema.ChatMessagePartTypeText, Text: "one"}, {Type: schema.ChatMessagePartTypeText, Text: "two"}}}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(parts) != 2 { + t.Fatalf("expected 2 parts, got %d", len(parts)) + } + }) + + t.Run("deprecated multi content", func(t *testing.T) { + parts, err := toInputContentFromMessage(&schema.Message{Role: schema.User, MultiContent: []schema.ChatMessagePart{{Type: schema.ChatMessagePartTypeText, Text: "legacy"}, {Type: schema.ChatMessagePartTypeImageURL, ImageURL: &schema.ChatMessageImageURL{URL: imgURL}}}}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(parts) != 2 { + t.Fatalf("expected 2 parts, got %d", len(parts)) + } + }) + + t.Run("empty assistant content allowed", func(t *testing.T) { + parts, err := toInputContentFromMessage(&schema.Message{Role: schema.Assistant}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(parts) != 0 { + t.Fatalf("expected empty parts, got %d", len(parts)) + } + }) + + t.Run("errors", func(t *testing.T) { + tests := []struct { + name string + msg *schema.Message + want string + }{ + {name: "mixed content fields", msg: &schema.Message{UserInputMultiContent: []schema.MessageInputPart{{Type: schema.ChatMessagePartTypeText, Text: "u"}}, AssistantGenMultiContent: []schema.MessageOutputPart{{Type: schema.ChatMessagePartTypeText, Text: "a"}}}, want: "cannot contain both"}, + {name: "assistant non text output part", msg: &schema.Message{Role: schema.Assistant, AssistantGenMultiContent: []schema.MessageOutputPart{{Type: schema.ChatMessagePartTypeImageURL}}}, want: "unsupported assistant output part type"}, + {name: "deprecated unsupported part", msg: &schema.Message{Role: schema.User, MultiContent: []schema.ChatMessagePart{{Type: schema.ChatMessagePartTypeFileURL}}}, want: "unsupported deprecated MultiContent"}, + {name: "empty user content", msg: &schema.Message{Role: schema.User}, want: "message content is empty"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := toInputContentFromMessage(tt.msg) + if err == nil || !strings.Contains(err.Error(), tt.want) { + t.Fatalf("expected error containing %q, got %v", tt.want, err) + } + }) + } + }) +} + +func TestToInputContentPartFromInputPart(t *testing.T) { + imgURL := "https://example.com/img.png" + fileData := "YmFzZTY0" + + if _, err := toInputContentPartFromInputPart(schema.MessageInputPart{Type: schema.ChatMessagePartTypeText, Text: "ok"}); err != nil { + t.Fatalf("text part should succeed: %v", err) + } + if _, err := toInputContentPartFromInputPart(schema.MessageInputPart{Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{URL: &imgURL}, Detail: schema.ImageURLDetailLow}}); err != nil { + t.Fatalf("image part should succeed: %v", err) + } + if _, err := toInputContentPartFromInputPart(schema.MessageInputPart{Type: schema.ChatMessagePartTypeFileURL, File: &schema.MessageInputFile{MessagePartCommon: schema.MessagePartCommon{Base64Data: &fileData, MIMEType: "text/plain"}, Name: "x.txt"}}); err != nil { + t.Fatalf("file part should succeed: %v", err) + } + + tests := []struct { + name string + part schema.MessageInputPart + want string + }{ + {name: "nil image", part: schema.MessageInputPart{Type: schema.ChatMessagePartTypeImageURL}, want: "image field must not be nil"}, + {name: "nil file", part: schema.MessageInputPart{Type: schema.ChatMessagePartTypeFileURL}, want: "file field must not be nil"}, + {name: "unsupported type", part: schema.MessageInputPart{Type: schema.ChatMessagePartTypeAudioURL}, want: "unsupported content type"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := toInputContentPartFromInputPart(tt.part) + if err == nil || !strings.Contains(err.Error(), tt.want) { + t.Fatalf("expected error containing %q, got %v", tt.want, err) + } + }) + } +} + +func TestToFunctionCallOutputFromToolMessageDeprecated(t *testing.T) { + if _, err := toFunctionCallOutputFromToolMessage(&schema.Message{}); err == nil || !strings.Contains(err.Error(), "deprecated") { + t.Fatalf("expected deprecated error, got %v", err) + } +} + +func TestPopulateToolChoice(t *testing.T) { + tools := []responses.ToolUnionParam{ + responses.ToolParamOfFunction("lookup_weather", map[string]any{"type": "object"}, true), + responses.ToolParamOfFunction("lookup_stock", map[string]any{"type": "object"}, true), + } + + t.Run("nil tool choice", func(t *testing.T) { + params := responses.ResponseNewParams{} + if err := populateToolChoice(¶ms, nil, nil, tools); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !param.IsOmitted(params.ToolChoice.OfToolChoiceMode) || params.ToolChoice.OfFunctionTool != nil { + t.Fatalf("expected omitted tool choice, got %#v", params.ToolChoice) + } + }) + + t.Run("forbidden", func(t *testing.T) { + params := responses.ResponseNewParams{} + choice := schema.ToolChoiceForbidden + if err := populateToolChoice(¶ms, &choice, nil, tools); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if mode := params.ToolChoice.OfToolChoiceMode.Value; mode != responses.ToolChoiceOptionsNone { + t.Fatalf("expected none mode, got %q", mode) + } + }) + + t.Run("allowed", func(t *testing.T) { + params := responses.ResponseNewParams{} + choice := schema.ToolChoiceAllowed + if err := populateToolChoice(¶ms, &choice, nil, tools); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if mode := params.ToolChoice.OfToolChoiceMode.Value; mode != responses.ToolChoiceOptionsAuto { + t.Fatalf("expected auto mode, got %q", mode) + } + }) + + t.Run("forced single allowed tool", func(t *testing.T) { + params := responses.ResponseNewParams{} + choice := schema.ToolChoiceForced + if err := populateToolChoice(¶ms, &choice, []string{"lookup_stock"}, tools); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if params.ToolChoice.OfFunctionTool == nil || params.ToolChoice.OfFunctionTool.Name != "lookup_stock" { + t.Fatalf("expected specific function tool, got %#v", params.ToolChoice) + } + }) + + t.Run("forced single tool fallback", func(t *testing.T) { + params := responses.ResponseNewParams{} + choice := schema.ToolChoiceForced + if err := populateToolChoice(¶ms, &choice, nil, tools[:1]); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if params.ToolChoice.OfFunctionTool == nil || params.ToolChoice.OfFunctionTool.Name != "lookup_weather" { + t.Fatalf("expected only tool to be forced, got %#v", params.ToolChoice) + } + }) + + t.Run("forced required mode", func(t *testing.T) { + params := responses.ResponseNewParams{} + choice := schema.ToolChoiceForced + if err := populateToolChoice(¶ms, &choice, nil, tools); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if mode := params.ToolChoice.OfToolChoiceMode.Value; mode != responses.ToolChoiceOptionsRequired { + t.Fatalf("expected required mode, got %q", mode) + } + }) + + t.Run("errors", func(t *testing.T) { + tests := []struct { + name string + choice schema.ToolChoice + allowed []string + tools []responses.ToolUnionParam + want string + }{ + {name: "forced no tools", choice: schema.ToolChoiceForced, want: "no tools"}, + {name: "multiple allowed names", choice: schema.ToolChoiceForced, tools: tools, allowed: []string{"a", "b"}, want: "only one allowed tool name"}, + {name: "allowed tool missing", choice: schema.ToolChoiceForced, tools: tools, allowed: []string{"missing"}, want: "not found"}, + {name: "unknown choice", choice: schema.ToolChoice("mystery"), tools: tools, want: "unknown tool choice"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + params := responses.ResponseNewParams{} + err := populateToolChoice(¶ms, &tt.choice, tt.allowed, tt.tools) + if err == nil || !strings.Contains(err.Error(), tt.want) { + t.Fatalf("expected error containing %q, got %v", tt.want, err) + } + }) + } + }) +} + +func TestToOpenAITools(t *testing.T) { + weatherTool := makeWeatherTool() + tools, rawTools, err := toOpenAITools([]*schema.ToolInfo{weatherTool}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(tools) != 1 || len(rawTools) != 1 { + t.Fatalf("expected one tool, got tools=%d rawTools=%d", len(tools), len(rawTools)) + } + if rawTools[0] != weatherTool { + t.Fatalf("expected raw tool to be preserved") + } + fn := tools[0].OfFunction + if fn == nil { + t.Fatalf("expected function tool") + } + if fn.Name != weatherTool.Name || !fn.Strict.Value { + t.Fatalf("unexpected function tool: %#v", fn) + } + props, ok := fn.Parameters["properties"].(map[string]any) + if !ok { + t.Fatalf("expected properties map in parameters: %#v", fn.Parameters) + } + if _, ok := props["city"]; !ok { + t.Fatalf("expected city property, got %#v", props) + } + if fn.Parameters["type"] != "object" || fn.Parameters["additionalProperties"] != false { + t.Fatalf("expected strict object schema, got %#v", fn.Parameters) + } + if required, ok := fn.Parameters["required"].([]any); !ok || len(required) == 0 { + t.Fatalf("expected required fields, got %#v", fn.Parameters["required"]) + } + + if _, _, err := toOpenAITools([]*schema.ToolInfo{nil}); err == nil || !strings.Contains(err.Error(), "cannot be nil") { + t.Fatalf("expected nil tool error, got %v", err) + } +} + +func TestToSDKImageDetail(t *testing.T) { + if got := toSDKImageDetail(schema.ImageURLDetailHigh); got != responses.ResponseInputImageDetailHigh { + t.Fatalf("expected high detail, got %q", got) + } + if got := toSDKImageDetail(schema.ImageURLDetailLow); got != responses.ResponseInputImageDetailLow { + t.Fatalf("expected low detail, got %q", got) + } + if got := toSDKImageDetail(schema.ImageURLDetailAuto); got != responses.ResponseInputImageDetailAuto { + t.Fatalf("expected auto detail, got %q", got) + } + if got := toSDKImageDetail(schema.ImageURLDetail("unknown")); got != responses.ResponseInputImageDetailAuto { + t.Fatalf("expected unknown detail to default to auto, got %q", got) + } +} + +func TestConvertResponseToMessage(t *testing.T) { + resp := &responses.Response{ + Status: responses.ResponseStatusCompleted, + Usage: responses.ResponseUsage{ + InputTokens: 11, + OutputTokens: 7, + TotalTokens: 18, + InputTokensDetails: responses.ResponseUsageInputTokensDetails{CachedTokens: 3}, + OutputTokensDetails: responses.ResponseUsageOutputTokensDetails{ReasoningTokens: 2}, + }, + Output: []responses.ResponseOutputItemUnion{ + mustJSON[responses.ResponseOutputItemUnion](t, map[string]any{ + "type": "message", + "id": "msg_1", + "role": "assistant", + "status": "completed", + "content": []map[string]any{{ + "type": "output_text", + "text": "Hello world", + "annotations": []any{}, + }}, + }), + mustJSON[responses.ResponseOutputItemUnion](t, map[string]any{ + "type": "function_call", + "id": "item_fc_1", + "call_id": "call_1", + "name": "lookup_weather", + "arguments": `{"city":"shanghai"}`, + }), + mustJSON[responses.ResponseOutputItemUnion](t, map[string]any{ + "type": "image_generation_call", + "id": "img_1", + "status": "completed", + "result": "ZmFrZS1pbWFnZQ==", + }), + mustJSON[responses.ResponseOutputItemUnion](t, map[string]any{ + "type": "reasoning", + "id": "reason_1", + "summary": []map[string]any{{ + "type": "summary_text", + "text": "summary text", + }}, + }), + }, + } + + msg, err := (&ChatModel{}).convertResponseToMessage(resp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if msg.Role != schema.Assistant || msg.Content != "Hello world" { + t.Fatalf("unexpected message: %#v", msg) + } + if msg.ResponseMeta == nil || msg.ResponseMeta.FinishReason != string(responses.ResponseStatusCompleted) { + t.Fatalf("unexpected response meta: %#v", msg.ResponseMeta) + } + if msg.ResponseMeta.Usage == nil || msg.ResponseMeta.Usage.TotalTokens != 18 { + t.Fatalf("unexpected token usage: %#v", msg.ResponseMeta.Usage) + } + if len(msg.ToolCalls) != 1 || msg.ToolCalls[0].ID != "call_1" || msg.ToolCalls[0].Function.Name != "lookup_weather" { + t.Fatalf("unexpected tool calls: %#v", msg.ToolCalls) + } + if msg.ReasoningContent != "summary text" { + t.Fatalf("unexpected reasoning content: %q", msg.ReasoningContent) + } + if len(msg.AssistantGenMultiContent) != 2 { + t.Fatalf("expected text and image output parts, got %#v", msg.AssistantGenMultiContent) + } + if msg.AssistantGenMultiContent[0].Type != schema.ChatMessagePartTypeText || msg.AssistantGenMultiContent[0].Text != "Hello world" { + t.Fatalf("unexpected first output part: %#v", msg.AssistantGenMultiContent[0]) + } + if msg.AssistantGenMultiContent[1].Type != schema.ChatMessagePartTypeImageURL || msg.AssistantGenMultiContent[1].Image == nil || msg.AssistantGenMultiContent[1].Image.Base64Data == nil { + t.Fatalf("unexpected image output part: %#v", msg.AssistantGenMultiContent[1]) + } + if got := *msg.AssistantGenMultiContent[1].Image.Base64Data; got != "ZmFrZS1pbWFnZQ==" { + t.Fatalf("unexpected image result %q", got) + } + + msg, err = (&ChatModel{}).convertResponseToMessage(&responses.Response{Status: responses.ResponseStatusFailed}) + if err != nil { + t.Fatalf("unexpected error for minimal response: %v", err) + } + if msg.ResponseMeta == nil || msg.ResponseMeta.FinishReason != string(responses.ResponseStatusFailed) { + t.Fatalf("unexpected finish reason: %#v", msg.ResponseMeta) + } + + if _, err := (&ChatModel{}).convertResponseToMessage(nil); err == nil || !strings.Contains(err.Error(), "nil response") { + t.Fatalf("expected nil response error, got %v", err) + } +} + +func TestTokenUsageConversions(t *testing.T) { + if got := toEinoTokenUsage(responses.ResponseUsage{}); got != nil { + t.Fatalf("expected nil usage for zero-value response usage, got %#v", got) + } + + usage := responses.ResponseUsage{ + InputTokens: 4, + OutputTokens: 6, + TotalTokens: 10, + InputTokensDetails: responses.ResponseUsageInputTokensDetails{CachedTokens: 1}, + OutputTokensDetails: responses.ResponseUsageOutputTokensDetails{ReasoningTokens: 2}, + } + converted := toEinoTokenUsage(usage) + if converted == nil || converted.PromptTokens != 4 || converted.CompletionTokens != 6 || converted.TotalTokens != 10 { + t.Fatalf("unexpected converted usage: %#v", converted) + } + modelUsage := toModelTokenUsage(&schema.ResponseMeta{Usage: converted}) + if modelUsage == nil || modelUsage.PromptTokens != 4 || modelUsage.CompletionTokensDetails.ReasoningTokens != 2 { + t.Fatalf("unexpected model token usage: %#v", modelUsage) + } + if got := toModelTokenUsage(nil); got != nil { + t.Fatalf("expected nil model token usage when meta is nil, got %#v", got) + } + if got := toModelTokenUsage(&schema.ResponseMeta{}); got != nil { + t.Fatalf("expected nil model token usage when usage is nil, got %#v", got) + } +} + +func TestOptionHelpers(t *testing.T) { + got := model.GetImplSpecificOptions(&options{}, + WithMaxOutputTokens(123), + WithReasoning(&Reasoning{Effort: ReasoningEffortMinimal, Summary: ReasoningSummaryAuto}), + WithStore(true), + WithMetadata(map[string]string{"k": "v"}), + WithExtraFields(map[string]any{"x": 1}), + ) + if got.MaxOutputTokens == nil || *got.MaxOutputTokens != 123 { + t.Fatalf("unexpected max output tokens: %#v", got.MaxOutputTokens) + } + if got.Reasoning == nil || got.Reasoning.Effort != ReasoningEffortMinimal { + t.Fatalf("unexpected reasoning option: %#v", got.Reasoning) + } + if got.Store == nil || *got.Store != true { + t.Fatalf("unexpected store option: %#v", got.Store) + } + if got.Metadata["k"] != "v" || got.ExtraFields["x"] != 1 { + t.Fatalf("unexpected option maps: %#v %#v", got.Metadata, got.ExtraFields) + } + + meta := map[string]string{"k": "v"} + extra := map[string]any{"x": 1} + got = model.GetImplSpecificOptions(&options{}, WithMetadata(meta), WithExtraFields(extra)) + meta["k"] = "changed" + extra["x"] = 2 + if got.Metadata["k"] != "v" || got.ExtraFields["x"] != 1 { + t.Fatalf("expected cloned maps, got %#v %#v", got.Metadata, got.ExtraFields) + } +} + +func TestReasoningToSDK(t *testing.T) { + if got := (*Reasoning)(nil).toSDK(); got.Effort != "" || got.Summary != "" { + t.Fatalf("expected zero value reasoning param, got %#v", got) + } + got := (&Reasoning{Effort: ReasoningEffortXHigh, Summary: ReasoningSummaryDetailed}).toSDK() + if string(got.Effort) != "xhigh" || string(got.Summary) != "detailed" { + t.Fatalf("unexpected sdk reasoning: %#v", got) + } +} + +func makeWeatherTool() *schema.ToolInfo { + return &schema.ToolInfo{ + Name: "lookup_weather", + Desc: "Look up weather by city", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "city": {Type: schema.String, Desc: "City name", Required: true}, + "unit": {Type: schema.String, Desc: "Temperature unit", Enum: []string{"celsius", "fahrenheit"}}, + }), + } +} + +func makeLookupTool() *schema.ToolInfo { + return &schema.ToolInfo{ + Name: "lookup_stock", + Desc: "Look up stock prices", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "ticker": {Type: schema.String, Required: true}, + }), + } +} + +func mustJSON[T any](t *testing.T, v any) T { + t.Helper() + b, err := json.Marshal(v) + if err != nil { + t.Fatalf("marshal %T: %v", v, err) + } + var out T + if err := json.Unmarshal(b, &out); err != nil { + t.Fatalf("unmarshal into target: %v", err) + } + return out +} diff --git a/components/model/openai-go/types.go b/components/model/openai-go/types.go new file mode 100644 index 000000000..36ea391dc --- /dev/null +++ b/components/model/openai-go/types.go @@ -0,0 +1,55 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openaigo + +import "github.com/openai/openai-go/v3/shared" + +type ReasoningEffort string + +const ( + ReasoningEffortNone ReasoningEffort = "none" + ReasoningEffortMinimal ReasoningEffort = "minimal" + ReasoningEffortLow ReasoningEffort = "low" + ReasoningEffortMedium ReasoningEffort = "medium" + ReasoningEffortHigh ReasoningEffort = "high" + ReasoningEffortXHigh ReasoningEffort = "xhigh" +) + +type ReasoningSummary string + +const ( + ReasoningSummaryAuto ReasoningSummary = "auto" + ReasoningSummaryConcise ReasoningSummary = "concise" + ReasoningSummaryDetailed ReasoningSummary = "detailed" +) + +// Reasoning config for Responses API reasoning models. +// Maps to openai-go's shared.ReasoningParam. +type Reasoning struct { + Effort ReasoningEffort `json:"effort,omitempty"` + Summary ReasoningSummary `json:"summary,omitempty"` +} + +func (r *Reasoning) toSDK() shared.ReasoningParam { + if r == nil { + return shared.ReasoningParam{} + } + return shared.ReasoningParam{ + Effort: shared.ReasoningEffort(r.Effort), + Summary: shared.ReasoningSummary(r.Summary), + } +} diff --git a/components/model/openai-go/util.go b/components/model/openai-go/util.go new file mode 100644 index 000000000..90724a5da --- /dev/null +++ b/components/model/openai-go/util.go @@ -0,0 +1,340 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openaigo + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/cloudwego/eino/schema" + "github.com/openai/openai-go/v3/packages/param" + "github.com/openai/openai-go/v3/responses" + "github.com/openai/openai-go/v3/shared" +) + +// Assistant messages are previous model outputs. The Responses API is strict: +// when role=assistant, content parts must be of type "output_text"/"refusal", +// not "input_text". +// +// The openai-go SDK's easiest compatible representation is to send assistant +// content as a plain string (not a typed content-part list). +// We therefore: +// - allow text-only assistant history (as string) +// - reject non-text assistant multimodal content when re-sending history +func extractAssistantTextForHistory(msg *schema.Message) (text string, ok bool, err error) { + if msg == nil { + return "", false, nil + } + + // Prefer the canonical Content field. + if msg.Content != "" { + return msg.Content, true, nil + } + + // If Content is empty, attempt to derive text from multi-content. + // If any non-text part exists, we fail fast to avoid producing invalid request bodies. + if len(msg.AssistantGenMultiContent) > 0 { + var b strings.Builder + for _, part := range msg.AssistantGenMultiContent { + if part.Type != schema.ChatMessagePartTypeText { + return "", false, fmt.Errorf("assistant history contains non-text part (%s); cannot re-send as Responses API input", part.Type) + } + if part.Text == "" { + continue + } + if b.Len() > 0 { + b.WriteString("\n") + } + b.WriteString(part.Text) + } + if b.Len() > 0 { + return b.String(), true, nil + } + } + + // Deprecated MultiContent. + if len(msg.MultiContent) > 0 { + var b strings.Builder + for _, c := range msg.MultiContent { + if c.Type != schema.ChatMessagePartTypeText { + return "", false, fmt.Errorf("assistant history contains deprecated MultiContent non-text part (%s); cannot re-send", c.Type) + } + if c.Text == "" { + continue + } + if b.Len() > 0 { + b.WriteString("\n") + } + b.WriteString(c.Text) + } + if b.Len() > 0 { + return b.String(), true, nil + } + } + + // Do not attempt to re-send UserInputMultiContent on assistant messages. + if len(msg.UserInputMultiContent) > 0 { + return "", false, fmt.Errorf("assistant history contains UserInputMultiContent; cannot re-send as Responses API input") + } + + return "", false, nil +} + +func toolNameExists(tools []responses.ToolUnionParam, name string) bool { + for _, t := range tools { + if t.OfFunction != nil && t.OfFunction.Name == name { + return true + } + } + return false +} + +// The OpenAI Responses API requires strict tool schemas to include: +// - type: "object" +// - properties: {...} +// - additionalProperties: false +// - required: [all keys in properties] +// +// Many JSON Schema generators omit "required" for fields tagged with `omitempty`. +func enforceOpenAIStrictJSONSchema(schema map[string]any) { + if schema == nil { + return + } + + // Recurse into nested schemas first. + if items, ok := schema["items"]; ok { + switch v := items.(type) { + case map[string]any: + enforceOpenAIStrictJSONSchema(v) + case []any: + for _, it := range v { + if m, ok := it.(map[string]any); ok { + enforceOpenAIStrictJSONSchema(m) + } + } + } + } + if props, ok := schema["properties"].(map[string]any); ok { + for _, pv := range props { + if pm, ok := pv.(map[string]any); ok { + enforceOpenAIStrictJSONSchema(pm) + } + } + } + if oneOf, ok := schema["oneOf"].([]any); ok { + for _, ov := range oneOf { + if om, ok := ov.(map[string]any); ok { + enforceOpenAIStrictJSONSchema(om) + } + } + } + if anyOf, ok := schema["anyOf"].([]any); ok { + for _, av := range anyOf { + if am, ok := av.(map[string]any); ok { + enforceOpenAIStrictJSONSchema(am) + } + } + } + if allOf, ok := schema["allOf"].([]any); ok { + for _, av := range allOf { + if am, ok := av.(map[string]any); ok { + enforceOpenAIStrictJSONSchema(am) + } + } + } + + // Now enforce strictness for object schemas. + props, ok := schema["properties"].(map[string]any) + if !ok || len(props) == 0 { + return + } + + // Ensure type is object (some generators omit it at the top level). + if _, ok := schema["type"]; !ok { + schema["type"] = "object" + } + + // OpenAI strict schema expects additionalProperties=false. + if _, ok := schema["additionalProperties"]; !ok { + schema["additionalProperties"] = false + } + + // Ensure required includes *all* keys in properties. + existing := map[string]struct{}{} + if req, ok := schema["required"]; ok { + switch v := req.(type) { + case []any: + for _, it := range v { + if s, ok := it.(string); ok { + existing[s] = struct{}{} + } + } + case []string: + for _, s := range v { + existing[s] = struct{}{} + } + } + } + + required := make([]any, 0, len(props)) + for k := range props { + if _, ok := existing[k]; !ok { + existing[k] = struct{}{} + } + } + for k := range existing { + required = append(required, k) + } + + // If there were no required keys produced for some reason, at least include all properties. + if len(required) == 0 { + for k := range props { + required = append(required, k) + } + } + + schema["required"] = required +} + +func jsonSchemaToMap(s any) (map[string]any, error) { + if s == nil { + return map[string]any{}, nil + } + // jsonschema.Schema has json tags; encode/decode to map[string]any. + b, err := json.Marshal(s) + if err != nil { + return nil, err + } + var m map[string]any + if err := json.Unmarshal(b, &m); err != nil { + return nil, err + } + return m, nil +} + +func commonToDataOrURL(common schema.MessagePartCommon) (string, error) { + if common.URL == nil && common.Base64Data == nil { + return "", fmt.Errorf("message part must have URL or Base64Data") + } + if common.URL != nil { + return *common.URL, nil + } + if common.MIMEType == "" { + return "", fmt.Errorf("message part must have MIMEType when using Base64Data") + } + if strings.HasPrefix(*common.Base64Data, "data:") { + return "", fmt.Errorf("base64Data must be raw base64 without 'data:' prefix") + } + return fmt.Sprintf("data:%s;base64,%s", common.MIMEType, *common.Base64Data), nil +} + +func joinReasoningText(item responses.ResponseReasoningItem) string { + // Summary is often what people want. + if len(item.Summary) > 0 { + var b strings.Builder + for i, s := range item.Summary { + if s.Text == "" { + continue + } + if i > 0 { + b.WriteString("\n\n") + } + b.WriteString(s.Text) + } + out := b.String() + if out != "" { + return out + } + } + + if len(item.Content) > 0 { + var b strings.Builder + for i, c := range item.Content { + if c.Text == "" { + continue + } + if i > 0 { + b.WriteString("\n\n") + } + b.WriteString(c.Text) + } + return b.String() + } + + return "" +} + +func ensureResponseMeta(meta *schema.ResponseMeta) *schema.ResponseMeta { + if meta == nil { + return &schema.ResponseMeta{} + } + return meta +} + +func responsesModelFromString(s string) responses.ResponsesModel { return shared.ResponsesModel(s) } + +func optInt64(v param.Opt[int64]) int64 { + if v.Valid() { + return v.Value + } + return 0 +} + +func optFloat64(v param.Opt[float64]) float64 { + if v.Valid() { + return v.Value + } + return 0 +} + +const callbackExtraModelName = "model_name" + +type panicErr struct { + info any + stack []byte +} + +func (p *panicErr) Error() string { + return fmt.Sprintf("panic error: %v, \nstack: %s", p.info, string(p.stack)) +} + +func newPanicErr(info any, stack []byte) error { + return &panicErr{info: info, stack: stack} +} + +func cloneStringMap(in map[string]string) map[string]string { + if in == nil { + return nil + } + out := make(map[string]string, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func cloneAnyMap(in map[string]any) map[string]any { + if in == nil { + return nil + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} diff --git a/components/model/openai-go/util_test.go b/components/model/openai-go/util_test.go new file mode 100644 index 000000000..b4932ea12 --- /dev/null +++ b/components/model/openai-go/util_test.go @@ -0,0 +1,217 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openaigo + +import ( + "strings" + "testing" + + "github.com/cloudwego/eino/schema" + jschema "github.com/eino-contrib/jsonschema" + openai "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/packages/param" + "github.com/openai/openai-go/v3/responses" +) + +func TestExtractAssistantTextForHistory(t *testing.T) { + text, ok, err := extractAssistantTextForHistory(&schema.Message{Content: "hello"}) + if err != nil || !ok || text != "hello" { + t.Fatalf("unexpected content extraction result text=%q ok=%v err=%v", text, ok, err) + } + + text, ok, err = extractAssistantTextForHistory(&schema.Message{AssistantGenMultiContent: []schema.MessageOutputPart{{Type: schema.ChatMessagePartTypeText, Text: "one"}, {Type: schema.ChatMessagePartTypeText, Text: "two"}}}) + if err != nil || !ok || text != "one\ntwo" { + t.Fatalf("unexpected assistant multi-content extraction: text=%q ok=%v err=%v", text, ok, err) + } + + text, ok, err = extractAssistantTextForHistory(&schema.Message{MultiContent: []schema.ChatMessagePart{{Type: schema.ChatMessagePartTypeText, Text: "legacy1"}, {Type: schema.ChatMessagePartTypeText, Text: "legacy2"}}}) + if err != nil || !ok || text != "legacy1\nlegacy2" { + t.Fatalf("unexpected deprecated multi-content extraction: text=%q ok=%v err=%v", text, ok, err) + } + + text, ok, err = extractAssistantTextForHistory(nil) + if err != nil || ok || text != "" { + t.Fatalf("expected nil message to produce no text, got text=%q ok=%v err=%v", text, ok, err) + } + + tests := []struct { + name string + msg *schema.Message + want string + }{ + {name: "assistant non-text output", msg: &schema.Message{AssistantGenMultiContent: []schema.MessageOutputPart{{Type: schema.ChatMessagePartTypeImageURL}}}, want: "non-text part"}, + {name: "deprecated non-text output", msg: &schema.Message{MultiContent: []schema.ChatMessagePart{{Type: schema.ChatMessagePartTypeImageURL}}}, want: "deprecated MultiContent non-text part"}, + {name: "assistant user input content", msg: &schema.Message{UserInputMultiContent: []schema.MessageInputPart{{Type: schema.ChatMessagePartTypeText, Text: "bad"}}}, want: "UserInputMultiContent"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, _, err := extractAssistantTextForHistory(tt.msg) + if err == nil || !strings.Contains(err.Error(), tt.want) { + t.Fatalf("expected error containing %q, got %v", tt.want, err) + } + }) + } +} + +func TestToolNameExists(t *testing.T) { + tools := []responses.ToolUnionParam{responses.ToolParamOfFunction("lookup_weather", map[string]any{"type": "object"}, true)} + if !toolNameExists(tools, "lookup_weather") { + t.Fatalf("expected tool to exist") + } + if toolNameExists(tools, "missing") { + t.Fatalf("did not expect missing tool to exist") + } +} + +func TestEnforceOpenAIStrictJSONSchema(t *testing.T) { + schemaMap := map[string]any{ + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "profile": map[string]any{ + "properties": map[string]any{ + "age": map[string]any{"type": "integer"}, + }, + }, + }, + "items": map[string]any{ + "properties": map[string]any{ + "nested": map[string]any{"type": "string"}, + }, + }, + } + enforceOpenAIStrictJSONSchema(schemaMap) + if schemaMap["type"] != "object" || schemaMap["additionalProperties"] != false { + t.Fatalf("expected strict object schema, got %#v", schemaMap) + } + if required, ok := schemaMap["required"].([]any); !ok || len(required) != 2 { + t.Fatalf("expected required keys for all properties, got %#v", schemaMap["required"]) + } + profile := schemaMap["properties"].(map[string]any)["profile"].(map[string]any) + if profile["type"] != "object" || profile["additionalProperties"] != false { + t.Fatalf("expected nested strict schema, got %#v", profile) + } +} + +func TestJSONSchemaToMap(t *testing.T) { + m, err := jsonSchemaToMap(&jschema.Schema{Type: "object", Description: "demo"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if m["type"] != "object" || m["description"] != "demo" { + t.Fatalf("unexpected schema map: %#v", m) + } + m, err = jsonSchemaToMap(nil) + if err != nil { + t.Fatalf("unexpected error for nil schema: %v", err) + } + if len(m) != 0 { + t.Fatalf("expected empty map for nil schema, got %#v", m) + } +} + +func TestCommonToDataOrURL(t *testing.T) { + url := "https://example.com/image.png" + got, err := commonToDataOrURL(schema.MessagePartCommon{URL: &url}) + if err != nil || got != url { + t.Fatalf("expected url passthrough, got %q err=%v", got, err) + } + b64 := "SGVsbG8=" + got, err = commonToDataOrURL(schema.MessagePartCommon{Base64Data: &b64, MIMEType: "text/plain"}) + if err != nil || got != "data:text/plain;base64,SGVsbG8=" { + t.Fatalf("unexpected data url: %q err=%v", got, err) + } + badPrefixed := "data:text/plain;base64,SGVsbG8=" + tests := []struct { + name string + common schema.MessagePartCommon + want string + }{ + {name: "missing source", common: schema.MessagePartCommon{}, want: "URL or Base64Data"}, + {name: "missing mime", common: schema.MessagePartCommon{Base64Data: &b64}, want: "MIMEType"}, + {name: "prefixed data", common: schema.MessagePartCommon{Base64Data: &badPrefixed, MIMEType: "text/plain"}, want: "raw base64"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := commonToDataOrURL(tt.common) + if err == nil || !strings.Contains(err.Error(), tt.want) { + t.Fatalf("expected error containing %q, got %v", tt.want, err) + } + }) + } +} + +func TestJoinReasoningText(t *testing.T) { + item := responses.ResponseReasoningItem{Summary: []responses.ResponseReasoningItemSummary{{Text: "s1"}, {Text: "s2"}}} + if got := joinReasoningText(item); got != "s1\n\ns2" { + t.Fatalf("unexpected summary text %q", got) + } + item = responses.ResponseReasoningItem{Content: []responses.ResponseReasoningItemContent{{Text: "c1"}, {Text: "c2"}}} + if got := joinReasoningText(item); got != "c1\n\nc2" { + t.Fatalf("unexpected content text %q", got) + } + if got := joinReasoningText(responses.ResponseReasoningItem{}); got != "" { + t.Fatalf("expected empty reasoning text, got %q", got) + } +} + +func TestEnsureResponseMetaAndOptHelpers(t *testing.T) { + meta := ensureResponseMeta(nil) + if meta == nil { + t.Fatalf("expected response meta to be initialized") + } + meta2 := &schema.ResponseMeta{} + if ensureResponseMeta(meta2) != meta2 { + t.Fatalf("expected existing meta to be returned as-is") + } + if got := responsesModelFromString("gpt-4o-mini"); got != "gpt-4o-mini" { + t.Fatalf("unexpected model conversion %q", got) + } + if got := optInt64(param.Opt[int64]{}); got != 0 { + t.Fatalf("expected zero int64 opt, got %d", got) + } + if got := optInt64(openai.Int(7)); got != 7 { + t.Fatalf("expected int64 opt value 7, got %d", got) + } + if got := optFloat64(param.Opt[float64]{}); got != 0 { + t.Fatalf("expected zero float64 opt, got %v", got) + } + if got := optFloat64(openai.Float(1.5)); got != 1.5 { + t.Fatalf("expected float64 opt value 1.5, got %v", got) + } +} + +func TestPanicAndCloneHelpers(t *testing.T) { + err := newPanicErr("boom", []byte("stacktrace")) + if err == nil || !strings.Contains(err.Error(), "boom") || !strings.Contains(err.Error(), "stacktrace") { + t.Fatalf("unexpected panic error %v", err) + } + if got := cloneStringMap(nil); got != nil { + t.Fatalf("expected nil clone for nil string map, got %#v", got) + } + if got := cloneAnyMap(nil); got != nil { + t.Fatalf("expected nil clone for nil any map, got %#v", got) + } + stringMap := map[string]string{"a": "b"} + anyMap := map[string]any{"x": 1} + cloneS := cloneStringMap(stringMap) + cloneA := cloneAnyMap(anyMap) + stringMap["a"] = "changed" + anyMap["x"] = 2 + if cloneS["a"] != "b" || cloneA["x"] != 1 { + t.Fatalf("expected clones to be independent, got %#v %#v", cloneS, cloneA) + } +}