Skip to content

Commit 021837f

Browse files
committed
Rebase onto updated PR #320; remove ForceEnableTelemetry; fix test alignment
- Remove ForceEnableTelemetry from telemetry Config, driver_integration.go, and all call sites (connector.go) - Update feature flag tests to use new connector-service endpoint format ({"flags": [{"name": ..., "value": ...}]} instead of {"flags": {...}}) - Update exporter/integration tests to use new TelemetryRequest payload format - Update config/connector tests to reflect EnableTelemetry=true default - Fix rows_test.go NewRows calls to include telemetryCtx and telemetryUpdate args
1 parent 64dab01 commit 021837f

25 files changed

+820
-549
lines changed

connection.go

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ import (
1616
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
1717
"github.com/databricks/databricks-sql-go/internal/cli_service"
1818
"github.com/databricks/databricks-sql-go/internal/client"
19-
context2 "github.com/databricks/databricks-sql-go/internal/compat/context"
2019
"github.com/databricks/databricks-sql-go/internal/config"
2120
dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors"
2221
"github.com/databricks/databricks-sql-go/internal/rows"
@@ -53,18 +52,21 @@ func (c *conn) Close() error {
5352
ctx := driverctx.NewContextWithConnId(context.Background(), c.id)
5453

5554
// Close telemetry and release resources
55+
closeStart := time.Now()
56+
_, err := c.client.CloseSession(ctx, &cli_service.TCloseSessionReq{
57+
SessionHandle: c.session.SessionHandle,
58+
})
59+
closeLatencyMs := time.Since(closeStart).Milliseconds()
60+
5661
if c.telemetry != nil {
62+
c.telemetry.RecordOperation(ctx, c.id, telemetry.OperationTypeDeleteSession, closeLatencyMs)
5763
_ = c.telemetry.Close(ctx)
5864
telemetry.ReleaseForConnection(c.cfg.Host)
5965
}
6066

61-
_, err := c.client.CloseSession(ctx, &cli_service.TCloseSessionReq{
62-
SessionHandle: c.session.SessionHandle,
63-
})
64-
6567
if err != nil {
6668
log.Err(err).Msg("databricks: failed to close connection")
67-
return dbsqlerrint.NewBadConnectionError(err)
69+
return dbsqlerrint.NewRequestError(ctx, dbsqlerr.ErrCloseConnection, err)
6870
}
6971
return nil
7072
}
@@ -123,15 +125,16 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
123125

124126
corrId := driverctx.CorrelationIdFromContext(ctx)
125127

126-
exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args)
128+
var pollCount int
129+
exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args, &pollCount)
127130
log, ctx = client.LoggerAndContext(ctx, exStmtResp)
128131
stagingErr := c.execStagingOperation(exStmtResp, ctx)
129132

130133
// Telemetry: track statement execution
131134
var statementID string
132135
if c.telemetry != nil && exStmtResp != nil && exStmtResp.OperationHandle != nil && exStmtResp.OperationHandle.OperationId != nil {
133136
statementID = client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID)
134-
ctx = c.telemetry.BeforeExecute(ctx, statementID)
137+
ctx = c.telemetry.BeforeExecute(ctx, c.id, statementID)
135138
defer func() {
136139
finalErr := err
137140
if stagingErr != nil {
@@ -140,6 +143,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
140143
c.telemetry.AfterExecute(ctx, finalErr)
141144
c.telemetry.CompleteStatement(ctx, statementID, finalErr != nil)
142145
}()
146+
c.telemetry.AddTag(ctx, "poll_count", pollCount)
143147
}
144148

145149
if exStmtResp != nil && exStmtResp.OperationHandle != nil {
@@ -181,34 +185,61 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
181185
log, _ := client.LoggerAndContext(ctx, nil)
182186
msg, start := log.Track("QueryContext")
183187

184-
// first we try to get the results synchronously.
185-
// at any point in time that the context is done we must cancel and return
186-
exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args)
188+
// Capture execution start time for telemetry before running the query
189+
executeStart := time.Now()
190+
var pollCount int
191+
exStmtResp, opStatusResp, pollCount, err := c.runQueryWithTelemetry(ctx, query, args, &pollCount)
187192
log, ctx = client.LoggerAndContext(ctx, exStmtResp)
188193
defer log.Duration(msg, start)
189194

