Skip to content

Commit 560049b

Browse files
committed
feat: implement extra body support in chat requests and enhance testing
- Added support for an ExtraBody field in ChatRequest to allow additional parameters in API requests, enabling provider-specific configurations. - Updated createChatRequest to merge ExtraBody with standard fields, ensuring that extra parameters can override existing ones. - Introduced new tests for WithExtraBody and getExtraBody functions to validate the handling of additional fields in CallOptions. - Enhanced existing tests to cover scenarios involving extra body fields, ensuring proper integration and functionality.
1 parent f2c57c4 commit 560049b

6 files changed

Lines changed: 350 additions & 12 deletions

File tree

llms/openai/internal/openaiclient/chat.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"errors"
99
"fmt"
1010
"io"
11+
"maps"
1112
"net/http"
1213
"strings"
1314

@@ -101,6 +102,10 @@ type ChatRequest struct {
101102
// WebSearchOptions configures web search behavior for search-enabled models
102103
// like gpt-4o-search-preview and gpt-4o-mini-search-preview.
103104
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
105+
106+
// ExtraBody allows passing additional fields that will be merged into the request body.
107+
// These fields take precedence over the standard fields.
108+
ExtraBody map[string]any `json:"-"`
104109
}
105110

106111
// ToolType is the type of a tool.
@@ -519,10 +524,26 @@ func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatCom
519524

520525
// Restore original metadata
521526
payload.Metadata = originalMetadata
527+
522528
if err != nil {
523529
return nil, err
524530
}
525531

532+
// If ExtraBody is provided, merge it with the standard payload
533+
if len(payload.ExtraBody) > 0 {
534+
var baseMap map[string]any
535+
if err := json.Unmarshal(payloadBytes, &baseMap); err != nil {
536+
return nil, err
537+
}
538+
539+
// Merge ExtraBody with priority (ExtraBody overwrites existing fields)
540+
maps.Copy(baseMap, payload.ExtraBody)
541+
542+
if payloadBytes, err = json.Marshal(baseMap); err != nil {
543+
return nil, err
544+
}
545+
}
546+
526547
// Build request
527548
body := bytes.NewReader(payloadBytes)
528549
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL("/chat/completions", payload.Model), body)

llms/openai/internal/openaiclient/marshal_test.go

Lines changed: 118 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@ package openaiclient
22

33
import (
44
"encoding/json"
5+
"maps"
56
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
610
)
711

