Skip to content

Commit ad8c237

Browse files
committed
Support Databricks query tags from session properties
1 parent 3e6bb81 commit ad8c237

2 files changed

Lines changed: 177 additions & 0 deletions

File tree

sqlmesh/core/engine_adapter/databricks.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,43 @@
3030
logger = logging.getLogger(__name__)
3131

3232

33+
def _query_tags(query_tags: t.Optional[exp.Expression]) -> t.Optional[t.Dict[str, t.Optional[str]]]:
34+
if not query_tags:
35+
return None
36+
37+
if not isinstance(query_tags, exp.Map):
38+
raise SQLMeshError(
39+
"Invalid value for `session_properties.query_tags`. Must be a map."
40+
)
41+
42+
keys = query_tags.args.get("keys")
43+
values = query_tags.args.get("values")
44+
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
45+
raise SQLMeshError(
46+
"Invalid value for `session_properties.query_tags`. Must be a map with array "
47+
"keys and array values."
48+
)
49+
50+
tags: t.Dict[str, t.Optional[str]] = {}
51+
for key, value in zip(keys.expressions, values.expressions):
52+
if not isinstance(key, exp.Literal) or not key.is_string:
53+
raise SQLMeshError(
54+
"Invalid key in `session_properties.query_tags`. Keys must be string literals."
55+
)
56+
57+
if isinstance(value, exp.Null):
58+
tags[key.this] = None
59+
elif isinstance(value, exp.Literal) and value.is_string:
60+
tags[key.this] = value.this
61+
else:
62+
raise SQLMeshError(
63+
"Invalid value in `session_properties.query_tags`. Values must be string "
64+
"literals or NULL."
65+
)
66+
67+
return tags
68+
69+
3370
class DatabricksEngineAdapter(SparkEngineAdapter, GrantsFromInfoSchemaMixin):
3471
DIALECT = "databricks"
3572
INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE
@@ -98,6 +135,13 @@ def _use_spark_session(self) -> bool:
98135
def is_spark_session_connection(self) -> bool:
99136
return isinstance(self.connection, SparkSessionConnection)
100137

138+
@property
139+
def _is_databricks_sql_connector_connection(self) -> bool:
140+
return (
141+
not self.is_spark_session_connection
142+
and not self._connection_pool.get_attribute("use_spark_engine_adapter")
143+
)
144+
101145
def _set_spark_engine_adapter_if_needed(self) -> None:
102146
self._spark_engine_adapter = None
103147

@@ -181,10 +225,25 @@ def _begin_session(self, properties: SessionProperties) -> t.Any:
181225
"""Begin a new session."""
182226
# Align the different possible connectors to a single catalog
183227
self.set_current_catalog(self.default_catalog) # type: ignore
228+
self._connection_pool.set_attribute(
229+
"query_tags", _query_tags(properties.get("query_tags"))
230+
)
184231

185232
def _end_session(self) -> None:
233+
self._connection_pool.set_attribute("query_tags", None)
186234
self._connection_pool.set_attribute("use_spark_engine_adapter", False)
187235

