@@ -2,6 +2,7 @@ package googleai
22
33import (
44 "context"
5+ "encoding/json"
56 "net/http"
67 "os"
78 "strings"
@@ -296,6 +297,84 @@ func TestGoogleAIWithTools(t *testing.T) {
296297 assert .NotEmpty (t , toolCall .ID , "ToolCall ID should not be empty" )
297298 assert .Equal (t , "getWeather" , toolCall .FunctionCall .Name )
298299 assert .Contains (t , toolCall .FunctionCall .Arguments , "New York" )
300+ } else {
301+ t .Fail ()
302+ }
303+ }
304+
305+ func TestGoogleAIWithToolsCustomType (t * testing.T ) {
306+ llm := newHTTPRRClient (t )
307+
308+ type toolCallArgs struct {
309+ Location string `json:"location"`
310+ }
311+
312+ type property struct {
313+ Type string `json:"type"`
314+ Description string `json:"description"`
315+ }
316+
317+ type properties struct {
318+ Location property `json:"location"`
319+ }
320+
321+ type parameters struct {
322+ Type string `json:"type"`
323+ Properties properties `json:"properties"`
324+ Required []string `json:"required"`
325+ }
326+
327+ tools := []llms.Tool {
328+ {
329+ Type : "function" ,
330+ Function : & llms.FunctionDefinition {
331+ Name : "getWeather" ,
332+ Description : "Get the weather for a location" ,
333+ Parameters : parameters {
334+ Type : "object" ,
335+ Properties : properties {
336+ Location : property {
337+ Type : "string" ,
338+ Description : "The location to get weather for" ,
339+ },
340+ },
341+ Required : []string {"location" },
342+ },
343+ },
344+ },
345+ }
346+
347+ content := []llms.MessageContent {
348+ {
349+ Role : llms .ChatMessageTypeHuman ,
350+ Parts : []llms.ContentPart {
351+ llms .TextPart ("What's the weather in New York?" ),
352+ },
353+ },
354+ }
355+
356+ resp , err := llm .GenerateContent (
357+ t .Context (),
358+ content ,
359+ llms .WithTools (tools ),
360+ )
361+
362+ require .NoError (t , err )
363+ require .NotNil (t , resp )
364+ assert .NotEmpty (t , resp .Choices )
365+
366+ // Check if tool call was made
367+ if len (resp .Choices [0 ].ToolCalls ) > 0 {
368+ toolCall := resp .Choices [0 ].ToolCalls [0 ]
369+ assert .NotEmpty (t , toolCall .ID , "ToolCall ID should not be empty" )
370+ assert .Equal (t , "getWeather" , toolCall .FunctionCall .Name )
371+ assert .Contains (t , toolCall .FunctionCall .Arguments , "New York" )
372+ var args toolCallArgs
373+ err := json .Unmarshal ([]byte (toolCall .FunctionCall .Arguments ), & args )
374+ require .NoError (t , err )
375+ assert .Equal (t , "New York" , args .Location )
376+ } else {
377+ t .Fail ()
299378 }
300379}
301380
0 commit comments