Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions internal/tools/bigquery/bigquerycommon/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,16 @@ func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, proj
// BQTypeStringFromToolType converts a tool parameter type string to a BigQuery standard SQL type string.
func BQTypeStringFromToolType(toolType string) (string, error) {
switch toolType {
case "string":
case parameters.TypeString:
return "STRING", nil
case "integer":
case parameters.TypeInt:
return "INT64", nil
case "float":
case parameters.TypeFloat:
return "FLOAT64", nil
case "boolean":
case parameters.TypeBool:
return "BOOL", nil
case parameters.TypeMap:
return "STRUCT", nil
default:
return "", fmt.Errorf("unsupported tool parameter type for BigQuery: %s", toolType)
}
Expand Down
215 changes: 137 additions & 78 deletions internal/tools/bigquery/bigquerysql/bigquerysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import (
"fmt"
"net/http"
"reflect"
"regexp"
"strconv"
"strings"

bigqueryapi "cloud.google.com/go/bigquery"
yaml "github.com/goccy/go-yaml"
Expand Down Expand Up @@ -117,43 +117,119 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
}

highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters))
lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(t.Parameters))

paramsMap := params.AsMap()
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {
return nil, util.NewAgentError("unable to extract template params", err)
}

for _, p := range t.Parameters {
highLevelParams, lowLevelParams, tbErr := buildQueryParameters(t.Parameters, paramsMap, newStatement)
if tbErr != nil {
return nil, tbErr
}

connProps := []*bigqueryapi.ConnectionProperty{}
if source.BigQuerySession() != nil {
session, err := source.BigQuerySession()(ctx)
if err != nil {
return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err)
}
if session != nil {
// Add session ID to the connection properties for subsequent calls.
connProps = append(connProps, &bigqueryapi.ConnectionProperty{Key: "session_id", Value: session.ID})
}
}

bqClient, restService, err := source.RetrieveClientAndService(accessToken)
if err != nil {
return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err)
}

dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, connProps, source.GetMaximumBytesBilled())
if err != nil {
return nil, util.ProcessGcpError(err)
}

statementType := dryRunJob.Statistics.Query.StatementType
resp, err := source.RunSQL(ctx, bqClient, newStatement, statementType, highLevelParams, connProps)
if err != nil {
return nil, util.ProcessGcpError(err)
}
return resp, nil
}

func buildQueryParameters(paramsMetadata parameters.Parameters, paramsMap map[string]any, statement string) ([]bigqueryapi.QueryParameter, []*bigqueryrestapi.QueryParameter, util.ToolboxError) {
highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(paramsMetadata))
lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(paramsMetadata))

