Skip to content

Commit ee45a01

Browse files
authored
Add repair_run action to manage_job_runs MCP tool (#444)
* Add repair_run action to manage_job_runs MCP tool * Remove try/except wrapper
1 parent 2feb7e0 commit ee45a01

File tree

5 files changed

+248
-5
lines changed

5 files changed

+248
-5
lines changed

databricks-mcp-server/databricks_mcp_server/tools/jobs.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
update_job as _update_job,
1717
delete_job as _delete_job,
1818
run_job_now as _run_job_now,
19+
repair_run as _repair_run,
1920
get_run as _get_run,
2021
get_run_output as _get_run_output,
2122
cancel_run as _cancel_run,
@@ -193,6 +194,10 @@ def manage_job_runs(
193194
sql_params: Dict[str, str] = None,
194195
dbt_commands: List[str] = None,
195196
queue: Dict[str, Any] = None,
197+
rerun_all_failed_tasks: bool = None,
198+
rerun_dependent_tasks: bool = None,
199+
rerun_tasks: List[str] = None,
200+
latest_repair_id: int = None,
196201
active_only: bool = False,
197202
completed_only: bool = False,
198203
limit: int = 25,
@@ -202,11 +207,11 @@ def manage_job_runs(
202207
timeout: int = 3600,
203208
poll_interval: int = 10,
204209
) -> Dict[str, Any]:
205-
"""Manage job runs: run_now, get, get_output, cancel, list, wait.
210+
"""Manage job runs: run_now, repair, get, get_output, cancel, list, wait.
206211
207-
run_now: requires job_id, returns {run_id}. get/get_output/cancel/wait: require run_id.
208-
list: filter by job_id/active_only/completed_only. wait: blocks until complete (timeout default 3600s).
209-
Returns: run_now={run_id}, get=run details, get_output=logs+results, cancel={status}, list={items}, wait=full result."""
212+
run_now: requires job_id, returns {run_id}. repair: requires run_id, reruns failed tasks (rerun_all_failed_tasks=True) or specific tasks (rerun_tasks=["task_key"]).
213+
get/get_output/cancel/wait: require run_id. list: filter by job_id/active_only/completed_only. wait: blocks until complete (timeout default 3600s).
214+
Returns: run_now={run_id}, repair={repair_id, run_id}, get=run details, get_output=logs+results, cancel={status}, list={items}, wait=full result."""
210215
act = action.lower()
211216

212217
if act == "run_now":
@@ -225,6 +230,24 @@ def manage_job_runs(
225230
)
226231
return {"run_id": run_id_result}
227232

233+
elif act == "repair":
234+
repair_id_result = _repair_run(
235+
run_id=run_id,
236+
rerun_all_failed_tasks=rerun_all_failed_tasks,
237+
rerun_dependent_tasks=rerun_dependent_tasks,
238+
rerun_tasks=rerun_tasks,
239+
latest_repair_id=latest_repair_id,
240+
jar_params=jar_params,
241+
notebook_params=notebook_params,
242+
python_params=python_params,
243+
spark_submit_params=spark_submit_params,
244+
python_named_params=python_named_params,
245+
pipeline_params=pipeline_params,
246+
sql_params=sql_params,
247+
dbt_commands=dbt_commands,
248+
)
249+
return {"repair_id": repair_id_result, "run_id": run_id}
250+
228251
elif act == "get":
229252
return _get_run(run_id=run_id)
230253

@@ -252,4 +275,4 @@ def manage_job_runs(
252275
result = _wait_for_run(run_id=run_id, timeout=timeout, poll_interval=poll_interval)
253276
return result.to_dict()
254277

255-
raise ValueError(f"Invalid action: '{action}'. Valid: run_now, get, get_output, cancel, list, wait")
278+
raise ValueError(f"Invalid action: '{action}'. Valid: run_now, repair, get, get_output, cancel, list, wait")

databricks-tools-core/databricks_tools_core/jobs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656

5757
from .runs import (
5858
run_job_now,
59+
repair_run,
5960
get_run,
6061
get_run_output,
6162
cancel_run,
@@ -79,6 +80,7 @@
7980
"delete_job",
8081
# Run Operations
8182
"run_job_now",
83+
"repair_run",
8284
"get_run",
8385
"get_run_output",
8486
"cancel_run",

databricks-tools-core/databricks_tools_core/jobs/runs.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,113 @@ def run_job_now(
121121
raise JobError(f"Failed to start run for job {job_id}: {str(e)}", job_id=job_id)
122122

123123

124+
def repair_run(
125+
run_id: int,
126+
rerun_all_failed_tasks: Optional[bool] = None,
127+
rerun_dependent_tasks: Optional[bool] = None,
128+
rerun_tasks: Optional[List[str]] = None,
129+
latest_repair_id: Optional[int] = None,
130+
jar_params: Optional[List[str]] = None,
131+
notebook_params: Optional[Dict[str, str]] = None,
132+
python_params: Optional[List[str]] = None,
133+
spark_submit_params: Optional[List[str]] = None,
134+
python_named_params: Optional[Dict[str, str]] = None,
135+
pipeline_params: Optional[Dict[str, Any]] = None,
136+
sql_params: Optional[Dict[str, str]] = None,
137+
dbt_commands: Optional[List[str]] = None,
138+
) -> int:
139+
"""
140+
Repair a failed job run by re-running only failed or specified tasks.
141+
142+
Tasks are re-run as part of the original job run using current job and
143+
task settings. Use this instead of run_job_now to avoid re-running
144+
tasks that already succeeded.
145+
146+
Args:
147+
run_id: The job run ID to repair (must not be in progress)
148+
rerun_all_failed_tasks: If True, rerun all tasks that failed
149+
rerun_dependent_tasks: If True, also rerun tasks that depend on failed tasks
150+
rerun_tasks: List of specific task keys to rerun
151+
latest_repair_id: ID of the latest repair to ensure sequential repairs
152+
jar_params: Parameters for JAR tasks
153+
notebook_params: Parameters for notebook tasks
154+
python_params: Parameters for Python tasks
155+
spark_submit_params: Parameters for spark-submit tasks
156+
python_named_params: Named parameters for Python tasks
157+
pipeline_params: Parameters for pipeline tasks
158+
sql_params: Parameters for SQL tasks
159+
dbt_commands: Commands for dbt tasks
160+
161+
Returns:
162+
Repair ID (integer) for tracking the repair
163+
164+
Raises:
165+
JobError: If repair fails to start
166+
167+
Example:
168+
>>> repair_id = repair_run(run_id=456, rerun_all_failed_tasks=True)
169+
>>> print(f"Started repair {repair_id}")
170+
"""
171+
w = get_workspace_client()
172+
173+
try:
174+
# Build kwargs for SDK call
175+
kwargs: Dict[str, Any] = {"run_id": run_id}
176+
177+
# Add repair-specific parameters
178+
if rerun_all_failed_tasks is not None:
179+
kwargs["rerun_all_failed_tasks"] = rerun_all_failed_tasks
180+
if rerun_dependent_tasks is not None:
181+
kwargs["rerun_dependent_tasks"] = rerun_dependent_tasks
182+
if rerun_tasks:
183+
kwargs["rerun_tasks"] = rerun_tasks
184+
if latest_repair_id is not None:
185+
kwargs["latest_repair_id"] = latest_repair_id
186+
187+
# Add optional task parameters
188+
if jar_params:
189+
kwargs["jar_params"] = jar_params
190+
if notebook_params:
191+
kwargs["notebook_params"] = notebook_params
192+
if python_params:
193+
kwargs["python_params"] = python_params
194+
if spark_submit_params:
195+
kwargs["spark_submit_params"] = spark_submit_params
196+
if python_named_params:
197+
kwargs["python_named_params"] = python_named_params
198+
if pipeline_params:
199+
kwargs["pipeline_params"] = pipeline_params
200+
if sql_params:
201+
kwargs["sql_params"] = sql_params
202+
if dbt_commands:
203+
kwargs["dbt_commands"] = dbt_commands
204+
205+
# Trigger repair - SDK returns Wait[Run] object
206+
# Wait.response is a RepairRunResponse with repair_id field
207+
response = w.jobs.repair_run(**kwargs)
208+
209+
# Extract repair_id from response
210+
repair_id = None
211+
if hasattr(response, "response") and hasattr(response.response, "repair_id"):
212+
repair_id = response.response.repair_id
213+
elif hasattr(response, "repair_id"):
214+
repair_id = response.repair_id
215+
else:
216+
# Fallback: try to get it from as_dict()
217+
response_dict = response.as_dict() if hasattr(response, "as_dict") else {}
218+
repair_id = response_dict.get("repair_id")
219+
220+
if repair_id is None:
221+
raise JobError(f"Failed to extract repair_id from response for run {run_id}", run_id=run_id)
222+
223+
return repair_id
224+
225+
except JobError:
226+
raise
227+
except Exception as e:
228+
raise JobError(f"Failed to repair run {run_id}: {str(e)}", run_id=run_id)
229+
230+
124231
def get_run(run_id: int) -> Dict[str, Any]:
125232
"""
126233
Get detailed run status and information.

databricks-tools-core/tests/integration/jobs/conftest.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,51 @@ def test_notebook_path() -> str:
8888
logger.warning(f"Failed to cleanup test notebook: {e}")
8989

9090

91+
@pytest.fixture(scope="module")
92+
def failing_notebook_path() -> str:
93+
"""
94+
Create a notebook that deliberately fails.
95+
96+
Used for testing repair_run functionality.
97+
Returns the workspace path to the notebook.
98+
"""
99+
w = get_workspace_client()
100+
user = w.current_user.me()
101+
notebook_path = f"/Users/{user.user_name}/test_jobs/test_failing_notebook"
102+
103+
notebook_content = """# Databricks notebook source
104+
# Test notebook that deliberately fails for repair_run tests
105+
raise Exception("Deliberate failure for repair_run test")
106+
"""
107+
108+
logger.info(f"Creating failing test notebook: {notebook_path}")
109+
110+
try:
111+
parent_folder = "/".join(notebook_path.split("/")[:-1])
112+
w.workspace.mkdirs(parent_folder)
113+
114+
content_b64 = base64.b64encode(notebook_content.encode("utf-8")).decode("utf-8")
115+
w.workspace.import_(
116+
path=notebook_path,
117+
format=ImportFormat.SOURCE,
118+
language=Language.PYTHON,
119+
content=content_b64,
120+
overwrite=True,
121+
)
122+
logger.info(f"Failing test notebook created: {notebook_path}")
123+
except Exception as e:
124+
logger.error(f"Failed to create failing test notebook: {e}")
125+
raise
126+
127+
yield notebook_path
128+
129+
try:
130+
logger.info(f"Cleaning up failing test notebook: {notebook_path}")
131+
w.workspace.delete(notebook_path)
132+
except Exception as e:
133+
logger.warning(f"Failed to cleanup failing test notebook: {e}")
134+
135+
91136
@pytest.fixture(scope="function")
92137
def cleanup_job():
93138
"""

databricks-tools-core/tests/integration/jobs/test_runs.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from databricks_tools_core.jobs import (
1919
create_job,
2020
run_job_now,
21+
repair_run,
2122
get_run,
2223
cancel_run,
2324
list_runs,
@@ -410,3 +411,68 @@ def test_wait_for_run_result_object(
410411
assert "lifecycle_state" in result_dict
411412

412413
logger.info(f"JobRunResult.to_dict(): {result_dict}")
414+
415+
416+
@pytest.mark.integration
417+
class TestRepairRun:
418+
"""Tests for repairing failed job runs."""
419+
420+
def test_repair_run_rerun_all_failed(
421+
self,
422+
failing_notebook_path: str,
423+
test_notebook_path: str,
424+
cleanup_job,
425+
):
426+
"""Should repair a failed run by rerunning all failed tasks."""
427+
from databricks_tools_core.auth import get_workspace_client
428+
import base64
429+
from databricks.sdk.service.workspace import ImportFormat, Language
430+
431+
w = get_workspace_client()
432+
433+
# Create job with the failing notebook
434+
tasks = [
435+
{
436+
"task_key": "repairable_task",
437+
"notebook_task": {
438+
"notebook_path": failing_notebook_path,
439+
"source": "WORKSPACE",
440+
},
441+
}
442+
]
443+
job = create_job(name="test_repair_run", tasks=tasks)
444+
cleanup_job(job["job_id"])
445+
446+
# Run the job (it will fail)
447+
run_id = run_job_now(job_id=job["job_id"])
448+
logger.info(f"Started run {run_id}, waiting for failure...")
449+
450+
result = wait_for_run(run_id=run_id, timeout=300, poll_interval=10)
451+
assert not result.success, "Run should have failed"
452+
logger.info(f"Run failed as expected: {result.result_state}")
453+
454+
# Replace the failing notebook with a passing one
455+
passing_content = (
456+
"# Databricks notebook source\n"
457+
"print('Repaired successfully')\n"
458+
"dbutils.notebook.exit('success')\n"
459+
)
460+
content_b64 = base64.b64encode(passing_content.encode("utf-8")).decode("utf-8")
461+
w.workspace.import_(
462+
path=failing_notebook_path,
463+
format=ImportFormat.SOURCE,
464+
language=Language.PYTHON,
465+
content=content_b64,
466+
overwrite=True,
467+
)
468+
469+
# Repair the run
470+
repair_id = repair_run(run_id=run_id, rerun_all_failed_tasks=True)
471+
assert repair_id is not None, "Should return a repair_id"
472+
assert isinstance(repair_id, int), "repair_id should be an integer"
473+
logger.info(f"Repair started with repair_id={repair_id}")
474+
475+
# Wait for the repair to complete
476+
repaired_result = wait_for_run(run_id=run_id, timeout=300, poll_interval=10)
477+
assert repaired_result.success, "Repaired run should succeed"
478+
logger.info(f"Repair completed successfully: {repaired_result.result_state}")

0 commit comments

Comments
 (0)