Skip to content

Commit f850b47

Browse files
joohoyeo-devjoohoyeovikrantpuppala
authored
Add statement-level query tag support (#341)
## Summary - Adds per-statement query tag support via `driverctx.NewContextWithQueryTags`, allowing users to attach query tags to individual SQL statements through context - Tags are serialized into `TExecuteStatementReq.ConfOverlay["query_tags"]`, consistent with the Python ([#736](databricks/databricks-sql-python#736)) and NodeJS ([#339](databricks/databricks-sql-nodejs#339)) connector implementations - Previously only session-level query tags were supported (set once via `WithSessionParams` at connection time) ## Usage ```go ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{ "team": "data-eng", "app": "etl-pipeline", }) rows, err := db.QueryContext(ctx, "SELECT * FROM table") ``` ## Changes | File | Description | |------|-------------| | `driverctx/ctx.go` | `NewContextWithQueryTags`, `QueryTagsFromContext`, propagation in `NewContextFromBackground` | | `query_tags.go` *(new)* | `SerializeQueryTags` — map to wire format with escaping | | `connection.go` | Read tags from context → serialize → set `ConfOverlay["query_tags"]` | | `driverctx/ctx_test.go` | 5 tests for context helpers | | `query_tags_test.go` *(new)* | 13 tests for serialization (escaping, edge cases) | | `connection_test.go` | 6 integration tests verifying ConfOverlay behavior | | `examples/query_tags/main.go` | Updated with session + statement-level examples | ## Test plan - [x] Unit tests for `SerializeQueryTags` covering nil, empty, single/multi tags, escaping of `\`, `:`, `,` in values and keys - [x] Unit tests for `NewContextWithQueryTags` / `QueryTagsFromContext` including nil context, missing key, timeout preservation, background propagation - [x] Integration tests verifying `ConfOverlay["query_tags"]` is correctly set (or absent) in captured `TExecuteStatementReq` - [ ] Verify existing tests still pass (CI) This pull request was AI-assisted by Isaac. --------- Signed-off-by: Jooho Yeo <jooho.yeo@databricks.com> Co-authored-by: Jooho Yeo <jooho.yeo@databricks.com> Co-authored-by: Vikrant Puppala <vikrant.puppala@databricks.com>
1 parent 6fb037c commit f850b47

File tree

9 files changed

+523
-10
lines changed

9 files changed

+523
-10
lines changed

connection.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,17 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
331331
req.Parameters = parameters
332332
}
333333

334+
// Add per-statement query tags if provided via context
335+
if queryTags := driverctx.QueryTagsFromContext(ctx); len(queryTags) > 0 {
336+
serialized := SerializeQueryTags(queryTags)
337+
if serialized != "" {
338+
if req.ConfOverlay == nil {
339+
req.ConfOverlay = make(map[string]string)
340+
}
341+
req.ConfOverlay["query_tags"] = serialized
342+
}
343+
}
344+
334345
resp, err := c.client.ExecuteStatement(ctx, &req)
335346
var log *logger.DBSQLLogger
336347
log, ctx = client.LoggerAndContext(ctx, resp)

connection_test.go

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/apache/thrift/lib/go/thrift"
1111
"github.com/pkg/errors"
1212

13+
"github.com/databricks/databricks-sql-go/driverctx"
1314
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
1415
"github.com/databricks/databricks-sql-go/internal/cli_service"
1516
"github.com/databricks/databricks-sql-go/internal/client"
@@ -493,6 +494,209 @@ func TestConn_executeStatement_ProtocolFeatures(t *testing.T) {
493494
}
494495
}
495496

