Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 51 additions & 5 deletions nemo_run/core/tunnel/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import shutil
import socket
import subprocess
import sys
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
Expand Down Expand Up @@ -237,17 +236,61 @@ def run(self, command: str, hide: bool = True, warn: bool = False, **kwargs) ->
if self.pre_command:
command = f"{self.pre_command} && {command}"

return self.session.run(command, hide=hide, warn=warn, **kwargs)
delay = 4
last_exc: Exception | None = None
for attempt in range(4):
try:
return self.session.run(command, hide=hide, warn=warn, **kwargs)
except (RuntimeError, EOFError, OSError, ConnectionError) as e:
last_exc = e
logger.warning(
f"SSH command failed (attempt {attempt + 1}/4): {e}, retrying in {delay}s..."
)
time.sleep(delay)
delay = min(delay * 2, 60)
self.connect()
assert last_exc is not None
raise last_exc

def put(self, local_path: str, remote_path: str) -> None:
self._check_connect()
assert self.session, "session is not yet established."
self.session.put(local_path, remote_path)
delay = 4
last_exc: Exception | None = None
for attempt in range(4):
try:
self.session.put(local_path, remote_path)
return
except (RuntimeError, EOFError, OSError, ConnectionError) as e:
last_exc = e
logger.warning(
f"SSH put failed (attempt {attempt + 1}/4): {e}, retrying in {delay}s..."
)
time.sleep(delay)
delay = min(delay * 2, 60)
self.connect()
assert last_exc is not None
raise last_exc

def get(self, remote_path: str, local_path: str) -> None:
self._check_connect()
assert self.session, "session is not yet established."
self.session.get(remote_path, local_path)
delay = 4
last_exc: Exception | None = None
for attempt in range(4):
try:
self.session.get(remote_path, local_path)
return
except (RuntimeError, EOFError, OSError, ConnectionError) as e:
last_exc = e
logger.warning(
f"SSH get failed (attempt {attempt + 1}/4): {e}, retrying in {delay}s..."
)
time.sleep(delay)
delay = min(delay * 2, 60)
self.connect()
assert last_exc is not None
raise last_exc

def cleanup(self):
if self.session:
Expand Down Expand Up @@ -302,7 +345,10 @@ def _authenticate(self):
except Exception:
logger.debug("[bold red]:x: Failed to Authenticate your connection")
if not self.session.is_connected:
sys.exit(1)
raise ConnectionError(
f"Failed to connect to {self.user}@{self.host}. "
"Check your SSH credentials and network connectivity."
)
logger.debug(":white_check_mark: The client is authenticated successfully")


Expand Down
24 changes: 19 additions & 5 deletions nemo_run/core/tunnel/rsync.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import logging
import time
from typing import Iterable

