@@ -16,7 +16,6 @@ import (
1616 dbsqlerr "github.com/databricks/databricks-sql-go/errors"
1717 "github.com/databricks/databricks-sql-go/internal/cli_service"
1818 "github.com/databricks/databricks-sql-go/internal/client"
19- context2 "github.com/databricks/databricks-sql-go/internal/compat/context"
2019 "github.com/databricks/databricks-sql-go/internal/config"
2120 dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors"
2221 "github.com/databricks/databricks-sql-go/internal/rows"
@@ -53,18 +52,21 @@ func (c *conn) Close() error {
5352 ctx := driverctx .NewContextWithConnId (context .Background (), c .id )
5453
5554 // Close telemetry and release resources
55+ closeStart := time .Now ()
56+ _ , err := c .client .CloseSession (ctx , & cli_service.TCloseSessionReq {
57+ SessionHandle : c .session .SessionHandle ,
58+ })
59+ closeLatencyMs := time .Since (closeStart ).Milliseconds ()
60+
5661 if c .telemetry != nil {
62+ c .telemetry .RecordOperation (ctx , c .id , telemetry .OperationTypeDeleteSession , closeLatencyMs )
5763 _ = c .telemetry .Close (ctx )
5864 telemetry .ReleaseForConnection (c .cfg .Host )
5965 }
6066
61- _ , err := c .client .CloseSession (ctx , & cli_service.TCloseSessionReq {
62- SessionHandle : c .session .SessionHandle ,
63- })
64-
6567 if err != nil {
6668 log .Err (err ).Msg ("databricks: failed to close connection" )
67- return dbsqlerrint .NewBadConnectionError ( err )
69+ return dbsqlerrint .NewRequestError ( ctx , dbsqlerr . ErrCloseConnection , err )
6870 }
6971 return nil
7072}
@@ -123,15 +125,16 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
123125
124126 corrId := driverctx .CorrelationIdFromContext (ctx )
125127
126- exStmtResp , opStatusResp , err := c .runQuery (ctx , query , args )
128+ var pollCount int
129+ exStmtResp , opStatusResp , err := c .runQuery (ctx , query , args , & pollCount )
127130 log , ctx = client .LoggerAndContext (ctx , exStmtResp )
128131 stagingErr := c .execStagingOperation (exStmtResp , ctx )
129132
130133 // Telemetry: track statement execution
131134 var statementID string
132135 if c .telemetry != nil && exStmtResp != nil && exStmtResp .OperationHandle != nil && exStmtResp .OperationHandle .OperationId != nil {
133136 statementID = client .SprintGuid (exStmtResp .OperationHandle .OperationId .GUID )
134- ctx = c .telemetry .BeforeExecute (ctx , statementID )
137+ ctx = c .telemetry .BeforeExecute (ctx , c . id , statementID )
135138 defer func () {
136139 finalErr := err
137140 if stagingErr != nil {
@@ -140,6 +143,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
140143 c .telemetry .AfterExecute (ctx , finalErr )
141144 c .telemetry .CompleteStatement (ctx , statementID , finalErr != nil )
142145 }()
146+ c .telemetry .AddTag (ctx , "poll_count" , pollCount )
143147 }
144148
145149 if exStmtResp != nil && exStmtResp .OperationHandle != nil {
@@ -181,34 +185,61 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
181185 log , _ := client .LoggerAndContext (ctx , nil )
182186 msg , start := log .Track ("QueryContext" )
183187
184- // first we try to get the results synchronously.
185- // at any point in time that the context is done we must cancel and return
186- exStmtResp , opStatusResp , err := c .runQuery (ctx , query , args )
188+ // Capture execution start time for telemetry before running the query
189+ executeStart := time .Now ()
190+ var pollCount int
191+ exStmtResp , opStatusResp , pollCount , err := c .runQueryWithTelemetry (ctx , query , args , & pollCount )
187192 log , ctx = client .LoggerAndContext (ctx , exStmtResp )
188193 defer log .Duration (msg , start )
189194
190- // Telemetry: track statement execution
191195 var statementID string
192196 if c .telemetry != nil && exStmtResp != nil && exStmtResp .OperationHandle != nil && exStmtResp .OperationHandle .OperationId != nil {
193197 statementID = client .SprintGuid (exStmtResp .OperationHandle .OperationId .GUID )
194- ctx = c .telemetry .BeforeExecute (ctx , statementID )
198+ // Use BeforeExecuteWithTime to set the correct start time (before execution)
199+ ctx = c .telemetry .BeforeExecuteWithTime (ctx , c .id , statementID , executeStart )
195200 defer func () {
196201 c .telemetry .AfterExecute (ctx , err )
197202 c .telemetry .CompleteStatement (ctx , statementID , err != nil )
198203 }()
204+
205+ c .telemetry .AddTag (ctx , "poll_count" , pollCount )
206+ c .telemetry .AddTag (ctx , "operation_type" , telemetry .OperationTypeExecuteStatement )
207+
208+ if exStmtResp .DirectResults != nil && exStmtResp .DirectResults .ResultSetMetadata != nil {
209+ resultFormat := exStmtResp .DirectResults .ResultSetMetadata .GetResultFormat ()
210+ c .telemetry .AddTag (ctx , "result.format" , resultFormat .String ())
211+ }
199212 }
200213
201214 if err != nil {
202215 log .Err (err ).Msg ("databricks: failed to run query" ) // To log query we need to redact credentials
203216 return nil , dbsqlerrint .NewExecutionError (ctx , dbsqlerr .ErrQueryExecution , err , opStatusResp )
204217 }
205218
206- rows , err := rows .NewRows (ctx , exStmtResp .OperationHandle , c .client , c .cfg , exStmtResp .DirectResults )
219+ var telemetryUpdate func (int , int64 )
220+ if c .telemetry != nil {
221+ telemetryUpdate = func (chunkCount int , bytesDownloaded int64 ) {
222+ c .telemetry .AddTag (ctx , "chunk_count" , chunkCount )
223+ c .telemetry .AddTag (ctx , "bytes_downloaded" , bytesDownloaded )
224+ }
225+ }
226+
227+ rows , err := rows .NewRows (ctx , exStmtResp .OperationHandle , c .client , c .cfg , exStmtResp .DirectResults , ctx , telemetryUpdate )
228+
207229 return rows , err
208230
209231}
210232
211- func (c * conn ) runQuery (ctx context.Context , query string , args []driver.NamedValue ) (* cli_service.TExecuteStatementResp , * cli_service.TGetOperationStatusResp , error ) {
233+ func (c * conn ) runQueryWithTelemetry (ctx context.Context , query string , args []driver.NamedValue , pollCount * int ) (* cli_service.TExecuteStatementResp , * cli_service.TGetOperationStatusResp , int , error ) {
234+ exStmtResp , opStatusResp , err := c .runQuery (ctx , query , args , pollCount )
235+ count := 0
236+ if pollCount != nil {
237+ count = * pollCount
238+ }
239+ return exStmtResp , opStatusResp , count , err
240+ }
241+
242+ func (c * conn ) runQuery (ctx context.Context , query string , args []driver.NamedValue , pollCount * int ) (* cli_service.TExecuteStatementResp , * cli_service.TGetOperationStatusResp , error ) {
212243 // first we try to get the results synchronously.
213244 // at any point in time that the context is done we must cancel and return
214245 exStmtResp , err := c .executeStatement (ctx , query , args )
@@ -240,7 +271,7 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
240271 case cli_service .TOperationState_INITIALIZED_STATE ,
241272 cli_service .TOperationState_PENDING_STATE ,
242273 cli_service .TOperationState_RUNNING_STATE :
243- statusResp , err := c .pollOperation (ctx , opHandle )
274+ statusResp , err := c .pollOperationWithCount (ctx , opHandle , pollCount )
244275 if err != nil {
245276 return exStmtResp , statusResp , err
246277 }
@@ -268,7 +299,7 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
268299 }
269300
270301 } else {
271- statusResp , err := c .pollOperation (ctx , opHandle )
302+ statusResp , err := c .pollOperationWithCount (ctx , opHandle , pollCount )
272303 if err != nil {
273304 return exStmtResp , statusResp , err
274305 }
@@ -372,7 +403,6 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
372403
373404 select {
374405 default :
375- // Non-blocking check: continue if context not done
376406 case <- ctx .Done ():
377407 newCtx := driverctx .NewContextFromBackground (ctx )
378408 // in case context is done, we need to cancel the operation if necessary
@@ -396,12 +426,12 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
396426 return resp , err
397427}
398428
399- func (c * conn ) pollOperation (ctx context.Context , opHandle * cli_service.TOperationHandle ) (* cli_service.TGetOperationStatusResp , error ) {
429+ func (c * conn ) pollOperationWithCount (ctx context.Context , opHandle * cli_service.TOperationHandle , pollCount * int ) (* cli_service.TGetOperationStatusResp , error ) {
400430 corrId := driverctx .CorrelationIdFromContext (ctx )
401431 log := logger .WithContext (c .id , corrId , client .SprintGuid (opHandle .OperationId .GUID ))
402432 var statusResp * cli_service.TGetOperationStatusResp
403433 ctx = driverctx .NewContextWithConnId (ctx , c .id )
404- newCtx := context2 . WithoutCancel ( ctx )
434+ newCtx := driverctx . NewContextWithCorrelationId ( driverctx . NewContextWithConnId ( context . Background (), c . id ), corrId )
405435 pollSentinel := sentinel.Sentinel {
406436 OnDoneFn : func (statusResp any ) (any , error ) {
407437 return statusResp , nil
@@ -413,6 +443,10 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
413443 OperationHandle : opHandle ,
414444 })
415445
446+ if pollCount != nil {
447+ * pollCount ++
448+ }
449+
416450 if statusResp != nil && statusResp .OperationState != nil {
417451 log .Debug ().Msgf ("databricks: status %s" , statusResp .GetOperationState ().String ())
418452 }
@@ -455,6 +489,10 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
455489 return statusResp , nil
456490}
457491
492+ func (c * conn ) pollOperation (ctx context.Context , opHandle * cli_service.TOperationHandle ) (* cli_service.TGetOperationStatusResp , error ) {
493+ return c .pollOperationWithCount (ctx , opHandle , nil )
494+ }
495+
458496func (c * conn ) CheckNamedValue (nv * driver.NamedValue ) error {
459497 var err error
460498 if parameter , ok := nv .Value .(Parameter ); ok {
@@ -622,7 +660,7 @@ func (c *conn) execStagingOperation(
622660 }
623661
624662 if len (driverctx .StagingPathsFromContext (ctx )) != 0 {
625- row , err = rows .NewRows (ctx , exStmtResp .OperationHandle , c .client , c .cfg , exStmtResp .DirectResults )
663+ row , err = rows .NewRows (ctx , exStmtResp .OperationHandle , c .client , c .cfg , exStmtResp .DirectResults , nil , nil )
626664 if err != nil {
627665 return dbsqlerrint .NewDriverError (ctx , "error reading row." , err )
628666 }
0 commit comments