Skip to content

Commit 32f3da4

Browse files
committed
feat: add custom tool type support in Google AI tests and enhance parameter conversion
1 parent 0d47609 commit 32f3da4

4 files changed

Lines changed: 195 additions & 9 deletions

File tree

llms/googleai/googleai.go

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -505,30 +505,49 @@ func convertTools(tools []llms.Tool) ([]*genai.Tool, error) {
505505
}}, nil
506506
}
507507

508-
func convertMaps(i any) any {
508+
func convertMaps(i any) (any, error) {
509+
var err error
509510
switch v := i.(type) {
510511
case map[any]any:
511512
m := make(map[string]any)
512513
for key, val := range v {
513514
sKey, ok := key.(string)
514515
if !ok {
515-
return v
516+
return v, nil
517+
}
518+
m[sKey], err = convertMaps(val)
519+
if err != nil {
520+
return nil, err
516521
}
517-
m[sKey] = convertMaps(val)
518522
}
519-
return m
523+
return m, nil
520524
case []any:
521525
s := make([]any, len(v))
522526
for idx, val := range v {
523-
s[idx] = convertMaps(val)
527+
s[idx], err = convertMaps(val)
528+
if err != nil {
529+
s[idx] = val
530+
}
524531
}
525-
return s
532+
return s, nil
533+
default:
534+
d, err := json.Marshal(i)
535+
if err != nil {
536+
return i, err
537+
}
538+
var m any
539+
if err := json.Unmarshal(d, &m); err != nil {
540+
return i, err
541+
}
542+
return m, nil
526543
}
527-
return i
528544
}
529545

530546
func convertToSchema(e any, topLevel bool) (*genai.Schema, error) {
531-
e = convertMaps(e)
547+
e, err := convertMaps(e)
548+
if err != nil {
549+
return nil, err
550+
}
532551
schema := &genai.Schema{}
533552

534553
eMap, ok := e.(map[string]any)

llms/googleai/googleai_test.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package googleai
22

33
import (
44
"context"
5+
"encoding/json"
56
"net/http"
67
"os"
78
"strings"
@@ -296,6 +297,84 @@ func TestGoogleAIWithTools(t *testing.T) {
296297
assert.NotEmpty(t, toolCall.ID, "ToolCall ID should not be empty")
297298
assert.Equal(t, "getWeather", toolCall.FunctionCall.Name)
298299
assert.Contains(t, toolCall.FunctionCall.Arguments, "New York")
300+
} else {
301+
t.Fail()
302+
}
303+
}
304+
305+
func TestGoogleAIWithToolsCustomType(t *testing.T) {
306+
llm := newHTTPRRClient(t)
307+
308+
type toolCallArgs struct {
309+
Location string `json:"location"`
310+
}
311+
312+
type property struct {
313+
Type string `json:"type"`
314+
Description string `json:"description"`
315+
}
316+
317+
type properties struct {
318+
Location property `json:"location"`
319+
}
320+
321+
type parameters struct {
322+
Type string `json:"type"`
323+
Properties properties `json:"properties"`
324+
Required []string `json:"required"`
325+
}
326+
327+
tools := []llms.Tool{
328+
{
329+
Type: "function",
330+
Function: &llms.FunctionDefinition{
331+
Name: "getWeather",
332+
Description: "Get the weather for a location",
333+
Parameters: parameters{
334+
Type: "object",
335+
Properties: properties{
336+
Location: property{
337+
Type: "string",
338+
Description: "The location to get weather for",
339+
},
340+
},
341+
Required: []string{"location"},
342+
},
343+
},
344+
},
345+
}
346+
347+
content := []llms.MessageContent{
348+
{
349+
Role: llms.ChatMessageTypeHuman,
350+
Parts: []llms.ContentPart{
351+
llms.TextPart("What's the weather in New York?"),
352+
},
353+
},
354+
}
355+
356+
resp, err := llm.GenerateContent(
357+
t.Context(),
358+
content,
359+
llms.WithTools(tools),
360+
)
361+
362+
require.NoError(t, err)
363+
require.NotNil(t, resp)
364+
assert.NotEmpty(t, resp.Choices)
365+
366+
// Check if tool call was made
367+
if len(resp.Choices[0].ToolCalls) > 0 {
368+
toolCall := resp.Choices[0].ToolCalls[0]
369+
assert.NotEmpty(t, toolCall.ID, "ToolCall ID should not be empty")
370+
assert.Equal(t, "getWeather", toolCall.FunctionCall.Name)
371+
assert.Contains(t, toolCall.FunctionCall.Arguments, "New York")
372+
var args toolCallArgs
373+
err := json.Unmarshal([]byte(toolCall.FunctionCall.Arguments), &args)
374+
require.NoError(t, err)
375+
assert.Equal(t, "New York", args.Location)
376+
} else {
377+
t.Fail()
299378
}
300379
}
301380

llms/googleai/testdata/TestGoogleAIWithToolsCustomType.httprr

Lines changed: 82 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

llms/googleai/vertex/vertex.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,13 @@ func convertTools(tools []llms.Tool) ([]*genai.Tool, error) {
397397
// extract properties to populate the schema.
398398
params, ok := tool.Function.Parameters.(map[string]any)
399399
if !ok {
400-
return nil, fmt.Errorf("tool [%d]: unsupported type %T of Parameters", i, tool.Function.Parameters)
400+
paramsData, err := json.Marshal(tool.Function.Parameters)
401+
if err != nil {
402+
return nil, fmt.Errorf("tool [%d]: failed to marshal parameters: %w", i, err)
403+
}
404+
if err := json.Unmarshal(paramsData, &params); err != nil {
405+
return nil, fmt.Errorf("tool [%d]: failed to unmarshal parameters: %w", i, err)
406+
}
401407
}
402408

403409
schema := &genai.Schema{}

0 commit comments

Comments
 (0)