497+
func TestConn_executeStatement_QueryTags(t *testing.T) {
498+
t.Parallel()
499+
500+
makeTestConn := func(captureReq *(*cli_service.TExecuteStatementReq)) *conn {
501+
executeStatement := func(ctx context.Context, req *cli_service.TExecuteStatementReq) (r *cli_service.TExecuteStatementResp, err error) {
502+
*captureReq = req
503+
return &cli_service.TExecuteStatementResp{
504+
Status: &cli_service.TStatus{
505+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
506+
},
507+
OperationHandle: &cli_service.TOperationHandle{
508+
OperationId: &cli_service.THandleIdentifier{
509+
GUID: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
510+
Secret: []byte("secret"),
511+
},
512+
},
513+
DirectResults: &cli_service.TSparkDirectResults{
514+
OperationStatus: &cli_service.TGetOperationStatusResp{
515+
Status: &cli_service.TStatus{
516+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
517+
},
518+
OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE),
519+
},
520+
},
521+
}, nil
522+
}
523+
524+
return &conn{
525+
session: getTestSession(),
526+
client: &client.TestClient{
527+
FnExecuteStatement: executeStatement,
528+
},
529+
cfg: config.WithDefaults(),
530+
}
531+
}
532+
533+
t.Run("query tags from context are set in ConfOverlay", func(t *testing.T) {
534+
var capturedReq *cli_service.TExecuteStatementReq
535+
testConn := makeTestConn(&capturedReq)
536+
537+
ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{
538+
"team": "engineering",
539+
"app": "etl",
540+
})
541+
542+
_, err := testConn.executeStatement(ctx, "SELECT 1", nil)
543+
assert.NoError(t, err)
544+
assert.NotNil(t, capturedReq.ConfOverlay)
545+
// Map iteration is non-deterministic, so check both possible orderings
546+
queryTags := capturedReq.ConfOverlay["query_tags"]
547+
assert.True(t,
548+
queryTags == "team:engineering,app:etl" || queryTags == "app:etl,team:engineering",
549+
"unexpected query_tags value: %s", queryTags)
550+
})
551+
552+
t.Run("no query tags in context means no ConfOverlay", func(t *testing.T) {
553+
var capturedReq *cli_service.TExecuteStatementReq
554+
testConn := makeTestConn(&capturedReq)
555+
556+
_, err := testConn.executeStatement(context.Background(), "SELECT 1", nil)
557+
assert.NoError(t, err)
558+
assert.Nil(t, capturedReq.ConfOverlay)
559+
})
560+
561+
t.Run("empty query tags map means no ConfOverlay", func(t *testing.T) {
562+
var capturedReq *cli_service.TExecuteStatementReq
563+
testConn := makeTestConn(&capturedReq)
564+
565+
ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{})
566+
567+
_, err := testConn.executeStatement(ctx, "SELECT 1", nil)
568+
assert.NoError(t, err)
569+
assert.Nil(t, capturedReq.ConfOverlay)
570+
})
571+
572+
t.Run("single query tag", func(t *testing.T) {
573+
var capturedReq *cli_service.TExecuteStatementReq
574+
testConn := makeTestConn(&capturedReq)
575+
576+
ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{
577+
"team": "data-eng",
578+
})
579+
580+
_, err := testConn.executeStatement(ctx, "SELECT 1", nil)
581+
assert.NoError(t, err)
582+
assert.Equal(t, "team:data-eng", capturedReq.ConfOverlay["query_tags"])
583+
})
584+
585+
t.Run("query tags with special characters in values", func(t *testing.T) {
586+
var capturedReq *cli_service.TExecuteStatementReq
587+
testConn := makeTestConn(&capturedReq)
588+
589+
ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{
590+
"url": "http://host:8080",
591+
})
592+
593+
_, err := testConn.executeStatement(ctx, "SELECT 1", nil)
594+
assert.NoError(t, err)
595+
assert.Equal(t, `url:http\://host\:8080`, capturedReq.ConfOverlay["query_tags"])
596+
})
597+
598+
t.Run("query tags with empty value", func(t *testing.T) {
599+
var capturedReq *cli_service.TExecuteStatementReq
600+
testConn := makeTestConn(&capturedReq)
601+
602+
ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{
603+
"flag": "",
604+
})
605+
606+
_, err := testConn.executeStatement(ctx, "SELECT 1", nil)
607+
assert.NoError(t, err)
608+
assert.Equal(t, "flag", capturedReq.ConfOverlay["query_tags"])
609+
})
610+
611+
t.Run("session-level and statement-level query tags coexist", func(t *testing.T) {
612+
// Session-level tags are sent via TOpenSessionReq.Configuration at connect time.
613+
// Statement-level tags are sent via TExecuteStatementReq.ConfOverlay at query time.
614+
// They are independent fields on different requests, so both should work together.
615+
616+
var capturedOpenReq *cli_service.TOpenSessionReq
617+
var capturedExecReq *cli_service.TExecuteStatementReq
618+
619+
testClient := &client.TestClient{
620+
FnOpenSession: func(ctx context.Context, req *cli_service.TOpenSessionReq) (*cli_service.TOpenSessionResp, error) {
621+
capturedOpenReq = req
622+
return &cli_service.TOpenSessionResp{
623+
Status: &cli_service.TStatus{
624+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
625+
},
626+
SessionHandle: &cli_service.TSessionHandle{
627+
SessionId: &cli_service.THandleIdentifier{
628+
GUID: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
629+
},
630+
},
631+
}, nil
632+
},
633+
FnExecuteStatement: func(ctx context.Context, req *cli_service.TExecuteStatementReq) (*cli_service.TExecuteStatementResp, error) {
634+
capturedExecReq = req
635+
return &cli_service.TExecuteStatementResp{
636+
Status: &cli_service.TStatus{
637+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
638+
},
639+
OperationHandle: &cli_service.TOperationHandle{
640+
OperationId: &cli_service.THandleIdentifier{
641+
GUID: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
642+
Secret: []byte("secret"),
643+
},
644+
},
645+
DirectResults: &cli_service.TSparkDirectResults{
646+
OperationStatus: &cli_service.TGetOperationStatusResp{
647+
Status: &cli_service.TStatus{
648+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
649+
},
650+
OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE),
651+
},
652+
},
653+
}, nil
654+
},
655+
}
656+
657+
// Simulate what connector.Connect() does: pass session params to OpenSession
658+
sessionParams := map[string]string{
659+
"QUERY_TAGS": "team:platform,env:prod",
660+
"ansi_mode": "false",
661+
}
662+
protocolVersion := int64(cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V8)
663+
session, err := testClient.OpenSession(context.Background(), &cli_service.TOpenSessionReq{
664+
ClientProtocolI64: &protocolVersion,
665+
Configuration: sessionParams,
666+
})
667+
assert.NoError(t, err)
668+
669+
// Verify session-level tags were sent in OpenSession
670+
assert.Equal(t, "team:platform,env:prod", capturedOpenReq.Configuration["QUERY_TAGS"])
671+
assert.Equal(t, "false", capturedOpenReq.Configuration["ansi_mode"])
672+
673+
// Create conn with session that has session-level tags
674+
cfg := config.WithDefaults()
675+
cfg.SessionParams = sessionParams
676+
testConn := &conn{
677+
session: session,
678+
client: testClient,
679+
cfg: cfg,
680+
}
681+
682+
// Execute with statement-level tags
683+
ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{
684+
"job": "nightly-etl",
685+
})
686+
_, err = testConn.executeStatement(ctx, "SELECT 1", nil)
687+
assert.NoError(t, err)
688+
689+
// Statement-level tags should be in ConfOverlay
690+
assert.Equal(t, "job:nightly-etl", capturedExecReq.ConfOverlay["query_tags"])
691+
692+
// ConfOverlay should ONLY have query_tags, not session params
693+
_, hasAnsiMode := capturedExecReq.ConfOverlay["ansi_mode"]
694+
assert.False(t, hasAnsiMode, "session params should not leak into ConfOverlay")
695+
_, hasSessionQueryTags := capturedExecReq.ConfOverlay["QUERY_TAGS"]
696+
assert.False(t, hasSessionQueryTags, "session-level QUERY_TAGS should not be in ConfOverlay")
697+
})
698+
}
699+
496700
func TestConn_pollOperation(t *testing.T) {
497701
t.Parallel()
498702
t.Run("pollOperation returns finished state response when query finishes", func(t *testing.T) {

connector.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,23 @@ func WithSessionParams(params map[string]string) ConnOption {
247247
}
248248
}
249249

250+
// WithQueryTags sets session-level query tags from a map.
251+
// Tags are serialized and passed as QUERY_TAGS in the session configuration.
252+
// All queries in the session will carry these tags unless overridden at the statement level.
253+
// This is the preferred way to set session-level query tags, as it handles serialization
254+
// and escaping automatically (consistent with the statement-level API).
255+
func WithQueryTags(tags map[string]string) ConnOption {
256+
return func(c *config.Config) {
257+
serialized := SerializeQueryTags(tags)
258+
if serialized != "" {
259+
if c.SessionParams == nil {
260+
c.SessionParams = make(map[string]string)
261+
}
262+
c.SessionParams["QUERY_TAGS"] = serialized
263+
}
264+
}
265+
}
266+
250267
// WithSkipTLSHostVerify disables the verification of the hostname in the TLS certificate.
251268
// WARNING:
252269
// When this option is used, TLS is susceptible to machine-in-the-middle attacks.

connector_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,65 @@ func TestNewConnector(t *testing.T) {
268268
})
269269
}
270270

