Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors"
"github.com/databricks/databricks-sql-go/internal/rows"
"github.com/databricks/databricks-sql-go/internal/sentinel"
"github.com/databricks/databricks-sql-go/internal/thrift_protocol"
"github.com/databricks/databricks-sql-go/logger"
"github.com/pkg/errors"
)
Expand Down Expand Up @@ -285,14 +286,30 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
Statement: query,
RunAsync: true,
QueryTimeout: int64(c.cfg.QueryTimeout / time.Second),
GetDirectResults: &cli_service.TSparkGetDirectResults{
}

// Check protocol version for feature support
serverProtocolVersion := c.session.ServerProtocolVersion

// Add direct results if supported
if thrift_protocol.SupportsDirectResults(serverProtocolVersion) {
req.GetDirectResults = &cli_service.TSparkGetDirectResults{
MaxRows: int64(c.cfg.MaxRows),
},
CanDecompressLZ4Result_: &c.cfg.UseLz4Compression,
Parameters: parameters,
}
}

if c.cfg.UseArrowBatches {
// Add LZ4 compression if supported and enabled
if thrift_protocol.SupportsLz4Compression(serverProtocolVersion) && c.cfg.UseLz4Compression {
req.CanDecompressLZ4Result_ = &c.cfg.UseLz4Compression
}

// Add cloud fetch if supported and enabled
if thrift_protocol.SupportsCloudFetch(serverProtocolVersion) && c.cfg.UseCloudFetch {
req.CanDownloadResult_ = &c.cfg.UseCloudFetch
}

// Add Arrow support if supported and enabled
if thrift_protocol.SupportsArrow(serverProtocolVersion) && c.cfg.UseArrowBatches {
req.CanReadArrowResult_ = &c.cfg.UseArrowBatches
req.UseArrowNativeTypes = &cli_service.TSparkArrowTypes{
DecimalAsArrow: &c.cfg.UseArrowNativeDecimal,
Expand All @@ -302,8 +319,9 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
}
}

if c.cfg.UseCloudFetch {
req.CanDownloadResult_ = &c.cfg.UseCloudFetch
// Add parameters if supported and provided
if thrift_protocol.SupportsParameterizedQueries(serverProtocolVersion) && len(parameters) > 0 {
req.Parameters = parameters
}

resp, err := c.client.ExecuteStatement(ctx, &req)
Expand Down
2 changes: 1 addition & 1 deletion connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
}
log := logger.WithContext(conn.id, driverctx.CorrelationIdFromContext(ctx), "")

log.Info().Msgf("connect: host=%s port=%d httpPath=%s", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath)
log.Info().Msgf("connect: host=%s port=%d httpPath=%s serverProtocolVersion=0x%X", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath, session.ServerProtocolVersion)

for k, v := range c.cfg.SessionParams {
setStmt := fmt.Sprintf("SET `%s` = `%s`;", k, v)
Expand Down
5 changes: 5 additions & 0 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ func (tsc *ThriftServiceClient) OpenSession(ctx context.Context, req *cli_servic
return resp, err
}

// Log the server protocol version
if resp != nil {
log.Debug().Msgf("Server protocol version: 0x%X", resp.ServerProtocolVersion)
}

recordResult(ctx, resp)

return resp, CheckStatus(resp)
Expand Down
18 changes: 17 additions & 1 deletion internal/client/testclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ import (
var ErrNotImplemented = errors.New("databricks: not implemented")

type TestClient struct {
// Default server protocol version to use in tests
ServerProtocolVersion cli_service.TProtocolVersion

FnOpenSession func(ctx context.Context, req *cli_service.TOpenSessionReq) (_r *cli_service.TOpenSessionResp, _err error)
FnCloseSession func(ctx context.Context, req *cli_service.TCloseSessionReq) (_r *cli_service.TCloseSessionResp, _err error)
FnGetInfo func(ctx context.Context, req *cli_service.TGetInfoReq) (_r *cli_service.TGetInfoResp, _err error)
Expand Down Expand Up @@ -39,7 +42,20 @@ func (c *TestClient) OpenSession(ctx context.Context, req *cli_service.TOpenSess
if c.FnOpenSession != nil {
return c.FnOpenSession(ctx, req)
}
return nil, ErrNotImplemented

// Default implementation for test client
resp := &cli_service.TOpenSessionResp{
Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS},
ServerProtocolVersion: c.ServerProtocolVersion,
SessionHandle: &cli_service.TSessionHandle{
SessionId: &cli_service.THandleIdentifier{
GUID: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
Secret: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
},
},
}

return resp, nil
Comment thread
shivam2680 marked this conversation as resolved.
Outdated
}
func (c *TestClient) CloseSession(ctx context.Context, req *cli_service.TCloseSessionReq) (_r *cli_service.TCloseSessionResp, _err error) {
if c.FnCloseSession != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ func TestConfig_DeepCopy(t *testing.T) {
DriverVersion: "0.9.0",
ThriftProtocol: "binary",
ThriftTransport: "http",
ThriftProtocolVersion: cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V6,
ThriftProtocolVersion: cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V8,
ThriftDebugClientProtocol: false,
}

Expand Down
46 changes: 46 additions & 0 deletions internal/thrift_protocol/protocol_feature_util.go
Comment thread
shivam2680 marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package thrift_protocol

import "github.com/databricks/databricks-sql-go/internal/cli_service"

// Feature checks
// SupportsDirectResults checks if the server protocol version supports direct results
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V1 and above
func SupportsDirectResults(version cli_service.TProtocolVersion) bool {
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V1
}

// SupportsLz4Compression checks if the server protocol version supports LZ4 compression
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V6 and above
func SupportsLz4Compression(version cli_service.TProtocolVersion) bool {
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V6
}
Comment thread
shivam2680 marked this conversation as resolved.

// SupportsCloudFetch checks if the server protocol version supports cloud fetch
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V3 and above
func SupportsCloudFetch(version cli_service.TProtocolVersion) bool {
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V3
}

// SupportsArrow checks if the server protocol version supports Arrow format
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V5 and above
func SupportsArrow(version cli_service.TProtocolVersion) bool {
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V5
}

// SupportsCompressedArrow checks if the server protocol version supports compressed Arrow format
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V6 and above
func SupportsCompressedArrow(version cli_service.TProtocolVersion) bool {
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V6
}

// SupportsParameterizedQueries checks if the server protocol version supports parameterized queries
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V8 and above
func SupportsParameterizedQueries(version cli_service.TProtocolVersion) bool {
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V8
}

// SupportsMultipleCatalogs checks if the server protocol version supports multiple catalogs
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V4 and above
func SupportsMultipleCatalogs(version cli_service.TProtocolVersion) bool {
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V4
}
Loading