Skip to content

Commit 2b1139f

Browse files
GeorgeTheo99George Theodosopoulosclaudecalreynolds
authored
Consolidate compute tools: 19 → 4 MCP tools for execution & compute management (#278)
* Add serverless code execution via Jobs API (no cluster required) Adds a new `run_code_on_serverless()` function that executes Python or SQL code on Databricks serverless compute using the Jobs API `runs/submit` endpoint. No interactive cluster is required. The implementation: - Uploads code as a temporary notebook to the workspace - Submits a one-time run with serverless compute (environments + environment_key pattern) - Waits for completion and retrieves output via get_run_output - Cleans up temporary workspace files after execution - Returns a typed ServerlessRunResult with output, error, run_id, run_url, and timing New files and changes: - databricks-tools-core: compute/serverless.py (core module) - databricks-tools-core: compute/__init__.py (exports) - databricks-mcp-server: tools/compute.py (MCP tool wrapper) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix serverless code runner: capture error tracebacks, add skill file, improve docs - Retrieve actual Python traceback on failure instead of generic "Workload failed" message by fetching run output in the exception handler - Fix Optional[str] type annotation for run_name in MCP wrapper - Document SQL SELECT output limitation in all docstrings - Reframe tool as Python-first; clarify SQL is niche (DDL/DML only, use execute_sql for queries — works with serverless SQL warehouses) - Add databricks-serverless-compute skill file with decision matrix, output capture behavior, limitations, and examples 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Add .ipynb (Jupyter notebook) support to serverless code runner Auto-detects .ipynb JSON content and uploads via Databricks native Jupyter import (ImportFormat.JUPYTER), enabling users to run local Jupyter notebooks on serverless compute without conversion. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Consolidate into databricks-execution-compute skill with multi-language and workspace_path support - Rename databricks-serverless-compute → databricks-execution-compute covering all three execution tools (execute_databricks_command, run_file_on_databricks, run_code_on_serverless) - Rename run_python_file_on_databricks → run_file_on_databricks with language auto-detection from file extension (.py, .scala, .sql, .r); old name kept as alias - Add workspace_path param to run_file_on_databricks and run_code_on_serverless for persistent mode (saves notebook to workspace) vs ephemeral (default, temp cleanup) - Add comprehensive integration tests (34 tests) covering classic cluster execution, serverless Python/SQL, ephemeral vs persistent modes, input validation, and error handling - Update MCP tool layer with new params and empty-string-to-None coercion - Update install_skills.sh with new skill name and description Co-authored-by: Isaac * Add databricks-manage-compute skill for cluster and warehouse lifecycle management Core functions (manage.py): create/modify/terminate/delete clusters with opinionated defaults (auto-pick LTS DBR, reasonable node type, SINGLE_USER mode, 120min auto-term), plus create/modify/delete SQL warehouses. List helpers for node types and spark versions. MCP tool wrappers with destructive-action warnings in docstrings. SKILL.md with decision matrix, tool reference tables, and examples. Integration tests validated against e2-demo-field-eng. Co-authored-by: Isaac * Consolidate 19 compute MCP tools into 4 per PR feedback Address Cal's feedback to reduce tool count and LLM parsing bloat. Uses action/resource params to determine behavior (like SDP's create_and_update pattern). New tools: - execute_code (replaces 3 execution tools) - manage_cluster (replaces 5 cluster tools) - manage_sql_warehouse (replaces 3 warehouse tools) - list_compute (replaces 5 listing/inspection tools) Also removes the separate databricks-manage-compute skill (merged into databricks-execution-compute) and adds 41 unit tests for routing logic. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: George Theodosopoulos <george.theodosopoulos@databricks.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Cal Reynolds <49540501+calreynolds@users.noreply.github.com>
1 parent 5416e4d commit 2b1139f

11 files changed

Lines changed: 2842 additions & 218 deletions

File tree

databricks-mcp-server/databricks_mcp_server/tools/compute.py