271+
func TestWithQueryTags(t *testing.T) {
272+
t.Run("WithQueryTags serializes map into SessionParams QUERY_TAGS", func(t *testing.T) {
273+
con, err := NewConnector(
274+
WithQueryTags(map[string]string{
275+
"team": "data-eng",
276+
}),
277+
)
278+
require.NoError(t, err)
279+
coni, ok := con.(*connector)
280+
require.True(t, ok)
281+
assert.Equal(t, "team:data-eng", coni.cfg.SessionParams["QUERY_TAGS"])
282+
})
283+
284+
t.Run("WithQueryTags with multiple tags", func(t *testing.T) {
285+
con, err := NewConnector(
286+
WithQueryTags(map[string]string{
287+
"team": "eng",
288+
"app": "etl",
289+
}),
290+
)
291+
require.NoError(t, err)
292+
coni, ok := con.(*connector)
293+
require.True(t, ok)
294+
// Map iteration is non-deterministic
295+
qt := coni.cfg.SessionParams["QUERY_TAGS"]
296+
assert.True(t, qt == "team:eng,app:etl" || qt == "app:etl,team:eng", "got: %s", qt)
297+
})
298+
299+
t.Run("WithQueryTags with empty map does not set QUERY_TAGS", func(t *testing.T) {
300+
con, err := NewConnector(
301+
WithQueryTags(map[string]string{}),
302+
)
303+
require.NoError(t, err)
304+
coni, ok := con.(*connector)
305+
require.True(t, ok)
306+
_, exists := coni.cfg.SessionParams["QUERY_TAGS"]
307+
assert.False(t, exists)
308+
})
309+
310+
t.Run("WithQueryTags overrides WithSessionParams QUERY_TAGS", func(t *testing.T) {
311+
con, err := NewConnector(
312+
WithSessionParams(map[string]string{
313+
"QUERY_TAGS": "old:value",
314+
"ansi_mode": "false",
315+
}),
316+
WithQueryTags(map[string]string{
317+
"team": "new-team",
318+
}),
319+
)
320+
require.NoError(t, err)
321+
coni, ok := con.(*connector)
322+
require.True(t, ok)
323+
// WithQueryTags should override the QUERY_TAGS from WithSessionParams
324+
assert.Equal(t, "team:new-team", coni.cfg.SessionParams["QUERY_TAGS"])
325+
// Other session params should be preserved
326+
assert.Equal(t, "false", coni.cfg.SessionParams["ansi_mode"])
327+
})
328+
}
329+
271330
type mockRoundTripper struct{}
272331