for _, p := range paramsMetadata {
name := p.GetName()
value := paramsMap[name]

// This block for converting []any to typed slices is still necessary and correct.
if arrayParam, ok := p.(*parameters.ArrayParameter); ok {
arrayParamValue, ok := value.([]any)
if !ok {
return nil, util.NewAgentError(fmt.Sprintf("unable to convert parameter `%s` to []any", name), nil)
}
itemType := arrayParam.GetItems().GetType()
var err error
value, err = parameters.ConvertAnySliceToTyped(arrayParamValue, itemType)
if err != nil {
return nil, util.NewAgentError(fmt.Sprintf("unable to convert parameter `%s` from []any to typed slice", name), err)
// Handle array types: convert []any to typed slices if necessary.
if arrayParam, ok := p.(*parameters.ArrayParameter); ok && value != nil {
if arrayParamValue, ok := value.([]any); ok {
itemType := arrayParam.GetItems().GetType()
var err error
value, err = parameters.ConvertAnySliceToTyped(arrayParamValue, itemType)
if err != nil {
return nil, nil, util.NewAgentError(fmt.Sprintf("unable to convert parameter `%s` from []any to typed slice", name), err)
}
}
}

// Determine if the parameter is named or positional for the high-level client.
var paramNameForHighLevel string
if strings.Contains(newStatement, "@"+name) {
isNamed, _ := regexp.MatchString("@"+name+"\\b", statement)
if isNamed {
paramNameForHighLevel = name
}

// Handle nil values for optional parameters by providing typed NULLs.
// BigQuery high-level client requires objects like NullString for NULLs.
// BigQuery low-level REST client requires setting the Null fields.
finalValue := value
isNull := value == nil

if isNull {
if p.GetEmbeddedBy() != "" {
finalValue = []float64(nil)
} else {
switch p.GetType() {
case parameters.TypeString:
finalValue = bigqueryapi.NullString{Valid: false}
case parameters.TypeInt:
finalValue = bigqueryapi.NullInt64{Valid: false}
case parameters.TypeFloat:
finalValue = bigqueryapi.NullFloat64{Valid: false}
case parameters.TypeBool:
finalValue = bigqueryapi.NullBool{Valid: false}
case parameters.TypeArray:
// For arrays, provide a typed nil slice based on items type.
if arrayParam, ok := p.(*parameters.ArrayParameter); ok {
switch arrayParam.GetItems().GetType() {
case parameters.TypeString:
finalValue = []string(nil)
case parameters.TypeInt:
finalValue = []int64(nil)
case parameters.TypeFloat:
finalValue = []float64(nil)
case parameters.TypeBool:
finalValue = []bool(nil)
default:
finalValue = []any(nil)
}
}
case parameters.TypeMap:
finalValue = map[string]any(nil)
}
}
}

// 1. Create the high-level parameter for the final query execution.
highLevelParams = append(highLevelParams, bigqueryapi.QueryParameter{
Name: paramNameForHighLevel,
Value: value,
Value: finalValue,
})

// 2. Create the low-level parameter for the dry run.
Expand All @@ -163,80 +239,63 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
ParameterValue: &bigqueryrestapi.QueryParameterValue{},
}

rv := reflect.ValueOf(value)
if rv.Kind() == reflect.Slice && rv.Type().Elem().Kind() != reflect.Uint8 {
lowLevelParam.ParameterType.Type = "ARRAY"
if isNull {
lowLevelParam.ParameterValue.NullFields = []string{"Value"}
}

// Default item type to FLOAT64 for embeddings, or use config if available.
itemType := "FLOAT64"
if arrayParam, ok := p.(*parameters.ArrayParameter); ok {
if bqType, err := bqutil.BQTypeStringFromToolType(arrayParam.GetItems().GetType()); err == nil {
itemType = bqType
}
// Check if this parameter is an array type.
// It is an array if its metadata type is Array, or if it is used for embedding,
var isArray bool
var itemType = "FLOAT64" // Default to FLOAT64 for embeddings
if arrayParam, ok := p.(*parameters.ArrayParameter); ok {
isArray = true
if bqType, err := bqutil.BQTypeStringFromToolType(arrayParam.GetItems().GetType()); err == nil {
itemType = bqType
}
lowLevelParam.ParameterType.ArrayType = &bigqueryrestapi.QueryParameterType{Type: itemType}
} else if p.GetEmbeddedBy() != "" {
isArray = true
}

// Build the array values.
arrayValues := make([]*bigqueryrestapi.QueryParameterValue, rv.Len())
for i := 0; i < rv.Len(); i++ {
val := rv.Index(i).Interface()

// Prevent precision loss and scientific notation issues
var valStr string
switch v := val.(type) {
case float64:
valStr = strconv.FormatFloat(v, 'f', -1, 64)
case float32:
valStr = strconv.FormatFloat(float64(v), 'f', -1, 32)
default:
valStr = fmt.Sprintf("%v", val)
}
if isArray {
lowLevelParam.ParameterType.Type = "ARRAY"
lowLevelParam.ParameterType.ArrayType = &bigqueryrestapi.QueryParameterType{Type: itemType}

arrayValues[i] = &bigqueryrestapi.QueryParameterValue{
Value: valStr,
if !isNull {
sliceVal := reflect.ValueOf(value)
arrayValues := make([]*bigqueryrestapi.QueryParameterValue, sliceVal.Len())
for i := 0; i < sliceVal.Len(); i++ {
val := sliceVal.Index(i).Interface()

// Prevent precision loss and scientific notation issues
var valStr string
switch v := val.(type) {
case float64:
valStr = strconv.FormatFloat(v, 'f', -1, 64)
case float32:
valStr = strconv.FormatFloat(float64(v), 'f', -1, 32)
default:
valStr = fmt.Sprintf("%v", val)
}

arrayValues[i] = &bigqueryrestapi.QueryParameterValue{
Value: valStr,
}
}
lowLevelParam.ParameterValue.ArrayValues = arrayValues
}
lowLevelParam.ParameterValue.ArrayValues = arrayValues
} else {
// Handle scalar types based on their defined type.
bqType, err := bqutil.BQTypeStringFromToolType(p.GetType())
if err != nil {
return nil, util.NewAgentError("unable to get BigQuery type from tool parameter type", err)
return nil, nil, util.NewAgentError(fmt.Sprintf("unable to get BigQuery type for parameter %q", name), err)
}
lowLevelParam.ParameterType.Type = bqType
lowLevelParam.ParameterValue.Value = fmt.Sprintf("%v", value)
if !isNull {
lowLevelParam.ParameterValue.Value = fmt.Sprintf("%v", value)
}
}
lowLevelParams = append(lowLevelParams, lowLevelParam)
}

connProps := []*bigqueryapi.ConnectionProperty{}
if source.BigQuerySession() != nil {
session, err := source.BigQuerySession()(ctx)
if err != nil {
return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err)
}
if session != nil {
// Add session ID to the connection properties for subsequent calls.
connProps = append(connProps, &bigqueryapi.ConnectionProperty{Key: "session_id", Value: session.ID})
}
}

bqClient, restService, err := source.RetrieveClientAndService(accessToken)
if err != nil {
return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err)
}

dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, connProps, source.GetMaximumBytesBilled())
if err != nil {
return nil, util.ProcessGcpError(err)
}

statementType := dryRunJob.Statistics.Query.StatementType
resp, err := source.RunSQL(ctx, bqClient, newStatement, statementType, highLevelParams, connProps)
if err != nil {
return nil, util.ProcessGcpError(err)
}
return resp, nil
return highLevelParams, lowLevelParams, nil
}

func formatVectorForBigQuery(vectorFloats []float32) any {
Expand Down
Loading
Loading