812
func TestChatRequest_TemperatureMarshalJSON(t *testing.T) {
@@ -43,7 +47,7 @@ func TestChatRequest_TemperatureMarshalJSON(t *testing.T) {
4347
t.Fatalf("failed to marshal: %v", err)
4448
}
4549

46-
var result map[string]interface{}
50+
var result map[string]any
4751
if err := json.Unmarshal(data, &result); err != nil {
4852
t.Fatalf("failed to unmarshal: %v", err)
4953
}
@@ -76,7 +80,7 @@ func TestChatRequest_WebSearchOptionsMarshalJSON(t *testing.T) {
7680
tests := []struct {
7781
name string
7882
request ChatRequest
79-
want map[string]interface{}
83+
want map[string]any
8084
}{
8185
{
8286
name: "no web search options",
@@ -91,7 +95,7 @@ func TestChatRequest_WebSearchOptionsMarshalJSON(t *testing.T) {
9195
Model: "gpt-4o-search-preview",
9296
WebSearchOptions: &WebSearchOptions{},
9397
},
94-
want: map[string]interface{}{},
98+
want: map[string]any{},
9599
},
96100
{
97101
name: "web search with search context size",
@@ -101,7 +105,7 @@ func TestChatRequest_WebSearchOptionsMarshalJSON(t *testing.T) {
101105
SearchContextSize: "high",
102106
},
103107
},
104-
want: map[string]interface{}{
108+
want: map[string]any{
105109
"search_context_size": "high",
106110
},
107111
},
@@ -121,11 +125,11 @@ func TestChatRequest_WebSearchOptionsMarshalJSON(t *testing.T) {
121125
},
122126
},
123127
},
124-
want: map[string]interface{}{
128+
want: map[string]any{
125129
"search_context_size": "medium",
126-
"user_location": map[string]interface{}{
130+
"user_location": map[string]any{
127131
"type": "approximate",
128-
"approximate": map[string]interface{}{
132+
"approximate": map[string]any{
129133
"country": "US",
130134
"city": "San Francisco",
131135
"region": "California",
@@ -142,7 +146,7 @@ func TestChatRequest_WebSearchOptionsMarshalJSON(t *testing.T) {
142146
t.Fatalf("failed to marshal: %v", err)
143147
}
144148

145-
var result map[string]interface{}
149+
var result map[string]any
146150
if err := json.Unmarshal(data, &result); err != nil {
147151
t.Fatalf("failed to unmarshal: %v", err)
148152
}
@@ -157,7 +161,7 @@ func TestChatRequest_WebSearchOptionsMarshalJSON(t *testing.T) {
157161
t.Fatal("expected web_search_options to be present")
158162
}
159163
// Check that it's properly serialized
160-
webSearchMap, ok := webSearchOpts.(map[string]interface{})
164+
webSearchMap, ok := webSearchOpts.(map[string]any)
161165
if !ok {
162166
t.Fatalf("web_search_options is not a map: %T", webSearchOpts)
163167
}
@@ -168,11 +172,11 @@ func TestChatRequest_WebSearchOptionsMarshalJSON(t *testing.T) {
168172
}
169173
}
170174
if tt.want["user_location"] != nil {
171-
userLoc, ok := webSearchMap["user_location"].(map[string]interface{})
175+
userLoc, ok := webSearchMap["user_location"].(map[string]any)
172176
if !ok {
173177
t.Fatalf("user_location is not a map: %T", webSearchMap["user_location"])
174178
}
175-
wantUserLoc := tt.want["user_location"].(map[string]interface{})
179+
wantUserLoc := tt.want["user_location"].(map[string]any)
176180
if userLoc["type"] != wantUserLoc["type"] {
177181
t.Errorf("user_location.type: got %v, want %v", userLoc["type"], wantUserLoc["type"])
178182
}
@@ -189,3 +193,106 @@ func getFloatPointer(f float64) *float64 {
189193
func getIntPointer(i int) *int {
190194
return &i
191195
}
196+
197+
func TestChatRequest_ExtraBodyMarshal(t *testing.T) {
198+
tests := []struct {
199+
name string
200+
request ChatRequest
201+
extraBody map[string]any
202+
checkFunc func(t *testing.T, result map[string]any)
203+
}{
204+
{
205+
name: "extra body fields are added",
206+
request: ChatRequest{
207+
Model: "gpt-4",
208+
},
209+
extraBody: map[string]any{
210+
"enable_thinking": false,
211+
"top_k": 20,
212+
},
213+
checkFunc: func(t *testing.T, result map[string]any) {
214+
assert.Equal(t, "gpt-4", result["model"])
215+
assert.Equal(t, false, result["enable_thinking"])
216+
assert.Equal(t, float64(20), result["top_k"])
217+
},
218+
},
219+
{
220+
name: "extra body overrides existing fields",
221+
request: ChatRequest{
222+
Model: "gpt-4",
223+
TopK: getIntPointer(10),
224+
},
225+
extraBody: map[string]any{
226+
"top_k": 20,
227+
},
228+
checkFunc: func(t *testing.T, result map[string]any) {
229+
assert.Equal(t, float64(20), result["top_k"], "ExtraBody should override existing top_k")
230+
},
231+
},
232+
{
233+
name: "nested objects in extra body",
234+
request: ChatRequest{
235+
Model: "gpt-4",
236+
},
237+
extraBody: map[string]any{
238+
"chat_template_kwargs": map[string]any{
239+
"enable_thinking": false,
240+
"custom_setting": "value",
241+
},
242+
},
243+
checkFunc: func(t *testing.T, result map[string]any) {
244+
kwargs, ok := result["chat_template_kwargs"].(map[string]any)
245+
require.True(t, ok, "chat_template_kwargs should be a map")
246+
assert.Equal(t, false, kwargs["enable_thinking"])
247+
assert.Equal(t, "value", kwargs["custom_setting"])
248+
},
249+
},
250+
{
251+
name: "no extra body",
252+
request: ChatRequest{
253+
Model: "gpt-4",
254+
},
255+
extraBody: nil,
256+
checkFunc: func(t *testing.T, result map[string]any) {
257+
assert.Equal(t, "gpt-4", result["model"])
258+
_, hasExtraField := result["enable_thinking"]
259+
assert.False(t, hasExtraField)
260+
},
261+
},
262+
}
263+
264+
for _, tt := range tests {
265+
t.Run(tt.name, func(t *testing.T) {
266+
// This test simulates the merging logic that will be in createChat
267+
tt.request.ExtraBody = tt.extraBody
268+
269+
// Step 1: Marshal without ExtraBody (standard fields)
270+
tempExtraBody := tt.request.ExtraBody
271+
tt.request.ExtraBody = nil
272+
273+
data, err := json.Marshal(tt.request)
274+
require.NoError(t, err)
275+
276+
var result map[string]any
277+
err = json.Unmarshal(data, &result)
278+
require.NoError(t, err)
279+
280+
// Step 2: Merge ExtraBody if present
281+
if len(tempExtraBody) > 0 {
282+
maps.Copy(result, tempExtraBody)
283+
284+
// Re-marshal and unmarshal to ensure proper type conversion
285+
// (This simulates what actually happens in the real code)
286+
data, err = json.Marshal(result)
287+
require.NoError(t, err)
288+
err = json.Unmarshal(data, &result)
289+
require.NoError(t, err)
290+
}
291+
292+
// Verify the result
293+
if tt.checkFunc != nil {
294+
tt.checkFunc(t, result)
295+
}
296+
})
297+
}
298+
}

llms/openai/openaillm.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ func (o *LLM) createChatRequest(chatMsgs []*ChatMessage, opts llms.CallOptions)
231231
Seed: opts.Seed,
232232
Metadata: opts.Metadata,
233233
WebSearchOptions: webSearchOptionsFromCallOptions(opts.WebSearchOptions),
234+
ExtraBody: getExtraBody(&opts),
234235
}
235236

236237
if isLegacyMaxTokensField(&opts) {

llms/openai/openaillm_test.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
package openai
22

33
import (
4+
"encoding/json"
45
"fmt"
6+
"io"
7+
"net/http"
8+
"net/http/httptest"
59
"testing"
610

711
"github.com/stretchr/testify/assert"
@@ -956,3 +960,102 @@ func TestCreateChatRequest_ReasoningModelTemperature(t *testing.T) {
956960
})
957961
}
958962
}
963+
964+
func TestExtraBody_Integration(t *testing.T) { //nolint:funlen
965+
t.Parallel()
966+
967+
var receivedRequest map[string]any
968+
969+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
970+
body, _ := io.ReadAll(r.Body)
971+
_ = json.Unmarshal(body, &receivedRequest)
972+
973+
response := map[string]any{
974+
"id": "test-id",
975+
"choices": []map[string]any{
976+
{
977+
"index": 0,
978+
"message": map[string]any{
979+
"role": "assistant",
980+
"content": "test response",
981+
},
982+
"finish_reason": "stop",
983+
},
984+
},
985+
}
986+
w.Header().Set("Content-Type", "application/json")
987+
_ = json.NewEncoder(w).Encode(response)
988+
}))
989+
defer server.Close()
990+
991+
llm, err := New(
992+
WithToken("test-token"),
993+
WithBaseURL(server.URL),
994+
WithModel("test-model"),
995+
)
996+
require.NoError(t, err)
997+
998+
tests := []struct {
999+
name string
1000+
extraBody map[string]any
1001+
checkFunc func(t *testing.T, req map[string]any)
1002+
}{
1003+
{
1004+
name: "simple extra fields",
1005+
extraBody: map[string]any{
1006+
"enable_thinking": false,
1007+
"top_k": 20,
1008+
},
1009+
checkFunc: func(t *testing.T, req map[string]any) {
1010+
assert.Equal(t, false, req["enable_thinking"])
1011+
assert.Equal(t, float64(20), req["top_k"])
1012+
},
1013+
},
1014+
{
1015+
name: "nested extra fields",
1016+
extraBody: map[string]any{
1017+
"chat_template_kwargs": map[string]any{
1018+
"enable_thinking": false,
1019+
},
1020+
},
1021+
checkFunc: func(t *testing.T, req map[string]any) {
1022+
kwargs, ok := req["chat_template_kwargs"].(map[string]any)
1023+
require.True(t, ok)
1024+
assert.Equal(t, false, kwargs["enable_thinking"])
1025+
},
1026+
},
1027+
{
1028+
name: "extra body overrides standard field",
1029+
extraBody: map[string]any{
1030+
"temperature": 0.9,
1031+
},
1032+
checkFunc: func(t *testing.T, req map[string]any) {
1033+
assert.Equal(t, 0.9, req["temperature"])
1034+
},
1035+
},
1036+
}
1037+
1038+
for _, tt := range tests {
1039+
t.Run(tt.name, func(t *testing.T) {
1040+
receivedRequest = nil
1041+
1042+
messages := []llms.MessageContent{
1043+
{
1044+
Role: llms.ChatMessageTypeHuman,
1045+
Parts: []llms.ContentPart{
1046+
llms.TextContent{Text: "test message"},
1047+
},
1048+
},
1049+
}
1050+
1051+
_, err := llm.GenerateContent(t.Context(), messages,
1052+
WithExtraBody(tt.extraBody),
1053+
llms.WithTemperature(0.7),
1054+
)
1055+
require.NoError(t, err)
1056+
require.NotNil(t, receivedRequest)
1057+
1058+
tt.checkFunc(t, receivedRequest)
1059+
})
1060+
}
1061+
}

0 commit comments

Comments
 (0)