diff --git a/databricks-mcp-server/databricks_mcp_server/tools/jobs.py b/databricks-mcp-server/databricks_mcp_server/tools/jobs.py index f1bce352..1b319bdc 100644 --- a/databricks-mcp-server/databricks_mcp_server/tools/jobs.py +++ b/databricks-mcp-server/databricks_mcp_server/tools/jobs.py @@ -16,6 +16,7 @@ update_job as _update_job, delete_job as _delete_job, run_job_now as _run_job_now, + repair_run as _repair_run, get_run as _get_run, get_run_output as _get_run_output, cancel_run as _cancel_run, @@ -193,6 +194,10 @@ def manage_job_runs( sql_params: Dict[str, str] = None, dbt_commands: List[str] = None, queue: Dict[str, Any] = None, + rerun_all_failed_tasks: bool = None, + rerun_dependent_tasks: bool = None, + rerun_tasks: List[str] = None, + latest_repair_id: int = None, active_only: bool = False, completed_only: bool = False, limit: int = 25, @@ -202,11 +207,11 @@ def manage_job_runs( timeout: int = 3600, poll_interval: int = 10, ) -> Dict[str, Any]: - """Manage job runs: run_now, get, get_output, cancel, list, wait. + """Manage job runs: run_now, repair, get, get_output, cancel, list, wait. - run_now: requires job_id, returns {run_id}. get/get_output/cancel/wait: require run_id. - list: filter by job_id/active_only/completed_only. wait: blocks until complete (timeout default 3600s). - Returns: run_now={run_id}, get=run details, get_output=logs+results, cancel={status}, list={items}, wait=full result.""" + run_now: requires job_id, returns {run_id}. repair: requires run_id, reruns failed tasks (rerun_all_failed_tasks=True) or specific tasks (rerun_tasks=["task_key"]). + get/get_output/cancel/wait: require run_id. list: filter by job_id/active_only/completed_only. wait: blocks until complete (timeout default 3600s). + Returns: run_now={run_id}, repair={repair_id, run_id}, get=run details, get_output=logs+results, cancel={status}, list={items}, wait=full result.""" act = action.lower() if act == "run_now": @@ -225,6 +230,24 @@ def manage_job_runs( ) return {"run_id": run_id_result} + elif act == "repair": + repair_id_result = _repair_run( + run_id=run_id, + rerun_all_failed_tasks=rerun_all_failed_tasks, + rerun_dependent_tasks=rerun_dependent_tasks, + rerun_tasks=rerun_tasks, + latest_repair_id=latest_repair_id, + jar_params=jar_params, + notebook_params=notebook_params, + python_params=python_params, + spark_submit_params=spark_submit_params, + python_named_params=python_named_params, + pipeline_params=pipeline_params, + sql_params=sql_params, + dbt_commands=dbt_commands, + ) + return {"repair_id": repair_id_result, "run_id": run_id} + elif act == "get": return _get_run(run_id=run_id) @@ -252,4 +275,4 @@ def manage_job_runs( result = _wait_for_run(run_id=run_id, timeout=timeout, poll_interval=poll_interval) return result.to_dict() - raise ValueError(f"Invalid action: '{action}'. Valid: run_now, get, get_output, cancel, list, wait") + raise ValueError(f"Invalid action: '{action}'. Valid: run_now, repair, get, get_output, cancel, list, wait") diff --git a/databricks-tools-core/databricks_tools_core/jobs/__init__.py b/databricks-tools-core/databricks_tools_core/jobs/__init__.py index a0ff1c40..522454c1 100644 --- a/databricks-tools-core/databricks_tools_core/jobs/__init__.py +++ b/databricks-tools-core/databricks_tools_core/jobs/__init__.py @@ -56,6 +56,7 @@ from .runs import ( run_job_now, + repair_run, get_run, get_run_output, cancel_run, @@ -79,6 +80,7 @@ "delete_job", # Run Operations "run_job_now", + "repair_run", "get_run", "get_run_output", "cancel_run", diff --git a/databricks-tools-core/databricks_tools_core/jobs/runs.py b/databricks-tools-core/databricks_tools_core/jobs/runs.py index 6b74c055..d18faa60 100644 --- a/databricks-tools-core/databricks_tools_core/jobs/runs.py +++ b/databricks-tools-core/databricks_tools_core/jobs/runs.py @@ -121,6 +121,113 @@ def run_job_now( raise JobError(f"Failed to start run for job {job_id}: {str(e)}", job_id=job_id) +def repair_run( + run_id: int, + rerun_all_failed_tasks: Optional[bool] = None, + rerun_dependent_tasks: Optional[bool] = None, + rerun_tasks: Optional[List[str]] = None, + latest_repair_id: Optional[int] = None, + jar_params: Optional[List[str]] = None, + notebook_params: Optional[Dict[str, str]] = None, + python_params: Optional[List[str]] = None, + spark_submit_params: Optional[List[str]] = None, + python_named_params: Optional[Dict[str, str]] = None, + pipeline_params: Optional[Dict[str, Any]] = None, + sql_params: Optional[Dict[str, str]] = None, + dbt_commands: Optional[List[str]] = None, +) -> int: + """ + Repair a failed job run by re-running only failed or specified tasks. + + Tasks are re-run as part of the original job run using current job and + task settings. Use this instead of run_job_now to avoid re-running + tasks that already succeeded. + + Args: + run_id: The job run ID to repair (must not be in progress) + rerun_all_failed_tasks: If True, rerun all tasks that failed + rerun_dependent_tasks: If True, also rerun tasks that depend on failed tasks + rerun_tasks: List of specific task keys to rerun + latest_repair_id: ID of the latest repair to ensure sequential repairs + jar_params: Parameters for JAR tasks + notebook_params: Parameters for notebook tasks + python_params: Parameters for Python tasks + spark_submit_params: Parameters for spark-submit tasks + python_named_params: Named parameters for Python tasks + pipeline_params: Parameters for pipeline tasks + sql_params: Parameters for SQL tasks + dbt_commands: Commands for dbt tasks + + Returns: + Repair ID (integer) for tracking the repair + + Raises: + JobError: If repair fails to start + + Example: + >>> repair_id = repair_run(run_id=456, rerun_all_failed_tasks=True) + >>> print(f"Started repair {repair_id}") + """ + w = get_workspace_client() + + try: + # Build kwargs for SDK call + kwargs: Dict[str, Any] = {"run_id": run_id} + + # Add repair-specific parameters + if rerun_all_failed_tasks is not None: + kwargs["rerun_all_failed_tasks"] = rerun_all_failed_tasks + if rerun_dependent_tasks is not None: + kwargs["rerun_dependent_tasks"] = rerun_dependent_tasks + if rerun_tasks: + kwargs["rerun_tasks"] = rerun_tasks + if latest_repair_id is not None: + kwargs["latest_repair_id"] = latest_repair_id + + # Add optional task parameters + if jar_params: + kwargs["jar_params"] = jar_params + if notebook_params: + kwargs["notebook_params"] = notebook_params + if python_params: + kwargs["python_params"] = python_params + if spark_submit_params: + kwargs["spark_submit_params"] = spark_submit_params + if python_named_params: + kwargs["python_named_params"] = python_named_params + if pipeline_params: + kwargs["pipeline_params"] = pipeline_params + if sql_params: + kwargs["sql_params"] = sql_params + if dbt_commands: + kwargs["dbt_commands"] = dbt_commands + + # Trigger repair - SDK returns Wait[Run] object + # Wait.response is a RepairRunResponse with repair_id field + response = w.jobs.repair_run(**kwargs) + + # Extract repair_id from response + repair_id = None + if hasattr(response, "response") and hasattr(response.response, "repair_id"): + repair_id = response.response.repair_id + elif hasattr(response, "repair_id"): + repair_id = response.repair_id + else: + # Fallback: try to get it from as_dict() + response_dict = response.as_dict() if hasattr(response, "as_dict") else {} + repair_id = response_dict.get("repair_id") + + if repair_id is None: + raise JobError(f"Failed to extract repair_id from response for run {run_id}", run_id=run_id) + + return repair_id + + except JobError: + raise + except Exception as e: + raise JobError(f"Failed to repair run {run_id}: {str(e)}", run_id=run_id) + + def get_run(run_id: int) -> Dict[str, Any]: """ Get detailed run status and information. diff --git a/databricks-tools-core/tests/integration/jobs/conftest.py b/databricks-tools-core/tests/integration/jobs/conftest.py index 263aba77..a76b6c10 100644 --- a/databricks-tools-core/tests/integration/jobs/conftest.py +++ b/databricks-tools-core/tests/integration/jobs/conftest.py @@ -88,6 +88,51 @@ def test_notebook_path() -> str: logger.warning(f"Failed to cleanup test notebook: {e}") +@pytest.fixture(scope="module") +def failing_notebook_path() -> str: + """ + Create a notebook that deliberately fails. + + Used for testing repair_run functionality. + Returns the workspace path to the notebook. + """ + w = get_workspace_client() + user = w.current_user.me() + notebook_path = f"/Users/{user.user_name}/test_jobs/test_failing_notebook" + + notebook_content = """# Databricks notebook source +# Test notebook that deliberately fails for repair_run tests +raise Exception("Deliberate failure for repair_run test") +""" + + logger.info(f"Creating failing test notebook: {notebook_path}") + + try: + parent_folder = "/".join(notebook_path.split("/")[:-1]) + w.workspace.mkdirs(parent_folder) + + content_b64 = base64.b64encode(notebook_content.encode("utf-8")).decode("utf-8") + w.workspace.import_( + path=notebook_path, + format=ImportFormat.SOURCE, + language=Language.PYTHON, + content=content_b64, + overwrite=True, + ) + logger.info(f"Failing test notebook created: {notebook_path}") + except Exception as e: + logger.error(f"Failed to create failing test notebook: {e}") + raise + + yield notebook_path + + try: + logger.info(f"Cleaning up failing test notebook: {notebook_path}") + w.workspace.delete(notebook_path) + except Exception as e: + logger.warning(f"Failed to cleanup failing test notebook: {e}") + + @pytest.fixture(scope="function") def cleanup_job(): """ diff --git a/databricks-tools-core/tests/integration/jobs/test_runs.py b/databricks-tools-core/tests/integration/jobs/test_runs.py index b435fb82..1cdc1215 100644 --- a/databricks-tools-core/tests/integration/jobs/test_runs.py +++ b/databricks-tools-core/tests/integration/jobs/test_runs.py @@ -18,6 +18,7 @@ from databricks_tools_core.jobs import ( create_job, run_job_now, + repair_run, get_run, cancel_run, list_runs, @@ -410,3 +411,68 @@ def test_wait_for_run_result_object( assert "lifecycle_state" in result_dict logger.info(f"JobRunResult.to_dict(): {result_dict}") + + +@pytest.mark.integration +class TestRepairRun: + """Tests for repairing failed job runs.""" + + def test_repair_run_rerun_all_failed( + self, + failing_notebook_path: str, + test_notebook_path: str, + cleanup_job, + ): + """Should repair a failed run by rerunning all failed tasks.""" + from databricks_tools_core.auth import get_workspace_client + import base64 + from databricks.sdk.service.workspace import ImportFormat, Language + + w = get_workspace_client() + + # Create job with the failing notebook + tasks = [ + { + "task_key": "repairable_task", + "notebook_task": { + "notebook_path": failing_notebook_path, + "source": "WORKSPACE", + }, + } + ] + job = create_job(name="test_repair_run", tasks=tasks) + cleanup_job(job["job_id"]) + + # Run the job (it will fail) + run_id = run_job_now(job_id=job["job_id"]) + logger.info(f"Started run {run_id}, waiting for failure...") + + result = wait_for_run(run_id=run_id, timeout=300, poll_interval=10) + assert not result.success, "Run should have failed" + logger.info(f"Run failed as expected: {result.result_state}") + + # Replace the failing notebook with a passing one + passing_content = ( + "# Databricks notebook source\n" + "print('Repaired successfully')\n" + "dbutils.notebook.exit('success')\n" + ) + content_b64 = base64.b64encode(passing_content.encode("utf-8")).decode("utf-8") + w.workspace.import_( + path=failing_notebook_path, + format=ImportFormat.SOURCE, + language=Language.PYTHON, + content=content_b64, + overwrite=True, + ) + + # Repair the run + repair_id = repair_run(run_id=run_id, rerun_all_failed_tasks=True) + assert repair_id is not None, "Should return a repair_id" + assert isinstance(repair_id, int), "repair_id should be an integer" + logger.info(f"Repair started with repair_id={repair_id}") + + # Wait for the repair to complete + repaired_result = wait_for_run(run_id=run_id, timeout=300, poll_interval=10) + assert repaired_result.success, "Repaired run should succeed" + logger.info(f"Repair completed successfully: {repaired_result.result_state}")