Skip to content

Commit ef035e9

Browse files
committed
Rebase onto updated PR #320; remove ForceEnableTelemetry; fix test alignment
- Remove ForceEnableTelemetry from telemetry Config, driver_integration.go, and all call sites (connector.go) - Update feature flag tests to use new connector-service endpoint format ({"flags": [{"name": ..., "value": ...}]} instead of {"flags": {...}}) - Update exporter/integration tests to use new TelemetryRequest payload format - Update config/connector tests to reflect EnableTelemetry=true default - Fix rows_test.go NewRows calls to include telemetryCtx and telemetryUpdate args
1 parent 147e70f commit ef035e9

23 files changed

+786
-517
lines changed

connection.go

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ import (
1616
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
1717
"github.com/databricks/databricks-sql-go/internal/cli_service"
1818
"github.com/databricks/databricks-sql-go/internal/client"
19-
context2 "github.com/databricks/databricks-sql-go/internal/compat/context"
2019
"github.com/databricks/databricks-sql-go/internal/config"
2120
dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors"
2221
"github.com/databricks/databricks-sql-go/internal/rows"
@@ -53,18 +52,21 @@ func (c *conn) Close() error {
5352
ctx := driverctx.NewContextWithConnId(context.Background(), c.id)
5453

5554
// Close telemetry and release resources
55+
closeStart := time.Now()
56+
_, err := c.client.CloseSession(ctx, &cli_service.TCloseSessionReq{
57+
SessionHandle: c.session.SessionHandle,
58+
})
59+
closeLatencyMs := time.Since(closeStart).Milliseconds()
60+
5661
if c.telemetry != nil {
62+
c.telemetry.RecordOperation(ctx, c.id, telemetry.OperationTypeDeleteSession, closeLatencyMs)
5763
_ = c.telemetry.Close(ctx)
5864
telemetry.ReleaseForConnection(c.cfg.Host)
5965
}
6066

61-
_, err := c.client.CloseSession(ctx, &cli_service.TCloseSessionReq{
62-
SessionHandle: c.session.SessionHandle,
63-
})
64-
6567
if err != nil {
6668
log.Err(err).Msg("databricks: failed to close connection")
67-
return dbsqlerrint.NewBadConnectionError(err)
69+
return dbsqlerrint.NewRequestError(ctx, dbsqlerr.ErrCloseConnection, err)
6870
}
6971
return nil
7072
}
@@ -123,7 +125,8 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
123125

124126
corrId := driverctx.CorrelationIdFromContext(ctx)
125127

126-
exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args)
128+
var pollCount int
129+
exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args, &pollCount)
127130
log, ctx = client.LoggerAndContext(ctx, exStmtResp)
128131
stagingErr := c.execStagingOperation(exStmtResp, ctx)
129132

@@ -132,7 +135,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
132135
var closeOpErr error // Track CloseOperation errors for telemetry
133136
if c.telemetry != nil && exStmtResp != nil && exStmtResp.OperationHandle != nil && exStmtResp.OperationHandle.OperationId != nil {
134137
statementID = client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID)
135-
ctx = c.telemetry.BeforeExecute(ctx, statementID)
138+
ctx = c.telemetry.BeforeExecute(ctx, c.id, statementID)
136139
defer func() {
137140
finalErr := err
138141
if stagingErr != nil {
@@ -145,6 +148,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
145148
c.telemetry.AfterExecute(ctx, finalErr)
146149
c.telemetry.CompleteStatement(ctx, statementID, finalErr != nil)
147150
}()
151+
c.telemetry.AddTag(ctx, "poll_count", pollCount)
148152
}
149153

150154
if exStmtResp != nil && exStmtResp.OperationHandle != nil {
@@ -187,34 +191,61 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
187191
log, _ := client.LoggerAndContext(ctx, nil)
188192
msg, start := log.Track("QueryContext")
189193

190-
// first we try to get the results synchronously.
191-
// at any point in time that the context is done we must cancel and return
192-
exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args)
194+
// Capture execution start time for telemetry before running the query
195+
executeStart := time.Now()
196+
var pollCount int
197+
exStmtResp, opStatusResp, pollCount, err := c.runQueryWithTelemetry(ctx, query, args, &pollCount)
193198
log, ctx = client.LoggerAndContext(ctx, exStmtResp)
194199
defer log.Duration(msg, start)
195200

196-
// Telemetry: track statement execution
197201
var statementID string
198202
if c.telemetry != nil && exStmtResp != nil && exStmtResp.OperationHandle != nil && exStmtResp.OperationHandle.OperationId != nil {
199203
statementID = client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID)
200-
ctx = c.telemetry.BeforeExecute(ctx, statementID)
204+
// Use BeforeExecuteWithTime to set the correct start time (before execution)
205+
ctx = c.telemetry.BeforeExecuteWithTime(ctx, c.id, statementID, executeStart)
201206
defer func() {
202207
c.telemetry.AfterExecute(ctx, err)
203208
c.telemetry.CompleteStatement(ctx, statementID, err != nil)
204209
}()
210+
211+
c.telemetry.AddTag(ctx, "poll_count", pollCount)
212+
c.telemetry.AddTag(ctx, "operation_type", telemetry.OperationTypeExecuteStatement)
213+
214+
if exStmtResp.DirectResults != nil && exStmtResp.DirectResults.ResultSetMetadata != nil {
215+
resultFormat := exStmtResp.DirectResults.ResultSetMetadata.GetResultFormat()
216+
c.telemetry.AddTag(ctx, "result.format", resultFormat.String())
217+
}
205218
}
206219