from fabric import Connection
Expand Down Expand Up @@ -88,8 +89,21 @@ def rsync(
cmd = "rsync {} {} {}@{}:{}"
cmd = cmd.format(options, source, user, host, target)
c.run(f"mkdir -p {target}", hide=hide_output)
result = c.local(cmd, hide=hide_output)
if result:
logger.info(f"Successfully ran `{result.command}`")
else:
raise RuntimeError("rsync failed")
delay = 4
last_exc: Exception | None = None
for attempt in range(4):
try:
result = c.local(cmd, hide=hide_output)
except Exception as e:
last_exc = e
logger.warning(f"rsync attempt {attempt + 1}/4 failed: {e}, retrying in {delay}s...")
time.sleep(delay)
delay = min(delay * 2, 60)
continue
if result:
logger.info(f"Successfully ran `{result.command}`")
return
else:
raise RuntimeError("rsync failed")
assert last_exc is not None
raise last_exc
20 changes: 18 additions & 2 deletions nemo_run/run/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,8 +885,24 @@ def _initialize_tunnels(self, extract_from_executors: bool = False):

for tunnel in self.tunnels.values():
if isinstance(tunnel, SSHTunnel):
tunnel.connect()
assert tunnel.session, f"SSH tunnel {tunnel.key} failed to connect."
delay = 4
last_exc: ConnectionError | None = None
for attempt in range(4):
try:
tunnel.connect()
assert tunnel.session, f"SSH tunnel {tunnel.key} failed to connect."
last_exc = None
break
except ConnectionError as e:
last_exc = e
self.console.log(
f"SSH tunnel {tunnel.key} connection failed "
f"(attempt {attempt + 1}/4): {e}, retrying in {delay}s..."
)
time.sleep(delay)
delay = min(delay * 2, 60)
if last_exc is not None:
raise last_exc

def status(self, return_dict: bool = False) -> Optional[dict[str, dict[str, str]]]:
"""
Expand Down
51 changes: 35 additions & 16 deletions nemo_run/run/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,22 +133,41 @@ def get_logs(
exceptions = Queue()
threads = []
for role_name, replica_id in replica_ids:
thread = threading.Thread(
target=print_log_lines,
args=(
file,
runner,
app_handle,
role_name,
replica_id,
regex,
should_tail,
exceptions,
streams,
),
)
thread.daemon = True
thread.start()
delay = 2
last_exc: RuntimeError | None = None
for attempt in range(4):
thread = threading.Thread(
target=print_log_lines,
args=(
file,
runner,
app_handle,
role_name,
replica_id,
regex,
should_tail,
exceptions,
streams,
),
)
thread.daemon = True
try:
thread.start()
last_exc = None
break
except RuntimeError as e:
if "can't start new thread" in str(e):
last_exc = e
logger.warning(
f"Thread limit reached for {role_name}/{replica_id} "
f"(attempt {attempt + 1}/4), retrying in {delay}s..."
)
time.sleep(delay)
delay = min(delay * 2, 60)
else:
raise
if last_exc is not None:
raise last_exc
threads.append(thread)

for thread in threads:
Expand Down
14 changes: 13 additions & 1 deletion nemo_run/run/torchx_backend/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,20 @@ def wait_and_exit(

tries = 0
status = None
thread_retry_delay = 2
while tries < timeout:
status = runner.wait(app_handle, wait_interval=2)
try:
status = runner.wait(app_handle, wait_interval=2)
except RuntimeError as e:
if "can't start new thread" in str(e):
logger.warning(
f"Thread limit reached while waiting for job {app_id}, "
f"retrying in {thread_retry_delay}s..."
)
time.sleep(thread_retry_delay)
thread_retry_delay = min(thread_retry_delay * 2, 60)
continue
raise
if status:
break
tries += 1
Expand Down
33 changes: 26 additions & 7 deletions nemo_run/run/torchx_backend/schedulers/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest | SlurmRayReques

# Run sbatch script
req.launch_cmd += [dst_path]
job_id = self.tunnel.run(" ".join(req.launch_cmd)).stdout.strip()
job_id = _run_tunnel_cmd(self.tunnel, " ".join(req.launch_cmd)).stdout.strip()

# Save metadata
_save_job_dir(job_id, job_dir, tunnel, slurm_executor.job_details.ls_term)
Expand Down Expand Up @@ -240,9 +240,7 @@ def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
return None

assert self.tunnel, "Tunnel is None."
p = self.tunnel.run(
f"sacct --parsable2 -j {app_id}",
)
p = _run_tunnel_cmd(self.tunnel, f"sacct --parsable2 -j {app_id}")
output = p.stdout.strip().split("\n")

if len(output) <= 1:
Expand Down Expand Up @@ -299,7 +297,7 @@ def list(self) -> list[ListAppResponse]:
# To return all jobs launched, set starttime to one second past unix epoch time
# Starttime will be modified when listing jobs by timeframe is supported
assert self.tunnel, "Tunnel is None."
p = self.tunnel.run("sacct --json -S1970-01-01-00:00:01")
p = _run_tunnel_cmd(self.tunnel, "sacct --json -S1970-01-01-00:00:01")
output_json = json.loads(p.stdout.strip())
return [
ListAppResponse(app_id=str(job["job_id"]), state=SLURM_STATES[job["state"]["current"]])
Expand Down Expand Up @@ -428,7 +426,28 @@ def _save_job_dir(
)


def _get_job_dirs(retries: int = 5) -> dict[str, tuple[str, SSHTunnel | LocalTunnel, str]]:
def _run_tunnel_cmd(tunnel, cmd: str, retries: int = 4, initial_delay: float = 4, **kwargs):
"""Run a tunnel command with exponential-backoff retries on transient failures."""
delay = initial_delay
last_exc: Exception | None = None
for attempt in range(retries):
try:
return tunnel.run(cmd, **kwargs)
except Exception as e:
last_exc = e
log.warning(
f"Tunnel command failed (attempt {attempt + 1}/{retries}): {e}, "
f"retrying in {delay}s..."
)
time.sleep(delay)
delay = min(delay * 2, 60)
assert last_exc is not None
raise last_exc


def _get_job_dirs(
retries: int = 5, initial_delay: float = 1
) -> dict[str, tuple[str, SSHTunnel | LocalTunnel, str]]:
last_exc: OSError | None = None
for attempt in range(retries):
try:
Expand All @@ -439,7 +458,7 @@ def _get_job_dirs(retries: int = 5) -> dict[str, tuple[str, SSHTunnel | LocalTun
return {}
except OSError as e:
last_exc = e
delay = min(2**attempt, 30)
delay = min(initial_delay * 2**attempt, 60)
log.warning(
f"OSError reading {SLURM_JOB_DIRS} (attempt {attempt + 1}/{retries}): {e}. "
f"Retrying in {delay}s..."
Expand Down
Loading
Loading