236+
def _execute(self, sql: str, track_rows_processed: bool = False, **kwargs: t.Any) -> None:
237+
query_tags = self._connection_pool.get_attribute("query_tags")
238+
if (
239+
query_tags
240+
and "query_tags" not in kwargs
241+
and self._is_databricks_sql_connector_connection
242+
):
243+
kwargs["query_tags"] = query_tags
244+
245+
return super()._execute(sql, track_rows_processed, **kwargs)
246+
188247
def _df_to_source_queries(
189248
self,
190249
df: DF,

tests/core/engine_adapter/test_databricks.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,23 @@
1010
from sqlmesh.core.engine_adapter import DatabricksEngineAdapter
1111
from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType
1212
from sqlmesh.core.node import IntervalUnit
13+
from sqlmesh.utils.errors import SQLMeshError
1314
from tests.core.engine_adapter import to_sql_calls
1415

1516
pytestmark = [pytest.mark.databricks, pytest.mark.engine]
1617

1718

19+
def _query_tags_map(*items: t.Optional[str]) -> exp.Map:
20+
return exp.Map(
21+
keys=exp.Array(expressions=[exp.Literal.string(item) for item in items[::2]]),
22+
values=exp.Array(
23+
expressions=[
24+
exp.Null() if item is None else exp.Literal.string(item) for item in items[1::2]
25+
]
26+
),
27+
)
28+
29+
1830
def test_replace_query_not_exists(mocker: MockFixture, make_mocked_engine_adapter: t.Callable):
1931
mocker.patch(
2032
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.table_exists",
@@ -117,6 +129,112 @@ def test_set_current_catalog(mocker: MockFixture, make_mocked_engine_adapter: t.
117129
assert to_sql_calls(adapter) == ["USE CATALOG `test_catalog2`"]
118130

119131

132+
def test_session_query_tags(mocker: MockFixture, make_mocked_engine_adapter: t.Callable):
133+
mocker.patch(
134+
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
135+
)
136+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
137+
138+
with adapter.session({"query_tags": _query_tags_map("team", "data-eng", "app", "sqlmesh")}):
139+
adapter.execute("SELECT 1")
140+
141+
adapter.cursor.execute.assert_called_with(
142+
"SELECT 1", query_tags={"team": "data-eng", "app": "sqlmesh"}
143+
)
144+
145+
adapter.execute("SELECT 2")
146+
147+
adapter.cursor.execute.assert_called_with("SELECT 2")
148+
149+
150+
def test_session_query_tags_allow_none_values(
151+
mocker: MockFixture, make_mocked_engine_adapter: t.Callable
152+
):
153+
mocker.patch(
154+
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
155+
)
156+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
157+
158+
with adapter.session({"query_tags": _query_tags_map("team", "data-eng", "feature", None)}):
159+
adapter.execute("SELECT 1")
160+
161+
adapter.cursor.execute.assert_called_with("SELECT 1", query_tags={"team": "data-eng", "feature": None})
162+
163+
164+
def test_session_query_tags_do_not_override_explicit_query_tags(
165+
mocker: MockFixture, make_mocked_engine_adapter: t.Callable
166+
):
167+
mocker.patch(
168+
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
169+
)
170+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
171+
172+
with adapter.session({"query_tags": _query_tags_map("team", "data-eng")}):
173+
adapter.execute("SELECT 1", query_tags={"team": "analytics"})
174+
175+
adapter.cursor.execute.assert_called_with("SELECT 1", query_tags={"team": "analytics"})
176+
177+
178+
def test_session_query_tags_not_applied_to_spark_session_connection(
179+
mocker: MockFixture, make_mocked_engine_adapter: t.Callable
180+
):
181+
mocker.patch(
182+
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
183+
)
184+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
185+
mocker.patch.object(
186+
DatabricksEngineAdapter,
187+
"is_spark_session_connection",
188+
new_callable=mocker.PropertyMock,
189+
return_value=True,
190+
)
191+
192+
with adapter.session({"query_tags": _query_tags_map("team", "data-eng")}):
193+
adapter.execute("SELECT 1")
194+
195+
adapter.cursor.execute.assert_called_with("SELECT 1")
196+
197+
198+
def test_session_query_tags_not_applied_to_spark_engine_adapter(
199+
mocker: MockFixture, make_mocked_engine_adapter: t.Callable
200+
):
201+
mocker.patch(
202+
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
203+
)
204+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
205+
spark_cursor = mocker.Mock()
206+
adapter._spark_engine_adapter = mocker.Mock(cursor=spark_cursor)
207+
adapter._connection_pool.set_attribute("use_spark_engine_adapter", True)
208+
209+
with adapter.session({"query_tags": _query_tags_map("team", "data-eng")}):
210+
adapter._connection_pool.set_attribute("use_spark_engine_adapter", True)
211+
adapter.execute("SELECT 1")
212+
213+
spark_cursor.execute.assert_called_with("SELECT 1")
214+
215+
216+
@pytest.mark.parametrize(
217+
"query_tags",
218+
[
219+
"team:data-eng",
220+
exp.Map(
221+
keys=exp.Array(expressions=[exp.Literal.number(1)]),
222+
values=exp.Array(expressions=[exp.Literal.string("data-eng")]),
223+
),
224+
exp.Map(
225+
keys=exp.Array(expressions=[exp.Literal.string("team")]),
226+
values=exp.Array(expressions=[exp.Literal.number(1)]),
227+
),
228+
],
229+
)
230+
def test_session_query_tags_invalid(query_tags, make_mocked_engine_adapter: t.Callable):
231+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
232+
233+
with pytest.raises(SQLMeshError, match="session_properties.query_tags"):
234+
with adapter.session({"query_tags": query_tags}):
235+
pass
236+
237+
120238
def test_get_current_catalog(mocker: MockFixture, make_mocked_engine_adapter: t.Callable):
121239
mocker.patch(
122240
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"

0 commit comments

Comments
 (0)