207220
if err != nil {
208221
log.Err(err).Msg("databricks: failed to run query") // To log query we need to redact credentials
209222
return nil, dbsqlerrint.NewExecutionError(ctx, dbsqlerr.ErrQueryExecution, err, opStatusResp)
210223
}
211224

212-
rows, err := rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
225+
var telemetryUpdate func(int, int64)
226+
if c.telemetry != nil {
227+
telemetryUpdate = func(chunkCount int, bytesDownloaded int64) {
228+
c.telemetry.AddTag(ctx, "chunk_count", chunkCount)
229+
c.telemetry.AddTag(ctx, "bytes_downloaded", bytesDownloaded)
230+
}
231+
}
232+
233+
rows, err := rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults, ctx, telemetryUpdate)
234+
213235
return rows, err
214236

215237
}
216238

217-
func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedValue) (*cli_service.TExecuteStatementResp, *cli_service.TGetOperationStatusResp, error) {
239+
func (c *conn) runQueryWithTelemetry(ctx context.Context, query string, args []driver.NamedValue, pollCount *int) (*cli_service.TExecuteStatementResp, *cli_service.TGetOperationStatusResp, int, error) {
240+
exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args, pollCount)
241+
count := 0
242+
if pollCount != nil {
243+
count = *pollCount
244+
}
245+
return exStmtResp, opStatusResp, count, err
246+
}
247+
248+
func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedValue, pollCount *int) (*cli_service.TExecuteStatementResp, *cli_service.TGetOperationStatusResp, error) {
218249
// first we try to get the results synchronously.
219250
// at any point in time that the context is done we must cancel and return
220251
exStmtResp, err := c.executeStatement(ctx, query, args)
@@ -246,7 +277,7 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
246277
case cli_service.TOperationState_INITIALIZED_STATE,
247278
cli_service.TOperationState_PENDING_STATE,
248279
cli_service.TOperationState_RUNNING_STATE:
249-
statusResp, err := c.pollOperation(ctx, opHandle)
280+
statusResp, err := c.pollOperationWithCount(ctx, opHandle, pollCount)
250281
if err != nil {
251282
return exStmtResp, statusResp, err
252283
}
@@ -274,7 +305,7 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
274305
}
275306

276307
} else {
277-
statusResp, err := c.pollOperation(ctx, opHandle)
308+
statusResp, err := c.pollOperationWithCount(ctx, opHandle, pollCount)
278309
if err != nil {
279310
return exStmtResp, statusResp, err
280311
}
@@ -389,7 +420,6 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
389420

390421
select {
391422
default:
392-
// Non-blocking check: continue if context not done
393423
case <-ctx.Done():
394424
newCtx := driverctx.NewContextFromBackground(ctx)
395425
// in case context is done, we need to cancel the operation if necessary
@@ -413,12 +443,12 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
413443
return resp, err
414444
}
415445

416-
func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperationHandle) (*cli_service.TGetOperationStatusResp, error) {
446+
func (c *conn) pollOperationWithCount(ctx context.Context, opHandle *cli_service.TOperationHandle, pollCount *int) (*cli_service.TGetOperationStatusResp, error) {
417447
corrId := driverctx.CorrelationIdFromContext(ctx)
418448
log := logger.WithContext(c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID))
419449
var statusResp *cli_service.TGetOperationStatusResp
420450
ctx = driverctx.NewContextWithConnId(ctx, c.id)
421-
newCtx := context2.WithoutCancel(ctx)
451+
newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId)
422452
pollSentinel := sentinel.Sentinel{
423453
OnDoneFn: func(statusResp any) (any, error) {
424454
return statusResp, nil
@@ -430,6 +460,10 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
430460
OperationHandle: opHandle,
431461
})
432462

463+
if pollCount != nil {
464+
*pollCount++
465+
}
466+
433467
if statusResp != nil && statusResp.OperationState != nil {
434468
log.Debug().Msgf("databricks: status %s", statusResp.GetOperationState().String())
435469
}
@@ -472,6 +506,10 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
472506
return statusResp, nil
473507
}
474508

