Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions databricks-mcp-server/databricks_mcp_server/tools/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
update_job as _update_job,
delete_job as _delete_job,
run_job_now as _run_job_now,
repair_run as _repair_run,
get_run as _get_run,
get_run_output as _get_run_output,
cancel_run as _cancel_run,
Expand Down Expand Up @@ -193,6 +194,10 @@ def manage_job_runs(
sql_params: Dict[str, str] = None,
dbt_commands: List[str] = None,
queue: Dict[str, Any] = None,
rerun_all_failed_tasks: bool = None,
rerun_dependent_tasks: bool = None,
rerun_tasks: List[str] = None,
latest_repair_id: int = None,
active_only: bool = False,
completed_only: bool = False,
limit: int = 25,
Expand All @@ -202,11 +207,11 @@ def manage_job_runs(
timeout: int = 3600,
poll_interval: int = 10,
) -> Dict[str, Any]:
"""Manage job runs: run_now, get, get_output, cancel, list, wait.
"""Manage job runs: run_now, repair, get, get_output, cancel, list, wait.

run_now: requires job_id, returns {run_id}. get/get_output/cancel/wait: require run_id.
list: filter by job_id/active_only/completed_only. wait: blocks until complete (timeout default 3600s).
Returns: run_now={run_id}, get=run details, get_output=logs+results, cancel={status}, list={items}, wait=full result."""
run_now: requires job_id, returns {run_id}. repair: requires run_id, reruns failed tasks (rerun_all_failed_tasks=True) or specific tasks (rerun_tasks=["task_key"]).
get/get_output/cancel/wait: require run_id. list: filter by job_id/active_only/completed_only. wait: blocks until complete (timeout default 3600s).
Returns: run_now={run_id}, repair={repair_id, run_id}, get=run details, get_output=logs+results, cancel={status}, list={items}, wait=full result."""
act = action.lower()

if act == "run_now":
Expand All @@ -225,6 +230,24 @@ def manage_job_runs(
)
return {"run_id": run_id_result}

elif act == "repair":
repair_id_result = _repair_run(
run_id=run_id,
rerun_all_failed_tasks=rerun_all_failed_tasks,
rerun_dependent_tasks=rerun_dependent_tasks,
rerun_tasks=rerun_tasks,
latest_repair_id=latest_repair_id,
jar_params=jar_params,
notebook_params=notebook_params,
python_params=python_params,
spark_submit_params=spark_submit_params,
python_named_params=python_named_params,
pipeline_params=pipeline_params,
sql_params=sql_params,
dbt_commands=dbt_commands,
)
return {"repair_id": repair_id_result, "run_id": run_id}

elif act == "get":
return _get_run(run_id=run_id)

Expand Down Expand Up @@ -252,4 +275,4 @@ def manage_job_runs(
result = _wait_for_run(run_id=run_id, timeout=timeout, poll_interval=poll_interval)
return result.to_dict()

raise ValueError(f"Invalid action: '{action}'. Valid: run_now, get, get_output, cancel, list, wait")
raise ValueError(f"Invalid action: '{action}'. Valid: run_now, repair, get, get_output, cancel, list, wait")
2 changes: 2 additions & 0 deletions databricks-tools-core/databricks_tools_core/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@

from .runs import (
run_job_now,
repair_run,
get_run,
get_run_output,
cancel_run,
Expand All @@ -79,6 +80,7 @@
"delete_job",
# Run Operations
"run_job_now",
"repair_run",
"get_run",
"get_run_output",
"cancel_run",
Expand Down
107 changes: 107 additions & 0 deletions databricks-tools-core/databricks_tools_core/jobs/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,113 @@ def run_job_now(
raise JobError(f"Failed to start run for job {job_id}: {str(e)}", job_id=job_id)


def repair_run(
run_id: int,
rerun_all_failed_tasks: Optional[bool] = None,
rerun_dependent_tasks: Optional[bool] = None,
rerun_tasks: Optional[List[str]] = None,
latest_repair_id: Optional[int] = None,
jar_params: Optional[List[str]] = None,
notebook_params: Optional[Dict[str, str]] = None,
python_params: Optional[List[str]] = None,
spark_submit_params: Optional[List[str]] = None,
python_named_params: Optional[Dict[str, str]] = None,
pipeline_params: Optional[Dict[str, Any]] = None,
sql_params: Optional[Dict[str, str]] = None,
dbt_commands: Optional[List[str]] = None,
) -> int:
"""
Repair a failed job run by re-running only failed or specified tasks.

Tasks are re-run as part of the original job run using current job and
task settings. Use this instead of run_job_now to avoid re-running
tasks that already succeeded.

Args:
run_id: The job run ID to repair (must not be in progress)
rerun_all_failed_tasks: If True, rerun all tasks that failed
rerun_dependent_tasks: If True, also rerun tasks that depend on failed tasks
rerun_tasks: List of specific task keys to rerun
latest_repair_id: ID of the latest repair to ensure sequential repairs
jar_params: Parameters for JAR tasks
notebook_params: Parameters for notebook tasks
python_params: Parameters for Python tasks
spark_submit_params: Parameters for spark-submit tasks
python_named_params: Named parameters for Python tasks
pipeline_params: Parameters for pipeline tasks
sql_params: Parameters for SQL tasks
dbt_commands: Commands for dbt tasks

Returns:
Repair ID (integer) for tracking the repair

Raises:
JobError: If repair fails to start

Example:
>>> repair_id = repair_run(run_id=456, rerun_all_failed_tasks=True)
>>> print(f"Started repair {repair_id}")
"""
w = get_workspace_client()

try:
# Build kwargs for SDK call
kwargs: Dict[str, Any] = {"run_id": run_id}

# Add repair-specific parameters
if rerun_all_failed_tasks is not None:
kwargs["rerun_all_failed_tasks"] = rerun_all_failed_tasks
if rerun_dependent_tasks is not None:
kwargs["rerun_dependent_tasks"] = rerun_dependent_tasks
if rerun_tasks:
kwargs["rerun_tasks"] = rerun_tasks
if latest_repair_id is not None:
kwargs["latest_repair_id"] = latest_repair_id

# Add optional task parameters
if jar_params:
kwargs["jar_params"] = jar_params
if notebook_params:
kwargs["notebook_params"] = notebook_params
if python_params:
kwargs["python_params"] = python_params
if spark_submit_params:
kwargs["spark_submit_params"] = spark_submit_params
if python_named_params:
kwargs["python_named_params"] = python_named_params
if pipeline_params:
kwargs["pipeline_params"] = pipeline_params
if sql_params:
kwargs["sql_params"] = sql_params
if dbt_commands:
kwargs["dbt_commands"] = dbt_commands

# Trigger repair - SDK returns Wait[Run] object
# Wait.response is a RepairRunResponse with repair_id field
response = w.jobs.repair_run(**kwargs)

# Extract repair_id from response
repair_id = None
if hasattr(response, "response") and hasattr(response.response, "repair_id"):
repair_id = response.response.repair_id
elif hasattr(response, "repair_id"):
repair_id = response.repair_id
else:
# Fallback: try to get it from as_dict()
response_dict = response.as_dict() if hasattr(response, "as_dict") else {}
repair_id = response_dict.get("repair_id")

if repair_id is None:
raise JobError(f"Failed to extract repair_id from response for run {run_id}", run_id=run_id)

return repair_id

except JobError:
raise
except Exception as e:
raise JobError(f"Failed to repair run {run_id}: {str(e)}", run_id=run_id)


def get_run(run_id: int) -> Dict[str, Any]:
"""
Get detailed run status and information.
Expand Down
45 changes: 45 additions & 0 deletions databricks-tools-core/tests/integration/jobs/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,51 @@ def test_notebook_path() -> str:
logger.warning(f"Failed to cleanup test notebook: {e}")


@pytest.fixture(scope="module")
def failing_notebook_path() -> str:
"""
Create a notebook that deliberately fails.

Used for testing repair_run functionality.
Returns the workspace path to the notebook.
"""
w = get_workspace_client()
user = w.current_user.me()
notebook_path = f"/Users/{user.user_name}/test_jobs/test_failing_notebook"

notebook_content = """# Databricks notebook source
# Test notebook that deliberately fails for repair_run tests
raise Exception("Deliberate failure for repair_run test")
"""

logger.info(f"Creating failing test notebook: {notebook_path}")

try:
parent_folder = "/".join(notebook_path.split("/")[:-1])
w.workspace.mkdirs(parent_folder)

content_b64 = base64.b64encode(notebook_content.encode("utf-8")).decode("utf-8")
w.workspace.import_(
path=notebook_path,
format=ImportFormat.SOURCE,
language=Language.PYTHON,
content=content_b64,
overwrite=True,
)
logger.info(f"Failing test notebook created: {notebook_path}")
except Exception as e:
logger.error(f"Failed to create failing test notebook: {e}")
raise

yield notebook_path

try:
logger.info(f"Cleaning up failing test notebook: {notebook_path}")
w.workspace.delete(notebook_path)
except Exception as e:
logger.warning(f"Failed to cleanup failing test notebook: {e}")


@pytest.fixture(scope="function")
def cleanup_job():
"""
Expand Down
66 changes: 66 additions & 0 deletions databricks-tools-core/tests/integration/jobs/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from databricks_tools_core.jobs import (
create_job,
run_job_now,
repair_run,
get_run,
cancel_run,
list_runs,
Expand Down Expand Up @@ -410,3 +411,68 @@ def test_wait_for_run_result_object(
assert "lifecycle_state" in result_dict

logger.info(f"JobRunResult.to_dict(): {result_dict}")


@pytest.mark.integration
class TestRepairRun:
"""Tests for repairing failed job runs."""

def test_repair_run_rerun_all_failed(
self,
failing_notebook_path: str,
test_notebook_path: str,
cleanup_job,
):
"""Should repair a failed run by rerunning all failed tasks."""
from databricks_tools_core.auth import get_workspace_client
import base64
from databricks.sdk.service.workspace import ImportFormat, Language

w = get_workspace_client()

# Create job with the failing notebook
tasks = [
{
"task_key": "repairable_task",
"notebook_task": {
"notebook_path": failing_notebook_path,
"source": "WORKSPACE",
},
}
]
job = create_job(name="test_repair_run", tasks=tasks)
cleanup_job(job["job_id"])

# Run the job (it will fail)
run_id = run_job_now(job_id=job["job_id"])
logger.info(f"Started run {run_id}, waiting for failure...")

result = wait_for_run(run_id=run_id, timeout=300, poll_interval=10)
assert not result.success, "Run should have failed"
logger.info(f"Run failed as expected: {result.result_state}")

# Replace the failing notebook with a passing one
passing_content = (
"# Databricks notebook source\n"
"print('Repaired successfully')\n"
"dbutils.notebook.exit('success')\n"
)
content_b64 = base64.b64encode(passing_content.encode("utf-8")).decode("utf-8")
w.workspace.import_(
path=failing_notebook_path,
format=ImportFormat.SOURCE,
language=Language.PYTHON,
content=content_b64,
overwrite=True,
)

# Repair the run
repair_id = repair_run(run_id=run_id, rerun_all_failed_tasks=True)
assert repair_id is not None, "Should return a repair_id"
assert isinstance(repair_id, int), "repair_id should be an integer"
logger.info(f"Repair started with repair_id={repair_id}")

# Wait for the repair to complete
repaired_result = wait_for_run(run_id=run_id, timeout=300, poll_interval=10)
assert repaired_result.success, "Repaired run should succeed"
logger.info(f"Repair completed successfully: {repaired_result.result_state}")