Lines changed: 415 additions & 183 deletions
Large diffs are not rendered by default.
Lines changed: 392 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,392 @@
1+
"""
2+
Unit tests for consolidated compute tools.
3+
4+
Tests the MCP tool wrapper routing logic without hitting Databricks APIs.
5+
"""
6+
7+
import pytest
8+
from unittest.mock import patch, MagicMock
9+
from databricks_mcp_server.tools.compute import (
10+
execute_code,
11+
manage_cluster,
12+
manage_sql_warehouse,
13+
list_compute,
14+
)
15+
16+
17+
# ---------------------------------------------------------------------------
18+
# execute_code routing tests
19+
# ---------------------------------------------------------------------------
20+
21+
22+
class TestExecuteCodeRouting:
23+
"""Test that execute_code routes to the correct backend."""
24+
25+
def test_requires_code_or_file_path(self):
26+
result = execute_code()
27+
assert result["success"] is False
28+
assert "code" in result["error"].lower() or "file_path" in result["error"].lower()
29+
30+
def test_empty_strings_treated_as_none(self):
31+
result = execute_code(code="", file_path="")
32+
assert result["success"] is False
33+
assert "code" in result["error"].lower() or "file_path" in result["error"].lower()
34+
35+
@patch("databricks_mcp_server.tools.compute._run_code_on_serverless")
36+
def test_auto_routes_to_serverless_for_python(self, mock_serverless):
37+
mock_result = MagicMock()
38+
mock_result.to_dict.return_value = {"success": True, "output": "hello"}
39+
mock_serverless.return_value = mock_result
40+
41+
execute_code(code="print('hi')", compute_type="auto")
42+
43+
mock_serverless.assert_called_once()
44+
call_kwargs = mock_serverless.call_args[1]
45+
assert call_kwargs["code"] == "print('hi')"
46+
assert call_kwargs["language"] == "python"
47+
48+
@patch("databricks_mcp_server.tools.compute._execute_databricks_command")
49+
def test_auto_routes_to_cluster_with_cluster_id(self, mock_cluster):
50+
mock_result = MagicMock()
51+
mock_result.to_dict.return_value = {"success": True}
52+
mock_cluster.return_value = mock_result
53+
54+
execute_code(code="print('hi')", cluster_id="abc-123")
55+
56+
mock_cluster.assert_called_once()
57+
assert mock_cluster.call_args[1]["cluster_id"] == "abc-123"
58+
59+
@patch("databricks_mcp_server.tools.compute._execute_databricks_command")
60+
def test_auto_routes_to_cluster_with_context_id(self, mock_cluster):
61+
mock_result = MagicMock()
62+
mock_result.to_dict.return_value = {"success": True}
63+
mock_cluster.return_value = mock_result
64+
65+
execute_code(code="print('hi')", context_id="ctx-456")
66+
67+
mock_cluster.assert_called_once()
68+
assert mock_cluster.call_args[1]["context_id"] == "ctx-456"
69+
70+
@patch("databricks_mcp_server.tools.compute._execute_databricks_command")
71+
def test_auto_routes_to_cluster_for_scala(self, mock_cluster):
72+
mock_result = MagicMock()
73+
mock_result.to_dict.return_value = {"success": True}
74+
mock_cluster.return_value = mock_result
75+
76+
execute_code(code="println(42)", language="scala")
77+
78+
mock_cluster.assert_called_once()
79+
assert mock_cluster.call_args[1]["language"] == "scala"
80+
81+
@patch("databricks_mcp_server.tools.compute._execute_databricks_command")
82+
def test_auto_routes_to_cluster_for_r(self, mock_cluster):
83+
mock_result = MagicMock()
84+
mock_result.to_dict.return_value = {"success": True}
85+
mock_cluster.return_value = mock_result
86+
87+
execute_code(code="print(42)", language="r")
88+
89+
mock_cluster.assert_called_once()
90+
91+
@patch("databricks_mcp_server.tools.compute._run_code_on_serverless")
92+
def test_explicit_serverless(self, mock_serverless):
93+
mock_result = MagicMock()
94+
mock_result.to_dict.return_value = {"success": True}
95+
mock_serverless.return_value = mock_result
96+
97+
execute_code(code="print('hi')", compute_type="serverless")
98+
99+
mock_serverless.assert_called_once()
100+
101+
@patch("databricks_mcp_server.tools.compute._execute_databricks_command")
102+
def test_explicit_cluster(self, mock_cluster):
103+
mock_result = MagicMock()
104+
mock_result.to_dict.return_value = {"success": True}
105+
mock_cluster.return_value = mock_result
106+
107+
execute_code(code="print('hi')", compute_type="cluster")
108+
109+
mock_cluster.assert_called_once()
110+
111+
@patch("databricks_mcp_server.tools.compute._run_code_on_serverless")
112+
def test_serverless_default_timeout(self, mock_serverless):
113+
mock_result = MagicMock()
114+
mock_result.to_dict.return_value = {"success": True}
115+
mock_serverless.return_value = mock_result
116+
117+
execute_code(code="x", compute_type="serverless")
118+
119+
assert mock_serverless.call_args[1]["timeout"] == 1800
120+
121+
@patch("databricks_mcp_server.tools.compute._execute_databricks_command")
122+
def test_cluster_default_timeout(self, mock_cluster):
123+
mock_result = MagicMock()
124+
mock_result.to_dict.return_value = {"success": True}
125+
mock_cluster.return_value = mock_result
126+
127+
execute_code(code="x", compute_type="cluster")
128+
129+
assert mock_cluster.call_args[1]["timeout"] == 120
130+
131+
@patch("databricks_mcp_server.tools.compute._run_code_on_serverless")
132+
def test_workspace_path_passed_to_serverless(self, mock_serverless):
133+
mock_result = MagicMock()
134+
mock_result.to_dict.return_value = {"success": True}
135+
mock_serverless.return_value = mock_result
136+
137+
execute_code(code="x", compute_type="serverless", workspace_path="/Workspace/Users/a/b")
138+
139+
call_kwargs = mock_serverless.call_args[1]
140+
assert call_kwargs["workspace_path"] == "/Workspace/Users/a/b"
141+
assert call_kwargs["cleanup"] is False
142+
143+
@patch("databricks_mcp_server.tools.compute._run_file_on_databricks")
144+
def test_file_path_on_cluster(self, mock_run_file):
145+
mock_result = MagicMock()
146+
mock_result.to_dict.return_value = {"success": True}
147+
mock_run_file.return_value = mock_result
148+
149+
execute_code(file_path="/tmp/test.py", compute_type="cluster")
150+
151+
mock_run_file.assert_called_once()
152+
assert mock_run_file.call_args[1]["file_path"] == "/tmp/test.py"
153+
154+
def test_file_path_not_found_serverless(self):
155+
result = execute_code(file_path="/nonexistent/file.py", compute_type="serverless")
156+
assert result["success"] is False
157+
assert "not found" in result["error"].lower()
158+
159+
@patch("databricks_mcp_server.tools.compute._execute_databricks_command")
160+
def test_no_running_cluster_error(self, mock_cluster):
161+
from databricks_tools_core.compute import NoRunningClusterError
162+
mock_cluster.side_effect = NoRunningClusterError(
163+
available_clusters=[],
164+
skipped_clusters=[],
165+
startable_clusters=[{"cluster_id": "abc", "cluster_name": "test", "state": "TERMINATED"}],
166+
)
167+
168+
result = execute_code(code="x", compute_type="cluster")
169+
170+
assert result["success"] is False
171+
assert "startable_clusters" in result
172+
assert len(result["startable_clusters"]) == 1
173+
174+
175+
# ---------------------------------------------------------------------------
176+
# manage_cluster routing tests
177+
# ---------------------------------------------------------------------------
178+
179+
180+
class TestManageCluster:
181+
"""Test manage_cluster action routing."""
182+
183+
def test_invalid_action(self):
184+
result = manage_cluster(action="explode")
185+
assert result["success"] is False
186+
assert "unknown action" in result["error"].lower()
187+
188+
def test_create_requires_name(self):
189+
result = manage_cluster(action="create")
190+
assert result["success"] is False
191+
assert "name" in result["error"].lower()
192+
193+
def test_modify_requires_cluster_id(self):
194+
result = manage_cluster(action="modify")
195+
assert result["success"] is False
196+
assert "cluster_id" in result["error"].lower()
197+
198+
def test_start_requires_cluster_id(self):
199+
result = manage_cluster(action="start")
200+
assert result["success"] is False
201+
assert "cluster_id" in result["error"].lower()
202+
203+
def test_terminate_requires_cluster_id(self):
204+
result = manage_cluster(action="terminate")
205+
assert result["success"] is False
206+
assert "cluster_id" in result["error"].lower()
207+
208+
def test_delete_requires_cluster_id(self):
209+
result = manage_cluster(action="delete")
210+
assert result["success"] is False
211+
assert "cluster_id" in result["error"].lower()
212+
213+
@patch("databricks_mcp_server.tools.compute._create_cluster")
214+
def test_create_routes_correctly(self, mock_create):
215+
mock_create.return_value = {"cluster_id": "abc", "state": "PENDING"}
216+
217+
result = manage_cluster(action="create", name="test-cluster", num_workers=2)
218+
219+
mock_create.assert_called_once()
220+
assert mock_create.call_args[1]["name"] == "test-cluster"
221+
assert mock_create.call_args[1]["num_workers"] == 2
222+
223+
@patch("databricks_mcp_server.tools.compute._modify_cluster")
224+
def test_modify_routes_correctly(self, mock_modify):
225+
mock_modify.return_value = {"cluster_id": "abc"}
226+
227+
manage_cluster(action="modify", cluster_id="abc", num_workers=4)
228+
229+
mock_modify.assert_called_once()
230+
assert mock_modify.call_args[1]["cluster_id"] == "abc"
231+
assert mock_modify.call_args[1]["num_workers"] == 4
232+
233+
@patch("databricks_mcp_server.tools.compute._start_cluster")
234+
def test_start_routes_correctly(self, mock_start):
235+
mock_start.return_value = {"cluster_id": "abc", "state": "PENDING"}
236+
237+
manage_cluster(action="start", cluster_id="abc")
238+
239+
mock_start.assert_called_once_with("abc")
240+
241+
@patch("databricks_mcp_server.tools.compute._terminate_cluster")
242+
def test_terminate_routes_correctly(self, mock_terminate):
243+
mock_terminate.return_value = {"cluster_id": "abc", "state": "TERMINATING"}
244+
245+
manage_cluster(action="terminate", cluster_id="abc")
246+
247+
mock_terminate.assert_called_once_with("abc")
248+
249+
@patch("databricks_mcp_server.tools.compute._delete_cluster")
250+
def test_delete_routes_correctly(self, mock_delete):
251+
mock_delete.return_value = {"cluster_id": "abc", "state": "DELETED"}
252+
253+
manage_cluster(action="delete", cluster_id="abc")
254+
255+
mock_delete.assert_called_once_with("abc")
256+
257+
@patch("databricks_mcp_server.tools.compute._create_cluster")
258+
def test_create_defaults(self, mock_create):
259+
mock_create.return_value = {"cluster_id": "abc"}
260+
261+
manage_cluster(action="create", name="test")
262+
263+
call_kwargs = mock_create.call_args[1]
264+
assert call_kwargs["num_workers"] == 1
265+
assert call_kwargs["autotermination_minutes"] == 120
266+
267+
@patch("databricks_mcp_server.tools.compute._create_cluster")
268+
def test_create_with_spark_conf_json(self, mock_create):
269+
mock_create.return_value = {"cluster_id": "abc"}
270+
271+
manage_cluster(
272+
action="create",
273+
name="test",
274+
spark_conf='{"spark.sql.shuffle.partitions": "8"}',
275+
)
276+
277+
call_kwargs = mock_create.call_args[1]
278+
assert call_kwargs["spark_conf"] == {"spark.sql.shuffle.partitions": "8"}
279+
280+
281+
# ---------------------------------------------------------------------------
282+
# manage_sql_warehouse routing tests
283+
# ---------------------------------------------------------------------------
284+
285+
286+
class TestManageSqlWarehouse:
287+
"""Test manage_sql_warehouse action routing."""
288+
289+
def test_invalid_action(self):
290+
result = manage_sql_warehouse(action="explode")
291+
assert result["success"] is False
292+
293+
def test_create_requires_name(self):
294+
result = manage_sql_warehouse(action="create")
295+
assert result["success"] is False
296+
assert "name" in result["error"].lower()
297+
298+
def test_modify_requires_warehouse_id(self):
299+
result = manage_sql_warehouse(action="modify")
300+
assert result["success"] is False
301+
assert "warehouse_id" in result["error"].lower()
302+
303+
def test_delete_requires_warehouse_id(self):
304+
result = manage_sql_warehouse(action="delete")
305+
assert result["success"] is False
306+
assert "warehouse_id" in result["error"].lower()
307+
308+
@patch("databricks_mcp_server.tools.compute._create_sql_warehouse")
309+
def test_create_routes_correctly(self, mock_create):
310+
mock_create.return_value = {"warehouse_id": "abc"}
311+
312+
manage_sql_warehouse(action="create", name="test-wh", size="Medium")
313+
314+
mock_create.assert_called_once()
315+
assert mock_create.call_args[1]["name"] == "test-wh"
316+
assert mock_create.call_args[1]["size"] == "Medium"
317+
318+
@patch("databricks_mcp_server.tools.compute._modify_sql_warehouse")
319+
def test_modify_routes_correctly(self, mock_modify):
320+
mock_modify.return_value = {"warehouse_id": "abc"}
321+
322+
manage_sql_warehouse(action="modify", warehouse_id="abc", size="Large")
323+
324+
mock_modify.assert_called_once()
325+
assert mock_modify.call_args[1]["size"] == "Large"
326+
327+
@patch("databricks_mcp_server.tools.compute._delete_sql_warehouse")
328+
def test_delete_routes_correctly(self, mock_delete):
329+
mock_delete.return_value = {"warehouse_id": "abc"}
330+
331+
manage_sql_warehouse(action="delete", warehouse_id="abc")
332+
333+
mock_delete.assert_called_once_with("abc")
334+
335+
336+
# ---------------------------------------------------------------------------
337+
# list_compute routing tests
338+
# ---------------------------------------------------------------------------
339+
340+
341+
class TestListCompute:
342+
"""Test list_compute resource routing."""
343+
344+
@patch("databricks_mcp_server.tools.compute._list_clusters")
345+
def test_default_lists_clusters(self, mock_list):
346+
mock_list.return_value = [{"cluster_id": "abc", "cluster_name": "test"}]
347+
348+
result = list_compute()
349+
350+
mock_list.assert_called_once()
351+
assert "clusters" in result
352+
353+
@patch("databricks_mcp_server.tools.compute._get_cluster_status")
354+
def test_cluster_id_gets_status(self, mock_status):
355+
mock_status.return_value = {"cluster_id": "abc", "state": "RUNNING"}
356+
357+
result = list_compute(cluster_id="abc")
358+
359+
mock_status.assert_called_once_with("abc")
360+
assert result["state"] == "RUNNING"
361+
362+
@patch("databricks_mcp_server.tools.compute._get_best_cluster")
363+
def test_auto_select(self, mock_best):
364+
mock_best.return_value = "best-cluster-id"
365+
366+
result = list_compute(auto_select=True)
367+
368+
mock_best.assert_called_once()
369+
assert result["cluster_id"] == "best-cluster-id"
370+
371+
@patch("databricks_mcp_server.tools.compute._list_node_types")
372+
def test_node_types(self, mock_nodes):
373+
mock_nodes.return_value = [{"node_type_id": "i3.xlarge"}]
374+
375+
result = list_compute(resource="node_types")
376+
377+
mock_nodes.assert_called_once()
378+
assert "node_types" in result
379+
380+
@patch("databricks_mcp_server.tools.compute._list_spark_versions")
381+
def test_spark_versions(self, mock_versions):
382+
mock_versions.return_value = [{"key": "15.4.x-scala2.12"}]
383+
384+
result = list_compute(resource="spark_versions")
385+
386+
mock_versions.assert_called_once()
387+
assert "spark_versions" in result
388+
389+
def test_invalid_resource(self):
390+
result = list_compute(resource="invalid")
391+
assert result["success"] is False
392+
assert "unknown resource" in result["error"].lower()

0 commit comments

Comments
 (0)