Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 261 additions & 13 deletions areal/infra/rpc/ray_rpc_server.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import abc
import os
import shlex
import subprocess
import sys
import time
import traceback
from concurrent.futures import Future
from typing import Any

import ray
import requests

from areal.api import InferenceEngine, TrainEngine
from areal.api.cli_args import BaseExperimentConfig
from areal.infra.rpc.rtensor import RTensor
from areal.infra.rpc.serialization import deserialize_value, serialize_value
from areal.infra.utils.proc import kill_process_tree
from areal.utils import logging, name_resolve, seeding
from areal.utils.data import (
broadcast_tensor_container,
Expand All @@ -17,25 +25,19 @@
from areal.utils.network import find_free_ports


@ray.remote
class RayRPCServer:
class RayServer(abc.ABC):
"""
Ray engine container. Represents either:
- one training world rank, or
- one rollout instance

Supports multiple named engines per worker for colocation scenarios.

Placement group scheduling is controlled by the scheduler.
The actor is only responsible for the engine lifecycle and method calls
within this process.
Ray actor base class that all Ray actors under RayScheduler should inherit from
"""

def __init__(self):
def __init__(self, config: BaseExperimentConfig, **kwargs):
self._engines: dict[str, TrainEngine | InferenceEngine] = {}
self._default_engine_name: str | None = None # For backward compatibility
self._allocated_port = set()
self.logger = logging.getLogger("RayRPCServer")
self.config: BaseExperimentConfig = config
ctx = ray.get_runtime_context()
self.actor_name = ctx.get_actor_name()
self.logger = logging.getLogger(self.__class__.__name__)

def _get_device(self):
# lazy resolve the device inside worker process
Expand Down Expand Up @@ -83,6 +85,53 @@ def set_env(self, env: dict[str, str]) -> None:
for k, v in env.items():
os.environ[str(k)] = str(v)

def post_init(self, **kwargs) -> Any:
# the HTTPLauncher needs this, but keeping this here for interface compatibility
# launched after the actor has been deployed
pass

@abc.abstractmethod
def create_engine(
self,
engine: str,
*init_args,
engine_name: str | None = None,
**init_kwargs,
) -> None:
raise NotImplementedError()

@abc.abstractmethod
def call(self, method: str, *args, engine_name: str | None = None, **kwargs) -> Any:
raise NotImplementedError()

@abc.abstractmethod
def destroy(self) -> None:
raise NotImplementedError()

def __ray_shutdown__(self):
self.destroy()

def __repr__(self):
return f"{self.__class__.__name__} [{self.actor_name}]"


@ray.remote
class RayRPCServer(RayServer):
"""
Ray engine container. Represents either:
- one training world rank, or
- one rollout instance

Supports multiple named engines per worker for colocation scenarios.

Placement group scheduling is controlled by the scheduler.
The actor is only responsible for the engine lifecycle and method calls
within this process.
"""

def __init__(self, config: BaseExperimentConfig, **kwargs):
super().__init__(config, **kwargs)

def create_engine(
self,
engine: str,
Expand Down Expand Up @@ -213,3 +262,202 @@ def destroy(self) -> None:
self._engines.clear()
self._default_engine_name = None
ray.actor.exit_actor()


@ray.remote
class RayHTTPLauncher(RayServer):
"""
Ray implementation of a launcher to launch proxy servers and any HTTP servers
"""

REQUIRED_ARGS = ("command", "worker_index", "role")

def __init__(self, config: BaseExperimentConfig, **kwargs):
super().__init__(config, **kwargs)

missing = [k for k in self.REQUIRED_ARGS if k not in kwargs]
if missing:
raise TypeError(f"Missing required kwargs: {missing}")

self.command = kwargs["command"]
self.worker_index = kwargs["worker_index"]
self.role = kwargs["role"]
self.worker_ip = ray.util.get_node_ip_address()
self.worker_port = None
self.worker_process: subprocess.Popen | None = None

def post_init(self, **kwargs):
self.worker_port = kwargs.get("port", self.alloc_ports(1)[0])
self.worker_process = self.launch_server(port=self.worker_port)

def create_engine(
self,
engine: str,
*init_args,
engine_name: str | None = None,
**init_kwargs,
) -> None:
self.logger.debug(f"Initializing engine {engine}")
payload = {
"engine": engine,
"engine_name": engine_name,
"init_args": serialize_value(list(init_args)),
"init_kwargs": serialize_value(init_kwargs),
}
try:
self._post_request("create_engine", payload)
except Exception as e:
self.logger.error(
f"RayHTTPLauncher failed to create engine '{engine}' : {e}\n"
f"{traceback.format_exc()}"
)
raise

def call(
self,
method: str,
*args,
engine_name: str | None = None,
rpc_meta: dict[str, Any] | None = None,
**kwargs,
) -> Any:
self.logger.debug(
f"Calling {method} on engine '{engine_name}' with arguments {args=} {kwargs=}"
)

payload = {
"method": method,
"engine_name": engine_name,
"rpc_meta": rpc_meta,
"args": serialize_value(list(args)),
"kwargs": serialize_value(kwargs),
}
try:
return self._post_request("call", payload)
except Exception as e:
self.logger.error(
f"RayHTTPLauncher failed for '{method}': {e}\n{traceback.format_exc()}"
)
raise

def destroy(self) -> None:
if self.worker_process and self.worker_process.poll() is None:
kill_process_tree(self.worker_process.pid, timeout=3, graceful=True)
self._default_engine_name = None
ray.actor.exit_actor()

def launch_server(self, port):
# keeping this as a separate function to support Awex server launches later
if not self.command:
raise RuntimeError(
f"Command was not given to {self.__class__.__name__}.launch_server. Cannot launch without command."
)

cmd = [sys.executable, "-m"]
cmd.extend(shlex.split(self.command))
cmd.extend(["--port", str(port)])

cmd.extend(["--experiment-name", self.config.experiment_name])
cmd.extend(["--trial-name", self.config.trial_name])
cmd.extend(["--role", self.role])
cmd.extend(["--worker-index", str(self.worker_index)])

cluster_config = self.config.cluster
name_resolve = self.config.cluster.name_resolve

cmd.extend(["--name-resolve-type", name_resolve.type])
cmd.extend(["--nfs-record-root", name_resolve.nfs_record_root])
cmd.extend(["--etcd3-addr", name_resolve.etcd3_addr])
cmd.extend(["--fileroot", str(cluster_config.fileroot)])

_env = os.environ.copy()
self.worker_process = subprocess.Popen(
cmd, env=_env, stdout=sys.stdout, stderr=subprocess.STDOUT
)

try:
self._check_health()
except Exception as e:
self.logger.error(e)
kill_process_tree(self.worker_process.pid, timeout=3, graceful=True)
raise RuntimeError(f"Could not launch server with command {cmd}")

return self.worker_process

def _post_request(
self,
endpoint,
payload,
http_timeout: float = 7200.0,
max_retries: int = 3,
retry_delay: float = 1.0,
):
url = f"{self.url}/{endpoint}"
last_error = ""
# adapted from local scheduler
for attempt in range(1, max_retries + 1):
if self.worker_process and self.worker_process.poll() is not None:
raise RuntimeError("Worker has terminated")

try:
response = requests.post(url, json=payload, timeout=http_timeout)
response.raise_for_status()
result = response.json().get("result")
deserialized_result = deserialize_value(result)
return deserialized_result

except requests.exceptions.HTTPError as e:
resp = e.response

if resp is not None and resp.status_code in [400, 404, 500]:
try:
error_detail = resp.json().get("detail", "unknown error")
except Exception:
error_detail = resp.text or "unknown error"
raise RuntimeError(error_detail)

last_error = (
f"HTTP {resp.status_code}: {resp.text}"
if resp is not None
else str(e)
)
except Exception as e:
last_error = str(e)
self.logger.warning(
f"Post failed when calling url {url} on actor '{self.actor_name}': {e}"
)

# otherwise retry
if attempt < max_retries:
delay = retry_delay * (2 ** (attempt - 1))
self.logger.warning(
f"Calling url {url} failed on actor '{self.actor_name}' "
f"(attempt {attempt}/{max_retries}): {last_error}. "
f"Retrying in {delay:.1f}s..."
)
time.sleep(delay)
raise RuntimeError(
f"Max retries exceeded trying to call url {url}: {last_error or 'unknown error'}"
)

@property
def url(self):
return f"http://{self.worker_ip}:{self.worker_port}"

def _check_health(self, timeout: float = 60.0):
url = f"{self.url}/health"
deadline = time.time() + timeout
while time.time() < deadline:
if self.worker_process and self.worker_process.poll() is not None:
raise RuntimeError("Server process exited before becoming healthy")

try:
r = requests.get(url, timeout=2.0)
if r.status_code == 200:
return
except requests.RequestException:
# expected during startup
pass
time.sleep(1)

raise RuntimeError(f"Health check timed out for {url}")
Loading
Loading