273332
var _ http.RoundTripper = mockRoundTripper{}

driverctx/ctx.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ const (
1515
QueryIdCallbackKey
1616
ConnIdCallbackKey
1717
StagingAllowedLocalPathKey
18+
QueryTagsContextKey
1819
)
1920

2021
type IdCallbackFunc func(string)
@@ -107,16 +108,40 @@ func NewContextWithStagingInfo(ctx context.Context, stagingAllowedLocalPath []st
107108
return context.WithValue(ctx, StagingAllowedLocalPathKey, stagingAllowedLocalPath)
108109
}
109110

111+
// NewContextWithQueryTags creates a new context with per-statement query tags.
112+
// These tags are serialized and passed via confOverlay as "query_tags" in TExecuteStatementReq.
113+
// They apply only to the statement executed with this context and do not persist across queries.
114+
func NewContextWithQueryTags(ctx context.Context, queryTags map[string]string) context.Context {
115+
return context.WithValue(ctx, QueryTagsContextKey, queryTags)
116+
}
117+
118+
// QueryTagsFromContext retrieves the per-statement query tags stored in context.
119+
func QueryTagsFromContext(ctx context.Context) map[string]string {
120+
if ctx == nil {
121+
return nil
122+
}
123+
124+
queryTags, ok := ctx.Value(QueryTagsContextKey).(map[string]string)
125+
if !ok {
126+
return nil
127+
}
128+
return queryTags
129+
}
130+
110131
func NewContextFromBackground(ctx context.Context) context.Context {
111132
connId := ConnIdFromContext(ctx)
112133
corrId := CorrelationIdFromContext(ctx)
113134
queryId := QueryIdFromContext(ctx)
114135
stagingPaths := StagingPathsFromContext(ctx)
136+
queryTags := QueryTagsFromContext(ctx)
115137

116138
newCtx := NewContextWithConnId(context.Background(), connId)
117139
newCtx = NewContextWithCorrelationId(newCtx, corrId)
118140
newCtx = NewContextWithQueryId(newCtx, queryId)
119141
newCtx = NewContextWithStagingInfo(newCtx, stagingPaths)
142+
if queryTags != nil {
143+
newCtx = NewContextWithQueryTags(newCtx, queryTags)
144+
}
120145

121146
return newCtx
122147
}

0 commit comments

Comments
 (0)