190-
// Telemetry: track statement execution
191195
var statementID string
192196
if c.telemetry != nil && exStmtResp != nil && exStmtResp.OperationHandle != nil && exStmtResp.OperationHandle.OperationId != nil {
193197
statementID = client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID)
194-
ctx = c.telemetry.BeforeExecute(ctx, statementID)
198+
// Use BeforeExecuteWithTime to set the correct start time (before execution)
199+
ctx = c.telemetry.BeforeExecuteWithTime(ctx, c.id, statementID, executeStart)
195200
defer func() {
196201
c.telemetry.AfterExecute(ctx, err)
197202
c.telemetry.CompleteStatement(ctx, statementID, err != nil)
198203
}()
204+
205+
c.telemetry.AddTag(ctx, "poll_count", pollCount)
206+
c.telemetry.AddTag(ctx, "operation_type", telemetry.OperationTypeExecuteStatement)
207+
208+
if exStmtResp.DirectResults != nil && exStmtResp.DirectResults.ResultSetMetadata != nil {
209+
resultFormat := exStmtResp.DirectResults.ResultSetMetadata.GetResultFormat()
210+
c.telemetry.AddTag(ctx, "result.format", resultFormat.String())
211+
}
199212
}
200213

201214
if err != nil {
202215
log.Err(err).Msg("databricks: failed to run query") // To log query we need to redact credentials
203216
return nil, dbsqlerrint.NewExecutionError(ctx, dbsqlerr.ErrQueryExecution, err, opStatusResp)
204217
}
205218

206-
rows, err := rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
219+
var telemetryUpdate func(int, int64)
220+
if c.telemetry != nil {
221+
telemetryUpdate = func(chunkCount int, bytesDownloaded int64) {
222+
c.telemetry.AddTag(ctx, "chunk_count", chunkCount)
223+
c.telemetry.AddTag(ctx, "bytes_downloaded", bytesDownloaded)
224+
}
225+
}
226+
227+
rows, err := rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults, ctx, telemetryUpdate)
228+
207229
return rows, err
208230

209231
}
210232

211-
func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedValue) (*cli_service.TExecuteStatementResp, *cli_service.TGetOperationStatusResp, error) {
233+
func (c *conn) runQueryWithTelemetry(ctx context.Context, query string, args []driver.NamedValue, pollCount *int) (*cli_service.TExecuteStatementResp, *cli_service.TGetOperationStatusResp, int, error) {
234+
exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args, pollCount)
235+
count := 0
236+
if pollCount != nil {
237+
count = *pollCount
238+
}
239+
return exStmtResp, opStatusResp, count, err
240+
}
241+
242+
func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedValue, pollCount *int) (*cli_service.TExecuteStatementResp, *cli_service.TGetOperationStatusResp, error) {
212243
// first we try to get the results synchronously.
213244
// at any point in time that the context is done we must cancel and return
214245
exStmtResp, err := c.executeStatement(ctx, query, args)
@@ -240,7 +271,7 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
240271
case cli_service.TOperationState_INITIALIZED_STATE,
241272
cli_service.TOperationState_PENDING_STATE,
242273
cli_service.TOperationState_RUNNING_STATE:
243-
statusResp, err := c.pollOperation(ctx, opHandle)
274+
statusResp, err := c.pollOperationWithCount(ctx, opHandle, pollCount)
244275
if err != nil {
245276
return exStmtResp, statusResp, err
246277
}
@@ -268,7 +299,7 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
268299
}
269300