509+
func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperationHandle) (*cli_service.TGetOperationStatusResp, error) {
510+
return c.pollOperationWithCount(ctx, opHandle, nil)
511+
}
512+
475513
func (c *conn) CheckNamedValue(nv *driver.NamedValue) error {
476514
var err error
477515
if parameter, ok := nv.Value.(Parameter); ok {
@@ -639,7 +677,7 @@ func (c *conn) execStagingOperation(
639677
}
640678

641679
if len(driverctx.StagingPathsFromContext(ctx)) != 0 {
642-
row, err = rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
680+
row, err = rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults, nil, nil)
643681
if err != nil {
644682
return dbsqlerrint.NewDriverError(ctx, "error reading row.", err)
645683
}

connection_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,7 @@ func TestConn_runQuery(t *testing.T) {
10371037
client: testClient,
10381038
cfg: config.WithDefaults(),
10391039
}
1040-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
1040+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
10411041
assert.Error(t, err)
10421042
assert.Nil(t, exStmtResp)
10431043
assert.Nil(t, opStatusResp)
@@ -1079,7 +1079,7 @@ func TestConn_runQuery(t *testing.T) {
10791079
client: testClient,
10801080
cfg: config.WithDefaults(),
10811081
}
1082-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
1082+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
10831083

10841084
assert.Error(t, err)
10851085
assert.Equal(t, 1, executeStatementCount)
@@ -1125,7 +1125,7 @@ func TestConn_runQuery(t *testing.T) {
11251125
client: testClient,
11261126
cfg: config.WithDefaults(),
11271127
}
1128-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
1128+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
11291129

11301130
assert.NoError(t, err)
11311131
assert.Equal(t, 1, executeStatementCount)
@@ -1172,7 +1172,7 @@ func TestConn_runQuery(t *testing.T) {
11721172
client: testClient,
11731173
cfg: config.WithDefaults(),
11741174
}
1175-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
1175+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
11761176

11771177
assert.Error(t, err)
11781178
assert.Equal(t, 1, executeStatementCount)
@@ -1225,7 +1225,7 @@ func TestConn_runQuery(t *testing.T) {
12251225
client: testClient,
12261226
cfg: config.WithDefaults(),
12271227
}
1228-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
1228+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
12291229

12301230
assert.NoError(t, err)
12311231
assert.Equal(t, 1, executeStatementCount)
@@ -1277,7 +1277,7 @@ func TestConn_runQuery(t *testing.T) {
12771277
client: testClient,
12781278
cfg: config.WithDefaults(),
12791279
}
1280-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
1280+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
12811281

12821282
assert.Error(t, err)
12831283
assert.Equal(t, 1, executeStatementCount)
@@ -1330,7 +1330,7 @@ func TestConn_runQuery(t *testing.T) {
13301330
client: testClient,
13311331
cfg: config.WithDefaults(),
13321332
}
1333-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
1333+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
13341334

13351335
assert.NoError(t, err)
13361336
assert.Equal(t, 1, executeStatementCount)
@@ -1383,7 +1383,7 @@ func TestConn_runQuery(t *testing.T) {
13831383
client: testClient,
13841384
cfg: config.WithDefaults(),
13851385
}
1386-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
1386+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
13871387

13881388
assert.Error(t, err)
13891389
assert.Equal(t, 1, executeStatementCount)

connector.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
5555
}
5656

5757
protocolVersion := int64(c.cfg.ThriftProtocolVersion)
58+
59+
sessionStart := time.Now()
5860
session, err := tclient.OpenSession(ctx, &cli_service.TOpenSessionReq{
5961
ClientProtocolI64: &protocolVersion,
6062
Configuration: sessionParams,
@@ -64,6 +66,8 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
6466
},
6567
CanUseMultipleCatalogs: &c.cfg.CanUseMultipleCatalogs,
6668
})
69+
sessionLatencyMs := time.Since(sessionStart).Milliseconds()
70+
6771
if err != nil {
6872
return nil, dbsqlerrint.NewRequestError(ctx, fmt.Sprintf("error connecting: host=%s port=%d, httpPath=%s", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath), err)
6973
}
@@ -80,11 +84,13 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
8084
conn.telemetry = telemetry.InitializeForConnection(
8185
ctx,
8286
c.cfg.Host,
87+
c.cfg.DriverVersion,
8388
c.client,
8489
c.cfg.EnableTelemetry,
8590
)
8691
if conn.telemetry != nil {
8792
log.Debug().Msg("telemetry initialized for connection")
93+
conn.telemetry.RecordOperation(ctx, conn.id, telemetry.OperationTypeCreateSession, sessionLatencyMs)
8894
}
8995

9096
log.Info().Msgf("connect: host=%s port=%d httpPath=%s serverProtocolVersion=0x%X", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath, session.ServerProtocolVersion)

internal/config/config.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ func (ucfg UserConfig) WithDefaults() UserConfig {
184184
ucfg.UseLz4Compression = false
185185
ucfg.CloudFetchConfig = CloudFetchConfig{}.WithDefaults()
186186

187+
// EnableTelemetry defaults to unset (ConfigValue zero value),
188+
// meaning telemetry is controlled by server feature flags.
189+
187190
return ucfg
188191
}
189192

0 commit comments

Comments
 (0)