From 3b9db5a8870c72fa7b35ad53f4b9b7e4b70e4509 Mon Sep 17 00:00:00 2001 From: zhtshr <547553598@qq.com> Date: Tue, 31 Mar 2026 15:43:27 +0800 Subject: [PATCH 1/9] add elastic scheduling and make the system more stable --- configs/disagg/wan_t2v_disagg_controller.json | 1 + lightx2v/disagg/conn.py | 9 + lightx2v/disagg/examples/run_service.py | 10 + lightx2v/disagg/monitor.py | 13 + lightx2v/disagg/rdma_client.py | 91 ++++- lightx2v/disagg/rdma_server.py | 9 +- lightx2v/disagg/services/controller.py | 369 +++++++++++++++++- lightx2v/disagg/services/decoder.py | 55 ++- lightx2v/disagg/services/encoder.py | 49 ++- lightx2v/disagg/services/transformer.py | 66 +++- scripts/disagg/kill_service.sh | 14 +- scripts/disagg/run_wan_t2v_service.sh | 102 ++--- 12 files changed, 695 insertions(+), 93 deletions(-) diff --git a/configs/disagg/wan_t2v_disagg_controller.json b/configs/disagg/wan_t2v_disagg_controller.json index d55badf24..233318e7c 100644 --- a/configs/disagg/wan_t2v_disagg_controller.json +++ b/configs/disagg/wan_t2v_disagg_controller.json @@ -16,6 +16,7 @@ "disagg_config": { "bootstrap_addr": "127.0.0.1", "bootstrap_room": 0, + "ranks" : 8, "encoder_engine_rank": 0, "transformer_engine_rank": 1, "decoder_engine_rank": 2, diff --git a/lightx2v/disagg/conn.py b/lightx2v/disagg/conn.py index cd592de7d..90ee69909 100644 --- a/lightx2v/disagg/conn.py +++ b/lightx2v/disagg/conn.py @@ -480,6 +480,15 @@ def enqueue_request( if self.transfer_event is not None: self.transfer_event.set() + def get_backlog_counts(self) -> Dict[str, int]: + with self.pool_lock: + waiting_pool_size = len(self.waiting_pool) if hasattr(self, "waiting_pool") else 0 + return { + "request_pool": len(self.request_pool), + "waiting_pool": waiting_pool_size, + "request_status": len(self.request_status), + } + def check_status(self, bootstrap_room: int): with self.pool_lock: if ( diff --git a/lightx2v/disagg/examples/run_service.py b/lightx2v/disagg/examples/run_service.py index 3c1659284..515265214 100644 --- a/lightx2v/disagg/examples/run_service.py +++ b/lightx2v/disagg/examples/run_service.py @@ -49,6 +49,12 @@ def _build_parser() -> argparse.ArgumentParser: default="auto", help="Service role. auto = infer from config_json.disagg_mode", ) + parser.add_argument( + "--engine_rank", + type=int, + default=None, + help="Override engine rank for encoder/transformer/decoder service.", + ) return parser @@ -110,6 +116,10 @@ def main(): config, raw_cfg = _build_runtime_config(args) service_mode = _resolve_service_mode(args, raw_cfg) + if args.engine_rank is not None and service_mode in {"encoder", "transformer", "decoder"}: + rank_key = f"{service_mode}_engine_rank" + config[rank_key] = int(args.engine_rank) + seed_all(args.seed) logger.info("Starting disagg service mode={}", service_mode) diff --git a/lightx2v/disagg/monitor.py b/lightx2v/disagg/monitor.py index d29c7136b..f9b940a59 100644 --- a/lightx2v/disagg/monitor.py +++ b/lightx2v/disagg/monitor.py @@ -26,6 +26,12 @@ def __init__(self, service_type: str, gpu_id: int, bind_address: str): ) self._context = zmq.Context.instance() self._stop_event = threading.Event() + self._metrics_lock = threading.Lock() + self._extra_metrics_provider: Optional[Callable[[], Dict[str, Any]]] = None + + def set_extra_metrics_provider(self, provider: Optional[Callable[[], Dict[str, Any]]]): + with self._metrics_lock: + self._extra_metrics_provider = provider def _query_gpu_metrics(self) -> Dict[str, Any]: cmd = [ @@ -55,6 +61,13 @@ def get_metrics(self) -> Dict[str, Any]: try: metrics.update(self._query_gpu_metrics()) metrics["status"] = "ok" + + with self._metrics_lock: + provider = self._extra_metrics_provider + if provider is not None: + extra_metrics = provider() + if isinstance(extra_metrics, dict): + metrics.update(extra_metrics) except Exception as exc: metrics["status"] = "error" metrics["error"] = str(exc) diff --git a/lightx2v/disagg/rdma_client.py b/lightx2v/disagg/rdma_client.py index 0c615eb49..1cd26b65e 100644 --- a/lightx2v/disagg/rdma_client.py +++ b/lightx2v/disagg/rdma_client.py @@ -29,12 +29,15 @@ class QPType: class WROpcode: RDMA_WRITE = e.IBV_WR_RDMA_WRITE RDMA_READ = e.IBV_WR_RDMA_READ + ATOMIC_FETCH_AND_ADD = e.IBV_WR_ATOMIC_FETCH_AND_ADD + ATOMIC_CMP_AND_SWP = e.IBV_WR_ATOMIC_CMP_AND_SWP class AccessFlag: LOCAL_WRITE = e.IBV_ACCESS_LOCAL_WRITE REMOTE_WRITE = e.IBV_ACCESS_REMOTE_WRITE REMOTE_READ = e.IBV_ACCESS_REMOTE_READ + REMOTE_ATOMIC = e.IBV_ACCESS_REMOTE_ATOMIC class RDMAClient: @@ -100,7 +103,7 @@ def connect_to_server(self, server_ip="127.0.0.1", port=5566): def _modify_qp_to_rts(self): # Follow the standard RC flow: INIT -> RTR -> RTS. init_attr = QPAttr(port_num=self.port_num) - init_attr.qp_access_flags = AccessFlag.LOCAL_WRITE | AccessFlag.REMOTE_WRITE | AccessFlag.REMOTE_READ + init_attr.qp_access_flags = AccessFlag.LOCAL_WRITE | AccessFlag.REMOTE_WRITE | AccessFlag.REMOTE_READ | AccessFlag.REMOTE_ATOMIC self.qp.to_init(init_attr) rtr_attr = QPAttr(port_num=self.port_num) @@ -208,16 +211,60 @@ def rdma_read_from(self, remote_addr, length, rkey=None): self.remote_info["rkey"] = old_rkey def rdma_faa(self, remote_addr, add_value, rkey=None): - """Best-effort FAA semantics via read-modify-write. + """Execute true remote atomic fetch-and-add and return previous value.""" + with self._io_lock: + self._ensure_local_mr_capacity(8) - NOTE: This is not a true remote atomic verb; it is a compatibility shim - until atomic WR support is implemented. - """ + # The original remote value will be written into this local buffer. + self.local_mr.write(b"\x00" * 8, 8, 0) + + sge = SGE(self.local_mr.buf, 8, self.local_mr.lkey) + wr = WR( + wr_id=125, + opcode=WROpcode.ATOMIC_FETCH_AND_ADD, + num_sge=1, + sg=[sge], + send_flags=e.IBV_SEND_SIGNALED, + ) + + target_rkey = int(self.remote_info["rkey"] if rkey is None else rkey) + add_u64 = int(add_value) & ((1 << 64) - 1) + wr.set_wr_atomic(target_rkey, int(remote_addr), add_u64, 0) + + self.qp.post_send(wr) + self._poll_cq() + + old = self.local_mr.read(8, 0) + old_v = int.from_bytes(old, byteorder="little", signed=False) + return old_v + + def rdma_cas(self, remote_addr, compare_value, swap_value, rkey=None): + """Execute true remote atomic compare-and-swap and return previous value.""" with self._io_lock: - old = self.rdma_read_from(int(remote_addr), 8, rkey=rkey) + self._ensure_local_mr_capacity(8) + + # The original remote value will be written into this local buffer. + self.local_mr.write(b"\x00" * 8, 8, 0) + + sge = SGE(self.local_mr.buf, 8, self.local_mr.lkey) + wr = WR( + wr_id=126, + opcode=WROpcode.ATOMIC_CMP_AND_SWP, + num_sge=1, + sg=[sge], + send_flags=e.IBV_SEND_SIGNALED, + ) + + target_rkey = int(self.remote_info["rkey"] if rkey is None else rkey) + compare_u64 = int(compare_value) & ((1 << 64) - 1) + swap_u64 = int(swap_value) & ((1 << 64) - 1) + wr.set_wr_atomic(target_rkey, int(remote_addr), compare_u64, swap_u64) + + self.qp.post_send(wr) + self._poll_cq() + + old = self.local_mr.read(8, 0) old_v = int.from_bytes(old, byteorder="little", signed=False) - new_v = (old_v + int(add_value)) & ((1 << 64) - 1) - self.rdma_write_to(int(remote_addr), new_v.to_bytes(8, byteorder="little", signed=False), rkey=rkey) return old_v def _poll_cq(self): @@ -245,10 +292,28 @@ def _poll_cq(self): # cli.connect_to_server('127.0.0.1') # 替换为服务器 IP # # 执行单边写 -# msg = b"Hello RDMA!" -# cli.rdma_write(msg) -# print("Write done.") +# # msg = b"Hello RDMA!" +# # cli.rdma_write(msg) +# # print("Write done.") + +# # # 执行单边读 +# # data = cli.rdma_read(len(msg)) +# # print("Read data:", data) + +# # 执行单边写(rdma_write 需要 bytes-like 数据) +# value = 123 +# payload = int(value).to_bytes(8, byteorder="little", signed=False) +# cli.rdma_write(payload) +# print(f"Write done. value={value}") # # 执行单边读 -# data = cli.rdma_read(len(msg)) -# print("Read data:", data) +# data = cli.rdma_read(8) +# read_value = int.from_bytes(data, byteorder="little", signed=False) +# print(f"Read data: raw={data} parsed={read_value}") + +# old_value = cli.rdma_faa(remote_addr=cli.remote_info["addr"], add_value=10) +# print(f"FAA old value: {old_value}") + +# data = cli.rdma_read(8) +# faa_read_value = int.from_bytes(data, byteorder="little", signed=False) +# print(f"Read data after FAA: raw={data} parsed={faa_read_value}") diff --git a/lightx2v/disagg/rdma_server.py b/lightx2v/disagg/rdma_server.py index 0111a4447..2b2658d33 100644 --- a/lightx2v/disagg/rdma_server.py +++ b/lightx2v/disagg/rdma_server.py @@ -31,6 +31,7 @@ class AccessFlag: LOCAL_WRITE = e.IBV_ACCESS_LOCAL_WRITE REMOTE_WRITE = e.IBV_ACCESS_REMOTE_WRITE REMOTE_READ = e.IBV_ACCESS_REMOTE_READ + REMOTE_ATOMIC = e.IBV_ACCESS_REMOTE_ATOMIC class RDMAServer: @@ -74,7 +75,11 @@ def __init__(self, iface_name=None, port_num=1, buffer_size=4096): # 关键:注册一块内存用于被远程访问 # buffer_size 可配置,允许远程写入 (REMOTE_WRITE) 和远程读取 (REMOTE_READ) - self.mr = MR(self.pd, self.buffer_size, AccessFlag.LOCAL_WRITE | AccessFlag.REMOTE_WRITE | AccessFlag.REMOTE_READ) + self.mr = MR( + self.pd, + self.buffer_size, + AccessFlag.LOCAL_WRITE | AccessFlag.REMOTE_WRITE | AccessFlag.REMOTE_READ | AccessFlag.REMOTE_ATOMIC, + ) # 初始化缓冲区数据 (例如全为 0) zeros = b"\x00" * self.buffer_size @@ -185,7 +190,7 @@ def handshake(self, host="0.0.0.0", port=5566, serve_forever=True): def _modify_qp_to_rts(self, qp, remote_info, local_psn): # Follow the standard RC flow: INIT -> RTR -> RTS. init_attr = QPAttr(port_num=self.port_num) - init_attr.qp_access_flags = AccessFlag.LOCAL_WRITE | AccessFlag.REMOTE_WRITE | AccessFlag.REMOTE_READ + init_attr.qp_access_flags = AccessFlag.LOCAL_WRITE | AccessFlag.REMOTE_WRITE | AccessFlag.REMOTE_READ | AccessFlag.REMOTE_ATOMIC qp.to_init(init_attr) rtr_attr = QPAttr(port_num=self.port_num) diff --git a/lightx2v/disagg/services/controller.py b/lightx2v/disagg/services/controller.py index 8fdfc1696..f92c491df 100644 --- a/lightx2v/disagg/services/controller.py +++ b/lightx2v/disagg/services/controller.py @@ -1,5 +1,12 @@ +import os +import socket +import subprocess +import sys +import time +from collections.abc import Mapping from pathlib import Path from threading import Event, Lock, Thread +from typing import Any from lightx2v.disagg.conn import MONITOR_POLLING_PORT, REQUEST_POLLING_PORT, ReqManager from lightx2v.disagg.monitor import Monitor @@ -27,6 +34,226 @@ def __init__(self): self._rdma_handshake_thread_request: Thread | None = None self._rdma_handshake_thread_phase1: Thread | None = None self._rdma_handshake_thread_phase2: Thread | None = None + self._instance_lock = Lock() + self._free_gpus: set[int] = set() + self._managed_instances: dict[str, dict[str, Any]] = {} + self.started_instances: list[tuple[str, str]] = [] + self._runtime_config: dict[str, Any] | None = None + self._bootstrap_addr: str = "127.0.0.1" + self._gpu_reuse_block_until: dict[int, float] = {} + self._gpu_reuse_grace_seconds: float = 5.0 + self._shutting_down: bool = False + + def _is_tcp_port_open(self, host: str, port: int) -> bool: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(0.2) + return sock.connect_ex((host, port)) == 0 + + def _wait_for_tcp_port_state(self, host: str, port: int, should_be_open: bool, timeout_seconds: float) -> bool: + deadline = time.time() + timeout_seconds + while time.time() < deadline: + is_open = self._is_tcp_port_open(host, port) + if is_open == should_be_open: + return True + time.sleep(0.1) + return self._is_tcp_port_open(host, port) == should_be_open + + def _to_plain(self, value: Any) -> Any: + """Recursively convert config containers (e.g. LockableDict) to built-in Python types.""" + if isinstance(value, Mapping): + return {k: self._to_plain(v) for k, v in value.items()} + if isinstance(value, list): + return [self._to_plain(v) for v in value] + if isinstance(value, tuple): + return tuple(self._to_plain(v) for v in value) + if isinstance(value, set): + return {self._to_plain(v) for v in value} + return value + + def _monitor_node_from_instance_address(self, instance_address: str) -> str: + host, port_str = instance_address.rsplit(":", 1) + rank = int(port_str) - REQUEST_POLLING_PORT + return f"tcp://{host}:{MONITOR_POLLING_PORT + rank}" + + def _instance_address_from_monitor_node(self, monitor_node: str) -> str: + host_port = monitor_node + if host_port.startswith("tcp://"): + host_port = host_port[len("tcp://") :] + host, port_str = host_port.rsplit(":", 1) + rank = int(port_str) - MONITOR_POLLING_PORT + return f"{host}:{REQUEST_POLLING_PORT + rank}" + + def _init_gpu_pool(self, config: dict): + disagg_cfg = config.get("disagg_config") if isinstance(config.get("disagg_config"), dict) else {} + total_ranks = int(config.get("ranks", disagg_cfg.get("ranks", 8))) + if total_ranks <= 0: + raise ValueError("ranks must be positive") + + self._free_gpus = set(range(total_ranks)) + + def create_instance(self, instance_type: str) -> str: + """Create one service instance on an idle GPU and add it to scheduling pool.""" + if instance_type not in {"encoder", "transformer", "decoder"}: + raise ValueError("instance_type must be one of: encoder, transformer, decoder") + if self._runtime_config is None: + raise RuntimeError("controller runtime config is not initialized") + + with self._instance_lock: + if not self._free_gpus: + raise RuntimeError("no idle GPU available") + + now = time.time() + gpu_id: int | None = None + for candidate_gpu in sorted(self._free_gpus): + if now < self._gpu_reuse_block_until.get(candidate_gpu, 0.0): + continue + + monitor_port = MONITOR_POLLING_PORT + candidate_gpu + if self._is_tcp_port_open(self._bootstrap_addr, monitor_port): + self.logger.warning( + "Skip gpu=%s for %s creation because monitor port %s is still in use", + candidate_gpu, + instance_type, + monitor_port, + ) + continue + + gpu_id = candidate_gpu + break + + if gpu_id is None: + raise RuntimeError(f"no idle GPU available for {instance_type}: all candidates cooling down or port is in use") + + instance_cfg = self._to_plain(self._runtime_config) + instance_cfg["disagg_mode"] = instance_type + if instance_type == "encoder": + instance_cfg["encoder_engine_rank"] = gpu_id + elif instance_type == "transformer": + instance_cfg["transformer_engine_rank"] = gpu_id + else: + instance_cfg["decoder_engine_rank"] = gpu_id + + model_path = instance_cfg.get("model_path") + config_json = instance_cfg.get("config_json") + if not model_path or not config_json: + raise RuntimeError("model_path and config_json are required to launch service subprocess") + + cmd = [ + sys.executable, + "-m", + "lightx2v.disagg.examples.run_service", + "--service", + instance_type, + "--engine_rank", + str(gpu_id), + "--model_cls", + str(instance_cfg.get("model_cls", "wan2.1")), + "--task", + str(instance_cfg.get("task", "t2v")), + "--model_path", + str(model_path), + "--config_json", + str(config_json), + "--seed", + str(instance_cfg.get("seed", 42)), + "--prompt", + str(instance_cfg.get("prompt", "")), + "--negative_prompt", + str(instance_cfg.get("negative_prompt", "")), + "--save_result_path", + str(instance_cfg.get("save_path", "")), + ] + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + process = subprocess.Popen(cmd, env=env) + + monitor_port = MONITOR_POLLING_PORT + gpu_id + if not self._wait_for_tcp_port_state(self._bootstrap_addr, monitor_port, should_be_open=True, timeout_seconds=8.0): + if process.poll() is None: + process.terminate() + try: + process.wait(timeout=3.0) + except subprocess.TimeoutExpired: + process.kill() + raise RuntimeError(f"service {instance_type} on gpu={gpu_id} failed to expose monitor port {monitor_port}") + + instance_address = f"{self._bootstrap_addr}:{REQUEST_POLLING_PORT + gpu_id}" + self._free_gpus.remove(gpu_id) + self.add_instance(instance_type, instance_address) + monitor_node = f"tcp://{self._bootstrap_addr}:{MONITOR_POLLING_PORT + gpu_id}" + if monitor_node not in self.monitor.nodes: + self.monitor.nodes.append(monitor_node) + self._managed_instances[instance_address] = { + "instance_type": instance_type, + "gpu_id": gpu_id, + "process": process, + } + self.started_instances.append((instance_type, instance_address)) + self.logger.info( + "Created %s instance on gpu=%s pid=%s address=%s", + instance_type, + gpu_id, + process.pid, + instance_address, + ) + return instance_address + + def reclaim_instance(self, instance_type: str, instance_address: str | None = None) -> str: + """Reclaim one managed instance and return its GPU back to idle pool.""" + if instance_type not in {"encoder", "transformer", "decoder"}: + raise ValueError("instance_type must be one of: encoder, transformer, decoder") + + with self._instance_lock: + target_address = instance_address + if target_address is None: + candidates = [addr for addr, meta in self._managed_instances.items() if meta.get("instance_type") == instance_type] + if not candidates: + raise RuntimeError(f"no managed {instance_type} instance to reclaim") + target_address = candidates[-1] + + meta = self._managed_instances.get(target_address) + if meta is None: + raise RuntimeError(f"instance not managed by controller: {target_address}") + if meta.get("instance_type") != instance_type: + raise RuntimeError(f"instance type mismatch for {target_address}: expected={instance_type} got={meta.get('instance_type')}") + + process = meta.get("process") + gpu_id = int(meta.get("gpu_id")) + + self.remove_instance(instance_type, target_address) + monitor_node = self._monitor_node_from_instance_address(target_address) + if monitor_node in self.monitor.nodes: + self.monitor.nodes.remove(monitor_node) + + if process is not None and process.poll() is None: + process.terminate() + try: + process.wait(timeout=5.0) + except subprocess.TimeoutExpired: + process.kill() + process.wait(timeout=1.0) + + monitor_port = MONITOR_POLLING_PORT + gpu_id + if not self._wait_for_tcp_port_state(self._bootstrap_addr, monitor_port, should_be_open=False, timeout_seconds=5.0): + self.logger.warning( + "Monitor port still open after reclaim: service=%s gpu=%s port=%s", + instance_type, + gpu_id, + monitor_port, + ) + + self._free_gpus.add(gpu_id) + self._gpu_reuse_block_until[gpu_id] = time.time() + self._gpu_reuse_grace_seconds + self._managed_instances.pop(target_address, None) + if (instance_type, target_address) in self.started_instances: + self.started_instances.remove((instance_type, target_address)) + self.logger.info( + "Reclaimed %s instance from gpu=%s address=%s", + instance_type, + gpu_id, + target_address, + ) + return target_address def _init_request_rdma_buffer(self, bootstrap_addr: str, config: dict): slots = int(config.get("rdma_buffer_slots", "128")) @@ -149,16 +376,21 @@ def send_request(self, config): self.logger.info("Request enqueued to encoder request RDMA buffer") def run(self, config): - """Initialize instances, send requests, wait for decoder save_path callbacks, then exit.""" + """Initialize controller buffers, send requests, wait for decoder save_path callbacks, then exit.""" if config is None: raise ValueError("config cannot be None") + self._shutting_down = False + bootstrap_addr = config.get("data_bootstrap_addr", "127.0.0.1") encoder_engine_rank = config.get("encoder_engine_rank", 0) transformer_engine_rank = config.get("transformer_engine_rank", 1) decoder_engine_rank = config.get("decoder_engine_rank", 2) - request_count = int(config.get("request_count", 2)) + request_count = int(config.get("request_count", 10)) result_port = int(config.get("controller_result_port", REQUEST_POLLING_PORT - 1)) + self._bootstrap_addr = str(bootstrap_addr) + self._runtime_config = self._to_plain(config) + self._init_gpu_pool(config) self.encoder_policy = RoundRobinPolicy() self.transformer_policy = RoundRobinPolicy() @@ -166,12 +398,8 @@ def run(self, config): self._init_request_rdma_buffer(bootstrap_addr, config) - self.add_instance("encoder", f"{bootstrap_addr}:{REQUEST_POLLING_PORT + encoder_engine_rank}") - self.add_instance( - "transformer", - f"{bootstrap_addr}:{REQUEST_POLLING_PORT + transformer_engine_rank}", - ) - self.add_instance("decoder", f"{bootstrap_addr}:{REQUEST_POLLING_PORT + decoder_engine_rank}") + for instance_type in ("encoder", "transformer", "decoder"): + address = self.create_instance(instance_type) monitor_nodes = [ f"tcp://{bootstrap_addr}:{MONITOR_POLLING_PORT + encoder_engine_rank}", @@ -181,10 +409,114 @@ def run(self, config): self.monitor.nodes = monitor_nodes monitor_stop_event = Event() + scale_out_threshold = 80.0 + scale_in_threshold = 20.0 + scale_cooldown_seconds = 30.0 + last_scale_ts: dict[str, float] = { + "encoder": 0.0, + "transformer": 0.0, + "decoder": 0.0, + } def _monitor_callback(results): + if self._shutting_down: + return + + service_metrics: dict[str, list[dict[str, Any]]] = { + "encoder": [], + "transformer": [], + "decoder": [], + } + for item in results: self.logger.info("monitor: %s", item) + if not isinstance(item, dict): + continue + + service_type = str(item.get("service_type", "")) + if service_type not in {"encoder", "transformer", "decoder"}: + continue + + if item.get("status") != "ok": + continue + + try: + gpu_utilization = float(item.get("gpu_utilization", 0.0)) + except (TypeError, ValueError): + continue + + monitor_address = str(item.get("address", "")) + if not monitor_address: + continue + + queue_total_pending = item.get("queue_total_pending", None) + try: + queue_total_pending_int = int(queue_total_pending) if queue_total_pending is not None else -1 + except (TypeError, ValueError): + queue_total_pending_int = -1 + + all_queues_empty = bool(item.get("all_queues_empty", False)) + + service_metrics[service_type].append( + { + "gpu_utilization": gpu_utilization, + "monitor_address": monitor_address, + "queue_total_pending": queue_total_pending_int, + "all_queues_empty": all_queues_empty, + } + ) + + for service_type, metrics in service_metrics.items(): + if not metrics: + continue + + now = time.time() + avg_gpu_utilization = sum(float(metric["gpu_utilization"]) for metric in metrics) / len(metrics) + + if avg_gpu_utilization > scale_out_threshold and now - last_scale_ts[service_type] >= scale_cooldown_seconds: + try: + new_address = self.create_instance(service_type) + last_scale_ts[service_type] = now + self.logger.info( + "Auto-scale out triggered: service=%s avg_gpu_utilization=%.2f new_instance=%s", + service_type, + avg_gpu_utilization, + new_address, + ) + except Exception as exc: + self.logger.warning( + "Auto-scale out skipped for service=%s avg_gpu_utilization=%.2f reason=%s", + service_type, + avg_gpu_utilization, + exc, + ) + + low_metric = min(metrics, key=lambda metric: float(metric["gpu_utilization"])) + low_utilization = float(low_metric["gpu_utilization"]) + low_monitor_address = str(low_metric["monitor_address"]) + with self._instance_lock: + service_instance_count = sum(1 for meta in self._managed_instances.values() if meta.get("instance_type") == service_type) + + queues_empty_for_service = bool(low_metric.get("all_queues_empty", False)) and int(low_metric.get("queue_total_pending", -1)) == 0 + + if low_utilization < scale_in_threshold and service_instance_count > 1 and queues_empty_for_service and now - last_scale_ts[service_type] >= scale_cooldown_seconds: + try: + target_instance_address = self._instance_address_from_monitor_node(low_monitor_address) + self.reclaim_instance(service_type, target_instance_address) + last_scale_ts[service_type] = now + self.logger.info( + "Auto-scale in triggered: service=%s low_gpu_utilization=%.2f reclaimed_instance=%s", + service_type, + low_utilization, + target_instance_address, + ) + except Exception as exc: + self.logger.warning( + "Auto-scale in skipped for service=%s low_gpu_utilization=%.2f reason=%s", + service_type, + low_utilization, + exc, + ) monitor_thread = Thread( target=self.monitor.run_forever, @@ -196,7 +528,7 @@ def _monitor_callback(results): name="controller-monitor", daemon=True, ) - # monitor_thread.start() + monitor_thread.start() base_save_path = config.get("save_path") expected_rooms: set[int] = set() @@ -210,11 +542,16 @@ def _monitor_callback(results): request_config["controller_result_port"] = result_port if base_save_path: save_path = Path(base_save_path) - request_config["save_path"] = str(save_path.with_name(f"{save_path.stem}{i + 1}{save_path.suffix}")) + request_config["save_path"] = str(save_path.with_name(f"{save_path.stem}{i}{save_path.suffix}")) # TODO: use queue to receive request from client and dispatch, currently we just send the same request multiple times for testing with self._lock: current_request = request_config self.send_request(current_request) + self.logger.info( + "Dispatched request room=%s save_path=%s", + i, + request_config.get("save_path"), + ) expected_rooms.add(i) @@ -262,6 +599,12 @@ def _monitor_callback(results): self.logger.info("All decoder results received. Controller exiting.") finally: - pass - # monitor_stop_event.set() - # monitor_thread.join(timeout=1.0) + self._shutting_down = True + monitor_stop_event.set() + monitor_thread.join(timeout=2.0) + + for instance_type, address in reversed(list(self.started_instances)): + try: + self.reclaim_instance(instance_type, address) + except Exception: + self.logger.exception("Failed to reclaim %s instance address=%s", instance_type, address) diff --git a/lightx2v/disagg/services/decoder.py b/lightx2v/disagg/services/decoder.py index 7513f1528..b558470fa 100644 --- a/lightx2v/disagg/services/decoder.py +++ b/lightx2v/disagg/services/decoder.py @@ -3,7 +3,7 @@ import threading import time from collections import deque -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch @@ -48,6 +48,13 @@ def __init__(self, config: dict): gpu_id=self.decoder_engine_rank, bind_address=f"tcp://{self.config.get('data_bootstrap_addr', '127.0.0.1')}:{MONITOR_POLLING_PORT + self.decoder_engine_rank}", ) + self._queue_metrics_lock = threading.Lock() + self._queue_metrics: dict[str, Any] = { + "queue_sizes": {}, + "queue_total_pending": 0, + "all_queues_empty": True, + } + self.reporter.set_extra_metrics_provider(self._get_queue_metrics) self._reporter_thread: Optional[threading.Thread] = threading.Thread( target=self.reporter.serve_forever, name="decoder-reporter", @@ -56,6 +63,28 @@ def __init__(self, config: dict): self._reporter_thread.start() self.load_models() + def _get_queue_metrics(self) -> dict[str, Any]: + with self._queue_metrics_lock: + queue_sizes = dict(self._queue_metrics.get("queue_sizes", {})) + return { + "queue_sizes": queue_sizes, + "queue_total_pending": int(self._queue_metrics.get("queue_total_pending", 0)), + "all_queues_empty": bool(self._queue_metrics.get("all_queues_empty", True)), + } + + def _update_queue_metrics(self, queue_sizes: dict[str, int], transfer_sizes: Optional[dict[str, int]] = None): + merged_sizes = {k: int(v) for k, v in queue_sizes.items()} + if transfer_sizes is not None: + for key, value in transfer_sizes.items(): + merged_sizes[f"transfer_{key}"] = int(value) + total_pending = sum(max(v, 0) for v in merged_sizes.values()) + with self._queue_metrics_lock: + self._queue_metrics = { + "queue_sizes": merged_sizes, + "queue_total_pending": total_pending, + "all_queues_empty": total_pending == 0, + } + def _ensure_phase2_request_buffer(self) -> bool: if self._phase2_rdma_buffer is not None: return True @@ -93,10 +122,6 @@ def init(self, config): self._phase2_slots = shared_slots self._phase2_slot_size = shared_slot_size - self.encoder_engine_rank = int(self.config.get("encoder_engine_rank", 0)) - self.transformer_engine_rank = int(self.config.get("transformer_engine_rank", 1)) - self.decoder_engine_rank = int(self.config.get("decoder_engine_rank", 2)) - if "seed" in self.config: seed_all(self.config["seed"]) @@ -257,6 +282,19 @@ def run(self, stop_event=None): exec_queue = deque() while True: + transfer_sizes = self.data_mgr.get_backlog_counts() if self.data_mgr is not None else {"request_pool": 0, "waiting_pool": 0} + self._update_queue_metrics( + { + "req_queue": len(req_queue), + "waiting_queue": len(waiting_queue), + "exec_queue": len(exec_queue), + }, + { + "request_pool": int(transfer_sizes.get("request_pool", 0)), + "waiting_pool": int(transfer_sizes.get("waiting_pool", 0)), + }, + ) + if self._phase2_rdma_buffer is None: try: self._ensure_phase2_request_buffer() @@ -268,9 +306,10 @@ def run(self, stop_event=None): if packet is not None: if isinstance(packet, dict) and "request_config" in packet: config = dict(packet.get("request_config") or {}) - config["transformer_node_address"] = packet.get("transformer_node_address", config.get("transformer_node_address", "127.0.0.1")) + config["transformer_node_address"] = packet.get("transformer_node_address", "127.0.0.1") else: config = packet + self.logger.info("Received request config from RDMA buffer: %s", {k: v for k, v in config.items()}) req_queue.append(config) if req_queue: @@ -310,7 +349,7 @@ def run(self, stop_event=None): room, config = exec_queue.popleft() try: save_path = self.process(config) - callback_host = str(config.get("controller_result_host", config.get("data_bootstrap_addr", "127.0.0.1"))) + callback_host = str(config.get("controller_result_host", "127.0.0.1")) callback_port = int(config.get("controller_result_port")) if config.get("controller_result_port") is not None else None if callback_port is not None: self.req_mgr.send( @@ -324,7 +363,7 @@ def run(self, stop_event=None): ) except Exception: self.logger.exception("Failed to process request for room=%s", room) - callback_host = str(config.get("controller_result_host", config.get("data_bootstrap_addr", "127.0.0.1"))) + callback_host = str(config.get("controller_result_host", "127.0.0.1")) callback_port = int(config.get("controller_result_port")) if config.get("controller_result_port") is not None else None if callback_port is not None: self.req_mgr.send( diff --git a/lightx2v/disagg/services/encoder.py b/lightx2v/disagg/services/encoder.py index 396cfcab0..057dde5cc 100644 --- a/lightx2v/disagg/services/encoder.py +++ b/lightx2v/disagg/services/encoder.py @@ -3,7 +3,7 @@ import threading import time from collections import deque -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import numpy as np import torch @@ -63,6 +63,13 @@ def __init__(self, config: dict): gpu_id=self.encoder_engine_rank, bind_address=f"tcp://{self.config.get('data_bootstrap_addr', '127.0.0.1')}:{MONITOR_POLLING_PORT + self.encoder_engine_rank}", ) + self._queue_metrics_lock = threading.Lock() + self._queue_metrics: dict[str, Any] = { + "queue_sizes": {}, + "queue_total_pending": 0, + "all_queues_empty": True, + } + self.reporter.set_extra_metrics_provider(self._get_queue_metrics) self._reporter_thread: Optional[threading.Thread] = threading.Thread( target=self.reporter.serve_forever, name="encoder-reporter", @@ -71,6 +78,28 @@ def __init__(self, config: dict): self._reporter_thread.start() self.load_models() + def _get_queue_metrics(self) -> dict[str, Any]: + with self._queue_metrics_lock: + queue_sizes = dict(self._queue_metrics.get("queue_sizes", {})) + return { + "queue_sizes": queue_sizes, + "queue_total_pending": int(self._queue_metrics.get("queue_total_pending", 0)), + "all_queues_empty": bool(self._queue_metrics.get("all_queues_empty", True)), + } + + def _update_queue_metrics(self, queue_sizes: dict[str, int], transfer_sizes: Optional[dict[str, int]] = None): + merged_sizes = {k: int(v) for k, v in queue_sizes.items()} + if transfer_sizes is not None: + for key, value in transfer_sizes.items(): + merged_sizes[f"transfer_{key}"] = int(value) + total_pending = sum(max(v, 0) for v in merged_sizes.values()) + with self._queue_metrics_lock: + self._queue_metrics = { + "queue_sizes": merged_sizes, + "queue_total_pending": total_pending, + "all_queues_empty": total_pending == 0, + } + def _ensure_request_buffer(self) -> bool: if self._request_rdma_buffer is not None: return True @@ -167,9 +196,6 @@ def init(self, config): self._phase1_handshake_port = int(self.config.get("rdma_phase1_handshake_port", self._phase1_handshake_port)) self._phase1_slots = shared_slots self._phase1_slot_size = shared_slot_size - self.encoder_engine_rank = int(self.config.get("encoder_engine_rank", 0)) - self.transformer_engine_rank = int(self.config.get("transformer_engine_rank", 1)) - self.decoder_engine_rank = int(self.config.get("decoder_engine_rank", 2)) # Seed everything if seed is in config if "seed" in self.config: @@ -513,6 +539,19 @@ def run(self, stop_event=None): complete_queue: Dict[int, dict] = {} while True: + transfer_sizes = self.data_mgr.get_backlog_counts() if self.data_mgr is not None else {"request_pool": 0, "waiting_pool": 0} + self._update_queue_metrics( + { + "req_queue": len(req_queue), + "exec_queue": len(exec_queue), + "complete_queue": len(complete_queue), + }, + { + "request_pool": int(transfer_sizes.get("request_pool", 0)), + "waiting_pool": int(transfer_sizes.get("waiting_pool", 0)), + }, + ) + if self._request_rdma_buffer is None: try: self._ensure_request_buffer() @@ -522,7 +561,7 @@ def run(self, stop_event=None): if self._request_rdma_buffer is not None: config = self._request_rdma_buffer.consume() if config is not None: - self.logger.info("Received request config from RDMA buffer: %s", {k: v for k, v in config.items() if not k.endswith("_path")}) + self.logger.info("Received request config from RDMA buffer: %s", {k: v for k, v in config.items()}) req_queue.append(config) if req_queue: diff --git a/lightx2v/disagg/services/transformer.py b/lightx2v/disagg/services/transformer.py index 0446199b0..33e98e5f1 100644 --- a/lightx2v/disagg/services/transformer.py +++ b/lightx2v/disagg/services/transformer.py @@ -3,7 +3,7 @@ import threading import time from collections import deque -from typing import List, Optional +from typing import Any, List, Optional import numpy as np import torch @@ -61,6 +61,13 @@ def __init__(self, config: dict): gpu_id=self.transformer_engine_rank, bind_address=f"tcp://{self.config.get('data_bootstrap_addr', '127.0.0.1')}:{MONITOR_POLLING_PORT + self.transformer_engine_rank}", ) + self._queue_metrics_lock = threading.Lock() + self._queue_metrics: dict[str, Any] = { + "queue_sizes": {}, + "queue_total_pending": 0, + "all_queues_empty": True, + } + self.reporter.set_extra_metrics_provider(self._get_queue_metrics) self._reporter_thread: Optional[threading.Thread] = threading.Thread( target=self.reporter.serve_forever, name="transformer-reporter", @@ -69,6 +76,36 @@ def __init__(self, config: dict): self._reporter_thread.start() self.load_models() + def _get_queue_metrics(self) -> dict[str, Any]: + with self._queue_metrics_lock: + queue_sizes = dict(self._queue_metrics.get("queue_sizes", {})) + return { + "queue_sizes": queue_sizes, + "queue_total_pending": int(self._queue_metrics.get("queue_total_pending", 0)), + "all_queues_empty": bool(self._queue_metrics.get("all_queues_empty", True)), + } + + def _update_queue_metrics( + self, + queue_sizes: dict[str, int], + phase1_transfer_sizes: Optional[dict[str, int]] = None, + phase2_transfer_sizes: Optional[dict[str, int]] = None, + ): + merged_sizes = {k: int(v) for k, v in queue_sizes.items()} + if phase1_transfer_sizes is not None: + for key, value in phase1_transfer_sizes.items(): + merged_sizes[f"phase1_{key}"] = int(value) + if phase2_transfer_sizes is not None: + for key, value in phase2_transfer_sizes.items(): + merged_sizes[f"phase2_{key}"] = int(value) + total_pending = sum(max(v, 0) for v in merged_sizes.values()) + with self._queue_metrics_lock: + self._queue_metrics = { + "queue_sizes": merged_sizes, + "queue_total_pending": total_pending, + "all_queues_empty": total_pending == 0, + } + def _ensure_phase1_request_buffer(self) -> bool: if self._phase1_rdma_buffer is not None: return True @@ -137,9 +174,6 @@ def init(self, config): self._phase2_handshake_port = int(self.config.get("rdma_phase2_handshake_port", self._phase2_handshake_port)) self._phase2_slots = shared_slots self._phase2_slot_size = shared_slot_size - self.encoder_engine_rank = int(self.config.get("encoder_engine_rank", 0)) - self.transformer_engine_rank = int(self.config.get("transformer_engine_rank", 1)) - self.decoder_engine_rank = int(self.config.get("decoder_engine_rank", 2)) # Set global seed if present in config, though specific process calls might reuse it if "seed" in self.config: @@ -516,20 +550,40 @@ def run(self, stop_event=None): complete_queue: dict[int, dict] = {} while True: + phase1_transfer_sizes = self.data_mgr1.get_backlog_counts() if self.data_mgr1 is not None else {"request_pool": 0, "waiting_pool": 0} + phase2_transfer_sizes = self.data_mgr2.get_backlog_counts() if self.data_mgr2 is not None else {"request_pool": 0, "waiting_pool": 0} + self._update_queue_metrics( + { + "req_queue": len(req_queue), + "waiting_queue": len(waiting_queue), + "exec_queue": len(exec_queue), + "complete_queue": len(complete_queue), + }, + { + "request_pool": int(phase1_transfer_sizes.get("request_pool", 0)), + "waiting_pool": int(phase1_transfer_sizes.get("waiting_pool", 0)), + }, + { + "request_pool": int(phase2_transfer_sizes.get("request_pool", 0)), + "waiting_pool": int(phase2_transfer_sizes.get("waiting_pool", 0)), + }, + ) + if self._phase1_rdma_buffer is None: try: self._ensure_phase1_request_buffer() except Exception: self.logger.exception("Failed to connect phase1 request RDMA buffer, will retry") - if self._phase1_rdma_buffer is not None: + if self._phase1_rdma_buffer is not None and len(req_queue) + len(waiting_queue) < 2: packet = self._phase1_rdma_buffer.consume() if packet is not None: if isinstance(packet, dict) and "request_config" in packet: config = dict(packet.get("request_config") or {}) - config["encoder_node_address"] = packet.get("encoder_node_address", config.get("encoder_node_address", "127.0.0.1")) + config["encoder_node_address"] = packet.get("encoder_node_address", "127.0.0.1") else: config = packet + self.logger.info("%s Received request config from RDMA buffer: %s", self.transformer_engine_rank, {k: v for k, v in config.items()}) req_queue.append(config) if req_queue: diff --git a/scripts/disagg/kill_service.sh b/scripts/disagg/kill_service.sh index 63e3601fc..3b6315890 100755 --- a/scripts/disagg/kill_service.sh +++ b/scripts/disagg/kill_service.sh @@ -3,7 +3,19 @@ set -euo pipefail SCRIPT_NAME="run_wan_t2v_service.sh" -PORTS=(7788 7789 7790 12788 12789 12790 17788 17789 17790 27788 27789 27790) + +list_port=(5566 7788 12788 17788 27788) + +n=10 +list_n=($(seq 0 $((n-1)))) + +PORTS=(5555 12787) + +for a in "${list_port[@]}"; do + for b in "${list_n[@]}"; do + PORTS+=($((a + b))) + done +done kill_pid_gracefully() { local pid="$1" diff --git a/scripts/disagg/run_wan_t2v_service.sh b/scripts/disagg/run_wan_t2v_service.sh index f7dd1ce8e..8facbc665 100755 --- a/scripts/disagg/run_wan_t2v_service.sh +++ b/scripts/disagg/run_wan_t2v_service.sh @@ -23,14 +23,17 @@ transformer_cfg=${lightx2v_path}/configs/disagg/wan_t2v_disagg_transformer.json decoder_cfg=${lightx2v_path}/configs/disagg/wan_t2v_disagg_decoder.json seed=42 +request_count=10 prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." negative_prompt="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" save_result_path=${lightx2v_path}/save_results/test_disagg.mp4 -save_result_path_1=${save_result_path%.mp4}1.mp4 -save_result_path_2=${save_result_path%.mp4}2.mp4 +output_files=() +for ((i=1; i<=request_count; i++)); do + output_files+=("${save_result_path%.mp4}${i}.mp4") +done # Remove old outputs so wait loop reflects current run status. -rm -f "${save_result_path_1}" "${save_result_path_2}" +rm -f "${output_files[@]}" cleanup() { local pids=("${encoder_pid:-}" "${transformer_pid:-}" "${decoder_pid:-}" "${controller_pid:-}") @@ -86,54 +89,63 @@ wait_for_port 127.0.0.1 ${rdma_request_port} 60 wait_for_port 127.0.0.1 ${rdma_phase1_port} 60 wait_for_port 127.0.0.1 ${rdma_phase2_port} 60 -CUDA_VISIBLE_DEVICES=0 python -m lightx2v.disagg.examples.run_service \ - --service encoder \ - --model_cls wan2.1 \ - --task t2v \ - --model_path ${model_path} \ - --config_json ${encoder_cfg} \ - --seed ${seed} \ - --prompt "${prompt}" \ - --negative_prompt "${negative_prompt}" \ - --save_result_path ${save_result_path} \ - > ${lightx2v_path}/save_results/disagg_encoder.log 2>&1 & -encoder_pid=$! - -CUDA_VISIBLE_DEVICES=1 python -m lightx2v.disagg.examples.run_service \ - --service transformer \ - --model_cls wan2.1 \ - --task t2v \ - --model_path ${model_path} \ - --config_json ${transformer_cfg} \ - --seed ${seed} \ - --prompt "${prompt}" \ - --negative_prompt "${negative_prompt}" \ - --save_result_path ${save_result_path} \ - > ${lightx2v_path}/save_results/disagg_transformer.log 2>&1 & -transformer_pid=$! - -CUDA_VISIBLE_DEVICES=2 python -m lightx2v.disagg.examples.run_service \ - --service decoder \ - --model_cls wan2.1 \ - --task t2v \ - --model_path ${model_path} \ - --config_json ${decoder_cfg} \ - --seed ${seed} \ - --prompt "${prompt}" \ - --negative_prompt "${negative_prompt}" \ - --save_result_path ${save_result_path} \ - > ${lightx2v_path}/save_results/disagg_decoder.log 2>&1 & -decoder_pid=$! +# NOTE: Kept for rollback. Controller now creates encoder/transformer/decoder internally. +# CUDA_VISIBLE_DEVICES=0 python -m lightx2v.disagg.examples.run_service \ +# --service encoder \ +# --model_cls wan2.1 \ +# --task t2v \ +# --model_path ${model_path} \ +# --config_json ${encoder_cfg} \ +# --seed ${seed} \ +# --prompt "${prompt}" \ +# --negative_prompt "${negative_prompt}" \ +# --save_result_path ${save_result_path} \ +# > ${lightx2v_path}/save_results/disagg_encoder.log 2>&1 & +# encoder_pid=$! + +# CUDA_VISIBLE_DEVICES=1 python -m lightx2v.disagg.examples.run_service \ +# --service transformer \ +# --model_cls wan2.1 \ +# --task t2v \ +# --model_path ${model_path} \ +# --config_json ${transformer_cfg} \ +# --seed ${seed} \ +# --prompt "${prompt}" \ +# --negative_prompt "${negative_prompt}" \ +# --save_result_path ${save_result_path} \ +# > ${lightx2v_path}/save_results/disagg_transformer.log 2>&1 & +# transformer_pid=$! + +# CUDA_VISIBLE_DEVICES=2 python -m lightx2v.disagg.examples.run_service \ +# --service decoder \ +# --model_cls wan2.1 \ +# --task t2v \ +# --model_path ${model_path} \ +# --config_json ${decoder_cfg} \ +# --seed ${seed} \ +# --prompt "${prompt}" \ +# --negative_prompt "${negative_prompt}" \ +# --save_result_path ${save_result_path} \ +# > ${lightx2v_path}/save_results/disagg_decoder.log 2>&1 & +# decoder_pid=$! # Give background services time to flush and finish queued requests. -echo "Waiting for output videos: ${save_result_path_1}, ${save_result_path_2}" +echo "Waiting for output videos: ${output_files[*]}" wait_seconds=0 -max_wait_seconds=1200 +max_wait_seconds=$((600 * request_count)) while true; do - if [[ -f "${save_result_path_1}" && -f "${save_result_path_2}" ]]; then - echo "Both output videos are generated." + all_generated=1 + for file in "${output_files[@]}"; do + if [[ ! -f "${file}" ]]; then + all_generated=0 + break + fi + done + + if (( all_generated )); then + echo "All ${request_count} output videos are generated." break fi From 40557eb1f22da97165ea31804d94750295e6c31c Mon Sep 17 00:00:00 2001 From: zhtshr <547553598@qq.com> Date: Tue, 14 Apr 2026 11:11:03 +0800 Subject: [PATCH 2/9] add test --- .../disagg/wan22_i2v_distill_controller.json | 58 ++++ configs/disagg/wan22_i2v_distill_decoder.json | 58 ++++ configs/disagg/wan22_i2v_distill_encoder.json | 58 ++++ .../disagg/wan22_i2v_distill_transformer.json | 58 ++++ configs/disagg/wan22_i2v_workload_stages.json | 28 ++ lightx2v/disagg/conn.py | 78 +++-- lightx2v/disagg/examples/run_user.py | 66 ++++ lightx2v/disagg/mooncake.py | 105 ++++++- lightx2v/disagg/rdma_buffer.py | 49 ++- lightx2v/disagg/rdma_client.py | 169 ++++++++-- lightx2v/disagg/rdma_server.py | 27 +- lightx2v/disagg/services/base.py | 13 + lightx2v/disagg/services/controller.py | 276 ++++++++++++---- lightx2v/disagg/services/decoder.py | 86 ++++- lightx2v/disagg/services/encoder.py | 44 ++- lightx2v/disagg/services/transformer.py | 150 +++++++-- lightx2v/disagg/workload.py | 294 ++++++++++++++++++ scripts/disagg/extract_dynamic_latency.py | 100 ++++++ scripts/disagg/kill_service.sh | 30 +- scripts/disagg/run_dynamic.sh | 84 +++++ scripts/disagg/run_wan22_i2v_distill.sh | 116 +++++++ scripts/disagg/run_wan_t2v_service.sh | 6 +- 22 files changed, 1782 insertions(+), 171 deletions(-) create mode 100644 configs/disagg/wan22_i2v_distill_controller.json create mode 100644 configs/disagg/wan22_i2v_distill_decoder.json create mode 100644 configs/disagg/wan22_i2v_distill_encoder.json create mode 100644 configs/disagg/wan22_i2v_distill_transformer.json create mode 100644 configs/disagg/wan22_i2v_workload_stages.json create mode 100644 lightx2v/disagg/examples/run_user.py create mode 100644 lightx2v/disagg/workload.py create mode 100644 scripts/disagg/extract_dynamic_latency.py create mode 100644 scripts/disagg/run_dynamic.sh create mode 100755 scripts/disagg/run_wan22_i2v_distill.sh diff --git a/configs/disagg/wan22_i2v_distill_controller.json b/configs/disagg/wan22_i2v_distill_controller.json new file mode 100644 index 000000000..a7b124e87 --- /dev/null +++ b/configs/disagg/wan22_i2v_distill_controller.json @@ -0,0 +1,58 @@ +{ + "infer_steps": 4, + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "eps": 1e-06, + "model_type": "i2v", + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "rdma_buffer_slot_size": 8192, + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "fps": 16, + "use_image_encoder": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "int8-q8f", + "high_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "high_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "image_path": "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B/examples/i2v_input.JPG", + "disagg_mode": "controller", + "disagg_config": { + "bootstrap_addr": "127.0.0.1", + "bootstrap_room": 0, + "ranks": 8, + "encoder_engine_rank": 0, + "transformer_engine_rank": 1, + "decoder_engine_rank": 2, + "protocol": "rdma", + "local_hostname": "localhost", + "metadata_server": "P2PHANDSHAKE" + } +} \ No newline at end of file diff --git a/configs/disagg/wan22_i2v_distill_decoder.json b/configs/disagg/wan22_i2v_distill_decoder.json new file mode 100644 index 000000000..37f99e25d --- /dev/null +++ b/configs/disagg/wan22_i2v_distill_decoder.json @@ -0,0 +1,58 @@ +{ + "infer_steps": 4, + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "eps": 1e-06, + "model_type": "i2v", + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "rdma_buffer_slot_size": 8192, + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "fps": 16, + "use_image_encoder": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "int8-q8f", + "high_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "high_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "image_path": "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B/examples/i2v_input.JPG", + "disagg_mode": "decoder", + "disagg_config": { + "bootstrap_addr": "127.0.0.1", + "bootstrap_room": 0, + "ranks": 8, + "encoder_engine_rank": 0, + "transformer_engine_rank": 1, + "decoder_engine_rank": 2, + "protocol": "rdma", + "local_hostname": "localhost", + "metadata_server": "P2PHANDSHAKE" + } +} \ No newline at end of file diff --git a/configs/disagg/wan22_i2v_distill_encoder.json b/configs/disagg/wan22_i2v_distill_encoder.json new file mode 100644 index 000000000..d90f17120 --- /dev/null +++ b/configs/disagg/wan22_i2v_distill_encoder.json @@ -0,0 +1,58 @@ +{ + "infer_steps": 4, + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "eps": 1e-06, + "model_type": "i2v", + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "rdma_buffer_slot_size": 8192, + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "fps": 16, + "use_image_encoder": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "int8-q8f", + "high_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "high_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "image_path": "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B/examples/i2v_input.JPG", + "disagg_mode": "encoder", + "disagg_config": { + "bootstrap_addr": "127.0.0.1", + "bootstrap_room": 0, + "ranks": 8, + "encoder_engine_rank": 0, + "transformer_engine_rank": 1, + "decoder_engine_rank": 2, + "protocol": "rdma", + "local_hostname": "localhost", + "metadata_server": "P2PHANDSHAKE" + } +} \ No newline at end of file diff --git a/configs/disagg/wan22_i2v_distill_transformer.json b/configs/disagg/wan22_i2v_distill_transformer.json new file mode 100644 index 000000000..f6e467a72 --- /dev/null +++ b/configs/disagg/wan22_i2v_distill_transformer.json @@ -0,0 +1,58 @@ +{ + "infer_steps": 4, + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "eps": 1e-06, + "model_type": "i2v", + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "rdma_buffer_slot_size": 8192, + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "fps": 16, + "use_image_encoder": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "int8-q8f", + "high_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "high_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "image_path": "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B/examples/i2v_input.JPG", + "disagg_mode": "transformer", + "disagg_config": { + "bootstrap_addr": "127.0.0.1", + "bootstrap_room": 0, + "ranks": 8, + "encoder_engine_rank": 0, + "transformer_engine_rank": 1, + "decoder_engine_rank": 2, + "protocol": "rdma", + "local_hostname": "localhost", + "metadata_server": "P2PHANDSHAKE" + } +} \ No newline at end of file diff --git a/configs/disagg/wan22_i2v_workload_stages.json b/configs/disagg/wan22_i2v_workload_stages.json new file mode 100644 index 000000000..19787956e --- /dev/null +++ b/configs/disagg/wan22_i2v_workload_stages.json @@ -0,0 +1,28 @@ +[ + { + "name": "warmup", + "duration_s": 120, + "user_count": 1, + "spawn_rate": 0.1, + "wait_time_s": 0.0, + "config_variants": [ + { + "infer_steps": 1, + "sample_shift": 5.0 + } + ] + }, + { + "name": "change", + "duration_s": 180, + "user_count": 1, + "spawn_rate": 0.1, + "wait_time_s": 0.0, + "config_variants": [ + { + "infer_steps": 1, + "sample_shift": 5.0 + } + ] + } +] diff --git a/lightx2v/disagg/conn.py b/lightx2v/disagg/conn.py index 90ee69909..ca63cf593 100644 --- a/lightx2v/disagg/conn.py +++ b/lightx2v/disagg/conn.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import os import struct import threading from collections.abc import Mapping @@ -81,6 +82,14 @@ class DataPoll: DATARECEIVER_POLLING_PORT = 27788 +def _normalize_loopback_host(host: str) -> str: + normalized = (host or "").strip() + if os.getenv("DISAGG_FORCE_IPV4_LOOPBACK", "1") not in ("0", "false", "False"): + if normalized in ("localhost", "::1", ""): + return "127.0.0.1" + return normalized or "127.0.0.1" + + class DataManager: # TODO: make it general and support multiple transfer backend before merging def __init__(self, disaggregation_phase: DisaggregationPhase, disaggregation_mode: DisaggregationMode): @@ -146,18 +155,25 @@ def transfer_loop(): sender_data_ptrs = self.request_pool.pop(pending_room) self.sync_status_to_transformer_endpoint(endpoint, pending_room) - ret = self.send_data( - pending_room, - mooncake_session_id, - sender_data_ptrs, - receiver_ptrs, - ) + try: + ret = self.send_data( + pending_room, + mooncake_session_id, + sender_data_ptrs, + receiver_ptrs, + ) + except Exception: + logger.exception("Transfer loop exception room=%s session=%s", pending_room, mooncake_session_id) + ret = -1 with self.pool_lock: if ret != 0: self.request_status[pending_room] = DataPoll.Failed else: self.request_status[pending_room] = DataPoll.Success - self.sync_status_to_transformer_endpoint(endpoint, pending_room) + try: + self.sync_status_to_transformer_endpoint(endpoint, pending_room) + except Exception: + logger.exception("Failed to sync final status room=%s endpoint=%s", pending_room, endpoint) self.transfer_thread = threading.Thread(target=transfer_loop, name="data-transfer-thread") self.transfer_thread.start() @@ -295,25 +311,35 @@ def send_data( # TODO: transfer data in batch if there are many tensors or large tensors, instead of sending one by one. args = self.data_args[room] tensor_num = int(len(args.data_ptrs)) + chunk_bytes = int(os.getenv("MOONCAKE_TRANSFER_CHUNK_BYTES", str(1024 * 1024))) + if chunk_bytes <= 0: + chunk_bytes = 1024 * 1024 for tensor_id in range(tensor_num): sender_addr = sender_data_ptrs[tensor_id] item_len = args.data_item_lens[tensor_id] receiver_addr = receiver_ptrs[tensor_id] - # TODO: mooncake transfer engine can do async transfer. Do async later - status = self.engine.transfer_sync( - mooncake_session_id, - sender_addr, - receiver_addr, - item_len, - ) - if status != 0: - return status + offset = 0 + remaining = int(item_len) + while remaining > 0: + transfer_len = min(chunk_bytes, remaining) + # TODO: mooncake transfer engine can do async transfer. Do async later + status = self.engine.transfer_sync( + mooncake_session_id, + sender_addr + offset, + receiver_addr + offset, + transfer_len, + ) + if status != 0: + return status + offset += transfer_len + remaining -= transfer_len return 0 def sync_status_to_transformer_endpoint(self, remote: str, room: int): if ":" in remote: remote = remote.split(":")[0] + remote = _normalize_loopback_host(remote) receiver_rank = self.data_args[room].receiver_engine_rank receiver_rank_port = DATARECEIVER_POLLING_PORT + receiver_rank + room * 10 self._connect("tcp://" + remote + ":" + str(receiver_rank_port)).send_multipart( @@ -335,12 +361,14 @@ def encode_thread(): try: ( endpoint, + receiver_engine_rank_raw, mooncake_session_id, bootstrap_room, transformer_ptrs, ) = room_socket.recv_multipart() except zmq.Again: continue + receiver_engine_rank = int.from_bytes(receiver_engine_rank_raw, byteorder="big") if bootstrap_room.decode("ascii") == "None": continue endpoint = endpoint.decode("ascii") @@ -348,10 +376,11 @@ def encode_thread(): bootstrap_room = int(bootstrap_room.decode("ascii")) transformer_ptrs = list(struct.unpack(f"{len(transformer_ptrs) // 8}Q", transformer_ptrs)) logger.info( - "Encoder received ZMQ: endpoint=%s session_id=%s room=%s transformer_ptrs=%s", + "Encoder received ZMQ: endpoint=%s session_id=%s room=%s receiver_engine_rank=%s transformer_ptrs=%s", endpoint, mooncake_session_id, bootstrap_room, + receiver_engine_rank, transformer_ptrs, ) with self.pool_lock: @@ -360,6 +389,8 @@ def encode_thread(): mooncake_session_id, transformer_ptrs, ) + if bootstrap_room in self.data_args: + self.data_args[bootstrap_room].receiver_engine_rank = receiver_engine_rank if self.transfer_event is not None: self.transfer_event.set() @@ -405,12 +436,14 @@ def transformer_thread(): try: ( endpoint, + receiver_engine_rank_raw, mooncake_session_id, bootstrap_room, decode_ptrs, ) = room_socket.recv_multipart() except zmq.Again: continue + receiver_engine_rank = int.from_bytes(receiver_engine_rank_raw, byteorder="big") if bootstrap_room.decode("ascii") == "None": continue endpoint = endpoint.decode("ascii") @@ -418,10 +451,11 @@ def transformer_thread(): bootstrap_room = int(bootstrap_room.decode("ascii")) decode_ptrs = list(struct.unpack(f"{len(decode_ptrs) // 8}Q", decode_ptrs)) logger.info( - "Transformer received ZMQ: endpoint=%s session_id=%s room=%s decode_ptrs=%s", + "Transformer received ZMQ: endpoint=%s session_id=%s room=%s receiver_engine_rank=%s decode_ptrs=%s", endpoint, mooncake_session_id, bootstrap_room, + receiver_engine_rank, decode_ptrs, ) with self.pool_lock: @@ -430,6 +464,8 @@ def transformer_thread(): mooncake_session_id, decode_ptrs, ) + if bootstrap_room in self.data_args: + self.data_args[bootstrap_room].receiver_engine_rank = receiver_engine_rank if self.transfer_event is not None: self.transfer_event.set() @@ -541,9 +577,10 @@ def __init__(self, mgr: DataManager, bootstrap_addr: str, bootstrap_room: Option raise ValueError("bootstrap_room is required for DataReceiver") args = self.data_mgr.data_args[self.bootstrap_room] sender_rank_port = DATASENDER_POLLING_PORT + args.sender_engine_rank + self.bootstrap_room * 10 - self.sender_server_url = bootstrap_addr.split(":")[0] + ":" + str(sender_rank_port) + sender_host = _normalize_loopback_host(bootstrap_addr.split(":")[0]) + self.sender_server_url = sender_host + ":" + str(sender_rank_port) logger.info("DataReceiver sender_server_url=%s", self.sender_server_url) - self.receiver_ip = self.data_mgr.get_localhost() + self.receiver_ip = _normalize_loopback_host(self.data_mgr.get_localhost()) self.session_id = self.data_mgr.get_session_id() self.data_mgr.set_status(bootstrap_room, DataPoll.WaitingForInput) @@ -560,6 +597,7 @@ def init(self): self._connect("tcp://" + self.sender_server_url).send_multipart( [ self.receiver_ip.encode("ascii"), + args.receiver_engine_rank.to_bytes(4, byteorder="big"), self.session_id.encode("ascii"), str(self.bootstrap_room).encode("ascii"), packed_data_ptrs, diff --git a/lightx2v/disagg/examples/run_user.py b/lightx2v/disagg/examples/run_user.py new file mode 100644 index 000000000..da912b5b2 --- /dev/null +++ b/lightx2v/disagg/examples/run_user.py @@ -0,0 +1,66 @@ +import argparse +import time + +from lightx2v.disagg.conn import ReqManager, REQUEST_POLLING_PORT +from lightx2v.disagg.workload import ( + DisaggLoadShape, + build_payload, + current_stage, + load_base_config, + load_stage_specs, + send_workload_end_signal, + start_workload_clock, +) + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Run dynamic disagg workload user and push configs to Controller") + parser.add_argument("--controller_host", type=str, default="127.0.0.1") + parser.add_argument("--controller_request_port", type=int, default=REQUEST_POLLING_PORT - 2) + parser.add_argument("--max_requests", type=int, default=0, help="0 means no hard cap") + parser.add_argument("--sleep_min_ms", type=float, default=5.0, help="minimum loop sleep in ms") + return parser + + +def main(): + args = _build_parser().parse_args() + + req_mgr = ReqManager() + stages = load_stage_specs() + base_config = load_base_config() + shape = DisaggLoadShape() + + start_workload_clock() + + sent = 0 + last_tick_ts = 0.0 + + while True: + tick = shape.tick() + if tick is None: + break + + _, spawn_rate = tick + spawn_rate = max(float(spawn_rate), 0.1) + + stage = current_stage(stages) + payload = build_payload(base_config, stage, sent) + req_mgr.send(args.controller_host, args.controller_request_port, payload) + sent += 1 + + now = time.time() + if now - last_tick_ts >= 1.0: + print(f"stage={stage.name} spawn_rate={spawn_rate:.3f} req/s sent={sent}") + last_tick_ts = now + + if args.max_requests > 0 and sent >= args.max_requests: + break + + time.sleep(max(1.0 / spawn_rate, args.sleep_min_ms / 1000.0)) + + send_workload_end_signal() + print(f"workload finished: sent={sent}, end signal sent") + + +if __name__ == "__main__": + main() diff --git a/lightx2v/disagg/mooncake.py b/lightx2v/disagg/mooncake.py index 756574d5a..70cc27166 100644 --- a/lightx2v/disagg/mooncake.py +++ b/lightx2v/disagg/mooncake.py @@ -1,11 +1,36 @@ import json import logging import os +import random +import socket +import time from dataclasses import dataclass logger = logging.getLogger(__name__) +def _detect_non_loopback_ipv4() -> str | None: + # Use a UDP connect trick to discover the outbound interface IP without sending traffic. + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.connect(("8.8.8.8", 80)) + ip = sock.getsockname()[0] + sock.close() + if ip and not ip.startswith("127."): + return ip + except Exception: + pass + + try: + host_ip = socket.gethostbyname(socket.gethostname()) + if host_ip and not host_ip.startswith("127."): + return host_ip + except Exception: + pass + + return None + + @dataclass class MooncakeTransferEngineConfig: local_hostname: str @@ -29,7 +54,33 @@ def load_from_env() -> "MooncakeTransferEngineConfig": config_file_path = os.getenv("MOONCAKE_CONFIG_PATH", "/root/zht/LightX2V/configs/mooncake_config.json") if config_file_path is None: raise ValueError("The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") - return MooncakeTransferEngineConfig.from_file(config_file_path) + cfg = MooncakeTransferEngineConfig.from_file(config_file_path) + + env_metadata_server = os.getenv("MOONCAKE_METADATA_SERVER", "").strip() + if env_metadata_server: + cfg.metadata_server = env_metadata_server + + env_protocol = os.getenv("MOONCAKE_PROTOCOL", "").strip() + if env_protocol: + cfg.protocol = env_protocol + + env_device_name = os.getenv("MOONCAKE_DEVICE_NAME", "").strip() + if env_device_name: + cfg.device_name = env_device_name + + # Keep session IDs and metadata endpoints stable on single-node runs. + # localhost may resolve to IPv6 on some hosts while peers use IPv4. + force_ipv4 = os.getenv("MOONCAKE_FORCE_IPV4_LOOPBACK", "1") not in ("0", "false", "False") + env_host = os.getenv("MOONCAKE_LOCAL_HOSTNAME", "").strip() + if env_host: + cfg.local_hostname = env_host + elif force_ipv4 and cfg.local_hostname in ("localhost", "::1", "127.0.0.1"): + detected = _detect_non_loopback_ipv4() + if detected is not None: + cfg.local_hostname = detected + else: + cfg.local_hostname = "127.0.0.1" + return cfg class MooncakeTransferEngine: @@ -89,11 +140,53 @@ def initialize( def transfer_sync(self, session_id: str, buffer: int, peer_buffer_address: int, length: int) -> int: """Synchronously transfer data to the specified address.""" if self.engine: - ret = self.engine.transfer_sync_write(session_id, buffer, peer_buffer_address, length) - if ret < 0: - logger.error("Transfer Return Error") - raise Exception("Transfer Return Error") - return ret + if os.getenv("NETWORK_LATENCY"): + latency_prob_raw = os.getenv("NETWORK_LATENCY_PROB", "0.02") + latency_sec_raw = os.getenv("NETWORK_LATENCY_SEC", "5") + try: + latency_prob = float(latency_prob_raw) + except ValueError: + latency_prob = 0.02 + # Accept either ratio (0.02) or percentage (2 / 5). + if latency_prob > 1.0: + latency_prob = latency_prob / 100.0 + latency_prob = max(0.0, min(1.0, latency_prob)) + + try: + latency_sec = float(latency_sec_raw) + except ValueError: + latency_sec = 5.0 + latency_sec = max(0.0, latency_sec) + + if random.random() < latency_prob: + logger.warning( + "Simulated network latency: sleeping %.3fs before transfer_sync_write (prob=%.4f)", + latency_sec, + latency_prob, + ) + time.sleep(latency_sec) + + retry_count = int(os.getenv("MOONCAKE_TRANSFER_RETRY", "5")) + retry_backoff_s = float(os.getenv("MOONCAKE_TRANSFER_RETRY_BACKOFF_S", "0.05")) + for attempt in range(retry_count + 1): + ret = self.engine.transfer_sync_write(session_id, buffer, peer_buffer_address, length) + if ret >= 0: + return ret + + logger.warning( + "Transfer Return Error attempt=%s/%s session=%s src=0x%x dst=0x%x len=%s", + attempt + 1, + retry_count + 1, + session_id, + int(buffer), + int(peer_buffer_address), + int(length), + ) + if attempt < retry_count: + time.sleep(retry_backoff_s) + + logger.error("Transfer Return Error after retries") + return -1 return -1 def get_localhost(self): diff --git a/lightx2v/disagg/rdma_buffer.py b/lightx2v/disagg/rdma_buffer.py index 2a26a69e0..1c2b8927a 100644 --- a/lightx2v/disagg/rdma_buffer.py +++ b/lightx2v/disagg/rdma_buffer.py @@ -4,6 +4,7 @@ import json import logging import threading +import time from dataclasses import dataclass from typing import Any, Dict, Optional @@ -214,31 +215,40 @@ def _deserialize_config(self, raw_slot: bytes) -> Dict[str, Any]: raise ValueError("invalid slot payload") plen = int.from_bytes(raw_slot[:4], byteorder="little", signed=False) if plen == 0: - return {} + raise ValueError("slot payload is not committed yet") + if plen > self.slot_size - 4: + raise ValueError(f"invalid slot payload length: {plen}") data = raw_slot[4 : 4 + plen] - return json.loads(data.decode("utf-8")) + try: + return json.loads(data.decode("utf-8")) + except Exception as exc: + raise ValueError("slot payload is incomplete or corrupted") from exc def produce(self, config: Dict[str, Any]) -> int: """Produce one config into ring buffer and advance tail by rdma_faa.""" if self.rdma_server is None and self.rdma_client is None: raise RuntimeError("produce requires rdma_server or rdma_client") - # Reserve one slot by atomically incrementing tail. - old_tail = self._rdma_faa(self.descriptor.tail_addr, 1) + # Read current indices first, write the slot fully, then publish by advancing tail. + old_tail = self._read_remote_u64(self.descriptor.tail_addr) cur_head = self._read_remote_u64(self.descriptor.head_addr) - if (old_tail + 1) - cur_head > self.buffer_size: - # Ring full, rollback reservation. - self._rdma_faa(self.descriptor.tail_addr, -1) + if old_tail - cur_head >= self.buffer_size: raise BufferError("ring buffer is full") slot_idx = old_tail % self.buffer_size offset = self._slot_offset(slot_idx) payload = self._serialize_config(config) + payload_len_header = payload[:4] + payload_body = payload[4:] # Write payload to the selected slot (works for both server-local and client-remote paths). slot_addr = self.descriptor.slot_addr + offset self._rdma_write_bytes(slot_addr, b"\x00" * self.slot_size) - self._rdma_write_bytes(slot_addr, payload) + if payload_body: + self._rdma_write_bytes(slot_addr + 4, payload_body) + # Write length header last so consumers never parse a half-written payload. + self._rdma_write_bytes(slot_addr, payload_len_header) + self._rdma_faa(self.descriptor.tail_addr, 1) logger.info("Produced config to RDMA buffer slot %d", slot_idx) return slot_idx @@ -273,10 +283,23 @@ def consume(self) -> Optional[Dict[str, Any]]: slot_idx = old_head % self.buffer_size slot_addr = self.descriptor.slot_addr + self._slot_offset(slot_idx) + max_read_retries = 5 + retry_sleep_seconds = 0.002 + last_error: Optional[Exception] = None + for _ in range(max_read_retries): + try: + raw = self._rdma_read_bytes(slot_addr, self.slot_size) + config = self._deserialize_config(raw) + logger.info("Consumed config from RDMA buffer slot %d", slot_idx) + return config + except Exception as exc: + last_error = exc + time.sleep(retry_sleep_seconds) + + # Slot still not readable after retries, rollback head so the slot can be retried later. + logger.warning("RDMA buffer slot %d read incomplete after retries, rolling back head: %s", slot_idx, last_error) try: - raw = self._rdma_read_bytes(slot_addr, self.slot_size) + self._rdma_faa(self.descriptor.head_addr, -1) except Exception as exc: - logger.warning("RDMA buffer slot read failed for slot %d: %s", slot_idx, exc) - return None - logger.info("Consumed config from RDMA buffer slot %d", slot_idx) - return self._deserialize_config(raw) + logger.warning("RDMA buffer rollback failed after incomplete slot read: %s", exc) + return None diff --git a/lightx2v/disagg/rdma_client.py b/lightx2v/disagg/rdma_client.py index 1cd26b65e..615e6e3f8 100644 --- a/lightx2v/disagg/rdma_client.py +++ b/lightx2v/disagg/rdma_client.py @@ -1,4 +1,6 @@ import json +import os +import random import socket import threading import time @@ -43,8 +45,13 @@ class AccessFlag: class RDMAClient: def __init__(self, iface_name=None, local_buffer_size=4096): self.local_psn = 654321 + self._next_psn = (int(time.time() * 1000000) & 0xFFFFFF) or 1 self.port_num = 1 - self.gid_index = 1 + self.gid_index = 0 + if iface_name is None: + env_iface = os.getenv("RDMA_IFACE", "").strip() + if env_iface: + iface_name = env_iface if iface_name is None: devices = get_device_list() if not devices: @@ -55,11 +62,12 @@ def __init__(self, iface_name=None, local_buffer_size=4096): self.ctx = IBDevice(iface_name).open() self.pd = PD(self.ctx) self.cq = CQ(self.ctx, 10) + self.gid_index = self._resolve_gid_index() qp_init_attr = QPCap(max_send_wr=10, max_recv_wr=10, max_send_sge=1, max_recv_sge=1) - qia = QPInitAttr(qp_type=QPType.RC, scq=self.cq, rcq=self.cq, cap=qp_init_attr) - qa = QPAttr(port_num=self.port_num) - self.qp = QP(self.pd, qia, qa) + self._qia = QPInitAttr(qp_type=QPType.RC, scq=self.cq, rcq=self.cq, cap=qp_init_attr) + self._qa = QPAttr(port_num=self.port_num) + self.qp = QP(self.pd, self._qia, self._qa) # 客户端也需要注册内存,用于发送数据的源 (Write) 或接收数据的目标 (Read) self.buffer_size = int(local_buffer_size) @@ -68,6 +76,69 @@ def __init__(self, iface_name=None, local_buffer_size=4096): self.local_mr = MR(self.pd, self.buffer_size, AccessFlag.LOCAL_WRITE) self._io_lock = threading.RLock() + def _resolve_gid_index(self): + env_gid = os.getenv("RDMA_GID_INDEX", "").strip() + if env_gid: + idx = int(env_gid) + self.ctx.query_gid(port_num=self.port_num, index=idx) + return idx + + # Prefer IPv4-mapped RoCE entries for Ethernet-based RDMA devices. + preferred = [2, 0, 1, 3, 4, 5, 6, 7] + for idx in preferred: + try: + gid = str(self.ctx.query_gid(port_num=self.port_num, index=idx)) + except Exception: + continue + if gid and gid != "::": + return idx + + # Last resort: let query_gid raise a descriptive error for index 0. + self.ctx.query_gid(port_num=self.port_num, index=0) + return 0 + + def _alloc_local_psn(self): + self._next_psn = (self._next_psn + 1) & 0xFFFFFF + if self._next_psn == 0: + self._next_psn = 1 + self.local_psn = self._next_psn + return self.local_psn + + def _reset_qp(self): + old_qp = getattr(self, "qp", None) + self.qp = QP(self.pd, self._qia, self._qa) + if old_qp is not None: + close_fn = getattr(old_qp, "close", None) + if callable(close_fn): + try: + close_fn() + except Exception: + pass + + def _recv_json(self, sock, timeout_sec): + decoder = json.JSONDecoder() + chunks = [] + deadline = time.time() + timeout_sec + while time.time() < deadline: + try: + chunk = sock.recv(4096) + except socket.timeout: + continue + + if not chunk: + break + + chunks.append(chunk) + payload = b"".join(chunks).decode("utf-8", errors="strict") + try: + obj, _ = decoder.raw_decode(payload) + return obj + except json.JSONDecodeError: + continue + + msg = b"".join(chunks).decode("utf-8", errors="ignore") + raise RuntimeError(f"Timed out waiting for complete handshake JSON. payload={msg!r}") + def _ensure_local_mr_capacity(self, required_size: int): required = int(required_size) if required <= self.buffer_size: @@ -76,29 +147,73 @@ def _ensure_local_mr_capacity(self, required_size: int): self.local_mr = MR(self.pd, self.buffer_size, AccessFlag.LOCAL_WRITE) def connect_to_server(self, server_ip="127.0.0.1", port=5566): - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.connect((server_ip, port)) - - # 1. 接收 Server 信息 (包含 rkey 和 addr) - data = sock.recv(4096) - self.remote_info = json.loads(data.decode()) - print(f"[Client] Got Server Info: Addr={hex(self.remote_info['addr'])}, RKey={self.remote_info['rkey']}") - - # 2. 发送我的信息给 Server - gid = self.ctx.query_gid(port_num=self.port_num, index=self.gid_index) - my_info = { - "lid": self.ctx.query_port(port_num=self.port_num).lid, - "qpn": self.qp.qp_num, - "psn": self.local_psn, - "gid": str(gid), - "gid_index": self.gid_index, - } - sock.sendall(json.dumps(my_info).encode()) - - # 3. 修改 QP 状态 - self._modify_qp_to_rts() - self.sock = sock - print("[Client] Connection established (RTS)") + max_retries = max(1, int(os.getenv("RDMA_CLIENT_CONNECT_RETRIES", "30"))) + connect_timeout_sec = float(os.getenv("RDMA_CLIENT_CONNECT_TIMEOUT_SEC", "2.0")) + backoff_base_sec = float(os.getenv("RDMA_CLIENT_BACKOFF_BASE_SEC", "0.1")) + backoff_max_sec = float(os.getenv("RDMA_CLIENT_BACKOFF_MAX_SEC", "2.0")) + jitter_ratio = float(os.getenv("RDMA_CLIENT_BACKOFF_JITTER", "0.2")) + + last_exc = None + for attempt in range(1, max_retries + 1): + sock = None + try: + self._reset_qp() + self._alloc_local_psn() + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(connect_timeout_sec) + sock.connect((server_ip, port)) + + # 1. 接收 Server 信息 (包含 rkey 和 addr) + remote_info = self._recv_json(sock, timeout_sec=connect_timeout_sec) + if not isinstance(remote_info, dict): + raise RuntimeError(f"Invalid handshake payload type: {type(remote_info)}") + required_keys = {"addr", "rkey", "qpn", "psn", "gid"} + missing = required_keys.difference(remote_info.keys()) + if missing: + raise RuntimeError(f"Handshake missing keys: {sorted(missing)}") + self.remote_info = remote_info + print(f"[Client] Got Server Info: Addr={hex(int(self.remote_info['addr']))}, RKey={self.remote_info['rkey']}") + + # 2. 发送我的信息给 Server + gid = self.ctx.query_gid(port_num=self.port_num, index=self.gid_index) + my_info = { + "lid": self.ctx.query_port(port_num=self.port_num).lid, + "qpn": self.qp.qp_num, + "psn": self.local_psn, + "gid": str(gid), + "gid_index": self.gid_index, + } + sock.sendall(json.dumps(my_info).encode()) + + # 3. 修改 QP 状态 + self._modify_qp_to_rts() + sock.settimeout(None) + self.sock = sock + print(f"[Client] Connection established (RTS) to {server_ip}:{port} at attempt {attempt}/{max_retries}") + return + except Exception as exc: + last_exc = exc + if sock is not None: + try: + sock.close() + except Exception: + pass + + if attempt < max_retries: + backoff = min(backoff_max_sec, backoff_base_sec * (2 ** (attempt - 1))) + if jitter_ratio > 0: + jitter = random.uniform(1.0 - jitter_ratio, 1.0 + jitter_ratio) + backoff = max(0.01, backoff * jitter) + print( + f"[Client] Handshake attempt {attempt}/{max_retries} failed to {server_ip}:{port}: {exc}. " + f"Retrying in {backoff:.2f}s" + ) + time.sleep(backoff) + + raise RuntimeError( + f"RDMA client failed to connect to {server_ip}:{port} after {max_retries} attempts" + ) from last_exc def _modify_qp_to_rts(self): # Follow the standard RC flow: INIT -> RTR -> RTS. diff --git a/lightx2v/disagg/rdma_server.py b/lightx2v/disagg/rdma_server.py index 2b2658d33..bfe0d0c7c 100644 --- a/lightx2v/disagg/rdma_server.py +++ b/lightx2v/disagg/rdma_server.py @@ -1,4 +1,5 @@ import json +import os import socket import threading @@ -39,10 +40,14 @@ def __init__(self, iface_name=None, port_num=1, buffer_size=4096): self.local_psn = 123456 self._next_psn = int(self.local_psn) self.port_num = port_num - self.gid_index = 1 + self.gid_index = 0 self.buffer_size = int(buffer_size) if self.buffer_size <= 0: raise ValueError("buffer_size must be positive") + if iface_name is None: + env_iface = os.getenv("RDMA_IFACE", "").strip() + if env_iface: + iface_name = env_iface if iface_name is None: devices = get_device_list() if not devices: @@ -60,6 +65,7 @@ def __init__(self, iface_name=None, port_num=1, buffer_size=4096): self.pd = PD(self.ctx) self.cq = CQ(self.ctx, 10) + self.gid_index = self._resolve_gid_index() # 创建 QP (Queue Pair) qp_init_attr = QPCap(max_send_wr=10, max_recv_wr=10, max_send_sge=1, max_recv_sge=1) @@ -91,6 +97,25 @@ def __init__(self, iface_name=None, port_num=1, buffer_size=4096): self._mr_addr = int(mr_addr) print(f"[Server] MR Registered. Addr: {mr_addr}, RKey: {self.mr.rkey}") + def _resolve_gid_index(self): + env_gid = os.getenv("RDMA_GID_INDEX", "").strip() + if env_gid: + idx = int(env_gid) + self.ctx.query_gid(port_num=self.port_num, index=idx) + return idx + + preferred = [2, 0, 1, 3, 4, 5, 6, 7] + for idx in preferred: + try: + gid = str(self.ctx.query_gid(port_num=self.port_num, index=idx)) + except Exception: + continue + if gid and gid != "::": + return idx + + self.ctx.query_gid(port_num=self.port_num, index=0) + return 0 + def register_memory(self, addr: int, length: int): """Validate a requested sub-region against server MR and return registration metadata. diff --git a/lightx2v/disagg/services/base.py b/lightx2v/disagg/services/base.py index ca2df2b0f..7aa807ff3 100644 --- a/lightx2v/disagg/services/base.py +++ b/lightx2v/disagg/services/base.py @@ -13,3 +13,16 @@ def __init__(self): """ self.logger = logger self.logger.info(f"Initializing {self.__class__.__name__}") + + def _sync_runtime_config(self, config): + current_config = getattr(self, "config", None) + if current_config is None: + self.config = dict(config) + return self.config + + if current_config is not config: + current_config.clear() + current_config.update(config) + + self.config = current_config + return self.config diff --git a/lightx2v/disagg/services/controller.py b/lightx2v/disagg/services/controller.py index f92c491df..1e39b74c0 100644 --- a/lightx2v/disagg/services/controller.py +++ b/lightx2v/disagg/services/controller.py @@ -70,6 +70,19 @@ def _to_plain(self, value: Any) -> Any: return {self._to_plain(v) for v in value} return value + def _resolve_service_config_json(self, config_json: str, instance_type: str) -> str: + config_path = Path(config_json) + if config_path.is_file(): + if config_path.name.endswith("_controller.json"): + candidate = config_path.with_name(config_path.name.replace("_controller.json", f"_{instance_type}.json")) + if candidate.is_file(): + return str(candidate) + if config_path.name.endswith("_distill_controller.json"): + candidate = config_path.with_name(config_path.name.replace("_distill_controller.json", f"_distill_{instance_type}.json")) + if candidate.is_file(): + return str(candidate) + return config_json + def _monitor_node_from_instance_address(self, instance_address: str) -> str: host, port_str = instance_address.rsplit(":", 1) rank = int(port_str) - REQUEST_POLLING_PORT @@ -137,6 +150,7 @@ def create_instance(self, instance_type: str) -> str: config_json = instance_cfg.get("config_json") if not model_path or not config_json: raise RuntimeError("model_path and config_json are required to launch service subprocess") + service_config_json = self._resolve_service_config_json(str(config_json), instance_type) cmd = [ sys.executable, @@ -153,7 +167,7 @@ def create_instance(self, instance_type: str) -> str: "--model_path", str(model_path), "--config_json", - str(config_json), + service_config_json, "--seed", str(instance_cfg.get("seed", 42)), "--prompt", @@ -337,6 +351,79 @@ def _init_request_rdma_buffer(self, bootstrap_addr: str, config: dict): need_bytes_phase2, ) + def _build_latency_summary(self, result: dict[str, Any], controller_recv_ts: float) -> dict[str, float] | None: + request_metrics = result.get("request_metrics") + if not isinstance(request_metrics, dict): + return None + + stages = request_metrics.get("stages") + if not isinstance(stages, dict): + return None + + def _as_float(value: Any) -> float | None: + try: + return float(value) + except (TypeError, ValueError): + return None + + def _stage(name: str) -> dict[str, Any]: + stage_metrics = stages.get(name) + return stage_metrics if isinstance(stage_metrics, dict) else {} + + controller_send_ts = _as_float(request_metrics.get("controller_send_ts")) + encoder = _stage("encoder") + transformer = _stage("transformer") + decoder = _stage("decoder") + + encoder_recv_ts = _as_float(encoder.get("request_received_ts")) + encoder_compute_start_ts = _as_float(encoder.get("compute_start_ts")) + encoder_compute_end_ts = _as_float(encoder.get("compute_end_ts")) + encoder_output_enqueued_ts = _as_float(encoder.get("output_enqueued_ts")) + + transformer_recv_ts = _as_float(transformer.get("request_received_ts")) + transformer_compute_start_ts = _as_float(transformer.get("compute_start_ts")) + transformer_compute_end_ts = _as_float(transformer.get("compute_end_ts")) + transformer_output_enqueued_ts = _as_float(transformer.get("output_enqueued_ts")) + + decoder_recv_ts = _as_float(decoder.get("request_received_ts")) + decoder_compute_start_ts = _as_float(decoder.get("compute_start_ts")) + decoder_compute_end_ts = _as_float(decoder.get("compute_end_ts")) + decoder_output_enqueued_ts = _as_float(decoder.get("output_enqueued_ts")) + + required_values = [ + controller_send_ts, + encoder_recv_ts, + encoder_compute_start_ts, + encoder_compute_end_ts, + encoder_output_enqueued_ts, + transformer_recv_ts, + transformer_compute_start_ts, + transformer_compute_end_ts, + transformer_output_enqueued_ts, + decoder_recv_ts, + decoder_compute_start_ts, + decoder_compute_end_ts, + decoder_output_enqueued_ts, + ] + if any(value is None for value in required_values): + return None + + summary: dict[str, float] = { + "controller_to_encoder_comm_delay_s": encoder_recv_ts - controller_send_ts, + "encoder_scheduling_delay_s": encoder_compute_start_ts - encoder_recv_ts, + "encoder_compute_delay_s": encoder_compute_end_ts - encoder_compute_start_ts, + "encoder_communication_delay_s": transformer_recv_ts - encoder_output_enqueued_ts, + "transformer_scheduling_delay_s": transformer_compute_start_ts - transformer_recv_ts, + "transformer_compute_delay_s": transformer_compute_end_ts - transformer_compute_start_ts, + "transformer_communication_delay_s": decoder_recv_ts - transformer_output_enqueued_ts, + "decoder_scheduling_delay_s": decoder_compute_start_ts - decoder_recv_ts, + "decoder_compute_delay_s": decoder_compute_end_ts - decoder_compute_start_ts, + "decoder_communication_delay_s": controller_recv_ts - decoder_output_enqueued_ts, + "end_to_end_delay_s": controller_recv_ts - controller_send_ts, + } + summary["sum_of_components_s"] = sum(value for key, value in summary.items() if key != "end_to_end_delay_s" and key != "sum_of_components_s") + return summary + def add_instance(self, instance_type: str, instance_address: str): """Add instance address to the matching scheduling policy by type.""" if not instance_address: @@ -376,17 +463,14 @@ def send_request(self, config): self.logger.info("Request enqueued to encoder request RDMA buffer") def run(self, config): - """Initialize controller buffers, send requests, wait for decoder save_path callbacks, then exit.""" + """Initialize controller buffers, stream request configs from workload, then wait for all callbacks.""" if config is None: raise ValueError("config cannot be None") self._shutting_down = False bootstrap_addr = config.get("data_bootstrap_addr", "127.0.0.1") - encoder_engine_rank = config.get("encoder_engine_rank", 0) - transformer_engine_rank = config.get("transformer_engine_rank", 1) - decoder_engine_rank = config.get("decoder_engine_rank", 2) - request_count = int(config.get("request_count", 10)) + request_ingress_port = int(config.get("controller_request_port", os.getenv("DISAGG_CONTROLLER_REQUEST_PORT", REQUEST_POLLING_PORT - 2))) result_port = int(config.get("controller_result_port", REQUEST_POLLING_PORT - 1)) self._bootstrap_addr = str(bootstrap_addr) self._runtime_config = self._to_plain(config) @@ -397,16 +481,13 @@ def run(self, config): self.decoder_policy = RoundRobinPolicy() self._init_request_rdma_buffer(bootstrap_addr, config) + + time.sleep(5.0) for instance_type in ("encoder", "transformer", "decoder"): address = self.create_instance(instance_type) - - monitor_nodes = [ - f"tcp://{bootstrap_addr}:{MONITOR_POLLING_PORT + encoder_engine_rank}", - f"tcp://{bootstrap_addr}:{MONITOR_POLLING_PORT + transformer_engine_rank}", - f"tcp://{bootstrap_addr}:{MONITOR_POLLING_PORT + decoder_engine_rank}", - ] - self.monitor.nodes = monitor_nodes + for _ in range(5): + self.create_instance("transformer") monitor_stop_event = Event() scale_out_threshold = 80.0 @@ -419,6 +500,7 @@ def run(self, config): } def _monitor_callback(results): + return if self._shutting_down: return @@ -429,13 +511,16 @@ def _monitor_callback(results): } for item in results: - self.logger.info("monitor: %s", item) + # self.logger.info("monitor: %s", item) if not isinstance(item, dict): continue service_type = str(item.get("service_type", "")) if service_type not in {"encoder", "transformer", "decoder"}: continue + + # if service_type not in {"transformer"}: + # continue if item.get("status") != "ok": continue @@ -484,12 +569,13 @@ def _monitor_callback(results): new_address, ) except Exception as exc: - self.logger.warning( - "Auto-scale out skipped for service=%s avg_gpu_utilization=%.2f reason=%s", - service_type, - avg_gpu_utilization, - exc, - ) + pass + # self.logger.warning( + # "Auto-scale out skipped for service=%s avg_gpu_utilization=%.2f reason=%s", + # service_type, + # avg_gpu_utilization, + # exc, + # ) low_metric = min(metrics, key=lambda metric: float(metric["gpu_utilization"])) low_utilization = float(low_metric["gpu_utilization"]) @@ -529,31 +615,123 @@ def _monitor_callback(results): daemon=True, ) monitor_thread.start() + + time.sleep(5.0) base_save_path = config.get("save_path") expected_rooms: set[int] = set() received_rooms: set[int] = set() received_results: list[dict] = [] + next_room = 0 + batch_request_start_ts: float | None = None + + def _handle_decoder_result(result: Any): + if not isinstance(result, dict): + self.logger.warning("Ignored non-dict decoder result: %s", result) + return + room = result.get("data_bootstrap_room") + if room is None: + self.logger.warning("Ignored decoder result without data_bootstrap_room: %s", result) + return + room = int(room) + if room not in expected_rooms: + self.logger.warning("Ignored decoder result for unexpected room=%s: %s", room, result) + return + if room in received_rooms: + self.logger.info("Duplicate decoder result for room=%s ignored", room) + return + + controller_recv_ts = time.time() + latency_summary = self._build_latency_summary(result, controller_recv_ts) + if latency_summary is not None: + result["latency_summary"] = latency_summary + self.logger.info("Latency summary room=%s metrics=%s", room, latency_summary) + + received_rooms.add(room) + received_results.append(result) + + if result.get("ok", False): + self.logger.info( + "Decoder result received room=%s save_path=%s (%s/%s)", + room, + result.get("save_path"), + len(received_rooms), + len(expected_rooms), + ) + else: + self.logger.error( + "Decoder result failed room=%s error=%s (%s/%s)", + room, + result.get("error"), + len(received_rooms), + len(expected_rooms), + ) + + def _drain_decoder_results_non_block(): + while True: + result = self.req_mgr.receive_non_block(result_port) + if result is None: + break + _handle_decoder_result(result) + try: - for i in range(request_count): + self.logger.info("Waiting workload configs on port=%s", request_ingress_port) + while True: + workload_config = self.req_mgr.receive(request_ingress_port) + if not isinstance(workload_config, dict): + self.logger.warning("Ignored invalid workload config packet: %s", workload_config) + continue + + if workload_config.get("workload_end") or workload_config.get("end") or workload_config.get("stop"): + self.logger.info("Received workload end signal, stop accepting new configs.") + break + request_config = dict(config) - request_config["data_bootstrap_room"] = i + request_config.update(self._to_plain(workload_config)) + + room = request_config.get("data_bootstrap_room", next_room) + try: + room = int(room) + except (TypeError, ValueError): + room = next_room + if room in expected_rooms: + while next_room in expected_rooms: + next_room += 1 + room = next_room + next_room = max(next_room, room + 1) + + request_config["data_bootstrap_room"] = room request_config["controller_result_host"] = bootstrap_addr request_config["controller_result_port"] = result_port - if base_save_path: + + metrics = request_config.get("request_metrics") + if not isinstance(metrics, dict): + metrics = {} + metrics["request_id"] = int(metrics.get("request_id", room)) + metrics["controller_send_ts"] = time.time() + if not isinstance(metrics.get("stages"), dict): + metrics["stages"] = {} + request_config["request_metrics"] = metrics + + if base_save_path and not request_config.get("save_path"): save_path = Path(base_save_path) - request_config["save_path"] = str(save_path.with_name(f"{save_path.stem}{i}{save_path.suffix}")) - # TODO: use queue to receive request from client and dispatch, currently we just send the same request multiple times for testing + request_config["save_path"] = str(save_path.with_name(f"{save_path.stem}{room}{save_path.suffix}")) + with self._lock: current_request = request_config + + if batch_request_start_ts is None: + batch_request_start_ts = time.time() + self.send_request(current_request) self.logger.info( "Dispatched request room=%s save_path=%s", - i, + room, request_config.get("save_path"), ) + expected_rooms.add(room) - expected_rooms.add(i) + _drain_decoder_results_non_block() self.logger.info( "Waiting for decoder results: expected=%s on port=%s", @@ -562,42 +740,18 @@ def _monitor_callback(results): ) while len(received_rooms) < len(expected_rooms): result = self.req_mgr.receive(result_port) - if not isinstance(result, dict): - self.logger.warning("Ignored non-dict decoder result: %s", result) - continue - room = result.get("data_bootstrap_room") - if room is None: - self.logger.warning("Ignored decoder result without data_bootstrap_room: %s", result) - continue - room = int(room) - if room not in expected_rooms: - self.logger.warning("Ignored decoder result for unexpected room=%s: %s", room, result) - continue - if room in received_rooms: - self.logger.info("Duplicate decoder result for room=%s ignored", room) - continue - - received_rooms.add(room) - received_results.append(result) - - if result.get("ok", False): - self.logger.info( - "Decoder result received room=%s save_path=%s (%s/%s)", - room, - result.get("save_path"), - len(received_rooms), - len(expected_rooms), - ) - else: - self.logger.error( - "Decoder result failed room=%s error=%s (%s/%s)", - room, - result.get("error"), - len(received_rooms), - len(expected_rooms), - ) + _handle_decoder_result(result) self.logger.info("All decoder results received. Controller exiting.") + if batch_request_start_ts is None: + batch_request_start_ts = time.time() + batch_total_time_s = time.time() - batch_request_start_ts + self.logger.info( + "Batch total elapsed time: requests=%s completed=%s total_time_s=%.3f", + len(expected_rooms), + len(received_rooms), + batch_total_time_s, + ) finally: self._shutting_down = True monitor_stop_event.set() diff --git a/lightx2v/disagg/services/decoder.py b/lightx2v/disagg/services/decoder.py index b558470fa..516986ed5 100644 --- a/lightx2v/disagg/services/decoder.py +++ b/lightx2v/disagg/services/decoder.py @@ -1,5 +1,7 @@ import hashlib import json +import math +import os import threading import time from collections import deque @@ -61,6 +63,7 @@ def __init__(self, config: dict): daemon=True, ) self._reporter_thread.start() + self.sync_comm = str(os.getenv("SYNC_COMM", "")).strip().lower() not in ("", "0", "false", "no", "off") self.load_models() def _get_queue_metrics(self) -> dict[str, Any]: @@ -114,7 +117,10 @@ def _ensure_phase2_request_buffer(self) -> bool: return True def init(self, config): - self.config = config + self._sync_runtime_config(config) + self.encoder_engine_rank = int(self.config.get("encoder_engine_rank", self.encoder_engine_rank)) + self.transformer_engine_rank = int(self.config.get("transformer_engine_rank", self.transformer_engine_rank)) + self.decoder_engine_rank = int(self.config.get("decoder_engine_rank", self.decoder_engine_rank)) shared_slots = int(self.config.get("rdma_buffer_slots", self._phase2_slots)) shared_slot_size = int(self.config.get("rdma_buffer_slot_size", 4096)) self._phase2_server_ip = str(self.config.get("rdma_phase2_host", self._phase2_server_ip)) @@ -181,6 +187,9 @@ def alloc_memory(self, request: AllocationRequest) -> MemoryHandle: def process(self, config): self.logger.info("Starting processing in DecoderService...") room = config.get("data_bootstrap_room", 0) + decoder_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("decoder", {}) + decoder_metrics["compute_start_ts"] = time.time() + strict_meta_hash_check = str(os.getenv("LIGHTX2V_STRICT_META_HASH", "0")).strip().lower() in {"1", "true", "yes", "on"} room_buffers = self._rdma_buffers.get(room) receiver = self.data_receiver.get(room) @@ -207,15 +216,58 @@ def _sha256_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]: raise RuntimeError("Phase2 RDMA buffers require [latents, meta] entries.") meta_buf = room_buffers[1] - meta_bytes = _buffer_view(meta_buf, torch.uint8, (meta_buf.numel(),)).detach().contiguous().cpu().numpy().tobytes() - meta_str = meta_bytes.split(b"\x00", 1)[0].decode("utf-8") if meta_bytes else "" - if not meta_str: - raise ValueError("missing latents metadata from transformer") - meta = json.loads(meta_str) + def _read_phase2_meta() -> tuple[dict, str]: + meta_bytes = _buffer_view(meta_buf, torch.uint8, (meta_buf.numel(),)).detach().contiguous().cpu().numpy().tobytes() + meta_str = meta_bytes.split(b"\x00", 1)[0].decode("utf-8", errors="ignore") if meta_bytes else "" + if not meta_str: + raise ValueError("missing latents metadata from transformer") + parsed = json.loads(meta_str) + if not isinstance(parsed, dict): + raise ValueError(f"phase2 metadata type mismatch: {type(parsed)}") + return parsed, meta_str + + def _infer_latents_shape_from_config() -> tuple[int, int, int, int]: + z_dim = int(config.get("vae_z_dim", 16)) + vae_stride = config.get("vae_stride", (4, 8, 8)) + stride_t = int(vae_stride[0]) + stride_h = int(vae_stride[1]) + stride_w = int(vae_stride[2]) + target_video_length = int(config.get("target_video_length", 81)) + target_height = int(config.get("target_height", 480)) + target_width = int(config.get("target_width", 832)) + + t_prime = 1 + (target_video_length - 1) // stride_t + h_prime = int(math.ceil(target_height / stride_h)) + w_prime = int(math.ceil(target_width / stride_w)) + return (z_dim, t_prime, h_prime, w_prime) + + meta = None + meta_str = "" + for attempt in range(3): + try: + meta, meta_str = _read_phase2_meta() + break + except Exception as exc: + if attempt < 2: + # Guard against rare stale/partial metadata visibility. + time.sleep(0.02) + continue + self.logger.warning( + "Invalid phase2 metadata for room=%s, fallback to config-derived shape. err=%s raw_prefix=%r", + room, + exc, + meta_str[:128], + ) + meta = { + "latents_shape": list(_infer_latents_shape_from_config()), + "latents_dtype": str(GET_DTYPE()), + "latents_hash": None, + } latents_shape_val = meta.get("latents_shape") if not isinstance(latents_shape_val, list) or len(latents_shape_val) != 4: - raise ValueError("invalid latents_shape in phase2 metadata") + latents_shape_val = list(_infer_latents_shape_from_config()) + self.logger.warning("phase2 metadata missing/invalid latents_shape for room=%s, using fallback shape=%s", room, latents_shape_val) latent_shape = tuple(int(value) for value in latents_shape_val) dtype_map = { @@ -229,7 +281,10 @@ def _sha256_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]: if list(latents.shape) != meta.get("latents_shape"): raise ValueError("latents shape mismatch between transformer and decoder") if meta.get("latents_hash") is not None and _sha256_tensor(latents) != meta.get("latents_hash"): - raise ValueError("latents hash mismatch between transformer and decoder") + msg = "latents hash mismatch between transformer and decoder" + if strict_meta_hash_check: + raise ValueError(msg) + self.logger.warning("%s for room=%s, continue with non-strict mode", msg, room) latents = latents.to(torch.device(AI_DEVICE)).contiguous() if self.vae_decoder is None: @@ -238,6 +293,7 @@ def _sha256_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]: self.logger.info("Decoding latents in DecoderService...") gen_video = self.vae_decoder.decode(latents.to(GET_DTYPE())) gen_video_final = wan_vae_to_comfy(gen_video) + decoder_metrics["compute_end_ts"] = time.time() save_path = config.get("save_path") if save_path is None: @@ -245,6 +301,7 @@ def _sha256_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]: self.logger.info(f"Saving video to {save_path}...") save_to_video(gen_video_final, save_path, fps=config.get("fps", 16), method="ffmpeg") + decoder_metrics["output_enqueued_ts"] = time.time() self.logger.info("Done!") return save_path @@ -309,6 +366,11 @@ def run(self, stop_event=None): config["transformer_node_address"] = packet.get("transformer_node_address", "127.0.0.1") else: config = packet + if not isinstance(config, dict) or "data_bootstrap_room" not in config: + self.logger.warning("Ignored incomplete phase2 packet from RDMA buffer: %s", packet) + continue + decoder_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("decoder", {}) + decoder_metrics["request_received_ts"] = time.time() self.logger.info("Received request config from RDMA buffer: %s", {k: v for k, v in config.items()}) req_queue.append(config) @@ -324,7 +386,11 @@ def run(self, stop_event=None): ready_rooms: List[int] = [] failed_rooms: List[int] = [] - for room in list(waiting_queue.keys()): + waiting_rooms = list(waiting_queue.keys()) + if self.sync_comm and waiting_rooms: + waiting_rooms = [waiting_rooms[0]] + + for room in waiting_rooms: receiver = self.data_receiver.get(room) if receiver is None: failed_rooms.append(room) @@ -359,6 +425,7 @@ def run(self, stop_event=None): "ok": True, "data_bootstrap_room": int(room), "save_path": save_path, + "request_metrics": config.get("request_metrics"), }, ) except Exception: @@ -374,6 +441,7 @@ def run(self, stop_event=None): "data_bootstrap_room": int(room), "save_path": None, "error": "decoder process failed", + "request_metrics": config.get("request_metrics"), }, ) finally: diff --git a/lightx2v/disagg/services/encoder.py b/lightx2v/disagg/services/encoder.py index 057dde5cc..c75bac55a 100644 --- a/lightx2v/disagg/services/encoder.py +++ b/lightx2v/disagg/services/encoder.py @@ -1,5 +1,6 @@ import hashlib import json +import os import threading import time from collections import deque @@ -76,8 +77,18 @@ def __init__(self, config: dict): daemon=True, ) self._reporter_thread.start() + self.sync_comm = str(os.getenv("SYNC_COMM", "")).strip().lower() not in ("", "0", "false", "no", "off") self.load_models() + def _wait_sender_success(self, room: int, sender: DataSender): + while True: + status = sender.poll() + if status == DataPoll.Success: + return + if status == DataPoll.Failed: + raise RuntimeError(f"DataSender transfer failed for room={room}") + time.sleep(0.001) + def _get_queue_metrics(self) -> dict[str, Any]: with self._queue_metrics_lock: queue_sizes = dict(self._queue_metrics.get("queue_sizes", {})) @@ -185,7 +196,7 @@ def _ensure_phase1_meta_buffer(self) -> bool: return True def init(self, config): - self.config = config + self._sync_runtime_config(config) shared_slots = int(self.config.get("rdma_buffer_slots", self._request_slots)) shared_slot_size = int(self.config.get("rdma_buffer_slot_size", 4096)) self._request_server_ip = str(self.config.get("rdma_request_host", self._request_server_ip)) @@ -238,13 +249,6 @@ def init(self, config): self.data_mgr.init(data_args, data_bootstrap_room) self.data_sender[data_bootstrap_room] = DataSender(self.data_mgr, data_bootstrap_addr, data_bootstrap_room) - phase1_meta = { - "request_config": dict(self.config), - "encoder_node_address": self.data_mgr.get_localhost(), - "encoder_session_id": self.data_mgr.get_session_id(), - } - self._phase1_rdma_buffer.produce(phase1_meta) - def load_models(self): self.logger.info("Loading Encoder Models...") @@ -345,6 +349,8 @@ def process(self, config): """ self.logger.info("Starting processing in EncoderService...") room = int(config.get("data_bootstrap_room", 0)) + encoder_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("encoder", {}) + encoder_metrics["compute_start_ts"] = time.time() room_buffers = self._rdma_buffers.get(room) sender = self.data_sender.get(room) @@ -411,6 +417,7 @@ def process(self, config): else: raise ValueError(f"Unsupported task: {task}") + encoder_metrics["compute_end_ts"] = time.time() self.logger.info("Encode processing completed. Preparing to send data...") if self.data_mgr is not None and sender is not None: @@ -499,7 +506,19 @@ def _sha256_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]: meta_buf[: len(meta_bytes)].copy_(torch.from_numpy(np.frombuffer(meta_bytes, dtype=np.uint8))) buffer_ptrs = [buf.data_ptr() for buf in room_buffers] + # Publish phase1 request metadata after compute so downstream can see latest metrics. + encoder_metrics["output_enqueued_ts"] = time.time() + phase1_meta = { + "request_config": dict(config), + "encoder_node_address": self.data_mgr.get_localhost(), + "encoder_session_id": self.data_mgr.get_session_id(), + } + if self._phase1_rdma_buffer is None: + raise RuntimeError("phase1 RDMA buffer is not ready") + self._phase1_rdma_buffer.produce(phase1_meta) sender.send(buffer_ptrs) + if self.sync_comm: + self._wait_sender_success(room, sender) def release_memory(self, room: int): """ @@ -561,6 +580,11 @@ def run(self, stop_event=None): if self._request_rdma_buffer is not None: config = self._request_rdma_buffer.consume() if config is not None: + if not isinstance(config, dict) or "data_bootstrap_room" not in config: + self.logger.warning("Ignored incomplete request packet from RDMA buffer: %s", config) + continue + encoder_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("encoder", {}) + encoder_metrics["request_received_ts"] = time.time() self.logger.info("Received request config from RDMA buffer: %s", {k: v for k, v in config.items()}) req_queue.append(config) @@ -591,6 +615,10 @@ def run(self, stop_event=None): completed_rooms.append(room) continue + if self.sync_comm: + completed_rooms.append(room) + continue + status = sender.poll() if status == DataPoll.Success: completed_rooms.append(room) diff --git a/lightx2v/disagg/services/transformer.py b/lightx2v/disagg/services/transformer.py index 33e98e5f1..338a0d285 100644 --- a/lightx2v/disagg/services/transformer.py +++ b/lightx2v/disagg/services/transformer.py @@ -1,5 +1,6 @@ import hashlib import json +import os import threading import time from collections import deque @@ -74,8 +75,18 @@ def __init__(self, config: dict): daemon=True, ) self._reporter_thread.start() + self.sync_comm = str(os.getenv("SYNC_COMM", "")).strip().lower() not in ("", "0", "false", "no", "off") self.load_models() + def _wait_sender_success(self, room: int, sender: DataSender): + while True: + status = sender.poll() + if status == DataPoll.Success: + return + if status == DataPoll.Failed: + raise RuntimeError(f"DataSender transfer failed for room={room}") + time.sleep(0.001) + def _get_queue_metrics(self) -> dict[str, Any]: with self._queue_metrics_lock: queue_sizes = dict(self._queue_metrics.get("queue_sizes", {})) @@ -163,7 +174,7 @@ def _ensure_phase2_meta_buffer(self) -> bool: return True def init(self, config): - self.config = config + self._sync_runtime_config(config) shared_slots = int(self.config.get("rdma_buffer_slots", self._phase1_slots)) shared_slot_size = int(self.config.get("rdma_buffer_slot_size", 4096)) self._phase1_server_ip = str(self.config.get("rdma_phase1_host", self._phase1_server_ip)) @@ -175,6 +186,9 @@ def init(self, config): self._phase2_slots = shared_slots self._phase2_slot_size = shared_slot_size + if self.scheduler is not None: + self.scheduler.refresh_from_config(self.config) + # Set global seed if present in config, though specific process calls might reuse it if "seed" in self.config: seed_all(self.config["seed"]) @@ -239,14 +253,6 @@ def init(self, config): self.data_mgr2.init(data_args, data_bootstrap_room) self.data_sender[data_bootstrap_room] = DataSender(self.data_mgr2, data_bootstrap_addr, data_bootstrap_room) - self._phase2_rdma_buffer.produce( - { - "request_config": dict(self.config), - "transformer_node_address": self.data_mgr2.get_localhost(), - "transformer_session_id": self.data_mgr2.get_session_id(), - } - ) - def load_models(self): self.logger.info("Loading Transformer Models...") @@ -301,6 +307,8 @@ def process(self, config): """ self.logger.info("Starting processing in TransformerService...") room = config.get("data_bootstrap_room", 0) + transformer_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("transformer", {}) + transformer_metrics["compute_start_ts"] = time.time() phase1_buffers = self.rdma_buffer1.get(room) phase2_buffers = self.rdma_buffer2.get(room) @@ -359,11 +367,73 @@ def _sha256_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]: buffer_index += 1 meta_buf = phase1_buffers[buffer_index] - meta_bytes = _buffer_view(meta_buf, torch.uint8, (meta_buf.numel(),)).detach().contiguous().cpu().numpy().tobytes() - meta_str = meta_bytes.split(b"\x00", 1)[0].decode("utf-8") if meta_bytes else "" - if not meta_str: - raise ValueError("missing metadata from encoder") - meta = json.loads(meta_str) + strict_meta_hash_check = str(os.getenv("LIGHTX2V_STRICT_META_HASH", "0")).strip().lower() in {"1", "true", "yes", "on"} + + def _load_phase1_meta(max_retries: int = 3, retry_sleep_s: float = 0.02) -> dict: + last_error: Optional[Exception] = None + last_preview = "" + for attempt in range(1, max_retries + 1): + meta_bytes = _buffer_view(meta_buf, torch.uint8, (meta_buf.numel(),)).detach().contiguous().cpu().numpy().tobytes() + raw_payload = meta_bytes.split(b"\x00", 1)[0] if meta_bytes else b"" + if not raw_payload: + last_error = ValueError("missing metadata from encoder") + if attempt < max_retries: + time.sleep(retry_sleep_s) + continue + break + try: + meta_str = raw_payload.decode("utf-8") + except UnicodeDecodeError as err: + last_error = err + last_preview = raw_payload[:32].hex() + if attempt < max_retries: + self.logger.warning( + "Invalid phase1 metadata UTF-8 for room=%s (attempt %s/%s), retrying...", + room, + attempt, + max_retries, + ) + time.sleep(retry_sleep_s) + continue + break + + if not meta_str.strip(): + last_error = ValueError("empty metadata payload from encoder") + if attempt < max_retries: + time.sleep(retry_sleep_s) + continue + break + + try: + parsed = json.loads(meta_str) + except json.JSONDecodeError as err: + last_error = err + last_preview = meta_str[:120] + if attempt < max_retries: + self.logger.warning( + "Invalid phase1 metadata JSON for room=%s (attempt %s/%s), retrying...", + room, + attempt, + max_retries, + ) + time.sleep(retry_sleep_s) + continue + break + + if not isinstance(parsed, dict): + last_error = TypeError(f"phase1 metadata must be a dict, got {type(parsed).__name__}") + last_preview = str(parsed)[:120] + if attempt < max_retries: + time.sleep(retry_sleep_s) + continue + break + + return parsed + + preview_suffix = f", preview={last_preview}" if last_preview else "" + raise ValueError(f"failed to load phase1 metadata for room={room}: {last_error}{preview_suffix}") + + meta = _load_phase1_meta() meta_shapes = {k: v for k, v in meta.items() if k.endswith("_shape")} meta_dtypes = {k: v for k, v in meta.items() if k.endswith("_dtype")} self.logger.info("Transformer meta shapes: %s", meta_shapes) @@ -413,33 +483,48 @@ def _get_shape(key: str) -> tuple[int, ...]: if list(context.shape) != meta.get("context_shape"): raise ValueError("context shape mismatch between encoder and transformer") if meta.get("context_hash") is not None and _sha256_tensor(context) != meta.get("context_hash"): - raise ValueError("context hash mismatch between encoder and transformer") + msg = "context hash mismatch between encoder and transformer" + if strict_meta_hash_check: + raise ValueError(msg) + self.logger.warning("%s for room=%s, continue with non-strict mode", msg, room) if enable_cfg: if context_null is not None: if list(context_null.shape) != meta.get("context_null_shape"): raise ValueError("context_null shape mismatch between encoder and transformer") if meta.get("context_null_hash") is not None: if _sha256_tensor(context_null) != meta.get("context_null_hash"): - raise ValueError("context_null hash mismatch between encoder and transformer") + msg = "context_null hash mismatch between encoder and transformer" + if strict_meta_hash_check: + raise ValueError(msg) + self.logger.warning("%s for room=%s, continue with non-strict mode", msg, room) if task == "i2v": if clip_encoder_out is not None: if list(clip_encoder_out.shape) != meta.get("clip_shape"): raise ValueError("clip shape mismatch between encoder and transformer") if meta.get("clip_hash") is not None: if _sha256_tensor(clip_encoder_out) != meta.get("clip_hash"): - raise ValueError("clip hash mismatch between encoder and transformer") + msg = "clip hash mismatch between encoder and transformer" + if strict_meta_hash_check: + raise ValueError(msg) + self.logger.warning("%s for room=%s, continue with non-strict mode", msg, room) if vae_encoder_out is not None: if list(vae_encoder_out.shape) != meta.get("vae_shape"): raise ValueError("vae shape mismatch between encoder and transformer") if meta.get("vae_hash") is not None: if _sha256_tensor(vae_encoder_out) != meta.get("vae_hash"): - raise ValueError("vae hash mismatch between encoder and transformer") + msg = "vae hash mismatch between encoder and transformer" + if strict_meta_hash_check: + raise ValueError(msg) + self.logger.warning("%s for room=%s, continue with non-strict mode", msg, room) if meta.get("latent_shape") is None or list(latent_shape) != meta.get("latent_shape"): raise ValueError("latent_shape mismatch between encoder and transformer") if meta.get("latent_hash") is not None: latent_tensor = torch.tensor(latent_shape, device=AI_DEVICE, dtype=torch.int64) if _sha256_tensor(latent_tensor) != meta.get("latent_hash"): - raise ValueError("latent_shape hash mismatch between encoder and transformer") + msg = "latent_shape hash mismatch between encoder and transformer" + if strict_meta_hash_check: + raise ValueError(msg) + self.logger.warning("%s for room=%s, continue with non-strict mode", msg, room) inputs = { "text_encoder_output": text_encoder_output, @@ -470,6 +555,7 @@ def _get_shape(key: str) -> tuple[int, ...]: self.scheduler.step_post() latents = self.scheduler.latents + transformer_metrics["compute_end_ts"] = time.time() # Send latents to DecoderService if len(phase2_buffers) < 2: @@ -501,7 +587,22 @@ def _get_shape(key: str) -> tuple[int, ...]: meta_view[: len(meta_bytes)].copy_(torch.from_numpy(np.frombuffer(meta_bytes, dtype=np.uint8))) buffer_ptrs = [buf.data_ptr() for buf in phase2_buffers] + # Publish phase2 request metadata after compute so downstream can see latest metrics. + transformer_metrics["output_enqueued_ts"] = time.time() + phase2_request_config = dict(config) + phase2_request_config["transformer_engine_rank"] = self.transformer_engine_rank + if self._phase2_rdma_buffer is None: + raise RuntimeError("phase2 RDMA buffer is not ready") + self._phase2_rdma_buffer.produce( + { + "request_config": phase2_request_config, + "transformer_node_address": self.data_mgr2.get_localhost(), + "transformer_session_id": self.data_mgr2.get_session_id(), + } + ) sender.send(buffer_ptrs) + if self.sync_comm: + self._wait_sender_success(room, sender) def release_memory(self, room: int): """ @@ -583,6 +684,11 @@ def run(self, stop_event=None): config["encoder_node_address"] = packet.get("encoder_node_address", "127.0.0.1") else: config = packet + if not isinstance(config, dict) or "data_bootstrap_room" not in config: + self.logger.warning("Ignored incomplete phase1 packet from RDMA buffer: %s", packet) + continue + transformer_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("transformer", {}) + transformer_metrics["request_received_ts"] = time.time() self.logger.info("%s Received request config from RDMA buffer: %s", self.transformer_engine_rank, {k: v for k, v in config.items()}) req_queue.append(config) @@ -598,7 +704,11 @@ def run(self, stop_event=None): ready_rooms: List[int] = [] failed_rooms: List[int] = [] - for room, config in list(waiting_queue.items()): + waiting_items = list(waiting_queue.items()) + if self.sync_comm and waiting_items: + waiting_items = [waiting_items[0]] + + for room, config in waiting_items: receiver = self.data_receiver.get(room) if receiver is None: failed_rooms.append(room) diff --git a/lightx2v/disagg/workload.py b/lightx2v/disagg/workload.py new file mode 100644 index 000000000..82b5484fe --- /dev/null +++ b/lightx2v/disagg/workload.py @@ -0,0 +1,294 @@ +from __future__ import annotations + +import copy +import json +import os +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Optional + +try: + from locust import LoadTestShape, User, events, task +except ModuleNotFoundError: + class _EventHook: + def add_listener(self, fn): + return fn + + def fire(self, **kwargs): + return None + + class _Events: + def __init__(self): + self.test_start = _EventHook() + self.test_stop = _EventHook() + self.request = _EventHook() + + class LoadTestShape: # type: ignore[no-redef] + pass + + class User: # type: ignore[no-redef] + pass + + def task(fn): # type: ignore[no-redef] + return fn + + events = _Events() # type: ignore[no-redef] + +from lightx2v.disagg.conn import REQUEST_POLLING_PORT, ReqManager + + +REPO_ROOT = Path(__file__).resolve().parents[2] +DEFAULT_BASE_CONFIG_JSON = REPO_ROOT / "configs" / "disagg" / "wan22_i2v_distill_controller.json" +DEFAULT_STAGE_DEFINITIONS_JSON = REPO_ROOT / "configs" / "disagg" / "wan22_i2v_workload_stages.json" + +_TEST_START_MONOTONIC: Optional[float] = None + + +def _deep_merge(base: dict[str, Any], overlay: dict[str, Any]) -> dict[str, Any]: + merged = copy.deepcopy(base) + for key, value in overlay.items(): + if isinstance(value, dict) and isinstance(merged.get(key), dict): + merged[key] = _deep_merge(merged[key], value) + else: + merged[key] = copy.deepcopy(value) + return merged + + +def _load_base_config() -> dict[str, Any]: + config_path = os.getenv("DISAGG_BASE_CONFIG_JSON") + if config_path: + path = Path(config_path) + else: + path = DEFAULT_BASE_CONFIG_JSON + + if path.is_file(): + with path.open("r", encoding="utf-8") as handle: + return json.load(handle) + + return { + "task": "i2v", + "model_cls": "wan2.2_moe", + "seed": 42, + "prompt": "A cinematic cat scene with detailed lighting and motion.", + "negative_prompt": "blurry, low quality, artifacts", + "save_path": str(REPO_ROOT / "save_results" / "locust_disagg.mp4"), + } + + +def _load_stage_definitions() -> list[dict[str, Any]]: + stage_file = Path(os.getenv("DISAGG_WORKLOAD_STAGES_JSON", str(DEFAULT_STAGE_DEFINITIONS_JSON))) + if not stage_file.is_file(): + raise FileNotFoundError(f"workload stage config not found: {stage_file}") + + with stage_file.open("r", encoding="utf-8") as handle: + loaded = json.load(handle) + + if not isinstance(loaded, list) or not loaded: + raise ValueError(f"{stage_file} must contain a non-empty JSON list") + + return loaded + + +@dataclass(frozen=True) +class StageSpec: + name: str + duration_s: float + user_count: int + spawn_rate: float + wait_time_s: float = 0.0 + config_variants: list[dict[str, Any]] = field(default_factory=list) + + @staticmethod + def from_dict(raw: dict[str, Any]) -> "StageSpec": + name = str(raw.get("name", "stage")) + duration_s = float(raw.get("duration_s", 0.0)) + user_count = int(raw.get("user_count", 1)) + spawn_rate = float(raw.get("spawn_rate", max(1, user_count))) + wait_time_s = float(raw.get("wait_time_s", 0.0)) + config_variants = raw.get("config_variants", []) or [] + if not isinstance(config_variants, list): + raise ValueError(f"stage {name}: config_variants must be a list") + return StageSpec( + name=name, + duration_s=max(duration_s, 0.0), + user_count=max(user_count, 1), + spawn_rate=max(spawn_rate, 0.1), + wait_time_s=max(wait_time_s, 0.0), + config_variants=[variant for variant in config_variants if isinstance(variant, dict)], + ) + + +def _load_stage_specs() -> list[StageSpec]: + return [StageSpec.from_dict(stage) for stage in _load_stage_definitions()] + + +def load_base_config() -> dict[str, Any]: + return _load_base_config() + + +def load_stage_specs() -> list[StageSpec]: + return _load_stage_specs() + + +def _elapsed_since_start() -> float: + if _TEST_START_MONOTONIC is None: + return 0.0 + return max(0.0, time.monotonic() - _TEST_START_MONOTONIC) + + +def _stage_index_for_elapsed(stages: list[StageSpec], elapsed_s: float) -> int: + if not stages: + return 0 + + accumulated = 0.0 + for index, stage in enumerate(stages): + accumulated += stage.duration_s + if elapsed_s < accumulated: + return index + return len(stages) - 1 + + +def _current_stage(stages: list[StageSpec]) -> StageSpec: + return stages[_stage_index_for_elapsed(stages, _elapsed_since_start())] + + +def _build_request_payload(base_config: dict[str, Any], stage: StageSpec, request_index: int) -> dict[str, Any]: + payload = copy.deepcopy(base_config) + variant = stage.config_variants[request_index % len(stage.config_variants)] if stage.config_variants else {} + payload = _deep_merge(payload, variant) + + payload.setdefault("request_metrics", {}) + payload["request_metrics"]["request_id"] = request_index + payload["request_metrics"]["client_send_ts"] = time.time() + payload["request_metrics"]["stage_name"] = stage.name + payload["request_metrics"]["load_stage"] = stage.name + + if "data_bootstrap_room" not in payload: + payload["data_bootstrap_room"] = request_index + + save_path_prefix = os.getenv("DISAGG_WORKLOAD_SAVE_PREFIX") + if save_path_prefix: + save_root = Path(save_path_prefix) + save_root.parent.mkdir(parents=True, exist_ok=True) + payload["save_path"] = str(save_root.with_name(f"{save_root.stem}_{stage.name}_{request_index}{save_root.suffix}")) + + return payload + + +def _get_controller_target() -> tuple[str, int]: + host = os.getenv("DISAGG_CONTROLLER_HOST", "127.0.0.1") + port = int(os.getenv("DISAGG_CONTROLLER_REQUEST_PORT", str(REQUEST_POLLING_PORT - 2))) + return host, port + + +def _send_to_controller(payload: dict[str, Any]) -> None: + host, port = _get_controller_target() + ReqManager().send(host, port, payload) + + +def start_workload_clock() -> None: + global _TEST_START_MONOTONIC + _TEST_START_MONOTONIC = time.monotonic() + + +def current_stage(stages: Optional[list[StageSpec]] = None) -> StageSpec: + loaded_stages = stages or _load_stage_specs() + return _current_stage(loaded_stages) + + +def build_payload(base_config: dict[str, Any], stage: StageSpec, request_index: int) -> dict[str, Any]: + return _build_request_payload(base_config, stage, request_index) + + +def send_workload_end_signal() -> None: + _send_to_controller( + { + "workload_end": True, + "request_metrics": { + "load_stage": "end", + "client_send_ts": time.time(), + }, + } + ) + + +@events.test_start.add_listener +def _on_test_start(environment, **kwargs): # type: ignore[override] + start_workload_clock() + + +@events.test_stop.add_listener +def _on_test_stop(environment, **kwargs): # type: ignore[override] + send_workload_end_signal() + + +class DisaggLoadShape(LoadTestShape): + """Time-based load shape for disaggregated LightX2V scenarios. + + Configure stages with DISAGG_WORKLOAD_STAGES_JSON as a JSON file path. Each stage supports: + - duration_s + - user_count + - spawn_rate + - wait_time_s + - config_variants + """ + + stages = _load_stage_specs() + + def tick(self): + elapsed_s = _elapsed_since_start() + total_duration_s = sum(stage.duration_s for stage in self.stages) + if total_duration_s > 0 and elapsed_s >= total_duration_s: + return None + + stage = _current_stage(self.stages) + return stage.user_count, stage.spawn_rate + + +class DisaggUser(User): + base_config = _load_base_config() + stages = _load_stage_specs() + req_mgr = ReqManager() + + def wait_time(self): # type: ignore[override] + stage = _current_stage(self.stages) + return stage.wait_time_s + + @task + def submit_request(self): + stage = _current_stage(self.stages) + request_index = int(time.time() * 1000) % 1_000_000 + payload = _build_request_payload(self.base_config, stage, request_index) + send_started = time.perf_counter() + try: + host, port = _get_controller_target() + self.req_mgr.send(host, port, payload) + events.request.fire( + request_type="zmq", + name=f"{stage.name}:config_push", + response_time=(time.perf_counter() - send_started) * 1000.0, + response_length=len(str(payload)), + exception=None, + ) + except Exception as exc: + events.request.fire( + request_type="zmq", + name=f"{stage.name}:config_push", + response_time=(time.perf_counter() - send_started) * 1000.0, + response_length=0, + exception=exc, + ) + + +__all__ = [ + "DisaggLoadShape", + "DisaggUser", + "StageSpec", + "start_workload_clock", + "current_stage", + "build_payload", + "send_workload_end_signal", + "load_base_config", + "load_stage_specs", +] \ No newline at end of file diff --git a/scripts/disagg/extract_dynamic_latency.py b/scripts/disagg/extract_dynamic_latency.py new file mode 100644 index 000000000..b9fded883 --- /dev/null +++ b/scripts/disagg/extract_dynamic_latency.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import ast +import csv +import re +from datetime import datetime +from pathlib import Path + +WAIT_RE = re.compile(r"^\[INFO\]\s+(\d{2}\s+\w{3}\s+\d{4}\s+\d{2}:\d{2}:\d{2}).*Waiting workload configs on port=") +LAT_RE = re.compile(r"^\[INFO\]\s+(\d{2}\s+\w{3}\s+\d{4}\s+\d{2}:\d{2}:\d{2}).*Latency summary room=(\d+) metrics=(\{.*\})") +TS_FMT = "%d %b %Y %H:%M:%S" + + +def _fmt_float3(value): + try: + return f"{float(value):.3f}" + except (TypeError, ValueError): + return value + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Extract latency summary rows relative to waiting workload log time") + parser.add_argument( + "--log", + default="/root/zht/LightX2V/save_results/disagg_wan22_i2v_dynamic_controller.log", + help="Controller log path", + ) + parser.add_argument( + "--output", + default="/root/zht/LightX2V/save_results/disagg_wan22_i2v_dynamic_results.csv", + help="Output table path", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + log_path = Path(args.log) + out_path = Path(args.output) + + if not log_path.is_file(): + raise FileNotFoundError(f"log file not found: {log_path}") + + wait_ts = None + rows = [] + metric_keys = [] + + with log_path.open("r", encoding="utf-8", errors="ignore") as f: + for line in f: + if wait_ts is None: + m_wait = WAIT_RE.match(line) + if m_wait: + wait_ts = datetime.strptime(m_wait.group(1), TS_FMT) + continue + + m_lat = LAT_RE.match(line) + if not m_lat: + continue + + ts = datetime.strptime(m_lat.group(1), TS_FMT) + room = int(m_lat.group(2)) + metrics = ast.literal_eval(m_lat.group(3)) + if not isinstance(metrics, dict): + continue + + if wait_ts is None: + rel_s = "NA" + else: + rel_s = f"{int((ts - wait_ts).total_seconds())}s" + + if not metric_keys: + metric_keys = list(metrics.keys()) + + row = { + "room": room, + "latency_summary_ts": ts.strftime("%Y-%m-%d %H:%M:%S"), + "relative_to_waiting_s": rel_s, + } + for key in metric_keys: + value = metrics.get(key) + row[key] = "" if value is None else _fmt_float3(value) + rows.append(row) + + out_path.parent.mkdir(parents=True, exist_ok=True) + + header = ["room", "latency_summary_ts", "relative_to_waiting_s", *metric_keys] + with out_path.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=header) + writer.writeheader() + for row in rows: + writer.writerow(row) + + print(f"wrote {len(rows)} rows to {out_path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/disagg/kill_service.sh b/scripts/disagg/kill_service.sh index 3b6315890..57f79bc4d 100755 --- a/scripts/disagg/kill_service.sh +++ b/scripts/disagg/kill_service.sh @@ -2,14 +2,14 @@ set -euo pipefail -SCRIPT_NAME="run_wan_t2v_service.sh" +SCRIPT_NAME="run_wan22_i2v_distill.sh" -list_port=(5566 7788 12788 17788 27788) +list_port=(5566 12788 17788 27788) -n=10 +n=30 list_n=($(seq 0 $((n-1)))) -PORTS=(5555 12787) +PORTS=(5555 7788 7789 7790 12787) for a in "${list_port[@]}"; do for b in "${list_n[@]}"; do @@ -71,6 +71,28 @@ else echo "No running process found for ${SCRIPT_NAME}" fi +# Fallback cleanup for orphaned disagg service processes. +cleanup_patterns=( + "lightx2v.disagg.examples.run_service" + "python -m lightx2v.disagg" + "conda run -n lightx2v bash scripts/disagg/run_wan22_i2v_distill.sh" +) + +for pattern in "${cleanup_patterns[@]}"; do + echo "Stopping processes matching pattern: ${pattern}" + matched_pids=$(pgrep -f "$pattern" || true) + if [[ -z "${matched_pids}" ]]; then + echo "No process matched: ${pattern}" + continue + fi + + while read -r pid; do + [[ -z "$pid" ]] && continue + echo "Killing matched pid=$pid" + kill_pid_gracefully "$pid" + done <<< "$matched_pids" +done + for port in "${PORTS[@]}"; do echo "Stopping listeners on port ${port}" port_pids=$(find_listen_pids_by_port "$port") diff --git a/scripts/disagg/run_dynamic.sh b/scripts/disagg/run_dynamic.sh new file mode 100644 index 000000000..80120bdd2 --- /dev/null +++ b/scripts/disagg/run_dynamic.sh @@ -0,0 +1,84 @@ +#!/bin/bash + +set -euo pipefail + +lightx2v_path=/root/zht/LightX2V +model_path=${lightx2v_path}/models/lightx2v/Wan2.2-Distill-Models + +# base.sh expects PYTHONPATH to be defined under `set -u`. +export PYTHONPATH=${PYTHONPATH:-} + +source ${lightx2v_path}/scripts/base/base.sh + +export CC=/usr/bin/gcc-13 +export CXX=/usr/bin/g++-13 +export CUDAHOSTCXX=/usr/bin/g++-13 +if [[ -n "${NVCC_PREPEND_FLAGS:-}" ]]; then + export NVCC_PREPEND_FLAGS="${NVCC_PREPEND_FLAGS} -allow-unsupported-compiler" +else + export NVCC_PREPEND_FLAGS="-allow-unsupported-compiler" +fi + +export RDMA_IFACE=${RDMA_IFACE:-erdma_0} +export MOONCAKE_DEVICE_NAME=${MOONCAKE_DEVICE_NAME:-eth0} +if [[ -z "${MOONCAKE_LOCAL_HOSTNAME:-}" ]]; then + _mc_ip=$(ip -4 -o addr show dev "${MOONCAKE_DEVICE_NAME}" 2>/dev/null | awk '{print $4}' | cut -d/ -f1 | head -n 1) + if [[ -n "${_mc_ip}" ]]; then + export MOONCAKE_LOCAL_HOSTNAME="${_mc_ip}" + fi +fi + +export DISAGG_CONTROLLER_HOST=${DISAGG_CONTROLLER_HOST:-127.0.0.1} +export DISAGG_CONTROLLER_REQUEST_PORT=${DISAGG_CONTROLLER_REQUEST_PORT:-12786} + +controller_cfg=${lightx2v_path}/configs/disagg/wan22_i2v_distill_controller.json +seed=${SEED:-42} +prompt=${PROMPT:-"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard."} +negative_prompt=${NEGATIVE_PROMPT:-"镜头晃动,色调艳丽,过曝,静态"} +save_result_path=${SAVE_RESULT_PATH:-${lightx2v_path}/save_results/wan22_i2v_dynamic.mp4} + +controller_log=${lightx2v_path}/save_results/disagg_wan22_i2v_dynamic_controller.log +user_log=${lightx2v_path}/save_results/disagg_wan22_i2v_dynamic_user.log + +cleanup() { + local pids=("${user_pid:-}" "${controller_pid:-}") + for pid in "${pids[@]}"; do + if [[ -n "${pid}" ]] && kill -0 "${pid}" 2>/dev/null; then + kill "${pid}" 2>/dev/null || true + fi + done +} + +trap cleanup EXIT INT TERM + +python -m lightx2v.disagg.examples.run_service \ + --service controller \ + --model_cls wan2.2_moe \ + --task i2v \ + --model_path ${model_path} \ + --config_json ${controller_cfg} \ + --seed ${seed} \ + --prompt "${prompt}" \ + --negative_prompt "${negative_prompt}" \ + --save_result_path ${save_result_path} \ + > ${controller_log} 2>&1 & +controller_pid=$! + +echo "controller started pid=${controller_pid}" +sleep 8 + +python -m lightx2v.disagg.examples.run_user \ + --controller_host "${DISAGG_CONTROLLER_HOST}" \ + --controller_request_port "${DISAGG_CONTROLLER_REQUEST_PORT}" \ + > ${user_log} 2>&1 & +user_pid=$! + +echo "run_user started pid=${user_pid}" + +wait ${user_pid} +echo "run_user finished" + +wait ${controller_pid} +echo "controller finished" + +echo "logs: ${controller_log} ${user_log}" diff --git a/scripts/disagg/run_wan22_i2v_distill.sh b/scripts/disagg/run_wan22_i2v_distill.sh new file mode 100755 index 000000000..5133f5fd6 --- /dev/null +++ b/scripts/disagg/run_wan22_i2v_distill.sh @@ -0,0 +1,116 @@ +#!/bin/bash + +# set path firstly +lightx2v_path=/root/zht/LightX2V +model_path=${lightx2v_path}/models/lightx2v/Wan2.2-Distill-Models + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +# Keep flashinfer enabled while ensuring nvcc uses a supported host compiler. +export CC=/usr/bin/gcc-13 +export CXX=/usr/bin/g++-13 +export CUDAHOSTCXX=/usr/bin/g++-13 +if [[ -n "${NVCC_PREPEND_FLAGS:-}" ]]; then + export NVCC_PREPEND_FLAGS="${NVCC_PREPEND_FLAGS} -allow-unsupported-compiler" +else + export NVCC_PREPEND_FLAGS="-allow-unsupported-compiler" +fi + +# Pin disagg RDMA and Mooncake to one NIC to avoid cross-NIC session mismatch. +export RDMA_IFACE=${RDMA_IFACE:-erdma_0} +export MOONCAKE_DEVICE_NAME=${MOONCAKE_DEVICE_NAME:-eth0} +if [[ -z "${MOONCAKE_LOCAL_HOSTNAME:-}" ]]; then + _mc_ip=$(ip -4 -o addr show dev "${MOONCAKE_DEVICE_NAME}" 2>/dev/null | awk '{print $4}' | cut -d/ -f1 | head -n 1) + if [[ -n "${_mc_ip}" ]]; then + export MOONCAKE_LOCAL_HOSTNAME="${_mc_ip}" + fi +fi +echo "RDMA_IFACE=${RDMA_IFACE} MOONCAKE_DEVICE_NAME=${MOONCAKE_DEVICE_NAME} MOONCAKE_LOCAL_HOSTNAME=${MOONCAKE_LOCAL_HOSTNAME:-unset}" + +# Enable simulated network jitter by default for this test script. +# Set NETWORK_LATENCY=0 before running to disable it. +# export NETWORK_LATENCY=${NETWORK_LATENCY:-1} +# echo "NETWORK_LATENCY=${NETWORK_LATENCY}" + +controller_cfg=${lightx2v_path}/configs/disagg/wan22_i2v_distill_controller.json +encoder_cfg=${lightx2v_path}/configs/disagg/wan22_i2v_distill_encoder.json +transformer_cfg=${lightx2v_path}/configs/disagg/wan22_i2v_distill_transformer.json +decoder_cfg=${lightx2v_path}/configs/disagg/wan22_i2v_distill_decoder.json + +seed=42 +request_count=30 +prompt="Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds." +negative_prompt="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +save_result_path=${lightx2v_path}/save_results/wan22_i2v_distill_disagg.mp4 +output_files=() +for ((i=0; i/dev/null; then + kill "${pid}" 2>/dev/null || true + fi + done +} + +trap cleanup EXIT INT TERM + +if [[ ! -f "${controller_cfg}" ]]; then + echo "Controller config not found: ${controller_cfg}" + exit 1 +fi + +# These are kept for manual split-service debug if needed. +if [[ ! -f "${encoder_cfg}" || ! -f "${transformer_cfg}" || ! -f "${decoder_cfg}" ]]; then + echo "One or more disagg stage configs are missing under configs/disagg" + exit 1 +fi + +python -m lightx2v.disagg.examples.run_service \ + --service controller \ + --model_cls wan2.2_moe \ + --task i2v \ + --model_path ${model_path} \ + --config_json ${controller_cfg} \ + --seed ${seed} \ + --prompt "${prompt}" \ + --negative_prompt "${negative_prompt}" \ + --save_result_path ${save_result_path} \ + > ${lightx2v_path}/save_results/disagg_wan22_i2v_distill_controller.log 2>&1 & +controller_pid=$! + +echo "Waiting for output videos: ${output_files[*]}" +wait_seconds=0 +max_wait_seconds=$((200 * request_count)) + +while true; do + all_generated=1 + for file in "${output_files[@]}"; do + if [[ ! -f "${file}" ]]; then + all_generated=0 + break + fi + done + + if (( all_generated )); then + echo "All ${request_count} output videos are generated." + break + fi + + if (( wait_seconds >= max_wait_seconds )); then + echo "Timeout waiting for output videos after ${max_wait_seconds}s" + exit 1 + fi + + sleep 5 + wait_seconds=$((wait_seconds + 5)) +done + +sleep 60 diff --git a/scripts/disagg/run_wan_t2v_service.sh b/scripts/disagg/run_wan_t2v_service.sh index 8facbc665..640954c99 100755 --- a/scripts/disagg/run_wan_t2v_service.sh +++ b/scripts/disagg/run_wan_t2v_service.sh @@ -23,12 +23,12 @@ transformer_cfg=${lightx2v_path}/configs/disagg/wan_t2v_disagg_transformer.json decoder_cfg=${lightx2v_path}/configs/disagg/wan_t2v_disagg_decoder.json seed=42 -request_count=10 +request_count=30 prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." negative_prompt="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" save_result_path=${lightx2v_path}/save_results/test_disagg.mp4 output_files=() -for ((i=1; i<=request_count; i++)); do +for ((i=0; i Date: Mon, 20 Apr 2026 11:18:47 +0800 Subject: [PATCH 3/9] fix scheduler bugs --- lightx2v/disagg/rdma_buffer.py | 150 ++- lightx2v/disagg/rdma_client.py | 88 +- lightx2v/disagg/rdma_server.py | 4 +- lightx2v/disagg/services/controller.py | 840 ++++++++++++---- lightx2v/disagg/services/data_mgr_sidecar.py | 975 +++++++++++++++++++ lightx2v/disagg/services/decoder.py | 33 +- lightx2v/disagg/services/encoder.py | 133 ++- lightx2v/disagg/services/transformer.py | 376 +++++-- lightx2v/models/schedulers/wan/scheduler.py | 10 + scripts/disagg/kill_service.sh | 65 +- scripts/disagg/run_dynamic.sh | 135 ++- 11 files changed, 2449 insertions(+), 360 deletions(-) create mode 100644 lightx2v/disagg/services/data_mgr_sidecar.py diff --git a/lightx2v/disagg/rdma_buffer.py b/lightx2v/disagg/rdma_buffer.py index 1c2b8927a..378e6585a 100644 --- a/lightx2v/disagg/rdma_buffer.py +++ b/lightx2v/disagg/rdma_buffer.py @@ -13,6 +13,8 @@ logger = logging.getLogger(__name__) +_U64_MASK = (1 << 64) - 1 + @dataclass class RDMABufferDescriptor: @@ -117,11 +119,15 @@ def descriptor(self) -> RDMABufferDescriptor: return self._descriptor def _write_local_u64(self, buf: bytearray, value: int): - buf[:8] = int(value).to_bytes(8, byteorder="little", signed=False) + buf[:8] = (int(value) & _U64_MASK).to_bytes(8, byteorder="little", signed=False) def _read_local_u64(self, buf: bytearray) -> int: return int.from_bytes(bytes(buf[:8]), byteorder="little", signed=False) + def _u64_distance(self, newer: int, older: int) -> int: + """Return unsigned circular distance on a 64-bit counter space.""" + return (int(newer) - int(older)) & _U64_MASK + def _rdma_faa(self, ptr_addr: int, add_value: int) -> int: if self.rdma_client is not None: return self.rdma_client.rdma_faa(ptr_addr, int(add_value), rkey=self.descriptor.rkey) @@ -137,14 +143,45 @@ def _rdma_faa(self, ptr_addr: int, add_value: int) -> int: with self._lock: if ptr_addr == self.descriptor.head_addr: old = self._read_local_u64(self._head_mem) - self._write_local_u64(self._head_mem, old + int(add_value)) + self._write_local_u64(self._head_mem, (old + int(add_value)) & _U64_MASK) return old if ptr_addr == self.descriptor.tail_addr: old = self._read_local_u64(self._tail_mem) - self._write_local_u64(self._tail_mem, old + int(add_value)) + self._write_local_u64(self._tail_mem, (old + int(add_value)) & _U64_MASK) return old raise RuntimeError("rdma_faa failed and no local fallback for ptr") + def _rdma_cas(self, ptr_addr: int, compare_value: int, swap_value: int) -> int: + if self.rdma_client is not None: + return self.rdma_client.rdma_cas( + ptr_addr, + int(compare_value), + int(swap_value), + rkey=self.descriptor.rkey, + ) + + if self.rdma_server is not None: + with self._lock: + old = self._read_remote_u64(ptr_addr) + if old == (int(compare_value) & _U64_MASK): + new = int(swap_value) & _U64_MASK + self._rdma_write_bytes(ptr_addr, new.to_bytes(8, byteorder="little", signed=False)) + return old + + # Local fallback for single-process testing. + with self._lock: + if ptr_addr == self.descriptor.head_addr: + old = self._read_local_u64(self._head_mem) + if old == (int(compare_value) & _U64_MASK): + self._write_local_u64(self._head_mem, int(swap_value) & _U64_MASK) + return old + if ptr_addr == self.descriptor.tail_addr: + old = self._read_local_u64(self._tail_mem) + if old == (int(compare_value) & _U64_MASK): + self._write_local_u64(self._tail_mem, int(swap_value) & _U64_MASK) + return old + raise RuntimeError("rdma_cas failed and no local fallback for ptr") + def _rdma_read_bytes(self, remote_addr: int, length: int) -> bytes: if self.rdma_server is not None and self._descriptor is not None: base = self._descriptor.head_addr @@ -232,7 +269,7 @@ def produce(self, config: Dict[str, Any]) -> int: # Read current indices first, write the slot fully, then publish by advancing tail. old_tail = self._read_remote_u64(self.descriptor.tail_addr) cur_head = self._read_remote_u64(self.descriptor.head_addr) - if old_tail - cur_head >= self.buffer_size: + if self._u64_distance(old_tail, cur_head) >= self.buffer_size: raise BufferError("ring buffer is full") slot_idx = old_tail % self.buffer_size @@ -257,49 +294,74 @@ def consume(self) -> Optional[Dict[str, Any]]: if self.role != "client": raise RuntimeError("consume is only allowed in client role") - try: - cur_head = self._read_remote_u64(self.descriptor.head_addr) - cur_tail = self._read_remote_u64(self.descriptor.tail_addr) - except Exception as exc: - return None + max_claim_retries = max(8, self.buffer_size * 2) + claim_retry_sleep_seconds = 0.001 - # Fast path: empty queue, do not touch head. - if cur_head >= cur_tail: - return None - - # Try to reserve one slot by advancing head atomically. - try: - old_head = self._rdma_faa(self.descriptor.head_addr, 1) - except Exception as exc: - return None - - if old_head >= cur_tail: - # Lost the race: rollback reservation. + for _ in range(max_claim_retries): try: - self._rdma_faa(self.descriptor.head_addr, -1) - except Exception as exc: - logger.warning("RDMA buffer rollback failed on empty consume: %s", exc) - return None - - slot_idx = old_head % self.buffer_size - slot_addr = self.descriptor.slot_addr + self._slot_offset(slot_idx) - max_read_retries = 5 - retry_sleep_seconds = 0.002 - last_error: Optional[Exception] = None - for _ in range(max_read_retries): + cur_head = self._read_remote_u64(self.descriptor.head_addr) + cur_tail = self._read_remote_u64(self.descriptor.tail_addr) + except Exception: + return None + + # Fast path: empty queue, do not touch head. + if self._u64_distance(cur_tail, cur_head) == 0: + return None + + slot_idx = cur_head % self.buffer_size + slot_addr = self.descriptor.slot_addr + self._slot_offset(slot_idx) + max_read_retries = 5 + retry_sleep_seconds = 0.002 + last_error: Optional[Exception] = None + config: Optional[Dict[str, Any]] = None + + for _ in range(max_read_retries): + try: + raw = self._rdma_read_bytes(slot_addr, self.slot_size) + config = self._deserialize_config(raw) + last_error = None + break + except Exception as exc: + last_error = exc + time.sleep(retry_sleep_seconds) + + if config is None: + # Keep head unchanged so this slot can be retried later. + logger.warning( + "RDMA buffer slot %d read incomplete after retries, keeping head unchanged: %s", + slot_idx, + last_error, + ) + return None + try: - raw = self._rdma_read_bytes(slot_addr, self.slot_size) - config = self._deserialize_config(raw) - logger.info("Consumed config from RDMA buffer slot %d", slot_idx) - return config + old_head = self._rdma_cas( + self.descriptor.head_addr, + cur_head, + (cur_head + 1) & _U64_MASK, + ) except Exception as exc: - last_error = exc - time.sleep(retry_sleep_seconds) + logger.warning("RDMA buffer head CAS failed for slot %d: %s", slot_idx, exc) + return None - # Slot still not readable after retries, rollback head so the slot can be retried later. - logger.warning("RDMA buffer slot %d read incomplete after retries, rolling back head: %s", slot_idx, last_error) - try: - self._rdma_faa(self.descriptor.head_addr, -1) - except Exception as exc: - logger.warning("RDMA buffer rollback failed after incomplete slot read: %s", exc) + if old_head != cur_head: + # Another consumer advanced head first; retry from latest head. + time.sleep(claim_retry_sleep_seconds) + continue + + logger.info("Consumed config from RDMA buffer slot %d", slot_idx) + return config + + logger.warning("RDMA buffer consume contention is too high, skip this round") return None + + def pending_count(self) -> int: + """Return current queue length inferred from ring tail/head counters.""" + cur_head = self._read_remote_u64(self.descriptor.head_addr) + cur_tail = self._read_remote_u64(self.descriptor.tail_addr) + pending = int(self._u64_distance(cur_tail, cur_head)) + if pending <= 0: + return 0 + if pending > self.buffer_size: + return self.buffer_size + return pending diff --git a/lightx2v/disagg/rdma_client.py b/lightx2v/disagg/rdma_client.py index 615e6e3f8..fa3edeb79 100644 --- a/lightx2v/disagg/rdma_client.py +++ b/lightx2v/disagg/rdma_client.py @@ -1,4 +1,5 @@ import json +import logging import os import random import socket @@ -16,6 +17,9 @@ from pyverbs.wr import SendWR as WR +logger = logging.getLogger(__name__) + + class IBDevice: def __init__(self, name: str): self.name = name @@ -61,10 +65,10 @@ def __init__(self, iface_name=None, local_buffer_size=4096): self.ctx = IBDevice(iface_name).open() self.pd = PD(self.ctx) - self.cq = CQ(self.ctx, 10) + self.cq = CQ(self.ctx, 64) self.gid_index = self._resolve_gid_index() - qp_init_attr = QPCap(max_send_wr=10, max_recv_wr=10, max_send_sge=1, max_recv_sge=1) + qp_init_attr = QPCap(max_send_wr=64, max_recv_wr=64, max_send_sge=1, max_recv_sge=1) self._qia = QPInitAttr(qp_type=QPType.RC, scq=self.cq, rcq=self.cq, cap=qp_init_attr) self._qa = QPAttr(port_num=self.port_num) self.qp = QP(self.pd, self._qia, self._qa) @@ -75,6 +79,37 @@ def __init__(self, iface_name=None, local_buffer_size=4096): raise ValueError("local_buffer_size must be positive") self.local_mr = MR(self.pd, self.buffer_size, AccessFlag.LOCAL_WRITE) self._io_lock = threading.RLock() + self._connected_server_ip: str | None = None + self._connected_server_port: int | None = None + self._qp_error_state: bool = False + self._last_wc_error_message: str = "" + + def has_qp_error(self) -> bool: + return self._qp_error_state + + def last_wc_error_message(self) -> str: + return self._last_wc_error_message + + def _wc_status_name(self, status: int | None) -> str: + if status is None: + return "UNKNOWN" + status_map = { + getattr(e, "IBV_WC_SUCCESS", -1): "IBV_WC_SUCCESS", + getattr(e, "IBV_WC_LOC_LEN_ERR", -2): "IBV_WC_LOC_LEN_ERR", + getattr(e, "IBV_WC_LOC_QP_OP_ERR", -3): "IBV_WC_LOC_QP_OP_ERR", + getattr(e, "IBV_WC_LOC_PROT_ERR", -4): "IBV_WC_LOC_PROT_ERR", + getattr(e, "IBV_WC_WR_FLUSH_ERR", -5): "IBV_WC_WR_FLUSH_ERR", + getattr(e, "IBV_WC_MW_BIND_ERR", -6): "IBV_WC_MW_BIND_ERR", + getattr(e, "IBV_WC_BAD_RESP_ERR", -7): "IBV_WC_BAD_RESP_ERR", + getattr(e, "IBV_WC_LOC_ACCESS_ERR", -8): "IBV_WC_LOC_ACCESS_ERR", + getattr(e, "IBV_WC_REM_INV_REQ_ERR", -9): "IBV_WC_REM_INV_REQ_ERR", + getattr(e, "IBV_WC_REM_ACCESS_ERR", -10): "IBV_WC_REM_ACCESS_ERR", + getattr(e, "IBV_WC_REM_OP_ERR", -11): "IBV_WC_REM_OP_ERR", + getattr(e, "IBV_WC_RETRY_EXC_ERR", -12): "IBV_WC_RETRY_EXC_ERR", + getattr(e, "IBV_WC_RNR_RETRY_EXC_ERR", -13): "IBV_WC_RNR_RETRY_EXC_ERR", + getattr(e, "IBV_WC_REM_ABORT_ERR", -14): "IBV_WC_REM_ABORT_ERR", + } + return status_map.get(status, f"IBV_WC_STATUS_{status}") def _resolve_gid_index(self): env_gid = os.getenv("RDMA_GID_INDEX", "").strip() @@ -157,6 +192,14 @@ def connect_to_server(self, server_ip="127.0.0.1", port=5566): for attempt in range(1, max_retries + 1): sock = None try: + old_sock = getattr(self, "sock", None) + if old_sock is not None: + try: + old_sock.close() + except Exception: + pass + self.sock = None + self._reset_qp() self._alloc_local_psn() @@ -190,6 +233,10 @@ def connect_to_server(self, server_ip="127.0.0.1", port=5566): self._modify_qp_to_rts() sock.settimeout(None) self.sock = sock + self._connected_server_ip = str(server_ip) + self._connected_server_port = int(port) + self._qp_error_state = False + self._last_wc_error_message = "" print(f"[Client] Connection established (RTS) to {server_ip}:{port} at attempt {attempt}/{max_retries}") return except Exception as exc: @@ -307,6 +354,13 @@ def rdma_write_to(self, remote_addr, data_bytes, rkey=None): self.remote_info["rkey"] = int(rkey) try: self.rdma_write(data_bytes, notify_server=False) + except Exception as exc: + raise RuntimeError( + "rdma_write_to failed " + f"server={self._connected_server_ip}:{self._connected_server_port} " + f"remote_addr={int(remote_addr)} length={len(data_bytes)} " + f"rkey={self.remote_info.get('rkey')}" + ) from exc finally: self.remote_info["addr"] = old_addr self.remote_info["rkey"] = old_rkey @@ -321,6 +375,13 @@ def rdma_read_from(self, remote_addr, length, rkey=None): self.remote_info["rkey"] = int(rkey) try: return self.rdma_read(int(length)) + except Exception as exc: + raise RuntimeError( + "rdma_read_from failed " + f"server={self._connected_server_ip}:{self._connected_server_port} " + f"remote_addr={int(remote_addr)} length={int(length)} " + f"rkey={self.remote_info.get('rkey')}" + ) from exc finally: self.remote_info["addr"] = old_addr self.remote_info["rkey"] = old_rkey @@ -396,7 +457,28 @@ def _poll_cq(self): raise RuntimeError(f"Unexpected WC object: {wc}") if status != e.IBV_WC_SUCCESS: vendor_err = getattr(wc, "vendor_err", None) - raise Exception(f"WC Error: {status}, vendor_err: {vendor_err}") + wr_id = getattr(wc, "wr_id", None) + opcode = getattr(wc, "opcode", None) + status_name = self._wc_status_name(status) + self._qp_error_state = True + self._last_wc_error_message = ( + f"status={status}({status_name}) vendor_err={vendor_err} wr_id={wr_id} opcode={opcode} " + f"server={self._connected_server_ip}:{self._connected_server_port}" + ) + logger.error( + "RDMA CQ failure: status=%s(%s) vendor_err=%s wr_id=%s opcode=%s server=%s:%s", + status, + status_name, + vendor_err, + wr_id, + opcode, + self._connected_server_ip, + self._connected_server_port, + ) + raise RuntimeError( + "WC Error: " + f"{status}({status_name}), vendor_err: {vendor_err}, wr_id: {wr_id}, opcode: {opcode}" + ) break time.sleep(0.0001) diff --git a/lightx2v/disagg/rdma_server.py b/lightx2v/disagg/rdma_server.py index bfe0d0c7c..005ae9650 100644 --- a/lightx2v/disagg/rdma_server.py +++ b/lightx2v/disagg/rdma_server.py @@ -64,11 +64,11 @@ def __init__(self, iface_name=None, port_num=1, buffer_size=4096): raise RuntimeError(f"Failed to open RDMA device '{iface_name}'. Available devices: {available}") self.pd = PD(self.ctx) - self.cq = CQ(self.ctx, 10) + self.cq = CQ(self.ctx, 64) self.gid_index = self._resolve_gid_index() # 创建 QP (Queue Pair) - qp_init_attr = QPCap(max_send_wr=10, max_recv_wr=10, max_send_sge=1, max_recv_sge=1) + qp_init_attr = QPCap(max_send_wr=64, max_recv_wr=64, max_send_sge=1, max_recv_sge=1) qia = QPInitAttr(qp_type=QPType.RC, scq=self.cq, rcq=self.cq, cap=qp_init_attr) qa = QPAttr(port_num=self.port_num) self.qp = QP(self.pd, qia, qa) # RC: Reliable Connected diff --git a/lightx2v/disagg/services/controller.py b/lightx2v/disagg/services/controller.py index 1e39b74c0..7b1366ffd 100644 --- a/lightx2v/disagg/services/controller.py +++ b/lightx2v/disagg/services/controller.py @@ -1,4 +1,6 @@ import os +import json +import signal import socket import subprocess import sys @@ -8,6 +10,8 @@ from threading import Event, Lock, Thread from typing import Any +import zmq + from lightx2v.disagg.conn import MONITOR_POLLING_PORT, REQUEST_POLLING_PORT, ReqManager from lightx2v.disagg.monitor import Monitor from lightx2v.disagg.rdma_buffer import RDMABuffer @@ -42,6 +46,13 @@ def __init__(self): self._bootstrap_addr: str = "127.0.0.1" self._gpu_reuse_block_until: dict[int, float] = {} self._gpu_reuse_grace_seconds: float = 5.0 + self._graceful_reclaim_timeout_seconds: float = float(os.getenv("DISAGG_RECLAIM_GRACEFUL_TIMEOUT_SECONDS", "30.0")) + self._force_kill_wait_seconds: float = float(os.getenv("DISAGG_RECLAIM_FORCE_KILL_WAIT_SECONDS", "1.0")) + self._sidecar_start_timeout_seconds: float = float(os.getenv("DISAGG_SIDECAR_START_TIMEOUT_SECONDS", "15.0")) + self._sidecar_drain_idle_seconds: float = float(os.getenv("DISAGG_SIDECAR_DRAIN_IDLE_SECONDS", "1.0")) + # <= 0 means wait indefinitely until sidecar pending queues are drained. + self._sidecar_drain_timeout_seconds: float = float(os.getenv("DISAGG_SIDECAR_DRAIN_TIMEOUT_SECONDS", "0")) + self._sidecar_reclaim_threads: list[Thread] = [] self._shutting_down: bool = False def _is_tcp_port_open(self, host: str, port: int) -> bool: @@ -58,6 +69,148 @@ def _wait_for_tcp_port_state(self, host: str, port: int, should_be_open: bool, t time.sleep(0.1) return self._is_tcp_port_open(host, port) == should_be_open + def _allocate_free_tcp_port(self) -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind((self._bootstrap_addr, 0)) + return int(sock.getsockname()[1]) + + def _query_sidecar(self, req_addr: str, cmd: str) -> dict[str, Any] | None: + context = zmq.Context() + req = context.socket(zmq.REQ) + req.setsockopt(zmq.RCVTIMEO, 1000) + req.setsockopt(zmq.SNDTIMEO, 1000) + req.connect(req_addr) + try: + req.send_pyobj({"cmd": str(cmd)}) + reply = req.recv_pyobj() + if isinstance(reply, dict): + return reply + return None + except Exception: + return None + finally: + req.close(0) + context.term() + + def _start_sidecar_process(self, instance_type: str, gpu_id: int) -> dict[str, Any]: + push_port = self._allocate_free_tcp_port() + req_port = self._allocate_free_tcp_port() + push_addr = f"tcp://{self._bootstrap_addr}:{push_port}" + req_addr = f"tcp://{self._bootstrap_addr}:{req_port}" + + cmd = [ + sys.executable, + "-m", + "lightx2v.disagg.services.data_mgr_sidecar", + "--push-addr", + push_addr, + "--req-addr", + req_addr, + ] + sidecar_env = os.environ.copy() + sidecar_env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + process = subprocess.Popen( + cmd, + env=sidecar_env, + start_new_session=True, + ) + + deadline = time.time() + self._sidecar_start_timeout_seconds + ready = False + while time.time() < deadline: + reply = self._query_sidecar(req_addr, "ping") + if isinstance(reply, dict) and reply.get("ok", False): + ready = True + break + time.sleep(0.1) + + if not ready: + if process.poll() is None: + process.terminate() + try: + process.wait(timeout=2.0) + except subprocess.TimeoutExpired: + process.kill() + raise RuntimeError(f"sidecar server failed to start for {instance_type} gpu={gpu_id}") + + self.logger.info( + "Started sidecar for %s gpu=%s pid=%s push=%s req=%s", + instance_type, + gpu_id, + process.pid, + push_addr, + req_addr, + ) + return { + "process": process, + "push_addr": push_addr, + "req_addr": req_addr, + } + + def _reclaim_sidecar_when_drained(self, instance_type: str, target_address: str, sidecar_meta: dict[str, Any]): + req_addr = str(sidecar_meta.get("req_addr", "")) + process = sidecar_meta.get("process") + if not req_addr or process is None: + return + + deadline = None + if self._sidecar_drain_timeout_seconds > 0: + deadline = time.time() + self._sidecar_drain_timeout_seconds + + while True: + if process.poll() is not None: + # Sidecar already exited. + break + + reply = self._query_sidecar(req_addr, "get_stats") + if isinstance(reply, dict) and reply.get("ok", False): + data = reply.get("data") if isinstance(reply.get("data"), dict) else {} + last_message_ts = float(data.get("last_message_ts", 0.0)) + idle_seconds = max(0.0, time.time() - last_message_ts) + pending_input_watch = int(data.get("input_watch", 0)) + pending_output_watch = int(data.get("output_watch", 0)) + pending_transformer_request = int(data.get("transformer_request_pool", 0)) + pending_transformer_waiting = int(data.get("transformer_waiting_pool", 0)) + pending_transformer_active = int(data.get("transformer_active_rooms", 0)) + pending_active = ( + pending_input_watch + + pending_output_watch + + pending_transformer_request + + pending_transformer_waiting + + pending_transformer_active + ) + + if pending_active == 0 and idle_seconds >= self._sidecar_drain_idle_seconds: + break + + if deadline is not None and time.time() >= deadline: + self.logger.warning( + "Sidecar drain timeout reached for %s address=%s, forcing shutdown", + instance_type, + target_address, + ) + break + + time.sleep(0.2) + + try: + self._query_sidecar(req_addr, "shutdown") + except Exception: + pass + + if process.poll() is None: + process.terminate() + try: + process.wait(timeout=2.0) + except subprocess.TimeoutExpired: + process.kill() + + self.logger.info( + "Reclaimed sidecar for %s address=%s", + instance_type, + target_address, + ) + def _to_plain(self, value: Any) -> Any: """Recursively convert config containers (e.g. LockableDict) to built-in Python types.""" if isinstance(value, Mapping): @@ -83,6 +236,388 @@ def _resolve_service_config_json(self, config_json: str, instance_type: str) -> return str(candidate) return config_json + def _load_warmup_duration_seconds(self, config: Mapping[str, Any]) -> float: + stage_json = os.getenv("DISAGG_WORKLOAD_STAGES_JSON", "") + if not stage_json: + stage_json = str(config.get("workload_stages_json", "") or "").strip() + + if stage_json: + stage_file = Path(stage_json) + else: + repo_root = Path(__file__).resolve().parents[3] + stage_file = repo_root / "configs" / "disagg" / "wan22_i2v_workload_stages.json" + + if not stage_file.is_file(): + self.logger.warning("workload stages config not found, skip warmup scale guard: %s", stage_file) + return 0.0 + + try: + with stage_file.open("r", encoding="utf-8") as handle: + loaded = json.load(handle) + except Exception as exc: + self.logger.warning("failed to load workload stages config %s: %s", stage_file, exc) + return 0.0 + + if not isinstance(loaded, list): + self.logger.warning("invalid workload stages config format (expect list): %s", stage_file) + return 0.0 + + warmup_duration_s = 0.0 + for raw_stage in loaded: + if not isinstance(raw_stage, Mapping): + continue + + stage_name = str(raw_stage.get("name", "")).strip().lower() + if stage_name != "warmup": + if warmup_duration_s > 0.0: + break + continue + + try: + duration_s = float(raw_stage.get("duration_s", 0.0)) + except (TypeError, ValueError): + duration_s = 0.0 + warmup_duration_s += max(duration_s, 0.0) + + self.logger.info( + "Loaded workload warmup duration: file=%s warmup_duration_s=%.3f", + stage_file, + warmup_duration_s, + ) + return warmup_duration_s + + def _sample_rdma_queue_pending(self) -> dict[str, int]: + pending_by_service: dict[str, int] = { + "encoder": 0, + "transformer": 0, + "decoder": 0, + } + buffer_by_service = { + "encoder": self.rdma_buffer_request, + "transformer": self.rdma_buffer_phase1, + "decoder": self.rdma_buffer_phase2, + } + for service_type, rdma_buffer in buffer_by_service.items(): + if rdma_buffer is None: + continue + try: + pending_by_service[service_type] = int(rdma_buffer.pending_count()) + except Exception as exc: + self.logger.warning("Failed to sample RDMA pending count for %s: %s", service_type, exc) + return pending_by_service + + def _calc_precompute_pending(self, service_type: str, queue_sizes: Any) -> int: + if not isinstance(queue_sizes, dict): + return -1 + + normalized: dict[str, int] = {} + for key, value in queue_sizes.items(): + try: + normalized[str(key)] = int(value) + except (TypeError, ValueError): + continue + + if service_type == "encoder": + keys = ("req_queue", "exec_queue") + return sum(max(normalized.get(key, 0), 0) for key in keys) + + if service_type == "transformer": + direct_keys = ("req_queue", "waiting_queue", "exec_queue") + pending = sum(max(normalized.get(key, 0), 0) for key in direct_keys) + # phase1_* are pre-compute ingress queues; phase2_* are post-compute egress queues. + pending += sum(max(value, 0) for key, value in normalized.items() if key.startswith("phase1_")) + return pending + + if service_type == "decoder": + direct_keys = ("req_queue", "waiting_queue", "exec_queue") + pending = sum(max(normalized.get(key, 0), 0) for key in direct_keys) + # Decoder transfer_* represent ingress from transformer, still before decode compute. + pending += sum(max(value, 0) for key, value in normalized.items() if key.startswith("transfer_")) + return pending + + return -1 + + def _monitor_callback(self, results): + monitor_runtime = getattr(self, "_monitor_runtime", None) + if self._shutting_down or not isinstance(monitor_runtime, dict): + return + + warmup_duration_s = float(monitor_runtime.get("warmup_duration_s", 0.0)) + autoscale_start_mono = float(monitor_runtime.get("autoscale_start_mono", time.monotonic())) + warmup_skip_logged = bool(monitor_runtime.get("warmup_skip_logged", False)) + warmup_end_logged = bool(monitor_runtime.get("warmup_end_logged", False)) + scale_out_threshold = float(monitor_runtime.get("scale_out_threshold", 80.0)) + scale_out_max_queue_threshold = int(monitor_runtime.get("scale_out_max_queue_threshold", 2)) + scale_in_threshold = float(monitor_runtime.get("scale_in_threshold", 20.0)) + scale_cooldown_seconds = float(monitor_runtime.get("scale_cooldown_seconds", 30.0)) + last_scale_ts = monitor_runtime.get("last_scale_ts") + if not isinstance(last_scale_ts, dict): + return + + if warmup_duration_s > 0.0: + elapsed_s = max(0.0, time.monotonic() - autoscale_start_mono) + if elapsed_s < warmup_duration_s: + if not warmup_skip_logged: + self.logger.info( + "Skip autoscaling during warmup: elapsed_s=%.3f warmup_duration_s=%.3f", + elapsed_s, + warmup_duration_s, + ) + warmup_skip_logged = True + monitor_runtime["warmup_skip_logged"] = True + return + if warmup_skip_logged and not warmup_end_logged: + self.logger.info( + "Warmup finished, autoscaling enabled: elapsed_s=%.3f warmup_duration_s=%.3f", + elapsed_s, + warmup_duration_s, + ) + warmup_end_logged = True + monitor_runtime["warmup_end_logged"] = True + + service_metrics: dict[str, list[dict[str, Any]]] = { + "encoder": [], + "transformer": [], + "decoder": [], + } + + for item in results: + self.logger.info("monitor: %s", item) + if not isinstance(item, dict): + continue + + service_type = str(item.get("service_type", "")) + if service_type not in {"encoder", "transformer", "decoder"}: + continue + + if service_type not in {"transformer", "decoder"}: + continue + + if item.get("status") != "ok": + continue + + try: + gpu_utilization = float(item.get("gpu_utilization", 0.0)) + except (TypeError, ValueError): + continue + + monitor_address = str(item.get("address", "")) + if not monitor_address: + continue + + queue_total_pending = item.get("queue_total_pending", None) + try: + queue_total_pending_int = int(queue_total_pending) if queue_total_pending is not None else -1 + except (TypeError, ValueError): + queue_total_pending_int = -1 + + all_queues_empty = bool(item.get("all_queues_empty", False)) + queue_sizes = item.get("queue_sizes") + precompute_pending = self._calc_precompute_pending(service_type, queue_sizes) + + service_metrics[service_type].append( + { + "gpu_utilization": gpu_utilization, + "monitor_address": monitor_address, + "queue_total_pending": queue_total_pending_int, + "all_queues_empty": all_queues_empty, + "precompute_pending": precompute_pending, + } + ) + + rdma_pending_by_service = self._sample_rdma_queue_pending() + scale_out_candidates: list[dict[str, Any]] = [] + service_queue_scores: dict[str, float] = {} + service_precompute_scores: dict[str, float] = {} + + for service_type, metrics in service_metrics.items(): + if not metrics: + continue + avg_queue_total_pending = sum(int(metric.get("queue_total_pending", 0)) for metric in metrics) / len(metrics) + rdma_queue_pending = int(rdma_pending_by_service.get(service_type, 0)) + service_queue_scores[service_type] = float(rdma_queue_pending) + float(avg_queue_total_pending) + + precompute_values = [int(metric.get("precompute_pending", -1)) for metric in metrics if int(metric.get("precompute_pending", -1)) >= 0] + if precompute_values: + avg_precompute_pending = sum(precompute_values) / len(precompute_values) + service_precompute_scores[service_type] = float(rdma_queue_pending) + float(avg_precompute_pending) + else: + service_precompute_scores[service_type] = float(rdma_queue_pending) + + max_precompute_score = max(service_precompute_scores.values(), default=0.0) + + for service_type, metrics in service_metrics.items(): + if not metrics: + continue + + now = time.time() + avg_gpu_utilization = sum(float(metric["gpu_utilization"]) for metric in metrics) / len(metrics) + avg_queue_total_pending = sum(int(metric.get("queue_total_pending", 0)) for metric in metrics) / len(metrics) + max_queue_total_pending = max(int(metric.get("queue_total_pending", -1)) for metric in metrics) + rdma_queue_pending = int(rdma_pending_by_service.get(service_type, 0)) + current_queue_score = float(service_queue_scores.get(service_type, 0.0)) + current_precompute_score = float(service_precompute_scores.get(service_type, 0.0)) + + scale_out_triggered = ( + avg_gpu_utilization > scale_out_threshold + or max_queue_total_pending > scale_out_max_queue_threshold + ) + + if scale_out_triggered and now - float(last_scale_ts.get(service_type, 0.0)) >= scale_cooldown_seconds: + scale_out_candidates.append( + { + "service_type": service_type, + "score": current_queue_score, + "avg_gpu_utilization": avg_gpu_utilization, + "avg_queue_total_pending": avg_queue_total_pending, + "max_queue_total_pending": max_queue_total_pending, + "rdma_queue_pending": rdma_queue_pending, + "now": now, + } + ) + + low_metric = min(metrics, key=lambda metric: float(metric["gpu_utilization"])) + low_utilization = float(low_metric["gpu_utilization"]) + low_monitor_address = str(low_metric["monitor_address"]) + with self._instance_lock: + service_instance_count = sum(1 for meta in self._managed_instances.values() if meta.get("instance_type") == service_type) + + low_precompute_pending = int(low_metric.get("precompute_pending", -1)) + if low_precompute_pending >= 0: + queues_empty_for_service = low_precompute_pending == 0 + else: + queues_empty_for_service = bool(low_metric.get("all_queues_empty", False)) and int(low_metric.get("queue_total_pending", -1)) == 0 + + blocked_by_queue_score = current_precompute_score > 0.0 and current_precompute_score >= max_precompute_score + + scale_in_triggered = ( + low_utilization < scale_in_threshold + and service_instance_count > 1 + and queues_empty_for_service + and now - float(last_scale_ts.get(service_type, 0.0)) >= scale_cooldown_seconds + ) + + if scale_in_triggered and blocked_by_queue_score: + self.logger.info( + "Skip scale in for highest precompute-score service: service=%s precompute_score=%.2f max_precompute_score=%.2f total_score=%.2f", + service_type, + current_precompute_score, + max_precompute_score, + current_queue_score, + ) + continue + + if scale_in_triggered: + try: + target_instance_address = self._instance_address_from_monitor_node(low_monitor_address) + self.reclaim_instance(service_type, target_instance_address) + last_scale_ts[service_type] = now + self.logger.info( + "Auto-scale in triggered: service=%s low_gpu_utilization=%.2f reclaimed_instance=%s", + service_type, + low_utilization, + target_instance_address, + ) + except Exception as exc: + self.logger.warning( + "Auto-scale in skipped for service=%s low_gpu_utilization=%.2f reason=%s", + service_type, + low_utilization, + exc, + ) + + if scale_out_candidates: + target = max( + scale_out_candidates, + key=lambda item: (item["score"], item["max_queue_total_pending"], item["avg_gpu_utilization"]), + ) + target_service = str(target["service_type"]) + if float(target["now"]) - float(last_scale_ts.get(target_service, 0.0)) < scale_cooldown_seconds: + return + try: + new_address = self.create_instance(target_service) + last_scale_ts[target_service] = float(target["now"]) + self.logger.info( + "Auto-scale out triggered: service=%s score=%.2f rdma_queue_pending=%s avg_queue_total_pending=%.2f max_queue_total_pending=%s avg_gpu_utilization=%.2f new_instance=%s", + target_service, + float(target["score"]), + int(target["rdma_queue_pending"]), + float(target["avg_queue_total_pending"]), + int(target["max_queue_total_pending"]), + float(target["avg_gpu_utilization"]), + new_address, + ) + except Exception: + pass + + def _handle_decoder_result( + self, + result: Any, + *, + expected_rooms: set[int], + received_rooms: set[int], + received_results: list[dict], + ): + if not isinstance(result, dict): + self.logger.warning("Ignored non-dict decoder result: %s", result) + return + room = result.get("data_bootstrap_room") + if room is None: + self.logger.warning("Ignored decoder result without data_bootstrap_room: %s", result) + return + room = int(room) + if room not in expected_rooms: + self.logger.warning("Ignored decoder result for unexpected room=%s: %s", room, result) + return + if room in received_rooms: + self.logger.info("Duplicate decoder result for room=%s ignored", room) + return + + controller_recv_ts = time.time() + latency_summary = self._build_latency_summary(result, controller_recv_ts) + if latency_summary is not None: + result["latency_summary"] = latency_summary + self.logger.info("Latency summary room=%s metrics=%s", room, latency_summary) + + received_rooms.add(room) + received_results.append(result) + + if result.get("ok", False): + self.logger.info( + "Decoder result received room=%s save_path=%s (%s/%s)", + room, + result.get("save_path"), + len(received_rooms), + len(expected_rooms), + ) + else: + self.logger.error( + "Decoder result failed room=%s error=%s (%s/%s)", + room, + result.get("error"), + len(received_rooms), + len(expected_rooms), + ) + + def _drain_decoder_results_non_block( + self, + *, + result_port: int, + expected_rooms: set[int], + received_rooms: set[int], + received_results: list[dict], + ): + while True: + result = self.req_mgr.receive_non_block(result_port) + if result is None: + break + self._handle_decoder_result( + result, + expected_rooms=expected_rooms, + received_rooms=received_rooms, + received_results=received_results, + ) + def _monitor_node_from_instance_address(self, instance_address: str) -> str: host, port_str = instance_address.rsplit(":", 1) rank = int(port_str) - REQUEST_POLLING_PORT @@ -177,9 +712,16 @@ def create_instance(self, instance_type: str) -> str: "--save_result_path", str(instance_cfg.get("save_path", "")), ] + sidecar_meta = self._start_sidecar_process(instance_type, gpu_id) env = os.environ.copy() env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) - process = subprocess.Popen(cmd, env=env) + env["LIGHTX2V_SIDECAR_PUSH_ADDR"] = str(sidecar_meta["push_addr"]) + env["LIGHTX2V_SIDECAR_REQ_ADDR"] = str(sidecar_meta["req_addr"]) + process = subprocess.Popen( + cmd, + env=env, + start_new_session=True, + ) monitor_port = MONITOR_POLLING_PORT + gpu_id if not self._wait_for_tcp_port_state(self._bootstrap_addr, monitor_port, should_be_open=True, timeout_seconds=8.0): @@ -189,11 +731,18 @@ def create_instance(self, instance_type: str) -> str: process.wait(timeout=3.0) except subprocess.TimeoutExpired: process.kill() + sidecar_process = sidecar_meta.get("process") + if sidecar_process is not None and sidecar_process.poll() is None: + sidecar_process.terminate() + try: + sidecar_process.wait(timeout=2.0) + except subprocess.TimeoutExpired: + sidecar_process.kill() raise RuntimeError(f"service {instance_type} on gpu={gpu_id} failed to expose monitor port {monitor_port}") instance_address = f"{self._bootstrap_addr}:{REQUEST_POLLING_PORT + gpu_id}" self._free_gpus.remove(gpu_id) - self.add_instance(instance_type, instance_address) + # self.add_instance(instance_type, instance_address) monitor_node = f"tcp://{self._bootstrap_addr}:{MONITOR_POLLING_PORT + gpu_id}" if monitor_node not in self.monitor.nodes: self.monitor.nodes.append(monitor_node) @@ -201,6 +750,7 @@ def create_instance(self, instance_type: str) -> str: "instance_type": instance_type, "gpu_id": gpu_id, "process": process, + "sidecar": sidecar_meta, } self.started_instances.append((instance_type, instance_address)) self.logger.info( @@ -227,25 +777,45 @@ def reclaim_instance(self, instance_type: str, instance_address: str | None = No meta = self._managed_instances.get(target_address) if meta is None: - raise RuntimeError(f"instance not managed by controller: {target_address}") + if (instance_type, target_address) in self.started_instances: + self.started_instances.remove((instance_type, target_address)) + self.logger.warning( + "Skip reclaim for already-removed %s instance address=%s", + instance_type, + target_address, + ) + return target_address if meta.get("instance_type") != instance_type: raise RuntimeError(f"instance type mismatch for {target_address}: expected={instance_type} got={meta.get('instance_type')}") process = meta.get("process") gpu_id = int(meta.get("gpu_id")) + sidecar_meta = meta.get("sidecar") if isinstance(meta.get("sidecar"), dict) else None - self.remove_instance(instance_type, target_address) + # self.remove_instance(instance_type, target_address) monitor_node = self._monitor_node_from_instance_address(target_address) - if monitor_node in self.monitor.nodes: - self.monitor.nodes.remove(monitor_node) if process is not None and process.poll() is None: - process.terminate() try: - process.wait(timeout=5.0) + os.killpg(process.pid, signal.SIGTERM) + except Exception: + process.terminate() + try: + process.wait(timeout=self._graceful_reclaim_timeout_seconds) except subprocess.TimeoutExpired: - process.kill() - process.wait(timeout=1.0) + try: + os.killpg(process.pid, signal.SIGKILL) + except Exception: + process.kill() + try: + process.wait(timeout=self._force_kill_wait_seconds) + except subprocess.TimeoutExpired as exc: + raise RuntimeError( + f"process did not exit after kill for {instance_type} instance {target_address}" + ) from exc + + if monitor_node in self.monitor.nodes: + self.monitor.nodes.remove(monitor_node) monitor_port = MONITOR_POLLING_PORT + gpu_id if not self._wait_for_tcp_port_state(self._bootstrap_addr, monitor_port, should_be_open=False, timeout_seconds=5.0): @@ -261,6 +831,17 @@ def reclaim_instance(self, instance_type: str, instance_address: str | None = No self._managed_instances.pop(target_address, None) if (instance_type, target_address) in self.started_instances: self.started_instances.remove((instance_type, target_address)) + + if sidecar_meta is not None: + reclaim_thread = Thread( + target=self._reclaim_sidecar_when_drained, + args=(instance_type, target_address, sidecar_meta), + name=f"sidecar-reclaim-{instance_type}-{gpu_id}", + daemon=True, + ) + reclaim_thread.start() + self._sidecar_reclaim_threads.append(reclaim_thread) + self.logger.info( "Reclaimed %s instance from gpu=%s address=%s", instance_type, @@ -476,9 +1057,9 @@ def run(self, config): self._runtime_config = self._to_plain(config) self._init_gpu_pool(config) - self.encoder_policy = RoundRobinPolicy() - self.transformer_policy = RoundRobinPolicy() - self.decoder_policy = RoundRobinPolicy() + # self.encoder_policy = RoundRobinPolicy() + # self.transformer_policy = RoundRobinPolicy() + # self.decoder_policy = RoundRobinPolicy() self._init_request_rdma_buffer(bootstrap_addr, config) @@ -490,7 +1071,12 @@ def run(self, config): self.create_instance("transformer") monitor_stop_event = Event() + warmup_duration_s = self._load_warmup_duration_seconds(config) + autoscale_start_mono = time.monotonic() + warmup_skip_logged = False + warmup_end_logged = False scale_out_threshold = 80.0 + scale_out_max_queue_threshold = 2 scale_in_threshold = 20.0 scale_cooldown_seconds = 30.0 last_scale_ts: dict[str, float] = { @@ -499,116 +1085,23 @@ def run(self, config): "decoder": 0.0, } - def _monitor_callback(results): - return - if self._shutting_down: - return - - service_metrics: dict[str, list[dict[str, Any]]] = { - "encoder": [], - "transformer": [], - "decoder": [], - } - - for item in results: - # self.logger.info("monitor: %s", item) - if not isinstance(item, dict): - continue - - service_type = str(item.get("service_type", "")) - if service_type not in {"encoder", "transformer", "decoder"}: - continue - - # if service_type not in {"transformer"}: - # continue - - if item.get("status") != "ok": - continue - - try: - gpu_utilization = float(item.get("gpu_utilization", 0.0)) - except (TypeError, ValueError): - continue - - monitor_address = str(item.get("address", "")) - if not monitor_address: - continue - - queue_total_pending = item.get("queue_total_pending", None) - try: - queue_total_pending_int = int(queue_total_pending) if queue_total_pending is not None else -1 - except (TypeError, ValueError): - queue_total_pending_int = -1 - - all_queues_empty = bool(item.get("all_queues_empty", False)) - - service_metrics[service_type].append( - { - "gpu_utilization": gpu_utilization, - "monitor_address": monitor_address, - "queue_total_pending": queue_total_pending_int, - "all_queues_empty": all_queues_empty, - } - ) - - for service_type, metrics in service_metrics.items(): - if not metrics: - continue - - now = time.time() - avg_gpu_utilization = sum(float(metric["gpu_utilization"]) for metric in metrics) / len(metrics) - - if avg_gpu_utilization > scale_out_threshold and now - last_scale_ts[service_type] >= scale_cooldown_seconds: - try: - new_address = self.create_instance(service_type) - last_scale_ts[service_type] = now - self.logger.info( - "Auto-scale out triggered: service=%s avg_gpu_utilization=%.2f new_instance=%s", - service_type, - avg_gpu_utilization, - new_address, - ) - except Exception as exc: - pass - # self.logger.warning( - # "Auto-scale out skipped for service=%s avg_gpu_utilization=%.2f reason=%s", - # service_type, - # avg_gpu_utilization, - # exc, - # ) - - low_metric = min(metrics, key=lambda metric: float(metric["gpu_utilization"])) - low_utilization = float(low_metric["gpu_utilization"]) - low_monitor_address = str(low_metric["monitor_address"]) - with self._instance_lock: - service_instance_count = sum(1 for meta in self._managed_instances.values() if meta.get("instance_type") == service_type) - - queues_empty_for_service = bool(low_metric.get("all_queues_empty", False)) and int(low_metric.get("queue_total_pending", -1)) == 0 - - if low_utilization < scale_in_threshold and service_instance_count > 1 and queues_empty_for_service and now - last_scale_ts[service_type] >= scale_cooldown_seconds: - try: - target_instance_address = self._instance_address_from_monitor_node(low_monitor_address) - self.reclaim_instance(service_type, target_instance_address) - last_scale_ts[service_type] = now - self.logger.info( - "Auto-scale in triggered: service=%s low_gpu_utilization=%.2f reclaimed_instance=%s", - service_type, - low_utilization, - target_instance_address, - ) - except Exception as exc: - self.logger.warning( - "Auto-scale in skipped for service=%s low_gpu_utilization=%.2f reason=%s", - service_type, - low_utilization, - exc, - ) + self._monitor_runtime = { + "warmup_duration_s": warmup_duration_s, + "autoscale_start_mono": autoscale_start_mono, + "warmup_skip_logged": warmup_skip_logged, + "warmup_end_logged": warmup_end_logged, + "scale_out_threshold": scale_out_threshold, + "scale_out_max_queue_threshold": scale_out_max_queue_threshold, + "scale_in_threshold": scale_in_threshold, + "scale_cooldown_seconds": scale_cooldown_seconds, + "last_scale_ts": last_scale_ts, + } monitor_thread = Thread( target=self.monitor.run_forever, kwargs={ - "interval_seconds": 5.0, - "callback": _monitor_callback, + "interval_seconds": 2.0, + "callback": self._monitor_callback, "stop_event": monitor_stop_event, }, name="controller-monitor", @@ -624,67 +1117,45 @@ def _monitor_callback(results): received_results: list[dict] = [] next_room = 0 batch_request_start_ts: float | None = None + load_from_user = str(os.getenv("LOAD_FROM_USER", "0")).strip().lower() in {"1", "true", "yes", "on"} + auto_request_count_raw = config.get("request_count", os.getenv("DISAGG_AUTO_REQUEST_COUNT", "30")) + try: + auto_request_count = int(auto_request_count_raw) + except (TypeError, ValueError): + self.logger.warning( + "Invalid request_count=%s, fallback to 30", + auto_request_count_raw, + ) + auto_request_count = 30 + if auto_request_count <= 0: + self.logger.warning("request_count must be positive, fallback to 30") + auto_request_count = 30 - def _handle_decoder_result(result: Any): - if not isinstance(result, dict): - self.logger.warning("Ignored non-dict decoder result: %s", result) - return - room = result.get("data_bootstrap_room") - if room is None: - self.logger.warning("Ignored decoder result without data_bootstrap_room: %s", result) - return - room = int(room) - if room not in expected_rooms: - self.logger.warning("Ignored decoder result for unexpected room=%s: %s", room, result) - return - if room in received_rooms: - self.logger.info("Duplicate decoder result for room=%s ignored", room) - return - - controller_recv_ts = time.time() - latency_summary = self._build_latency_summary(result, controller_recv_ts) - if latency_summary is not None: - result["latency_summary"] = latency_summary - self.logger.info("Latency summary room=%s metrics=%s", room, latency_summary) - - received_rooms.add(room) - received_results.append(result) - - if result.get("ok", False): - self.logger.info( - "Decoder result received room=%s save_path=%s (%s/%s)", - room, - result.get("save_path"), - len(received_rooms), - len(expected_rooms), - ) + try: + generated_request_count = 0 + if load_from_user: + self.logger.info("LOAD_FROM_USER enabled, waiting workload configs on port=%s", request_ingress_port) else: - self.logger.error( - "Decoder result failed room=%s error=%s (%s/%s)", - room, - result.get("error"), - len(received_rooms), - len(expected_rooms), + self.logger.info( + "LOAD_FROM_USER disabled, generating requests from config: count=%s", + auto_request_count, ) - def _drain_decoder_results_non_block(): while True: - result = self.req_mgr.receive_non_block(result_port) - if result is None: - break - _handle_decoder_result(result) - - try: - self.logger.info("Waiting workload configs on port=%s", request_ingress_port) - while True: - workload_config = self.req_mgr.receive(request_ingress_port) - if not isinstance(workload_config, dict): - self.logger.warning("Ignored invalid workload config packet: %s", workload_config) - continue - - if workload_config.get("workload_end") or workload_config.get("end") or workload_config.get("stop"): - self.logger.info("Received workload end signal, stop accepting new configs.") - break + if load_from_user: + workload_config = self.req_mgr.receive(request_ingress_port) + if not isinstance(workload_config, dict): + self.logger.warning("Ignored invalid workload config packet: %s", workload_config) + continue + + if workload_config.get("workload_end") or workload_config.get("end") or workload_config.get("stop"): + self.logger.info("Received workload end signal, stop accepting new configs.") + break + else: + if generated_request_count >= auto_request_count: + break + workload_config = {} + generated_request_count += 1 request_config = dict(config) request_config.update(self._to_plain(workload_config)) @@ -731,7 +1202,12 @@ def _drain_decoder_results_non_block(): ) expected_rooms.add(room) - _drain_decoder_results_non_block() + self._drain_decoder_results_non_block( + result_port=result_port, + expected_rooms=expected_rooms, + received_rooms=received_rooms, + received_results=received_results, + ) self.logger.info( "Waiting for decoder results: expected=%s on port=%s", @@ -740,7 +1216,12 @@ def _drain_decoder_results_non_block(): ) while len(received_rooms) < len(expected_rooms): result = self.req_mgr.receive(result_port) - _handle_decoder_result(result) + self._handle_decoder_result( + result, + expected_rooms=expected_rooms, + received_rooms=received_rooms, + received_results=received_results, + ) self.logger.info("All decoder results received. Controller exiting.") if batch_request_start_ts is None: @@ -756,9 +1237,14 @@ def _drain_decoder_results_non_block(): self._shutting_down = True monitor_stop_event.set() monitor_thread.join(timeout=2.0) + self._monitor_runtime = None for instance_type, address in reversed(list(self.started_instances)): try: self.reclaim_instance(instance_type, address) except Exception: self.logger.exception("Failed to reclaim %s instance address=%s", instance_type, address) + + for thread in list(self._sidecar_reclaim_threads): + if thread.is_alive(): + thread.join(timeout=3.0) diff --git a/lightx2v/disagg/services/data_mgr_sidecar.py b/lightx2v/disagg/services/data_mgr_sidecar.py new file mode 100644 index 000000000..2378d924a --- /dev/null +++ b/lightx2v/disagg/services/data_mgr_sidecar.py @@ -0,0 +1,975 @@ +from __future__ import annotations + +import argparse +import os +import queue +import threading +import time +from collections import deque +from multiprocessing import resource_tracker, shared_memory +from typing import TYPE_CHECKING, Any, Deque + +import zmq + +if TYPE_CHECKING: + from lightx2v.disagg.conn import DataReceiver, DataSender + + +STATUS_FAILED = 0 +STATUS_SUCCESS = 4 +_SHM_TRACKING_PATCHED = False + + +def _disable_shared_memory_tracking_for_process(): + """Disable multiprocessing resource_tracker registration for shared_memory. + + Python 3.12 does not expose SharedMemory(track=False). In fail-fast paths where + processes are terminated quickly, tracker warnings/noise can dominate logs even + when manual cleanup is performed by sidecar ownership logic. + """ + + global _SHM_TRACKING_PATCHED + if _SHM_TRACKING_PATCHED: + return + + original_register = resource_tracker.register + original_unregister = resource_tracker.unregister + + def _register(name, rtype): + if rtype == "shared_memory": + return + return original_register(name, rtype) + + def _unregister(name, rtype): + if rtype == "shared_memory": + return + return original_unregister(name, rtype) + + resource_tracker.register = _register + resource_tracker.unregister = _unregister + _SHM_TRACKING_PATCHED = True + + +class DataMgrSidecarServer: + """Controller-managed sidecar server process. + + Services push transfer-state events to this process and pop aggregated events + through request/reply calls. + """ + + def __init__(self, push_addr: str, req_addr: str): + _disable_shared_memory_tracking_for_process() + self.push_addr = str(push_addr) + self.req_addr = str(req_addr) + + self._input_watch: set[int] = set() + self._output_watch: set[int] = set() + self._ready_inputs: Deque[int] = deque() + self._failed_inputs: Deque[int] = deque() + self._completed_outputs: Deque[tuple[int, int]] = deque() + + self._total_messages = 0 + self._last_message_ts = time.time() + self._running = True + + self._transformer_phase2_mgr: Any | None = None + self._transformer_phase2_rooms: dict[int, dict[str, Any]] = {} + self._transformer_phase2_output_watch: set[int] = set() + self._transformer_phase2_last_status: dict[int, int] = {} + + def _mark_activity(self): + self._total_messages += 1 + self._last_message_ts = time.time() + + def _handle_push(self, msg: dict): + cmd = str(msg.get("cmd", "")) + room = int(msg.get("room", -1)) + + if cmd == "watch_input" and room >= 0: + self._input_watch.add(room) + self._mark_activity() + return + if cmd == "unwatch_input" and room >= 0: + self._input_watch.discard(room) + self._mark_activity() + return + if cmd == "watch_output" and room >= 0: + self._output_watch.add(room) + self._mark_activity() + return + if cmd == "unwatch_output" and room >= 0: + self._output_watch.discard(room) + self._mark_activity() + return + + if cmd == "input_status" and room >= 0: + status = int(msg.get("status", STATUS_FAILED)) + self._input_watch.discard(room) + if status == STATUS_SUCCESS: + self._ready_inputs.append(room) + else: + self._failed_inputs.append(room) + self._mark_activity() + return + + if cmd == "output_status" and room >= 0: + status = int(msg.get("status", STATUS_FAILED)) + self._output_watch.discard(room) + self._completed_outputs.append((room, status)) + self._mark_activity() + return + + if cmd == "shutdown": + self._running = False + self._mark_activity() + + def _ensure_transformer_phase2_mgr(self): + if self._transformer_phase2_mgr is not None: + return self._transformer_phase2_mgr + + from lightx2v.disagg.conn import DataManager, DisaggregationMode, DisaggregationPhase + + self._transformer_phase2_mgr = DataManager(DisaggregationPhase.PHASE2, DisaggregationMode.TRANSFORMER) + return self._transformer_phase2_mgr + + def _create_shared_memory(self, size: int) -> shared_memory.SharedMemory: + # Keep lifecycle in this process and avoid resource_tracker duplicate cleanup at shutdown. + try: + return shared_memory.SharedMemory(create=True, size=int(size), track=False) + except TypeError: + return shared_memory.SharedMemory(create=True, size=int(size)) + + def _close_unlink_shared_memory(self, shm: shared_memory.SharedMemory): + try: + shm.close() + except Exception: + pass + try: + shm.unlink() + except FileNotFoundError: + pass + except Exception: + pass + + def _cleanup_transformer_phase2_room(self, room: int): + room = int(room) + info = self._transformer_phase2_rooms.pop(room, None) + self._transformer_phase2_output_watch.discard(room) + + mgr = self._transformer_phase2_mgr + if mgr is not None: + try: + mgr.remove(room) + except Exception: + pass + + if not isinstance(info, dict): + return + + shms = info.get("shms") + if isinstance(shms, list): + for shm in shms: + self._close_unlink_shared_memory(shm) + + def _init_transformer_output_room( + self, + room: int, + sender_engine_rank: int, + receiver_engine_rank: int, + data_lens: list[int], + bootstrap_addr: str, + ) -> dict[str, Any]: + room = int(room) + sender_engine_rank = int(sender_engine_rank) + receiver_engine_rank = int(receiver_engine_rank) + normalized_lens = [int(v) for v in list(data_lens)] + if not normalized_lens or any(v <= 0 for v in normalized_lens): + raise ValueError(f"invalid data_lens for room={room}: {normalized_lens}") + + self._cleanup_transformer_phase2_room(room) + self._transformer_phase2_last_status.pop(room, None) + mgr = self._ensure_transformer_phase2_mgr() + + from lightx2v.disagg.conn import DataArgs, DataSender + import numpy as np + import torch + + shms: list[shared_memory.SharedMemory] = [] + arrays: list[Any] = [] + tensors: list[Any] = [] + data_ptrs: list[int] = [] + shm_names: list[str] = [] + + try: + for nbytes in normalized_lens: + shm = self._create_shared_memory(int(nbytes)) + arr = np.ndarray((int(nbytes),), dtype=np.uint8, buffer=shm.buf) + tensor = torch.from_numpy(arr) + tensor.zero_() + + shms.append(shm) + arrays.append(arr) + tensors.append(tensor) + data_ptrs.append(int(tensor.data_ptr())) + shm_names.append(str(shm.name)) + + data_args = DataArgs( + sender_engine_rank=sender_engine_rank, + receiver_engine_rank=receiver_engine_rank, + data_ptrs=data_ptrs, + data_lens=normalized_lens, + data_item_lens=normalized_lens, + ib_device=None, + ) + mgr.init(data_args, room) + sender = DataSender(mgr, bootstrap_addr, room) + + self._transformer_phase2_rooms[room] = { + "sender": sender, + "data_ptrs": data_ptrs, + "shms": shms, + "arrays": arrays, + "tensors": tensors, + } + + self._mark_activity() + return { + "room": room, + "shm_names": shm_names, + "data_lens": normalized_lens, + "host": str(mgr.get_localhost()), + "session_id": str(mgr.get_session_id()), + } + except Exception: + for shm in shms: + self._close_unlink_shared_memory(shm) + try: + mgr.remove(room) + except Exception: + pass + raise + + def _send_transformer_output_room(self, room: int): + room = int(room) + info = self._transformer_phase2_rooms.get(room) + if not isinstance(info, dict): + raise KeyError(f"transformer output room not initialized: {room}") + + sender = info.get("sender") + data_ptrs = info.get("data_ptrs") + if sender is None or not isinstance(data_ptrs, list): + raise RuntimeError(f"transformer output room metadata invalid: {room}") + + sender.send(list(data_ptrs)) + self._transformer_phase2_output_watch.add(room) + self._mark_activity() + + def _get_transformer_output_status(self, room: int) -> int: + room = int(room) + info = self._transformer_phase2_rooms.get(room) + if not isinstance(info, dict): + return int(self._transformer_phase2_last_status.get(room, STATUS_FAILED)) + sender = info.get("sender") + if sender is None: + return int(self._transformer_phase2_last_status.get(room, STATUS_FAILED)) + try: + return int(sender.poll()) + except Exception: + return int(self._transformer_phase2_last_status.get(room, STATUS_FAILED)) + + def _get_transformer_output_backlog(self) -> dict[str, int]: + mgr = self._transformer_phase2_mgr + if mgr is None: + return { + "request_pool": 0, + "waiting_pool": 0, + "request_status": 0, + } + try: + data = mgr.get_backlog_counts() + except Exception: + data = {} + return { + "request_pool": int(data.get("request_pool", 0)), + "waiting_pool": int(data.get("waiting_pool", 0)), + "request_status": int(data.get("request_status", 0)), + } + + def _poll_transformer_output_watch(self): + for room in list(self._transformer_phase2_output_watch): + status_val = self._get_transformer_output_status(room) + if status_val in (STATUS_SUCCESS, STATUS_FAILED): + self._transformer_phase2_output_watch.discard(room) + self._transformer_phase2_last_status[int(room)] = int(status_val) + self._completed_outputs.append((int(room), int(status_val))) + self._cleanup_transformer_phase2_room(room) + self._mark_activity() + + def _release_transformer_phase2_mgr(self): + for room in list(self._transformer_phase2_rooms.keys()): + self._cleanup_transformer_phase2_room(room) + mgr = self._transformer_phase2_mgr + self._transformer_phase2_mgr = None + if mgr is not None: + try: + mgr.release() + except Exception: + pass + + def _get_pending_counts(self) -> dict[str, int]: + transformer_backlog = self._get_transformer_output_backlog() + output_watch = len(self._output_watch) + len(self._transformer_phase2_output_watch) + return { + "input_watch": len(self._input_watch), + "output_watch": output_watch, + "ready_inputs": len(self._ready_inputs), + "failed_inputs": len(self._failed_inputs), + "completed_outputs": len(self._completed_outputs), + "transformer_request_pool": int(transformer_backlog.get("request_pool", 0)), + "transformer_waiting_pool": int(transformer_backlog.get("waiting_pool", 0)), + "transformer_active_rooms": len(self._transformer_phase2_rooms), + } + + def _handle_req(self, req: dict) -> dict: + cmd = str(req.get("cmd", "")) + + if cmd == "ping": + return {"ok": True} + + if cmd == "get_pending_counts": + return {"ok": True, "data": self._get_pending_counts()} + + if cmd == "get_stats": + counts = self._get_pending_counts() + return { + "ok": True, + "data": { + **counts, + "total_messages": int(self._total_messages), + "last_message_ts": float(self._last_message_ts), + }, + } + + if cmd == "init_transformer_output_room": + try: + room = int(req.get("room", -1)) + sender_engine_rank = int(req.get("sender_engine_rank", -1)) + receiver_engine_rank = int(req.get("receiver_engine_rank", -1)) + data_lens_raw = req.get("data_lens") + bootstrap_addr = str(req.get("bootstrap_addr", "127.0.0.1")) + if room < 0 or sender_engine_rank < 0 or receiver_engine_rank < 0: + raise ValueError("room/sender_engine_rank/receiver_engine_rank must be non-negative") + if not isinstance(data_lens_raw, list): + raise ValueError("data_lens must be a list") + data = self._init_transformer_output_room( + room=room, + sender_engine_rank=sender_engine_rank, + receiver_engine_rank=receiver_engine_rank, + data_lens=[int(v) for v in data_lens_raw], + bootstrap_addr=bootstrap_addr, + ) + return {"ok": True, "data": data} + except Exception as exc: + return {"ok": False, "error": str(exc)} + + if cmd == "send_transformer_output_room": + try: + room = int(req.get("room", -1)) + if room < 0: + raise ValueError("room must be non-negative") + self._send_transformer_output_room(room) + return {"ok": True, "data": True} + except Exception as exc: + return {"ok": False, "error": str(exc)} + + if cmd == "get_transformer_output_status": + room = int(req.get("room", -1)) + if room < 0: + return {"ok": False, "error": "room must be non-negative"} + return {"ok": True, "data": self._get_transformer_output_status(room)} + + if cmd == "remove_transformer_output_room": + room = int(req.get("room", -1)) + if room < 0: + return {"ok": False, "error": "room must be non-negative"} + self._cleanup_transformer_phase2_room(room) + self._mark_activity() + return {"ok": True, "data": True} + + if cmd == "get_transformer_output_backlog": + return {"ok": True, "data": self._get_transformer_output_backlog()} + + if cmd == "get_transformer_output_identity": + mgr = self._transformer_phase2_mgr + if mgr is None: + return {"ok": False, "error": "transformer phase2 manager not initialized"} + return { + "ok": True, + "data": { + "host": str(mgr.get_localhost()), + "session_id": str(mgr.get_session_id()), + }, + } + + if cmd == "pop_ready_inputs": + items = list(self._ready_inputs) + self._ready_inputs.clear() + return {"ok": True, "data": items} + + if cmd == "pop_failed_inputs": + items = list(self._failed_inputs) + self._failed_inputs.clear() + return {"ok": True, "data": items} + + if cmd == "pop_completed_outputs": + items = list(self._completed_outputs) + self._completed_outputs.clear() + return {"ok": True, "data": items} + + if cmd == "shutdown": + self._running = False + self._mark_activity() + return {"ok": True} + + return {"ok": False, "error": f"unknown command: {cmd}"} + + def run_forever(self): + context = zmq.Context() + pull = context.socket(zmq.PULL) + rep = context.socket(zmq.REP) + + pull.bind(self.push_addr) + rep.bind(self.req_addr) + + poller = zmq.Poller() + poller.register(pull, zmq.POLLIN) + poller.register(rep, zmq.POLLIN) + + try: + while self._running: + events = dict(poller.poll(timeout=100)) + if pull in events: + try: + self._handle_push(pull.recv_pyobj()) + except Exception: + pass + + if rep in events: + try: + reply = self._handle_req(rep.recv_pyobj()) + except Exception as exc: + reply = {"ok": False, "error": str(exc)} + rep.send_pyobj(reply) + + self._poll_transformer_output_watch() + finally: + self._release_transformer_phase2_mgr() + pull.close(0) + rep.close(0) + context.term() + + +class _LocalDataMgrSidecar: + """Fallback local sidecar used when controller-managed endpoints are absent.""" + + def __init__(self, poll_interval_s: float = 0.01): + self.poll_interval_s = max(float(poll_interval_s), 0.001) + + self._lock = threading.Lock() + self._stop_event = threading.Event() + self._thread: threading.Thread | None = None + self._started = False + + self._input_watch: dict[int, DataReceiver] = {} + self._output_watch: dict[int, DataSender] = {} + + self._ready_inputs: Deque[int] = deque() + self._failed_inputs: Deque[int] = deque() + self._completed_outputs: Deque[tuple[int, int]] = deque() + + def start(self): + if self._thread is not None and self._thread.is_alive(): + self._started = True + return + + self._stop_event.clear() + self._thread = threading.Thread( + target=self._run, + name="data-mgr-sidecar-local", + daemon=True, + ) + self._thread.start() + self._started = True + + def stop(self): + self._stop_event.set() + if self._thread is not None and self._thread.is_alive(): + self._thread.join(timeout=1.0) + self._thread = None + self._started = False + + def watch_input(self, room: int, receiver: DataReceiver): + if not self._started: + self.start() + with self._lock: + self._input_watch[int(room)] = receiver + + def unwatch_input(self, room: int): + with self._lock: + self._input_watch.pop(int(room), None) + + def watch_output(self, room: int, sender: DataSender): + if not self._started: + self.start() + with self._lock: + self._output_watch[int(room)] = sender + + def unwatch_output(self, room: int): + with self._lock: + self._output_watch.pop(int(room), None) + + def pop_ready_inputs(self) -> list[int]: + with self._lock: + items = list(self._ready_inputs) + self._ready_inputs.clear() + return items + + def pop_failed_inputs(self) -> list[int]: + with self._lock: + items = list(self._failed_inputs) + self._failed_inputs.clear() + return items + + def pop_completed_outputs(self) -> list[tuple[int, int]]: + with self._lock: + items = list(self._completed_outputs) + self._completed_outputs.clear() + return items + + def get_pending_counts(self) -> dict[str, int]: + with self._lock: + return { + "input_watch": len(self._input_watch), + "output_watch": len(self._output_watch), + "ready_inputs": len(self._ready_inputs), + "failed_inputs": len(self._failed_inputs), + "completed_outputs": len(self._completed_outputs), + } + + def init_transformer_output_room( + self, + room: int, + sender_engine_rank: int, + receiver_engine_rank: int, + data_lens: list[int], + bootstrap_addr: str, + ) -> dict[str, Any] | None: + return None + + def send_transformer_output_room(self, room: int) -> bool: + return False + + def get_transformer_output_status(self, room: int) -> int: + return STATUS_FAILED + + def remove_transformer_output_room(self, room: int) -> bool: + return False + + def get_transformer_output_backlog(self) -> dict[str, int]: + return { + "request_pool": 0, + "waiting_pool": 0, + "request_status": 0, + } + + def get_transformer_output_identity(self, room: int | None = None) -> dict[str, Any] | None: + return None + + def _run(self): + while not self._stop_event.is_set(): + with self._lock: + input_items = list(self._input_watch.items()) + output_items = list(self._output_watch.items()) + + if not input_items and not output_items: + time.sleep(self.poll_interval_s) + continue + + for room, receiver in input_items: + try: + status = receiver.poll() + except Exception: + status = STATUS_FAILED + + status_val = int(status) + if status_val == STATUS_SUCCESS: + with self._lock: + self._input_watch.pop(room, None) + self._ready_inputs.append(room) + elif status_val == STATUS_FAILED: + with self._lock: + self._input_watch.pop(room, None) + self._failed_inputs.append(room) + + for room, sender in output_items: + try: + status = sender.poll() + except Exception: + status = STATUS_FAILED + + status_val = int(status) + if status_val in (STATUS_SUCCESS, STATUS_FAILED): + with self._lock: + self._output_watch.pop(room, None) + self._completed_outputs.append((room, status_val)) + + time.sleep(self.poll_interval_s) + + +class _RemoteDataMgrSidecarClient: + """Service-side client for controller-managed sidecar process.""" + + def __init__(self, push_addr: str, req_addr: str, poll_interval_s: float = 0.01): + self.push_addr = str(push_addr) + self.req_addr = str(req_addr) + self.poll_interval_s = max(float(poll_interval_s), 0.001) + + self._context = zmq.Context.instance() + self._push = self._context.socket(zmq.PUSH) + self._push.connect(self.push_addr) + + self._req = self._context.socket(zmq.REQ) + self._req.connect(self.req_addr) + self._req.setsockopt(zmq.RCVTIMEO, 1500) + self._req.setsockopt(zmq.SNDTIMEO, 1500) + + self._req_lock = threading.Lock() + self._watch_lock = threading.Lock() + self._stop_event = threading.Event() + self._thread: threading.Thread | None = None + self._started = False + + self._input_watch: dict[int, DataReceiver] = {} + self._output_watch: dict[int, DataSender] = {} + + def start(self): + if self._thread is not None and self._thread.is_alive(): + self._started = True + return + self._stop_event.clear() + self._thread = threading.Thread(target=self._run, name="data-mgr-sidecar-remote-client", daemon=True) + self._thread.start() + self._started = True + + def stop(self): + self._stop_event.set() + if self._thread is not None and self._thread.is_alive(): + self._thread.join(timeout=1.0) + self._thread = None + self._started = False + + try: + with self._watch_lock: + rooms_in = list(self._input_watch.keys()) + rooms_out = list(self._output_watch.keys()) + self._input_watch.clear() + self._output_watch.clear() + for room in rooms_in: + self._push_cmd({"cmd": "unwatch_input", "room": int(room)}) + for room in rooms_out: + self._push_cmd({"cmd": "unwatch_output", "room": int(room)}) + except Exception: + pass + + def watch_input(self, room: int, receiver: DataReceiver): + if not self._started: + self.start() + room = int(room) + with self._watch_lock: + self._input_watch[room] = receiver + self._push_cmd({"cmd": "watch_input", "room": room}) + + def unwatch_input(self, room: int): + room = int(room) + with self._watch_lock: + self._input_watch.pop(room, None) + self._push_cmd({"cmd": "unwatch_input", "room": room}) + + def watch_output(self, room: int, sender: DataSender): + if not self._started: + self.start() + room = int(room) + with self._watch_lock: + self._output_watch[room] = sender + self._push_cmd({"cmd": "watch_output", "room": room}) + + def unwatch_output(self, room: int): + room = int(room) + with self._watch_lock: + self._output_watch.pop(room, None) + self._push_cmd({"cmd": "unwatch_output", "room": room}) + + def pop_ready_inputs(self) -> list[int]: + data = self._req_cmd("pop_ready_inputs") + if isinstance(data, list): + return [int(v) for v in data] + return [] + + def pop_failed_inputs(self) -> list[int]: + data = self._req_cmd("pop_failed_inputs") + if isinstance(data, list): + return [int(v) for v in data] + return [] + + def pop_completed_outputs(self) -> list[tuple[int, int]]: + data = self._req_cmd("pop_completed_outputs") + if not isinstance(data, list): + return [] + items: list[tuple[int, int]] = [] + for item in data: + if isinstance(item, (list, tuple)) and len(item) == 2: + items.append((int(item[0]), int(item[1]))) + return items + + def get_pending_counts(self) -> dict[str, int]: + data = self._req_cmd("get_pending_counts") + if not isinstance(data, dict): + return { + "input_watch": 0, + "output_watch": 0, + "ready_inputs": 0, + "failed_inputs": 0, + "completed_outputs": 0, + } + return { + "input_watch": int(data.get("input_watch", 0)), + "output_watch": int(data.get("output_watch", 0)), + "ready_inputs": int(data.get("ready_inputs", 0)), + "failed_inputs": int(data.get("failed_inputs", 0)), + "completed_outputs": int(data.get("completed_outputs", 0)), + } + + def _push_cmd(self, cmd: dict): + try: + self._push.send_pyobj(cmd) + except Exception: + pass + + def _req_cmd(self, cmd: str, payload: dict[str, Any] | None = None): + try: + req_payload: dict[str, Any] = {"cmd": str(cmd)} + if isinstance(payload, dict): + req_payload.update(payload) + with self._req_lock: + self._req.send_pyobj(req_payload) + reply = self._req.recv_pyobj() + if isinstance(reply, dict) and reply.get("ok", False): + return reply.get("data") + return None + except Exception: + return None + + def init_transformer_output_room( + self, + room: int, + sender_engine_rank: int, + receiver_engine_rank: int, + data_lens: list[int], + bootstrap_addr: str, + ) -> dict[str, Any] | None: + data = self._req_cmd( + "init_transformer_output_room", + { + "room": int(room), + "sender_engine_rank": int(sender_engine_rank), + "receiver_engine_rank": int(receiver_engine_rank), + "data_lens": [int(v) for v in data_lens], + "bootstrap_addr": str(bootstrap_addr), + }, + ) + if isinstance(data, dict): + return data + return None + + def send_transformer_output_room(self, room: int) -> bool: + data = self._req_cmd("send_transformer_output_room", {"room": int(room)}) + return bool(data) + + def get_transformer_output_status(self, room: int) -> int: + data = self._req_cmd("get_transformer_output_status", {"room": int(room)}) + if data is None: + return STATUS_FAILED + try: + return int(data) + except Exception: + return STATUS_FAILED + + def remove_transformer_output_room(self, room: int) -> bool: + data = self._req_cmd("remove_transformer_output_room", {"room": int(room)}) + return bool(data) + + def get_transformer_output_backlog(self) -> dict[str, int]: + data = self._req_cmd("get_transformer_output_backlog") + if not isinstance(data, dict): + return { + "request_pool": 0, + "waiting_pool": 0, + "request_status": 0, + } + return { + "request_pool": int(data.get("request_pool", 0)), + "waiting_pool": int(data.get("waiting_pool", 0)), + "request_status": int(data.get("request_status", 0)), + } + + def get_transformer_output_identity(self, room: int | None = None) -> dict[str, Any] | None: + payload: dict[str, Any] = {} + if room is not None: + payload["room"] = int(room) + data = self._req_cmd("get_transformer_output_identity", payload) + if isinstance(data, dict): + return { + "host": str(data.get("host", "")), + "session_id": str(data.get("session_id", "")), + } + return None + + def _run(self): + while not self._stop_event.is_set(): + with self._watch_lock: + input_items = list(self._input_watch.items()) + output_items = list(self._output_watch.items()) + + if not input_items and not output_items: + time.sleep(self.poll_interval_s) + continue + + for room, receiver in input_items: + try: + status = receiver.poll() + except Exception: + status = STATUS_FAILED + + status_val = int(status) + if status_val in (STATUS_SUCCESS, STATUS_FAILED): + with self._watch_lock: + self._input_watch.pop(room, None) + self._push_cmd( + { + "cmd": "input_status", + "room": int(room), + "status": status_val, + } + ) + + for room, sender in output_items: + try: + status = sender.poll() + except Exception: + status = STATUS_FAILED + + status_val = int(status) + if status_val in (STATUS_SUCCESS, STATUS_FAILED): + with self._watch_lock: + self._output_watch.pop(room, None) + self._push_cmd( + { + "cmd": "output_status", + "room": int(room), + "status": status_val, + } + ) + + time.sleep(self.poll_interval_s) + + +class DataMgrSidecar: + """Service-facing sidecar facade. + + If controller-side endpoints exist, use controller-managed remote sidecar. + Otherwise fallback to in-process local sidecar for standalone runs. + """ + + def __init__(self, poll_interval_s: float = 0.01): + push_addr = str(os.getenv("LIGHTX2V_SIDECAR_PUSH_ADDR", "")).strip() + req_addr = str(os.getenv("LIGHTX2V_SIDECAR_REQ_ADDR", "")).strip() + + if push_addr and req_addr: + self._impl = _RemoteDataMgrSidecarClient(push_addr=push_addr, req_addr=req_addr, poll_interval_s=poll_interval_s) + else: + self._impl = _LocalDataMgrSidecar(poll_interval_s=poll_interval_s) + + def start(self): + self._impl.start() + + def stop(self): + self._impl.stop() + + def watch_input(self, room: int, receiver: DataReceiver): + self._impl.watch_input(room, receiver) + + def unwatch_input(self, room: int): + self._impl.unwatch_input(room) + + def watch_output(self, room: int, sender: DataSender): + self._impl.watch_output(room, sender) + + def unwatch_output(self, room: int): + self._impl.unwatch_output(room) + + def pop_ready_inputs(self) -> list[int]: + return self._impl.pop_ready_inputs() + + def pop_failed_inputs(self) -> list[int]: + return self._impl.pop_failed_inputs() + + def pop_completed_outputs(self) -> list[tuple[int, int]]: + return self._impl.pop_completed_outputs() + + def get_pending_counts(self) -> dict[str, int]: + return self._impl.get_pending_counts() + + def init_transformer_output_room( + self, + room: int, + sender_engine_rank: int, + receiver_engine_rank: int, + data_lens: list[int], + bootstrap_addr: str, + ) -> dict[str, Any] | None: + return self._impl.init_transformer_output_room( + room=room, + sender_engine_rank=sender_engine_rank, + receiver_engine_rank=receiver_engine_rank, + data_lens=data_lens, + bootstrap_addr=bootstrap_addr, + ) + + def send_transformer_output_room(self, room: int) -> bool: + return self._impl.send_transformer_output_room(room) + + def get_transformer_output_status(self, room: int) -> int: + return self._impl.get_transformer_output_status(room) + + def remove_transformer_output_room(self, room: int) -> bool: + return self._impl.remove_transformer_output_room(room) + + def get_transformer_output_backlog(self) -> dict[str, int]: + return self._impl.get_transformer_output_backlog() + + def get_transformer_output_identity(self, room: int | None = None) -> dict[str, Any] | None: + return self._impl.get_transformer_output_identity(room) + + +def _run_server_from_cli(): + parser = argparse.ArgumentParser(description="Run DataMgr sidecar server process") + parser.add_argument("--push-addr", type=str, required=True) + parser.add_argument("--req-addr", type=str, required=True) + args = parser.parse_args() + + server = DataMgrSidecarServer(push_addr=args.push_addr, req_addr=args.req_addr) + server.run_forever() + + +if __name__ == "__main__": + _run_server_from_cli() diff --git a/lightx2v/disagg/services/decoder.py b/lightx2v/disagg/services/decoder.py index 516986ed5..a42bed24c 100644 --- a/lightx2v/disagg/services/decoder.py +++ b/lightx2v/disagg/services/decoder.py @@ -15,6 +15,7 @@ from lightx2v.disagg.rdma_buffer import RDMABuffer, RDMABufferDescriptor from lightx2v.disagg.rdma_client import RDMAClient from lightx2v.disagg.services.base import BaseService +from lightx2v.disagg.services.data_mgr_sidecar import DataMgrSidecar from lightx2v.disagg.utils import estimate_transformer_buffer_sizes, load_wan_vae_decoder from lightx2v.utils.envs import GET_DTYPE from lightx2v.utils.utils import save_to_video, seed_all, wan_vae_to_comfy @@ -63,6 +64,7 @@ def __init__(self, config: dict): daemon=True, ) self._reporter_thread.start() + self._data_mgr_sidecar = DataMgrSidecar() self.sync_comm = str(os.getenv("SYNC_COMM", "")).strip().lower() not in ("", "0", "false", "no", "off") self.load_models() @@ -315,6 +317,7 @@ def remove(self, room: int): self.release_memory(room) self.data_receiver.pop(room, None) + self._data_mgr_sidecar.unwatch_input(room) if self.data_mgr is None: return @@ -340,6 +343,7 @@ def run(self, stop_event=None): while True: transfer_sizes = self.data_mgr.get_backlog_counts() if self.data_mgr is not None else {"request_pool": 0, "waiting_pool": 0} + sidecar_sizes = self._data_mgr_sidecar.get_pending_counts() self._update_queue_metrics( { "req_queue": len(req_queue), @@ -349,6 +353,7 @@ def run(self, stop_event=None): { "request_pool": int(transfer_sizes.get("request_pool", 0)), "waiting_pool": int(transfer_sizes.get("waiting_pool", 0)), + "sidecar_input_watch": int(sidecar_sizes.get("input_watch", 0)), }, ) @@ -380,31 +385,23 @@ def run(self, stop_event=None): try: self.init(config) waiting_queue[room] = config + receiver = self.data_receiver.get(room) + if receiver is None: + raise RuntimeError(f"DataReceiver is not initialized for room={room}") + self._data_mgr_sidecar.watch_input(room, receiver) except Exception: self.logger.exception("Failed to initialize request for room=%s", room) self.remove(room) - ready_rooms: List[int] = [] - failed_rooms: List[int] = [] - waiting_rooms = list(waiting_queue.keys()) - if self.sync_comm and waiting_rooms: - waiting_rooms = [waiting_rooms[0]] - - for room in waiting_rooms: - receiver = self.data_receiver.get(room) - if receiver is None: - failed_rooms.append(room) - continue - - status = receiver.poll() - if status == DataPoll.Success: - ready_rooms.append(room) - elif status == DataPoll.Failed: - failed_rooms.append(room) + ready_rooms = self._data_mgr_sidecar.pop_ready_inputs() + failed_rooms = self._data_mgr_sidecar.pop_failed_inputs() for room in ready_rooms: + config = waiting_queue.pop(room, None) + if config is None: + continue self.logger.info("Latents received successfully in DecoderService for room=%s.", room) - exec_queue.append((room, waiting_queue.pop(room))) + exec_queue.append((room, config)) for room in failed_rooms: waiting_queue.pop(room, None) diff --git a/lightx2v/disagg/services/encoder.py b/lightx2v/disagg/services/encoder.py index c75bac55a..2e686f4c1 100644 --- a/lightx2v/disagg/services/encoder.py +++ b/lightx2v/disagg/services/encoder.py @@ -15,6 +15,7 @@ from lightx2v.disagg.rdma_buffer import RDMABuffer, RDMABufferDescriptor from lightx2v.disagg.rdma_client import RDMAClient from lightx2v.disagg.services.base import BaseService +from lightx2v.disagg.services.data_mgr_sidecar import DataMgrSidecar from lightx2v.disagg.utils import ( estimate_encoder_buffer_sizes, load_wan_image_encoder, @@ -77,6 +78,7 @@ def __init__(self, config: dict): daemon=True, ) self._reporter_thread.start() + self._data_mgr_sidecar = DataMgrSidecar() self.sync_comm = str(os.getenv("SYNC_COMM", "")).strip().lower() not in ("", "0", "false", "no", "off") self.load_models() @@ -153,6 +155,21 @@ def _ensure_request_buffer(self) -> bool: ) return True + def _reconnect_request_buffer(self): + self._request_rdma_buffer = None + self._last_request_connect_retry_ts = 0.0 + + if self._request_rdma_client is not None: + sock = getattr(self._request_rdma_client, "sock", None) + if sock is not None: + try: + sock.close() + except Exception: + pass + self._request_rdma_client.sock = None + + self._ensure_request_buffer() + def _ensure_phase1_meta_buffer(self) -> bool: if self._phase1_rdma_buffer is not None: return True @@ -195,8 +212,68 @@ def _ensure_phase1_meta_buffer(self) -> bool: ) return True + def _reconnect_phase1_meta_buffer(self): + self._phase1_rdma_buffer = None + self._last_phase1_connect_retry_ts = 0.0 + + if self._phase1_rdma_client is not None: + sock = getattr(self._phase1_rdma_client, "sock", None) + if sock is not None: + try: + sock.close() + except Exception: + pass + self._phase1_rdma_client.sock = None + + self._ensure_phase1_meta_buffer() + + def _produce_phase1_request_with_retry(self, room: int, payload: dict[str, Any]): + retries = max(1, int(os.getenv("RDMA_PHASE1_PRODUCE_RETRIES", "3"))) + retry_delay_s = max(0.01, float(os.getenv("RDMA_PHASE1_PRODUCE_RETRY_DELAY_S", "0.2"))) + last_exc: Optional[Exception] = None + + for attempt in range(1, retries + 1): + try: + if self._phase1_rdma_buffer is None: + self._ensure_phase1_meta_buffer() + if self._phase1_rdma_buffer is None: + raise RuntimeError("phase1 RDMA buffer is not ready") + self._phase1_rdma_buffer.produce(payload) + return + except Exception as exc: + last_exc = exc + self.logger.warning( + "Phase1 RDMA produce failed for room=%s attempt=%s/%s host=%s port=%s: %s", + room, + attempt, + retries, + self._phase1_server_ip, + self._phase1_handshake_port, + exc, + ) + if attempt >= retries: + break + try: + self._reconnect_phase1_meta_buffer() + except Exception as reconnect_exc: + self.logger.warning( + "Phase1 RDMA reconnect failed for room=%s attempt=%s/%s host=%s port=%s: %s", + room, + attempt, + retries, + self._phase1_server_ip, + self._phase1_handshake_port, + reconnect_exc, + ) + time.sleep(retry_delay_s) + + raise RuntimeError(f"Failed to produce phase1 RDMA request for room={room} after {retries} attempts") from last_exc + def init(self, config): self._sync_runtime_config(config) + self.encoder_engine_rank = int(self.config.get("encoder_engine_rank", self.encoder_engine_rank)) + self.transformer_engine_rank = int(self.config.get("transformer_engine_rank", self.transformer_engine_rank)) + self.decoder_engine_rank = int(self.config.get("decoder_engine_rank", self.decoder_engine_rank)) shared_slots = int(self.config.get("rdma_buffer_slots", self._request_slots)) shared_slot_size = int(self.config.get("rdma_buffer_slot_size", 4096)) self._request_server_ip = str(self.config.get("rdma_request_host", self._request_server_ip)) @@ -513,9 +590,7 @@ def _sha256_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]: "encoder_node_address": self.data_mgr.get_localhost(), "encoder_session_id": self.data_mgr.get_session_id(), } - if self._phase1_rdma_buffer is None: - raise RuntimeError("phase1 RDMA buffer is not ready") - self._phase1_rdma_buffer.produce(phase1_meta) + self._produce_phase1_request_with_retry(room, phase1_meta) sender.send(buffer_ptrs) if self.sync_comm: self._wait_sender_success(room, sender) @@ -532,6 +607,7 @@ def remove(self, room: int): self.release_memory(room) self.data_sender.pop(room, None) + self._data_mgr_sidecar.unwatch_output(room) if self.data_mgr is None: return @@ -555,19 +631,21 @@ def release(self): def run(self, stop_event=None): req_queue = deque() exec_queue = deque() - complete_queue: Dict[int, dict] = {} + complete_queue: set[int] = set() while True: transfer_sizes = self.data_mgr.get_backlog_counts() if self.data_mgr is not None else {"request_pool": 0, "waiting_pool": 0} + sidecar_sizes = self._data_mgr_sidecar.get_pending_counts() self._update_queue_metrics( { "req_queue": len(req_queue), "exec_queue": len(exec_queue), - "complete_queue": len(complete_queue), }, { + "complete_queue": len(complete_queue), "request_pool": int(transfer_sizes.get("request_pool", 0)), "waiting_pool": int(transfer_sizes.get("waiting_pool", 0)), + "sidecar_output_watch": int(sidecar_sizes.get("output_watch", 0)), }, ) @@ -577,6 +655,16 @@ def run(self, stop_event=None): except Exception: self.logger.exception("Failed to connect request RDMA buffer, will retry") + if self._request_rdma_client is not None and self._request_rdma_client.has_qp_error(): + self.logger.warning( + "Request RDMA client entered error state, reconnecting: %s", + self._request_rdma_client.last_wc_error_message(), + ) + try: + self._reconnect_request_buffer() + except Exception: + self.logger.exception("Failed to reconnect request RDMA buffer after QP error") + if self._request_rdma_buffer is not None: config = self._request_rdma_buffer.consume() if config is not None: @@ -602,32 +690,25 @@ def run(self, stop_event=None): room, config = exec_queue.popleft() try: self.process(config) - complete_queue[room] = config + if self.sync_comm: + self.remove(room) + else: + sender = self.data_sender.get(room) + if sender is None: + self.logger.error("DataSender is missing for room=%s", room) + self.remove(room) + else: + self._data_mgr_sidecar.watch_output(room, sender) + complete_queue.add(room) except Exception: self.logger.exception("Failed to process request for room=%s", room) - complete_queue.pop(room, None) self.remove(room) - completed_rooms: List[int] = [] - for room in list(complete_queue.keys()): - sender = self.data_sender.get(room) - if sender is None: - completed_rooms.append(room) - continue - - if self.sync_comm: - completed_rooms.append(room) - continue - - status = sender.poll() - if status == DataPoll.Success: - completed_rooms.append(room) - elif status == DataPoll.Failed: + completed_outputs = self._data_mgr_sidecar.pop_completed_outputs() + for room, status in completed_outputs: + if status == DataPoll.Failed: self.logger.error("DataSender transfer failed for room=%s", room) - completed_rooms.append(room) - - for room in completed_rooms: - complete_queue.pop(room, None) + complete_queue.discard(room) self.remove(room) if stop_event is not None and stop_event.is_set() and not req_queue and not exec_queue and not complete_queue: diff --git a/lightx2v/disagg/services/transformer.py b/lightx2v/disagg/services/transformer.py index 338a0d285..8d1b69555 100644 --- a/lightx2v/disagg/services/transformer.py +++ b/lightx2v/disagg/services/transformer.py @@ -4,6 +4,7 @@ import threading import time from collections import deque +from multiprocessing import resource_tracker, shared_memory from typing import Any, List, Optional import numpy as np @@ -15,6 +16,7 @@ from lightx2v.disagg.rdma_buffer import RDMABuffer, RDMABufferDescriptor from lightx2v.disagg.rdma_client import RDMAClient from lightx2v.disagg.services.base import BaseService +from lightx2v.disagg.services.data_mgr_sidecar import DataMgrSidecar from lightx2v.disagg.utils import ( estimate_encoder_buffer_sizes, estimate_transformer_buffer_sizes, @@ -26,9 +28,36 @@ from lightx2v_platform.base.global_var import AI_DEVICE +_SHM_TRACKING_PATCHED = False + + +def _disable_shared_memory_tracking_for_process(): + global _SHM_TRACKING_PATCHED + if _SHM_TRACKING_PATCHED: + return + + original_register = resource_tracker.register + original_unregister = resource_tracker.unregister + + def _register(name, rtype): + if rtype == "shared_memory": + return + return original_register(name, rtype) + + def _unregister(name, rtype): + if rtype == "shared_memory": + return + return original_unregister(name, rtype) + + resource_tracker.register = _register + resource_tracker.unregister = _unregister + _SHM_TRACKING_PATCHED = True + + class TransformerService(BaseService): def __init__(self, config: dict): super().__init__() + _disable_shared_memory_tracking_for_process() self.config = config self.encoder_engine_rank = int(self.config.get("encoder_engine_rank", 0)) self.transformer_engine_rank = int(self.config.get("transformer_engine_rank", 1)) @@ -56,7 +85,9 @@ def __init__(self, config: dict): self.data_mgr1 = DataManager(DisaggregationPhase.PHASE1, DisaggregationMode.TRANSFORMER) self.data_mgr2 = DataManager(DisaggregationPhase.PHASE2, DisaggregationMode.TRANSFORMER) self.data_receiver: dict[int, DataReceiver] = {} - self.data_sender: dict[int, DataSender] = {} + self.data_sender: dict[int, Optional[DataSender]] = {} + self._phase2_remote_rooms: set[int] = set() + self._phase2_remote_shared_memory: dict[int, list[shared_memory.SharedMemory]] = {} self.reporter = Reporter( service_type="transformer", gpu_id=self.transformer_engine_rank, @@ -75,6 +106,7 @@ def __init__(self, config: dict): daemon=True, ) self._reporter_thread.start() + self._data_mgr_sidecar = DataMgrSidecar() self.sync_comm = str(os.getenv("SYNC_COMM", "")).strip().lower() not in ("", "0", "false", "no", "off") self.load_models() @@ -87,6 +119,13 @@ def _wait_sender_success(self, room: int, sender: DataSender): raise RuntimeError(f"DataSender transfer failed for room={room}") time.sleep(0.001) + def _attach_remote_shared_memory(self, shm_name: str) -> shared_memory.SharedMemory: + # Python 3.12 supports `track=False`, which avoids duplicate cleanup from non-owner processes. + try: + return shared_memory.SharedMemory(name=shm_name, create=False, track=False) + except TypeError: + return shared_memory.SharedMemory(name=shm_name, create=False) + def _get_queue_metrics(self) -> dict[str, Any]: with self._queue_metrics_lock: queue_sizes = dict(self._queue_metrics.get("queue_sizes", {})) @@ -145,6 +184,21 @@ def _ensure_phase1_request_buffer(self) -> bool: ) return True + def _reconnect_phase1_request_buffer(self): + self._phase1_rdma_buffer = None + self._last_phase1_connect_retry_ts = 0.0 + + if self._phase1_rdma_client is not None: + sock = getattr(self._phase1_rdma_client, "sock", None) + if sock is not None: + try: + sock.close() + except Exception: + pass + self._phase1_rdma_client.sock = None + + self._ensure_phase1_request_buffer() + def _ensure_phase2_meta_buffer(self) -> bool: if self._phase2_rdma_buffer is not None: return True @@ -173,8 +227,64 @@ def _ensure_phase2_meta_buffer(self) -> bool: ) return True + def _reconnect_phase2_meta_buffer(self): + self._phase2_rdma_buffer = None + self._last_phase2_connect_retry_ts = 0.0 + + if self._phase2_rdma_client is not None: + sock = getattr(self._phase2_rdma_client, "sock", None) + if sock is not None: + try: + sock.close() + except Exception: + pass + self._phase2_rdma_client.sock = None + + self._ensure_phase2_meta_buffer() + + def _produce_phase2_request_with_retry(self, room: int, payload: dict[str, Any]): + retries = max(1, int(os.getenv("RDMA_PHASE2_PRODUCE_RETRIES", "3"))) + retry_delay_s = max(0.01, float(os.getenv("RDMA_PHASE2_PRODUCE_RETRY_DELAY_S", "0.2"))) + last_exc: Optional[Exception] = None + + for attempt in range(1, retries + 1): + try: + if self._phase2_rdma_buffer is None: + self._ensure_phase2_meta_buffer() + if self._phase2_rdma_buffer is None: + raise RuntimeError("phase2 RDMA buffer is not ready") + self._phase2_rdma_buffer.produce(payload) + return + except Exception as exc: + last_exc = exc + self.logger.warning( + "Phase2 RDMA produce failed for room=%s attempt=%s/%s: %s", + room, + attempt, + retries, + exc, + ) + if attempt >= retries: + break + try: + self._reconnect_phase2_meta_buffer() + except Exception as reconnect_exc: + self.logger.warning( + "Phase2 RDMA reconnect failed for room=%s attempt=%s/%s: %s", + room, + attempt, + retries, + reconnect_exc, + ) + time.sleep(retry_delay_s) + + raise RuntimeError(f"Failed to produce phase2 RDMA request for room={room} after {retries} attempts") from last_exc + def init(self, config): self._sync_runtime_config(config) + self.encoder_engine_rank = int(self.config.get("encoder_engine_rank", self.encoder_engine_rank)) + self.transformer_engine_rank = int(self.config.get("transformer_engine_rank", self.transformer_engine_rank)) + self.decoder_engine_rank = int(self.config.get("decoder_engine_rank", self.decoder_engine_rank)) shared_slots = int(self.config.get("rdma_buffer_slots", self._phase1_slots)) shared_slot_size = int(self.config.get("rdma_buffer_slot_size", 4096)) self._phase1_server_ip = str(self.config.get("rdma_phase1_host", self._phase1_server_ip)) @@ -234,24 +344,69 @@ def init(self, config): self.data_receiver[data_bootstrap_room] = DataReceiver(self.data_mgr1, phase1_bootstrap_addr, data_bootstrap_room) self.data_receiver[data_bootstrap_room].init() - buffer_sizes = estimate_transformer_buffer_sizes(self.config) - request = AllocationRequest( - bootstrap_room=data_bootstrap_room, - buffer_sizes=buffer_sizes, - ) - handle = self.alloc_memory(DisaggregationPhase.PHASE2, request) - data_ptrs = [buf.addr for buf in handle.buffers] - data_lens = [buf.nbytes for buf in handle.buffers] - data_args = DataArgs( - sender_engine_rank=self.transformer_engine_rank, - receiver_engine_rank=self.decoder_engine_rank, - data_ptrs=data_ptrs, - data_lens=data_lens, - data_item_lens=data_lens, - ib_device=None, - ) - self.data_mgr2.init(data_args, data_bootstrap_room) - self.data_sender[data_bootstrap_room] = DataSender(self.data_mgr2, data_bootstrap_addr, data_bootstrap_room) + buffer_sizes = [int(v) for v in estimate_transformer_buffer_sizes(self.config)] + remote_room: dict[str, Any] | None = None + room_init_retries = max(1, int(os.getenv("DISAGG_TRANSFORMER_REMOTE_OUTPUT_INIT_RETRIES", "3"))) + room_init_retry_sleep_s = max(0.01, float(os.getenv("DISAGG_TRANSFORMER_REMOTE_OUTPUT_INIT_RETRY_SLEEP_S", "0.2"))) + + for attempt in range(1, room_init_retries + 1): + try: + remote_room = self._data_mgr_sidecar.init_transformer_output_room( + room=data_bootstrap_room, + sender_engine_rank=self.transformer_engine_rank, + receiver_engine_rank=self.decoder_engine_rank, + data_lens=buffer_sizes, + bootstrap_addr=data_bootstrap_addr, + ) + except Exception: + self.logger.exception( + "Failed to initialize remote transformer output room=%s attempt=%s/%s", + data_bootstrap_room, + attempt, + room_init_retries, + ) + remote_room = None + + if isinstance(remote_room, dict): + break + + if attempt < room_init_retries: + time.sleep(room_init_retry_sleep_s) + + if not isinstance(remote_room, dict): + raise RuntimeError( + f"remote transformer output room init failed for room={data_bootstrap_room}; " + "sidecar ownership is required to keep transfers alive during service reclaim" + ) + + shm_names_raw = remote_room.get("shm_names") + data_lens_raw = remote_room.get("data_lens", buffer_sizes) + if not isinstance(shm_names_raw, list) or not isinstance(data_lens_raw, list) or len(shm_names_raw) != len(data_lens_raw): + raise RuntimeError(f"invalid remote output room metadata for room={data_bootstrap_room}: {remote_room}") + + shm_handles: list[shared_memory.SharedMemory] = [] + phase2_buffers: list[torch.Tensor] = [] + try: + for shm_name, nbytes in zip(shm_names_raw, data_lens_raw): + shm = self._attach_remote_shared_memory(str(shm_name)) + np_view = np.ndarray((int(nbytes),), dtype=np.uint8, buffer=shm.buf) + tensor = torch.from_numpy(np_view) + tensor.zero_() + shm_handles.append(shm) + phase2_buffers.append(tensor) + except Exception: + for shm in shm_handles: + try: + shm.close() + except Exception: + pass + self._data_mgr_sidecar.remove_transformer_output_room(data_bootstrap_room) + raise + + self._phase2_remote_rooms.add(int(data_bootstrap_room)) + self._phase2_remote_shared_memory[int(data_bootstrap_room)] = shm_handles + self.rdma_buffer2[int(data_bootstrap_room)] = phase2_buffers + self.data_sender[int(data_bootstrap_room)] = None def load_models(self): self.logger.info("Loading Transformer Models...") @@ -306,6 +461,9 @@ def process(self, config): Executes the diffusion process and video decoding. """ self.logger.info("Starting processing in TransformerService...") + # Re-sync scheduler with the current request to avoid cross-request config bleed. + if self.scheduler is not None: + self.scheduler.refresh_from_config(config) room = config.get("data_bootstrap_room", 0) transformer_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("transformer", {}) transformer_metrics["compute_start_ts"] = time.time() @@ -314,6 +472,7 @@ def process(self, config): phase2_buffers = self.rdma_buffer2.get(room) receiver = self.data_receiver.get(room) sender = self.data_sender.get(room) + use_remote_phase2 = room in self._phase2_remote_rooms if phase1_buffers is None: raise RuntimeError(f"phase1 RDMA buffers are not initialized for room={room}.") @@ -321,7 +480,7 @@ def process(self, config): raise RuntimeError(f"phase2 RDMA buffers are not initialized for room={room}.") if receiver is None: raise RuntimeError(f"DataReceiver is not initialized for room={room}.") - if sender is None: + if sender is None and not use_remote_phase2: raise RuntimeError(f"DataSender is not initialized for room={room}.") def _buffer_view(buf: torch.Tensor, dtype: torch.dtype, shape: tuple[int, ...]) -> torch.Tensor: @@ -369,9 +528,18 @@ def _sha256_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]: meta_buf = phase1_buffers[buffer_index] strict_meta_hash_check = str(os.getenv("LIGHTX2V_STRICT_META_HASH", "0")).strip().lower() in {"1", "true", "yes", "on"} - def _load_phase1_meta(max_retries: int = 3, retry_sleep_s: float = 0.02) -> dict: + def _load_phase1_meta(max_retries: int = 20, retry_sleep_s: float = 0.05) -> dict: last_error: Optional[Exception] = None last_preview = "" + + required_shape_keys = ["context_shape", "latent_shape"] + if enable_cfg: + required_shape_keys.append("context_null_shape") + if task == "i2v": + required_shape_keys.append("vae_shape") + if use_image_encoder: + required_shape_keys.append("clip_shape") + for attempt in range(1, max_retries + 1): meta_bytes = _buffer_view(meta_buf, torch.uint8, (meta_buf.numel(),)).detach().contiguous().cpu().numpy().tobytes() raw_payload = meta_bytes.split(b"\x00", 1)[0] if meta_bytes else b"" @@ -428,6 +596,26 @@ def _load_phase1_meta(max_retries: int = 3, retry_sleep_s: float = 0.02) -> dict continue break + missing_shape_keys = [ + key + for key in required_shape_keys + if not isinstance(parsed.get(key), (list, tuple)) or len(parsed.get(key)) == 0 + ] + if missing_shape_keys: + last_error = ValueError(f"incomplete metadata, missing keys: {missing_shape_keys}") + last_preview = str({k: parsed.get(k) for k in required_shape_keys})[:180] + if attempt < max_retries: + self.logger.warning( + "Incomplete phase1 metadata for room=%s (attempt %s/%s), missing=%s, retrying...", + room, + attempt, + max_retries, + missing_shape_keys, + ) + time.sleep(retry_sleep_s) + continue + break + return parsed preview_suffix = f", preview={last_preview}" if last_preview else "" @@ -591,18 +779,45 @@ def _get_shape(key: str) -> tuple[int, ...]: transformer_metrics["output_enqueued_ts"] = time.time() phase2_request_config = dict(config) phase2_request_config["transformer_engine_rank"] = self.transformer_engine_rank - if self._phase2_rdma_buffer is None: - raise RuntimeError("phase2 RDMA buffer is not ready") - self._phase2_rdma_buffer.produce( + transformer_node_address = "" + transformer_session_id = "" + if room in self._phase2_remote_rooms: + identity = self._data_mgr_sidecar.get_transformer_output_identity(room) + if not isinstance(identity, dict): + raise RuntimeError(f"remote transformer output identity unavailable for room={room}") + transformer_node_address = str(identity.get("host", "")).strip() + transformer_session_id = str(identity.get("session_id", "")).strip() + if not transformer_node_address or not transformer_session_id: + raise RuntimeError(f"remote transformer output identity invalid for room={room}: {identity}") + else: + transformer_node_address = self.data_mgr2.get_localhost() + transformer_session_id = self.data_mgr2.get_session_id() + + self._produce_phase2_request_with_retry( + room, { "request_config": phase2_request_config, - "transformer_node_address": self.data_mgr2.get_localhost(), - "transformer_session_id": self.data_mgr2.get_session_id(), - } + "transformer_node_address": transformer_node_address, + "transformer_session_id": transformer_session_id, + }, ) - sender.send(buffer_ptrs) - if self.sync_comm: - self._wait_sender_success(room, sender) + if use_remote_phase2: + if not self._data_mgr_sidecar.send_transformer_output_room(room): + raise RuntimeError(f"Failed to enqueue remote transformer output transfer for room={room}") + if self.sync_comm: + while True: + status = int(self._data_mgr_sidecar.get_transformer_output_status(room)) + if status == DataPoll.Success: + break + if status == DataPoll.Failed: + raise RuntimeError(f"DataSender transfer failed for room={room}") + time.sleep(0.001) + else: + if sender is None: + raise RuntimeError(f"DataSender is not initialized for room={room}") + sender.send(buffer_ptrs) + if self.sync_comm: + self._wait_sender_success(room, sender) def release_memory(self, room: int): """ @@ -614,17 +829,31 @@ def release_memory(self, room: int): if room in self.rdma_buffer2: self.rdma_buffer2.pop(room, None) + shm_handles = self._phase2_remote_shared_memory.pop(room, None) + if isinstance(shm_handles, list): + for shm in shm_handles: + try: + shm.close() + except Exception: + pass + self._phase2_remote_rooms.discard(room) + torch.cuda.empty_cache() def remove(self, room: int): + use_remote_phase2 = room in self._phase2_remote_rooms self.release_memory(room) self.data_receiver.pop(room, None) self.data_sender.pop(room, None) + self._data_mgr_sidecar.unwatch_input(room) + self._data_mgr_sidecar.unwatch_output(room) if self.data_mgr1 is not None: self.data_mgr1.remove(room) - if self.data_mgr2 is not None: + if use_remote_phase2: + self._data_mgr_sidecar.remove_transformer_output_room(room) + elif self.data_mgr2 is not None: self.data_mgr2.remove(room) def release(self): @@ -648,25 +877,31 @@ def run(self, stop_event=None): req_queue = deque() waiting_queue: dict[int, dict] = {} exec_queue = deque() - complete_queue: dict[int, dict] = {} + complete_queue: set[int] = set() while True: phase1_transfer_sizes = self.data_mgr1.get_backlog_counts() if self.data_mgr1 is not None else {"request_pool": 0, "waiting_pool": 0} phase2_transfer_sizes = self.data_mgr2.get_backlog_counts() if self.data_mgr2 is not None else {"request_pool": 0, "waiting_pool": 0} + remote_phase2_transfer_sizes = self._data_mgr_sidecar.get_transformer_output_backlog() + for key in ("request_pool", "waiting_pool", "request_status"): + phase2_transfer_sizes[key] = int(phase2_transfer_sizes.get(key, 0)) + int(remote_phase2_transfer_sizes.get(key, 0)) + sidecar_sizes = self._data_mgr_sidecar.get_pending_counts() self._update_queue_metrics( { "req_queue": len(req_queue), "waiting_queue": len(waiting_queue), "exec_queue": len(exec_queue), - "complete_queue": len(complete_queue), }, { "request_pool": int(phase1_transfer_sizes.get("request_pool", 0)), "waiting_pool": int(phase1_transfer_sizes.get("waiting_pool", 0)), + "sidecar_input_watch": int(sidecar_sizes.get("input_watch", 0)), }, { + "complete_queue": len(complete_queue), "request_pool": int(phase2_transfer_sizes.get("request_pool", 0)), "waiting_pool": int(phase2_transfer_sizes.get("waiting_pool", 0)), + "sidecar_output_watch": int(sidecar_sizes.get("output_watch", 0)), }, ) @@ -676,6 +911,16 @@ def run(self, stop_event=None): except Exception: self.logger.exception("Failed to connect phase1 request RDMA buffer, will retry") + if self._phase1_rdma_client is not None and self._phase1_rdma_client.has_qp_error(): + self.logger.warning( + "Phase1 request RDMA client entered error state, reconnecting: %s", + self._phase1_rdma_client.last_wc_error_message(), + ) + try: + self._reconnect_phase1_request_buffer() + except Exception: + self.logger.exception("Failed to reconnect phase1 request RDMA buffer after QP error") + if self._phase1_rdma_buffer is not None and len(req_queue) + len(waiting_queue) < 2: packet = self._phase1_rdma_buffer.consume() if packet is not None: @@ -698,30 +943,21 @@ def run(self, stop_event=None): try: self.init(config) waiting_queue[room] = config + receiver = self.data_receiver.get(room) + if receiver is None: + raise RuntimeError(f"DataReceiver is not initialized for room={room}") + self._data_mgr_sidecar.watch_input(room, receiver) except Exception: self.logger.exception("Failed to initialize request for room=%s", room) self.remove(room) - ready_rooms: List[int] = [] - failed_rooms: List[int] = [] - waiting_items = list(waiting_queue.items()) - if self.sync_comm and waiting_items: - waiting_items = [waiting_items[0]] - - for room, config in waiting_items: - receiver = self.data_receiver.get(room) - if receiver is None: - failed_rooms.append(room) - continue - - status = receiver.poll() - if status == DataPoll.Success: - ready_rooms.append(room) - elif status == DataPoll.Failed: - failed_rooms.append(room) + ready_rooms = self._data_mgr_sidecar.pop_ready_inputs() + failed_rooms = self._data_mgr_sidecar.pop_failed_inputs() for room in ready_rooms: - exec_queue.append((room, waiting_queue.pop(room))) + config = waiting_queue.pop(room, None) + if config is not None: + exec_queue.append((room, config)) for room in failed_rooms: waiting_queue.pop(room, None) @@ -729,31 +965,33 @@ def run(self, stop_event=None): self.remove(room) if exec_queue: - room, config = exec_queue.popleft() + room, config = exec_queue[0] try: self.process(config) - complete_queue[room] = config + if self.sync_comm: + self.remove(room) + else: + if room in self._phase2_remote_rooms: + complete_queue.add(room) + else: + sender = self.data_sender.get(room) + if sender is None: + self.logger.error("DataSender is not initialized for room=%s", room) + self.remove(room) + else: + self._data_mgr_sidecar.watch_output(room, sender) + complete_queue.add(room) except Exception: self.logger.exception("Failed to process request for room=%s", room) - complete_queue.pop(room, None) self.remove(room) + finally: + exec_queue.popleft() - completed_rooms: List[int] = [] - for room in list(complete_queue.keys()): - sender = self.data_sender.get(room) - if sender is None: - completed_rooms.append(room) - continue - - status = sender.poll() - if status == DataPoll.Success: - completed_rooms.append(room) - elif status == DataPoll.Failed: + completed_outputs = self._data_mgr_sidecar.pop_completed_outputs() + for room, status in completed_outputs: + if status == DataPoll.Failed: self.logger.error("DataSender transfer failed for room=%s", room) - completed_rooms.append(room) - - for room in completed_rooms: - complete_queue.pop(room, None) + complete_queue.discard(room) self.remove(room) if stop_event is not None and stop_event.is_set() and not req_queue and not waiting_queue and not exec_queue and not complete_queue: diff --git a/lightx2v/models/schedulers/wan/scheduler.py b/lightx2v/models/schedulers/wan/scheduler.py index 19cc3f844..047118630 100755 --- a/lightx2v/models/schedulers/wan/scheduler.py +++ b/lightx2v/models/schedulers/wan/scheduler.py @@ -28,6 +28,16 @@ def __init__(self, config): self.caching_records_2 = [True] * self.config["infer_steps"] self.head_size = self.config["dim"] // self.config["num_heads"] + def refresh_from_config(self, config): + self.config = config + self.infer_steps = int(self.config["infer_steps"]) + self.target_video_length = int(self.config["target_video_length"]) + self.sample_shift = float(self.config["sample_shift"]) + self.sample_guide_scale = self.config["sample_guide_scale"] + self.caching_records = [True] * self.infer_steps + self.caching_records_2 = [True] * self.infer_steps + self.step_index = 0 + def prepare(self, seed, latent_shape, image_encoder_output=None): if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v", "rs2v"]: self.vae_encoder_out = image_encoder_output["vae_encoder_out"] diff --git a/scripts/disagg/kill_service.sh b/scripts/disagg/kill_service.sh index 57f79bc4d..4003e93cc 100755 --- a/scripts/disagg/kill_service.sh +++ b/scripts/disagg/kill_service.sh @@ -2,14 +2,19 @@ set -euo pipefail -SCRIPT_NAME="run_wan22_i2v_distill.sh" +SCRIPT_NAMES=("run_wan22_i2v_distill.sh" "run_dynamic.sh") list_port=(5566 12788 17788 27788) n=30 list_n=($(seq 0 $((n-1)))) -PORTS=(5555 7788 7789 7790 12787) +PORTS=(5555 12787) + +# Monitor ports for autoscaled services are contiguous from 7788. +for p in $(seq 7788 7803); do + PORTS+=($p) +done for a in "${list_port[@]}"; do for b in "${list_n[@]}"; do @@ -22,6 +27,10 @@ kill_pid_gracefully() { if [[ -z "$pid" ]]; then return fi + if is_protected_pid "$pid"; then + echo "Skip protected pid=$pid" + return + fi if kill -0 "$pid" 2>/dev/null; then kill "$pid" 2>/dev/null || true sleep 1 @@ -31,6 +40,32 @@ kill_pid_gracefully() { fi } +declare -a PROTECTED_PIDS=() +collect_protected_pids() { + local cur="$$" + while [[ -n "$cur" ]] && [[ "$cur" != "0" ]]; do + PROTECTED_PIDS+=("$cur") + local parent + parent=$(ps -o ppid= -p "$cur" 2>/dev/null | tr -d ' ' || true) + if [[ -z "$parent" ]] || [[ "$parent" == "$cur" ]]; then + break + fi + cur="$parent" + done +} + +is_protected_pid() { + local target="$1" + for p in "${PROTECTED_PIDS[@]}"; do + if [[ "$p" == "$target" ]]; then + return 0 + fi + done + return 1 +} + +collect_protected_pids + find_listen_pids_by_port() { local port="$1" @@ -59,22 +94,26 @@ find_listen_pids_by_port() { echo "No supported tool found to query listening ports (need one of: lsof, ss, fuser)." >&2 } -echo "Stopping script process: ${SCRIPT_NAME}" -script_pids=$(pgrep -f "$SCRIPT_NAME" || true) -if [[ -n "${script_pids}" ]]; then - while read -r pid; do - [[ -z "$pid" ]] && continue - echo "Killing script pid=$pid" - kill_pid_gracefully "$pid" - done <<< "$script_pids" -else - echo "No running process found for ${SCRIPT_NAME}" -fi +for script_name in "${SCRIPT_NAMES[@]}"; do + echo "Stopping script process: ${script_name}" + script_pids=$(pgrep -f "$script_name" || true) + if [[ -n "${script_pids}" ]]; then + while read -r pid; do + [[ -z "$pid" ]] && continue + echo "Killing script pid=$pid" + kill_pid_gracefully "$pid" + done <<< "$script_pids" + else + echo "No running process found for ${script_name}" + fi +done # Fallback cleanup for orphaned disagg service processes. cleanup_patterns=( "lightx2v.disagg.examples.run_service" + "lightx2v.disagg.examples.run_user" "python -m lightx2v.disagg" + "conda run -n lightx2v bash scripts/disagg/run_dynamic.sh" "conda run -n lightx2v bash scripts/disagg/run_wan22_i2v_distill.sh" ) diff --git a/scripts/disagg/run_dynamic.sh b/scripts/disagg/run_dynamic.sh index 80120bdd2..087f82682 100644 --- a/scripts/disagg/run_dynamic.sh +++ b/scripts/disagg/run_dynamic.sh @@ -10,6 +10,9 @@ export PYTHONPATH=${PYTHONPATH:-} source ${lightx2v_path}/scripts/base/base.sh +# Ensure stale disagg services/ports from previous runs do not block bootstrap. +bash ${lightx2v_path}/scripts/disagg/kill_service.sh || true + export CC=/usr/bin/gcc-13 export CXX=/usr/bin/g++-13 export CUDAHOSTCXX=/usr/bin/g++-13 @@ -30,6 +33,7 @@ fi export DISAGG_CONTROLLER_HOST=${DISAGG_CONTROLLER_HOST:-127.0.0.1} export DISAGG_CONTROLLER_REQUEST_PORT=${DISAGG_CONTROLLER_REQUEST_PORT:-12786} +export LOAD_FROM_USER=${LOAD_FROM_USER:-0} controller_cfg=${lightx2v_path}/configs/disagg/wan22_i2v_distill_controller.json seed=${SEED:-42} @@ -40,6 +44,75 @@ save_result_path=${SAVE_RESULT_PATH:-${lightx2v_path}/save_results/wan22_i2v_dyn controller_log=${lightx2v_path}/save_results/disagg_wan22_i2v_dynamic_controller.log user_log=${lightx2v_path}/save_results/disagg_wan22_i2v_dynamic_user.log +controller_wait_timeout_s=${CONTROLLER_WAIT_TIMEOUT_S:-3000} +controller_poll_interval_s=${CONTROLLER_POLL_INTERVAL_S:-5} +fatal_watch_interval_s=${FATAL_WATCH_INTERVAL_S:-2} +fatal_flag_file=${lightx2v_path}/save_results/disagg_wan22_i2v_dynamic_fatal.flag + +rm -f "${fatal_flag_file}" + +has_fatal_log_error() { + local log_path="$1" + [[ -f "${log_path}" ]] || return 1 + + # Fail-fast on known fatal patterns so we do not wait for full run completion. + rg -q "KeyError: '/psm_|resource_tracker: There appear to be [0-9]+ leaked shared_memory objects|Failed to process request for room=|Data(Sender|Receiver) transfer failed for room=" "${log_path}" +} + +start_fatal_watchdog() { + ( + while true; do + if [[ -f "${fatal_flag_file}" ]]; then + exit 0 + fi + if [[ -n "${controller_pid:-}" ]] && ! kill -0 "${controller_pid}" 2>/dev/null; then + exit 0 + fi + if has_fatal_log_error "${controller_log}" || has_fatal_log_error "${user_log}"; then + echo "fatal error detected in logs, stopping services immediately" + : > "${fatal_flag_file}" + [[ -n "${user_pid:-}" ]] && kill -TERM "${user_pid}" 2>/dev/null || true + [[ -n "${controller_pid:-}" ]] && kill -TERM "${controller_pid}" 2>/dev/null || true + # Give controller/sidecars a short grace window to release rooms. + for _ in $(seq 1 10); do + local_alive=0 + if [[ -n "${user_pid:-}" ]] && kill -0 "${user_pid}" 2>/dev/null; then + local_alive=1 + fi + if [[ -n "${controller_pid:-}" ]] && kill -0 "${controller_pid}" 2>/dev/null; then + local_alive=1 + fi + if [[ "${local_alive}" -eq 0 ]]; then + break + fi + sleep 0.5 + done + bash ${lightx2v_path}/scripts/disagg/kill_service.sh || true + exit 0 + fi + sleep "${fatal_watch_interval_s}" + done + ) & + watchdog_pid=$! +} + +is_controller_stuck() { + local log_path="$1" + [[ -f "${log_path}" ]] || return 1 + + local tail_block + tail_block=$(tail -n 240 "${log_path}" 2>/dev/null || true) + [[ -n "${tail_block}" ]] || return 1 + + # Waiting for decoder results, all GPUs idle, and queues still pending => hard-stuck. + if echo "${tail_block}" | rg -q "Waiting for decoder results" \ + && echo "${tail_block}" | rg -q "queue_total_pending': [1-9]" \ + && ! echo "${tail_block}" | rg -q "gpu_utilization': ([1-9][0-9]*|0\\.[1-9])"; then + return 0 + fi + return 1 +} + cleanup() { local pids=("${user_pid:-}" "${controller_pid:-}") for pid in "${pids[@]}"; do @@ -47,6 +120,10 @@ cleanup() { kill "${pid}" 2>/dev/null || true fi done + if [[ -n "${watchdog_pid:-}" ]] && kill -0 "${watchdog_pid}" 2>/dev/null; then + kill "${watchdog_pid}" 2>/dev/null || true + fi + bash ${lightx2v_path}/scripts/disagg/kill_service.sh || true } trap cleanup EXIT INT TERM @@ -67,18 +144,60 @@ controller_pid=$! echo "controller started pid=${controller_pid}" sleep 8 -python -m lightx2v.disagg.examples.run_user \ - --controller_host "${DISAGG_CONTROLLER_HOST}" \ - --controller_request_port "${DISAGG_CONTROLLER_REQUEST_PORT}" \ - > ${user_log} 2>&1 & -user_pid=$! +if [[ "${LOAD_FROM_USER}" != "0" ]]; then + python -m lightx2v.disagg.examples.run_user \ + --controller_host "${DISAGG_CONTROLLER_HOST}" \ + --controller_request_port "${DISAGG_CONTROLLER_REQUEST_PORT}" \ + > ${user_log} 2>&1 & + user_pid=$! + echo "run_user started pid=${user_pid}" +else + echo "LOAD_FROM_USER=${LOAD_FROM_USER}, skip starting run_user" +fi -echo "run_user started pid=${user_pid}" +start_fatal_watchdog -wait ${user_pid} -echo "run_user finished" +if [[ -n "${user_pid:-}" ]]; then + wait ${user_pid} || true + echo "run_user finished" +fi + +if [[ -f "${fatal_flag_file}" ]]; then + echo "fatal error handled by watchdog, exiting early" + wait "${controller_pid}" 2>/dev/null || true + exit 125 +fi + +controller_wait_start=$(date +%s) +while kill -0 "${controller_pid}" 2>/dev/null; do + now_ts=$(date +%s) + elapsed=$((now_ts - controller_wait_start)) + + if (( elapsed >= controller_wait_timeout_s )); then + if is_controller_stuck "${controller_log}"; then + echo "controller stuck detected (all GPUs idle with pending queues), force killing services" + else + echo "controller wait timeout (${controller_wait_timeout_s}s), force killing services" + fi + kill "${controller_pid}" 2>/dev/null || true + bash ${lightx2v_path}/scripts/disagg/kill_service.sh || true + wait "${controller_pid}" 2>/dev/null || true + exit 124 + fi + + if [[ -f "${fatal_flag_file}" ]]; then + echo "fatal error handled by watchdog, exiting early" + wait "${controller_pid}" 2>/dev/null || true + exit 125 + fi + + sleep "${controller_poll_interval_s}" +done wait ${controller_pid} +if [[ -n "${watchdog_pid:-}" ]] && kill -0 "${watchdog_pid}" 2>/dev/null; then + kill "${watchdog_pid}" 2>/dev/null || true +fi echo "controller finished" echo "logs: ${controller_log} ${user_log}" From b24f32b4b9b2376e543b04a2a26b6194144abc31 Mon Sep 17 00:00:00 2001 From: zhtshr <547553598@qq.com> Date: Mon, 20 Apr 2026 15:36:27 +0800 Subject: [PATCH 4/9] fix merge bugs --- lightx2v/disagg/rdma_buffer.py | 28 ++----- lightx2v/disagg/services/controller.py | 103 +++++++++++++++---------- scripts/disagg/run_wan_t2v_service.sh | 2 +- 3 files changed, 69 insertions(+), 64 deletions(-) diff --git a/lightx2v/disagg/rdma_buffer.py b/lightx2v/disagg/rdma_buffer.py index 15c4419ed..bb03ad4ed 100644 --- a/lightx2v/disagg/rdma_buffer.py +++ b/lightx2v/disagg/rdma_buffer.py @@ -41,11 +41,6 @@ class RDMABuffer: - client: consumer side, reads slots remotely and updates head by rdma_faa. The ring stores serialized JSON configs in fixed-size slots. - - Multi-consumer note: multiple client processes calling ``consume()`` compete on the - same head pointer. Unless the backend implements a true remote atomic fetch-add - (see ``RDMAClient.rdma_faa``), correctness under heavy parallel consumption is not - guaranteed. Prefer one consumer per ring or low parallelism for production. """ def __init__( @@ -94,8 +89,8 @@ def __init__( base_addr = int(info["addr"]) need_bytes = 16 + self.buffer_size * self.slot_size self.rdma_server.register_memory(base_addr, need_bytes) - self.rdma_server.write_memory(base_addr, (0).to_bytes(8, byteorder="big", signed=False)) - self.rdma_server.write_memory(base_addr + 8, (0).to_bytes(8, byteorder="big", signed=False)) + self.rdma_server.write_memory(base_addr, (0).to_bytes(8, byteorder="little", signed=False)) + self.rdma_server.write_memory(base_addr + 8, (0).to_bytes(8, byteorder="little", signed=False)) self._descriptor = RDMABufferDescriptor( slot_addr=base_addr + 16, slot_bytes=self.buffer_size * self.slot_size, @@ -128,10 +123,10 @@ def descriptor(self) -> RDMABufferDescriptor: return self._descriptor def _write_local_u64(self, buf: bytearray, value: int): - buf[:8] = int(value & _U64_MASK).to_bytes(8, byteorder="big", signed=False) + buf[:8] = (int(value) & _U64_MASK).to_bytes(8, byteorder="little", signed=False) def _read_local_u64(self, buf: bytearray) -> int: - return int.from_bytes(bytes(buf[:8]), byteorder="big", signed=False) + return int.from_bytes(bytes(buf[:8]), byteorder="little", signed=False) def _u64_distance(self, newer: int, older: int) -> int: """Return unsigned circular distance on a 64-bit counter space.""" @@ -145,7 +140,7 @@ def _rdma_faa(self, ptr_addr: int, add_value: int) -> int: with self._lock: old = self._read_remote_u64(ptr_addr) new = (old + int(add_value)) & ((1 << 64) - 1) - self._rdma_write_bytes(ptr_addr, new.to_bytes(8, byteorder="big", signed=False)) + self._rdma_write_bytes(ptr_addr, new.to_bytes(8, byteorder="little", signed=False)) return old # Fallback: local atomic emulation (useful for single-process validation). @@ -245,7 +240,7 @@ def _rdma_write_bytes(self, remote_addr: int, payload: bytes): def _read_remote_u64(self, remote_addr: int) -> int: raw = self._rdma_read_bytes(remote_addr, 8) - return int.from_bytes(raw, byteorder="big", signed=False) + return int.from_bytes(raw, byteorder="little", signed=False) def _slot_offset(self, index: int) -> int: return (index % self.buffer_size) * self.slot_size @@ -278,16 +273,7 @@ def produce(self, config: Dict[str, Any]) -> int: # Read current indices first, write the slot fully, then publish by advancing tail. old_tail = self._read_remote_u64(self.descriptor.tail_addr) cur_head = self._read_remote_u64(self.descriptor.head_addr) - used = (old_tail + 1) - cur_head - if used > self.buffer_size: - self._rdma_faa(self.descriptor.tail_addr, -1) - logger.error( - "Ring buffer full: old_tail=%d cur_head=%d used=%d buffer_size=%d", - old_tail, - cur_head, - used, - self.buffer_size, - ) + if self._u64_distance(old_tail, cur_head) >= self.buffer_size: raise BufferError("ring buffer is full") slot_idx = old_tail % self.buffer_size diff --git a/lightx2v/disagg/services/controller.py b/lightx2v/disagg/services/controller.py index 380a01286..9273bbab4 100644 --- a/lightx2v/disagg/services/controller.py +++ b/lightx2v/disagg/services/controller.py @@ -54,6 +54,13 @@ def __init__(self): self._sidecar_drain_timeout_seconds: float = float(os.getenv("DISAGG_SIDECAR_DRAIN_TIMEOUT_SECONDS", "0")) self._sidecar_reclaim_threads: list[Thread] = [] self._shutting_down: bool = False + self._enable_monitor: bool = False + + def _is_monitor_enabled(self) -> bool: + raw = os.getenv("ENABLE_MONITOR") + if raw is None: + return False + return str(raw).strip().lower() in {"1", "true", "yes", "on"} def _is_tcp_port_open(self, host: str, port: int) -> bool: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: @@ -731,9 +738,10 @@ def create_instance(self, instance_type: str) -> str: instance_address = f"{self._bootstrap_addr}:{REQUEST_POLLING_PORT + gpu_id}" self._free_gpus.remove(gpu_id) # self.add_instance(instance_type, instance_address) - monitor_node = f"tcp://{self._bootstrap_addr}:{MONITOR_POLLING_PORT + gpu_id}" - if monitor_node not in self.monitor.nodes: - self.monitor.nodes.append(monitor_node) + if self._enable_monitor: + monitor_node = f"tcp://{self._bootstrap_addr}:{MONITOR_POLLING_PORT + gpu_id}" + if monitor_node not in self.monitor.nodes: + self.monitor.nodes.append(monitor_node) self._managed_instances[instance_address] = { "instance_type": instance_type, "gpu_id": gpu_id, @@ -800,7 +808,7 @@ def reclaim_instance(self, instance_type: str, instance_address: str | None = No except subprocess.TimeoutExpired as exc: raise RuntimeError(f"process did not exit after kill for {instance_type} instance {target_address}") from exc - if monitor_node in self.monitor.nodes: + if self._enable_monitor and monitor_node in self.monitor.nodes: self.monitor.nodes.remove(monitor_node) monitor_port = MONITOR_POLLING_PORT + gpu_id @@ -1060,6 +1068,7 @@ def run(self, config): self._bootstrap_addr = str(bootstrap_addr) self._runtime_config = self._to_plain(config) self._init_gpu_pool(config) + self._enable_monitor = self._is_monitor_enabled() # self.encoder_policy = RoundRobinPolicy() # self.transformer_policy = RoundRobinPolicy() @@ -1074,44 +1083,52 @@ def run(self, config): for _ in range(5): self.create_instance("transformer") - monitor_stop_event = Event() - warmup_duration_s = self._load_warmup_duration_seconds(config) - autoscale_start_mono = time.monotonic() - warmup_skip_logged = False - warmup_end_logged = False - scale_out_threshold = 80.0 - scale_out_max_queue_threshold = 2 - scale_in_threshold = 20.0 - scale_cooldown_seconds = 30.0 - last_scale_ts: dict[str, float] = { - "encoder": 0.0, - "transformer": 0.0, - "decoder": 0.0, - } + monitor_stop_event: Event | None = None + monitor_thread: Thread | None = None + self._monitor_runtime = None + + if self._enable_monitor: + monitor_stop_event = Event() + warmup_duration_s = self._load_warmup_duration_seconds(config) + autoscale_start_mono = time.monotonic() + warmup_skip_logged = False + warmup_end_logged = False + scale_out_threshold = 80.0 + scale_out_max_queue_threshold = 2 + scale_in_threshold = 20.0 + scale_cooldown_seconds = 30.0 + last_scale_ts: dict[str, float] = { + "encoder": 0.0, + "transformer": 0.0, + "decoder": 0.0, + } - self._monitor_runtime = { - "warmup_duration_s": warmup_duration_s, - "autoscale_start_mono": autoscale_start_mono, - "warmup_skip_logged": warmup_skip_logged, - "warmup_end_logged": warmup_end_logged, - "scale_out_threshold": scale_out_threshold, - "scale_out_max_queue_threshold": scale_out_max_queue_threshold, - "scale_in_threshold": scale_in_threshold, - "scale_cooldown_seconds": scale_cooldown_seconds, - "last_scale_ts": last_scale_ts, - } + self._monitor_runtime = { + "warmup_duration_s": warmup_duration_s, + "autoscale_start_mono": autoscale_start_mono, + "warmup_skip_logged": warmup_skip_logged, + "warmup_end_logged": warmup_end_logged, + "scale_out_threshold": scale_out_threshold, + "scale_out_max_queue_threshold": scale_out_max_queue_threshold, + "scale_in_threshold": scale_in_threshold, + "scale_cooldown_seconds": scale_cooldown_seconds, + "last_scale_ts": last_scale_ts, + } - monitor_thread = Thread( - target=self.monitor.run_forever, - kwargs={ - "interval_seconds": 2.0, - "callback": self._monitor_callback, - "stop_event": monitor_stop_event, - }, - name="controller-monitor", - daemon=True, - ) - monitor_thread.start() + monitor_thread = Thread( + target=self.monitor.run_forever, + kwargs={ + "interval_seconds": 2.0, + "callback": self._monitor_callback, + "stop_event": monitor_stop_event, + }, + name="controller-monitor", + daemon=True, + ) + monitor_thread.start() + self.logger.info("ENABLE_MONITOR enabled, monitor thread started") + else: + self.logger.info("ENABLE_MONITOR is not set, skip monitor logic") time.sleep(5.0) @@ -1239,8 +1256,10 @@ def run(self, config): ) finally: self._shutting_down = True - monitor_stop_event.set() - monitor_thread.join(timeout=2.0) + if monitor_stop_event is not None: + monitor_stop_event.set() + if monitor_thread is not None: + monitor_thread.join(timeout=2.0) self._monitor_runtime = None for instance_type, address in reversed(list(self.started_instances)): diff --git a/scripts/disagg/run_wan_t2v_service.sh b/scripts/disagg/run_wan_t2v_service.sh index e60d54343..2e1e41842 100755 --- a/scripts/disagg/run_wan_t2v_service.sh +++ b/scripts/disagg/run_wan_t2v_service.sh @@ -158,4 +158,4 @@ while true; do wait_seconds=$((wait_seconds + 5)) done -sleep 30 \ No newline at end of file +sleep 30 From a7bf5078144b4dbe7dc76e548f7e479f005f4a25 Mon Sep 17 00:00:00 2001 From: zhtshr <547553598@qq.com> Date: Mon, 20 Apr 2026 19:32:04 +0800 Subject: [PATCH 5/9] test multi machine --- .../disagg/wan22_i2v_distill_controller.json | 69 ++++++++++++++++++- configs/disagg/wan22_i2v_distill_decoder.json | 4 +- configs/disagg/wan22_i2v_distill_encoder.json | 4 +- .../disagg/wan22_i2v_distill_transformer.json | 4 +- lightx2v/disagg/examples/run_service.py | 12 ++-- 5 files changed, 80 insertions(+), 13 deletions(-) diff --git a/configs/disagg/wan22_i2v_distill_controller.json b/configs/disagg/wan22_i2v_distill_controller.json index 99483393e..a8d678fe6 100644 --- a/configs/disagg/wan22_i2v_distill_controller.json +++ b/configs/disagg/wan22_i2v_distill_controller.json @@ -45,14 +45,77 @@ "image_path": "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B/examples/i2v_input.JPG", "disagg_mode": "controller", "disagg_config": { - "bootstrap_addr": "127.0.0.1", + "bootstrap_addr": "192.168.0.166", "bootstrap_room": 0, "ranks": 8, "encoder_engine_rank": 0, "transformer_engine_rank": 1, "decoder_engine_rank": 2, "protocol": "rdma", - "local_hostname": "localhost", - "metadata_server": "P2PHANDSHAKE" + "local_hostname": "192.168.0.166", + "metadata_server": "P2PHANDSHAKE", + "remote_workdir": "/root/zht/LightX2V", + "remote_python_executable": "python", + "remote_activate_cmd": "source /root/miniconda3/etc/profile.d/conda.sh && conda activate lightx2v", + "remote_log_dir": "/root/zht/LightX2V/save_results", + "ssh_user": "root", + "ssh_options": [ + "-i", + "/root/.ssh/id_ed25519_zht", + "-o", + "BatchMode=yes", + "-o", + "StrictHostKeyChecking=no" + ], + "static_instance_slots": [ + { + "instance_type": "encoder", + "host": "192.168.0.166", + "engine_rank": 0, + "cuda_device": 0 + }, + { + "instance_type": "transformer", + "host": "192.168.0.139", + "engine_rank": 1, + "cuda_device": 0 + }, + { + "instance_type": "transformer", + "host": "192.168.0.139", + "engine_rank": 2, + "cuda_device": 1 + }, + { + "instance_type": "transformer", + "host": "192.168.0.139", + "engine_rank": 3, + "cuda_device": 2 + }, + { + "instance_type": "transformer", + "host": "192.168.0.139", + "engine_rank": 4, + "cuda_device": 3 + }, + { + "instance_type": "transformer", + "host": "192.168.0.139", + "engine_rank": 5, + "cuda_device": 4 + }, + { + "instance_type": "transformer", + "host": "192.168.0.139", + "engine_rank": 6, + "cuda_device": 5 + }, + { + "instance_type": "decoder", + "host": "192.168.0.139", + "engine_rank": 7, + "cuda_device": 6 + } + ] } } diff --git a/configs/disagg/wan22_i2v_distill_decoder.json b/configs/disagg/wan22_i2v_distill_decoder.json index 943d42f73..46e6d668c 100644 --- a/configs/disagg/wan22_i2v_distill_decoder.json +++ b/configs/disagg/wan22_i2v_distill_decoder.json @@ -45,14 +45,14 @@ "image_path": "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B/examples/i2v_input.JPG", "disagg_mode": "decoder", "disagg_config": { - "bootstrap_addr": "127.0.0.1", + "bootstrap_addr": "192.168.0.166", "bootstrap_room": 0, "ranks": 8, "encoder_engine_rank": 0, "transformer_engine_rank": 1, "decoder_engine_rank": 2, "protocol": "rdma", - "local_hostname": "localhost", + "local_hostname": "192.168.0.139", "metadata_server": "P2PHANDSHAKE" } } diff --git a/configs/disagg/wan22_i2v_distill_encoder.json b/configs/disagg/wan22_i2v_distill_encoder.json index 1e96bfe2c..27cadaaf0 100644 --- a/configs/disagg/wan22_i2v_distill_encoder.json +++ b/configs/disagg/wan22_i2v_distill_encoder.json @@ -45,14 +45,14 @@ "image_path": "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B/examples/i2v_input.JPG", "disagg_mode": "encoder", "disagg_config": { - "bootstrap_addr": "127.0.0.1", + "bootstrap_addr": "192.168.0.166", "bootstrap_room": 0, "ranks": 8, "encoder_engine_rank": 0, "transformer_engine_rank": 1, "decoder_engine_rank": 2, "protocol": "rdma", - "local_hostname": "localhost", + "local_hostname": "192.168.0.166", "metadata_server": "P2PHANDSHAKE" } } diff --git a/configs/disagg/wan22_i2v_distill_transformer.json b/configs/disagg/wan22_i2v_distill_transformer.json index e28fdfb6b..a629ff2f4 100644 --- a/configs/disagg/wan22_i2v_distill_transformer.json +++ b/configs/disagg/wan22_i2v_distill_transformer.json @@ -45,14 +45,14 @@ "image_path": "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B/examples/i2v_input.JPG", "disagg_mode": "transformer", "disagg_config": { - "bootstrap_addr": "127.0.0.1", + "bootstrap_addr": "192.168.0.166", "bootstrap_room": 0, "ranks": 8, "encoder_engine_rank": 0, "transformer_engine_rank": 1, "decoder_engine_rank": 2, "protocol": "rdma", - "local_hostname": "localhost", + "local_hostname": "192.168.0.139", "metadata_server": "P2PHANDSHAKE" } } diff --git a/lightx2v/disagg/examples/run_service.py b/lightx2v/disagg/examples/run_service.py index 515265214..f24140dfb 100644 --- a/lightx2v/disagg/examples/run_service.py +++ b/lightx2v/disagg/examples/run_service.py @@ -4,10 +4,6 @@ from loguru import logger -from lightx2v.disagg.services.controller import ControllerService -from lightx2v.disagg.services.decoder import DecoderService -from lightx2v.disagg.services.encoder import EncoderService -from lightx2v.disagg.services.transformer import TransformerService from lightx2v.disagg.utils import set_config from lightx2v.utils.utils import seed_all @@ -124,12 +120,20 @@ def main(): logger.info("Starting disagg service mode={}", service_mode) if service_mode == "encoder": + from lightx2v.disagg.services.encoder import EncoderService + EncoderService(config).run() elif service_mode == "transformer": + from lightx2v.disagg.services.transformer import TransformerService + TransformerService(config).run() elif service_mode == "decoder": + from lightx2v.disagg.services.decoder import DecoderService + DecoderService(config).run() elif service_mode == "controller": + from lightx2v.disagg.services.controller import ControllerService + ControllerService().run(config) else: raise ValueError(f"Unsupported service mode: {service_mode}") From fe2522105fb9036a3a249688911971a9942e5cc9 Mon Sep 17 00:00:00 2001 From: zhtshr <547553598@qq.com> Date: Thu, 23 Apr 2026 10:55:54 +0800 Subject: [PATCH 6/9] multi machine --- .../wan22_i2v_distill_controller.json | 8 +- .../wan22_i2v_distill_decoder.json | 2 +- .../wan22_i2v_distill_encoder.json | 0 .../wan22_i2v_distill_transformer.json | 0 .../wan_t2v_disagg_controller.json | 0 .../wan_t2v_disagg_decoder.json | 0 .../wan_t2v_disagg_encoder.json | 0 .../wan_t2v_disagg_transformer.json | 0 .../wan22_i2v_distill_controller.json | 108 +++ configs/disagg/wan22_i2v_workload_stages.json | 6 +- lightx2v/disagg/rdma_client.py | 22 +- lightx2v/disagg/rdma_server.py | 20 +- lightx2v/disagg/rdma_utils.py | 84 ++ lightx2v/disagg/services/base.py | 63 +- lightx2v/disagg/services/controller.py | 785 +++++++++++++++--- lightx2v/disagg/services/decoder.py | 6 +- lightx2v/disagg/services/encoder.py | 8 +- lightx2v/disagg/services/instance_proxy.py | 251 ++++++ lightx2v/disagg/services/transformer.py | 8 +- lightx2v/disagg/workload.py | 2 +- scripts/disagg/extract_dynamic_latency.py | 36 +- scripts/disagg/run_dynamic.sh | 256 +++++- scripts/disagg/run_wan22_i2v_distill.sh | 116 --- scripts/disagg/run_wan_t2v_service.sh | 161 ---- 24 files changed, 1490 insertions(+), 452 deletions(-) rename configs/disagg/{ => multi_node}/wan22_i2v_distill_controller.json (93%) rename configs/disagg/{ => multi_node}/wan22_i2v_distill_decoder.json (97%) rename configs/disagg/{ => multi_node}/wan22_i2v_distill_encoder.json (100%) rename configs/disagg/{ => multi_node}/wan22_i2v_distill_transformer.json (100%) rename configs/disagg/{ => multi_node}/wan_t2v_disagg_controller.json (100%) rename configs/disagg/{ => multi_node}/wan_t2v_disagg_decoder.json (100%) rename configs/disagg/{ => multi_node}/wan_t2v_disagg_encoder.json (100%) rename configs/disagg/{ => multi_node}/wan_t2v_disagg_transformer.json (100%) create mode 100644 configs/disagg/single_node/wan22_i2v_distill_controller.json create mode 100644 lightx2v/disagg/rdma_utils.py create mode 100644 lightx2v/disagg/services/instance_proxy.py delete mode 100755 scripts/disagg/run_wan22_i2v_distill.sh delete mode 100755 scripts/disagg/run_wan_t2v_service.sh diff --git a/configs/disagg/wan22_i2v_distill_controller.json b/configs/disagg/multi_node/wan22_i2v_distill_controller.json similarity index 93% rename from configs/disagg/wan22_i2v_distill_controller.json rename to configs/disagg/multi_node/wan22_i2v_distill_controller.json index a8d678fe6..c4c691ec7 100644 --- a/configs/disagg/wan22_i2v_distill_controller.json +++ b/configs/disagg/multi_node/wan22_i2v_distill_controller.json @@ -56,8 +56,10 @@ "metadata_server": "P2PHANDSHAKE", "remote_workdir": "/root/zht/LightX2V", "remote_python_executable": "python", - "remote_activate_cmd": "source /root/miniconda3/etc/profile.d/conda.sh && conda activate lightx2v", + "remote_activate_cmd": "source /root/miniconda3/etc/profile.d/conda.sh && conda activate lightx2v && export LD_LIBRARY_PATH=/root/miniconda3/envs/lightx2v/lib:${LD_LIBRARY_PATH:-}", "remote_log_dir": "/root/zht/LightX2V/save_results", + "use_remote_proxy": true, + "remote_proxy_req_base_port": 28000, "ssh_user": "root", "ssh_options": [ "-i", @@ -112,9 +114,9 @@ }, { "instance_type": "decoder", - "host": "192.168.0.139", + "host": "192.168.0.166", "engine_rank": 7, - "cuda_device": 6 + "cuda_device": 7 } ] } diff --git a/configs/disagg/wan22_i2v_distill_decoder.json b/configs/disagg/multi_node/wan22_i2v_distill_decoder.json similarity index 97% rename from configs/disagg/wan22_i2v_distill_decoder.json rename to configs/disagg/multi_node/wan22_i2v_distill_decoder.json index 46e6d668c..31b7fa418 100644 --- a/configs/disagg/wan22_i2v_distill_decoder.json +++ b/configs/disagg/multi_node/wan22_i2v_distill_decoder.json @@ -52,7 +52,7 @@ "transformer_engine_rank": 1, "decoder_engine_rank": 2, "protocol": "rdma", - "local_hostname": "192.168.0.139", + "local_hostname": "192.168.0.166", "metadata_server": "P2PHANDSHAKE" } } diff --git a/configs/disagg/wan22_i2v_distill_encoder.json b/configs/disagg/multi_node/wan22_i2v_distill_encoder.json similarity index 100% rename from configs/disagg/wan22_i2v_distill_encoder.json rename to configs/disagg/multi_node/wan22_i2v_distill_encoder.json diff --git a/configs/disagg/wan22_i2v_distill_transformer.json b/configs/disagg/multi_node/wan22_i2v_distill_transformer.json similarity index 100% rename from configs/disagg/wan22_i2v_distill_transformer.json rename to configs/disagg/multi_node/wan22_i2v_distill_transformer.json diff --git a/configs/disagg/wan_t2v_disagg_controller.json b/configs/disagg/multi_node/wan_t2v_disagg_controller.json similarity index 100% rename from configs/disagg/wan_t2v_disagg_controller.json rename to configs/disagg/multi_node/wan_t2v_disagg_controller.json diff --git a/configs/disagg/wan_t2v_disagg_decoder.json b/configs/disagg/multi_node/wan_t2v_disagg_decoder.json similarity index 100% rename from configs/disagg/wan_t2v_disagg_decoder.json rename to configs/disagg/multi_node/wan_t2v_disagg_decoder.json diff --git a/configs/disagg/wan_t2v_disagg_encoder.json b/configs/disagg/multi_node/wan_t2v_disagg_encoder.json similarity index 100% rename from configs/disagg/wan_t2v_disagg_encoder.json rename to configs/disagg/multi_node/wan_t2v_disagg_encoder.json diff --git a/configs/disagg/wan_t2v_disagg_transformer.json b/configs/disagg/multi_node/wan_t2v_disagg_transformer.json similarity index 100% rename from configs/disagg/wan_t2v_disagg_transformer.json rename to configs/disagg/multi_node/wan_t2v_disagg_transformer.json diff --git a/configs/disagg/single_node/wan22_i2v_distill_controller.json b/configs/disagg/single_node/wan22_i2v_distill_controller.json new file mode 100644 index 000000000..e8e4f577b --- /dev/null +++ b/configs/disagg/single_node/wan22_i2v_distill_controller.json @@ -0,0 +1,108 @@ +{ + "infer_steps": 4, + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "eps": 1e-06, + "model_type": "i2v", + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "rdma_buffer_slot_size": 8192, + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "fps": 16, + "use_image_encoder": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "int8-q8f", + "high_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "high_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "image_path": "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B/examples/i2v_input.JPG", + "disagg_mode": "controller", + "disagg_config": { + "bootstrap_addr": "127.0.0.1", + "bootstrap_room": 0, + "ranks": 8, + "encoder_engine_rank": 0, + "transformer_engine_rank": 1, + "decoder_engine_rank": 2, + "protocol": "rdma", + "local_hostname": "127.0.0.1", + "metadata_server": "P2PHANDSHAKE", + "static_instance_slots": [ + { + "instance_type": "encoder", + "host": "127.0.0.1", + "engine_rank": 0, + "cuda_device": 0 + }, + { + "instance_type": "transformer", + "host": "127.0.0.1", + "engine_rank": 1, + "cuda_device": 1 + }, + { + "instance_type": "transformer", + "host": "127.0.0.1", + "engine_rank": 2, + "cuda_device": 2 + }, + { + "instance_type": "transformer", + "host": "127.0.0.1", + "engine_rank": 3, + "cuda_device": 3 + }, + { + "instance_type": "transformer", + "host": "127.0.0.1", + "engine_rank": 4, + "cuda_device": 4 + }, + { + "instance_type": "transformer", + "host": "127.0.0.1", + "engine_rank": 5, + "cuda_device": 5 + }, + { + "instance_type": "transformer", + "host": "127.0.0.1", + "engine_rank": 6, + "cuda_device": 6 + }, + { + "instance_type": "decoder", + "host": "127.0.0.1", + "engine_rank": 7, + "cuda_device": 7 + } + ] + } +} diff --git a/configs/disagg/wan22_i2v_workload_stages.json b/configs/disagg/wan22_i2v_workload_stages.json index 19787956e..edf619f89 100644 --- a/configs/disagg/wan22_i2v_workload_stages.json +++ b/configs/disagg/wan22_i2v_workload_stages.json @@ -7,20 +7,20 @@ "wait_time_s": 0.0, "config_variants": [ { - "infer_steps": 1, + "infer_steps": 4, "sample_shift": 5.0 } ] }, { "name": "change", - "duration_s": 180, + "duration_s": 1000, "user_count": 1, "spawn_rate": 0.1, "wait_time_s": 0.0, "config_variants": [ { - "infer_steps": 1, + "infer_steps": 4, "sample_shift": 5.0 } ] diff --git a/lightx2v/disagg/rdma_client.py b/lightx2v/disagg/rdma_client.py index 5a62cb343..fd88dc615 100644 --- a/lightx2v/disagg/rdma_client.py +++ b/lightx2v/disagg/rdma_client.py @@ -16,6 +16,8 @@ from pyverbs.wr import SGE from pyverbs.wr import SendWR as WR +from lightx2v.disagg.rdma_utils import resolve_gid_index + logger = logging.getLogger(__name__) @@ -111,25 +113,7 @@ def _wc_status_name(self, status: int | None) -> str: return status_map.get(status, f"IBV_WC_STATUS_{status}") def _resolve_gid_index(self): - env_gid = os.getenv("RDMA_GID_INDEX", "").strip() - if env_gid: - idx = int(env_gid) - self.ctx.query_gid(port_num=self.port_num, index=idx) - return idx - - # Prefer IPv4-mapped RoCE entries for Ethernet-based RDMA devices. - preferred = [2, 0, 1, 3, 4, 5, 6, 7] - for idx in preferred: - try: - gid = str(self.ctx.query_gid(port_num=self.port_num, index=idx)) - except Exception: - continue - if gid and gid != "::": - return idx - - # Last resort: let query_gid raise a descriptive error for index 0. - self.ctx.query_gid(port_num=self.port_num, index=0) - return 0 + return resolve_gid_index(self.ctx, self.port_num) def _alloc_local_psn(self): self._next_psn = (self._next_psn + 1) & 0xFFFFFF diff --git a/lightx2v/disagg/rdma_server.py b/lightx2v/disagg/rdma_server.py index 0b161761d..b8a1357a5 100644 --- a/lightx2v/disagg/rdma_server.py +++ b/lightx2v/disagg/rdma_server.py @@ -11,6 +11,8 @@ from pyverbs.pd import PD from pyverbs.qp import QP, QPAttr, QPCap, QPInitAttr +from lightx2v.disagg.rdma_utils import resolve_gid_index + class IBDevice: def __init__(self, name: str): @@ -98,23 +100,7 @@ def __init__(self, iface_name=None, port_num=1, buffer_size=4096): print(f"[Server] MR Registered. Addr: {mr_addr}, RKey: {self.mr.rkey}") def _resolve_gid_index(self): - env_gid = os.getenv("RDMA_GID_INDEX", "").strip() - if env_gid: - idx = int(env_gid) - self.ctx.query_gid(port_num=self.port_num, index=idx) - return idx - - preferred = [2, 0, 1, 3, 4, 5, 6, 7] - for idx in preferred: - try: - gid = str(self.ctx.query_gid(port_num=self.port_num, index=idx)) - except Exception: - continue - if gid and gid != "::": - return idx - - self.ctx.query_gid(port_num=self.port_num, index=0) - return 0 + return resolve_gid_index(self.ctx, self.port_num) def register_memory(self, addr: int, length: int): """Validate a requested sub-region against server MR and return registration metadata. diff --git a/lightx2v/disagg/rdma_utils.py b/lightx2v/disagg/rdma_utils.py new file mode 100644 index 000000000..fcdba75e5 --- /dev/null +++ b/lightx2v/disagg/rdma_utils.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import os +import socket + + +def _collect_local_ipv4_addresses() -> list[str]: + candidates: list[str] = [] + + try: + hostname = socket.gethostname() + for info in socket.getaddrinfo(hostname, None, socket.AF_INET): + address = info[4][0] + if address and not address.startswith("127.") and address not in candidates: + candidates.append(address) + except Exception: + pass + + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + sock.connect(("8.8.8.8", 80)) + address = sock.getsockname()[0] + if address and not address.startswith("127.") and address not in candidates: + candidates.append(address) + finally: + sock.close() + except Exception: + pass + + try: + address = socket.gethostbyname(socket.gethostname()) + if address and not address.startswith("127.") and address not in candidates: + candidates.append(address) + except Exception: + pass + + return candidates + + +def _gid_to_ipv4(gid_text: str) -> str | None: + if gid_text.startswith("::ffff:"): + return gid_text.removeprefix("::ffff:") + return None + + +def resolve_gid_index(ctx, port_num: int, env_var_name: str = "RDMA_GID_INDEX") -> int: + env_gid = os.getenv(env_var_name, "").strip() + if env_gid: + idx = int(env_gid) + ctx.query_gid(port_num=port_num, index=idx) + return idx + + local_ipv4s = _collect_local_ipv4_addresses() + + mapped_candidates: list[tuple[int, str]] = [] + first_non_empty_idx: int | None = None + + for idx in range(16): + try: + gid_text = str(ctx.query_gid(port_num=port_num, index=idx)) + except Exception: + continue + + if not gid_text or gid_text == "::": + continue + + if first_non_empty_idx is None: + first_non_empty_idx = idx + + ipv4 = _gid_to_ipv4(gid_text) + if ipv4 is not None: + mapped_candidates.append((idx, ipv4)) + if ipv4 in local_ipv4s: + return idx + + if mapped_candidates: + return mapped_candidates[0][0] + + if first_non_empty_idx is not None: + return first_non_empty_idx + + ctx.query_gid(port_num=port_num, index=0) + return 0 \ No newline at end of file diff --git a/lightx2v/disagg/services/base.py b/lightx2v/disagg/services/base.py index 7aa807ff3..f35c7ad1e 100644 --- a/lightx2v/disagg/services/base.py +++ b/lightx2v/disagg/services/base.py @@ -1,9 +1,62 @@ -import logging from abc import ABC +import sys -# Configure logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) +from loguru import logger as loguru_logger + + +loguru_logger.remove() +loguru_logger.add( + sys.stderr, + level="INFO", + format="[{level}] {time:DD MMM YYYY HH:mm:ss} | {name}:{function}:{line} - {message}", +) + + +class _LoguruLoggerAdapter: + def __init__(self, logger): + self._logger = logger + + @staticmethod + def _format_message(message, args): + if not args: + return message + try: + return message % args + except Exception: + return f"{message} {' '.join(map(str, args))}" + + def debug(self, message, *args, **kwargs): + self._logger.opt(depth=1).debug(self._format_message(message, args)) + + def info(self, message, *args, **kwargs): + self._logger.opt(depth=1).info(self._format_message(message, args)) + + def warning(self, message, *args, **kwargs): + self._logger.opt(depth=1).warning(self._format_message(message, args)) + + def error(self, message, *args, **kwargs): + self._logger.opt(depth=1).error(self._format_message(message, args)) + + def critical(self, message, *args, **kwargs): + self._logger.opt(depth=1).critical(self._format_message(message, args)) + + def exception(self, message, *args, **kwargs): + self._logger.opt(depth=1, exception=True).error(self._format_message(message, args)) + + def log(self, level, message, *args, **kwargs): + self._logger.opt(depth=1).log(level, self._format_message(message, args)) + + def bind(self, **kwargs): + return _LoguruLoggerAdapter(self._logger.bind(**kwargs)) + + def opt(self, *args, **kwargs): + return self._logger.opt(*args, **kwargs) + + def __getattr__(self, item): + return getattr(self._logger, item) + + +logger = _LoguruLoggerAdapter(loguru_logger) class BaseService(ABC): @@ -12,7 +65,7 @@ def __init__(self): Base initialization for all services. """ self.logger = logger - self.logger.info(f"Initializing {self.__class__.__name__}") + self.logger.info("Initializing %s", self.__class__.__name__) def _sync_runtime_config(self, config): current_config = getattr(self, "config", None) diff --git a/lightx2v/disagg/services/controller.py b/lightx2v/disagg/services/controller.py index 9273bbab4..e9f76dbe2 100644 --- a/lightx2v/disagg/services/controller.py +++ b/lightx2v/disagg/services/controller.py @@ -1,5 +1,6 @@ import json import os +import shlex import signal import socket import subprocess @@ -48,13 +49,19 @@ def __init__(self): self._gpu_reuse_grace_seconds: float = 5.0 self._graceful_reclaim_timeout_seconds: float = float(os.getenv("DISAGG_RECLAIM_GRACEFUL_TIMEOUT_SECONDS", "30.0")) self._force_kill_wait_seconds: float = float(os.getenv("DISAGG_RECLAIM_FORCE_KILL_WAIT_SECONDS", "1.0")) + self._instance_start_timeout_seconds: float = float(os.getenv("DISAGG_INSTANCE_START_TIMEOUT_SECONDS", "90.0")) self._sidecar_start_timeout_seconds: float = float(os.getenv("DISAGG_SIDECAR_START_TIMEOUT_SECONDS", "15.0")) self._sidecar_drain_idle_seconds: float = float(os.getenv("DISAGG_SIDECAR_DRAIN_IDLE_SECONDS", "1.0")) # <= 0 means wait indefinitely until sidecar pending queues are drained. self._sidecar_drain_timeout_seconds: float = float(os.getenv("DISAGG_SIDECAR_DRAIN_TIMEOUT_SECONDS", "0")) + self._remote_proxy_start_timeout_seconds: float = float(os.getenv("DISAGG_REMOTE_PROXY_START_TIMEOUT_SECONDS", "20.0")) self._sidecar_reclaim_threads: list[Thread] = [] self._shutting_down: bool = False self._enable_monitor: bool = False + self._static_instance_slots: list[dict[str, Any]] = [] + self._free_slot_ids: set[int] = set() + self._slot_reuse_block_until: dict[int, float] = {} + self._local_host_aliases: set[str] = set() def _is_monitor_enabled(self) -> bool: raw = os.getenv("ENABLE_MONITOR") @@ -76,19 +83,75 @@ def _wait_for_tcp_port_state(self, host: str, port: int, should_be_open: bool, t time.sleep(0.1) return self._is_tcp_port_open(host, port) == should_be_open - def _allocate_free_tcp_port(self) -> int: + def _refresh_local_host_aliases(self): + aliases: set[str] = { + "127.0.0.1", + "localhost", + str(self._bootstrap_addr), + } + try: + hostname = socket.gethostname() + aliases.add(hostname) + aliases.add(socket.getfqdn()) + host_info = socket.gethostbyname_ex(hostname) + aliases.update(host_info[1]) + aliases.update(host_info[2]) + except Exception: + pass + self._local_host_aliases = {item.strip() for item in aliases if isinstance(item, str) and item.strip()} + + def _is_local_host(self, host: str) -> bool: + normalized = str(host).strip() + if not normalized: + return False + if normalized in self._local_host_aliases: + return True + try: + return socket.gethostbyname(normalized) in self._local_host_aliases + except Exception: + return False + + def _allocate_free_tcp_port(self, bind_host: str | None = None) -> int: + host = str(bind_host or self._bootstrap_addr) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind((self._bootstrap_addr, 0)) + sock.bind((host, 0)) return int(sock.getsockname()[1]) - def _query_sidecar(self, req_addr: str, cmd: str) -> dict[str, Any] | None: + def _build_service_command(self, instance_type: str, engine_rank: int, instance_cfg: dict[str, Any], service_config_json: str) -> list[str]: + return [ + sys.executable, + "-m", + "lightx2v.disagg.examples.run_service", + "--service", + instance_type, + "--engine_rank", + str(engine_rank), + "--model_cls", + str(instance_cfg.get("model_cls", "wan2.1")), + "--task", + str(instance_cfg.get("task", "t2v")), + "--model_path", + str(instance_cfg.get("model_path")), + "--config_json", + service_config_json, + "--seed", + str(instance_cfg.get("seed", 42)), + "--prompt", + str(instance_cfg.get("prompt", "")), + "--negative_prompt", + str(instance_cfg.get("negative_prompt", "")), + "--save_result_path", + str(instance_cfg.get("save_path", "")), + ] + + def _query_zmq(self, req_addr: str, payload: dict[str, Any], timeout_ms: int = 1000) -> dict[str, Any] | None: context = zmq.Context() req = context.socket(zmq.REQ) - req.setsockopt(zmq.RCVTIMEO, 1000) - req.setsockopt(zmq.SNDTIMEO, 1000) + req.setsockopt(zmq.RCVTIMEO, int(timeout_ms)) + req.setsockopt(zmq.SNDTIMEO, int(timeout_ms)) req.connect(req_addr) try: - req.send_pyobj({"cmd": str(cmd)}) + req.send_pyobj(payload) reply = req.recv_pyobj() if isinstance(reply, dict): return reply @@ -99,11 +162,76 @@ def _query_sidecar(self, req_addr: str, cmd: str) -> dict[str, Any] | None: req.close(0) context.term() - def _start_sidecar_process(self, instance_type: str, gpu_id: int) -> dict[str, Any]: - push_port = self._allocate_free_tcp_port() - req_port = self._allocate_free_tcp_port() - push_addr = f"tcp://{self._bootstrap_addr}:{push_port}" - req_addr = f"tcp://{self._bootstrap_addr}:{req_port}" + def _query_sidecar(self, req_addr: str, cmd: str) -> dict[str, Any] | None: + return self._query_zmq(req_addr, {"cmd": str(cmd)}, timeout_ms=1000) + + def _is_truthy(self, value: Any, default: bool = False) -> bool: + if value is None: + return default + if isinstance(value, bool): + return value + return str(value).strip().lower() in {"1", "true", "yes", "on"} + + def _remote_proxy_req_addr(self, slot: dict[str, Any]) -> str: + host = str(slot["host"]) + proxy_req_port = int(slot["proxy_req_port"]) + return f"tcp://{host}:{proxy_req_port}" + + def _ensure_remote_instance_proxy(self, slot: dict[str, Any]): + if not self._is_truthy(slot.get("use_remote_proxy", False)): + return + + req_addr = self._remote_proxy_req_addr(slot) + reply = self._query_zmq(req_addr, {"cmd": "ping"}, timeout_ms=800) + if isinstance(reply, dict) and reply.get("ok", False): + return + + python_executable = str(slot.get("python_executable", sys.executable)) + workdir = str(slot.get("workdir", Path(__file__).resolve().parents[3])) + log_dir = str(slot.get("log_dir", "/tmp/lightx2v_disagg")) + activate_cmd = str(slot.get("activate_cmd", "")).strip() + proxy_req_port = int(slot["proxy_req_port"]) + proxy_log_path = str(slot.get("proxy_log_path", f"{log_dir}/instance_proxy.log")) + + script_lines = [ + "set -e", + f"mkdir -p {shlex.quote(log_dir)}", + f"cd {shlex.quote(workdir)}", + ] + if activate_cmd: + script_lines.append(activate_cmd) + script_lines.extend( + [ + ( + "nohup env PYTHONUNBUFFERED=1 " + f"{shlex.quote(python_executable)} -m lightx2v.disagg.services.instance_proxy " + f"--bind-addr {shlex.quote(f'tcp://0.0.0.0:{proxy_req_port}')} " + f"--workdir {shlex.quote(workdir)} --log-dir {shlex.quote(log_dir)} " + f"> {shlex.quote(proxy_log_path)} 2>&1 &" + ), + "echo PROXY_PID=$!", + ] + ) + script = "\n".join(script_lines) + + self._run_ssh_script(slot, script, timeout_seconds=30.0, check=True) + + deadline = time.time() + self._remote_proxy_start_timeout_seconds + while time.time() < deadline: + probe = self._query_zmq(req_addr, {"cmd": "ping"}, timeout_ms=800) + if isinstance(probe, dict) and probe.get("ok", False): + self.logger.info("Remote instance proxy is ready on host=%s req_addr=%s", slot.get("host"), req_addr) + return + time.sleep(0.2) + + raise RuntimeError(f"remote instance proxy failed to start on host={slot.get('host')} req_addr={req_addr}") + + def _start_sidecar_process(self, instance_type: str, gpu_id: str | int, bind_host: str | None = None) -> dict[str, Any]: + host = str(bind_host or self._bootstrap_addr) + push_port = self._allocate_free_tcp_port(host) + req_port = self._allocate_free_tcp_port(host) + push_addr = f"tcp://{host}:{push_port}" + req_addr = f"tcp://{host}:{req_port}" cmd = [ sys.executable, @@ -154,6 +282,245 @@ def _start_sidecar_process(self, instance_type: str, gpu_id: int) -> dict[str, A "req_addr": req_addr, } + def _run_ssh_script(self, slot: dict[str, Any], script: str, timeout_seconds: float = 30.0, check: bool = True) -> subprocess.CompletedProcess: + ssh_bin = str(slot.get("ssh_bin", "ssh")) + ssh_target = str(slot.get("ssh_target", slot.get("host", ""))).strip() + if not ssh_target: + raise RuntimeError("remote slot missing ssh target") + + ssh_options = slot.get("ssh_options") + ssh_cmd = [ssh_bin] + if isinstance(ssh_options, list): + ssh_cmd.extend(str(opt) for opt in ssh_options if str(opt).strip()) + ssh_cmd.extend([ssh_target, "bash", "-lc", script]) + return subprocess.run( + ssh_cmd, + check=check, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=timeout_seconds, + ) + + def _launch_remote_instance(self, slot: dict[str, Any], instance_type: str, cmd: list[str], cuda_device: str) -> tuple[dict[str, Any], dict[str, Any]]: + if self._is_truthy(slot.get("use_remote_proxy", False)): + return self._launch_remote_instance_via_proxy(slot, instance_type, cmd, cuda_device) + + host = str(slot["host"]) + engine_rank = int(slot["engine_rank"]) + python_executable = str(slot.get("python_executable", sys.executable)) + workdir = str(slot.get("workdir", Path(__file__).resolve().parents[3])) + log_dir = str(slot.get("log_dir", "/tmp/lightx2v_disagg")) + activate_cmd = str(slot.get("activate_cmd", "")).strip() + push_port = int(slot["sidecar_push_port"]) + req_port = int(slot["sidecar_req_port"]) + push_addr = f"tcp://{host}:{push_port}" + req_addr = f"tcp://{host}:{req_port}" + service_log = f"{log_dir}/{instance_type}_{engine_rank}_service.log" + sidecar_log = f"{log_dir}/{instance_type}_{engine_rank}_sidecar.log" + + extra_env = slot.get("env") + normalized_env: dict[str, str] = {} + if isinstance(extra_env, dict): + for key, value in extra_env.items(): + normalized_env[str(key)] = str(value) + + sidecar_env_vars = { + **normalized_env, + "CUDA_VISIBLE_DEVICES": str(cuda_device), + "PYTHONUNBUFFERED": "1", + } + service_env_vars = { + **normalized_env, + "CUDA_VISIBLE_DEVICES": str(cuda_device), + "LIGHTX2V_SIDECAR_PUSH_ADDR": push_addr, + "LIGHTX2V_SIDECAR_REQ_ADDR": req_addr, + "PYTHONUNBUFFERED": "1", + } + + def _to_env_prefix(env_map: dict[str, str]) -> str: + return " ".join(f"{key}={shlex.quote(value)}" for key, value in env_map.items()) + + def _with_env(base_cmd: str, env_map: dict[str, str]) -> str: + env_prefix = _to_env_prefix(env_map) + if not env_prefix: + return base_cmd + return f"env {env_prefix} {base_cmd}" + + sidecar_cmd = _with_env( + ( + f"{shlex.quote(python_executable)} " + "-m lightx2v.disagg.services.data_mgr_sidecar " + f"--push-addr {shlex.quote(push_addr)} --req-addr {shlex.quote(req_addr)}" + ), + sidecar_env_vars, + ) + cmd_with_python = [python_executable, *cmd[1:]] + service_cmd = _with_env(" ".join(shlex.quote(str(part)) for part in cmd_with_python), service_env_vars) + + script_lines = [ + "set -e", + f"mkdir -p {shlex.quote(log_dir)}", + f"cd {shlex.quote(workdir)}", + ] + if activate_cmd: + script_lines.append(activate_cmd) + script_lines.extend( + [ + f"nohup {sidecar_cmd} > {shlex.quote(sidecar_log)} 2>&1 &", + "sidecar_pid=$!", + "sleep 0.5", + f"nohup {service_cmd} > {shlex.quote(service_log)} 2>&1 &", + "service_pid=$!", + "echo SIDECAR_PID=$sidecar_pid", + "echo SERVICE_PID=$service_pid", + ] + ) + script = "\n".join(script_lines) + + completed = self._run_ssh_script(slot, script, timeout_seconds=45.0, check=True) + sidecar_pid: int | None = None + service_pid: int | None = None + for line in completed.stdout.splitlines(): + if line.startswith("SIDECAR_PID="): + try: + sidecar_pid = int(line.split("=", 1)[1].strip()) + except ValueError: + sidecar_pid = None + elif line.startswith("SERVICE_PID="): + try: + service_pid = int(line.split("=", 1)[1].strip()) + except ValueError: + service_pid = None + + if sidecar_pid is None or service_pid is None: + raise RuntimeError( + f"failed to parse remote pids for {instance_type} rank={engine_rank} host={host}: stdout={completed.stdout!r} stderr={completed.stderr!r}" + ) + + sidecar_meta = { + "mode": "remote", + "host": host, + "req_addr": req_addr, + "push_addr": push_addr, + "pid": sidecar_pid, + "log_path": sidecar_log, + } + process_meta = { + "mode": "remote", + "host": host, + "pid": service_pid, + "log_path": service_log, + } + return process_meta, sidecar_meta + + def _launch_remote_instance_via_proxy(self, slot: dict[str, Any], instance_type: str, cmd: list[str], cuda_device: str) -> tuple[dict[str, Any], dict[str, Any]]: + self._ensure_remote_instance_proxy(slot) + + host = str(slot["host"]) + engine_rank = int(slot["engine_rank"]) + python_executable = str(slot.get("python_executable", sys.executable)) + workdir = str(slot.get("workdir", Path(__file__).resolve().parents[3])) + log_dir = str(slot.get("log_dir", "/tmp/lightx2v_disagg")) + push_port = int(slot["sidecar_push_port"]) + req_port = int(slot["sidecar_req_port"]) + push_addr = f"tcp://{host}:{push_port}" + req_addr = f"tcp://{host}:{req_port}" + service_log = f"{log_dir}/{instance_type}_{engine_rank}_service.log" + sidecar_log = f"{log_dir}/{instance_type}_{engine_rank}_sidecar.log" + + extra_env = slot.get("env") + normalized_env: dict[str, str] = {} + if isinstance(extra_env, dict): + for key, value in extra_env.items(): + normalized_env[str(key)] = str(value) + + proxy_req_addr = self._remote_proxy_req_addr(slot) + payload = { + "cmd": "start_instance", + "instance_type": str(instance_type), + "engine_rank": int(engine_rank), + "cuda_device": str(cuda_device), + "python_executable": python_executable, + "service_argv": [str(part) for part in cmd[1:]], + "sidecar_push_addr": push_addr, + "sidecar_req_addr": req_addr, + "service_log_path": service_log, + "sidecar_log_path": sidecar_log, + "workdir": workdir, + "log_dir": log_dir, + "env": normalized_env, + } + reply = self._query_zmq(proxy_req_addr, payload, timeout_ms=10000) + if not isinstance(reply, dict) or not reply.get("ok", False): + raise RuntimeError(f"remote proxy failed to start instance on host={host}: {reply}") + + data = reply.get("data") if isinstance(reply.get("data"), dict) else {} + sidecar_pid = int(data.get("sidecar_pid", 0) or 0) + service_pid = int(data.get("service_pid", 0) or 0) + if sidecar_pid <= 0 or service_pid <= 0: + raise RuntimeError(f"remote proxy returned invalid pids for host={host}: {reply}") + + sidecar_meta = { + "mode": "remote", + "host": host, + "req_addr": req_addr, + "push_addr": push_addr, + "pid": sidecar_pid, + "log_path": sidecar_log, + "proxy_req_addr": proxy_req_addr, + } + process_meta = { + "mode": "remote", + "host": host, + "pid": service_pid, + "log_path": service_log, + "proxy_req_addr": proxy_req_addr, + } + return process_meta, sidecar_meta + + def _stop_remote_pid(self, slot: dict[str, Any], pid: int, graceful_timeout_seconds: float): + if self._is_truthy(slot.get("use_remote_proxy", False)): + req_addr = self._remote_proxy_req_addr(slot) + timeout_seconds = max(1, int(graceful_timeout_seconds)) + payload = { + "cmd": "stop_pid", + "pid": int(pid), + "timeout_seconds": timeout_seconds, + } + reply = self._query_zmq(req_addr, payload, timeout_ms=(timeout_seconds + 3) * 1000) + if isinstance(reply, dict) and reply.get("ok", False): + return + self.logger.warning( + "Remote proxy stop_pid failed, falling back to ssh kill: host=%s pid=%s reply=%s", + slot.get("host"), + pid, + reply, + ) + + timeout_seconds = max(1, int(graceful_timeout_seconds)) + script = "\n".join( + [ + "set +e", + f"pid={int(pid)}", + "if kill -0 ${pid} >/dev/null 2>&1; then", + " kill -TERM ${pid} >/dev/null 2>&1 || true", + f" deadline=$((SECONDS+{timeout_seconds}))", + " while kill -0 ${pid} >/dev/null 2>&1; do", + " if (( SECONDS >= deadline )); then", + " kill -KILL ${pid} >/dev/null 2>&1 || true", + " break", + " fi", + " sleep 0.2", + " done", + "fi", + ] + ) + try: + self._run_ssh_script(slot, script, timeout_seconds=float(timeout_seconds + 10), check=False) + except Exception as exc: + self.logger.warning("Failed to stop remote pid=%s on host=%s: %s", pid, slot.get("host"), exc) + def _reclaim_sidecar_when_drained(self, instance_type: str, target_address: str, sidecar_meta: dict[str, Any]): req_addr = str(sidecar_meta.get("req_addr", "")) process = sidecar_meta.get("process") @@ -628,6 +995,111 @@ def _instance_address_from_monitor_node(self, monitor_node: str) -> str: def _init_gpu_pool(self, config: dict): disagg_cfg = config.get("disagg_config") if isinstance(config.get("disagg_config"), dict) else {} + self._refresh_local_host_aliases() + + static_slots_raw = disagg_cfg.get("static_instance_slots") + self._static_instance_slots = [] + self._free_slot_ids = set() + self._slot_reuse_block_until = {} + + if isinstance(static_slots_raw, list) and static_slots_raw: + default_workdir = str(disagg_cfg.get("remote_workdir", Path(__file__).resolve().parents[3])) + default_python = str(disagg_cfg.get("remote_python_executable", sys.executable)) + default_log_dir = str(disagg_cfg.get("remote_log_dir", "/tmp/lightx2v_disagg")) + default_activate_cmd = str(disagg_cfg.get("remote_activate_cmd", "")).strip() + default_ssh_user = str(disagg_cfg.get("ssh_user", "")).strip() + default_ssh_bin = str(disagg_cfg.get("ssh_bin", os.getenv("DISAGG_SSH_BIN", "ssh"))) + default_use_remote_proxy = self._is_truthy(disagg_cfg.get("use_remote_proxy"), default=self._is_truthy(os.getenv("DISAGG_USE_REMOTE_PROXY"), False)) + default_proxy_req_base_port = int(disagg_cfg.get("remote_proxy_req_base_port", 28000)) + + default_ssh_options_raw = disagg_cfg.get("ssh_options", []) + if isinstance(default_ssh_options_raw, str): + default_ssh_options = shlex.split(default_ssh_options_raw) + elif isinstance(default_ssh_options_raw, list): + default_ssh_options = [str(opt) for opt in default_ssh_options_raw if str(opt).strip()] + else: + default_ssh_options = [] + + default_slot_env = disagg_cfg.get("service_env", {}) + normalized_default_slot_env: dict[str, str] = {} + if isinstance(default_slot_env, dict): + for key, value in default_slot_env.items(): + normalized_default_slot_env[str(key)] = str(value) + + sidecar_base_port = int(disagg_cfg.get("sidecar_base_port", 26000)) + seen_slot_keys: set[tuple[str, int]] = set() + + for index, raw_slot in enumerate(static_slots_raw): + if not isinstance(raw_slot, dict): + raise ValueError(f"invalid static_instance_slots[{index}] (expect object)") + + instance_type = str(raw_slot.get("instance_type", "")).strip().lower() + if instance_type not in {"encoder", "transformer", "decoder"}: + raise ValueError(f"invalid static_instance_slots[{index}].instance_type={instance_type!r}") + + host = str(raw_slot.get("host", "")).strip() + if not host: + raise ValueError(f"static_instance_slots[{index}].host cannot be empty") + + engine_rank = int(raw_slot.get("engine_rank")) + cuda_device = str(raw_slot.get("cuda_device", engine_rank)) + slot_key = (host, engine_rank) + if slot_key in seen_slot_keys: + raise ValueError(f"duplicate static slot host/rank: {slot_key}") + seen_slot_keys.add(slot_key) + + ssh_user = str(raw_slot.get("ssh_user", default_ssh_user)).strip() + ssh_target = f"{ssh_user}@{host}" if ssh_user else host + ssh_bin = str(raw_slot.get("ssh_bin", default_ssh_bin)) + + ssh_options_raw = raw_slot.get("ssh_options", default_ssh_options) + if isinstance(ssh_options_raw, str): + ssh_options = shlex.split(ssh_options_raw) + elif isinstance(ssh_options_raw, list): + ssh_options = [str(opt) for opt in ssh_options_raw if str(opt).strip()] + else: + ssh_options = list(default_ssh_options) + + slot_env = dict(normalized_default_slot_env) + raw_slot_env = raw_slot.get("env", {}) + if isinstance(raw_slot_env, dict): + for key, value in raw_slot_env.items(): + slot_env[str(key)] = str(value) + + push_port = int(raw_slot.get("sidecar_push_port", sidecar_base_port + engine_rank * 2)) + req_port = int(raw_slot.get("sidecar_req_port", sidecar_base_port + engine_rank * 2 + 1)) + use_remote_proxy = self._is_truthy(raw_slot.get("use_remote_proxy"), default=default_use_remote_proxy) + proxy_req_port = int(raw_slot.get("proxy_req_port", default_proxy_req_base_port + engine_rank)) + proxy_log_path = str(raw_slot.get("proxy_log_path", f"{default_log_dir}/instance_proxy_{engine_rank}.log")) + + self._static_instance_slots.append( + { + "slot_id": index, + "instance_type": instance_type, + "host": host, + "engine_rank": engine_rank, + "cuda_device": cuda_device, + "workdir": str(raw_slot.get("workdir", default_workdir)), + "python_executable": str(raw_slot.get("python_executable", default_python)), + "log_dir": str(raw_slot.get("log_dir", default_log_dir)), + "activate_cmd": str(raw_slot.get("activate_cmd", default_activate_cmd)).strip(), + "ssh_target": ssh_target, + "ssh_bin": ssh_bin, + "ssh_options": ssh_options, + "sidecar_push_port": push_port, + "sidecar_req_port": req_port, + "use_remote_proxy": use_remote_proxy, + "proxy_req_port": proxy_req_port, + "proxy_log_path": proxy_log_path, + "env": slot_env, + } + ) + + self._free_slot_ids = {int(slot["slot_id"]) for slot in self._static_instance_slots} + self.logger.info("Static multi-node mode enabled with %s slots", len(self._static_instance_slots)) + self._free_gpus = set() + return + total_ranks = int(config.get("ranks", disagg_cfg.get("ranks", 8))) if total_ranks <= 0: raise ValueError("ranks must be positive") @@ -642,39 +1114,80 @@ def create_instance(self, instance_type: str) -> str: raise RuntimeError("controller runtime config is not initialized") with self._instance_lock: - if not self._free_gpus: - raise RuntimeError("no idle GPU available") + use_static_slots = bool(self._static_instance_slots) + selected_slot: dict[str, Any] | None = None - now = time.time() - gpu_id: int | None = None - for candidate_gpu in sorted(self._free_gpus): - if now < self._gpu_reuse_block_until.get(candidate_gpu, 0.0): - continue + if use_static_slots: + if not self._free_slot_ids: + raise RuntimeError("no idle static slot available") - monitor_port = MONITOR_POLLING_PORT + candidate_gpu - if self._is_tcp_port_open(self._bootstrap_addr, monitor_port): - self.logger.warning( - "Skip gpu=%s for %s creation because monitor port %s is still in use", - candidate_gpu, - instance_type, - monitor_port, - ) - continue + now = time.time() + for slot_id in sorted(self._free_slot_ids): + slot = self._static_instance_slots[slot_id] + if slot.get("instance_type") != instance_type: + continue + if now < self._slot_reuse_block_until.get(slot_id, 0.0): + continue - gpu_id = candidate_gpu - break + host = str(slot["host"]) + engine_rank = int(slot["engine_rank"]) + monitor_port = MONITOR_POLLING_PORT + engine_rank + if self._is_tcp_port_open(host, monitor_port): + self.logger.warning( + "Skip static slot=%s host=%s rank=%s for %s creation because monitor port %s is still in use", + slot_id, + host, + engine_rank, + instance_type, + monitor_port, + ) + continue + + selected_slot = slot + break + + if selected_slot is None: + raise RuntimeError(f"no idle static slot available for {instance_type}: all candidates cooling down or port is in use") + + engine_rank = int(selected_slot["engine_rank"]) + host = str(selected_slot["host"]) + cuda_device = str(selected_slot["cuda_device"]) + else: + if not self._free_gpus: + raise RuntimeError("no idle GPU available") + + now = time.time() + engine_rank: int | None = None + host = self._bootstrap_addr + for candidate_gpu in sorted(self._free_gpus): + if now < self._gpu_reuse_block_until.get(candidate_gpu, 0.0): + continue + + monitor_port = MONITOR_POLLING_PORT + candidate_gpu + if self._is_tcp_port_open(self._bootstrap_addr, monitor_port): + self.logger.warning( + "Skip gpu=%s for %s creation because monitor port %s is still in use", + candidate_gpu, + instance_type, + monitor_port, + ) + continue + + engine_rank = candidate_gpu + break - if gpu_id is None: - raise RuntimeError(f"no idle GPU available for {instance_type}: all candidates cooling down or port is in use") + if engine_rank is None: + raise RuntimeError(f"no idle GPU available for {instance_type}: all candidates cooling down or port is in use") + cuda_device = str(engine_rank) instance_cfg = self._to_plain(self._runtime_config) instance_cfg["disagg_mode"] = instance_type if instance_type == "encoder": - instance_cfg["encoder_engine_rank"] = gpu_id + instance_cfg["encoder_engine_rank"] = engine_rank elif instance_type == "transformer": - instance_cfg["transformer_engine_rank"] = gpu_id + instance_cfg["transformer_engine_rank"] = engine_rank else: - instance_cfg["decoder_engine_rank"] = gpu_id + instance_cfg["decoder_engine_rank"] = engine_rank model_path = instance_cfg.get("model_path") config_json = instance_cfg.get("config_json") @@ -682,78 +1195,89 @@ def create_instance(self, instance_type: str) -> str: raise RuntimeError("model_path and config_json are required to launch service subprocess") service_config_json = self._resolve_service_config_json(str(config_json), instance_type) - cmd = [ - sys.executable, - "-m", - "lightx2v.disagg.examples.run_service", - "--service", - instance_type, - "--engine_rank", - str(gpu_id), - "--model_cls", - str(instance_cfg.get("model_cls", "wan2.1")), - "--task", - str(instance_cfg.get("task", "t2v")), - "--model_path", - str(model_path), - "--config_json", - service_config_json, - "--seed", - str(instance_cfg.get("seed", 42)), - "--prompt", - str(instance_cfg.get("prompt", "")), - "--negative_prompt", - str(instance_cfg.get("negative_prompt", "")), - "--save_result_path", - str(instance_cfg.get("save_path", "")), - ] - sidecar_meta = self._start_sidecar_process(instance_type, gpu_id) - env = os.environ.copy() - env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) - env["LIGHTX2V_SIDECAR_PUSH_ADDR"] = str(sidecar_meta["push_addr"]) - env["LIGHTX2V_SIDECAR_REQ_ADDR"] = str(sidecar_meta["req_addr"]) - process = subprocess.Popen( - cmd, - env=env, - start_new_session=True, - ) + cmd = self._build_service_command(instance_type, engine_rank, instance_cfg, service_config_json) - monitor_port = MONITOR_POLLING_PORT + gpu_id - if not self._wait_for_tcp_port_state(self._bootstrap_addr, monitor_port, should_be_open=True, timeout_seconds=8.0): - if process.poll() is None: + process: subprocess.Popen | None = None + process_meta: dict[str, Any] | None = None + sidecar_meta: dict[str, Any] + launch_mode = "local" + + try: + if use_static_slots and selected_slot is not None and not self._is_local_host(host): + launch_mode = "remote" + process_meta, sidecar_meta = self._launch_remote_instance(selected_slot, instance_type, cmd, cuda_device) + else: + sidecar_meta = self._start_sidecar_process(instance_type, cuda_device, bind_host=host) + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = str(cuda_device) + env["LIGHTX2V_SIDECAR_PUSH_ADDR"] = str(sidecar_meta["push_addr"]) + env["LIGHTX2V_SIDECAR_REQ_ADDR"] = str(sidecar_meta["req_addr"]) + process = subprocess.Popen( + cmd, + env=env, + start_new_session=True, + ) + + monitor_port = MONITOR_POLLING_PORT + engine_rank + if not self._wait_for_tcp_port_state(host, monitor_port, should_be_open=True, timeout_seconds=self._instance_start_timeout_seconds): + raise RuntimeError(f"service {instance_type} rank={engine_rank} host={host} failed to expose monitor port {monitor_port}") + except Exception: + if process is not None and process.poll() is None: process.terminate() try: process.wait(timeout=3.0) except subprocess.TimeoutExpired: process.kill() - sidecar_process = sidecar_meta.get("process") - if sidecar_process is not None and sidecar_process.poll() is None: - sidecar_process.terminate() - try: - sidecar_process.wait(timeout=2.0) - except subprocess.TimeoutExpired: - sidecar_process.kill() - raise RuntimeError(f"service {instance_type} on gpu={gpu_id} failed to expose monitor port {monitor_port}") - instance_address = f"{self._bootstrap_addr}:{REQUEST_POLLING_PORT + gpu_id}" - self._free_gpus.remove(gpu_id) + if launch_mode == "remote" and selected_slot is not None and process_meta is not None: + remote_pid = process_meta.get("pid") + if isinstance(remote_pid, int) and remote_pid > 0: + self._stop_remote_pid(selected_slot, remote_pid, self._graceful_reclaim_timeout_seconds) + + if "sidecar_meta" in locals(): + if launch_mode == "remote" and selected_slot is not None: + sidecar_pid = sidecar_meta.get("pid") + if isinstance(sidecar_pid, int) and sidecar_pid > 0: + self._stop_remote_pid(selected_slot, sidecar_pid, self._force_kill_wait_seconds) + else: + sidecar_process = sidecar_meta.get("process") + if sidecar_process is not None and sidecar_process.poll() is None: + sidecar_process.terminate() + try: + sidecar_process.wait(timeout=2.0) + except subprocess.TimeoutExpired: + sidecar_process.kill() + raise + + instance_address = f"{host}:{REQUEST_POLLING_PORT + engine_rank}" + if use_static_slots and selected_slot is not None: + self._free_slot_ids.remove(int(selected_slot["slot_id"])) + else: + self._free_gpus.remove(engine_rank) # self.add_instance(instance_type, instance_address) if self._enable_monitor: - monitor_node = f"tcp://{self._bootstrap_addr}:{MONITOR_POLLING_PORT + gpu_id}" + monitor_node = f"tcp://{host}:{MONITOR_POLLING_PORT + engine_rank}" if monitor_node not in self.monitor.nodes: self.monitor.nodes.append(monitor_node) self._managed_instances[instance_address] = { "instance_type": instance_type, - "gpu_id": gpu_id, + "gpu_id": engine_rank, + "host": host, + "launch_mode": launch_mode, + "cuda_device": cuda_device, "process": process, + "process_meta": process_meta, "sidecar": sidecar_meta, + "slot_id": int(selected_slot["slot_id"]) if selected_slot is not None else None, + "static_slot": self._to_plain(selected_slot) if selected_slot is not None else None, } self.started_instances.append((instance_type, instance_address)) self.logger.info( - "Created %s instance on gpu=%s pid=%s address=%s", + "Created %s instance host=%s rank=%s mode=%s address=%s", instance_type, - gpu_id, - process.pid, + host, + engine_rank, + launch_mode, instance_address, ) return instance_address @@ -785,48 +1309,72 @@ def reclaim_instance(self, instance_type: str, instance_address: str | None = No raise RuntimeError(f"instance type mismatch for {target_address}: expected={instance_type} got={meta.get('instance_type')}") process = meta.get("process") + process_meta = meta.get("process_meta") if isinstance(meta.get("process_meta"), dict) else None gpu_id = int(meta.get("gpu_id")) sidecar_meta = meta.get("sidecar") if isinstance(meta.get("sidecar"), dict) else None + host = str(meta.get("host", self._bootstrap_addr)) + launch_mode = str(meta.get("launch_mode", "local")) + static_slot = meta.get("static_slot") if isinstance(meta.get("static_slot"), dict) else None + slot_id_raw = meta.get("slot_id") + slot_id = int(slot_id_raw) if slot_id_raw is not None else None # self.remove_instance(instance_type, target_address) monitor_node = self._monitor_node_from_instance_address(target_address) - if process is not None and process.poll() is None: - try: - os.killpg(process.pid, signal.SIGTERM) - except Exception: - process.terminate() - try: - process.wait(timeout=self._graceful_reclaim_timeout_seconds) - except subprocess.TimeoutExpired: + if launch_mode == "remote": + if static_slot is None: + raise RuntimeError(f"remote instance metadata missing static slot for {target_address}") + + remote_service_pid = None + if process_meta is not None and isinstance(process_meta.get("pid"), int): + remote_service_pid = int(process_meta["pid"]) + if remote_service_pid is not None and remote_service_pid > 0: + self._stop_remote_pid(static_slot, remote_service_pid, self._graceful_reclaim_timeout_seconds) + + if sidecar_meta is not None and isinstance(sidecar_meta.get("pid"), int): + self._stop_remote_pid(static_slot, int(sidecar_meta["pid"]), self._force_kill_wait_seconds) + else: + if process is not None and process.poll() is None: try: - os.killpg(process.pid, signal.SIGKILL) + os.killpg(process.pid, signal.SIGTERM) except Exception: - process.kill() + process.terminate() try: - process.wait(timeout=self._force_kill_wait_seconds) - except subprocess.TimeoutExpired as exc: - raise RuntimeError(f"process did not exit after kill for {instance_type} instance {target_address}") from exc + process.wait(timeout=self._graceful_reclaim_timeout_seconds) + except subprocess.TimeoutExpired: + try: + os.killpg(process.pid, signal.SIGKILL) + except Exception: + process.kill() + try: + process.wait(timeout=self._force_kill_wait_seconds) + except subprocess.TimeoutExpired as exc: + raise RuntimeError(f"process did not exit after kill for {instance_type} instance {target_address}") from exc if self._enable_monitor and monitor_node in self.monitor.nodes: self.monitor.nodes.remove(monitor_node) monitor_port = MONITOR_POLLING_PORT + gpu_id - if not self._wait_for_tcp_port_state(self._bootstrap_addr, monitor_port, should_be_open=False, timeout_seconds=5.0): + if not self._wait_for_tcp_port_state(host, monitor_port, should_be_open=False, timeout_seconds=5.0): self.logger.warning( - "Monitor port still open after reclaim: service=%s gpu=%s port=%s", + "Monitor port still open after reclaim: service=%s host=%s rank=%s port=%s", instance_type, + host, gpu_id, monitor_port, ) - self._free_gpus.add(gpu_id) - self._gpu_reuse_block_until[gpu_id] = time.time() + self._gpu_reuse_grace_seconds + if slot_id is not None and slot_id in range(len(self._static_instance_slots)): + self._free_slot_ids.add(slot_id) + self._slot_reuse_block_until[slot_id] = time.time() + self._gpu_reuse_grace_seconds + else: + self._free_gpus.add(gpu_id) + self._gpu_reuse_block_until[gpu_id] = time.time() + self._gpu_reuse_grace_seconds self._managed_instances.pop(target_address, None) if (instance_type, target_address) in self.started_instances: self.started_instances.remove((instance_type, target_address)) - if sidecar_meta is not None: + if sidecar_meta is not None and launch_mode != "remote": reclaim_thread = Thread( target=self._reclaim_sidecar_when_drained, args=(instance_type, target_address, sidecar_meta), @@ -837,8 +1385,9 @@ def reclaim_instance(self, instance_type: str, instance_address: str | None = No self._sidecar_reclaim_threads.append(reclaim_thread) self.logger.info( - "Reclaimed %s instance from gpu=%s address=%s", + "Reclaimed %s instance from host=%s rank=%s address=%s", instance_type, + host, gpu_id, target_address, ) @@ -1078,10 +1627,26 @@ def run(self, config): time.sleep(5.0) - for instance_type in ("encoder", "transformer", "decoder"): - address = self.create_instance(instance_type) - for _ in range(5): - self.create_instance("transformer") + if self._static_instance_slots: + self.logger.info( + "Starting managed instances from static_instance_slots: %s", + [slot["instance_type"] for slot in self._static_instance_slots], + ) + for slot in self._static_instance_slots: + self.create_instance(str(slot["instance_type"])) + else: + for instance_type in ("encoder", "transformer", "decoder"): + self.create_instance(instance_type) + for _ in range(5): + self.create_instance("transformer") + + instance_warmup_wait_s = int(os.getenv("DISAGG_INSTANCE_WARMUP_WAIT_S", "30")) + if instance_warmup_wait_s > 0: + self.logger.info( + "Managed instances created, waiting %ss before accepting requests", + instance_warmup_wait_s, + ) + time.sleep(instance_warmup_wait_s) monitor_stop_event: Event | None = None monitor_thread: Thread | None = None diff --git a/lightx2v/disagg/services/decoder.py b/lightx2v/disagg/services/decoder.py index f0f349605..1052b9e50 100644 --- a/lightx2v/disagg/services/decoder.py +++ b/lightx2v/disagg/services/decoder.py @@ -31,9 +31,11 @@ def __init__(self, config: dict): self.decoder_engine_rank = int(self.config.get("decoder_engine_rank", 2)) self._phase2_rdma_client: Optional[RDMAClient] = None self._phase2_rdma_buffer: Optional[RDMABuffer] = None + data_bootstrap_addr = str(self.config.get("data_bootstrap_addr", "127.0.0.1")) + monitor_bind_host = str(self.config.get("local_hostname", data_bootstrap_addr)) shared_slots = int(self.config.get("rdma_buffer_slots", "128")) shared_slot_size = int(self.config.get("rdma_buffer_slot_size", "4096")) - self._phase2_server_ip = str(self.config.get("rdma_phase2_host", "127.0.0.1")) + self._phase2_server_ip = str(self.config.get("rdma_phase2_host", data_bootstrap_addr)) self._phase2_handshake_port = int(self.config.get("rdma_phase2_handshake_port", "5568")) self._phase2_slots = shared_slots self._phase2_slot_size = shared_slot_size @@ -49,7 +51,7 @@ def __init__(self, config: dict): self.reporter = Reporter( service_type="decoder", gpu_id=self.decoder_engine_rank, - bind_address=f"tcp://{self.config.get('data_bootstrap_addr', '127.0.0.1')}:{MONITOR_POLLING_PORT + self.decoder_engine_rank}", + bind_address=f"tcp://{monitor_bind_host}:{MONITOR_POLLING_PORT + self.decoder_engine_rank}", ) self._queue_metrics_lock = threading.Lock() self._queue_metrics: dict[str, Any] = { diff --git a/lightx2v/disagg/services/encoder.py b/lightx2v/disagg/services/encoder.py index 2e686f4c1..117b3484c 100644 --- a/lightx2v/disagg/services/encoder.py +++ b/lightx2v/disagg/services/encoder.py @@ -39,13 +39,15 @@ def __init__(self, config: dict): self._request_rdma_buffer: Optional[RDMABuffer] = None self._phase1_rdma_client: Optional[RDMAClient] = None self._phase1_rdma_buffer: Optional[RDMABuffer] = None + data_bootstrap_addr = str(self.config.get("data_bootstrap_addr", "127.0.0.1")) + monitor_bind_host = str(self.config.get("local_hostname", data_bootstrap_addr)) shared_slots = int(self.config.get("rdma_buffer_slots", "128")) shared_slot_size = int(self.config.get("rdma_buffer_slot_size", "4096")) - self._request_server_ip = str(self.config.get("rdma_request_host", "127.0.0.1")) + self._request_server_ip = str(self.config.get("rdma_request_host", data_bootstrap_addr)) self._request_handshake_port = int(self.config.get("rdma_request_handshake_port", "5566")) self._request_slots = shared_slots self._request_slot_size = shared_slot_size - self._phase1_server_ip = str(self.config.get("rdma_phase1_host", "127.0.0.1")) + self._phase1_server_ip = str(self.config.get("rdma_phase1_host", data_bootstrap_addr)) self._phase1_handshake_port = int(self.config.get("rdma_phase1_handshake_port", "5567")) self._phase1_slots = shared_slots self._phase1_slot_size = shared_slot_size @@ -63,7 +65,7 @@ def __init__(self, config: dict): self.reporter = Reporter( service_type="encoder", gpu_id=self.encoder_engine_rank, - bind_address=f"tcp://{self.config.get('data_bootstrap_addr', '127.0.0.1')}:{MONITOR_POLLING_PORT + self.encoder_engine_rank}", + bind_address=f"tcp://{monitor_bind_host}:{MONITOR_POLLING_PORT + self.encoder_engine_rank}", ) self._queue_metrics_lock = threading.Lock() self._queue_metrics: dict[str, Any] = { diff --git a/lightx2v/disagg/services/instance_proxy.py b/lightx2v/disagg/services/instance_proxy.py new file mode 100644 index 000000000..f7e2e8f83 --- /dev/null +++ b/lightx2v/disagg/services/instance_proxy.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +import argparse +import os +import signal +import subprocess +import time +from collections.abc import Mapping +from pathlib import Path +from typing import Any + +import zmq + + +class InstanceProxyServer: + """Remote process proxy that creates/stops local disagg service processes. + + This server is intended to run on remote nodes where the local runtime + environment is trusted. The controller sends simple commands to this proxy + instead of assembling remote launch scripts for every instance operation. + """ + + def __init__(self, bind_addr: str, workdir: str, log_dir: str): + self.bind_addr = str(bind_addr) + self.workdir = str(workdir) + self.log_dir = str(log_dir) + self._running = True + self._managed: dict[int, subprocess.Popen] = {} + + def _normalize_env(self, extra_env: Any, cuda_device: str) -> dict[str, str]: + env = os.environ.copy() + if isinstance(extra_env, Mapping): + for key, value in extra_env.items(): + env[str(key)] = str(value) + env["CUDA_VISIBLE_DEVICES"] = str(cuda_device) + env["PYTHONUNBUFFERED"] = "1" + return env + + def _terminate_pid(self, pid: int, timeout_seconds: float) -> bool: + process = self._managed.get(pid) + timeout_seconds = max(1.0, float(timeout_seconds)) + + if process is not None: + if process.poll() is not None: + self._managed.pop(pid, None) + return True + + try: + os.killpg(process.pid, signal.SIGTERM) + except Exception: + process.terminate() + + deadline = time.time() + timeout_seconds + while time.time() < deadline: + if process.poll() is not None: + self._managed.pop(pid, None) + return True + time.sleep(0.1) + + try: + os.killpg(process.pid, signal.SIGKILL) + except Exception: + process.kill() + + try: + process.wait(timeout=2.0) + except Exception: + pass + self._managed.pop(pid, None) + return process.poll() is not None + + # Fallback for pids created before current proxy process lifetime. + try: + os.kill(pid, signal.SIGTERM) + except ProcessLookupError: + return True + except Exception: + return False + + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + os.kill(pid, 0) + except ProcessLookupError: + return True + except Exception: + break + time.sleep(0.1) + + try: + os.kill(pid, signal.SIGKILL) + return True + except ProcessLookupError: + return True + except Exception: + return False + + def _start_instance(self, msg: dict[str, Any]) -> dict[str, Any]: + instance_type = str(msg.get("instance_type", "")) + engine_rank = int(msg.get("engine_rank", -1)) + cuda_device = str(msg.get("cuda_device", "0")) + python_executable = str(msg.get("python_executable", "python")) + service_argv = msg.get("service_argv", []) + sidecar_push_addr = str(msg.get("sidecar_push_addr", "")).strip() + sidecar_req_addr = str(msg.get("sidecar_req_addr", "")).strip() + service_log_path = str(msg.get("service_log_path", "")).strip() + sidecar_log_path = str(msg.get("sidecar_log_path", "")).strip() + workdir = str(msg.get("workdir", self.workdir)) + log_dir = str(msg.get("log_dir", self.log_dir)) + extra_env = msg.get("env", {}) + + if not instance_type: + raise ValueError("instance_type is required") + if engine_rank < 0: + raise ValueError("engine_rank must be non-negative") + if not isinstance(service_argv, list) or not service_argv: + raise ValueError("service_argv must be a non-empty list") + if not sidecar_push_addr or not sidecar_req_addr: + raise ValueError("sidecar_push_addr and sidecar_req_addr are required") + + if not service_log_path: + service_log_path = f"{log_dir}/{instance_type}_{engine_rank}_service.log" + if not sidecar_log_path: + sidecar_log_path = f"{log_dir}/{instance_type}_{engine_rank}_sidecar.log" + + os.makedirs(log_dir, exist_ok=True) + os.makedirs(Path(service_log_path).parent, exist_ok=True) + os.makedirs(Path(sidecar_log_path).parent, exist_ok=True) + + sidecar_env = self._normalize_env(extra_env, cuda_device) + service_env = self._normalize_env(extra_env, cuda_device) + service_env["LIGHTX2V_SIDECAR_PUSH_ADDR"] = sidecar_push_addr + service_env["LIGHTX2V_SIDECAR_REQ_ADDR"] = sidecar_req_addr + + sidecar_cmd = [ + python_executable, + "-m", + "lightx2v.disagg.services.data_mgr_sidecar", + "--push-addr", + sidecar_push_addr, + "--req-addr", + sidecar_req_addr, + ] + service_cmd = [python_executable, *[str(part) for part in service_argv]] + + with open(sidecar_log_path, "a", encoding="utf-8") as sidecar_log: + sidecar_proc = subprocess.Popen( + sidecar_cmd, + cwd=workdir, + env=sidecar_env, + stdout=sidecar_log, + stderr=subprocess.STDOUT, + start_new_session=True, + ) + + time.sleep(0.3) + if sidecar_proc.poll() is not None: + raise RuntimeError(f"failed to start sidecar process, exited with code={sidecar_proc.returncode}") + + with open(service_log_path, "a", encoding="utf-8") as service_log: + service_proc = subprocess.Popen( + service_cmd, + cwd=workdir, + env=service_env, + stdout=service_log, + stderr=subprocess.STDOUT, + start_new_session=True, + ) + + if service_proc.poll() is not None: + self._terminate_pid(sidecar_proc.pid, timeout_seconds=2.0) + raise RuntimeError(f"failed to start service process, exited with code={service_proc.returncode}") + + self._managed[sidecar_proc.pid] = sidecar_proc + self._managed[service_proc.pid] = service_proc + + return { + "instance_type": instance_type, + "engine_rank": engine_rank, + "sidecar_pid": sidecar_proc.pid, + "service_pid": service_proc.pid, + "sidecar_log_path": sidecar_log_path, + "service_log_path": service_log_path, + } + + def handle(self, msg: dict[str, Any]) -> dict[str, Any]: + cmd = str(msg.get("cmd", "")).strip() + + if cmd == "ping": + return {"ok": True, "data": {"alive": True, "managed": len(self._managed)}} + + if cmd == "start_instance": + data = self._start_instance(msg) + return {"ok": True, "data": data} + + if cmd == "stop_pid": + pid = int(msg.get("pid", -1)) + timeout_seconds = float(msg.get("timeout_seconds", 10.0)) + if pid <= 0: + return {"ok": False, "error": "invalid pid"} + stopped = self._terminate_pid(pid, timeout_seconds=timeout_seconds) + return {"ok": bool(stopped), "data": {"pid": pid, "stopped": bool(stopped)}} + + if cmd == "shutdown": + self._running = False + return {"ok": True, "data": {"shutting_down": True}} + + if cmd == "stats": + managed_alive = 0 + for process in self._managed.values(): + if process.poll() is None: + managed_alive += 1 + return {"ok": True, "data": {"managed_alive": managed_alive}} + + return {"ok": False, "error": f"unsupported command: {cmd}"} + + def serve(self): + context = zmq.Context() + socket = context.socket(zmq.REP) + socket.bind(self.bind_addr) + try: + while self._running: + try: + msg = socket.recv_pyobj() + if not isinstance(msg, dict): + socket.send_pyobj({"ok": False, "error": "request must be a dict"}) + continue + reply = self.handle(msg) + except Exception as exc: + reply = {"ok": False, "error": str(exc)} + socket.send_pyobj(reply) + finally: + socket.close(0) + context.term() + for pid in list(self._managed.keys()): + self._terminate_pid(pid, timeout_seconds=2.0) + + +def main(): + parser = argparse.ArgumentParser(description="Remote instance proxy for disagg services") + parser.add_argument("--bind-addr", type=str, required=True) + parser.add_argument("--workdir", type=str, default=str(Path(__file__).resolve().parents[3])) + parser.add_argument("--log-dir", type=str, default="/tmp/lightx2v_disagg") + args = parser.parse_args() + + server = InstanceProxyServer(bind_addr=args.bind_addr, workdir=args.workdir, log_dir=args.log_dir) + server.serve() + + +if __name__ == "__main__": + main() diff --git a/lightx2v/disagg/services/transformer.py b/lightx2v/disagg/services/transformer.py index 1d6d08b67..0641bc817 100644 --- a/lightx2v/disagg/services/transformer.py +++ b/lightx2v/disagg/services/transformer.py @@ -65,13 +65,15 @@ def __init__(self, config: dict): self._phase1_rdma_buffer: Optional[RDMABuffer] = None self._phase2_rdma_client: Optional[RDMAClient] = None self._phase2_rdma_buffer: Optional[RDMABuffer] = None + data_bootstrap_addr = str(self.config.get("data_bootstrap_addr", "127.0.0.1")) + monitor_bind_host = str(self.config.get("local_hostname", data_bootstrap_addr)) shared_slots = int(self.config.get("rdma_buffer_slots", "128")) shared_slot_size = int(self.config.get("rdma_buffer_slot_size", "4096")) - self._phase1_server_ip = str(self.config.get("rdma_phase1_host", "127.0.0.1")) + self._phase1_server_ip = str(self.config.get("rdma_phase1_host", data_bootstrap_addr)) self._phase1_handshake_port = int(self.config.get("rdma_phase1_handshake_port", "5567")) self._phase1_slots = shared_slots self._phase1_slot_size = shared_slot_size - self._phase2_server_ip = str(self.config.get("rdma_phase2_host", "127.0.0.1")) + self._phase2_server_ip = str(self.config.get("rdma_phase2_host", data_bootstrap_addr)) self._phase2_handshake_port = int(self.config.get("rdma_phase2_handshake_port", "5568")) self._phase2_slots = shared_slots self._phase2_slot_size = shared_slot_size @@ -90,7 +92,7 @@ def __init__(self, config: dict): self.reporter = Reporter( service_type="transformer", gpu_id=self.transformer_engine_rank, - bind_address=f"tcp://{self.config.get('data_bootstrap_addr', '127.0.0.1')}:{MONITOR_POLLING_PORT + self.transformer_engine_rank}", + bind_address=f"tcp://{monitor_bind_host}:{MONITOR_POLLING_PORT + self.transformer_engine_rank}", ) self._queue_metrics_lock = threading.Lock() self._queue_metrics: dict[str, Any] = { diff --git a/lightx2v/disagg/workload.py b/lightx2v/disagg/workload.py index 7c2262971..c1efec1ff 100644 --- a/lightx2v/disagg/workload.py +++ b/lightx2v/disagg/workload.py @@ -39,7 +39,7 @@ def task(fn): # type: ignore[no-redef] from lightx2v.disagg.conn import REQUEST_POLLING_PORT, ReqManager REPO_ROOT = Path(__file__).resolve().parents[2] -DEFAULT_BASE_CONFIG_JSON = REPO_ROOT / "configs" / "disagg" / "wan22_i2v_distill_controller.json" +DEFAULT_BASE_CONFIG_JSON = REPO_ROOT / "configs" / "disagg" / "single_node" / "wan22_i2v_distill_controller.json" DEFAULT_STAGE_DEFINITIONS_JSON = REPO_ROOT / "configs" / "disagg" / "wan22_i2v_workload_stages.json" _TEST_START_MONOTONIC: Optional[float] = None diff --git a/scripts/disagg/extract_dynamic_latency.py b/scripts/disagg/extract_dynamic_latency.py index b9fded883..589b83f74 100644 --- a/scripts/disagg/extract_dynamic_latency.py +++ b/scripts/disagg/extract_dynamic_latency.py @@ -8,9 +8,16 @@ from datetime import datetime from pathlib import Path -WAIT_RE = re.compile(r"^\[INFO\]\s+(\d{2}\s+\w{3}\s+\d{4}\s+\d{2}:\d{2}:\d{2}).*Waiting workload configs on port=") -LAT_RE = re.compile(r"^\[INFO\]\s+(\d{2}\s+\w{3}\s+\d{4}\s+\d{2}:\d{2}:\d{2}).*Latency summary room=(\d+) metrics=(\{.*\})") +WAIT_PATTERNS = [ + re.compile(r"^\[INFO\]\s+(\d{2}\s+\w{3}\s+\d{4}\s+\d{2}:\d{2}:\d{2}).*waiting workload configs on port=", re.IGNORECASE), + re.compile(r"^\[(?:INFO|WARNING|ERROR|DEBUG|CRITICAL)\]\s+(\d{2}\s+\w{3}\s+\d{4}\s+\d{2}:\d{2}:\d{2}).*waiting workload configs on port=", re.IGNORECASE), +] +LAT_PATTERNS = [ + re.compile(r"^\[INFO\]\s+(\d{2}\s+\w{3}\s+\d{4}\s+\d{2}:\d{2}:\d{2}).*Latency summary room=(\d+) metrics=(\{.*\})"), + re.compile(r"^\[(?:INFO|WARNING|ERROR|DEBUG|CRITICAL)\]\s+(\d{2}\s+\w{3}\s+\d{4}\s+\d{2}:\d{2}:\d{2}).*Latency summary room=(\d+) metrics=(\{.*\})"), +] TS_FMT = "%d %b %Y %H:%M:%S" +LOGURU_TS_FMT = "%Y-%m-%d %H:%M:%S" def _fmt_float3(value): @@ -20,6 +27,23 @@ def _fmt_float3(value): return value +def _match_any(patterns, line): + for pattern in patterns: + match = pattern.match(line) + if match: + return match + return None + + +def _parse_timestamp(raw_ts: str): + for fmt in (TS_FMT, LOGURU_TS_FMT): + try: + return datetime.strptime(raw_ts, fmt) + except ValueError: + pass + raise ValueError(f"unsupported timestamp format: {raw_ts}") + + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Extract latency summary rows relative to waiting workload log time") parser.add_argument( @@ -50,16 +74,16 @@ def main() -> int: with log_path.open("r", encoding="utf-8", errors="ignore") as f: for line in f: if wait_ts is None: - m_wait = WAIT_RE.match(line) + m_wait = _match_any(WAIT_PATTERNS, line) if m_wait: - wait_ts = datetime.strptime(m_wait.group(1), TS_FMT) + wait_ts = _parse_timestamp(m_wait.group(1)) continue - m_lat = LAT_RE.match(line) + m_lat = _match_any(LAT_PATTERNS, line) if not m_lat: continue - ts = datetime.strptime(m_lat.group(1), TS_FMT) + ts = _parse_timestamp(m_lat.group(1)) room = int(m_lat.group(2)) metrics = ast.literal_eval(m_lat.group(3)) if not isinstance(metrics, dict): diff --git a/scripts/disagg/run_dynamic.sh b/scripts/disagg/run_dynamic.sh index 087f82682..2586865b0 100644 --- a/scripts/disagg/run_dynamic.sh +++ b/scripts/disagg/run_dynamic.sh @@ -10,6 +10,21 @@ export PYTHONPATH=${PYTHONPATH:-} source ${lightx2v_path}/scripts/base/base.sh +disagg_conda_env=${DISAGG_CONDA_ENV:-lightx2v} +if [[ "${DISAGG_SKIP_CONDA_ACTIVATE:-0}" != "1" ]]; then + if [[ "${CONDA_DEFAULT_ENV:-}" != "${disagg_conda_env}" ]]; then + if ! command -v conda >/dev/null 2>&1; then + echo "ERROR: conda is not available, cannot activate env ${disagg_conda_env}" >&2 + exit 2 + fi + set +u + eval "$(conda shell.bash hook)" + conda activate "${disagg_conda_env}" + set -u + echo "activated conda env: ${disagg_conda_env}" + fi +fi + # Ensure stale disagg services/ports from previous runs do not block bootstrap. bash ${lightx2v_path}/scripts/disagg/kill_service.sh || true @@ -31,11 +46,38 @@ if [[ -z "${MOONCAKE_LOCAL_HOSTNAME:-}" ]]; then fi fi -export DISAGG_CONTROLLER_HOST=${DISAGG_CONTROLLER_HOST:-127.0.0.1} +topology=${DISAGG_TOPOLOGY:-multi_node} +default_controller_cfg=${lightx2v_path}/configs/disagg/multi_node/wan22_i2v_distill_controller.json +if [[ "${topology}" == "single_node" ]]; then + default_controller_cfg=${lightx2v_path}/configs/disagg/single_node/wan22_i2v_distill_controller.json +fi +controller_cfg=${DISAGG_CONTROLLER_CFG:-${default_controller_cfg}} +if [[ ! -f "${controller_cfg}" ]]; then + echo "ERROR: controller config not found: ${controller_cfg}" >&2 + exit 2 +fi + +derived_controller_host="" +if command -v jq >/dev/null 2>&1; then + derived_controller_host=$(jq -r '.disagg_config.bootstrap_addr // empty' "${controller_cfg}") +fi +export DISAGG_CONTROLLER_HOST=${DISAGG_CONTROLLER_HOST:-${derived_controller_host:-127.0.0.1}} export DISAGG_CONTROLLER_REQUEST_PORT=${DISAGG_CONTROLLER_REQUEST_PORT:-12786} export LOAD_FROM_USER=${LOAD_FROM_USER:-0} +export DISAGG_INSTANCE_START_TIMEOUT_SECONDS=${DISAGG_INSTANCE_START_TIMEOUT_SECONDS:-90} +# Dynamic debug defaults to a smaller request batch; override for stress runs. +export DISAGG_AUTO_REQUEST_COUNT=${DISAGG_AUTO_REQUEST_COUNT:-1} +user_start_delay_s=${USER_START_DELAY_S:-0} +if [[ -n "${USER_MAX_REQUESTS:-}" ]]; then + user_max_requests=${USER_MAX_REQUESTS} +elif [[ "${LOAD_FROM_USER}" != "0" ]]; then + # When the workload is driven from the user process, keep sending until the stage ends + # unless the caller explicitly sets a hard cap. + user_max_requests=0 +else + user_max_requests=${DISAGG_AUTO_REQUEST_COUNT} +fi -controller_cfg=${lightx2v_path}/configs/disagg/wan22_i2v_distill_controller.json seed=${SEED:-42} prompt=${PROMPT:-"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard."} negative_prompt=${NEGATIVE_PROMPT:-"镜头晃动,色调艳丽,过曝,静态"} @@ -48,9 +90,210 @@ controller_wait_timeout_s=${CONTROLLER_WAIT_TIMEOUT_S:-3000} controller_poll_interval_s=${CONTROLLER_POLL_INTERVAL_S:-5} fatal_watch_interval_s=${FATAL_WATCH_INTERVAL_S:-2} fatal_flag_file=${lightx2v_path}/save_results/disagg_wan22_i2v_dynamic_fatal.flag +remote_log_collect=${REMOTE_LOG_COLLECT:-1} +remote_log_collect_dir=${REMOTE_LOG_COLLECT_DIR:-${lightx2v_path}/save_results/remote_logs} +remote_logs_collected=0 +remote_pre_clean=${DISAGG_REMOTE_PRE_CLEAN:-1} +is_single_node=0 +if [[ "${topology}" == "single_node" ]]; then + is_single_node=1 +fi + +echo "disagg topology=${topology}" +echo "controller_cfg=${controller_cfg}" +echo "DISAGG_CONTROLLER_HOST=${DISAGG_CONTROLLER_HOST} DISAGG_CONTROLLER_REQUEST_PORT=${DISAGG_CONTROLLER_REQUEST_PORT}" +echo "DISAGG_AUTO_REQUEST_COUNT=${DISAGG_AUTO_REQUEST_COUNT}" +echo "LOAD_FROM_USER=${LOAD_FROM_USER} USER_START_DELAY_S=${user_start_delay_s} USER_MAX_REQUESTS=${user_max_requests}" rm -f "${fatal_flag_file}" +pre_clean_remote_hosts_once() { + if [[ "${is_single_node}" == "1" ]]; then + echo "skip remote pre-clean: single_node topology" + return 0 + fi + if [[ "${remote_pre_clean}" == "0" || "${remote_pre_clean}" == "false" ]]; then + return 0 + fi + if ! command -v jq >/dev/null 2>&1; then + echo "skip remote pre-clean: jq not found" + return 0 + fi + + local bootstrap_host + bootstrap_host=$(jq -r '.disagg_config.bootstrap_addr // empty' "${controller_cfg}") + local ssh_user + ssh_user=$(jq -r '.disagg_config.ssh_user // empty' "${controller_cfg}") + local remote_workdir + remote_workdir=$(jq -r '.disagg_config.remote_workdir // empty' "${controller_cfg}") + if [[ -z "${remote_workdir}" ]]; then + remote_workdir="${lightx2v_path}" + fi + + mapfile -t remote_hosts < <(jq -r --arg bootstrap "${bootstrap_host}" '.disagg_config.static_instance_slots[]?.host // empty | select(length > 0 and . != $bootstrap)' "${controller_cfg}" | sort -u) + if (( ${#remote_hosts[@]} == 0 )); then + echo "no remote hosts discovered for pre-clean" + return 0 + fi + + mapfile -t ssh_opts < <(jq -r '.disagg_config.ssh_options[]? // empty' "${controller_cfg}") + + local remote_workdir_q + remote_workdir_q=$(printf '%q' "${remote_workdir}") + local remote_cmd + remote_cmd="set -e; cd ${remote_workdir_q}; bash scripts/disagg/kill_service.sh || true" + + for host in "${remote_hosts[@]}"; do + local target="${host}" + if [[ -n "${ssh_user}" ]]; then + target="${ssh_user}@${host}" + fi + + echo "remote pre-clean on ${host}" + if ssh "${ssh_opts[@]}" "${target}" "bash -lc '${remote_cmd}'"; then + echo "remote pre-clean succeeded on ${host}" + else + echo "warning: remote pre-clean failed on ${host}" + fi + done +} + +sync_remote_configs_once() { + if [[ "${is_single_node}" == "1" ]]; then + echo "skip remote config sync: single_node topology" + return 0 + fi + if ! command -v jq >/dev/null 2>&1; then + echo "skip remote config sync: jq not found" + return 0 + fi + if ! command -v scp >/dev/null 2>&1; then + echo "skip remote config sync: scp not found" + return 0 + fi + + local bootstrap_host + bootstrap_host=$(jq -r '.disagg_config.bootstrap_addr // empty' "${controller_cfg}") + local ssh_user + ssh_user=$(jq -r '.disagg_config.ssh_user // empty' "${controller_cfg}") + local remote_workdir + remote_workdir=$(jq -r '.disagg_config.remote_workdir // empty' "${controller_cfg}") + if [[ -z "${remote_workdir}" ]]; then + remote_workdir="${lightx2v_path}" + fi + + mapfile -t remote_hosts < <(jq -r --arg bootstrap "${bootstrap_host}" '.disagg_config.static_instance_slots[]?.host // empty | select(length > 0 and . != $bootstrap)' "${controller_cfg}" | sort -u) + if (( ${#remote_hosts[@]} == 0 )); then + echo "no remote hosts discovered for config sync" + return 0 + fi + + mapfile -t ssh_opts < <(jq -r '.disagg_config.ssh_options[]? // empty' "${controller_cfg}") + + local config_files=("${controller_cfg}") + for role in encoder transformer decoder; do + local cfg_candidate="${controller_cfg/_controller.json/_${role}.json}" + if [[ -f "${cfg_candidate}" ]]; then + config_files+=("${cfg_candidate}") + fi + done + + for host in "${remote_hosts[@]}"; do + local target="${host}" + if [[ -n "${ssh_user}" ]]; then + target="${ssh_user}@${host}" + fi + + for src_cfg in "${config_files[@]}"; do + local rel_cfg="${src_cfg#${lightx2v_path}/}" + local dst_cfg="${src_cfg}" + if [[ "${src_cfg}" == "${lightx2v_path}/"* ]]; then + dst_cfg="${remote_workdir}/${rel_cfg}" + fi + + local dst_dir + dst_dir=$(dirname "${dst_cfg}") + ssh "${ssh_opts[@]}" "${target}" "mkdir -p '${dst_dir}'" || true + if scp "${ssh_opts[@]}" "${src_cfg}" "${target}:${dst_cfg}" >/dev/null 2>&1; then + echo "synced config to ${host}:${dst_cfg}" + else + echo "warning: failed to sync config to ${host}:${dst_cfg}" + fi + done + done +} + +collect_remote_logs_once() { + if [[ "${is_single_node}" == "1" ]]; then + echo "skip remote log collection: single_node topology" + return 0 + fi + if [[ "${remote_log_collect}" == "0" || "${remote_log_collect}" == "false" ]]; then + return 0 + fi + if [[ "${remote_logs_collected}" == "1" ]]; then + return 0 + fi + remote_logs_collected=1 + + if [[ ! -f "${controller_cfg}" ]]; then + echo "skip remote log collection: controller config not found: ${controller_cfg}" + return 0 + fi + if ! command -v jq >/dev/null 2>&1; then + echo "skip remote log collection: jq not found" + return 0 + fi + + local remote_log_dir + remote_log_dir=$(jq -r '.disagg_config.remote_log_dir // empty' "${controller_cfg}") + local bootstrap_host + bootstrap_host=$(jq -r '.disagg_config.bootstrap_addr // empty' "${controller_cfg}") + local ssh_user + ssh_user=$(jq -r '.disagg_config.ssh_user // empty' "${controller_cfg}") + if [[ -z "${remote_log_dir}" ]]; then + echo "skip remote log collection: disagg_config.remote_log_dir is empty" + return 0 + fi + + mapfile -t remote_hosts < <(jq -r --arg bootstrap "${bootstrap_host}" '.disagg_config.static_instance_slots[]?.host // empty | select(length > 0 and . != $bootstrap)' "${controller_cfg}" | sort -u) + if (( ${#remote_hosts[@]} == 0 )); then + echo "no remote hosts discovered from static_instance_slots, skip remote log collection" + return 0 + fi + + mapfile -t ssh_opts < <(jq -r '.disagg_config.ssh_options[]? // empty' "${controller_cfg}") + + local ts + ts=$(date +%Y%m%d_%H%M%S) + mkdir -p "${remote_log_collect_dir}" + + for host in "${remote_hosts[@]}"; do + local target="${host}" + if [[ -n "${ssh_user}" ]]; then + target="${ssh_user}@${host}" + fi + + local dest_dir="${remote_log_collect_dir}/${host}_${ts}" + local archive_path="${dest_dir}/remote_logs.tgz" + mkdir -p "${dest_dir}" + + local remote_log_dir_q + remote_log_dir_q=$(printf '%q' "${remote_log_dir}") + local remote_cmd + remote_cmd="set -e; shopt -s nullglob; cd ${remote_log_dir_q}; files=(*_service.log *_sidecar.log); if (( \${#files[@]} == 0 )); then exit 3; fi; tar -czf - -- \"\${files[@]}\"" + + if ssh "${ssh_opts[@]}" "${target}" "bash -lc '${remote_cmd}'" > "${archive_path}" 2>/dev/null; then + tar -xzf "${archive_path}" -C "${dest_dir}" >/dev/null 2>&1 || true + rm -f "${archive_path}" + echo "remote logs collected from ${host} -> ${dest_dir}" + else + rm -f "${archive_path}" + echo "warning: failed to collect remote logs from ${host}:${remote_log_dir}" + fi + done +} + has_fatal_log_error() { local log_path="$1" [[ -f "${log_path}" ]] || return 1 @@ -124,10 +367,14 @@ cleanup() { kill "${watchdog_pid}" 2>/dev/null || true fi bash ${lightx2v_path}/scripts/disagg/kill_service.sh || true + collect_remote_logs_once || true } trap cleanup EXIT INT TERM +pre_clean_remote_hosts_once +sync_remote_configs_once + python -m lightx2v.disagg.examples.run_service \ --service controller \ --model_cls wan2.2_moe \ @@ -145,9 +392,14 @@ echo "controller started pid=${controller_pid}" sleep 8 if [[ "${LOAD_FROM_USER}" != "0" ]]; then + if [[ "${user_start_delay_s}" != "0" ]]; then + echo "waiting ${user_start_delay_s}s before run_user to let remote services warm up" + sleep "${user_start_delay_s}" + fi python -m lightx2v.disagg.examples.run_user \ --controller_host "${DISAGG_CONTROLLER_HOST}" \ --controller_request_port "${DISAGG_CONTROLLER_REQUEST_PORT}" \ + --max_requests "${user_max_requests}" \ > ${user_log} 2>&1 & user_pid=$! echo "run_user started pid=${user_pid}" diff --git a/scripts/disagg/run_wan22_i2v_distill.sh b/scripts/disagg/run_wan22_i2v_distill.sh deleted file mode 100755 index 5133f5fd6..000000000 --- a/scripts/disagg/run_wan22_i2v_distill.sh +++ /dev/null @@ -1,116 +0,0 @@ -#!/bin/bash - -# set path firstly -lightx2v_path=/root/zht/LightX2V -model_path=${lightx2v_path}/models/lightx2v/Wan2.2-Distill-Models - -# set environment variables -source ${lightx2v_path}/scripts/base/base.sh - -# Keep flashinfer enabled while ensuring nvcc uses a supported host compiler. -export CC=/usr/bin/gcc-13 -export CXX=/usr/bin/g++-13 -export CUDAHOSTCXX=/usr/bin/g++-13 -if [[ -n "${NVCC_PREPEND_FLAGS:-}" ]]; then - export NVCC_PREPEND_FLAGS="${NVCC_PREPEND_FLAGS} -allow-unsupported-compiler" -else - export NVCC_PREPEND_FLAGS="-allow-unsupported-compiler" -fi - -# Pin disagg RDMA and Mooncake to one NIC to avoid cross-NIC session mismatch. -export RDMA_IFACE=${RDMA_IFACE:-erdma_0} -export MOONCAKE_DEVICE_NAME=${MOONCAKE_DEVICE_NAME:-eth0} -if [[ -z "${MOONCAKE_LOCAL_HOSTNAME:-}" ]]; then - _mc_ip=$(ip -4 -o addr show dev "${MOONCAKE_DEVICE_NAME}" 2>/dev/null | awk '{print $4}' | cut -d/ -f1 | head -n 1) - if [[ -n "${_mc_ip}" ]]; then - export MOONCAKE_LOCAL_HOSTNAME="${_mc_ip}" - fi -fi -echo "RDMA_IFACE=${RDMA_IFACE} MOONCAKE_DEVICE_NAME=${MOONCAKE_DEVICE_NAME} MOONCAKE_LOCAL_HOSTNAME=${MOONCAKE_LOCAL_HOSTNAME:-unset}" - -# Enable simulated network jitter by default for this test script. -# Set NETWORK_LATENCY=0 before running to disable it. -# export NETWORK_LATENCY=${NETWORK_LATENCY:-1} -# echo "NETWORK_LATENCY=${NETWORK_LATENCY}" - -controller_cfg=${lightx2v_path}/configs/disagg/wan22_i2v_distill_controller.json -encoder_cfg=${lightx2v_path}/configs/disagg/wan22_i2v_distill_encoder.json -transformer_cfg=${lightx2v_path}/configs/disagg/wan22_i2v_distill_transformer.json -decoder_cfg=${lightx2v_path}/configs/disagg/wan22_i2v_distill_decoder.json - -seed=42 -request_count=30 -prompt="Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds." -negative_prompt="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" -save_result_path=${lightx2v_path}/save_results/wan22_i2v_distill_disagg.mp4 -output_files=() -for ((i=0; i/dev/null; then - kill "${pid}" 2>/dev/null || true - fi - done -} - -trap cleanup EXIT INT TERM - -if [[ ! -f "${controller_cfg}" ]]; then - echo "Controller config not found: ${controller_cfg}" - exit 1 -fi - -# These are kept for manual split-service debug if needed. -if [[ ! -f "${encoder_cfg}" || ! -f "${transformer_cfg}" || ! -f "${decoder_cfg}" ]]; then - echo "One or more disagg stage configs are missing under configs/disagg" - exit 1 -fi - -python -m lightx2v.disagg.examples.run_service \ - --service controller \ - --model_cls wan2.2_moe \ - --task i2v \ - --model_path ${model_path} \ - --config_json ${controller_cfg} \ - --seed ${seed} \ - --prompt "${prompt}" \ - --negative_prompt "${negative_prompt}" \ - --save_result_path ${save_result_path} \ - > ${lightx2v_path}/save_results/disagg_wan22_i2v_distill_controller.log 2>&1 & -controller_pid=$! - -echo "Waiting for output videos: ${output_files[*]}" -wait_seconds=0 -max_wait_seconds=$((200 * request_count)) - -while true; do - all_generated=1 - for file in "${output_files[@]}"; do - if [[ ! -f "${file}" ]]; then - all_generated=0 - break - fi - done - - if (( all_generated )); then - echo "All ${request_count} output videos are generated." - break - fi - - if (( wait_seconds >= max_wait_seconds )); then - echo "Timeout waiting for output videos after ${max_wait_seconds}s" - exit 1 - fi - - sleep 5 - wait_seconds=$((wait_seconds + 5)) -done - -sleep 60 diff --git a/scripts/disagg/run_wan_t2v_service.sh b/scripts/disagg/run_wan_t2v_service.sh deleted file mode 100755 index 2e1e41842..000000000 --- a/scripts/disagg/run_wan_t2v_service.sh +++ /dev/null @@ -1,161 +0,0 @@ -#!/bin/bash - -# set path firstly -lightx2v_path=/root/zht/LightX2V -model_path=/root/zht/LightX2V/models/Wan-AI/Wan2.1-T2V-1.3B - -# set environment variables -source ${lightx2v_path}/scripts/base/base.sh - -# Keep flashinfer enabled while ensuring nvcc uses a supported host compiler. -export CC=/usr/bin/gcc-13 -export CXX=/usr/bin/g++-13 -export CUDAHOSTCXX=/usr/bin/g++-13 -if [[ -n "${NVCC_PREPEND_FLAGS:-}" ]]; then - export NVCC_PREPEND_FLAGS="${NVCC_PREPEND_FLAGS} -allow-unsupported-compiler" -else - export NVCC_PREPEND_FLAGS="-allow-unsupported-compiler" -fi - -controller_cfg=${lightx2v_path}/configs/disagg/wan_t2v_disagg_controller.json -encoder_cfg=${lightx2v_path}/configs/disagg/wan_t2v_disagg_encoder.json -transformer_cfg=${lightx2v_path}/configs/disagg/wan_t2v_disagg_transformer.json -decoder_cfg=${lightx2v_path}/configs/disagg/wan_t2v_disagg_decoder.json - -seed=42 -request_count=10 -prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." -negative_prompt="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" -save_result_path=${lightx2v_path}/save_results/test_disagg.mp4 -output_files=() -for ((i=1; i<=request_count; i++)); do - output_files+=("${save_result_path%.mp4}${i}.mp4") -done - -# Remove old outputs so wait loop reflects current run status. -rm -f "${output_files[@]}" - -cleanup() { - local pids=("${encoder_pid:-}" "${transformer_pid:-}" "${decoder_pid:-}" "${controller_pid:-}") - for pid in "${pids[@]}"; do - if [[ -n "${pid}" ]] && kill -0 "${pid}" 2>/dev/null; then - kill "${pid}" 2>/dev/null || true - fi - done -} - -trap cleanup EXIT INT TERM - -wait_for_port() { - local host="$1" - local port="$2" - local timeout_secs="${3:-30}" - local waited=0 - - while true; do - if (echo > /dev/tcp/${host}/${port}) >/dev/null 2>&1; then - echo "Port ready: ${host}:${port}" - return 0 - fi - - if (( waited >= timeout_secs )); then - echo "Timeout waiting for port ${host}:${port} after ${timeout_secs}s" - return 1 - fi - - sleep 1 - waited=$((waited + 1)) - done -} - -rdma_request_port=5566 -rdma_phase1_port=5567 -rdma_phase2_port=5568 - -python -m lightx2v.disagg.examples.run_service \ - --service controller \ - --model_cls wan2.1 \ - --task t2v \ - --model_path ${model_path} \ - --config_json ${controller_cfg} \ - --seed ${seed} \ - --prompt "${prompt}" \ - --negative_prompt "${negative_prompt}" \ - --save_result_path ${save_result_path} \ - > ${lightx2v_path}/save_results/disagg_controller.log 2>&1 & -controller_pid=$! - -wait_for_port 127.0.0.1 ${rdma_request_port} 60 -wait_for_port 127.0.0.1 ${rdma_phase1_port} 60 -wait_for_port 127.0.0.1 ${rdma_phase2_port} 60 - -# NOTE: Kept for rollback. Controller now creates encoder/transformer/decoder internally. -# CUDA_VISIBLE_DEVICES=0 python -m lightx2v.disagg.examples.run_service \ -# --service encoder \ -# --model_cls wan2.1 \ -# --task t2v \ -# --model_path ${model_path} \ -# --config_json ${encoder_cfg} \ -# --seed ${seed} \ -# --prompt "${prompt}" \ -# --negative_prompt "${negative_prompt}" \ -# --save_result_path ${save_result_path} \ -# > ${lightx2v_path}/save_results/disagg_encoder.log 2>&1 & -# encoder_pid=$! - -# CUDA_VISIBLE_DEVICES=1 python -m lightx2v.disagg.examples.run_service \ -# --service transformer \ -# --model_cls wan2.1 \ -# --task t2v \ -# --model_path ${model_path} \ -# --config_json ${transformer_cfg} \ -# --seed ${seed} \ -# --prompt "${prompt}" \ -# --negative_prompt "${negative_prompt}" \ -# --save_result_path ${save_result_path} \ -# > ${lightx2v_path}/save_results/disagg_transformer.log 2>&1 & -# transformer_pid=$! - -# CUDA_VISIBLE_DEVICES=2 python -m lightx2v.disagg.examples.run_service \ -# --service decoder \ -# --model_cls wan2.1 \ -# --task t2v \ -# --model_path ${model_path} \ -# --config_json ${decoder_cfg} \ -# --seed ${seed} \ -# --prompt "${prompt}" \ -# --negative_prompt "${negative_prompt}" \ -# --save_result_path ${save_result_path} \ -# > ${lightx2v_path}/save_results/disagg_decoder.log 2>&1 & -# decoder_pid=$! - -# Give background services time to flush and finish queued requests. - -echo "Waiting for output videos: ${output_files[*]}" -wait_seconds=0 -max_wait_seconds=$((600 * request_count)) - -while true; do - all_generated=1 - for file in "${output_files[@]}"; do - if [[ ! -f "${file}" ]]; then - all_generated=0 - break - fi - done - - if (( all_generated )); then - echo "All ${request_count} output videos are generated." - break - fi - - if (( wait_seconds >= max_wait_seconds )); then - echo "Timeout waiting for output videos after ${max_wait_seconds}s" - exit 1 - fi - - sleep 5 - wait_seconds=$((wait_seconds + 5)) -done - -sleep 30 From 5a90f98145f68cf7c8182d514801a6a553d48274 Mon Sep 17 00:00:00 2001 From: zhtshr <547553598@qq.com> Date: Fri, 24 Apr 2026 16:41:00 +0800 Subject: [PATCH 7/9] enable multi machine test --- .../wan22_i2v_distill_controller.json | 69 ++- .../multi_node/wan22_i2v_distill_decoder.json | 4 +- .../multi_node/wan22_i2v_distill_encoder.json | 4 +- .../wan22_i2v_distill_transformer.json | 4 +- .../wan22_i2v_distill_controller.json | 2 +- lightx2v/disagg/README.md | 143 ++++++ lightx2v/disagg/mooncake.py | 45 +- lightx2v/disagg/rdma_base.py | 75 ++++ lightx2v/disagg/rdma_client.py | 155 +++---- lightx2v/disagg/rdma_server.py | 170 ++++--- lightx2v/disagg/rdma_utils.py | 167 ++++++- lightx2v/disagg/services/controller.py | 419 +++++++++++++++--- lightx2v/disagg/services/decoder.py | 57 ++- lightx2v/disagg/services/encoder.py | 148 ++++++- lightx2v/disagg/services/instance_proxy.py | 4 +- lightx2v/disagg/services/transformer.py | 185 ++++++-- lightx2v/disagg/utils.py | 46 +- scripts/disagg/extract_dynamic_latency.py | 2 +- scripts/disagg/kill_service.sh | 41 ++ scripts/disagg/run_dynamic.sh | 102 ++++- 20 files changed, 1523 insertions(+), 319 deletions(-) create mode 100644 lightx2v/disagg/README.md create mode 100644 lightx2v/disagg/rdma_base.py diff --git a/configs/disagg/multi_node/wan22_i2v_distill_controller.json b/configs/disagg/multi_node/wan22_i2v_distill_controller.json index c4c691ec7..f6162108e 100644 --- a/configs/disagg/multi_node/wan22_i2v_distill_controller.json +++ b/configs/disagg/multi_node/wan22_i2v_distill_controller.json @@ -42,7 +42,7 @@ "low_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", "high_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", "low_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", - "image_path": "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B/examples/i2v_input.JPG", + "image_path": "/root/zht/LightX2V/assets/inputs/imgs/img_0.jpg", "disagg_mode": "controller", "disagg_config": { "bootstrap_addr": "192.168.0.166", @@ -54,6 +54,9 @@ "protocol": "rdma", "local_hostname": "192.168.0.166", "metadata_server": "P2PHANDSHAKE", + "service_env": { + "RDMA_IFACE": "erdma_0" + }, "remote_workdir": "/root/zht/LightX2V", "remote_python_executable": "python", "remote_activate_cmd": "source /root/miniconda3/etc/profile.d/conda.sh && conda activate lightx2v && export LD_LIBRARY_PATH=/root/miniconda3/envs/lightx2v/lib:${LD_LIBRARY_PATH:-}", @@ -72,51 +75,83 @@ "static_instance_slots": [ { "instance_type": "encoder", - "host": "192.168.0.166", + "host": "192.168.0.139", "engine_rank": 0, - "cuda_device": 0 + "cuda_device": 0, + "env": { + "MOONCAKE_DEVICE_NAME": "eth0", + "MOONCAKE_LOCAL_HOSTNAME": "192.168.0.139" + } }, { "instance_type": "transformer", - "host": "192.168.0.139", + "host": "192.168.0.166", "engine_rank": 1, - "cuda_device": 0 + "cuda_device": 0, + "env": { + "MOONCAKE_DEVICE_NAME": "eth0", + "MOONCAKE_LOCAL_HOSTNAME": "192.168.0.166" + } }, { "instance_type": "transformer", - "host": "192.168.0.139", + "host": "192.168.0.166", "engine_rank": 2, - "cuda_device": 1 + "cuda_device": 1, + "env": { + "MOONCAKE_DEVICE_NAME": "eth0", + "MOONCAKE_LOCAL_HOSTNAME": "192.168.0.166" + } }, { "instance_type": "transformer", - "host": "192.168.0.139", + "host": "192.168.0.166", "engine_rank": 3, - "cuda_device": 2 + "cuda_device": 2, + "env": { + "MOONCAKE_DEVICE_NAME": "eth0", + "MOONCAKE_LOCAL_HOSTNAME": "192.168.0.166" + } }, { "instance_type": "transformer", - "host": "192.168.0.139", + "host": "192.168.0.166", "engine_rank": 4, - "cuda_device": 3 + "cuda_device": 3, + "env": { + "MOONCAKE_DEVICE_NAME": "eth0", + "MOONCAKE_LOCAL_HOSTNAME": "192.168.0.166" + } }, { "instance_type": "transformer", - "host": "192.168.0.139", + "host": "192.168.0.166", "engine_rank": 5, - "cuda_device": 4 + "cuda_device": 4, + "env": { + "MOONCAKE_DEVICE_NAME": "eth0", + "MOONCAKE_LOCAL_HOSTNAME": "192.168.0.166" + } }, { "instance_type": "transformer", - "host": "192.168.0.139", + "host": "192.168.0.166", "engine_rank": 6, - "cuda_device": 5 + "cuda_device": 5, + "env": { + "MOONCAKE_DEVICE_NAME": "eth0", + "MOONCAKE_LOCAL_HOSTNAME": "192.168.0.166" + } }, { "instance_type": "decoder", - "host": "192.168.0.166", + "host": "192.168.0.139", "engine_rank": 7, - "cuda_device": 7 + "cuda_device": 1, + "env": { + "MOONCAKE_DEVICE_NAME": "eth0", + "MOONCAKE_LOCAL_HOSTNAME": "192.168.0.139" + } } ] } diff --git a/configs/disagg/multi_node/wan22_i2v_distill_decoder.json b/configs/disagg/multi_node/wan22_i2v_distill_decoder.json index 31b7fa418..8549d4165 100644 --- a/configs/disagg/multi_node/wan22_i2v_distill_decoder.json +++ b/configs/disagg/multi_node/wan22_i2v_distill_decoder.json @@ -42,7 +42,7 @@ "low_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", "high_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", "low_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", - "image_path": "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B/examples/i2v_input.JPG", + "image_path": "/root/zht/LightX2V/assets/inputs/imgs/img_0.jpg", "disagg_mode": "decoder", "disagg_config": { "bootstrap_addr": "192.168.0.166", @@ -52,7 +52,7 @@ "transformer_engine_rank": 1, "decoder_engine_rank": 2, "protocol": "rdma", - "local_hostname": "192.168.0.166", + "local_hostname": "192.168.0.139", "metadata_server": "P2PHANDSHAKE" } } diff --git a/configs/disagg/multi_node/wan22_i2v_distill_encoder.json b/configs/disagg/multi_node/wan22_i2v_distill_encoder.json index 27cadaaf0..7b126a729 100644 --- a/configs/disagg/multi_node/wan22_i2v_distill_encoder.json +++ b/configs/disagg/multi_node/wan22_i2v_distill_encoder.json @@ -42,7 +42,7 @@ "low_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", "high_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", "low_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", - "image_path": "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B/examples/i2v_input.JPG", + "image_path": "/root/zht/LightX2V/assets/inputs/imgs/img_0.jpg", "disagg_mode": "encoder", "disagg_config": { "bootstrap_addr": "192.168.0.166", @@ -52,7 +52,7 @@ "transformer_engine_rank": 1, "decoder_engine_rank": 2, "protocol": "rdma", - "local_hostname": "192.168.0.166", + "local_hostname": "192.168.0.139", "metadata_server": "P2PHANDSHAKE" } } diff --git a/configs/disagg/multi_node/wan22_i2v_distill_transformer.json b/configs/disagg/multi_node/wan22_i2v_distill_transformer.json index a629ff2f4..99572301f 100644 --- a/configs/disagg/multi_node/wan22_i2v_distill_transformer.json +++ b/configs/disagg/multi_node/wan22_i2v_distill_transformer.json @@ -42,7 +42,7 @@ "low_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", "high_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", "low_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", - "image_path": "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B/examples/i2v_input.JPG", + "image_path": "/root/zht/LightX2V/assets/inputs/imgs/img_0.jpg", "disagg_mode": "transformer", "disagg_config": { "bootstrap_addr": "192.168.0.166", @@ -52,7 +52,7 @@ "transformer_engine_rank": 1, "decoder_engine_rank": 2, "protocol": "rdma", - "local_hostname": "192.168.0.139", + "local_hostname": "192.168.0.166", "metadata_server": "P2PHANDSHAKE" } } diff --git a/configs/disagg/single_node/wan22_i2v_distill_controller.json b/configs/disagg/single_node/wan22_i2v_distill_controller.json index e8e4f577b..12c9c6730 100644 --- a/configs/disagg/single_node/wan22_i2v_distill_controller.json +++ b/configs/disagg/single_node/wan22_i2v_distill_controller.json @@ -42,7 +42,7 @@ "low_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", "high_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", "low_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", - "image_path": "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B/examples/i2v_input.JPG", + "image_path": "/root/zht/LightX2V/assets/inputs/imgs/img_0.jpg", "disagg_mode": "controller", "disagg_config": { "bootstrap_addr": "127.0.0.1", diff --git a/lightx2v/disagg/README.md b/lightx2v/disagg/README.md new file mode 100644 index 000000000..13d064464 --- /dev/null +++ b/lightx2v/disagg/README.md @@ -0,0 +1,143 @@ +# disagg / `run_dynamic.sh` 使用说明 + +`scripts/disagg/run_dynamic.sh` 是 LightX2V 的动态多机/单机离线调度启动脚本。它会自动完成以下工作: + +1. 激活 `lightx2v` conda 环境,除非显式关闭。 +2. 先执行 `scripts/disagg/kill_service.sh` 清理残留进程和端口。 +3. 读取 controller 配置,准备 `multi_node` 或 `single_node` 启动参数。 +4. 为 Mooncake / RDMA / ZMQ / 日志收集设置默认环境变量。 +5. 启动 controller,并按配置拉起 encoder / transformer / decoder。 + +## 基本用法 + +最常见的方式是直接运行脚本: + +```bash +bash scripts/disagg/run_dynamic.sh +``` + +如果要切换拓扑或覆盖默认配置,可以在命令前追加环境变量: + +```bash +DISAGG_TOPOLOGY=multi_node \ +DISAGG_CONTROLLER_CFG=/root/zht/LightX2V/configs/disagg/multi_node/wan22_i2v_distill_controller.json \ +bash scripts/disagg/run_dynamic.sh +``` + +单机调试可以改成: + +```bash +DISAGG_TOPOLOGY=single_node \ +bash scripts/disagg/run_dynamic.sh +``` + +## 脚本会自动处理的事情 + +脚本会自动: + +1. 如果当前没有激活到 `DISAGG_CONDA_ENV`,就尝试 `conda activate`。 +2. 设置编译器和 `NVCC_PREPEND_FLAGS`,便于本地编译或运行扩展。 +3. 默认将 `RDMA_IFACE` 设为 `erdma_0`,将 `MOONCAKE_DEVICE_NAME` 设为 `eth0`。 +4. 如果没有显式设置 `MOONCAKE_LOCAL_HOSTNAME`,就从 `MOONCAKE_DEVICE_NAME` 对应网卡自动解析本机 IPv4。 +5. 根据 controller 配置里的 `bootstrap_addr` 自动推导 `DISAGG_CONTROLLER_HOST`。 +6. 先执行 `kill_service.sh` 清理旧服务,避免端口冲突。 + +## 环境变量说明 + +下面按功能分组说明常用变量。未特别说明时,都是脚本默认值。 + +### 运行模式与配置 + +| 变量 | 含义 | 默认值 | +| --- | --- | --- | +| `DISAGG_TOPOLOGY` | 运行拓扑,`multi_node` 表示多机,`single_node` 表示单机。 | `multi_node` | +| `DISAGG_CONTROLLER_CFG` | controller 配置文件路径。脚本会根据拓扑自动选择默认配置。 | `configs/disagg/multi_node/wan22_i2v_distill_controller.json` 或 single_node 对应文件 | +| `DISAGG_CONDA_ENV` | 启动时要激活的 conda 环境名。 | `lightx2v` | +| `DISAGG_SKIP_CONDA_ACTIVATE` | 设为 `1` 时跳过 conda 激活。 | `0` | +| `DISAGG_CONTROLLER_HOST` | controller 对外使用的主机地址。若未设置,脚本会尝试从配置文件 `bootstrap_addr` 推导。 | 配置里的 `bootstrap_addr`,否则 `127.0.0.1` | + +### RDMA / Mooncake + +| 变量 | 含义 | 默认值 | +| --- | --- | --- | +| `RDMA_IFACE` | 本机 RDMA / eRDMA 网卡名。 | `erdma_0` | +| `MOONCAKE_DEVICE_NAME` | Mooncake 用来解析本机 IPv4 的网卡名。 | `eth0` | +| `MOONCAKE_LOCAL_HOSTNAME` | Mooncake 认为的本机地址。若未设置,脚本会自动从 `MOONCAKE_DEVICE_NAME` 对应网卡提取 IPv4。 | 自动推导 | +| `RDMA_PREFERRED_IPV4` | 优先选择的 RDMA 数据平面 IPv4,通常用于多网卡环境下稳定选择 gid_index。 | 自动推导为 `DISAGG_CONTROLLER_HOST`(当其是 IPv4 且不是 `127.0.0.1`) | + +### 控制面端口与启动等待 + +| 变量 | 含义 | 默认值 | +| --- | --- | --- | +| `DISAGG_CONTROLLER_REQUEST_PORT` | controller 请求入口端口。 | `12786` | +| `DISAGG_INSTANCE_START_TIMEOUT_SECONDS` | 等待实例启动完成的超时时间。 | `single_node: 90`,`multi_node: 300` | +| `DISAGG_REMOTE_PROXY_START_TIMEOUT_SECONDS` | 等待远端 proxy 启动的超时时间。 | `120` | +| `DISAGG_SIDECAR_START_TIMEOUT_SECONDS` | 等待 sidecar 启动的超时时间。 | `60` | +| `CONTROLLER_WAIT_TIMEOUT_S` | 等待 controller 完成整轮任务的超时时间。 | `single_node: 3000`,`multi_node: 7200` | +| `CONTROLLER_POLL_INTERVAL_S` | controller 状态轮询间隔。 | `5` | + +### 请求数量、调试与通信方式 + +| 变量 | 含义 | 默认值 | +| --- | --- | --- | +| `LOAD_FROM_USER` | 设为非 `0` 时,由 user 侧持续发请求,直到阶段结束。 | `0` | +| `DISAGG_AUTO_REQUEST_COUNT` | 自动请求的默认数量。`LOAD_FROM_USER=0` 时会使用这个值。 | `30` | +| `USER_MAX_REQUESTS` | 手动限制 user 进程最多发多少个请求,优先级高于 `DISAGG_AUTO_REQUEST_COUNT`。 | 未设置 | +| `USER_START_DELAY_S` | user 进程启动后的延迟时间。 | `0` | +| `SYNC_COMM` | 是否启用同步通信模式。 | `0` | + +### Nsight 采集 + +| 变量 | 含义 | 默认值 | +| --- | --- | --- | +| `DISAGG_ENABLE_NSYS` | 是否启用 `nsys profile` 包裹实例进程。 | `0` | +| `DISAGG_NSYS_BIN` | `nsys` 可执行文件路径或命令名。 | `nsys` | +| `DISAGG_NSYS_OUTPUT_DIR` | nsys trace 输出目录。 | `save_results/nsys` | +| `DISAGG_NSYS_TRACE` | `nsys profile` 的 trace 类型。 | `cuda,nvtx,osrt` | +| `DISAGG_NSYS_EXTRA_ARGS` | 额外传给 `nsys profile` 的参数。 | 空 | + +### 日志与清理 + +| 变量 | 含义 | 默认值 | +| --- | --- | --- | +| `REMOTE_LOG_COLLECT` | 是否在结束后拉取远端日志。 | `1` | +| `REMOTE_LOG_COLLECT_DIR` | 远端日志收集到本地的目录。 | `save_results/remote_logs` | +| `DISAGG_REMOTE_PRE_CLEAN` | 是否在启动前先远端执行 `kill_service.sh`。 | `1` | +| `SEED` | 随机种子。 | `42` | +| `PROMPT` | 文本提示词。 | 脚本内置示例 prompt | +| `NEGATIVE_PROMPT` | 负向提示词。 | 脚本内置示例 negative prompt | +| `SAVE_RESULT_PATH` | 最终视频保存路径。 | `save_results/wan22_i2v_dynamic.mp4` | + +## 推荐的常见组合 + +### 本地单机调试 + +```bash +DISAGG_TOPOLOGY=single_node \ +LOAD_FROM_USER=0 \ +DISAGG_AUTO_REQUEST_COUNT=1 \ +bash scripts/disagg/run_dynamic.sh +``` + +### 多机标准跑法 + +```bash +DISAGG_TOPOLOGY=multi_node \ +DISAGG_CONTROLLER_CFG=/root/zht/LightX2V/configs/disagg/multi_node/wan22_i2v_distill_controller.json \ +DISAGG_AUTO_REQUEST_COUNT=30 \ +bash scripts/disagg/run_dynamic.sh +``` + +### 开启 Nsight + +```bash +DISAGG_ENABLE_NSYS=1 \ +DISAGG_NSYS_TRACE=cuda,nvtx,osrt \ +bash scripts/disagg/run_dynamic.sh +``` + +## 备注 + +1. 多机运行时,`DISAGG_CONTROLLER_CFG` 里的 `bootstrap_addr`、`static_instance_slots` 和各 slot 的 `env` 会直接影响远端实例如何绑定网络与 Mooncake 地址。 +2. 如果遇到端口占用,优先检查 `scripts/disagg/kill_service.sh` 是否已经把旧实例和 proxy 清理干净。 +3. 如果需要了解 controller 配置文件本身的字段含义,可以继续查看 `configs/disagg/` 下对应 JSON。 \ No newline at end of file diff --git a/lightx2v/disagg/mooncake.py b/lightx2v/disagg/mooncake.py index 70cc27166..e26ca3caf 100644 --- a/lightx2v/disagg/mooncake.py +++ b/lightx2v/disagg/mooncake.py @@ -31,6 +31,25 @@ def _detect_non_loopback_ipv4() -> str | None: return None +def _collect_local_ipv4_addresses() -> list[str]: + candidates: list[str] = [] + + try: + hostname = socket.gethostname() + for info in socket.getaddrinfo(hostname, None, socket.AF_INET): + address = info[4][0] + if address and not address.startswith("127.") and address not in candidates: + candidates.append(address) + except Exception: + pass + + detected = _detect_non_loopback_ipv4() + if detected is not None and detected not in candidates: + candidates.append(detected) + + return candidates + + @dataclass class MooncakeTransferEngineConfig: local_hostname: str @@ -55,6 +74,7 @@ def load_from_env() -> "MooncakeTransferEngineConfig": if config_file_path is None: raise ValueError("The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") cfg = MooncakeTransferEngineConfig.from_file(config_file_path) + local_ipv4s = _collect_local_ipv4_addresses() env_metadata_server = os.getenv("MOONCAKE_METADATA_SERVER", "").strip() if env_metadata_server: @@ -73,13 +93,36 @@ def load_from_env() -> "MooncakeTransferEngineConfig": force_ipv4 = os.getenv("MOONCAKE_FORCE_IPV4_LOOPBACK", "1") not in ("0", "false", "False") env_host = os.getenv("MOONCAKE_LOCAL_HOSTNAME", "").strip() if env_host: - cfg.local_hostname = env_host + if env_host in ("localhost", "::1", "127.0.0.1") or env_host in local_ipv4s: + cfg.local_hostname = env_host + else: + detected = _detect_non_loopback_ipv4() + if detected is not None: + logger.warning( + "Ignoring MOONCAKE_LOCAL_HOSTNAME=%s because it does not match this host (local_ipv4s=%s); using %s", + env_host, + local_ipv4s, + detected, + ) + cfg.local_hostname = detected + elif force_ipv4: + cfg.local_hostname = "127.0.0.1" elif force_ipv4 and cfg.local_hostname in ("localhost", "::1", "127.0.0.1"): detected = _detect_non_loopback_ipv4() if detected is not None: cfg.local_hostname = detected else: cfg.local_hostname = "127.0.0.1" + elif cfg.local_hostname not in local_ipv4s and cfg.local_hostname not in ("localhost", "::1", "127.0.0.1"): + detected = _detect_non_loopback_ipv4() + if detected is not None: + logger.warning( + "Auto-correcting Mooncake local_hostname from %s to %s on this host (local_ipv4s=%s)", + cfg.local_hostname, + detected, + local_ipv4s, + ) + cfg.local_hostname = detected return cfg diff --git a/lightx2v/disagg/rdma_base.py b/lightx2v/disagg/rdma_base.py new file mode 100644 index 000000000..90a7fb6e2 --- /dev/null +++ b/lightx2v/disagg/rdma_base.py @@ -0,0 +1,75 @@ +"""Shared pyverbs imports and thin RDMA types for client/server.""" + +from __future__ import annotations + +import pyverbs.enums as e +from pyverbs.addr import GID, AHAttr, GlobalRoute +from pyverbs.cq import CQ +from pyverbs.device import Context, get_device_list +from pyverbs.mr import MR +from pyverbs.pd import PD +from pyverbs.qp import QP, QPAttr, QPCap, QPInitAttr +from pyverbs.wr import SGE +from pyverbs.wr import SendWR as WR + +from lightx2v.disagg.rdma_utils import ( + recv_json_from_stream, + resolve_gid_index, + rtr_ah_dest_dlid, + rtr_path_mtu, + rtr_path_mtu_negotiated, +) + + +class IBDevice: + def __init__(self, name: str): + self.name = name + + def open(self): + return Context(name=self.name) + + +class QPType: + RC = e.IBV_QPT_RC + + +class WROpcode: + RDMA_WRITE = e.IBV_WR_RDMA_WRITE + RDMA_READ = e.IBV_WR_RDMA_READ + ATOMIC_FETCH_AND_ADD = e.IBV_WR_ATOMIC_FETCH_AND_ADD + ATOMIC_CMP_AND_SWP = e.IBV_WR_ATOMIC_CMP_AND_SWP + + +class AccessFlag: + LOCAL_WRITE = e.IBV_ACCESS_LOCAL_WRITE + REMOTE_WRITE = e.IBV_ACCESS_REMOTE_WRITE + REMOTE_READ = e.IBV_ACCESS_REMOTE_READ + REMOTE_ATOMIC = e.IBV_ACCESS_REMOTE_ATOMIC + + +__all__ = [ + "AccessFlag", + "AHAttr", + "CQ", + "Context", + "GID", + "GlobalRoute", + "IBDevice", + "MR", + "PD", + "QP", + "QPAttr", + "QPCap", + "QPInitAttr", + "QPType", + "SGE", + "WR", + "WROpcode", + "e", + "get_device_list", + "recv_json_from_stream", + "resolve_gid_index", + "rtr_ah_dest_dlid", + "rtr_path_mtu", + "rtr_path_mtu_negotiated", +] diff --git a/lightx2v/disagg/rdma_client.py b/lightx2v/disagg/rdma_client.py index fd88dc615..2c7214bbc 100644 --- a/lightx2v/disagg/rdma_client.py +++ b/lightx2v/disagg/rdma_client.py @@ -6,47 +6,35 @@ import threading import time -import pyverbs.enums as e -from pyverbs.addr import GID, AHAttr, GlobalRoute -from pyverbs.cq import CQ -from pyverbs.device import Context, get_device_list -from pyverbs.mr import MR -from pyverbs.pd import PD -from pyverbs.qp import QP, QPAttr, QPCap, QPInitAttr -from pyverbs.wr import SGE -from pyverbs.wr import SendWR as WR - -from lightx2v.disagg.rdma_utils import resolve_gid_index +from lightx2v.disagg.rdma_base import ( + AccessFlag, + AHAttr, + CQ, + GID, + GlobalRoute, + IBDevice, + MR, + PD, + QP, + QPAttr, + QPCap, + QPInitAttr, + QPType, + SGE, + WR, + WROpcode, + e, + get_device_list, + recv_json_from_stream, + resolve_gid_index, + rtr_ah_dest_dlid, + rtr_path_mtu, + rtr_path_mtu_negotiated, +) logger = logging.getLogger(__name__) -class IBDevice: - def __init__(self, name: str): - self.name = name - - def open(self): - return Context(name=self.name) - - -class QPType: - RC = e.IBV_QPT_RC - - -class WROpcode: - RDMA_WRITE = e.IBV_WR_RDMA_WRITE - RDMA_READ = e.IBV_WR_RDMA_READ - ATOMIC_FETCH_AND_ADD = e.IBV_WR_ATOMIC_FETCH_AND_ADD - ATOMIC_CMP_AND_SWP = e.IBV_WR_ATOMIC_CMP_AND_SWP - - -class AccessFlag: - LOCAL_WRITE = e.IBV_ACCESS_LOCAL_WRITE - REMOTE_WRITE = e.IBV_ACCESS_REMOTE_WRITE - REMOTE_READ = e.IBV_ACCESS_REMOTE_READ - REMOTE_ATOMIC = e.IBV_ACCESS_REMOTE_ATOMIC - - class RDMAClient: def __init__(self, iface_name=None, local_buffer_size=4096): self.local_psn = 654321 @@ -133,30 +121,6 @@ def _reset_qp(self): except Exception: pass - def _recv_json(self, sock, timeout_sec): - decoder = json.JSONDecoder() - chunks = [] - deadline = time.time() + timeout_sec - while time.time() < deadline: - try: - chunk = sock.recv(4096) - except socket.timeout: - continue - - if not chunk: - break - - chunks.append(chunk) - payload = b"".join(chunks).decode("utf-8", errors="strict") - try: - obj, _ = decoder.raw_decode(payload) - return obj - except json.JSONDecodeError: - continue - - msg = b"".join(chunks).decode("utf-8", errors="ignore") - raise RuntimeError(f"Timed out waiting for complete handshake JSON. payload={msg!r}") - def _ensure_local_mr_capacity(self, required_size: int): required = int(required_size) if required <= self.buffer_size: @@ -191,7 +155,7 @@ def connect_to_server(self, server_ip="127.0.0.1", port=5566): sock.connect((server_ip, port)) # 1. 接收 Server 信息 (包含 rkey 和 addr) - remote_info = self._recv_json(sock, timeout_sec=connect_timeout_sec) + remote_info = recv_json_from_stream(sock, timeout_sec=connect_timeout_sec) if not isinstance(remote_info, dict): raise RuntimeError(f"Invalid handshake payload type: {type(remote_info)}") required_keys = {"addr", "rkey", "qpn", "psn", "gid"} @@ -209,6 +173,7 @@ def connect_to_server(self, server_ip="127.0.0.1", port=5566): "psn": self.local_psn, "gid": str(gid), "gid_index": self.gid_index, + "active_mtu": int(rtr_path_mtu(self.ctx, self.port_num)), } sock.sendall(json.dumps(my_info).encode()) @@ -242,22 +207,58 @@ def connect_to_server(self, server_ip="127.0.0.1", port=5566): def _modify_qp_to_rts(self): # Follow the standard RC flow: INIT -> RTR -> RTS. - init_attr = QPAttr(port_num=self.port_num) - init_attr.qp_access_flags = AccessFlag.LOCAL_WRITE | AccessFlag.REMOTE_WRITE | AccessFlag.REMOTE_READ | AccessFlag.REMOTE_ATOMIC - self.qp.to_init(init_attr) - - rtr_attr = QPAttr(port_num=self.port_num) - rtr_attr.path_mtu = e.IBV_MTU_1024 - rtr_attr.max_dest_rd_atomic = 1 - rtr_attr.min_rnr_timer = 12 - rtr_attr.dest_qp_num = int(self.remote_info["qpn"]) - rtr_attr.rq_psn = int(self.remote_info["psn"]) - remote_lid = int(self.remote_info.get("lid", 0)) - remote_gid_index = int(self.remote_info.get("gid_index", self.gid_index)) - gr = GlobalRoute(dgid=GID(self.remote_info["gid"]), sgid_index=remote_gid_index) - rtr_attr.ah_attr = AHAttr(port_num=self.port_num, is_global=1, gr=gr, dlid=remote_lid) - self.qp.to_rtr(rtr_attr) + heuristic_dlid = rtr_ah_dest_dlid(self.ctx, self.port_num, remote_lid) + negotiated_mtu = int(rtr_path_mtu_negotiated(self.ctx, self.port_num, self.remote_info.get("active_mtu"))) + local_mtu = int(rtr_path_mtu(self.ctx, self.port_num)) + default_mtu = int(e.IBV_MTU_1024) + + # Some eRDMA/RoCE stacks are strict about dlid/mtu combinations; try safe fallbacks. + mtu_candidates = [] + for v in (negotiated_mtu, local_mtu, default_mtu): + if v not in mtu_candidates: + mtu_candidates.append(v) + dlid_candidates = [] + for v in (heuristic_dlid, 0, remote_lid): + if v not in dlid_candidates: + dlid_candidates.append(v) + + gr = GlobalRoute(dgid=GID(self.remote_info["gid"]), sgid_index=self.gid_index, hop_limit=1) + last_exc = None + for rd_atomic in (1, 0): + for mtu in mtu_candidates: + for dlid in dlid_candidates: + for is_global in (1, 0): + try: + init_attr = QPAttr(port_num=self.port_num) + init_attr.qp_access_flags = AccessFlag.LOCAL_WRITE | AccessFlag.REMOTE_WRITE | AccessFlag.REMOTE_READ | AccessFlag.REMOTE_ATOMIC + self.qp.to_init(init_attr) + + rtr_attr = QPAttr(port_num=self.port_num) + rtr_attr.path_mtu = int(mtu) + rtr_attr.max_dest_rd_atomic = int(rd_atomic) + rtr_attr.min_rnr_timer = 12 + rtr_attr.dest_qp_num = int(self.remote_info["qpn"]) + rtr_attr.rq_psn = int(self.remote_info["psn"]) + # Some drivers require GRH(is_global=1), others only accept non-GRH. + if is_global == 1: + rtr_attr.ah_attr = AHAttr(port_num=self.port_num, is_global=1, gr=gr, dlid=int(dlid)) + else: + rtr_attr.ah_attr = AHAttr(port_num=self.port_num, is_global=0, dlid=int(dlid)) + self.qp.to_rtr(rtr_attr) + last_exc = None + break + except Exception as exc: + last_exc = exc + continue + if last_exc is None: + break + if last_exc is None: + break + if last_exc is None: + break + if last_exc is not None: + raise last_exc rts_attr = QPAttr(port_num=self.port_num) rts_attr.timeout = 14 diff --git a/lightx2v/disagg/rdma_server.py b/lightx2v/disagg/rdma_server.py index b8a1357a5..63e272cf1 100644 --- a/lightx2v/disagg/rdma_server.py +++ b/lightx2v/disagg/rdma_server.py @@ -3,38 +3,28 @@ import socket import threading -import pyverbs.enums as e -from pyverbs.addr import GID, AHAttr, GlobalRoute -from pyverbs.cq import CQ -from pyverbs.device import Context, get_device_list -from pyverbs.mr import MR -from pyverbs.pd import PD -from pyverbs.qp import QP, QPAttr, QPCap, QPInitAttr - -from lightx2v.disagg.rdma_utils import resolve_gid_index - - -class IBDevice: - def __init__(self, name: str): - self.name = name - - def open(self): - return Context(name=self.name) - - -class QPType: - RC = e.IBV_QPT_RC - - -class WROpcode: - RDMA_WRITE = e.IBV_WR_RDMA_WRITE - - -class AccessFlag: - LOCAL_WRITE = e.IBV_ACCESS_LOCAL_WRITE - REMOTE_WRITE = e.IBV_ACCESS_REMOTE_WRITE - REMOTE_READ = e.IBV_ACCESS_REMOTE_READ - REMOTE_ATOMIC = e.IBV_ACCESS_REMOTE_ATOMIC +from lightx2v.disagg.rdma_base import ( + AccessFlag, + AHAttr, + CQ, + GID, + GlobalRoute, + IBDevice, + MR, + PD, + QP, + QPAttr, + QPCap, + QPInitAttr, + QPType, + e, + get_device_list, + recv_json_from_stream, + resolve_gid_index, + rtr_ah_dest_dlid, + rtr_path_mtu, + rtr_path_mtu_negotiated, +) class RDMAServer: @@ -142,7 +132,27 @@ def get_local_info(self, qp=None, psn=None): qp = self.qp if qp is None else qp psn = self.local_psn if psn is None else int(psn) gid = self.ctx.query_gid(self.port_num, self.gid_index) - return {"lid": self.ctx.query_port(self.port_num).lid, "qpn": qp.qp_num, "psn": psn, "gid": str(gid), "gid_index": self.gid_index, "rkey": self.mr.rkey, "addr": mr_addr} + return { + "lid": self.ctx.query_port(self.port_num).lid, + "qpn": qp.qp_num, + "psn": psn, + "gid": str(gid), + "gid_index": self.gid_index, + "rkey": self.mr.rkey, + "addr": mr_addr, + "active_mtu": int(rtr_path_mtu(self.ctx, self.port_num)), + } + + @staticmethod + def _safe_destroy_qp(qp): + if qp is None: + return + close_fn = getattr(qp, "close", None) + if callable(close_fn): + try: + close_fn() + except Exception: + pass def _alloc_qp_with_psn(self): with self._conn_lock: @@ -158,18 +168,24 @@ def _accept_one_client(self, listen_sock): print(f"[Server] Connected to {addr}") qp, local_psn = self._alloc_qp_with_psn() - - # 1. 发送我的信息给 Client - my_info = self.get_local_info(qp=qp, psn=local_psn) - conn.sendall(json.dumps(my_info).encode()) - - # 2. 接收 Client 的信息 - data = conn.recv(4096) - remote_info = json.loads(data.decode()) - print(f"[Server] Received remote info: QPN={remote_info['qpn']}") - - # 3. 修改 QP 状态到 RTS - self._modify_qp_to_rts(qp, remote_info, local_psn) + try: + # 1. 发送我的信息给 Client + my_info = self.get_local_info(qp=qp, psn=local_psn) + conn.sendall(json.dumps(my_info).encode()) + + # 2. 接收 Client 的信息(可能分片,勿单次 recv) + remote_info = recv_json_from_stream(conn, timeout_sec=30.0) + print(f"[Server] Received remote info: QPN={remote_info['qpn']}") + + # 3. 修改 QP 状态到 RTS + self._modify_qp_to_rts(qp, remote_info, local_psn) + except BaseException: + self._safe_destroy_qp(qp) + try: + conn.close() + except Exception: + pass + raise with self._conn_lock: self._active_qps.append(qp) @@ -200,22 +216,56 @@ def handshake(self, host="0.0.0.0", port=5566, serve_forever=True): def _modify_qp_to_rts(self, qp, remote_info, local_psn): # Follow the standard RC flow: INIT -> RTR -> RTS. - init_attr = QPAttr(port_num=self.port_num) - init_attr.qp_access_flags = AccessFlag.LOCAL_WRITE | AccessFlag.REMOTE_WRITE | AccessFlag.REMOTE_READ | AccessFlag.REMOTE_ATOMIC - qp.to_init(init_attr) - - rtr_attr = QPAttr(port_num=self.port_num) - rtr_attr.path_mtu = e.IBV_MTU_1024 - rtr_attr.max_dest_rd_atomic = 1 - rtr_attr.min_rnr_timer = 12 - rtr_attr.dest_qp_num = int(remote_info["qpn"]) - rtr_attr.rq_psn = int(remote_info["psn"]) - remote_lid = int(remote_info.get("lid", 0)) - remote_gid_index = int(remote_info.get("gid_index", self.gid_index)) - gr = GlobalRoute(dgid=GID(remote_info["gid"]), sgid_index=remote_gid_index) - rtr_attr.ah_attr = AHAttr(port_num=self.port_num, is_global=1, gr=gr, dlid=remote_lid) - qp.to_rtr(rtr_attr) + heuristic_dlid = rtr_ah_dest_dlid(self.ctx, self.port_num, remote_lid) + negotiated_mtu = int(rtr_path_mtu_negotiated(self.ctx, self.port_num, remote_info.get("active_mtu"))) + local_mtu = int(rtr_path_mtu(self.ctx, self.port_num)) + default_mtu = int(e.IBV_MTU_1024) + + mtu_candidates = [] + for v in (negotiated_mtu, local_mtu, default_mtu): + if v not in mtu_candidates: + mtu_candidates.append(v) + dlid_candidates = [] + for v in (heuristic_dlid, 0, remote_lid): + if v not in dlid_candidates: + dlid_candidates.append(v) + + gr = GlobalRoute(dgid=GID(remote_info["gid"]), sgid_index=self.gid_index, hop_limit=1) + last_exc = None + for rd_atomic in (1, 0): + for mtu in mtu_candidates: + for dlid in dlid_candidates: + for is_global in (1, 0): + try: + init_attr = QPAttr(port_num=self.port_num) + init_attr.qp_access_flags = AccessFlag.LOCAL_WRITE | AccessFlag.REMOTE_WRITE | AccessFlag.REMOTE_READ | AccessFlag.REMOTE_ATOMIC + qp.to_init(init_attr) + + rtr_attr = QPAttr(port_num=self.port_num) + rtr_attr.path_mtu = int(mtu) + rtr_attr.max_dest_rd_atomic = int(rd_atomic) + rtr_attr.min_rnr_timer = 12 + rtr_attr.dest_qp_num = int(remote_info["qpn"]) + rtr_attr.rq_psn = int(remote_info["psn"]) + if is_global == 1: + rtr_attr.ah_attr = AHAttr(port_num=self.port_num, is_global=1, gr=gr, dlid=int(dlid)) + else: + rtr_attr.ah_attr = AHAttr(port_num=self.port_num, is_global=0, dlid=int(dlid)) + qp.to_rtr(rtr_attr) + last_exc = None + break + except Exception as exc: + last_exc = exc + continue + if last_exc is None: + break + if last_exc is None: + break + if last_exc is None: + break + if last_exc is not None: + raise last_exc rts_attr = QPAttr(port_num=self.port_num) rts_attr.timeout = 14 diff --git a/lightx2v/disagg/rdma_utils.py b/lightx2v/disagg/rdma_utils.py index fcdba75e5..629340ef9 100644 --- a/lightx2v/disagg/rdma_utils.py +++ b/lightx2v/disagg/rdma_utils.py @@ -1,7 +1,14 @@ from __future__ import annotations +import ipaddress +import json +import logging import os import socket +import time + + +logger = logging.getLogger(__name__) def _collect_local_ipv4_addresses() -> list[str]: @@ -39,19 +46,82 @@ def _collect_local_ipv4_addresses() -> list[str]: def _gid_to_ipv4(gid_text: str) -> str | None: - if gid_text.startswith("::ffff:"): - return gid_text.removeprefix("::ffff:") + """Map an IPv6 GID string to IPv4 when it is IPv4-mapped.""" + text = str(gid_text).strip() + if not text or text == "::": + return None + lower = text.lower() + if lower.startswith("::ffff:"): + return _canonical_ipv4(text[7:]) + try: + mapped = ipaddress.ip_address(text).ipv4_mapped + if mapped is not None: + return str(mapped) + except ValueError: + pass return None +def _canonical_ipv4(text: str) -> str | None: + text = str(text).strip() + if not text: + return None + try: + return str(ipaddress.IPv4Address(text)) + except Exception: + return None + + +def _preferred_rdma_ipv4() -> str | None: + """RoCE GID row to prefer when auto-picking gid_index (multi-node / multi-homing).""" + v = _canonical_ipv4(os.getenv("RDMA_PREFERRED_IPV4", "")) + if v: + return v + return _canonical_ipv4(os.getenv("MOONCAKE_LOCAL_HOSTNAME", "")) + + def resolve_gid_index(ctx, port_num: int, env_var_name: str = "RDMA_GID_INDEX") -> int: + local_ipv4s = _collect_local_ipv4_addresses() + preferred = _preferred_rdma_ipv4() + env_gid = os.getenv(env_var_name, "").strip() if env_gid: - idx = int(env_gid) - ctx.query_gid(port_num=port_num, index=idx) - return idx - - local_ipv4s = _collect_local_ipv4_addresses() + try: + idx = int(env_gid) + except ValueError: + idx = -1 + else: + try: + gid_text = str(ctx.query_gid(port_num=port_num, index=idx)) + except Exception: + gid_text = "" + else: + if gid_text and gid_text != "::": + ipv4 = _gid_to_ipv4(gid_text) + if ipv4 is not None and (ipv4 in local_ipv4s or ipv4 == preferred): + return idx + + try: + logger.warning( + "Ignoring RDMA_GID_INDEX=%s because it does not map to a local IPv4 on this host (local_ipv4s=%s preferred=%s)", + env_gid, + local_ipv4s, + preferred, + ) + except Exception: + pass + + if preferred: + for idx in range(16): + try: + gid_text = str(ctx.query_gid(port_num=port_num, index=idx)) + except Exception: + continue + if not gid_text or gid_text == "::": + continue + ipv4 = _gid_to_ipv4(gid_text) + if ipv4 == preferred: + return idx mapped_candidates: list[tuple[int, str]] = [] first_non_empty_idx: int | None = None @@ -81,4 +151,85 @@ def resolve_gid_index(ctx, port_num: int, env_var_name: str = "RDMA_GID_INDEX") return first_non_empty_idx ctx.query_gid(port_num=port_num, index=0) - return 0 \ No newline at end of file + return 0 + + +def recv_json_from_stream(sock: socket.socket, timeout_sec: float = 10.0) -> dict: + """Read one JSON object from a TCP stream (handles split packets).""" + decoder = json.JSONDecoder() + chunks: list[bytes] = [] + deadline = time.time() + float(timeout_sec) + while time.time() < deadline: + try: + sock.settimeout(max(0.01, deadline - time.time())) + chunk = sock.recv(65536) + except socket.timeout: + continue + if not chunk: + break + chunks.append(chunk) + payload = b"".join(chunks).decode("utf-8", errors="strict") + try: + obj, _ = decoder.raw_decode(payload) + if isinstance(obj, dict): + return obj + except json.JSONDecodeError: + continue + msg = b"".join(chunks).decode("utf-8", errors="ignore") + raise RuntimeError(f"Incomplete handshake JSON from peer: {msg!r}") + + +def rtr_ah_dest_dlid(ctx, port_num: int, remote_lid: int) -> int: + """Destination LID for RC QP RTR when using GRH. + + RoCE (Ethernet link layer) expects dlid 0 with a valid dgid in the GRH; some + drivers still report a non-zero port LID, and using that in AHAttr triggers + ibv_modify_qp RTR EINVAL on many setups. + """ + rl = int(remote_lid) + raw = os.getenv("RDMA_RTR_DLID", "").strip().lower() + if raw in ("0", "zero", "roce", "eth"): + return 0 + if raw in ("peer", "remote", "ib", "infiniband"): + return rl + try: + port = ctx.query_port(port_num) + ll = int(getattr(port, "link_layer", -1)) + local_lid = int(getattr(port, "lid", -1)) + # rdma-core ibv_port_attr.link_layer: 0 unspecified, 1 InfiniBand, 2 Ethernet (RoCE). + if ll == 2: + return 0 + if ll == 1: + return rl + # Unspecified / unknown link_layer: eRDMA and some stacks omit or mis-report; + # RoCE uses LID 0 — using a non-zero dlid here causes RTR EINVAL. + if local_lid == 0: + return 0 + except Exception: + pass + if rl == 0: + return 0 + return rl + + +def rtr_path_mtu(ctx, port_num: int) -> int: + """Use port active MTU for RTR path_mtu (avoids hard-coded 1024 vs link mismatch).""" + try: + port = ctx.query_port(port_num) + return int(port.active_mtu) + except Exception: + import pyverbs.enums as e + + return int(e.IBV_MTU_1024) + + +def rtr_path_mtu_negotiated(ctx, port_num: int, peer_active_mtu: int | None) -> int: + """path_mtu for RTR must not exceed either peer's active MTU (IB enum ordering).""" + local = rtr_path_mtu(ctx, port_num) + if peer_active_mtu is None: + return local + try: + peer = int(peer_active_mtu) + except (TypeError, ValueError): + return local + return min(local, peer) \ No newline at end of file diff --git a/lightx2v/disagg/services/controller.py b/lightx2v/disagg/services/controller.py index e9f76dbe2..fea94d668 100644 --- a/lightx2v/disagg/services/controller.py +++ b/lightx2v/disagg/services/controller.py @@ -1,7 +1,10 @@ +import ipaddress import json +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer import os import shlex import signal +import shutil import socket import subprocess import sys @@ -62,6 +65,7 @@ def __init__(self): self._free_slot_ids: set[int] = set() self._slot_reuse_block_until: dict[int, float] = {} self._local_host_aliases: set[str] = set() + self._request_metrics_by_room: dict[int, dict[str, Any]] = {} def _is_monitor_enabled(self) -> bool: raw = os.getenv("ENABLE_MONITOR") @@ -69,6 +73,12 @@ def _is_monitor_enabled(self) -> bool: return False return str(raw).strip().lower() in {"1", "true", "yes", "on"} + def _is_centralized_enabled(self) -> bool: + raw = os.getenv("IS_CENTRALIZED") + if raw is None: + return False + return str(raw).strip().lower() in {"1", "true", "yes", "on"} + def _is_tcp_port_open(self, host: str, port: int) -> bool: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: sock.settimeout(0.2) @@ -111,6 +121,18 @@ def _is_local_host(self, host: str) -> bool: except Exception: return False + def _ensure_rdma_preferred_ipv4_env(self, host: str, env: dict[str, str]) -> None: + """So RoCE gid_index matches the data-plane IP on each worker (multi-node).""" + if env.get("RDMA_PREFERRED_IPV4"): + return + h = str(host).strip() + if not h: + return + try: + env["RDMA_PREFERRED_IPV4"] = str(ipaddress.IPv4Address(h)) + except Exception: + pass + def _allocate_free_tcp_port(self, bind_host: str | None = None) -> int: host = str(bind_host or self._bootstrap_addr) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: @@ -144,6 +166,90 @@ def _build_service_command(self, instance_type: str, engine_rank: int, instance_ str(instance_cfg.get("save_path", "")), ] + def _maybe_wrap_service_command_with_nsys( + self, + *, + host: str, + instance_type: str, + engine_rank: int, + instance_cfg: dict[str, Any], + command: list[str], + ) -> list[str]: + if not self._is_truthy(os.getenv("DISAGG_ENABLE_NSYS"), default=False): + return command + + if not self._is_local_host(host): + self.logger.info( + "Skip nsys profiling for remote %s instance host=%s rank=%s", + instance_type, + host, + engine_rank, + ) + return command + + nsys_bin = shutil.which(os.getenv("DISAGG_NSYS_BIN", "nsys")) + if nsys_bin is None: + self.logger.warning("DISAGG_ENABLE_NSYS is set but nsys is not available, skip profiling for %s rank=%s", instance_type, engine_rank) + return command + + output_dir_raw = os.getenv("DISAGG_NSYS_OUTPUT_DIR") + if output_dir_raw: + output_dir = Path(output_dir_raw) + else: + base_save_path = instance_cfg.get("save_path") or (self._runtime_config or {}).get("save_path") or str(Path(__file__).resolve().parents[3] / "save_results" / "wan22_i2v_dynamic.mp4") + output_dir = Path(str(base_save_path)).parent / "nsys" + output_dir.mkdir(parents=True, exist_ok=True) + + output_name = f"{instance_type}_rank{engine_rank}" + trace = os.getenv("DISAGG_NSYS_TRACE", "cuda,nvtx,osrt") + extra_args = shlex.split(os.getenv("DISAGG_NSYS_EXTRA_ARGS", "")) + + profiled_command = [ + nsys_bin, + "profile", + "--force-overwrite=true", + "--trace", + trace, + "-o", + str(output_dir / output_name), + ] + profiled_command.extend(extra_args) + profiled_command.extend(command) + return profiled_command + + def _merge_request_metrics(self, existing: dict[str, Any] | None, update: dict[str, Any] | None) -> dict[str, Any]: + merged: dict[str, Any] = {} + if isinstance(existing, dict): + merged.update(existing) + if not isinstance(update, dict): + return merged + + for key, value in update.items(): + if key != "stages" or not isinstance(value, dict): + merged[key] = value + continue + + merged_stages: dict[str, Any] = {} + existing_stages = merged.get("stages") + if isinstance(existing_stages, dict): + for stage_name, stage_metrics in existing_stages.items(): + merged_stages[stage_name] = dict(stage_metrics) if isinstance(stage_metrics, dict) else stage_metrics + + for stage_name, stage_metrics in value.items(): + if not isinstance(stage_metrics, dict): + continue + base_stage_metrics = merged_stages.get(stage_name) + if isinstance(base_stage_metrics, dict): + combined_stage_metrics = dict(base_stage_metrics) + combined_stage_metrics.update(stage_metrics) + else: + combined_stage_metrics = dict(stage_metrics) + merged_stages[stage_name] = combined_stage_metrics + + merged["stages"] = merged_stages + + return merged + def _query_zmq(self, req_addr: str, payload: dict[str, Any], timeout_ms: int = 1000) -> dict[str, Any] | None: context = zmq.Context() req = context.socket(zmq.REQ) @@ -165,6 +271,48 @@ def _query_zmq(self, req_addr: str, payload: dict[str, Any], timeout_ms: int = 1 def _query_sidecar(self, req_addr: str, cmd: str) -> dict[str, Any] | None: return self._query_zmq(req_addr, {"cmd": str(cmd)}, timeout_ms=1000) + def _run_centralized_ok_server(self, stop_event: Event, bind_host: str, bind_port: int): + controller = self + + class _Handler(BaseHTTPRequestHandler): + def do_POST(self): + if self.path != "/ok": + self.send_response(404) + self.end_headers() + return + + content_length = int(self.headers.get("Content-Length", "0") or "0") + raw_body = self.rfile.read(content_length) if content_length > 0 else b"" + try: + message = json.loads(raw_body.decode("utf-8")) if raw_body else {} + except Exception: + self.send_response(400) + self.end_headers() + return + + controller.logger.info( + "Received centralized OK control message: stage=%s room=%s", + message.get("stage_name"), + message.get("data_bootstrap_room"), + ) + response = json.dumps({"ok": True, "control": "OK"}).encode("utf-8") + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(response))) + self.end_headers() + self.wfile.write(response) + + def log_message(self, format, *args): + return + + server = ThreadingHTTPServer((bind_host, bind_port), _Handler) + server.timeout = 0.2 + try: + while not stop_event.is_set(): + server.handle_request() + finally: + server.server_close() + def _is_truthy(self, value: Any, default: bool = False) -> bool: if value is None: return default @@ -324,6 +472,11 @@ def _launch_remote_instance(self, slot: dict[str, Any], instance_type: str, cmd: if isinstance(extra_env, dict): for key, value in extra_env.items(): normalized_env[str(key)] = str(value) + if self._is_centralized_enabled(): + normalized_env["IS_CENTRALIZED"] = "1" + if os.getenv("SYNC_COMM") is not None: + normalized_env["SYNC_COMM"] = str(os.getenv("SYNC_COMM", "0")) + self._ensure_rdma_preferred_ipv4_env(host, normalized_env) sidecar_env_vars = { **normalized_env, @@ -434,6 +587,7 @@ def _launch_remote_instance_via_proxy(self, slot: dict[str, Any], instance_type: if isinstance(extra_env, dict): for key, value in extra_env.items(): normalized_env[str(key)] = str(value) + self._ensure_rdma_preferred_ipv4_env(host, normalized_env) proxy_req_addr = self._remote_proxy_req_addr(slot) payload = { @@ -923,11 +1077,29 @@ def _handle_decoder_result( if not isinstance(result, dict): self.logger.warning("Ignored non-dict decoder result: %s", result) return + + message_type = str(result.get("message_type", "decoder_result")) room = result.get("data_bootstrap_room") if room is None: self.logger.warning("Ignored decoder result without data_bootstrap_room: %s", result) return room = int(room) + + if message_type == "stage_metrics": + request_metrics = result.get("request_metrics") + if not isinstance(request_metrics, dict): + self.logger.warning("Ignored stage metrics update without request_metrics: %s", result) + return + merged_metrics = self._merge_request_metrics(self._request_metrics_by_room.get(room), request_metrics) + self._request_metrics_by_room[room] = merged_metrics + self.logger.info( + "Stage metrics updated room=%s stage=%s metrics=%s", + room, + result.get("stage_name"), + request_metrics.get("stages", {}), + ) + return + if room not in expected_rooms: self.logger.warning("Ignored decoder result for unexpected room=%s: %s", room, result) return @@ -935,6 +1107,15 @@ def _handle_decoder_result( self.logger.info("Duplicate decoder result for room=%s ignored", room) return + stored_metrics = self._request_metrics_by_room.get(room) + request_metrics = result.get("request_metrics") + if isinstance(request_metrics, dict): + merged_metrics = self._merge_request_metrics(stored_metrics, request_metrics) + self._request_metrics_by_room[room] = merged_metrics + result["request_metrics"] = merged_metrics + elif isinstance(stored_metrics, dict): + result["request_metrics"] = stored_metrics + controller_recv_ts = time.time() latency_summary = self._build_latency_summary(result, controller_recv_ts) if latency_summary is not None: @@ -1196,6 +1377,13 @@ def create_instance(self, instance_type: str) -> str: service_config_json = self._resolve_service_config_json(str(config_json), instance_type) cmd = self._build_service_command(instance_type, engine_rank, instance_cfg, service_config_json) + cmd = self._maybe_wrap_service_command_with_nsys( + host=host, + instance_type=instance_type, + engine_rank=engine_rank, + instance_cfg=instance_cfg, + command=cmd, + ) process: subprocess.Popen | None = None process_meta: dict[str, Any] | None = None @@ -1209,6 +1397,12 @@ def create_instance(self, instance_type: str) -> str: else: sidecar_meta = self._start_sidecar_process(instance_type, cuda_device, bind_host=host) env = os.environ.copy() + if use_static_slots and selected_slot is not None: + slot_env = selected_slot.get("env") + if isinstance(slot_env, dict): + for key, value in slot_env.items(): + env[str(key)] = str(value) + self._ensure_rdma_preferred_ipv4_env(host, env) env["CUDA_VISIBLE_DEVICES"] = str(cuda_device) env["LIGHTX2V_SIDECAR_PUSH_ADDR"] = str(sidecar_meta["push_addr"]) env["LIGHTX2V_SIDECAR_REQ_ADDR"] = str(sidecar_meta["req_addr"]) @@ -1254,7 +1448,6 @@ def create_instance(self, instance_type: str) -> str: self._free_slot_ids.remove(int(selected_slot["slot_id"])) else: self._free_gpus.remove(engine_rank) - # self.add_instance(instance_type, instance_address) if self._enable_monitor: monitor_node = f"tcp://{host}:{MONITOR_POLLING_PORT + engine_rank}" if monitor_node not in self.monitor.nodes: @@ -1272,6 +1465,7 @@ def create_instance(self, instance_type: str) -> str: "static_slot": self._to_plain(selected_slot) if selected_slot is not None else None, } self.started_instances.append((instance_type, instance_address)) + self.add_instance(instance_type, instance_address) self.logger.info( "Created %s instance host=%s rank=%s mode=%s address=%s", instance_type, @@ -1318,7 +1512,7 @@ def reclaim_instance(self, instance_type: str, instance_address: str | None = No slot_id_raw = meta.get("slot_id") slot_id = int(slot_id_raw) if slot_id_raw is not None else None - # self.remove_instance(instance_type, target_address) + self.remove_instance(instance_type, target_address) monitor_node = self._monitor_node_from_instance_address(target_address) if launch_mode == "remote": @@ -1415,21 +1609,24 @@ def _init_request_rdma_buffer(self, bootstrap_addr: str, config: dict): config["rdma_phase2_handshake_port"] = phase2_handshake_port need_bytes = 16 + slots * slot_size - self._rdma_server_request = RDMAServer(buffer_size=need_bytes) - self.rdma_buffer_request = RDMABuffer( - role="server", - buffer_size=slots, - slot_size=slot_size, - rdma_server=self._rdma_server_request, - ) + if not self._is_centralized_enabled(): + self._rdma_server_request = RDMAServer(buffer_size=need_bytes) + self.rdma_buffer_request = RDMABuffer( + role="server", + buffer_size=slots, + slot_size=slot_size, + rdma_server=self._rdma_server_request, + ) - self._rdma_handshake_thread_request = Thread( - target=self._rdma_server_request.handshake, - kwargs={"host": bootstrap_addr, "port": handshake_port}, - name="controller-rdma-handshake", - daemon=True, - ) - self._rdma_handshake_thread_request.start() + self._rdma_handshake_thread_request = Thread( + target=self._rdma_server_request.handshake, + kwargs={"host": bootstrap_addr, "port": handshake_port}, + name="controller-rdma-handshake", + daemon=True, + ) + self._rdma_handshake_thread_request.start() + else: + self.logger.info("IS_CENTRALIZED enabled, skip controller request RDMA ring initialization") need_bytes_phase1 = 16 + phase1_slots * phase1_slot_size self._rdma_server_phase1 = RDMAServer(buffer_size=need_bytes_phase1) @@ -1464,9 +1661,9 @@ def _init_request_rdma_buffer(self, bootstrap_addr: str, config: dict): self._rdma_handshake_thread_phase2.start() self.logger.info( "Initialized RDMA buffers: request=(%s,%s,%s) phase1=(%s,%s,%s) phase2=(%s,%s,%s)", - slots, - slot_size, - need_bytes, + slots if self.rdma_buffer_request is not None else 0, + slot_size if self.rdma_buffer_request is not None else 0, + need_bytes if self.rdma_buffer_request is not None else 0, phase1_slots, phase1_slot_size, need_bytes_phase1, @@ -1498,10 +1695,6 @@ def _build_latency_summary(self, result: dict[str, Any], controller_recv_ts: flo if not isinstance(request_metrics, dict): return None - stages = request_metrics.get("stages") - if not isinstance(stages, dict): - return None - def _as_float(value: Any) -> float | None: try: return float(value) @@ -1509,10 +1702,21 @@ def _as_float(value: Any) -> float | None: return None def _stage(name: str) -> dict[str, Any]: + stages = request_metrics.get("stages") + if not isinstance(stages, dict): + return {} stage_metrics = stages.get(name) return stage_metrics if isinstance(stage_metrics, dict) else {} controller_send_ts = _as_float(request_metrics.get("controller_send_ts")) + if controller_send_ts is None: + return None + + centralized_mode = self._is_centralized_enabled() + summary: dict[str, float] = { + "end_to_end_delay_s": controller_recv_ts - controller_send_ts, + } + encoder = _stage("encoder") transformer = _stage("transformer") decoder = _stage("decoder") @@ -1532,38 +1736,78 @@ def _stage(name: str) -> dict[str, Any]: decoder_compute_end_ts = _as_float(decoder.get("compute_end_ts")) decoder_output_enqueued_ts = _as_float(decoder.get("output_enqueued_ts")) - required_values = [ - controller_send_ts, - encoder_recv_ts, - encoder_compute_start_ts, - encoder_compute_end_ts, - encoder_output_enqueued_ts, - transformer_recv_ts, - transformer_compute_start_ts, - transformer_compute_end_ts, - transformer_output_enqueued_ts, - decoder_recv_ts, - decoder_compute_start_ts, - decoder_compute_end_ts, - decoder_output_enqueued_ts, - ] - if any(value is None for value in required_values): - return None - - summary: dict[str, float] = { - "controller_to_encoder_comm_delay_s": encoder_recv_ts - controller_send_ts, - "encoder_scheduling_delay_s": encoder_compute_start_ts - encoder_recv_ts, - "encoder_compute_delay_s": encoder_compute_end_ts - encoder_compute_start_ts, - "encoder_communication_delay_s": transformer_recv_ts - encoder_output_enqueued_ts, - "transformer_scheduling_delay_s": transformer_compute_start_ts - transformer_recv_ts, - "transformer_compute_delay_s": transformer_compute_end_ts - transformer_compute_start_ts, - "transformer_communication_delay_s": decoder_recv_ts - transformer_output_enqueued_ts, - "decoder_scheduling_delay_s": decoder_compute_start_ts - decoder_recv_ts, - "decoder_compute_delay_s": decoder_compute_end_ts - decoder_compute_start_ts, - "decoder_communication_delay_s": controller_recv_ts - decoder_output_enqueued_ts, - "end_to_end_delay_s": controller_recv_ts - controller_send_ts, - } - summary["sum_of_components_s"] = sum(value for key, value in summary.items() if key != "end_to_end_delay_s" and key != "sum_of_components_s") + if centralized_mode: + if encoder_recv_ts is not None: + summary["controller_to_encoder_comm_delay_s"] = encoder_recv_ts - controller_send_ts + if encoder_recv_ts is not None and encoder_compute_start_ts is not None: + summary["encoder_scheduling_delay_s"] = encoder_compute_start_ts - encoder_recv_ts + if encoder_compute_start_ts is not None and encoder_compute_end_ts is not None: + summary["encoder_compute_delay_s"] = encoder_compute_end_ts - encoder_compute_start_ts + if encoder_output_enqueued_ts is not None and transformer_recv_ts is not None: + summary["encoder_communication_delay_s"] = transformer_recv_ts - controller_send_ts + if transformer_recv_ts is not None and transformer_compute_start_ts is not None: + summary["transformer_scheduling_delay_s"] = transformer_compute_start_ts - transformer_recv_ts + if transformer_compute_start_ts is not None and transformer_compute_end_ts is not None: + summary["transformer_compute_delay_s"] = transformer_compute_end_ts - transformer_compute_start_ts + if transformer_recv_ts is not None: + summary["transformer_communication_delay_s"] = transformer_recv_ts - controller_send_ts + if decoder_recv_ts is not None and decoder_compute_start_ts is not None: + summary["decoder_scheduling_delay_s"] = decoder_compute_start_ts - decoder_recv_ts + if decoder_compute_start_ts is not None and decoder_compute_end_ts is not None: + summary["decoder_compute_delay_s"] = decoder_compute_end_ts - decoder_compute_start_ts + if decoder_recv_ts is not None: + summary["decoder_communication_delay_s"] = decoder_recv_ts - controller_send_ts + + component_keys = [ + "controller_to_encoder_comm_delay_s", + "encoder_scheduling_delay_s", + "encoder_compute_delay_s", + "encoder_communication_delay_s", + "transformer_scheduling_delay_s", + "transformer_compute_delay_s", + "transformer_communication_delay_s", + "decoder_scheduling_delay_s", + "decoder_compute_delay_s", + "decoder_communication_delay_s", + ] + if all(key in summary for key in component_keys): + summary["sum_of_components_s"] = sum(summary[key] for key in component_keys) + else: + if encoder_recv_ts is not None: + summary["controller_to_encoder_comm_delay_s"] = encoder_recv_ts - controller_send_ts + if encoder_recv_ts is not None and encoder_compute_start_ts is not None: + summary["encoder_scheduling_delay_s"] = encoder_compute_start_ts - encoder_recv_ts + if encoder_compute_start_ts is not None and encoder_compute_end_ts is not None: + summary["encoder_compute_delay_s"] = encoder_compute_end_ts - encoder_compute_start_ts + if encoder_output_enqueued_ts is not None and transformer_recv_ts is not None: + summary["encoder_communication_delay_s"] = transformer_recv_ts - encoder_output_enqueued_ts + if transformer_recv_ts is not None and transformer_compute_start_ts is not None: + summary["transformer_scheduling_delay_s"] = transformer_compute_start_ts - transformer_recv_ts + if transformer_compute_start_ts is not None and transformer_compute_end_ts is not None: + summary["transformer_compute_delay_s"] = transformer_compute_end_ts - transformer_compute_start_ts + if transformer_output_enqueued_ts is not None and decoder_recv_ts is not None: + summary["transformer_communication_delay_s"] = decoder_recv_ts - transformer_output_enqueued_ts + if decoder_recv_ts is not None and decoder_compute_start_ts is not None: + summary["decoder_scheduling_delay_s"] = decoder_compute_start_ts - decoder_recv_ts + if decoder_compute_start_ts is not None and decoder_compute_end_ts is not None: + summary["decoder_compute_delay_s"] = decoder_compute_end_ts - decoder_compute_start_ts + if decoder_output_enqueued_ts is not None: + summary["decoder_communication_delay_s"] = controller_recv_ts - decoder_output_enqueued_ts + + component_keys = [ + "controller_to_encoder_comm_delay_s", + "encoder_scheduling_delay_s", + "encoder_compute_delay_s", + "encoder_communication_delay_s", + "transformer_scheduling_delay_s", + "transformer_compute_delay_s", + "transformer_communication_delay_s", + "decoder_scheduling_delay_s", + "decoder_compute_delay_s", + "decoder_communication_delay_s", + ] + if all(key in summary for key in component_keys): + summary["sum_of_components_s"] = sum(summary[key] for key in component_keys) return summary def add_instance(self, instance_type: str, instance_address: str): @@ -1599,6 +1843,50 @@ def send_request(self, config): if config is None: raise ValueError("config cannot be None") + room_raw = config.get("data_bootstrap_room") + try: + room = int(room_raw) + except (TypeError, ValueError): + room = None + + request_metrics = config.get("request_metrics") + if room is not None and isinstance(request_metrics, dict): + self._request_metrics_by_room[room] = self._merge_request_metrics(None, request_metrics) + + if self._is_centralized_enabled(): + request_config = self._to_plain(config) + + encoder_address = self.encoder_policy.schedule() + transformer_address = self.transformer_policy.schedule() + decoder_address = self.decoder_policy.schedule() + + def _address_to_rank(instance_address: str) -> int: + _, port_str = instance_address.rsplit(":", 1) + return int(port_str) - REQUEST_POLLING_PORT + + encoder_rank = _address_to_rank(encoder_address) + transformer_rank = _address_to_rank(transformer_address) + decoder_rank = _address_to_rank(decoder_address) + + request_config["encoder_engine_rank"] = encoder_rank + request_config["transformer_engine_rank"] = transformer_rank + request_config["decoder_engine_rank"] = decoder_rank + request_config["encoder_node_address"] = encoder_address + request_config["transformer_node_address"] = transformer_address + request_config["decoder_node_address"] = decoder_address + request_config["controller_control_host"] = request_config.get("controller_result_host", self._bootstrap_addr) + request_config["controller_control_port"] = int(request_config.get("controller_control_port", REQUEST_POLLING_PORT - 3)) + + for instance_type, target_address in ( + ("encoder", encoder_address), + ("transformer", transformer_address), + ("decoder", decoder_address), + ): + host, port_str = target_address.rsplit(":", 1) + self.req_mgr.send(host, int(port_str), request_config) + self.logger.info("Request dispatched to %s via ZMQ: target=%s", instance_type, target_address) + return + if self.rdma_buffer_request is None: raise RuntimeError("RDMA request buffer is not initialized") self.rdma_buffer_request.produce(config) @@ -1614,16 +1902,20 @@ def run(self, config): bootstrap_addr = config.get("data_bootstrap_addr", "127.0.0.1") request_ingress_port = int(config.get("controller_request_port", os.getenv("DISAGG_CONTROLLER_REQUEST_PORT", REQUEST_POLLING_PORT - 2))) result_port = int(config.get("controller_result_port", REQUEST_POLLING_PORT - 1)) + control_port = int(config.get("controller_control_port", REQUEST_POLLING_PORT - 3)) self._bootstrap_addr = str(bootstrap_addr) self._runtime_config = self._to_plain(config) self._init_gpu_pool(config) self._enable_monitor = self._is_monitor_enabled() + centralized_mode = self._is_centralized_enabled() # self.encoder_policy = RoundRobinPolicy() # self.transformer_policy = RoundRobinPolicy() # self.decoder_policy = RoundRobinPolicy() self._init_request_rdma_buffer(bootstrap_addr, config) + if centralized_mode: + self.logger.info("IS_CENTRALIZED enabled, controller will dispatch requests via ZMQ") time.sleep(5.0) @@ -1650,6 +1942,8 @@ def run(self, config): monitor_stop_event: Event | None = None monitor_thread: Thread | None = None + ok_gate_stop_event: Event | None = None + ok_gate_thread: Thread | None = None self._monitor_runtime = None if self._enable_monitor: @@ -1695,6 +1989,17 @@ def run(self, config): else: self.logger.info("ENABLE_MONITOR is not set, skip monitor logic") + if centralized_mode: + ok_gate_stop_event = Event() + ok_gate_thread = Thread( + target=self._run_centralized_ok_server, + args=(ok_gate_stop_event, self._bootstrap_addr, control_port), + name="controller-ok-gate", + daemon=True, + ) + ok_gate_thread.start() + self.logger.info("Centralized OK gate server started on %s:%s", self._bootstrap_addr, control_port) + time.sleep(5.0) base_save_path = config.get("save_path") @@ -1825,6 +2130,10 @@ def run(self, config): monitor_stop_event.set() if monitor_thread is not None: monitor_thread.join(timeout=2.0) + if ok_gate_stop_event is not None: + ok_gate_stop_event.set() + if ok_gate_thread is not None: + ok_gate_thread.join(timeout=2.0) self._monitor_runtime = None for instance_type, address in reversed(list(self.started_instances)): diff --git a/lightx2v/disagg/services/decoder.py b/lightx2v/disagg/services/decoder.py index 1052b9e50..707c8d936 100644 --- a/lightx2v/disagg/services/decoder.py +++ b/lightx2v/disagg/services/decoder.py @@ -9,7 +9,7 @@ import torch -from lightx2v.disagg.conn import MONITOR_POLLING_PORT, DataArgs, DataManager, DataReceiver, DisaggregationMode, DisaggregationPhase, ReqManager +from lightx2v.disagg.conn import MONITOR_POLLING_PORT, REQUEST_POLLING_PORT, DataArgs, DataManager, DataReceiver, DisaggregationMode, DisaggregationPhase, ReqManager from lightx2v.disagg.monitor import Reporter from lightx2v.disagg.protocol import AllocationRequest, MemoryHandle, RemoteBuffer from lightx2v.disagg.rdma_buffer import RDMABuffer, RDMABufferDescriptor @@ -31,6 +31,8 @@ def __init__(self, config: dict): self.decoder_engine_rank = int(self.config.get("decoder_engine_rank", 2)) self._phase2_rdma_client: Optional[RDMAClient] = None self._phase2_rdma_buffer: Optional[RDMABuffer] = None + self._centralized_request_mgr = ReqManager() + self._centralized_request_port = REQUEST_POLLING_PORT + self.decoder_engine_rank data_bootstrap_addr = str(self.config.get("data_bootstrap_addr", "127.0.0.1")) monitor_bind_host = str(self.config.get("local_hostname", data_bootstrap_addr)) shared_slots = int(self.config.get("rdma_buffer_slots", "128")) @@ -141,10 +143,11 @@ def init(self, config): if data_bootstrap_addr is None or data_bootstrap_room is None: return - try: - self._ensure_phase2_request_buffer() - except Exception: - self.logger.exception("Failed to connect phase2 RDMA buffer, will retry") + if not str(os.getenv("IS_CENTRALIZED", "0")).strip().lower() in {"1", "true", "yes", "on"}: + try: + self._ensure_phase2_request_buffer() + except Exception: + self.logger.exception("Failed to connect phase2 RDMA buffer, will retry") buffer_sizes = estimate_transformer_buffer_sizes(self.config) request = AllocationRequest( @@ -360,27 +363,39 @@ def run(self, stop_event=None): }, ) - if self._phase2_rdma_buffer is None: - try: - self._ensure_phase2_request_buffer() - except Exception: - self.logger.exception("Failed to connect phase2 request RDMA buffer, will retry") - - if self._phase2_rdma_buffer is not None: - packet = self._phase2_rdma_buffer.consume() - if packet is not None: - if isinstance(packet, dict) and "request_config" in packet: - config = dict(packet.get("request_config") or {}) - config["transformer_node_address"] = packet.get("transformer_node_address", "127.0.0.1") - else: - config = packet + centralized_request_mode = str(os.getenv("IS_CENTRALIZED", "0")).strip().lower() in {"1", "true", "yes", "on"} + if centralized_request_mode: + config = self._centralized_request_mgr.receive_non_block(self._centralized_request_port) + if config is not None: if not isinstance(config, dict) or "data_bootstrap_room" not in config: - self.logger.warning("Ignored incomplete phase2 packet from RDMA buffer: %s", packet) + self.logger.warning("Ignored incomplete request packet from ZMQ: %s", config) continue decoder_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("decoder", {}) decoder_metrics["request_received_ts"] = time.time() - self.logger.info("Received request config from RDMA buffer: %s", {k: v for k, v in config.items()}) + self.logger.info("Received request config from ZMQ: %s", {k: v for k, v in config.items()}) req_queue.append(config) + else: + if self._phase2_rdma_buffer is None: + try: + self._ensure_phase2_request_buffer() + except Exception: + self.logger.exception("Failed to connect phase2 request RDMA buffer, will retry") + + if self._phase2_rdma_buffer is not None: + packet = self._phase2_rdma_buffer.consume() + if packet is not None: + if isinstance(packet, dict) and "request_config" in packet: + config = dict(packet.get("request_config") or {}) + config["transformer_node_address"] = packet.get("transformer_node_address", "127.0.0.1") + else: + config = packet + if not isinstance(config, dict) or "data_bootstrap_room" not in config: + self.logger.warning("Ignored incomplete phase2 packet from RDMA buffer: %s", packet) + continue + decoder_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("decoder", {}) + decoder_metrics["request_received_ts"] = time.time() + self.logger.info("Received request config from RDMA buffer: %s", {k: v for k, v in config.items()}) + req_queue.append(config) if req_queue: config = req_queue.popleft() diff --git a/lightx2v/disagg/services/encoder.py b/lightx2v/disagg/services/encoder.py index 117b3484c..f434641cb 100644 --- a/lightx2v/disagg/services/encoder.py +++ b/lightx2v/disagg/services/encoder.py @@ -5,11 +5,14 @@ import time from collections import deque from typing import Any, Dict, List, Optional +from urllib.error import URLError +from urllib.request import Request, urlopen import numpy as np import torch +import zmq -from lightx2v.disagg.conn import MONITOR_POLLING_PORT, DataArgs, DataManager, DataPoll, DataSender, DisaggregationMode, DisaggregationPhase +from lightx2v.disagg.conn import MONITOR_POLLING_PORT, REQUEST_POLLING_PORT, DataArgs, DataManager, DataPoll, DataSender, DisaggregationMode, DisaggregationPhase, ReqManager from lightx2v.disagg.monitor import Reporter from lightx2v.disagg.protocol import AllocationRequest, MemoryHandle, RemoteBuffer from lightx2v.disagg.rdma_buffer import RDMABuffer, RDMABufferDescriptor @@ -37,6 +40,8 @@ def __init__(self, config: dict): self.decoder_engine_rank = int(self.config.get("decoder_engine_rank", "2")) self._request_rdma_client: Optional[RDMAClient] = None self._request_rdma_buffer: Optional[RDMABuffer] = None + self._centralized_request_mgr = ReqManager() + self._centralized_request_port = REQUEST_POLLING_PORT + self.encoder_engine_rank self._phase1_rdma_client: Optional[RDMAClient] = None self._phase1_rdma_buffer: Optional[RDMABuffer] = None data_bootstrap_addr = str(self.config.get("data_bootstrap_addr", "127.0.0.1")) @@ -53,6 +58,7 @@ def __init__(self, config: dict): self._phase1_slot_size = shared_slot_size self._last_request_connect_retry_ts = 0.0 self._last_phase1_connect_retry_ts = 0.0 + self._centralized_request_mode = str(os.getenv("IS_CENTRALIZED", "0")).strip().lower() in {"1", "true", "yes", "on"} self.text_encoder = None self.image_encoder = None self.vae_encoder = None @@ -93,6 +99,92 @@ def _wait_sender_success(self, room: int, sender: DataSender): raise RuntimeError(f"DataSender transfer failed for room={room}") time.sleep(0.001) + def _report_stage_metrics_to_controller(self, stage_name: str, config: dict[str, Any]): + if not self._centralized_request_mode: + return + + controller_host = str(config.get("controller_result_host", "127.0.0.1")) + controller_port_raw = config.get("controller_result_port") + if controller_port_raw is None: + return + + try: + controller_port = int(controller_port_raw) + except (TypeError, ValueError): + return + + request_metrics = config.get("request_metrics") + if not isinstance(request_metrics, dict): + return + + stage_metrics = request_metrics.get("stages", {}).get(stage_name) + if not isinstance(stage_metrics, dict): + return + + payload_request_metrics: dict[str, Any] = { + "request_id": request_metrics.get("request_id", config.get("data_bootstrap_room")), + "stages": {stage_name: stage_metrics}, + } + if request_metrics.get("controller_send_ts") is not None: + payload_request_metrics["controller_send_ts"] = request_metrics.get("controller_send_ts") + + self._centralized_request_mgr.send( + controller_host, + controller_port, + { + "message_type": "stage_metrics", + "stage_name": stage_name, + "data_bootstrap_room": int(config.get("data_bootstrap_room", 0)), + "request_metrics": payload_request_metrics, + }, + ) + self.logger.info( + "Reported %s stage metrics to controller: room=%s target=%s:%s", + stage_name, + config.get("data_bootstrap_room"), + controller_host, + controller_port, + ) + + def _wait_for_controller_ok(self, stage_name: str, config: dict[str, Any]): + if not self._centralized_request_mode: + return + + controller_host = str(config.get("controller_control_host", config.get("controller_result_host", "127.0.0.1"))) + controller_port_raw = config.get("controller_control_port") + if controller_port_raw is None: + return + + try: + controller_port = int(controller_port_raw) + except (TypeError, ValueError): + return + + request_body = json.dumps( + { + "control": "OK", + "stage_name": stage_name, + "data_bootstrap_room": int(config.get("data_bootstrap_room", 0)), + } + ).encode("utf-8") + request = Request( + f"http://{controller_host}:{controller_port}/ok", + data=request_body, + headers={"Content-Type": "application/json"}, + method="POST", + ) + try: + with urlopen(request, timeout=10) as response: + reply = json.loads(response.read().decode("utf-8")) + if not isinstance(reply, dict) or not reply.get("ok", False): + raise RuntimeError(f"unexpected controller OK reply: {reply}") + except URLError: + self.logger.exception("Failed to wait for controller OK reply for %s room=%s", stage_name, config.get("data_bootstrap_room")) + return + except Exception: + self.logger.exception("Failed to wait for controller OK reply for %s room=%s", stage_name, config.get("data_bootstrap_room")) + return + def _get_queue_metrics(self) -> dict[str, Any]: with self._queue_metrics_lock: queue_sizes = dict(self._queue_metrics.get("queue_sizes", {})) @@ -587,6 +679,9 @@ def _sha256_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]: buffer_ptrs = [buf.data_ptr() for buf in room_buffers] # Publish phase1 request metadata after compute so downstream can see latest metrics. encoder_metrics["output_enqueued_ts"] = time.time() + if self._centralized_request_mode: + self._report_stage_metrics_to_controller("encoder", config) + self._wait_for_controller_ok("encoder", config) phase1_meta = { "request_config": dict(config), "encoder_node_address": self.data_mgr.get_localhost(), @@ -651,32 +746,43 @@ def run(self, stop_event=None): }, ) - if self._request_rdma_buffer is None: - try: - self._ensure_request_buffer() - except Exception: - self.logger.exception("Failed to connect request RDMA buffer, will retry") - - if self._request_rdma_client is not None and self._request_rdma_client.has_qp_error(): - self.logger.warning( - "Request RDMA client entered error state, reconnecting: %s", - self._request_rdma_client.last_wc_error_message(), - ) - try: - self._reconnect_request_buffer() - except Exception: - self.logger.exception("Failed to reconnect request RDMA buffer after QP error") - - if self._request_rdma_buffer is not None: - config = self._request_rdma_buffer.consume() + if self._centralized_request_mode: + config = self._centralized_request_mgr.receive_non_block(self._centralized_request_port) if config is not None: if not isinstance(config, dict) or "data_bootstrap_room" not in config: - self.logger.warning("Ignored incomplete request packet from RDMA buffer: %s", config) + self.logger.warning("Ignored incomplete request packet from ZMQ: %s", config) continue encoder_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("encoder", {}) encoder_metrics["request_received_ts"] = time.time() - self.logger.info("Received request config from RDMA buffer: %s", {k: v for k, v in config.items()}) + self.logger.info("Received request config from ZMQ: %s", {k: v for k, v in config.items()}) req_queue.append(config) + else: + if self._request_rdma_buffer is None: + try: + self._ensure_request_buffer() + except Exception: + self.logger.exception("Failed to connect request RDMA buffer, will retry") + + if self._request_rdma_client is not None and self._request_rdma_client.has_qp_error(): + self.logger.warning( + "Request RDMA client entered error state, reconnecting: %s", + self._request_rdma_client.last_wc_error_message(), + ) + try: + self._reconnect_request_buffer() + except Exception: + self.logger.exception("Failed to reconnect request RDMA buffer after QP error") + + if self._request_rdma_buffer is not None: + config = self._request_rdma_buffer.consume() + if config is not None: + if not isinstance(config, dict) or "data_bootstrap_room" not in config: + self.logger.warning("Ignored incomplete request packet from RDMA buffer: %s", config) + continue + encoder_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("encoder", {}) + encoder_metrics["request_received_ts"] = time.time() + self.logger.info("Received request config from RDMA buffer: %s", {k: v for k, v in config.items()}) + req_queue.append(config) if req_queue: config = req_queue.popleft() diff --git a/lightx2v/disagg/services/instance_proxy.py b/lightx2v/disagg/services/instance_proxy.py index f7e2e8f83..c2a319809 100644 --- a/lightx2v/disagg/services/instance_proxy.py +++ b/lightx2v/disagg/services/instance_proxy.py @@ -143,7 +143,7 @@ def _start_instance(self, msg: dict[str, Any]) -> dict[str, Any]: ] service_cmd = [python_executable, *[str(part) for part in service_argv]] - with open(sidecar_log_path, "a", encoding="utf-8") as sidecar_log: + with open(sidecar_log_path, "w", encoding="utf-8") as sidecar_log: sidecar_proc = subprocess.Popen( sidecar_cmd, cwd=workdir, @@ -157,7 +157,7 @@ def _start_instance(self, msg: dict[str, Any]) -> dict[str, Any]: if sidecar_proc.poll() is not None: raise RuntimeError(f"failed to start sidecar process, exited with code={sidecar_proc.returncode}") - with open(service_log_path, "a", encoding="utf-8") as service_log: + with open(service_log_path, "w", encoding="utf-8") as service_log: service_proc = subprocess.Popen( service_cmd, cwd=workdir, diff --git a/lightx2v/disagg/services/transformer.py b/lightx2v/disagg/services/transformer.py index 0641bc817..ddd660cb2 100644 --- a/lightx2v/disagg/services/transformer.py +++ b/lightx2v/disagg/services/transformer.py @@ -9,8 +9,11 @@ import numpy as np import torch +import zmq +from urllib.error import URLError +from urllib.request import Request, urlopen -from lightx2v.disagg.conn import MONITOR_POLLING_PORT, DataArgs, DataManager, DataPoll, DataReceiver, DataSender, DisaggregationMode, DisaggregationPhase +from lightx2v.disagg.conn import MONITOR_POLLING_PORT, REQUEST_POLLING_PORT, DataArgs, DataManager, DataPoll, DataReceiver, DataSender, DisaggregationMode, DisaggregationPhase, ReqManager from lightx2v.disagg.monitor import Reporter from lightx2v.disagg.protocol import AllocationRequest, MemoryHandle, RemoteBuffer from lightx2v.disagg.rdma_buffer import RDMABuffer, RDMABufferDescriptor @@ -65,10 +68,13 @@ def __init__(self, config: dict): self._phase1_rdma_buffer: Optional[RDMABuffer] = None self._phase2_rdma_client: Optional[RDMAClient] = None self._phase2_rdma_buffer: Optional[RDMABuffer] = None + self._centralized_request_mgr = ReqManager() + self._centralized_request_port = REQUEST_POLLING_PORT + self.transformer_engine_rank data_bootstrap_addr = str(self.config.get("data_bootstrap_addr", "127.0.0.1")) monitor_bind_host = str(self.config.get("local_hostname", data_bootstrap_addr)) shared_slots = int(self.config.get("rdma_buffer_slots", "128")) shared_slot_size = int(self.config.get("rdma_buffer_slot_size", "4096")) + self._centralized_request_mode = str(os.getenv("IS_CENTRALIZED", "0")).strip().lower() in {"1", "true", "yes", "on"} self._phase1_server_ip = str(self.config.get("rdma_phase1_host", data_bootstrap_addr)) self._phase1_handshake_port = int(self.config.get("rdma_phase1_handshake_port", "5567")) self._phase1_slots = shared_slots @@ -120,6 +126,92 @@ def _wait_sender_success(self, room: int, sender: DataSender): raise RuntimeError(f"DataSender transfer failed for room={room}") time.sleep(0.001) + def _report_stage_metrics_to_controller(self, stage_name: str, config: dict[str, Any]): + if not self._centralized_request_mode: + return + + controller_host = str(config.get("controller_result_host", "127.0.0.1")) + controller_port_raw = config.get("controller_result_port") + if controller_port_raw is None: + return + + try: + controller_port = int(controller_port_raw) + except (TypeError, ValueError): + return + + request_metrics = config.get("request_metrics") + if not isinstance(request_metrics, dict): + return + + stage_metrics = request_metrics.get("stages", {}).get(stage_name) + if not isinstance(stage_metrics, dict): + return + + payload_request_metrics: dict[str, Any] = { + "request_id": request_metrics.get("request_id", config.get("data_bootstrap_room")), + "stages": {stage_name: stage_metrics}, + } + if request_metrics.get("controller_send_ts") is not None: + payload_request_metrics["controller_send_ts"] = request_metrics.get("controller_send_ts") + + self._centralized_request_mgr.send( + controller_host, + controller_port, + { + "message_type": "stage_metrics", + "stage_name": stage_name, + "data_bootstrap_room": int(config.get("data_bootstrap_room", 0)), + "request_metrics": payload_request_metrics, + }, + ) + self.logger.info( + "Reported %s stage metrics to controller: room=%s target=%s:%s", + stage_name, + config.get("data_bootstrap_room"), + controller_host, + controller_port, + ) + + def _wait_for_controller_ok(self, stage_name: str, config: dict[str, Any]): + if not self._centralized_request_mode: + return + + controller_host = str(config.get("controller_control_host", config.get("controller_result_host", "127.0.0.1"))) + controller_port_raw = config.get("controller_control_port") + if controller_port_raw is None: + return + + try: + controller_port = int(controller_port_raw) + except (TypeError, ValueError): + return + + request_body = json.dumps( + { + "control": "OK", + "stage_name": stage_name, + "data_bootstrap_room": int(config.get("data_bootstrap_room", 0)), + } + ).encode("utf-8") + request = Request( + f"http://{controller_host}:{controller_port}/ok", + data=request_body, + headers={"Content-Type": "application/json"}, + method="POST", + ) + try: + with urlopen(request, timeout=10) as response: + reply = json.loads(response.read().decode("utf-8")) + if not isinstance(reply, dict) or not reply.get("ok", False): + raise RuntimeError(f"unexpected controller OK reply: {reply}") + except URLError: + self.logger.exception("Failed to wait for controller OK reply for %s room=%s", stage_name, config.get("data_bootstrap_room")) + return + except Exception: + self.logger.exception("Failed to wait for controller OK reply for %s room=%s", stage_name, config.get("data_bootstrap_room")) + return + def _attach_remote_shared_memory(self, shm_name: str) -> shared_memory.SharedMemory: # Python 3.12 supports `track=False`, which avoids duplicate cleanup from non-owner processes. try: @@ -310,19 +402,20 @@ def init(self, config): if data_bootstrap_addr is None or data_bootstrap_room is None: return - phase_deadline = time.time() + 30.0 - while time.time() < phase_deadline: - try: - self._ensure_phase1_request_buffer() - self._ensure_phase2_meta_buffer() - except Exception: - self.logger.exception("Failed to connect phase RDMA buffers, will retry") - if self._phase1_rdma_buffer is not None and self._phase2_rdma_buffer is not None: - break - time.sleep(0.1) + if not self._centralized_request_mode: + phase_deadline = time.time() + 30.0 + while time.time() < phase_deadline: + try: + self._ensure_phase1_request_buffer() + self._ensure_phase2_meta_buffer() + except Exception: + self.logger.exception("Failed to connect phase RDMA buffers, will retry") + if self._phase1_rdma_buffer is not None and self._phase2_rdma_buffer is not None: + break + time.sleep(0.1) - if self._phase1_rdma_buffer is None or self._phase2_rdma_buffer is None: - raise RuntimeError("phase RDMA buffers are not ready") + if self._phase1_rdma_buffer is None or self._phase2_rdma_buffer is None: + raise RuntimeError("phase RDMA buffers are not ready") buffer_sizes = estimate_encoder_buffer_sizes(self.config) request = AllocationRequest( @@ -899,37 +992,48 @@ def run(self, stop_event=None): }, ) - if self._phase1_rdma_buffer is None: - try: - self._ensure_phase1_request_buffer() - except Exception: - self.logger.exception("Failed to connect phase1 request RDMA buffer, will retry") - - if self._phase1_rdma_client is not None and self._phase1_rdma_client.has_qp_error(): - self.logger.warning( - "Phase1 request RDMA client entered error state, reconnecting: %s", - self._phase1_rdma_client.last_wc_error_message(), - ) - try: - self._reconnect_phase1_request_buffer() - except Exception: - self.logger.exception("Failed to reconnect phase1 request RDMA buffer after QP error") - - if self._phase1_rdma_buffer is not None and len(req_queue) + len(waiting_queue) < 2: - packet = self._phase1_rdma_buffer.consume() - if packet is not None: - if isinstance(packet, dict) and "request_config" in packet: - config = dict(packet.get("request_config") or {}) - config["encoder_node_address"] = packet.get("encoder_node_address", "127.0.0.1") - else: - config = packet + if self._centralized_request_mode: + config = self._centralized_request_mgr.receive_non_block(self._centralized_request_port) + if config is not None: if not isinstance(config, dict) or "data_bootstrap_room" not in config: - self.logger.warning("Ignored incomplete phase1 packet from RDMA buffer: %s", packet) + self.logger.warning("Ignored incomplete request packet from ZMQ: %s", config) continue transformer_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("transformer", {}) transformer_metrics["request_received_ts"] = time.time() - self.logger.info("%s Received request config from RDMA buffer: %s", self.transformer_engine_rank, {k: v for k, v in config.items()}) + self.logger.info("Received request config from ZMQ: %s", {k: v for k, v in config.items()}) req_queue.append(config) + else: + if self._phase1_rdma_buffer is None: + try: + self._ensure_phase1_request_buffer() + except Exception: + self.logger.exception("Failed to connect phase1 request RDMA buffer, will retry") + + if self._phase1_rdma_client is not None and self._phase1_rdma_client.has_qp_error(): + self.logger.warning( + "Phase1 request RDMA client entered error state, reconnecting: %s", + self._phase1_rdma_client.last_wc_error_message(), + ) + try: + self._reconnect_phase1_request_buffer() + except Exception: + self.logger.exception("Failed to reconnect phase1 request RDMA buffer after QP error") + + if self._phase1_rdma_buffer is not None and len(req_queue) + len(waiting_queue) < 2: + packet = self._phase1_rdma_buffer.consume() + if packet is not None: + if isinstance(packet, dict) and "request_config" in packet: + config = dict(packet.get("request_config") or {}) + config["encoder_node_address"] = packet.get("encoder_node_address", "127.0.0.1") + else: + config = packet + if not isinstance(config, dict) or "data_bootstrap_room" not in config: + self.logger.warning("Ignored incomplete phase1 packet from RDMA buffer: %s", packet) + continue + transformer_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("transformer", {}) + transformer_metrics["request_received_ts"] = time.time() + self.logger.info("%s Received request config from RDMA buffer: %s", self.transformer_engine_rank, {k: v for k, v in config.items()}) + req_queue.append(config) if req_queue: config = req_queue.popleft() @@ -962,6 +1066,9 @@ def run(self, stop_event=None): room, config = exec_queue[0] try: self.process(config) + if self._centralized_request_mode: + self._report_stage_metrics_to_controller("transformer", config) + self._wait_for_controller_ok("transformer", config) if self.sync_comm: self.remove(room) else: diff --git a/lightx2v/disagg/utils.py b/lightx2v/disagg/utils.py index 3ade1618b..992f2a313 100644 --- a/lightx2v/disagg/utils.py +++ b/lightx2v/disagg/utils.py @@ -8,13 +8,7 @@ import torchvision.transforms.functional as TF from PIL import Image -from lightx2v.models.input_encoders.hf.wan.t5.model import T5EncoderModel -from lightx2v.models.input_encoders.hf.wan.xlm_roberta.model import CLIPModel from lightx2v.models.networks.lora_adapter import LoraAdapter -from lightx2v.models.networks.wan.model import WanModel -from lightx2v.models.video_encoders.hf.wan.vae import WanVAE -from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE -from lightx2v.models.video_encoders.hf.wan.vae_tiny import Wan2_2_VAE_tiny, WanVAE_tiny from lightx2v.utils.envs import GET_DTYPE from lightx2v.utils.set_config import set_config as set_config_base from lightx2v.utils.utils import find_torch_model_path @@ -170,6 +164,8 @@ def build_wan_model_with_lora(wan_module, config, model_kwargs, lora_configs, mo def load_wan_text_encoder(config: Dict[str, Any]): + from lightx2v.models.input_encoders.hf.wan.t5.model import T5EncoderModel + # offload config t5_offload = config.get("t5_cpu_offload", config.get("cpu_offload")) if t5_offload: @@ -212,6 +208,8 @@ def load_wan_text_encoder(config: Dict[str, Any]): def load_wan_image_encoder(config: Dict[str, Any]): + from lightx2v.models.input_encoders.hf.wan.xlm_roberta.model import CLIPModel + image_encoder = None if config["task"] in ["i2v", "flf2v", "animate", "s2v"] and config.get("use_image_encoder", True): # offload config @@ -259,6 +257,9 @@ def get_vae_parallel(config: Dict[str, Any]): def load_wan_vae_encoder(config: Dict[str, Any]): + from lightx2v.models.video_encoders.hf.wan.vae import WanVAE + from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE + vae_name = config.get("vae_name", "Wan2.1_VAE.pth") if config.get("model_cls", "") == "wan2.2": vae_cls = Wan2_2_VAE @@ -289,6 +290,10 @@ def load_wan_vae_encoder(config: Dict[str, Any]): def load_wan_vae_decoder(config: Dict[str, Any]): + from lightx2v.models.video_encoders.hf.wan.vae import WanVAE + from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE + from lightx2v.models.video_encoders.hf.wan.vae_tiny import Wan2_2_VAE_tiny, WanVAE_tiny + vae_name = config.get("vae_name", "Wan2.1_VAE.pth") tiny_vae_name = "taew2_1.pth" @@ -327,6 +332,31 @@ def load_wan_vae_decoder(config: Dict[str, Any]): def load_wan_transformer(config: Dict[str, Any]): + print( + "Loading WanModel module: model_cls=%s model_path=%s device=%s dit_quantized=%s lazy_load=%s" + % ( + config.get("model_cls"), + config.get("model_path"), + AI_DEVICE if not config.get("cpu_offload") else "cpu", + config.get("dit_quantized", False), + config.get("lazy_load", False), + ), + flush=True, + ) + from lightx2v.models.networks.wan.model import WanModel + + print( + "Constructing WanModel: model_cls=%s model_path=%s device=%s dit_quantized=%s lazy_load=%s" + % ( + config.get("model_cls"), + config.get("model_path"), + AI_DEVICE if not config.get("cpu_offload") else "cpu", + config.get("dit_quantized", False), + config.get("lazy_load", False), + ), + flush=True, + ) + if config["cpu_offload"]: init_device = torch.device("cpu") else: @@ -339,9 +369,12 @@ def load_wan_transformer(config: Dict[str, Any]): model = WanModel(**wan_model_kwargs) else: model = build_wan_model_with_lora(WanModel, config, wan_model_kwargs, lora_configs, model_type="wan2.1") + logger.info("WanModel construction finished") return model elif config.get("model_cls") == "wan2.2_moe": + print("Loading MultiModelStruct module start", flush=True) from lightx2v.models.runners.wan.wan_runner import MultiModelStruct + print("Loading MultiModelStruct module done", flush=True) high_noise_model_path = os.path.join(config["model_path"], "high_noise_model") if config.get("dit_quantized", False) and config.get("high_noise_quantized_ckpt", None): @@ -376,6 +409,7 @@ def load_wan_transformer(config: Dict[str, Any]): high_noise_model = build_wan_model_with_lora(WanModel, config, high_model_kwargs, lora_configs, model_type="high_noise_model") low_noise_model = build_wan_model_with_lora(WanModel, config, low_model_kwargs, lora_configs, model_type="low_noise_model") + logger.info("WanModel construction finished for wan2.2_moe") return MultiModelStruct([high_noise_model, low_noise_model], config, config.get("boundary", 0.875)) else: model_struct = MultiModelStruct([None, None], config, config.get("boundary", 0.875)) diff --git a/scripts/disagg/extract_dynamic_latency.py b/scripts/disagg/extract_dynamic_latency.py index 589b83f74..0fdec8e40 100644 --- a/scripts/disagg/extract_dynamic_latency.py +++ b/scripts/disagg/extract_dynamic_latency.py @@ -9,7 +9,7 @@ from pathlib import Path WAIT_PATTERNS = [ - re.compile(r"^\[INFO\]\s+(\d{2}\s+\w{3}\s+\d{4}\s+\d{2}:\d{2}:\d{2}).*waiting workload configs on port=", re.IGNORECASE), + re.compile(r"^\[INFO\]\s+(\d{2}\s+\w{3}\s+\d{4}\s+\d{2}:\d{2}:\d{2}).*Waiting for decoder results", re.IGNORECASE), re.compile(r"^\[(?:INFO|WARNING|ERROR|DEBUG|CRITICAL)\]\s+(\d{2}\s+\w{3}\s+\d{4}\s+\d{2}:\d{2}:\d{2}).*waiting workload configs on port=", re.IGNORECASE), ] LAT_PATTERNS = [ diff --git a/scripts/disagg/kill_service.sh b/scripts/disagg/kill_service.sh index 4003e93cc..8dd3f6005 100755 --- a/scripts/disagg/kill_service.sh +++ b/scripts/disagg/kill_service.sh @@ -6,6 +6,30 @@ SCRIPT_NAMES=("run_wan22_i2v_distill.sh" "run_dynamic.sh") list_port=(5566 12788 17788 27788) +collect_proxy_ports_from_config() { + local config_path="$1" + + if [[ -z "$config_path" || ! -f "$config_path" ]]; then + return 0 + fi + if ! command -v jq >/dev/null 2>&1; then + return 0 + fi + + local base_port + base_port=$(jq -r '.disagg_config.remote_proxy_req_base_port // empty' "$config_path" 2>/dev/null || true) + if [[ -z "$base_port" || ! "$base_port" =~ ^[0-9]+$ ]]; then + return 0 + fi + + jq -r '.disagg_config.static_instance_slots[]?.engine_rank // empty' "$config_path" 2>/dev/null | while read -r engine_rank; do + [[ -z "$engine_rank" ]] && continue + if [[ "$engine_rank" =~ ^[0-9]+$ ]]; then + echo $((base_port + engine_rank)) + fi + done +} + n=30 list_n=($(seq 0 $((n-1)))) @@ -22,6 +46,23 @@ for a in "${list_port[@]}"; do done done +proxy_config_candidates=( + "${DISAGG_CONTROLLER_CFG:-}" + "/root/zht/LightX2V/configs/disagg/multi_node/wan22_i2v_distill_controller.json" + "/root/zht/LightX2V/configs/disagg/single_node/wan22_i2v_distill_controller.json" +) +for config_path in "${proxy_config_candidates[@]}"; do + while read -r proxy_port; do + [[ -z "$proxy_port" ]] && continue + PORTS+=($proxy_port) + done < <(collect_proxy_ports_from_config "$config_path") +done + +# Fallback for environments without jq or without a readable config file. +PORTS+=(28000) + +mapfile -t PORTS < <(printf '%s\n' "${PORTS[@]}" | awk 'NF && !seen[$0]++ { print $0 }' | sort -n) + kill_pid_gracefully() { local pid="$1" if [[ -z "$pid" ]]; then diff --git a/scripts/disagg/run_dynamic.sh b/scripts/disagg/run_dynamic.sh index 2586865b0..971ae14f8 100644 --- a/scripts/disagg/run_dynamic.sh +++ b/scripts/disagg/run_dynamic.sh @@ -62,11 +62,30 @@ if command -v jq >/dev/null 2>&1; then derived_controller_host=$(jq -r '.disagg_config.bootstrap_addr // empty' "${controller_cfg}") fi export DISAGG_CONTROLLER_HOST=${DISAGG_CONTROLLER_HOST:-${derived_controller_host:-127.0.0.1}} +# RoCE gid_index: align with cluster data-plane IP (multi-homed / wrong default route breaks cross-node QP RTR). +if [[ -z "${RDMA_PREFERRED_IPV4:-}" && -n "${derived_controller_host}" ]]; then + if [[ "${derived_controller_host}" =~ ^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$ && "${derived_controller_host}" != "127.0.0.1" ]]; then + export RDMA_PREFERRED_IPV4="${derived_controller_host}" + fi +fi export DISAGG_CONTROLLER_REQUEST_PORT=${DISAGG_CONTROLLER_REQUEST_PORT:-12786} export LOAD_FROM_USER=${LOAD_FROM_USER:-0} -export DISAGG_INSTANCE_START_TIMEOUT_SECONDS=${DISAGG_INSTANCE_START_TIMEOUT_SECONDS:-90} +# multi_node: remote ranks (e.g. slow encoder/decoder host) may need longer TCP/ready waits. +if [[ "${topology}" == "single_node" ]]; then + export DISAGG_INSTANCE_START_TIMEOUT_SECONDS=${DISAGG_INSTANCE_START_TIMEOUT_SECONDS:-90} +else + export DISAGG_INSTANCE_START_TIMEOUT_SECONDS=${DISAGG_INSTANCE_START_TIMEOUT_SECONDS:-300} + export DISAGG_REMOTE_PROXY_START_TIMEOUT_SECONDS=${DISAGG_REMOTE_PROXY_START_TIMEOUT_SECONDS:-120} + export DISAGG_SIDECAR_START_TIMEOUT_SECONDS=${DISAGG_SIDECAR_START_TIMEOUT_SECONDS:-60} +fi # Dynamic debug defaults to a smaller request batch; override for stress runs. -export DISAGG_AUTO_REQUEST_COUNT=${DISAGG_AUTO_REQUEST_COUNT:-1} +export DISAGG_AUTO_REQUEST_COUNT=${DISAGG_AUTO_REQUEST_COUNT:-30} +export DISAGG_ENABLE_NSYS=${DISAGG_ENABLE_NSYS:-0} +export SYNC_COMM=${SYNC_COMM:-0} +export DISAGG_NSYS_BIN=${DISAGG_NSYS_BIN:-nsys} +export DISAGG_NSYS_OUTPUT_DIR=${DISAGG_NSYS_OUTPUT_DIR:-${lightx2v_path}/save_results/nsys} +export DISAGG_NSYS_TRACE=${DISAGG_NSYS_TRACE:-cuda,nvtx,osrt} +export DISAGG_NSYS_EXTRA_ARGS=${DISAGG_NSYS_EXTRA_ARGS:-} user_start_delay_s=${USER_START_DELAY_S:-0} if [[ -n "${USER_MAX_REQUESTS:-}" ]]; then user_max_requests=${USER_MAX_REQUESTS} @@ -86,7 +105,11 @@ save_result_path=${SAVE_RESULT_PATH:-${lightx2v_path}/save_results/wan22_i2v_dyn controller_log=${lightx2v_path}/save_results/disagg_wan22_i2v_dynamic_controller.log user_log=${lightx2v_path}/save_results/disagg_wan22_i2v_dynamic_user.log -controller_wait_timeout_s=${CONTROLLER_WAIT_TIMEOUT_S:-3000} +if [[ "${topology}" == "single_node" ]]; then + controller_wait_timeout_s=${CONTROLLER_WAIT_TIMEOUT_S:-3000} +else + controller_wait_timeout_s=${CONTROLLER_WAIT_TIMEOUT_S:-7200} +fi controller_poll_interval_s=${CONTROLLER_POLL_INTERVAL_S:-5} fatal_watch_interval_s=${FATAL_WATCH_INTERVAL_S:-2} fatal_flag_file=${lightx2v_path}/save_results/disagg_wan22_i2v_dynamic_fatal.flag @@ -102,7 +125,10 @@ fi echo "disagg topology=${topology}" echo "controller_cfg=${controller_cfg}" echo "DISAGG_CONTROLLER_HOST=${DISAGG_CONTROLLER_HOST} DISAGG_CONTROLLER_REQUEST_PORT=${DISAGG_CONTROLLER_REQUEST_PORT}" +echo "RDMA_PREFERRED_IPV4=${RDMA_PREFERRED_IPV4:-}" echo "DISAGG_AUTO_REQUEST_COUNT=${DISAGG_AUTO_REQUEST_COUNT}" +echo "DISAGG_ENABLE_NSYS=${DISAGG_ENABLE_NSYS} DISAGG_NSYS_OUTPUT_DIR=${DISAGG_NSYS_OUTPUT_DIR} DISAGG_NSYS_TRACE=${DISAGG_NSYS_TRACE}" +echo "SYNC_COMM=${SYNC_COMM}" echo "LOAD_FROM_USER=${LOAD_FROM_USER} USER_START_DELAY_S=${user_start_delay_s} USER_MAX_REQUESTS=${user_max_requests}" rm -f "${fatal_flag_file}" @@ -223,6 +249,73 @@ sync_remote_configs_once() { done } +# Remote workers import lightx2v from remote_workdir; without syncing, fixes on the controller host never run on peers. +sync_remote_disagg_sources_once() { + if [[ "${is_single_node}" == "1" ]]; then + echo "skip remote disagg source sync: single_node topology" + return 0 + fi + if [[ "${DISAGG_SYNC_REMOTE_SOURCES:-1}" == "0" || "${DISAGG_SYNC_REMOTE_SOURCES:-}" == "false" ]]; then + echo "skip remote disagg source sync: DISAGG_SYNC_REMOTE_SOURCES=${DISAGG_SYNC_REMOTE_SOURCES:-}" + return 0 + fi + if ! command -v jq >/dev/null 2>&1; then + echo "skip remote disagg source sync: jq not found" + return 0 + fi + + local bootstrap_host + bootstrap_host=$(jq -r '.disagg_config.bootstrap_addr // empty' "${controller_cfg}") + local ssh_user + ssh_user=$(jq -r '.disagg_config.ssh_user // empty' "${controller_cfg}") + local remote_workdir + remote_workdir=$(jq -r '.disagg_config.remote_workdir // empty' "${controller_cfg}") + if [[ -z "${remote_workdir}" ]]; then + remote_workdir="${lightx2v_path}" + fi + + mapfile -t remote_hosts < <(jq -r --arg bootstrap "${bootstrap_host}" '.disagg_config.static_instance_slots[]?.host // empty | select(length > 0 and . != $bootstrap)' "${controller_cfg}" | sort -u) + if (( ${#remote_hosts[@]} == 0 )); then + echo "no remote hosts for disagg source sync" + return 0 + fi + + mapfile -t ssh_opts < <(jq -r '.disagg_config.ssh_options[]? // empty' "${controller_cfg}") + local rsync_rsh="ssh" + for opt in "${ssh_opts[@]}"; do + rsync_rsh+=" $(printf '%q' "${opt}")" + done + + local rel_disagg="lightx2v/disagg" + local src_dir="${lightx2v_path}/${rel_disagg}/" + # Do not overwrite rdma_base.py on peers: pyverbs/rdma-core versions may differ per host. + local sync_excludes=( + --exclude=rdma_base.py + ) + + for host in "${remote_hosts[@]}"; do + local target="${host}" + if [[ -n "${ssh_user}" ]]; then + target="${ssh_user}@${host}" + fi + local dst_dir="${remote_workdir}/${rel_disagg}" + ssh "${ssh_opts[@]}" "${target}" "mkdir -p '${dst_dir}'" || true + if command -v rsync >/dev/null 2>&1; then + if rsync -az -e "${rsync_rsh}" "${sync_excludes[@]}" "${src_dir}" "${target}:${dst_dir}/"; then + echo "synced ${rel_disagg}/ to ${host}:${dst_dir}/ (excludes rdma_base.py)" + else + echo "warning: rsync ${rel_disagg} to ${host} failed" + fi + else + if ( cd "${lightx2v_path}" && tar cf - "${sync_excludes[@]}" "${rel_disagg}" ) | ssh "${ssh_opts[@]}" "${target}" "cd '${remote_workdir}' && tar xf -"; then + echo "synced ${rel_disagg}/ to ${host} (tar, excludes rdma_base.py)" + else + echo "warning: tar-sync ${rel_disagg} to ${host} failed" + fi + fi + done +} + collect_remote_logs_once() { if [[ "${is_single_node}" == "1" ]]; then echo "skip remote log collection: single_node topology" @@ -373,7 +466,8 @@ cleanup() { trap cleanup EXIT INT TERM pre_clean_remote_hosts_once -sync_remote_configs_once +# sync_remote_configs_once +sync_remote_disagg_sources_once python -m lightx2v.disagg.examples.run_service \ --service controller \ From 04f98e7cfd6f2a47fd46f52cb378cb61692b29e9 Mon Sep 17 00:00:00 2001 From: zhtshr <547553598@qq.com> Date: Fri, 24 Apr 2026 16:50:34 +0800 Subject: [PATCH 8/9] run precommit --- lightx2v/disagg/README.md | 2 +- lightx2v/disagg/rdma_client.py | 12 ++++++------ lightx2v/disagg/rdma_server.py | 8 ++++---- lightx2v/disagg/rdma_utils.py | 3 +-- lightx2v/disagg/services/base.py | 3 +-- lightx2v/disagg/services/controller.py | 14 ++++---------- lightx2v/disagg/services/decoder.py | 2 +- lightx2v/disagg/services/encoder.py | 1 - lightx2v/disagg/services/transformer.py | 5 ++--- lightx2v/disagg/utils.py | 1 + 10 files changed, 21 insertions(+), 30 deletions(-) diff --git a/lightx2v/disagg/README.md b/lightx2v/disagg/README.md index 13d064464..c7efd0862 100644 --- a/lightx2v/disagg/README.md +++ b/lightx2v/disagg/README.md @@ -140,4 +140,4 @@ bash scripts/disagg/run_dynamic.sh 1. 多机运行时,`DISAGG_CONTROLLER_CFG` 里的 `bootstrap_addr`、`static_instance_slots` 和各 slot 的 `env` 会直接影响远端实例如何绑定网络与 Mooncake 地址。 2. 如果遇到端口占用,优先检查 `scripts/disagg/kill_service.sh` 是否已经把旧实例和 proxy 清理干净。 -3. 如果需要了解 controller 配置文件本身的字段含义,可以继续查看 `configs/disagg/` 下对应 JSON。 \ No newline at end of file +3. 如果需要了解 controller 配置文件本身的字段含义,可以继续查看 `configs/disagg/` 下对应 JSON。 diff --git a/lightx2v/disagg/rdma_client.py b/lightx2v/disagg/rdma_client.py index 2c7214bbc..b80d8fc03 100644 --- a/lightx2v/disagg/rdma_client.py +++ b/lightx2v/disagg/rdma_client.py @@ -7,21 +7,21 @@ import time from lightx2v.disagg.rdma_base import ( - AccessFlag, - AHAttr, CQ, GID, - GlobalRoute, - IBDevice, MR, PD, QP, + SGE, + WR, + AHAttr, + AccessFlag, + GlobalRoute, + IBDevice, QPAttr, QPCap, QPInitAttr, QPType, - SGE, - WR, WROpcode, e, get_device_list, diff --git a/lightx2v/disagg/rdma_server.py b/lightx2v/disagg/rdma_server.py index 63e272cf1..b431665eb 100644 --- a/lightx2v/disagg/rdma_server.py +++ b/lightx2v/disagg/rdma_server.py @@ -4,15 +4,15 @@ import threading from lightx2v.disagg.rdma_base import ( - AccessFlag, - AHAttr, CQ, GID, - GlobalRoute, - IBDevice, MR, PD, QP, + AHAttr, + AccessFlag, + GlobalRoute, + IBDevice, QPAttr, QPCap, QPInitAttr, diff --git a/lightx2v/disagg/rdma_utils.py b/lightx2v/disagg/rdma_utils.py index 629340ef9..bf0df1631 100644 --- a/lightx2v/disagg/rdma_utils.py +++ b/lightx2v/disagg/rdma_utils.py @@ -7,7 +7,6 @@ import socket import time - logger = logging.getLogger(__name__) @@ -232,4 +231,4 @@ def rtr_path_mtu_negotiated(ctx, port_num: int, peer_active_mtu: int | None) -> peer = int(peer_active_mtu) except (TypeError, ValueError): return local - return min(local, peer) \ No newline at end of file + return min(local, peer) diff --git a/lightx2v/disagg/services/base.py b/lightx2v/disagg/services/base.py index f35c7ad1e..3fc66b220 100644 --- a/lightx2v/disagg/services/base.py +++ b/lightx2v/disagg/services/base.py @@ -1,9 +1,8 @@ -from abc import ABC import sys +from abc import ABC from loguru import logger as loguru_logger - loguru_logger.remove() loguru_logger.add( sys.stderr, diff --git a/lightx2v/disagg/services/controller.py b/lightx2v/disagg/services/controller.py index fea94d668..403a51e42 100644 --- a/lightx2v/disagg/services/controller.py +++ b/lightx2v/disagg/services/controller.py @@ -1,15 +1,15 @@ import ipaddress import json -from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer import os import shlex -import signal import shutil +import signal import socket import subprocess import sys import time from collections.abc import Mapping +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from pathlib import Path from threading import Event, Lock, Thread from typing import Any @@ -501,11 +501,7 @@ def _with_env(base_cmd: str, env_map: dict[str, str]) -> str: return f"env {env_prefix} {base_cmd}" sidecar_cmd = _with_env( - ( - f"{shlex.quote(python_executable)} " - "-m lightx2v.disagg.services.data_mgr_sidecar " - f"--push-addr {shlex.quote(push_addr)} --req-addr {shlex.quote(req_addr)}" - ), + (f"{shlex.quote(python_executable)} -m lightx2v.disagg.services.data_mgr_sidecar --push-addr {shlex.quote(push_addr)} --req-addr {shlex.quote(req_addr)}"), sidecar_env_vars, ) cmd_with_python = [python_executable, *cmd[1:]] @@ -547,9 +543,7 @@ def _with_env(base_cmd: str, env_map: dict[str, str]) -> str: service_pid = None if sidecar_pid is None or service_pid is None: - raise RuntimeError( - f"failed to parse remote pids for {instance_type} rank={engine_rank} host={host}: stdout={completed.stdout!r} stderr={completed.stderr!r}" - ) + raise RuntimeError(f"failed to parse remote pids for {instance_type} rank={engine_rank} host={host}: stdout={completed.stdout!r} stderr={completed.stderr!r}") sidecar_meta = { "mode": "remote", diff --git a/lightx2v/disagg/services/decoder.py b/lightx2v/disagg/services/decoder.py index 707c8d936..a3706db74 100644 --- a/lightx2v/disagg/services/decoder.py +++ b/lightx2v/disagg/services/decoder.py @@ -143,7 +143,7 @@ def init(self, config): if data_bootstrap_addr is None or data_bootstrap_room is None: return - if not str(os.getenv("IS_CENTRALIZED", "0")).strip().lower() in {"1", "true", "yes", "on"}: + if str(os.getenv("IS_CENTRALIZED", "0")).strip().lower() not in {"1", "true", "yes", "on"}: try: self._ensure_phase2_request_buffer() except Exception: diff --git a/lightx2v/disagg/services/encoder.py b/lightx2v/disagg/services/encoder.py index f434641cb..5b86b81df 100644 --- a/lightx2v/disagg/services/encoder.py +++ b/lightx2v/disagg/services/encoder.py @@ -10,7 +10,6 @@ import numpy as np import torch -import zmq from lightx2v.disagg.conn import MONITOR_POLLING_PORT, REQUEST_POLLING_PORT, DataArgs, DataManager, DataPoll, DataSender, DisaggregationMode, DisaggregationPhase, ReqManager from lightx2v.disagg.monitor import Reporter diff --git a/lightx2v/disagg/services/transformer.py b/lightx2v/disagg/services/transformer.py index ddd660cb2..f3917d88d 100644 --- a/lightx2v/disagg/services/transformer.py +++ b/lightx2v/disagg/services/transformer.py @@ -6,12 +6,11 @@ from collections import deque from multiprocessing import resource_tracker, shared_memory from typing import Any, List, Optional +from urllib.error import URLError +from urllib.request import Request, urlopen import numpy as np import torch -import zmq -from urllib.error import URLError -from urllib.request import Request, urlopen from lightx2v.disagg.conn import MONITOR_POLLING_PORT, REQUEST_POLLING_PORT, DataArgs, DataManager, DataPoll, DataReceiver, DataSender, DisaggregationMode, DisaggregationPhase, ReqManager from lightx2v.disagg.monitor import Reporter diff --git a/lightx2v/disagg/utils.py b/lightx2v/disagg/utils.py index 992f2a313..0d5a40da3 100644 --- a/lightx2v/disagg/utils.py +++ b/lightx2v/disagg/utils.py @@ -374,6 +374,7 @@ def load_wan_transformer(config: Dict[str, Any]): elif config.get("model_cls") == "wan2.2_moe": print("Loading MultiModelStruct module start", flush=True) from lightx2v.models.runners.wan.wan_runner import MultiModelStruct + print("Loading MultiModelStruct module done", flush=True) high_noise_model_path = os.path.join(config["model_path"], "high_noise_model") From ee4ed2d93721b98c1ba739164511fde8598ea2e4 Mon Sep 17 00:00:00 2001 From: zhtshr <547553598@qq.com> Date: Fri, 24 Apr 2026 16:55:04 +0800 Subject: [PATCH 9/9] run precommit --- lightx2v/models/schedulers/wan/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightx2v/models/schedulers/wan/scheduler.py b/lightx2v/models/schedulers/wan/scheduler.py index ff1e0a45a..076c1b24c 100755 --- a/lightx2v/models/schedulers/wan/scheduler.py +++ b/lightx2v/models/schedulers/wan/scheduler.py @@ -37,7 +37,7 @@ def refresh_from_config(self, config): self.caching_records = [True] * self.infer_steps self.caching_records_2 = [True] * self.infer_steps self.step_index = 0 - + def _uses_conditioned_latent_prefix(self): """Whether this Wan variant keeps a fixed latent prefix during diffusion.