@@ -16,7 +16,6 @@ import (
1616 dbsqlerr "github.com/databricks/databricks-sql-go/errors"
1717 "github.com/databricks/databricks-sql-go/internal/cli_service"
1818 "github.com/databricks/databricks-sql-go/internal/client"
19- context2 "github.com/databricks/databricks-sql-go/internal/compat/context"
2019 "github.com/databricks/databricks-sql-go/internal/config"
2120 dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors"
2221 "github.com/databricks/databricks-sql-go/internal/rows"
@@ -53,18 +52,21 @@ func (c *conn) Close() error {
5352 ctx := driverctx .NewContextWithConnId (context .Background (), c .id )
5453
5554 // Close telemetry and release resources
55+ closeStart := time .Now ()
56+ _ , err := c .client .CloseSession (ctx , & cli_service.TCloseSessionReq {
57+ SessionHandle : c .session .SessionHandle ,
58+ })
59+ closeLatencyMs := time .Since (closeStart ).Milliseconds ()
60+
5661 if c .telemetry != nil {
62+ c .telemetry .RecordOperation (ctx , c .id , telemetry .OperationTypeDeleteSession , closeLatencyMs )
5763 _ = c .telemetry .Close (ctx )
5864 telemetry .ReleaseForConnection (c .cfg .Host )
5965 }
6066
61- _ , err := c .client .CloseSession (ctx , & cli_service.TCloseSessionReq {
62- SessionHandle : c .session .SessionHandle ,
63- })
64-
6567 if err != nil {
6668 log .Err (err ).Msg ("databricks: failed to close connection" )
67- return dbsqlerrint .NewBadConnectionError ( err )
69+ return dbsqlerrint .NewRequestError ( ctx , dbsqlerr . ErrCloseConnection , err )
6870 }
6971 return nil
7072}
@@ -123,7 +125,8 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
123125
124126 corrId := driverctx .CorrelationIdFromContext (ctx )
125127
126- exStmtResp , opStatusResp , err := c .runQuery (ctx , query , args )
128+ var pollCount int
129+ exStmtResp , opStatusResp , err := c .runQuery (ctx , query , args , & pollCount )
127130 log , ctx = client .LoggerAndContext (ctx , exStmtResp )
128131 stagingErr := c .execStagingOperation (exStmtResp , ctx )
129132
@@ -132,7 +135,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
132135 var closeOpErr error // Track CloseOperation errors for telemetry
133136 if c .telemetry != nil && exStmtResp != nil && exStmtResp .OperationHandle != nil && exStmtResp .OperationHandle .OperationId != nil {
134137 statementID = client .SprintGuid (exStmtResp .OperationHandle .OperationId .GUID )
135- ctx = c .telemetry .BeforeExecute (ctx , statementID )
138+ ctx = c .telemetry .BeforeExecute (ctx , c . id , statementID )
136139 defer func () {
137140 finalErr := err
138141 if stagingErr != nil {
@@ -145,6 +148,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
145148 c .telemetry .AfterExecute (ctx , finalErr )
146149 c .telemetry .CompleteStatement (ctx , statementID , finalErr != nil )
147150 }()
151+ c .telemetry .AddTag (ctx , "poll_count" , pollCount )
148152 }
149153
150154 if exStmtResp != nil && exStmtResp .OperationHandle != nil {
@@ -187,34 +191,61 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
187191 log , _ := client .LoggerAndContext (ctx , nil )
188192 msg , start := log .Track ("QueryContext" )
189193
190- // first we try to get the results synchronously.
191- // at any point in time that the context is done we must cancel and return
192- exStmtResp , opStatusResp , err := c .runQuery (ctx , query , args )
194+ // Capture execution start time for telemetry before running the query
195+ executeStart := time .Now ()
196+ var pollCount int
197+ exStmtResp , opStatusResp , pollCount , err := c .runQueryWithTelemetry (ctx , query , args , & pollCount )
193198 log , ctx = client .LoggerAndContext (ctx , exStmtResp )
194199 defer log .Duration (msg , start )
195200
196- // Telemetry: track statement execution
197201 var statementID string
198202 if c .telemetry != nil && exStmtResp != nil && exStmtResp .OperationHandle != nil && exStmtResp .OperationHandle .OperationId != nil {
199203 statementID = client .SprintGuid (exStmtResp .OperationHandle .OperationId .GUID )
200- ctx = c .telemetry .BeforeExecute (ctx , statementID )
204+ // Use BeforeExecuteWithTime to set the correct start time (before execution)
205+ ctx = c .telemetry .BeforeExecuteWithTime (ctx , c .id , statementID , executeStart )
201206 defer func () {
202207 c .telemetry .AfterExecute (ctx , err )
203208 c .telemetry .CompleteStatement (ctx , statementID , err != nil )
204209 }()
210+
211+ c .telemetry .AddTag (ctx , "poll_count" , pollCount )
212+ c .telemetry .AddTag (ctx , "operation_type" , telemetry .OperationTypeExecuteStatement )
213+
214+ if exStmtResp .DirectResults != nil && exStmtResp .DirectResults .ResultSetMetadata != nil {
215+ resultFormat := exStmtResp .DirectResults .ResultSetMetadata .GetResultFormat ()
216+ c .telemetry .AddTag (ctx , "result.format" , resultFormat .String ())
217+ }
205218 }
206219
207220 if err != nil {
208221 log .Err (err ).Msg ("databricks: failed to run query" ) // To log query we need to redact credentials
209222 return nil , dbsqlerrint .NewExecutionError (ctx , dbsqlerr .ErrQueryExecution , err , opStatusResp )
210223 }
211224
212- rows , err := rows .NewRows (ctx , exStmtResp .OperationHandle , c .client , c .cfg , exStmtResp .DirectResults )
225+ var telemetryUpdate func (int , int64 )
226+ if c .telemetry != nil {
227+ telemetryUpdate = func (chunkCount int , bytesDownloaded int64 ) {
228+ c .telemetry .AddTag (ctx , "chunk_count" , chunkCount )
229+ c .telemetry .AddTag (ctx , "bytes_downloaded" , bytesDownloaded )
230+ }
231+ }
232+
233+ rows , err := rows .NewRows (ctx , exStmtResp .OperationHandle , c .client , c .cfg , exStmtResp .DirectResults , ctx , telemetryUpdate )
234+
213235 return rows , err
214236
215237}
216238
217- func (c * conn ) runQuery (ctx context.Context , query string , args []driver.NamedValue ) (* cli_service.TExecuteStatementResp , * cli_service.TGetOperationStatusResp , error ) {
239+ func (c * conn ) runQueryWithTelemetry (ctx context.Context , query string , args []driver.NamedValue , pollCount * int ) (* cli_service.TExecuteStatementResp , * cli_service.TGetOperationStatusResp , int , error ) {
240+ exStmtResp , opStatusResp , err := c .runQuery (ctx , query , args , pollCount )
241+ count := 0
242+ if pollCount != nil {
243+ count = * pollCount
244+ }
245+ return exStmtResp , opStatusResp , count , err
246+ }
247+
248+ func (c * conn ) runQuery (ctx context.Context , query string , args []driver.NamedValue , pollCount * int ) (* cli_service.TExecuteStatementResp , * cli_service.TGetOperationStatusResp , error ) {
218249 // first we try to get the results synchronously.
219250 // at any point in time that the context is done we must cancel and return
220251 exStmtResp , err := c .executeStatement (ctx , query , args )
@@ -246,7 +277,7 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
246277 case cli_service .TOperationState_INITIALIZED_STATE ,
247278 cli_service .TOperationState_PENDING_STATE ,
248279 cli_service .TOperationState_RUNNING_STATE :
249- statusResp , err := c .pollOperation (ctx , opHandle )
280+ statusResp , err := c .pollOperationWithCount (ctx , opHandle , pollCount )
250281 if err != nil {
251282 return exStmtResp , statusResp , err
252283 }
@@ -274,7 +305,7 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
274305 }
275306
276307 } else {
277- statusResp , err := c .pollOperation (ctx , opHandle )
308+ statusResp , err := c .pollOperationWithCount (ctx , opHandle , pollCount )
278309 if err != nil {
279310 return exStmtResp , statusResp , err
280311 }
@@ -389,7 +420,6 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
389420
390421 select {
391422 default :
392- // Non-blocking check: continue if context not done
393423 case <- ctx .Done ():
394424 newCtx := driverctx .NewContextFromBackground (ctx )
395425 // in case context is done, we need to cancel the operation if necessary
@@ -413,12 +443,12 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
413443 return resp , err
414444}
415445
416- func (c * conn ) pollOperation (ctx context.Context , opHandle * cli_service.TOperationHandle ) (* cli_service.TGetOperationStatusResp , error ) {
446+ func (c * conn ) pollOperationWithCount (ctx context.Context , opHandle * cli_service.TOperationHandle , pollCount * int ) (* cli_service.TGetOperationStatusResp , error ) {
417447 corrId := driverctx .CorrelationIdFromContext (ctx )
418448 log := logger .WithContext (c .id , corrId , client .SprintGuid (opHandle .OperationId .GUID ))
419449 var statusResp * cli_service.TGetOperationStatusResp
420450 ctx = driverctx .NewContextWithConnId (ctx , c .id )
421- newCtx := context2 . WithoutCancel ( ctx )
451+ newCtx := driverctx . NewContextWithCorrelationId ( driverctx . NewContextWithConnId ( context . Background (), c . id ), corrId )
422452 pollSentinel := sentinel.Sentinel {
423453 OnDoneFn : func (statusResp any ) (any , error ) {
424454 return statusResp , nil
@@ -430,6 +460,10 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
430460 OperationHandle : opHandle ,
431461 })
432462
463+ if pollCount != nil {
464+ * pollCount ++
465+ }
466+
433467 if statusResp != nil && statusResp .OperationState != nil {
434468 log .Debug ().Msgf ("databricks: status %s" , statusResp .GetOperationState ().String ())
435469 }
@@ -472,6 +506,10 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
472506 return statusResp , nil
473507}
474508
509+ func (c * conn ) pollOperation (ctx context.Context , opHandle * cli_service.TOperationHandle ) (* cli_service.TGetOperationStatusResp , error ) {
510+ return c .pollOperationWithCount (ctx , opHandle , nil )
511+ }
512+
475513func (c * conn ) CheckNamedValue (nv * driver.NamedValue ) error {
476514 var err error
477515 if parameter , ok := nv .Value .(Parameter ); ok {
@@ -639,7 +677,7 @@ func (c *conn) execStagingOperation(
639677 }
640678
641679 if len (driverctx .StagingPathsFromContext (ctx )) != 0 {
642- row , err = rows .NewRows (ctx , exStmtResp .OperationHandle , c .client , c .cfg , exStmtResp .DirectResults )
680+ row , err = rows .NewRows (ctx , exStmtResp .OperationHandle , c .client , c .cfg , exStmtResp .DirectResults , nil , nil )
643681 if err != nil {
644682 return dbsqlerrint .NewDriverError (ctx , "error reading row." , err )
645683 }
0 commit comments