Skip to content

Commit 15ce89b

Browse files
committed
Rebase onto updated PR #320; remove ForceEnableTelemetry; fix test alignment
- Remove ForceEnableTelemetry from telemetry Config, driver_integration.go, and all call sites (connector.go) - Update feature flag tests to use new connector-service endpoint format ({"flags": [{"name": ..., "value": ...}]} instead of {"flags": {...}}) - Update exporter/integration tests to use new TelemetryRequest payload format - Update config/connector tests to reflect EnableTelemetry=true default - Fix rows_test.go NewRows calls to include telemetryCtx and telemetryUpdate args
1 parent 50b0789 commit 15ce89b

23 files changed

+786
-526
lines changed

connection.go

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ import (
1616
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
1717
"github.com/databricks/databricks-sql-go/internal/cli_service"
1818
"github.com/databricks/databricks-sql-go/internal/client"
19-
context2 "github.com/databricks/databricks-sql-go/internal/compat/context"
2019
"github.com/databricks/databricks-sql-go/internal/config"
2120
dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors"
2221
"github.com/databricks/databricks-sql-go/internal/rows"
@@ -53,18 +52,21 @@ func (c *conn) Close() error {
5352
ctx := driverctx.NewContextWithConnId(context.Background(), c.id)
5453

5554
// Close telemetry and release resources
55+
closeStart := time.Now()
56+
_, err := c.client.CloseSession(ctx, &cli_service.TCloseSessionReq{
57+
SessionHandle: c.session.SessionHandle,
58+
})
59+
closeLatencyMs := time.Since(closeStart).Milliseconds()
60+
5661
if c.telemetry != nil {
62+
c.telemetry.RecordOperation(ctx, c.id, telemetry.OperationTypeDeleteSession, closeLatencyMs)
5763
_ = c.telemetry.Close(ctx)
5864
telemetry.ReleaseForConnection(c.cfg.Host)
5965
}
6066

61-
_, err := c.client.CloseSession(ctx, &cli_service.TCloseSessionReq{
62-
SessionHandle: c.session.SessionHandle,
63-
})
64-
6567
if err != nil {
6668
log.Err(err).Msg("databricks: failed to close connection")
67-
return dbsqlerrint.NewBadConnectionError(err)
69+
return dbsqlerrint.NewRequestError(ctx, dbsqlerr.ErrCloseConnection, err)
6870
}
6971
return nil
7072
}
@@ -123,15 +125,16 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
123125

124126
corrId := driverctx.CorrelationIdFromContext(ctx)
125127

126-
exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args)
128+
var pollCount int
129+
exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args, &pollCount)
127130
log, ctx = client.LoggerAndContext(ctx, exStmtResp)
128131
stagingErr := c.execStagingOperation(exStmtResp, ctx)
129132

130133
// Telemetry: track statement execution
131134
var statementID string
132135
if c.telemetry != nil && exStmtResp != nil && exStmtResp.OperationHandle != nil && exStmtResp.OperationHandle.OperationId != nil {
133136
statementID = client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID)
134-
ctx = c.telemetry.BeforeExecute(ctx, statementID)
137+
ctx = c.telemetry.BeforeExecute(ctx, c.id, statementID)
135138
defer func() {
136139
finalErr := err
137140
if stagingErr != nil {
@@ -140,6 +143,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
140143
c.telemetry.AfterExecute(ctx, finalErr)
141144
c.telemetry.CompleteStatement(ctx, statementID, finalErr != nil)
142145
}()
146+
c.telemetry.AddTag(ctx, "poll_count", pollCount)
143147
}
144148

145149
if exStmtResp != nil && exStmtResp.OperationHandle != nil {
@@ -181,34 +185,61 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
181185
log, _ := client.LoggerAndContext(ctx, nil)
182186
msg, start := log.Track("QueryContext")
183187

184-
// first we try to get the results synchronously.
185-
// at any point in time that the context is done we must cancel and return
186-
exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args)
188+
// Capture execution start time for telemetry before running the query
189+
executeStart := time.Now()
190+
var pollCount int
191+
exStmtResp, opStatusResp, pollCount, err := c.runQueryWithTelemetry(ctx, query, args, &pollCount)
187192
log, ctx = client.LoggerAndContext(ctx, exStmtResp)
188193
defer log.Duration(msg, start)
189194

190-
// Telemetry: track statement execution
191195
var statementID string
192196
if c.telemetry != nil && exStmtResp != nil && exStmtResp.OperationHandle != nil && exStmtResp.OperationHandle.OperationId != nil {
193197
statementID = client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID)
194-
ctx = c.telemetry.BeforeExecute(ctx, statementID)
198+
// Use BeforeExecuteWithTime to set the correct start time (before execution)
199+
ctx = c.telemetry.BeforeExecuteWithTime(ctx, c.id, statementID, executeStart)
195200
defer func() {
196201
c.telemetry.AfterExecute(ctx, err)
197202
c.telemetry.CompleteStatement(ctx, statementID, err != nil)
198203
}()
204+
205+
c.telemetry.AddTag(ctx, "poll_count", pollCount)
206+
c.telemetry.AddTag(ctx, "operation_type", telemetry.OperationTypeExecuteStatement)
207+
208+
if exStmtResp.DirectResults != nil && exStmtResp.DirectResults.ResultSetMetadata != nil {
209+
resultFormat := exStmtResp.DirectResults.ResultSetMetadata.GetResultFormat()
210+
c.telemetry.AddTag(ctx, "result.format", resultFormat.String())
211+
}
199212
}
200213