270301
} else {
271-
statusResp, err := c.pollOperation(ctx, opHandle)
302+
statusResp, err := c.pollOperationWithCount(ctx, opHandle, pollCount)
272303
if err != nil {
273304
return exStmtResp, statusResp, err
274305
}
@@ -372,7 +403,6 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
372403

373404
select {
374405
default:
375-
// Non-blocking check: continue if context not done
376406
case <-ctx.Done():
377407
newCtx := driverctx.NewContextFromBackground(ctx)
378408
// in case context is done, we need to cancel the operation if necessary
@@ -396,12 +426,12 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
396426
return resp, err
397427
}
398428

399-
func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperationHandle) (*cli_service.TGetOperationStatusResp, error) {
429+
func (c *conn) pollOperationWithCount(ctx context.Context, opHandle *cli_service.TOperationHandle, pollCount *int) (*cli_service.TGetOperationStatusResp, error) {
400430
corrId := driverctx.CorrelationIdFromContext(ctx)
401431
log := logger.WithContext(c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID))
402432
var statusResp *cli_service.TGetOperationStatusResp
403433
ctx = driverctx.NewContextWithConnId(ctx, c.id)
404-
newCtx := context2.WithoutCancel(ctx)
434+
newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId)
405435
pollSentinel := sentinel.Sentinel{
406436
OnDoneFn: func(statusResp any) (any, error) {
407437
return statusResp, nil
@@ -413,6 +443,10 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
413443
OperationHandle: opHandle,
414444
})
415445

446+
if pollCount != nil {
447+
*pollCount++
448+
}
449+
416450
if statusResp != nil && statusResp.OperationState != nil {
417451
log.Debug().Msgf("databricks: status %s", statusResp.GetOperationState().String())
418452
}
@@ -455,6 +489,10 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
455489
return statusResp, nil
456490
}
457491

492+
func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperationHandle) (*cli_service.TGetOperationStatusResp, error) {
493+
return c.pollOperationWithCount(ctx, opHandle, nil)
494+
}
495+
458496
func (c *conn) CheckNamedValue(nv *driver.NamedValue) error {
459497
var err error
460498
if parameter, ok := nv.Value.(Parameter); ok {
@@ -622,7 +660,7 @@ func (c *conn) execStagingOperation(
622660
}
623661

624662
if len(driverctx.StagingPathsFromContext(ctx)) != 0 {
625-
row, err = rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
663+
row, err = rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults, nil, nil)
626664
if err != nil {
627665
return dbsqlerrint.NewDriverError(ctx, "error reading row.", err)
628666
}

connection_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ func TestConn_runQuery(t *testing.T) {
833833
client: testClient,
834834
cfg: config.WithDefaults(),
835835
}
836-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
836+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
837837
assert.Error(t, err)
838838
assert.Nil(t, exStmtResp)
839839
assert.Nil(t, opStatusResp)
@@ -875,7 +875,7 @@ func TestConn_runQuery(t *testing.T) {
875875
client: testClient,
876876
cfg: config.WithDefaults(),
877877
}
878-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
878+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
879879

880880
assert.Error(t, err)
881881
assert.Equal(t, 1, executeStatementCount)
@@ -921,7 +921,7 @@ func TestConn_runQuery(t *testing.T) {
921921
client: testClient,
922922
cfg: config.WithDefaults(),
923923
}
924-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
924+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
925925

926926
assert.NoError(t, err)
927927
assert.Equal(t, 1, executeStatementCount)
@@ -968,7 +968,7 @@ func TestConn_runQuery(t *testing.T) {
968968
client: testClient,
969969
cfg: config.WithDefaults(),
970970
}
971-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
971+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
972972

973973
assert.Error(t, err)
974974
assert.Equal(t, 1, executeStatementCount)
@@ -1021,7 +1021,7 @@ func TestConn_runQuery(t *testing.T) {
10211021
client: testClient,
10221022
cfg: config.WithDefaults(),
10231023
}
1024-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
1024+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
10251025

10261026
assert.NoError(t, err)
10271027
assert.Equal(t, 1, executeStatementCount)
@@ -1073,7 +1073,7 @@ func TestConn_runQuery(t *testing.T) {
10731073
client: testClient,
10741074
cfg: config.WithDefaults(),
10751075
}
1076-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
1076+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
10771077

10781078
assert.Error(t, err)
10791079
assert.Equal(t, 1, executeStatementCount)
@@ -1126,7 +1126,7 @@ func TestConn_runQuery(t *testing.T) {
11261126
client: testClient,
11271127
cfg: config.WithDefaults(),
11281128
}
1129-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
1129+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
11301130

11311131
assert.NoError(t, err)
11321132
assert.Equal(t, 1, executeStatementCount)
@@ -1179,7 +1179,7 @@ func TestConn_runQuery(t *testing.T) {
11791179
client: testClient,
11801180
cfg: config.WithDefaults(),
11811181
}
1182-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
1182+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
11831183

11841184
assert.Error(t, err)
11851185
assert.Equal(t, 1, executeStatementCount)

connector.go

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
5555
}
5656

5757
protocolVersion := int64(c.cfg.ThriftProtocolVersion)
58+
59+
sessionStart := time.Now()
5860
session, err := tclient.OpenSession(ctx, &cli_service.TOpenSessionReq{
5961
ClientProtocolI64: &protocolVersion,
6062
Configuration: sessionParams,
@@ -64,6 +66,8 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
6466
},
6567
CanUseMultipleCatalogs: &c.cfg.CanUseMultipleCatalogs,
6668
})
69+
sessionLatencyMs := time.Since(sessionStart).Milliseconds()
70+
6771
if err != nil {
6872
return nil, dbsqlerrint.NewRequestError(ctx, fmt.Sprintf("error connecting: host=%s port=%d, httpPath=%s", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath), err)
6973
}
@@ -76,21 +80,19 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
7680
}
7781
log := logger.WithContext(conn.id, driverctx.CorrelationIdFromContext(ctx), "")
7882

