-
Notifications
You must be signed in to change notification settings - Fork 198
Expand file tree
/
Copy pathtest_connection_manager.py
More file actions
163 lines (128 loc) · 7.44 KB
/
test_connection_manager.py
File metadata and controls
163 lines (128 loc) · 7.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from multiprocessing import get_context
from unittest.mock import Mock, patch
import pytest
from dbt.adapters.databricks.connections import (
DatabricksConnectionManager,
DatabricksDBTConnection,
)
from dbt.adapters.databricks.credentials import DatabricksCredentials
from dbt.adapters.databricks.dbr_capabilities import DBRCapabilities, DBRCapability
from dbt.adapters.databricks.utils import is_cluster_http_path
class TestDatabricksConnectionManager:
def test_is_cluster_with_warehouse_path_no_cluster_id(self):
"""Test is_cluster() returns False for warehouse path with no cluster_id"""
# Create a minimal connection manager with mock config
mock_config = Mock()
connection_manager = DatabricksConnectionManager(mock_config, get_context("spawn"))
# Mock the connection
mock_connection = Mock(spec=DatabricksDBTConnection)
mock_connection.credentials = Mock(spec=DatabricksCredentials)
mock_connection.credentials.cluster_id = None
mock_connection.http_path = "sql/1.0/warehouses/abc123def456"
with patch.object(
connection_manager, "get_thread_connection", return_value=mock_connection
):
assert connection_manager.is_cluster() is False
def test_is_cluster_with_cluster_id_overrides_path(self):
"""Test is_cluster() returns False even when cluster_id is provided"""
# Create a minimal connection manager with mock config
mock_config = Mock()
connection_manager = DatabricksConnectionManager(mock_config, get_context("spawn"))
# Mock the connection with cluster_id set (overriding warehouse path)
mock_connection = Mock(spec=DatabricksDBTConnection)
mock_connection.credentials = Mock(spec=DatabricksCredentials)
mock_connection.credentials.cluster_id = "cluster-123"
mock_connection.http_path = "sql/1.0/warehouses/abc123def456"
with patch.object(
connection_manager, "get_thread_connection", return_value=mock_connection
):
assert connection_manager.is_cluster() is False
def test_is_cluster_http_path_function_warehouse_path(self):
assert is_cluster_http_path("sql/1.0/warehouses/abc123def456", None) is False
def test_is_cluster_http_path_function_cluster_path(self):
assert is_cluster_http_path("sql/protocolv1/o/1234567890123456/", None) is True
def test_is_cluster_http_path_function_cluster_id_overrides(self):
assert is_cluster_http_path("sql/1.0/warehouses/abc123def456", "cluster-123") is False
@patch("dbt.adapters.databricks.connections.DatabricksHandle.from_connection_args")
@patch("dbt.adapters.databricks.connections.SqlUtils.prepare_connection_arguments")
def test_open_calls_is_cluster_http_path_for_warehouse(
self, mock_prepare_args, mock_from_connection_args
):
"""
Test that open() method calls is_cluster_http_path with correct arguments for warehouse
"""
# Create a minimal connection manager with mock config
mock_config = Mock()
connection_manager = DatabricksConnectionManager(mock_config, get_context("spawn"))
# Mock the connection with proper timeout values
mock_connection = Mock(spec=DatabricksDBTConnection)
mock_connection.credentials = Mock(spec=DatabricksCredentials)
mock_connection.credentials.cluster_id = None
mock_connection.credentials.connect_retries = 1
mock_connection.credentials.connect_timeout = 10
mock_connection.credentials.query_tags = None
mock_connection.http_path = "sql/protocolv1/o/abc123def456"
mock_connection.credentials.authenticate.return_value = Mock()
mock_connection._query_header_context = None
# Mock the handle creation
mock_handle = Mock()
mock_handle.session_id = "test_session"
mock_from_connection_args.return_value = mock_handle
mock_prepare_args.return_value = {}
# Call open method
connection_manager.open(mock_connection)
# Verify that from_connection_args was called with is_cluster=False (warehouse path)
mock_from_connection_args.assert_called_once()
args, kwargs = mock_from_connection_args.call_args
# Second argument (is_cluster) should be True for warehouse path with cluster_id
assert args[1] is True
class TestCacheDbr:
"""Unit tests for _cache_dbr_capabilities."""
HTTP_PATH = "sql/protocolv1/o/1234567890123456/cluster-abc"
@pytest.fixture(autouse=True)
def clear_cache(self):
DatabricksConnectionManager._dbr_capabilities_cache = {}
yield
DatabricksConnectionManager._dbr_capabilities_cache = {}
@patch.object(DatabricksConnectionManager, "_query_dbr_version", return_value=None)
def test_does_not_write_to_cache_when_version_is_none(self, mock_query):
"""Regression for #1398: a None version query result must not poison the cache."""
creds = Mock(spec=DatabricksCredentials)
creds.cluster_id = None
DatabricksConnectionManager._cache_dbr_capabilities(creds, self.HTTP_PATH)
assert self.HTTP_PATH not in DatabricksConnectionManager._dbr_capabilities_cache
mock_query.assert_called_once_with(creds, self.HTTP_PATH)
@patch.object(DatabricksConnectionManager, "_query_dbr_version", return_value=(15, 4))
def test_writes_to_cache_when_version_is_known(self, mock_query):
"""When the version query succeeds, capabilities are cached correctly."""
creds = Mock(spec=DatabricksCredentials)
creds.cluster_id = None
DatabricksConnectionManager._cache_dbr_capabilities(creds, self.HTTP_PATH)
mock_query.assert_called_once_with(creds, self.HTTP_PATH)
caps = DatabricksConnectionManager._dbr_capabilities_cache.get(self.HTTP_PATH)
assert caps is not None
assert caps.dbr_version == (15, 4)
assert not caps.is_sql_warehouse
assert caps.has_capability(DBRCapability.ICEBERG)
@patch.object(DatabricksConnectionManager, "_query_dbr_version", return_value=(15, 4))
def test_skips_write_when_already_cached(self, mock_query):
"""If the path is already in cache, the version query is never made."""
creds = Mock(spec=DatabricksCredentials)
creds.cluster_id = None
existing = DBRCapabilities(dbr_version=(14, 3), is_sql_warehouse=False)
DatabricksConnectionManager._dbr_capabilities_cache[self.HTTP_PATH] = existing
DatabricksConnectionManager._cache_dbr_capabilities(creds, self.HTTP_PATH)
mock_query.assert_not_called()
assert DatabricksConnectionManager._dbr_capabilities_cache[self.HTTP_PATH] is existing
def test_retry_succeeds_after_transient_failure(self):
"""Regression for #1398: after a transient None result, the next call must re-query."""
creds = Mock(spec=DatabricksCredentials)
creds.cluster_id = None
with patch.object(DatabricksConnectionManager, "_query_dbr_version", return_value=None):
DatabricksConnectionManager._cache_dbr_capabilities(creds, self.HTTP_PATH)
with patch.object(DatabricksConnectionManager, "_query_dbr_version", return_value=(16, 2)):
DatabricksConnectionManager._cache_dbr_capabilities(creds, self.HTTP_PATH)
caps = DatabricksConnectionManager._dbr_capabilities_cache.get(self.HTTP_PATH)
assert caps is not None
assert caps.dbr_version == (16, 2)
assert caps.has_capability(DBRCapability.ICEBERG)