201214
if err != nil {
202215
log.Err(err).Msg("databricks: failed to run query") // To log query we need to redact credentials
203216
return nil, dbsqlerrint.NewExecutionError(ctx, dbsqlerr.ErrQueryExecution, err, opStatusResp)
204217
}
205218

206-
rows, err := rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
219+
var telemetryUpdate func(int, int64)
220+
if c.telemetry != nil {
221+
telemetryUpdate = func(chunkCount int, bytesDownloaded int64) {
222+
c.telemetry.AddTag(ctx, "chunk_count", chunkCount)
223+
c.telemetry.AddTag(ctx, "bytes_downloaded", bytesDownloaded)
224+
}
225+
}
226+
227+
rows, err := rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults, ctx, telemetryUpdate)
228+
207229
return rows, err
208230

209231
}
210232

211-
func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedValue) (*cli_service.TExecuteStatementResp, *cli_service.TGetOperationStatusResp, error) {
233+
func (c *conn) runQueryWithTelemetry(ctx context.Context, query string, args []driver.NamedValue, pollCount *int) (*cli_service.TExecuteStatementResp, *cli_service.TGetOperationStatusResp, int, error) {
234+
exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args, pollCount)
235+
count := 0
236+
if pollCount != nil {
237+
count = *pollCount
238+
}
239+
return exStmtResp, opStatusResp, count, err
240+
}
241+
242+
func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedValue, pollCount *int) (*cli_service.TExecuteStatementResp, *cli_service.TGetOperationStatusResp, error) {
212243
// first we try to get the results synchronously.
213244
// at any point in time that the context is done we must cancel and return
214245
exStmtResp, err := c.executeStatement(ctx, query, args)
@@ -240,7 +271,7 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
240271
case cli_service.TOperationState_INITIALIZED_STATE,
241272
cli_service.TOperationState_PENDING_STATE,
242273
cli_service.TOperationState_RUNNING_STATE:
243-
statusResp, err := c.pollOperation(ctx, opHandle)
274+
statusResp, err := c.pollOperationWithCount(ctx, opHandle, pollCount)
244275
if err != nil {
245276
return exStmtResp, statusResp, err
246277
}
@@ -268,7 +299,7 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
268299
}
269300

270301
} else {
271-
statusResp, err := c.pollOperation(ctx, opHandle)
302+
statusResp, err := c.pollOperationWithCount(ctx, opHandle, pollCount)
272303
if err != nil {
273304
return exStmtResp, statusResp, err
274305
}
@@ -372,7 +403,6 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
372403

373404
select {
374405
default:
375-
// Non-blocking check: continue if context not done
376406
case <-ctx.Done():
377407
newCtx := driverctx.NewContextFromBackground(ctx)
378408
// in case context is done, we need to cancel the operation if necessary
@@ -396,12 +426,12 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
396426
return resp, err
397427
}
398428

399-
func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperationHandle) (*cli_service.TGetOperationStatusResp, error) {
429+
func (c *conn) pollOperationWithCount(ctx context.Context, opHandle *cli_service.TOperationHandle, pollCount *int) (*cli_service.TGetOperationStatusResp, error) {
400430
corrId := driverctx.CorrelationIdFromContext(ctx)
401431
log := logger.WithContext(c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID))
402432
var statusResp *cli_service.TGetOperationStatusResp
403433
ctx = driverctx.NewContextWithConnId(ctx, c.id)
404-
newCtx := context2.WithoutCancel(ctx)
434+
newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId)
405435
pollSentinel := sentinel.Sentinel{
406436
OnDoneFn: func(statusResp any) (any, error) {
407437
return statusResp, nil
@@ -413,6 +443,10 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
413443
OperationHandle: opHandle,
414444
})
415445

446+
if pollCount != nil {
447+
*pollCount++
448+
}
449+
416450
if statusResp != nil && statusResp.OperationState != nil {
417451
log.Debug().Msgf("databricks: status %s", statusResp.GetOperationState().String())
418452
}
@@ -455,6 +489,10 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
455489
return statusResp, nil
456490
}
457491

