Skip to content

Commit 7c70c0f

Browse files
authored
fix: Set correct cluster type on databricks connection (#1158)
<!-- Please review our pull request review process in CONTRIBUTING.md before your proceed. --> Resolves # <!--- Include the number of the issue addressed by this PR above if applicable. Example: resolves #1234 Please review our pull request review process in CONTRIBUTING.md before your proceed. --> ### Description We're currently setting the wrong connection type in `DatabricksHandle.from_connection_args(..)` when establishing connections with a compute override at the model level ### Checklist - [ ] I have run this code in development and it appears to resolve the stated issue - [ ] This PR includes tests, or tests are not required/relevant for this PR - [ ] I have updated the `CHANGELOG.md` and added information about my change to the "dbt-databricks next" section.
1 parent 0e233c4 commit 7c70c0f

4 files changed

Lines changed: 118 additions & 10 deletions

File tree

dbt/adapters/databricks/connections.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from dbt.adapters.databricks.handle import CursorWrapper, DatabricksHandle, SqlUtils
3737
from dbt.adapters.databricks.logging import logger
3838
from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker
39-
from dbt.adapters.databricks.utils import redact_credentials
39+
from dbt.adapters.databricks.utils import is_cluster_http_path, redact_credentials
4040
from dbt.adapters.events.types import (
4141
ConnectionClosedInCleanup,
4242
ConnectionReused,
@@ -130,12 +130,8 @@ def api_client(self) -> DatabricksApiClient:
130130

131131
def is_cluster(self) -> bool:
132132
conn = self.get_thread_connection()
133-
return (
134-
conn.credentials.cluster_id is not None
135-
# Credentials field is not updated when overriding the compute at model level.
136-
# This secondary check is a workaround for that case
137-
or "/warehouses/" not in cast(DatabricksDBTConnection, conn).http_path
138-
)
133+
databricks_conn = cast(DatabricksDBTConnection, conn)
134+
return is_cluster_http_path(databricks_conn.http_path, conn.credentials.cluster_id)
139135

140136
def cancel_open(self) -> list[str]:
141137
cancelled = super().cancel_open()
@@ -402,7 +398,8 @@ def connect() -> DatabricksHandle:
402398
try:
403399
# TODO: what is the error when a user specifies a catalog they don't have access to
404400
conn = DatabricksHandle.from_connection_args(
405-
conn_args, creds.cluster_id is not None
401+
conn_args,
402+
is_cluster_http_path(databricks_connection.http_path, creds.cluster_id),
406403
)
407404
if conn:
408405
databricks_connection.session_id = conn.session_id

dbt/adapters/databricks/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import re
22
from collections.abc import Callable
3-
from typing import TYPE_CHECKING, Any, TypeVar
3+
from typing import TYPE_CHECKING, Any, Optional, TypeVar
44

55
from dbt_common.exceptions import DbtRuntimeError
66
from jinja2 import Undefined
@@ -88,3 +88,12 @@ def handle_exceptions_as_warning(op: Callable[[], None], log_gen: ExceptionToStr
8888
op()
8989
except Exception as e:
9090
logger.warning(log_gen(e))
91+
92+
93+
def is_cluster_http_path(http_path: str, cluster_id: Optional[str]) -> bool:
94+
return (
95+
cluster_id is not None
96+
# Credentials field is not updated when overriding the compute at model level.
97+
# This secondary check is a workaround for that case
98+
or "/warehouses/" not in http_path
99+
)
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from multiprocessing import get_context
2+
from unittest.mock import Mock, patch
3+
4+
from dbt.adapters.databricks.connections import DatabricksConnectionManager, DatabricksDBTConnection
5+
from dbt.adapters.databricks.credentials import DatabricksCredentials
6+
from dbt.adapters.databricks.utils import is_cluster_http_path
7+
8+
9+
class TestDatabricksConnectionManager:
10+
def test_is_cluster_with_warehouse_path_no_cluster_id(self):
11+
"""Test is_cluster() returns False for warehouse path with no cluster_id"""
12+
# Create a minimal connection manager with mock config
13+
mock_config = Mock()
14+
connection_manager = DatabricksConnectionManager(mock_config, get_context("spawn"))
15+
16+
# Mock the connection
17+
mock_connection = Mock(spec=DatabricksDBTConnection)
18+
mock_connection.credentials = Mock(spec=DatabricksCredentials)
19+
mock_connection.credentials.cluster_id = None
20+
mock_connection.http_path = "sql/1.0/warehouses/abc123def456"
21+
22+
with patch.object(
23+
connection_manager, "get_thread_connection", return_value=mock_connection
24+
):
25+
assert connection_manager.is_cluster() is False
26+
27+
def test_is_cluster_with_cluster_id_overrides_path(self):
28+
"""Test is_cluster() returns True when cluster_id is provided, regardless of path"""
29+
# Create a minimal connection manager with mock config
30+
mock_config = Mock()
31+
connection_manager = DatabricksConnectionManager(mock_config, get_context("spawn"))
32+
33+
# Mock the connection with cluster_id set (overriding warehouse path)
34+
mock_connection = Mock(spec=DatabricksDBTConnection)
35+
mock_connection.credentials = Mock(spec=DatabricksCredentials)
36+
mock_connection.credentials.cluster_id = "cluster-123"
37+
mock_connection.http_path = "sql/1.0/warehouses/abc123def456"
38+
39+
with patch.object(
40+
connection_manager, "get_thread_connection", return_value=mock_connection
41+
):
42+
assert connection_manager.is_cluster() is True
43+
44+
def test_is_cluster_http_path_function_warehouse_path(self):
45+
assert is_cluster_http_path("sql/1.0/warehouses/abc123def456", None) is False
46+
47+
def test_is_cluster_http_path_function_cluster_path(self):
48+
assert is_cluster_http_path("sql/protocolv1/o/1234567890123456/", None) is True
49+
50+
def test_is_cluster_http_path_function_cluster_id_overrides(self):
51+
assert is_cluster_http_path("sql/1.0/warehouses/abc123def456", "cluster-123") is True
52+
53+
@patch("dbt.adapters.databricks.connections.DatabricksHandle.from_connection_args")
54+
@patch("dbt.adapters.databricks.connections.SqlUtils.prepare_connection_arguments")
55+
def test_open_calls_is_cluster_http_path_for_warehouse(
56+
self, mock_prepare_args, mock_from_connection_args
57+
):
58+
"""
59+
Test that open() method calls is_cluster_http_path with correct arguments for warehouse
60+
"""
61+
# Create a minimal connection manager with mock config
62+
mock_config = Mock()
63+
connection_manager = DatabricksConnectionManager(mock_config, get_context("spawn"))
64+
65+
# Mock the connection with proper timeout values
66+
mock_connection = Mock(spec=DatabricksDBTConnection)
67+
mock_connection.credentials = Mock(spec=DatabricksCredentials)
68+
mock_connection.credentials.cluster_id = None
69+
mock_connection.credentials.connect_retries = 1
70+
mock_connection.credentials.connect_timeout = 10
71+
mock_connection.http_path = "sql/protocolv1/o/abc123def456"
72+
mock_connection.credentials.authenticate.return_value = Mock()
73+
74+
# Mock the handle creation
75+
mock_handle = Mock()
76+
mock_handle.session_id = "test_session"
77+
mock_from_connection_args.return_value = mock_handle
78+
79+
mock_prepare_args.return_value = {}
80+
81+
# Call open method
82+
connection_manager.open(mock_connection)
83+
84+
# Verify that from_connection_args was called with is_cluster=False (warehouse path)
85+
mock_from_connection_args.assert_called_once()
86+
args, kwargs = mock_from_connection_args.call_args
87+
# Second argument (is_cluster) should be True for warehouse path with cluster_id
88+
assert args[1] is True

tests/unit/test_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
from dbt.adapters.databricks.utils import quote, redact_credentials, remove_ansi
1+
from dbt.adapters.databricks.utils import (
2+
is_cluster_http_path,
3+
quote,
4+
redact_credentials,
5+
remove_ansi,
6+
)
27

38

49
class TestDatabricksUtils:
@@ -67,3 +72,12 @@ def test_remove_ansi(self):
6772

6873
def test_quote(self):
6974
assert quote("table") == "`table`"
75+
76+
def test_is_cluster_http_path_with_cluster_id(self):
77+
assert is_cluster_http_path("/sql/1.0/warehouses/abc", "cluster-123") is True
78+
79+
def test_is_cluster_http_path_without_cluster_id_and_warehouses(self):
80+
assert is_cluster_http_path("/sql/1.0/endpoints/abc", None) is True
81+
82+
def test_is_cluster_http_path_without_cluster_id_and_with_warehouses(self):
83+
assert is_cluster_http_path("/sql/1.0/warehouses/abc", None) is False

0 commit comments

Comments
 (0)