|
26 | 26 |
|
27 | 27 | from nemo_run.core.execution.slurm import SlurmBatchRequest, SlurmExecutor |
28 | 28 | from nemo_run.core.tunnel.client import LocalTunnel |
| 29 | +from nemo_run.exceptions import PersistentSacctFailure |
29 | 30 | from nemo_run.run.torchx_backend.schedulers.slurm import ( |
| 31 | + MAX_CONSECUTIVE_SACCT_FAILURES, |
30 | 32 | SlurmTunnelScheduler, |
31 | 33 | TunnelLogIterator, |
32 | 34 | _get_job_dirs, |
@@ -380,6 +382,83 @@ def test_describe_returns_unknown_on_persistent_permission_error(slurm_scheduler |
380 | 382 | assert result.state == AppState.UNKNOWN |
381 | 383 |
|
382 | 384 |
|
| 385 | +def test_describe_returns_unknown_on_sacct_exception(slurm_scheduler, mocker): |
| 386 | + """Regression: transient sacct failure (e.g. after hours of polling) must not |
| 387 | + propagate an exception and kill the wait loop. describe() should return UNKNOWN |
| 388 | + (non-terminal) so polling continues until the job completes.""" |
| 389 | + from torchx.specs import AppState |
| 390 | + |
| 391 | + job_dirs = {"12345": ("/path/to/job", LocalTunnel(job_dir="/path/to/tunnel"), "log*")} |
| 392 | + mocker.patch( |
| 393 | + "nemo_run.run.torchx_backend.schedulers.slurm._get_job_dirs", |
| 394 | + return_value=job_dirs, |
| 395 | + ) |
| 396 | + mocker.patch.object(SlurmTunnelScheduler, "_initialize_tunnel") |
| 397 | + |
| 398 | + slurm_scheduler.tunnel = mock.MagicMock() |
| 399 | + slurm_scheduler.tunnel.run.side_effect = Exception("sacct: command failed") |
| 400 | + |
| 401 | + result = slurm_scheduler.describe("12345") |
| 402 | + assert result is not None |
| 403 | + assert result.state == AppState.UNKNOWN |
| 404 | + |
| 405 | + |
| 406 | +def test_describe_raises_persistent_sacct_failure_after_threshold(slurm_scheduler, mocker): |
| 407 | + """After MAX_CONSECUTIVE_SACCT_FAILURES consecutive sacct exceptions, describe() must |
| 408 | + raise PersistentSacctFailure so the caller can cancel the job instead of spinning forever.""" |
| 409 | + job_dirs = {"12345": ("/path/to/job", LocalTunnel(job_dir="/path/to/tunnel"), "log*")} |
| 410 | + mocker.patch( |
| 411 | + "nemo_run.run.torchx_backend.schedulers.slurm._get_job_dirs", |
| 412 | + return_value=job_dirs, |
| 413 | + ) |
| 414 | + mocker.patch.object(SlurmTunnelScheduler, "_initialize_tunnel") |
| 415 | + |
| 416 | + slurm_scheduler.tunnel = mock.MagicMock() |
| 417 | + slurm_scheduler.tunnel.run.side_effect = Exception("sacct: command failed") |
| 418 | + |
| 419 | + for _ in range(MAX_CONSECUTIVE_SACCT_FAILURES - 1): |
| 420 | + result = slurm_scheduler.describe("12345") |
| 421 | + assert result.state == AppState.UNKNOWN |
| 422 | + |
| 423 | + with pytest.raises(PersistentSacctFailure, match="12345"): |
| 424 | + slurm_scheduler.describe("12345") |
| 425 | + |
| 426 | + |
| 427 | +def test_describe_resets_sacct_failure_counter_on_success(slurm_scheduler, mocker): |
| 428 | + """A successful sacct call must reset the consecutive failure counter so that |
| 429 | + subsequent transient failures start fresh.""" |
| 430 | + job_dirs = {"12345": ("/path/to/job", LocalTunnel(job_dir="/path/to/tunnel"), "log*")} |
| 431 | + mocker.patch( |
| 432 | + "nemo_run.run.torchx_backend.schedulers.slurm._get_job_dirs", |
| 433 | + return_value=job_dirs, |
| 434 | + ) |
| 435 | + mocker.patch.object(SlurmTunnelScheduler, "_initialize_tunnel") |
| 436 | + |
| 437 | + slurm_scheduler.tunnel = mock.MagicMock() |
| 438 | + |
| 439 | + # Fail just below the threshold |
| 440 | + slurm_scheduler.tunnel.run.side_effect = Exception("sacct: command failed") |
| 441 | + for _ in range(MAX_CONSECUTIVE_SACCT_FAILURES - 1): |
| 442 | + slurm_scheduler.describe("12345") |
| 443 | + |
| 444 | + # Recover — sacct returns valid output |
| 445 | + header = "JobID|JobName|State|ExitCode" |
| 446 | + row = "12345|exp.master|RUNNING|0:0" |
| 447 | + success_result = mock.MagicMock() |
| 448 | + success_result.stdout = f"{header}\n{row}" |
| 449 | + slurm_scheduler.tunnel.run.side_effect = None |
| 450 | + slurm_scheduler.tunnel.run.return_value = success_result |
| 451 | + slurm_scheduler.describe("12345") |
| 452 | + |
| 453 | + assert slurm_scheduler._consecutive_sacct_failures.get("12345", 0) == 0 |
| 454 | + |
| 455 | + # Fail again — counter should restart from 1, not trigger threshold immediately |
| 456 | + slurm_scheduler.tunnel.run.side_effect = Exception("sacct: command failed") |
| 457 | + result = slurm_scheduler.describe("12345") |
| 458 | + assert result.state == AppState.UNKNOWN |
| 459 | + assert slurm_scheduler._consecutive_sacct_failures["12345"] == 1 |
| 460 | + |
| 461 | + |
383 | 462 | def test_schedule_with_dependencies(slurm_scheduler, slurm_executor): |
384 | 463 | mock_request = mock.MagicMock() |
385 | 464 | mock_request.cmd = ["sbatch", "--requeue", "--parsable"] |
|
0 commit comments