492+
func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperationHandle) (*cli_service.TGetOperationStatusResp, error) {
493+
return c.pollOperationWithCount(ctx, opHandle, nil)
494+
}
495+
458496
func (c *conn) CheckNamedValue(nv *driver.NamedValue) error {
459497
var err error
460498
if parameter, ok := nv.Value.(Parameter); ok {
@@ -622,7 +660,7 @@ func (c *conn) execStagingOperation(
622660
}
623661

624662
if len(driverctx.StagingPathsFromContext(ctx)) != 0 {
625-
row, err = rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
663+
row, err = rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults, nil, nil)
626664
if err != nil {
627665
return dbsqlerrint.NewDriverError(ctx, "error reading row.", err)
628666
}

connection_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ func TestConn_runQuery(t *testing.T) {
833833
client: testClient,
834834
cfg: config.WithDefaults(),
835835
}
836-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
836+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
837837
assert.Error(t, err)
838838
assert.Nil(t, exStmtResp)
839839
assert.Nil(t, opStatusResp)
@@ -875,7 +875,7 @@ func TestConn_runQuery(t *testing.T) {
875875
client: testClient,
876876
cfg: config.WithDefaults(),
877877
}
878-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
878+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
879879

880880
assert.Error(t, err)
881881
assert.Equal(t, 1, executeStatementCount)
@@ -921,7 +921,7 @@ func TestConn_runQuery(t *testing.T) {
921921
client: testClient,
922922
cfg: config.WithDefaults(),
923923
}
924-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
924+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
925925

926926
assert.NoError(t, err)
927927
assert.Equal(t, 1, executeStatementCount)
@@ -968,7 +968,7 @@ func TestConn_runQuery(t *testing.T) {
968968
client: testClient,
969969
cfg: config.WithDefaults(),
970970
}
971-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
971+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
972972

973973
assert.Error(t, err)
974974
assert.Equal(t, 1, executeStatementCount)
@@ -1021,7 +1021,7 @@ func TestConn_runQuery(t *testing.T) {
10211021
client: testClient,
10221022
cfg: config.WithDefaults(),
10231023
}
1024-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
1024+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
10251025

10261026
assert.NoError(t, err)
10271027
assert.Equal(t, 1, executeStatementCount)
@@ -1073,7 +1073,7 @@ func TestConn_runQuery(t *testing.T) {
10731073
client: testClient,
10741074
cfg: config.WithDefaults(),
10751075
}
1076-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
1076+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
10771077

10781078
assert.Error(t, err)
10791079
assert.Equal(t, 1, executeStatementCount)
@@ -1126,7 +1126,7 @@ func TestConn_runQuery(t *testing.T) {
11261126
client: testClient,
11271127
cfg: config.WithDefaults(),
11281128
}
1129-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
1129+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
11301130

11311131
assert.NoError(t, err)
11321132
assert.Equal(t, 1, executeStatementCount)
@@ -1179,7 +1179,7 @@ func TestConn_runQuery(t *testing.T) {
11791179
client: testClient,
11801180
cfg: config.WithDefaults(),
11811181
}
1182-
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{})
1182+
exStmtResp, opStatusResp, err := testConn.runQuery(context.Background(), "select 1", []driver.NamedValue{}, nil)
11831183

11841184
assert.Error(t, err)
11851185
assert.Equal(t, 1, executeStatementCount)

connector.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
5555
}
5656

5757
protocolVersion := int64(c.cfg.ThriftProtocolVersion)
58+
59+
sessionStart := time.Now()
5860
session, err := tclient.OpenSession(ctx, &cli_service.TOpenSessionReq{
5961
ClientProtocolI64: &protocolVersion,
6062
Configuration: sessionParams,
@@ -64,6 +66,8 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
6466
},
6567
CanUseMultipleCatalogs: &c.cfg.CanUseMultipleCatalogs,
6668
})
69+
sessionLatencyMs := time.Since(sessionStart).Milliseconds()
70+
6771
if err != nil {
6872
return nil, dbsqlerrint.NewRequestError(ctx, fmt.Sprintf("error connecting: host=%s port=%d, httpPath=%s", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath), err)
6973
}
@@ -80,11 +84,13 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
8084
conn.telemetry = telemetry.InitializeForConnection(
8185
ctx,
8286
c.cfg.Host,
87+
c.cfg.DriverVersion,
8388
c.client,
8489
c.cfg.EnableTelemetry,
8590
)
8691
if conn.telemetry != nil {
8792
log.Debug().Msg("telemetry initialized for connection")
93+
conn.telemetry.RecordOperation(ctx, conn.id, telemetry.OperationTypeCreateSession, sessionLatencyMs)
8894
}
8995

9096
log.Info().Msgf("connect: host=%s port=%d httpPath=%s serverProtocolVersion=0x%X", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath, session.ServerProtocolVersion)

internal/config/config.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ func (ucfg UserConfig) WithDefaults() UserConfig {
184184
ucfg.UseLz4Compression = false
185185
ucfg.CloudFetchConfig = CloudFetchConfig{}.WithDefaults()
186186

187+
// EnableTelemetry defaults to unset (ConfigValue zero value),
188+
// meaning telemetry is controlled by server feature flags.
189+
187190
return ucfg
188191
}
189192

0 commit comments

Comments
 (0)