Skip to content

Commit 94b9da9

Browse files
authored
Wait for library installation on python submission (#1031)
1 parent dab5794 commit 94b9da9

5 files changed

Lines changed: 171 additions & 16 deletions

File tree

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
## dbt-databricks 1.10.3 (TBD)
22

3+
### Fixes
4+
5+
- Fix bug where python model run starts before all libraries are installed on the cluster ([1028](https://github.com/databricks/dbt-databricks/issues/1028))
6+
37
## dbt-databricks 1.10.2 (May 21, 2025)
48

59
### Features

dbt/adapters/databricks/api_client.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
DEFAULT_POLLING_INTERVAL = 10
2020
SUBMISSION_LANGUAGE = "python"
2121
USER_AGENT = f"dbt-databricks/{version}"
22+
LIBRARY_VALID_STATUSES = {"INSTALLED", "RESTORED", "SKIPPED"}
2223

2324

2425
class PrefixSession:
@@ -47,10 +48,45 @@ def __init__(self, session: Session, host: str, api: str):
4748
self.session = PrefixSession(session, host, api)
4849

4950

51+
class LibraryApi(DatabricksApi):
52+
def __init__(self, session: Session, host: str):
53+
super().__init__(session, host, "/api/2.0/libraries")
54+
55+
def all_libraries_installed(self, cluster_id: str) -> bool:
56+
status = self.get_cluster_libraries_status(cluster_id)
57+
if "library_statuses" in status:
58+
return all(
59+
library["status"] in LIBRARY_VALID_STATUSES
60+
for library in status["library_statuses"]
61+
)
62+
else:
63+
return True
64+
65+
def get_cluster_libraries_status(self, cluster_id: str) -> dict[str, Any]:
66+
response = self.session.get(
67+
"/cluster-status",
68+
json={"cluster_id": cluster_id},
69+
)
70+
if response.status_code != 200:
71+
raise DbtRuntimeError(
72+
f"Error getting status of libraries of a cluster.\n {response.content!r}"
73+
)
74+
75+
json_response = response.json()
76+
return json_response
77+
78+
5079
class ClusterApi(DatabricksApi):
51-
def __init__(self, session: Session, host: str, max_cluster_start_time: int = 900):
80+
def __init__(
81+
self,
82+
session: Session,
83+
host: str,
84+
libraries: LibraryApi,
85+
max_cluster_start_time: int = 900,
86+
):
5287
super().__init__(session, host, "/api/2.0/clusters")
5388
self.max_cluster_start_time = max_cluster_start_time
89+
self.libraries = libraries
5490

5591
def status(self, cluster_id: str) -> str:
5692
# https://docs.databricks.com/dev-tools/api/latest/clusters.html#get
@@ -68,10 +104,19 @@ def wait_for_cluster(self, cluster_id: str) -> None:
68104

69105
while time.time() - start_time < self.max_cluster_start_time:
70106
status_response = self.status(cluster_id)
71-
if status_response == "RUNNING":
72-
return
73-
else:
107+
108+
if status_response != "RUNNING":
109+
logger.debug("Waiting for cluster to start")
110+
time.sleep(5)
111+
continue
112+
113+
libraries_status = self.libraries.get_cluster_libraries_status(cluster_id)
114+
if not self.libraries.all_libraries_installed(cluster_id):
115+
logger.debug(f"Waiting for all libraries to be installed: {libraries_status}")
74116
time.sleep(5)
117+
continue
118+
119+
return
75120

76121
raise DbtRuntimeError(
77122
f"Cluster {cluster_id} restart timed out after {self.max_cluster_start_time} seconds"
@@ -96,9 +141,12 @@ def start(self, cluster_id: str) -> None:
96141

97142

98143
class CommandContextApi(DatabricksApi):
99-
def __init__(self, session: Session, host: str, cluster_api: ClusterApi):
144+
def __init__(
145+
self, session: Session, host: str, cluster_api: ClusterApi, library_api: LibraryApi
146+
):
100147
super().__init__(session, host, "/api/1.2/contexts")
101148
self.cluster_api = cluster_api
149+
self.library_api = library_api
102150

103151
def create(self, cluster_id: str) -> str:
104152
current_status = self.cluster_api.status(cluster_id)
@@ -107,7 +155,9 @@ def create(self, cluster_id: str) -> str:
107155
logger.debug(f"Cluster {cluster_id} is not running. Attempting to restart.")
108156
self.cluster_api.start(cluster_id)
109157
logger.debug(f"Cluster {cluster_id} is now running.")
110-
elif current_status != "RUNNING":
158+
elif current_status != "RUNNING" or not self.library_api.all_libraries_installed(
159+
cluster_id
160+
):
111161
self.cluster_api.wait_for_cluster(cluster_id)
112162

113163
response = self.session.post(
@@ -532,8 +582,9 @@ def __init__(
532582
timeout: int,
533583
use_user_folder: bool,
534584
):
535-
self.clusters = ClusterApi(session, host)
536-
self.command_contexts = CommandContextApi(session, host, self.clusters)
585+
self.libraries = LibraryApi(session, host)
586+
self.clusters = ClusterApi(session, host, self.libraries)
587+
self.command_contexts = CommandContextApi(session, host, self.clusters, self.libraries)
537588
self.curr_user = CurrUserApi(session, host)
538589
if use_user_folder:
539590
self.folders: FolderApi = UserFolderApi(session, host, self.curr_user)

tests/unit/api_client/test_cluster_api.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import patch
1+
from unittest.mock import Mock, patch
22

33
import freezegun
44
import pytest
@@ -10,8 +10,12 @@
1010

1111
class TestClusterApi(ApiTestBase):
1212
@pytest.fixture
13-
def api(self, session, host):
14-
return ClusterApi(session, host)
13+
def library_api(self):
14+
return Mock()
15+
16+
@pytest.fixture
17+
def api(self, session, host, library_api):
18+
return ClusterApi(session, host, library_api)
1519

1620
def test_status__non_200(self, api, session):
1721
self.assert_non_200_raises_error(lambda: api.status("cluster_id"), session)
@@ -26,10 +30,33 @@ def test_status__200(self, api, session, host):
2630
)
2731

2832
@patch("dbt.adapters.databricks.api_client.time.sleep")
29-
def test_wait_for_cluster__success(self, _, api, session):
33+
def test_wait_for_cluster__success(self, _, api, session, library_api):
3034
session.get.return_value.status_code = 200
3135
session.get.return_value.json.side_effect = [{"state": "pending"}, {"state": "running"}]
36+
library_api.get_cluster_libraries_status.return_value = {"library_statuses": []}
37+
api.wait_for_cluster("cluster_id")
38+
39+
@patch("dbt.adapters.databricks.api_client.time.sleep")
40+
def test_wait_for_cluster_with_installed_library__success(
41+
self, mock_sleep, api, session, library_api
42+
):
43+
session.get.return_value.status_code = 200
44+
session.get.return_value.json.return_value = {"state": "running"}
45+
library_api.get_cluster_libraries_status.return_value = {
46+
"library_statuses": [{"status": "INSTALLED"}]
47+
}
48+
api.wait_for_cluster("cluster_id")
49+
mock_sleep.assert_not_called()
50+
51+
@patch("dbt.adapters.databricks.api_client.time.sleep")
52+
def test_wait_for_cluster_with_pending_library__success(
53+
self, mock_sleep, api, session, library_api
54+
):
55+
session.get.return_value.status_code = 200
56+
session.get.return_value.json.side_effect = [{"state": "running"}, {"state": "running"}]
57+
library_api.all_libraries_installed.side_effect = [False, True]
3258
api.wait_for_cluster("cluster_id")
59+
mock_sleep.assert_called_with(5)
3360

3461
@freezegun.freeze_time("2020-01-01", auto_tick_seconds=900)
3562
@patch("dbt.adapters.databricks.api_client.time.sleep")
@@ -42,10 +69,11 @@ def test_wait_for_cluster__timeout(self, _, api, session):
4269
def test_start__non_200(self, api, session):
4370
self.assert_non_200_raises_error(lambda: api.start("cluster_id"), session)
4471

45-
def test_start__200(self, api, session, host):
72+
def test_start__200(self, api, session, host, library_api):
4673
session.post.return_value.status_code = 200
4774
session.get.return_value.status_code = 200
4875
session.get.return_value.json.return_value = {"state": "running"}
76+
library_api.get_cluster_libraries_status.return_value = {"library_statuses": []}
4977
api.start("cluster_id")
5078
session.post.assert_called_once_with(
5179
f"https://{host}/api/2.0/clusters/start", json={"cluster_id": "cluster_id"}, params=None

tests/unit/api_client/test_command_context_api.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,20 @@ def cluster_api(self):
1212
return Mock()
1313

1414
@pytest.fixture
15-
def api(self, session, host, cluster_api):
16-
return CommandContextApi(session, host, cluster_api)
15+
def library_api(self):
16+
return Mock()
17+
18+
@pytest.fixture
19+
def api(self, session, host, cluster_api, library_api):
20+
return CommandContextApi(session, host, cluster_api, library_api)
1721

1822
def test_create__non_200(self, api, cluster_api, session):
1923
cluster_api.status.return_value = "RUNNING"
2024
self.assert_non_200_raises_error(lambda: api.create("cluster_id"), session)
2125

22-
def test_create__cluster_running(self, api, cluster_api, session):
26+
def test_create__cluster_running(self, api, cluster_api, library_api, session):
2327
cluster_api.status.return_value = "RUNNING"
28+
library_api.all_libraries_installed.return_value = True
2429
session.post.return_value.status_code = 200
2530
session.post.return_value.json.return_value = {"id": "context_id"}
2631
id = api.create("cluster_id")
@@ -29,8 +34,25 @@ def test_create__cluster_running(self, api, cluster_api, session):
2934
json={"clusterId": "cluster_id", "language": "python"},
3035
params=None,
3136
)
37+
cluster_api.wait_for_cluster.assert_not_called()
3238
assert id == "context_id"
3339

40+
def test_create__cluster_running_with_pending_libraries(
41+
self, api, cluster_api, library_api, session
42+
):
43+
cluster_api.status.return_value = "RUNNING"
44+
library_api.all_libraries_installed.return_value = False
45+
session.post.return_value.status_code = 200
46+
session.post.return_value.json.return_value = {"id": "context_id"}
47+
id = api.create("cluster_id")
48+
session.post.assert_called_once_with(
49+
"https://host/api/1.2/contexts/create",
50+
json={"clusterId": "cluster_id", "language": "python"},
51+
params=None,
52+
)
53+
assert id == "context_id"
54+
cluster_api.wait_for_cluster.assert_called_once_with("cluster_id")
55+
3456
def test_create__cluster_terminated(self, api, cluster_api, session):
3557
cluster_api.status.return_value = "TERMINATED"
3658
session.post.return_value.status_code = 200
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
3+
from dbt.adapters.databricks.api_client import LibraryApi
4+
from tests.unit.api_client.api_test_base import ApiTestBase
5+
6+
7+
class TestLibraryApi(ApiTestBase):
8+
@pytest.fixture
9+
def api(self, session, host):
10+
return LibraryApi(session, host)
11+
12+
def test_get_cluster_libraries_status__non_200(self, api, session):
13+
self.assert_non_200_raises_error(
14+
lambda: api.get_cluster_libraries_status("cluster_id"), session
15+
)
16+
17+
def test_get_cluster_libraries_status__200(self, api, session, host):
18+
cluster_id = "cluster_id"
19+
expected_response = {"library_statuses": [{"status": "INSTALLED"}, {"status": "PENDING"}]}
20+
session.get.return_value.status_code = 200
21+
session.get.return_value.json.return_value = expected_response
22+
23+
result = api.get_cluster_libraries_status(cluster_id)
24+
assert result == expected_response
25+
session.get.assert_called_once_with(
26+
f"https://{host}/api/2.0/libraries/cluster-status",
27+
json={"cluster_id": cluster_id},
28+
params=None,
29+
)
30+
31+
def test_all_libraries_installed__true(self, api, session, host):
32+
session.get.return_value.status_code = 200
33+
session.get.return_value.json.return_value = {"library_statuses": [{"status": "INSTALLED"}]}
34+
35+
result = api.all_libraries_installed("cluster_id")
36+
assert result is True
37+
38+
def test_all_libraries_installed__false(self, api, session, host):
39+
session.get.return_value.status_code = 200
40+
session.get.return_value.json.return_value = {"library_statuses": [{"status": "PENDING"}]}
41+
42+
result = api.all_libraries_installed("cluster_id")
43+
assert result is False
44+
45+
def test_library_statuses_not_present(self, api, session, host):
46+
session.get.return_value.status_code = 200
47+
session.get.return_value.json.return_value = {"cluster-id": "abc-123"}
48+
49+
result = api.all_libraries_installed("abc-123")
50+
assert result is True

0 commit comments

Comments
 (0)