79-
// Initialize telemetry: pass user opt-in flag; if unset, feature flags decide
80-
var enableTelemetry *bool
83+
// Initialize telemetry if configured
8184
if c.cfg.EnableTelemetry {
82-
trueVal := true
83-
enableTelemetry = &trueVal
84-
}
85-
86-
conn.telemetry = telemetry.InitializeForConnection(
87-
ctx,
88-
c.cfg.Host,
89-
c.client,
90-
enableTelemetry,
91-
)
92-
if conn.telemetry != nil {
93-
log.Debug().Msg("telemetry initialized for connection")
85+
conn.telemetry = telemetry.InitializeForConnection(
86+
ctx,
87+
c.cfg.Host,
88+
c.cfg.DriverVersion,
89+
c.client,
90+
c.cfg.EnableTelemetry,
91+
)
92+
if conn.telemetry != nil {
93+
log.Debug().Msg("telemetry initialized for connection")
94+
conn.telemetry.RecordOperation(ctx, conn.id, telemetry.OperationTypeCreateSession, sessionLatencyMs)
95+
}
9496
}
9597

9698
log.Info().Msgf("connect: host=%s port=%d httpPath=%s serverProtocolVersion=0x%X", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath, session.ServerProtocolVersion)

connector_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ func TestNewConnector(t *testing.T) {
6868
RetryWaitMax: 60 * time.Second,
6969
Transport: roundTripper,
7070
CloudFetchConfig: expectedCloudFetchConfig,
71+
EnableTelemetry: true,
7172
}
7273
expectedCfg := config.WithDefaults()
7374
expectedCfg.DriverVersion = DriverVersion
@@ -110,6 +111,7 @@ func TestNewConnector(t *testing.T) {
110111
RetryWaitMin: 1 * time.Second,
111112
RetryWaitMax: 30 * time.Second,
112113
CloudFetchConfig: expectedCloudFetchConfig,
114+
EnableTelemetry: true,
113115
}
114116
expectedCfg := config.WithDefaults()
115117
expectedCfg.UserConfig = expectedUserConfig
@@ -152,6 +154,7 @@ func TestNewConnector(t *testing.T) {
152154
RetryWaitMin: 0,
153155
RetryWaitMax: 0,
154156
CloudFetchConfig: expectedCloudFetchConfig,
157+
EnableTelemetry: true,
155158
}
156159
expectedCfg := config.WithDefaults()
157160
expectedCfg.DriverVersion = DriverVersion

internal/config/config.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,9 @@ func (ucfg UserConfig) WithDefaults() UserConfig {
182182
ucfg.UseLz4Compression = false
183183
ucfg.CloudFetchConfig = CloudFetchConfig{}.WithDefaults()
184184

185+
// Enable telemetry by default (respects server feature flags)
186+
ucfg.EnableTelemetry = true
187+
185188
return ucfg
186189
}
187190

0 commit comments

Comments
 (0)