diff --git a/configs/disagg/multi_node/wan22_i2v_distill_controller.json b/configs/disagg/multi_node/wan22_i2v_distill_controller.json new file mode 100644 index 000000000..f6162108e --- /dev/null +++ b/configs/disagg/multi_node/wan22_i2v_distill_controller.json @@ -0,0 +1,158 @@ +{ + "infer_steps": 4, + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "eps": 1e-06, + "model_type": "i2v", + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "rdma_buffer_slot_size": 8192, + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "fps": 16, + "use_image_encoder": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "int8-q8f", + "high_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "high_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "image_path": "/root/zht/LightX2V/assets/inputs/imgs/img_0.jpg", + "disagg_mode": "controller", + "disagg_config": { + "bootstrap_addr": "192.168.0.166", + "bootstrap_room": 0, + "ranks": 8, + "encoder_engine_rank": 0, + "transformer_engine_rank": 1, + "decoder_engine_rank": 2, + "protocol": "rdma", + "local_hostname": "192.168.0.166", + "metadata_server": "P2PHANDSHAKE", + "service_env": { + "RDMA_IFACE": "erdma_0" + }, + "remote_workdir": "/root/zht/LightX2V", + "remote_python_executable": "python", + "remote_activate_cmd": "source /root/miniconda3/etc/profile.d/conda.sh && conda activate lightx2v && export LD_LIBRARY_PATH=/root/miniconda3/envs/lightx2v/lib:${LD_LIBRARY_PATH:-}", + "remote_log_dir": "/root/zht/LightX2V/save_results", + "use_remote_proxy": true, + "remote_proxy_req_base_port": 28000, + "ssh_user": "root", + "ssh_options": [ + "-i", + "/root/.ssh/id_ed25519_zht", + "-o", + "BatchMode=yes", + "-o", + "StrictHostKeyChecking=no" + ], + "static_instance_slots": [ + { + "instance_type": "encoder", + "host": "192.168.0.139", + "engine_rank": 0, + "cuda_device": 0, + "env": { + "MOONCAKE_DEVICE_NAME": "eth0", + "MOONCAKE_LOCAL_HOSTNAME": "192.168.0.139" + } + }, + { + "instance_type": "transformer", + "host": "192.168.0.166", + "engine_rank": 1, + "cuda_device": 0, + "env": { + "MOONCAKE_DEVICE_NAME": "eth0", + "MOONCAKE_LOCAL_HOSTNAME": "192.168.0.166" + } + }, + { + "instance_type": "transformer", + "host": "192.168.0.166", + "engine_rank": 2, + "cuda_device": 1, + "env": { + "MOONCAKE_DEVICE_NAME": "eth0", + "MOONCAKE_LOCAL_HOSTNAME": "192.168.0.166" + } + }, + { + "instance_type": "transformer", + "host": "192.168.0.166", + "engine_rank": 3, + "cuda_device": 2, + "env": { + "MOONCAKE_DEVICE_NAME": "eth0", + "MOONCAKE_LOCAL_HOSTNAME": "192.168.0.166" + } + }, + { + "instance_type": "transformer", + "host": "192.168.0.166", + "engine_rank": 4, + "cuda_device": 3, + "env": { + "MOONCAKE_DEVICE_NAME": "eth0", + "MOONCAKE_LOCAL_HOSTNAME": "192.168.0.166" + } + }, + { + "instance_type": "transformer", + "host": "192.168.0.166", + "engine_rank": 5, + "cuda_device": 4, + "env": { + "MOONCAKE_DEVICE_NAME": "eth0", + "MOONCAKE_LOCAL_HOSTNAME": "192.168.0.166" + } + }, + { + "instance_type": "transformer", + "host": "192.168.0.166", + "engine_rank": 6, + "cuda_device": 5, + "env": { + "MOONCAKE_DEVICE_NAME": "eth0", + "MOONCAKE_LOCAL_HOSTNAME": "192.168.0.166" + } + }, + { + "instance_type": "decoder", + "host": "192.168.0.139", + "engine_rank": 7, + "cuda_device": 1, + "env": { + "MOONCAKE_DEVICE_NAME": "eth0", + "MOONCAKE_LOCAL_HOSTNAME": "192.168.0.139" + } + } + ] + } +} diff --git a/configs/disagg/multi_node/wan22_i2v_distill_decoder.json b/configs/disagg/multi_node/wan22_i2v_distill_decoder.json new file mode 100644 index 000000000..8549d4165 --- /dev/null +++ b/configs/disagg/multi_node/wan22_i2v_distill_decoder.json @@ -0,0 +1,58 @@ +{ + "infer_steps": 4, + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "eps": 1e-06, + "model_type": "i2v", + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "rdma_buffer_slot_size": 8192, + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "fps": 16, + "use_image_encoder": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "int8-q8f", + "high_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "high_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "image_path": "/root/zht/LightX2V/assets/inputs/imgs/img_0.jpg", + "disagg_mode": "decoder", + "disagg_config": { + "bootstrap_addr": "192.168.0.166", + "bootstrap_room": 0, + "ranks": 8, + "encoder_engine_rank": 0, + "transformer_engine_rank": 1, + "decoder_engine_rank": 2, + "protocol": "rdma", + "local_hostname": "192.168.0.139", + "metadata_server": "P2PHANDSHAKE" + } +} diff --git a/configs/disagg/multi_node/wan22_i2v_distill_encoder.json b/configs/disagg/multi_node/wan22_i2v_distill_encoder.json new file mode 100644 index 000000000..7b126a729 --- /dev/null +++ b/configs/disagg/multi_node/wan22_i2v_distill_encoder.json @@ -0,0 +1,58 @@ +{ + "infer_steps": 4, + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "eps": 1e-06, + "model_type": "i2v", + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "rdma_buffer_slot_size": 8192, + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "fps": 16, + "use_image_encoder": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "int8-q8f", + "high_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "high_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "image_path": "/root/zht/LightX2V/assets/inputs/imgs/img_0.jpg", + "disagg_mode": "encoder", + "disagg_config": { + "bootstrap_addr": "192.168.0.166", + "bootstrap_room": 0, + "ranks": 8, + "encoder_engine_rank": 0, + "transformer_engine_rank": 1, + "decoder_engine_rank": 2, + "protocol": "rdma", + "local_hostname": "192.168.0.139", + "metadata_server": "P2PHANDSHAKE" + } +} diff --git a/configs/disagg/multi_node/wan22_i2v_distill_transformer.json b/configs/disagg/multi_node/wan22_i2v_distill_transformer.json new file mode 100644 index 000000000..99572301f --- /dev/null +++ b/configs/disagg/multi_node/wan22_i2v_distill_transformer.json @@ -0,0 +1,58 @@ +{ + "infer_steps": 4, + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "eps": 1e-06, + "model_type": "i2v", + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "rdma_buffer_slot_size": 8192, + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "fps": 16, + "use_image_encoder": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "int8-q8f", + "high_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "high_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "image_path": "/root/zht/LightX2V/assets/inputs/imgs/img_0.jpg", + "disagg_mode": "transformer", + "disagg_config": { + "bootstrap_addr": "192.168.0.166", + "bootstrap_room": 0, + "ranks": 8, + "encoder_engine_rank": 0, + "transformer_engine_rank": 1, + "decoder_engine_rank": 2, + "protocol": "rdma", + "local_hostname": "192.168.0.166", + "metadata_server": "P2PHANDSHAKE" + } +} diff --git a/configs/disagg/wan_t2v_disagg_controller.json b/configs/disagg/multi_node/wan_t2v_disagg_controller.json similarity index 100% rename from configs/disagg/wan_t2v_disagg_controller.json rename to configs/disagg/multi_node/wan_t2v_disagg_controller.json diff --git a/configs/disagg/wan_t2v_disagg_decoder.json b/configs/disagg/multi_node/wan_t2v_disagg_decoder.json similarity index 100% rename from configs/disagg/wan_t2v_disagg_decoder.json rename to configs/disagg/multi_node/wan_t2v_disagg_decoder.json diff --git a/configs/disagg/wan_t2v_disagg_encoder.json b/configs/disagg/multi_node/wan_t2v_disagg_encoder.json similarity index 100% rename from configs/disagg/wan_t2v_disagg_encoder.json rename to configs/disagg/multi_node/wan_t2v_disagg_encoder.json diff --git a/configs/disagg/wan_t2v_disagg_transformer.json b/configs/disagg/multi_node/wan_t2v_disagg_transformer.json similarity index 100% rename from configs/disagg/wan_t2v_disagg_transformer.json rename to configs/disagg/multi_node/wan_t2v_disagg_transformer.json diff --git a/configs/disagg/single_node/wan22_i2v_distill_controller.json b/configs/disagg/single_node/wan22_i2v_distill_controller.json new file mode 100644 index 000000000..12c9c6730 --- /dev/null +++ b/configs/disagg/single_node/wan22_i2v_distill_controller.json @@ -0,0 +1,108 @@ +{ + "infer_steps": 4, + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "eps": 1e-06, + "model_type": "i2v", + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "rdma_buffer_slot_size": 8192, + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "fps": 16, + "use_image_encoder": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "int8-q8f", + "high_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "high_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "image_path": "/root/zht/LightX2V/assets/inputs/imgs/img_0.jpg", + "disagg_mode": "controller", + "disagg_config": { + "bootstrap_addr": "127.0.0.1", + "bootstrap_room": 0, + "ranks": 8, + "encoder_engine_rank": 0, + "transformer_engine_rank": 1, + "decoder_engine_rank": 2, + "protocol": "rdma", + "local_hostname": "127.0.0.1", + "metadata_server": "P2PHANDSHAKE", + "static_instance_slots": [ + { + "instance_type": "encoder", + "host": "127.0.0.1", + "engine_rank": 0, + "cuda_device": 0 + }, + { + "instance_type": "transformer", + "host": "127.0.0.1", + "engine_rank": 1, + "cuda_device": 1 + }, + { + "instance_type": "transformer", + "host": "127.0.0.1", + "engine_rank": 2, + "cuda_device": 2 + }, + { + "instance_type": "transformer", + "host": "127.0.0.1", + "engine_rank": 3, + "cuda_device": 3 + }, + { + "instance_type": "transformer", + "host": "127.0.0.1", + "engine_rank": 4, + "cuda_device": 4 + }, + { + "instance_type": "transformer", + "host": "127.0.0.1", + "engine_rank": 5, + "cuda_device": 5 + }, + { + "instance_type": "transformer", + "host": "127.0.0.1", + "engine_rank": 6, + "cuda_device": 6 + }, + { + "instance_type": "decoder", + "host": "127.0.0.1", + "engine_rank": 7, + "cuda_device": 7 + } + ] + } +} diff --git a/configs/disagg/wan22_i2v_workload_stages.json b/configs/disagg/wan22_i2v_workload_stages.json new file mode 100644 index 000000000..edf619f89 --- /dev/null +++ b/configs/disagg/wan22_i2v_workload_stages.json @@ -0,0 +1,28 @@ +[ + { + "name": "warmup", + "duration_s": 120, + "user_count": 1, + "spawn_rate": 0.1, + "wait_time_s": 0.0, + "config_variants": [ + { + "infer_steps": 4, + "sample_shift": 5.0 + } + ] + }, + { + "name": "change", + "duration_s": 1000, + "user_count": 1, + "spawn_rate": 0.1, + "wait_time_s": 0.0, + "config_variants": [ + { + "infer_steps": 4, + "sample_shift": 5.0 + } + ] + } +] diff --git a/lightx2v/disagg/README.md b/lightx2v/disagg/README.md new file mode 100644 index 000000000..c7efd0862 --- /dev/null +++ b/lightx2v/disagg/README.md @@ -0,0 +1,143 @@ +# disagg / `run_dynamic.sh` 使用说明 + +`scripts/disagg/run_dynamic.sh` 是 LightX2V 的动态多机/单机离线调度启动脚本。它会自动完成以下工作: + +1. 激活 `lightx2v` conda 环境,除非显式关闭。 +2. 先执行 `scripts/disagg/kill_service.sh` 清理残留进程和端口。 +3. 读取 controller 配置,准备 `multi_node` 或 `single_node` 启动参数。 +4. 为 Mooncake / RDMA / ZMQ / 日志收集设置默认环境变量。 +5. 启动 controller,并按配置拉起 encoder / transformer / decoder。 + +## 基本用法 + +最常见的方式是直接运行脚本: + +```bash +bash scripts/disagg/run_dynamic.sh +``` + +如果要切换拓扑或覆盖默认配置,可以在命令前追加环境变量: + +```bash +DISAGG_TOPOLOGY=multi_node \ +DISAGG_CONTROLLER_CFG=/root/zht/LightX2V/configs/disagg/multi_node/wan22_i2v_distill_controller.json \ +bash scripts/disagg/run_dynamic.sh +``` + +单机调试可以改成: + +```bash +DISAGG_TOPOLOGY=single_node \ +bash scripts/disagg/run_dynamic.sh +``` + +## 脚本会自动处理的事情 + +脚本会自动: + +1. 如果当前没有激活到 `DISAGG_CONDA_ENV`,就尝试 `conda activate`。 +2. 设置编译器和 `NVCC_PREPEND_FLAGS`,便于本地编译或运行扩展。 +3. 默认将 `RDMA_IFACE` 设为 `erdma_0`,将 `MOONCAKE_DEVICE_NAME` 设为 `eth0`。 +4. 如果没有显式设置 `MOONCAKE_LOCAL_HOSTNAME`,就从 `MOONCAKE_DEVICE_NAME` 对应网卡自动解析本机 IPv4。 +5. 根据 controller 配置里的 `bootstrap_addr` 自动推导 `DISAGG_CONTROLLER_HOST`。 +6. 先执行 `kill_service.sh` 清理旧服务,避免端口冲突。 + +## 环境变量说明 + +下面按功能分组说明常用变量。未特别说明时,都是脚本默认值。 + +### 运行模式与配置 + +| 变量 | 含义 | 默认值 | +| --- | --- | --- | +| `DISAGG_TOPOLOGY` | 运行拓扑,`multi_node` 表示多机,`single_node` 表示单机。 | `multi_node` | +| `DISAGG_CONTROLLER_CFG` | controller 配置文件路径。脚本会根据拓扑自动选择默认配置。 | `configs/disagg/multi_node/wan22_i2v_distill_controller.json` 或 single_node 对应文件 | +| `DISAGG_CONDA_ENV` | 启动时要激活的 conda 环境名。 | `lightx2v` | +| `DISAGG_SKIP_CONDA_ACTIVATE` | 设为 `1` 时跳过 conda 激活。 | `0` | +| `DISAGG_CONTROLLER_HOST` | controller 对外使用的主机地址。若未设置,脚本会尝试从配置文件 `bootstrap_addr` 推导。 | 配置里的 `bootstrap_addr`,否则 `127.0.0.1` | + +### RDMA / Mooncake + +| 变量 | 含义 | 默认值 | +| --- | --- | --- | +| `RDMA_IFACE` | 本机 RDMA / eRDMA 网卡名。 | `erdma_0` | +| `MOONCAKE_DEVICE_NAME` | Mooncake 用来解析本机 IPv4 的网卡名。 | `eth0` | +| `MOONCAKE_LOCAL_HOSTNAME` | Mooncake 认为的本机地址。若未设置,脚本会自动从 `MOONCAKE_DEVICE_NAME` 对应网卡提取 IPv4。 | 自动推导 | +| `RDMA_PREFERRED_IPV4` | 优先选择的 RDMA 数据平面 IPv4,通常用于多网卡环境下稳定选择 gid_index。 | 自动推导为 `DISAGG_CONTROLLER_HOST`(当其是 IPv4 且不是 `127.0.0.1`) | + +### 控制面端口与启动等待 + +| 变量 | 含义 | 默认值 | +| --- | --- | --- | +| `DISAGG_CONTROLLER_REQUEST_PORT` | controller 请求入口端口。 | `12786` | +| `DISAGG_INSTANCE_START_TIMEOUT_SECONDS` | 等待实例启动完成的超时时间。 | `single_node: 90`,`multi_node: 300` | +| `DISAGG_REMOTE_PROXY_START_TIMEOUT_SECONDS` | 等待远端 proxy 启动的超时时间。 | `120` | +| `DISAGG_SIDECAR_START_TIMEOUT_SECONDS` | 等待 sidecar 启动的超时时间。 | `60` | +| `CONTROLLER_WAIT_TIMEOUT_S` | 等待 controller 完成整轮任务的超时时间。 | `single_node: 3000`,`multi_node: 7200` | +| `CONTROLLER_POLL_INTERVAL_S` | controller 状态轮询间隔。 | `5` | + +### 请求数量、调试与通信方式 + +| 变量 | 含义 | 默认值 | +| --- | --- | --- | +| `LOAD_FROM_USER` | 设为非 `0` 时,由 user 侧持续发请求,直到阶段结束。 | `0` | +| `DISAGG_AUTO_REQUEST_COUNT` | 自动请求的默认数量。`LOAD_FROM_USER=0` 时会使用这个值。 | `30` | +| `USER_MAX_REQUESTS` | 手动限制 user 进程最多发多少个请求,优先级高于 `DISAGG_AUTO_REQUEST_COUNT`。 | 未设置 | +| `USER_START_DELAY_S` | user 进程启动后的延迟时间。 | `0` | +| `SYNC_COMM` | 是否启用同步通信模式。 | `0` | + +### Nsight 采集 + +| 变量 | 含义 | 默认值 | +| --- | --- | --- | +| `DISAGG_ENABLE_NSYS` | 是否启用 `nsys profile` 包裹实例进程。 | `0` | +| `DISAGG_NSYS_BIN` | `nsys` 可执行文件路径或命令名。 | `nsys` | +| `DISAGG_NSYS_OUTPUT_DIR` | nsys trace 输出目录。 | `save_results/nsys` | +| `DISAGG_NSYS_TRACE` | `nsys profile` 的 trace 类型。 | `cuda,nvtx,osrt` | +| `DISAGG_NSYS_EXTRA_ARGS` | 额外传给 `nsys profile` 的参数。 | 空 | + +### 日志与清理 + +| 变量 | 含义 | 默认值 | +| --- | --- | --- | +| `REMOTE_LOG_COLLECT` | 是否在结束后拉取远端日志。 | `1` | +| `REMOTE_LOG_COLLECT_DIR` | 远端日志收集到本地的目录。 | `save_results/remote_logs` | +| `DISAGG_REMOTE_PRE_CLEAN` | 是否在启动前先远端执行 `kill_service.sh`。 | `1` | +| `SEED` | 随机种子。 | `42` | +| `PROMPT` | 文本提示词。 | 脚本内置示例 prompt | +| `NEGATIVE_PROMPT` | 负向提示词。 | 脚本内置示例 negative prompt | +| `SAVE_RESULT_PATH` | 最终视频保存路径。 | `save_results/wan22_i2v_dynamic.mp4` | + +## 推荐的常见组合 + +### 本地单机调试 + +```bash +DISAGG_TOPOLOGY=single_node \ +LOAD_FROM_USER=0 \ +DISAGG_AUTO_REQUEST_COUNT=1 \ +bash scripts/disagg/run_dynamic.sh +``` + +### 多机标准跑法 + +```bash +DISAGG_TOPOLOGY=multi_node \ +DISAGG_CONTROLLER_CFG=/root/zht/LightX2V/configs/disagg/multi_node/wan22_i2v_distill_controller.json \ +DISAGG_AUTO_REQUEST_COUNT=30 \ +bash scripts/disagg/run_dynamic.sh +``` + +### 开启 Nsight + +```bash +DISAGG_ENABLE_NSYS=1 \ +DISAGG_NSYS_TRACE=cuda,nvtx,osrt \ +bash scripts/disagg/run_dynamic.sh +``` + +## 备注 + +1. 多机运行时,`DISAGG_CONTROLLER_CFG` 里的 `bootstrap_addr`、`static_instance_slots` 和各 slot 的 `env` 会直接影响远端实例如何绑定网络与 Mooncake 地址。 +2. 如果遇到端口占用,优先检查 `scripts/disagg/kill_service.sh` 是否已经把旧实例和 proxy 清理干净。 +3. 如果需要了解 controller 配置文件本身的字段含义,可以继续查看 `configs/disagg/` 下对应 JSON。 diff --git a/lightx2v/disagg/conn.py b/lightx2v/disagg/conn.py index 90ee69909..ca63cf593 100644 --- a/lightx2v/disagg/conn.py +++ b/lightx2v/disagg/conn.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import os import struct import threading from collections.abc import Mapping @@ -81,6 +82,14 @@ class DataPoll: DATARECEIVER_POLLING_PORT = 27788 +def _normalize_loopback_host(host: str) -> str: + normalized = (host or "").strip() + if os.getenv("DISAGG_FORCE_IPV4_LOOPBACK", "1") not in ("0", "false", "False"): + if normalized in ("localhost", "::1", ""): + return "127.0.0.1" + return normalized or "127.0.0.1" + + class DataManager: # TODO: make it general and support multiple transfer backend before merging def __init__(self, disaggregation_phase: DisaggregationPhase, disaggregation_mode: DisaggregationMode): @@ -146,18 +155,25 @@ def transfer_loop(): sender_data_ptrs = self.request_pool.pop(pending_room) self.sync_status_to_transformer_endpoint(endpoint, pending_room) - ret = self.send_data( - pending_room, - mooncake_session_id, - sender_data_ptrs, - receiver_ptrs, - ) + try: + ret = self.send_data( + pending_room, + mooncake_session_id, + sender_data_ptrs, + receiver_ptrs, + ) + except Exception: + logger.exception("Transfer loop exception room=%s session=%s", pending_room, mooncake_session_id) + ret = -1 with self.pool_lock: if ret != 0: self.request_status[pending_room] = DataPoll.Failed else: self.request_status[pending_room] = DataPoll.Success - self.sync_status_to_transformer_endpoint(endpoint, pending_room) + try: + self.sync_status_to_transformer_endpoint(endpoint, pending_room) + except Exception: + logger.exception("Failed to sync final status room=%s endpoint=%s", pending_room, endpoint) self.transfer_thread = threading.Thread(target=transfer_loop, name="data-transfer-thread") self.transfer_thread.start() @@ -295,25 +311,35 @@ def send_data( # TODO: transfer data in batch if there are many tensors or large tensors, instead of sending one by one. args = self.data_args[room] tensor_num = int(len(args.data_ptrs)) + chunk_bytes = int(os.getenv("MOONCAKE_TRANSFER_CHUNK_BYTES", str(1024 * 1024))) + if chunk_bytes <= 0: + chunk_bytes = 1024 * 1024 for tensor_id in range(tensor_num): sender_addr = sender_data_ptrs[tensor_id] item_len = args.data_item_lens[tensor_id] receiver_addr = receiver_ptrs[tensor_id] - # TODO: mooncake transfer engine can do async transfer. Do async later - status = self.engine.transfer_sync( - mooncake_session_id, - sender_addr, - receiver_addr, - item_len, - ) - if status != 0: - return status + offset = 0 + remaining = int(item_len) + while remaining > 0: + transfer_len = min(chunk_bytes, remaining) + # TODO: mooncake transfer engine can do async transfer. Do async later + status = self.engine.transfer_sync( + mooncake_session_id, + sender_addr + offset, + receiver_addr + offset, + transfer_len, + ) + if status != 0: + return status + offset += transfer_len + remaining -= transfer_len return 0 def sync_status_to_transformer_endpoint(self, remote: str, room: int): if ":" in remote: remote = remote.split(":")[0] + remote = _normalize_loopback_host(remote) receiver_rank = self.data_args[room].receiver_engine_rank receiver_rank_port = DATARECEIVER_POLLING_PORT + receiver_rank + room * 10 self._connect("tcp://" + remote + ":" + str(receiver_rank_port)).send_multipart( @@ -335,12 +361,14 @@ def encode_thread(): try: ( endpoint, + receiver_engine_rank_raw, mooncake_session_id, bootstrap_room, transformer_ptrs, ) = room_socket.recv_multipart() except zmq.Again: continue + receiver_engine_rank = int.from_bytes(receiver_engine_rank_raw, byteorder="big") if bootstrap_room.decode("ascii") == "None": continue endpoint = endpoint.decode("ascii") @@ -348,10 +376,11 @@ def encode_thread(): bootstrap_room = int(bootstrap_room.decode("ascii")) transformer_ptrs = list(struct.unpack(f"{len(transformer_ptrs) // 8}Q", transformer_ptrs)) logger.info( - "Encoder received ZMQ: endpoint=%s session_id=%s room=%s transformer_ptrs=%s", + "Encoder received ZMQ: endpoint=%s session_id=%s room=%s receiver_engine_rank=%s transformer_ptrs=%s", endpoint, mooncake_session_id, bootstrap_room, + receiver_engine_rank, transformer_ptrs, ) with self.pool_lock: @@ -360,6 +389,8 @@ def encode_thread(): mooncake_session_id, transformer_ptrs, ) + if bootstrap_room in self.data_args: + self.data_args[bootstrap_room].receiver_engine_rank = receiver_engine_rank if self.transfer_event is not None: self.transfer_event.set() @@ -405,12 +436,14 @@ def transformer_thread(): try: ( endpoint, + receiver_engine_rank_raw, mooncake_session_id, bootstrap_room, decode_ptrs, ) = room_socket.recv_multipart() except zmq.Again: continue + receiver_engine_rank = int.from_bytes(receiver_engine_rank_raw, byteorder="big") if bootstrap_room.decode("ascii") == "None": continue endpoint = endpoint.decode("ascii") @@ -418,10 +451,11 @@ def transformer_thread(): bootstrap_room = int(bootstrap_room.decode("ascii")) decode_ptrs = list(struct.unpack(f"{len(decode_ptrs) // 8}Q", decode_ptrs)) logger.info( - "Transformer received ZMQ: endpoint=%s session_id=%s room=%s decode_ptrs=%s", + "Transformer received ZMQ: endpoint=%s session_id=%s room=%s receiver_engine_rank=%s decode_ptrs=%s", endpoint, mooncake_session_id, bootstrap_room, + receiver_engine_rank, decode_ptrs, ) with self.pool_lock: @@ -430,6 +464,8 @@ def transformer_thread(): mooncake_session_id, decode_ptrs, ) + if bootstrap_room in self.data_args: + self.data_args[bootstrap_room].receiver_engine_rank = receiver_engine_rank if self.transfer_event is not None: self.transfer_event.set() @@ -541,9 +577,10 @@ def __init__(self, mgr: DataManager, bootstrap_addr: str, bootstrap_room: Option raise ValueError("bootstrap_room is required for DataReceiver") args = self.data_mgr.data_args[self.bootstrap_room] sender_rank_port = DATASENDER_POLLING_PORT + args.sender_engine_rank + self.bootstrap_room * 10 - self.sender_server_url = bootstrap_addr.split(":")[0] + ":" + str(sender_rank_port) + sender_host = _normalize_loopback_host(bootstrap_addr.split(":")[0]) + self.sender_server_url = sender_host + ":" + str(sender_rank_port) logger.info("DataReceiver sender_server_url=%s", self.sender_server_url) - self.receiver_ip = self.data_mgr.get_localhost() + self.receiver_ip = _normalize_loopback_host(self.data_mgr.get_localhost()) self.session_id = self.data_mgr.get_session_id() self.data_mgr.set_status(bootstrap_room, DataPoll.WaitingForInput) @@ -560,6 +597,7 @@ def init(self): self._connect("tcp://" + self.sender_server_url).send_multipart( [ self.receiver_ip.encode("ascii"), + args.receiver_engine_rank.to_bytes(4, byteorder="big"), self.session_id.encode("ascii"), str(self.bootstrap_room).encode("ascii"), packed_data_ptrs, diff --git a/lightx2v/disagg/examples/run_service.py b/lightx2v/disagg/examples/run_service.py index 515265214..f24140dfb 100644 --- a/lightx2v/disagg/examples/run_service.py +++ b/lightx2v/disagg/examples/run_service.py @@ -4,10 +4,6 @@ from loguru import logger -from lightx2v.disagg.services.controller import ControllerService -from lightx2v.disagg.services.decoder import DecoderService -from lightx2v.disagg.services.encoder import EncoderService -from lightx2v.disagg.services.transformer import TransformerService from lightx2v.disagg.utils import set_config from lightx2v.utils.utils import seed_all @@ -124,12 +120,20 @@ def main(): logger.info("Starting disagg service mode={}", service_mode) if service_mode == "encoder": + from lightx2v.disagg.services.encoder import EncoderService + EncoderService(config).run() elif service_mode == "transformer": + from lightx2v.disagg.services.transformer import TransformerService + TransformerService(config).run() elif service_mode == "decoder": + from lightx2v.disagg.services.decoder import DecoderService + DecoderService(config).run() elif service_mode == "controller": + from lightx2v.disagg.services.controller import ControllerService + ControllerService().run(config) else: raise ValueError(f"Unsupported service mode: {service_mode}") diff --git a/lightx2v/disagg/examples/run_user.py b/lightx2v/disagg/examples/run_user.py new file mode 100644 index 000000000..f360bcd44 --- /dev/null +++ b/lightx2v/disagg/examples/run_user.py @@ -0,0 +1,66 @@ +import argparse +import time + +from lightx2v.disagg.conn import REQUEST_POLLING_PORT, ReqManager +from lightx2v.disagg.workload import ( + DisaggLoadShape, + build_payload, + current_stage, + load_base_config, + load_stage_specs, + send_workload_end_signal, + start_workload_clock, +) + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Run dynamic disagg workload user and push configs to Controller") + parser.add_argument("--controller_host", type=str, default="127.0.0.1") + parser.add_argument("--controller_request_port", type=int, default=REQUEST_POLLING_PORT - 2) + parser.add_argument("--max_requests", type=int, default=0, help="0 means no hard cap") + parser.add_argument("--sleep_min_ms", type=float, default=5.0, help="minimum loop sleep in ms") + return parser + + +def main(): + args = _build_parser().parse_args() + + req_mgr = ReqManager() + stages = load_stage_specs() + base_config = load_base_config() + shape = DisaggLoadShape() + + start_workload_clock() + + sent = 0 + last_tick_ts = 0.0 + + while True: + tick = shape.tick() + if tick is None: + break + + _, spawn_rate = tick + spawn_rate = max(float(spawn_rate), 0.1) + + stage = current_stage(stages) + payload = build_payload(base_config, stage, sent) + req_mgr.send(args.controller_host, args.controller_request_port, payload) + sent += 1 + + now = time.time() + if now - last_tick_ts >= 1.0: + print(f"stage={stage.name} spawn_rate={spawn_rate:.3f} req/s sent={sent}") + last_tick_ts = now + + if args.max_requests > 0 and sent >= args.max_requests: + break + + time.sleep(max(1.0 / spawn_rate, args.sleep_min_ms / 1000.0)) + + send_workload_end_signal() + print(f"workload finished: sent={sent}, end signal sent") + + +if __name__ == "__main__": + main() diff --git a/lightx2v/disagg/mooncake.py b/lightx2v/disagg/mooncake.py index 756574d5a..e26ca3caf 100644 --- a/lightx2v/disagg/mooncake.py +++ b/lightx2v/disagg/mooncake.py @@ -1,11 +1,55 @@ import json import logging import os +import random +import socket +import time from dataclasses import dataclass logger = logging.getLogger(__name__) +def _detect_non_loopback_ipv4() -> str | None: + # Use a UDP connect trick to discover the outbound interface IP without sending traffic. + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.connect(("8.8.8.8", 80)) + ip = sock.getsockname()[0] + sock.close() + if ip and not ip.startswith("127."): + return ip + except Exception: + pass + + try: + host_ip = socket.gethostbyname(socket.gethostname()) + if host_ip and not host_ip.startswith("127."): + return host_ip + except Exception: + pass + + return None + + +def _collect_local_ipv4_addresses() -> list[str]: + candidates: list[str] = [] + + try: + hostname = socket.gethostname() + for info in socket.getaddrinfo(hostname, None, socket.AF_INET): + address = info[4][0] + if address and not address.startswith("127.") and address not in candidates: + candidates.append(address) + except Exception: + pass + + detected = _detect_non_loopback_ipv4() + if detected is not None and detected not in candidates: + candidates.append(detected) + + return candidates + + @dataclass class MooncakeTransferEngineConfig: local_hostname: str @@ -29,7 +73,57 @@ def load_from_env() -> "MooncakeTransferEngineConfig": config_file_path = os.getenv("MOONCAKE_CONFIG_PATH", "/root/zht/LightX2V/configs/mooncake_config.json") if config_file_path is None: raise ValueError("The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") - return MooncakeTransferEngineConfig.from_file(config_file_path) + cfg = MooncakeTransferEngineConfig.from_file(config_file_path) + local_ipv4s = _collect_local_ipv4_addresses() + + env_metadata_server = os.getenv("MOONCAKE_METADATA_SERVER", "").strip() + if env_metadata_server: + cfg.metadata_server = env_metadata_server + + env_protocol = os.getenv("MOONCAKE_PROTOCOL", "").strip() + if env_protocol: + cfg.protocol = env_protocol + + env_device_name = os.getenv("MOONCAKE_DEVICE_NAME", "").strip() + if env_device_name: + cfg.device_name = env_device_name + + # Keep session IDs and metadata endpoints stable on single-node runs. + # localhost may resolve to IPv6 on some hosts while peers use IPv4. + force_ipv4 = os.getenv("MOONCAKE_FORCE_IPV4_LOOPBACK", "1") not in ("0", "false", "False") + env_host = os.getenv("MOONCAKE_LOCAL_HOSTNAME", "").strip() + if env_host: + if env_host in ("localhost", "::1", "127.0.0.1") or env_host in local_ipv4s: + cfg.local_hostname = env_host + else: + detected = _detect_non_loopback_ipv4() + if detected is not None: + logger.warning( + "Ignoring MOONCAKE_LOCAL_HOSTNAME=%s because it does not match this host (local_ipv4s=%s); using %s", + env_host, + local_ipv4s, + detected, + ) + cfg.local_hostname = detected + elif force_ipv4: + cfg.local_hostname = "127.0.0.1" + elif force_ipv4 and cfg.local_hostname in ("localhost", "::1", "127.0.0.1"): + detected = _detect_non_loopback_ipv4() + if detected is not None: + cfg.local_hostname = detected + else: + cfg.local_hostname = "127.0.0.1" + elif cfg.local_hostname not in local_ipv4s and cfg.local_hostname not in ("localhost", "::1", "127.0.0.1"): + detected = _detect_non_loopback_ipv4() + if detected is not None: + logger.warning( + "Auto-correcting Mooncake local_hostname from %s to %s on this host (local_ipv4s=%s)", + cfg.local_hostname, + detected, + local_ipv4s, + ) + cfg.local_hostname = detected + return cfg class MooncakeTransferEngine: @@ -89,11 +183,53 @@ def initialize( def transfer_sync(self, session_id: str, buffer: int, peer_buffer_address: int, length: int) -> int: """Synchronously transfer data to the specified address.""" if self.engine: - ret = self.engine.transfer_sync_write(session_id, buffer, peer_buffer_address, length) - if ret < 0: - logger.error("Transfer Return Error") - raise Exception("Transfer Return Error") - return ret + if os.getenv("NETWORK_LATENCY"): + latency_prob_raw = os.getenv("NETWORK_LATENCY_PROB", "0.02") + latency_sec_raw = os.getenv("NETWORK_LATENCY_SEC", "5") + try: + latency_prob = float(latency_prob_raw) + except ValueError: + latency_prob = 0.02 + # Accept either ratio (0.02) or percentage (2 / 5). + if latency_prob > 1.0: + latency_prob = latency_prob / 100.0 + latency_prob = max(0.0, min(1.0, latency_prob)) + + try: + latency_sec = float(latency_sec_raw) + except ValueError: + latency_sec = 5.0 + latency_sec = max(0.0, latency_sec) + + if random.random() < latency_prob: + logger.warning( + "Simulated network latency: sleeping %.3fs before transfer_sync_write (prob=%.4f)", + latency_sec, + latency_prob, + ) + time.sleep(latency_sec) + + retry_count = int(os.getenv("MOONCAKE_TRANSFER_RETRY", "5")) + retry_backoff_s = float(os.getenv("MOONCAKE_TRANSFER_RETRY_BACKOFF_S", "0.05")) + for attempt in range(retry_count + 1): + ret = self.engine.transfer_sync_write(session_id, buffer, peer_buffer_address, length) + if ret >= 0: + return ret + + logger.warning( + "Transfer Return Error attempt=%s/%s session=%s src=0x%x dst=0x%x len=%s", + attempt + 1, + retry_count + 1, + session_id, + int(buffer), + int(peer_buffer_address), + int(length), + ) + if attempt < retry_count: + time.sleep(retry_backoff_s) + + logger.error("Transfer Return Error after retries") + return -1 return -1 def get_localhost(self): diff --git a/lightx2v/disagg/rdma_base.py b/lightx2v/disagg/rdma_base.py new file mode 100644 index 000000000..90a7fb6e2 --- /dev/null +++ b/lightx2v/disagg/rdma_base.py @@ -0,0 +1,75 @@ +"""Shared pyverbs imports and thin RDMA types for client/server.""" + +from __future__ import annotations + +import pyverbs.enums as e +from pyverbs.addr import GID, AHAttr, GlobalRoute +from pyverbs.cq import CQ +from pyverbs.device import Context, get_device_list +from pyverbs.mr import MR +from pyverbs.pd import PD +from pyverbs.qp import QP, QPAttr, QPCap, QPInitAttr +from pyverbs.wr import SGE +from pyverbs.wr import SendWR as WR + +from lightx2v.disagg.rdma_utils import ( + recv_json_from_stream, + resolve_gid_index, + rtr_ah_dest_dlid, + rtr_path_mtu, + rtr_path_mtu_negotiated, +) + + +class IBDevice: + def __init__(self, name: str): + self.name = name + + def open(self): + return Context(name=self.name) + + +class QPType: + RC = e.IBV_QPT_RC + + +class WROpcode: + RDMA_WRITE = e.IBV_WR_RDMA_WRITE + RDMA_READ = e.IBV_WR_RDMA_READ + ATOMIC_FETCH_AND_ADD = e.IBV_WR_ATOMIC_FETCH_AND_ADD + ATOMIC_CMP_AND_SWP = e.IBV_WR_ATOMIC_CMP_AND_SWP + + +class AccessFlag: + LOCAL_WRITE = e.IBV_ACCESS_LOCAL_WRITE + REMOTE_WRITE = e.IBV_ACCESS_REMOTE_WRITE + REMOTE_READ = e.IBV_ACCESS_REMOTE_READ + REMOTE_ATOMIC = e.IBV_ACCESS_REMOTE_ATOMIC + + +__all__ = [ + "AccessFlag", + "AHAttr", + "CQ", + "Context", + "GID", + "GlobalRoute", + "IBDevice", + "MR", + "PD", + "QP", + "QPAttr", + "QPCap", + "QPInitAttr", + "QPType", + "SGE", + "WR", + "WROpcode", + "e", + "get_device_list", + "recv_json_from_stream", + "resolve_gid_index", + "rtr_ah_dest_dlid", + "rtr_path_mtu", + "rtr_path_mtu_negotiated", +] diff --git a/lightx2v/disagg/rdma_buffer.py b/lightx2v/disagg/rdma_buffer.py index 896efbbcf..bb03ad4ed 100644 --- a/lightx2v/disagg/rdma_buffer.py +++ b/lightx2v/disagg/rdma_buffer.py @@ -4,6 +4,7 @@ import json import logging import threading +import time from dataclasses import dataclass from typing import Any, Dict, Optional @@ -16,6 +17,8 @@ logger = logging.getLogger(__name__) +_U64_MASK = (1 << 64) - 1 + @dataclass class RDMABufferDescriptor: @@ -38,11 +41,6 @@ class RDMABuffer: - client: consumer side, reads slots remotely and updates head by rdma_faa. The ring stores serialized JSON configs in fixed-size slots. - - Multi-consumer note: multiple client processes calling ``consume()`` compete on the - same head pointer. Unless the backend implements a true remote atomic fetch-add - (see ``RDMAClient.rdma_faa``), correctness under heavy parallel consumption is not - guaranteed. Prefer one consumer per ring or low parallelism for production. """ def __init__( @@ -91,8 +89,8 @@ def __init__( base_addr = int(info["addr"]) need_bytes = 16 + self.buffer_size * self.slot_size self.rdma_server.register_memory(base_addr, need_bytes) - self.rdma_server.write_memory(base_addr, (0).to_bytes(8, byteorder="big", signed=False)) - self.rdma_server.write_memory(base_addr + 8, (0).to_bytes(8, byteorder="big", signed=False)) + self.rdma_server.write_memory(base_addr, (0).to_bytes(8, byteorder="little", signed=False)) + self.rdma_server.write_memory(base_addr + 8, (0).to_bytes(8, byteorder="little", signed=False)) self._descriptor = RDMABufferDescriptor( slot_addr=base_addr + 16, slot_bytes=self.buffer_size * self.slot_size, @@ -125,10 +123,14 @@ def descriptor(self) -> RDMABufferDescriptor: return self._descriptor def _write_local_u64(self, buf: bytearray, value: int): - buf[:8] = int(value).to_bytes(8, byteorder="big", signed=False) + buf[:8] = (int(value) & _U64_MASK).to_bytes(8, byteorder="little", signed=False) def _read_local_u64(self, buf: bytearray) -> int: - return int.from_bytes(bytes(buf[:8]), byteorder="big", signed=False) + return int.from_bytes(bytes(buf[:8]), byteorder="little", signed=False) + + def _u64_distance(self, newer: int, older: int) -> int: + """Return unsigned circular distance on a 64-bit counter space.""" + return (int(newer) - int(older)) & _U64_MASK def _rdma_faa(self, ptr_addr: int, add_value: int) -> int: if self.rdma_client is not None: @@ -138,21 +140,52 @@ def _rdma_faa(self, ptr_addr: int, add_value: int) -> int: with self._lock: old = self._read_remote_u64(ptr_addr) new = (old + int(add_value)) & ((1 << 64) - 1) - self._rdma_write_bytes(ptr_addr, new.to_bytes(8, byteorder="big", signed=False)) + self._rdma_write_bytes(ptr_addr, new.to_bytes(8, byteorder="little", signed=False)) return old # Fallback: local atomic emulation (useful for single-process validation). with self._lock: if ptr_addr == self.descriptor.head_addr: old = self._read_local_u64(self._head_mem) - self._write_local_u64(self._head_mem, old + int(add_value)) + self._write_local_u64(self._head_mem, (old + int(add_value)) & _U64_MASK) return old if ptr_addr == self.descriptor.tail_addr: old = self._read_local_u64(self._tail_mem) - self._write_local_u64(self._tail_mem, old + int(add_value)) + self._write_local_u64(self._tail_mem, (old + int(add_value)) & _U64_MASK) return old raise RuntimeError("rdma_faa failed and no local fallback for ptr") + def _rdma_cas(self, ptr_addr: int, compare_value: int, swap_value: int) -> int: + if self.rdma_client is not None: + return self.rdma_client.rdma_cas( + ptr_addr, + int(compare_value), + int(swap_value), + rkey=self.descriptor.rkey, + ) + + if self.rdma_server is not None: + with self._lock: + old = self._read_remote_u64(ptr_addr) + if old == (int(compare_value) & _U64_MASK): + new = int(swap_value) & _U64_MASK + self._rdma_write_bytes(ptr_addr, new.to_bytes(8, byteorder="little", signed=False)) + return old + + # Local fallback for single-process testing. + with self._lock: + if ptr_addr == self.descriptor.head_addr: + old = self._read_local_u64(self._head_mem) + if old == (int(compare_value) & _U64_MASK): + self._write_local_u64(self._head_mem, int(swap_value) & _U64_MASK) + return old + if ptr_addr == self.descriptor.tail_addr: + old = self._read_local_u64(self._tail_mem) + if old == (int(compare_value) & _U64_MASK): + self._write_local_u64(self._tail_mem, int(swap_value) & _U64_MASK) + return old + raise RuntimeError("rdma_cas failed and no local fallback for ptr") + def _rdma_read_bytes(self, remote_addr: int, length: int) -> bytes: if self.rdma_server is not None and self._descriptor is not None: base = self._descriptor.head_addr @@ -207,7 +240,7 @@ def _rdma_write_bytes(self, remote_addr: int, payload: bytes): def _read_remote_u64(self, remote_addr: int) -> int: raw = self._rdma_read_bytes(remote_addr, 8) - return int.from_bytes(raw, byteorder="big", signed=False) + return int.from_bytes(raw, byteorder="little", signed=False) def _slot_offset(self, index: int) -> int: return (index % self.buffer_size) * self.slot_size @@ -223,38 +256,40 @@ def _deserialize_config(self, raw_slot: bytes) -> Dict[str, Any]: raise ValueError("invalid slot payload") plen = int.from_bytes(raw_slot[:4], byteorder="little", signed=False) if plen == 0: - return {} + raise ValueError("slot payload is not committed yet") + if plen > self.slot_size - 4: + raise ValueError(f"invalid slot payload length: {plen}") data = raw_slot[4 : 4 + plen] - return json.loads(data.decode("utf-8")) + try: + return json.loads(data.decode("utf-8")) + except Exception as exc: + raise ValueError("slot payload is incomplete or corrupted") from exc def produce(self, config: Dict[str, Any]) -> int: """Produce one config into ring buffer and advance tail by rdma_faa.""" if self.rdma_server is None and self.rdma_client is None: raise RuntimeError("produce requires rdma_server or rdma_client") - # Reserve one slot by atomically incrementing tail. - old_tail = self._rdma_faa(self.descriptor.tail_addr, 1) + # Read current indices first, write the slot fully, then publish by advancing tail. + old_tail = self._read_remote_u64(self.descriptor.tail_addr) cur_head = self._read_remote_u64(self.descriptor.head_addr) - used = (old_tail + 1) - cur_head - if used > self.buffer_size: - self._rdma_faa(self.descriptor.tail_addr, -1) - logger.error( - "Ring buffer full: old_tail=%d cur_head=%d used=%d buffer_size=%d", - old_tail, - cur_head, - used, - self.buffer_size, - ) + if self._u64_distance(old_tail, cur_head) >= self.buffer_size: raise BufferError("ring buffer is full") slot_idx = old_tail % self.buffer_size offset = self._slot_offset(slot_idx) payload = self._serialize_config(config) + payload_len_header = payload[:4] + payload_body = payload[4:] # Write payload to the selected slot (works for both server-local and client-remote paths). slot_addr = self.descriptor.slot_addr + offset self._rdma_write_bytes(slot_addr, b"\x00" * self.slot_size) - self._rdma_write_bytes(slot_addr, payload) + if payload_body: + self._rdma_write_bytes(slot_addr + 4, payload_body) + # Write length header last so consumers never parse a half-written payload. + self._rdma_write_bytes(slot_addr, payload_len_header) + self._rdma_faa(self.descriptor.tail_addr, 1) logger.info("Produced config to RDMA buffer slot %d", slot_idx) return slot_idx @@ -263,38 +298,74 @@ def consume(self) -> Optional[Dict[str, Any]]: if self.role != "client": raise RuntimeError("consume is only allowed in client role") - try: - cur_head = self._read_remote_u64(self.descriptor.head_addr) - cur_tail = self._read_remote_u64(self.descriptor.tail_addr) - except Exception as exc: - return None - - if cur_head >= cur_tail: - return None + max_claim_retries = max(8, self.buffer_size * 2) + claim_retry_sleep_seconds = 0.001 - try: - old_head = self._rdma_faa(self.descriptor.head_addr, 1) - except Exception as exc: - return None + for _ in range(max_claim_retries): + try: + cur_head = self._read_remote_u64(self.descriptor.head_addr) + cur_tail = self._read_remote_u64(self.descriptor.tail_addr) + except Exception: + return None + + # Fast path: empty queue, do not touch head. + if self._u64_distance(cur_tail, cur_head) == 0: + return None + + slot_idx = cur_head % self.buffer_size + slot_addr = self.descriptor.slot_addr + self._slot_offset(slot_idx) + max_read_retries = 5 + retry_sleep_seconds = 0.002 + last_error: Optional[Exception] = None + config: Optional[Dict[str, Any]] = None + + for _ in range(max_read_retries): + try: + raw = self._rdma_read_bytes(slot_addr, self.slot_size) + config = self._deserialize_config(raw) + last_error = None + break + except Exception as exc: + last_error = exc + time.sleep(retry_sleep_seconds) + + if config is None: + # Keep head unchanged so this slot can be retried later. + logger.warning( + "RDMA buffer slot %d read incomplete after retries, keeping head unchanged: %s", + slot_idx, + last_error, + ) + return None - if old_head >= cur_tail: try: - self._rdma_faa(self.descriptor.head_addr, -1) + old_head = self._rdma_cas( + self.descriptor.head_addr, + cur_head, + (cur_head + 1) & _U64_MASK, + ) except Exception as exc: - logger.warning("RDMA buffer rollback failed on empty consume: %s", exc) - logger.debug( - "Consume race lost: old_head=%d cur_tail=%d (rolled back)", - old_head, - cur_tail, - ) - return None + logger.warning("RDMA buffer head CAS failed for slot %d: %s", slot_idx, exc) + return None - slot_idx = old_head % self.buffer_size - slot_addr = self.descriptor.slot_addr + self._slot_offset(slot_idx) - try: - raw = self._rdma_read_bytes(slot_addr, self.slot_size) - except Exception as exc: - logger.warning("RDMA buffer slot read failed for slot %d: %s", slot_idx, exc) - return None - logger.info("Consumed config from RDMA buffer slot %d", slot_idx) - return self._deserialize_config(raw) + if old_head != cur_head: + # Another consumer advanced head first; retry from latest head. + time.sleep(claim_retry_sleep_seconds) + continue + + logger.info("Consumed config from RDMA buffer slot %d", slot_idx) + return config + + logger.warning("RDMA buffer consume contention is too high, skip this round") + return None + + def pending_count(self) -> int: + """Return current queue length inferred from ring tail/head counters.""" + cur_head = self._read_remote_u64(self.descriptor.head_addr) + cur_tail = self._read_remote_u64(self.descriptor.tail_addr) + pending = int(self._u64_distance(cur_tail, cur_head)) + if pending <= 0: + return 0 + if pending > self.buffer_size: + return self.buffer_size + return pending diff --git a/lightx2v/disagg/rdma_client.py b/lightx2v/disagg/rdma_client.py index d8a08362c..b80d8fc03 100644 --- a/lightx2v/disagg/rdma_client.py +++ b/lightx2v/disagg/rdma_client.py @@ -1,50 +1,50 @@ import json +import logging +import os +import random import socket import threading import time -import pyverbs.enums as e -from pyverbs.addr import GID, AHAttr, GlobalRoute -from pyverbs.cq import CQ -from pyverbs.device import Context, get_device_list -from pyverbs.mr import MR -from pyverbs.pd import PD -from pyverbs.qp import QP, QPAttr, QPCap, QPInitAttr -from pyverbs.wr import SGE -from pyverbs.wr import SendWR as WR - - -class IBDevice: - def __init__(self, name: str): - self.name = name - - def open(self): - return Context(name=self.name) - - -class QPType: - RC = e.IBV_QPT_RC - - -class WROpcode: - RDMA_WRITE = e.IBV_WR_RDMA_WRITE - RDMA_READ = e.IBV_WR_RDMA_READ - ATOMIC_FETCH_AND_ADD = e.IBV_WR_ATOMIC_FETCH_AND_ADD - ATOMIC_CMP_AND_SWP = e.IBV_WR_ATOMIC_CMP_AND_SWP - - -class AccessFlag: - LOCAL_WRITE = e.IBV_ACCESS_LOCAL_WRITE - REMOTE_WRITE = e.IBV_ACCESS_REMOTE_WRITE - REMOTE_READ = e.IBV_ACCESS_REMOTE_READ - REMOTE_ATOMIC = e.IBV_ACCESS_REMOTE_ATOMIC +from lightx2v.disagg.rdma_base import ( + CQ, + GID, + MR, + PD, + QP, + SGE, + WR, + AHAttr, + AccessFlag, + GlobalRoute, + IBDevice, + QPAttr, + QPCap, + QPInitAttr, + QPType, + WROpcode, + e, + get_device_list, + recv_json_from_stream, + resolve_gid_index, + rtr_ah_dest_dlid, + rtr_path_mtu, + rtr_path_mtu_negotiated, +) + +logger = logging.getLogger(__name__) class RDMAClient: def __init__(self, iface_name=None, local_buffer_size=4096): self.local_psn = 654321 + self._next_psn = (int(time.time() * 1000000) & 0xFFFFFF) or 1 self.port_num = 1 - self.gid_index = 1 + self.gid_index = 0 + if iface_name is None: + env_iface = os.getenv("RDMA_IFACE", "").strip() + if env_iface: + iface_name = env_iface if iface_name is None: devices = get_device_list() if not devices: @@ -54,12 +54,13 @@ def __init__(self, iface_name=None, local_buffer_size=4096): self.ctx = IBDevice(iface_name).open() self.pd = PD(self.ctx) - self.cq = CQ(self.ctx, 10) + self.cq = CQ(self.ctx, 64) + self.gid_index = self._resolve_gid_index() - qp_init_attr = QPCap(max_send_wr=10, max_recv_wr=10, max_send_sge=1, max_recv_sge=1) - qia = QPInitAttr(qp_type=QPType.RC, scq=self.cq, rcq=self.cq, cap=qp_init_attr) - qa = QPAttr(port_num=self.port_num) - self.qp = QP(self.pd, qia, qa) + qp_init_attr = QPCap(max_send_wr=64, max_recv_wr=64, max_send_sge=1, max_recv_sge=1) + self._qia = QPInitAttr(qp_type=QPType.RC, scq=self.cq, rcq=self.cq, cap=qp_init_attr) + self._qa = QPAttr(port_num=self.port_num) + self.qp = QP(self.pd, self._qia, self._qa) # 客户端也需要注册内存,用于发送数据的源 (Write) 或接收数据的目标 (Read) self.buffer_size = int(local_buffer_size) @@ -67,6 +68,58 @@ def __init__(self, iface_name=None, local_buffer_size=4096): raise ValueError("local_buffer_size must be positive") self.local_mr = MR(self.pd, self.buffer_size, AccessFlag.LOCAL_WRITE) self._io_lock = threading.RLock() + self._connected_server_ip: str | None = None + self._connected_server_port: int | None = None + self._qp_error_state: bool = False + self._last_wc_error_message: str = "" + + def has_qp_error(self) -> bool: + return self._qp_error_state + + def last_wc_error_message(self) -> str: + return self._last_wc_error_message + + def _wc_status_name(self, status: int | None) -> str: + if status is None: + return "UNKNOWN" + status_map = { + getattr(e, "IBV_WC_SUCCESS", -1): "IBV_WC_SUCCESS", + getattr(e, "IBV_WC_LOC_LEN_ERR", -2): "IBV_WC_LOC_LEN_ERR", + getattr(e, "IBV_WC_LOC_QP_OP_ERR", -3): "IBV_WC_LOC_QP_OP_ERR", + getattr(e, "IBV_WC_LOC_PROT_ERR", -4): "IBV_WC_LOC_PROT_ERR", + getattr(e, "IBV_WC_WR_FLUSH_ERR", -5): "IBV_WC_WR_FLUSH_ERR", + getattr(e, "IBV_WC_MW_BIND_ERR", -6): "IBV_WC_MW_BIND_ERR", + getattr(e, "IBV_WC_BAD_RESP_ERR", -7): "IBV_WC_BAD_RESP_ERR", + getattr(e, "IBV_WC_LOC_ACCESS_ERR", -8): "IBV_WC_LOC_ACCESS_ERR", + getattr(e, "IBV_WC_REM_INV_REQ_ERR", -9): "IBV_WC_REM_INV_REQ_ERR", + getattr(e, "IBV_WC_REM_ACCESS_ERR", -10): "IBV_WC_REM_ACCESS_ERR", + getattr(e, "IBV_WC_REM_OP_ERR", -11): "IBV_WC_REM_OP_ERR", + getattr(e, "IBV_WC_RETRY_EXC_ERR", -12): "IBV_WC_RETRY_EXC_ERR", + getattr(e, "IBV_WC_RNR_RETRY_EXC_ERR", -13): "IBV_WC_RNR_RETRY_EXC_ERR", + getattr(e, "IBV_WC_REM_ABORT_ERR", -14): "IBV_WC_REM_ABORT_ERR", + } + return status_map.get(status, f"IBV_WC_STATUS_{status}") + + def _resolve_gid_index(self): + return resolve_gid_index(self.ctx, self.port_num) + + def _alloc_local_psn(self): + self._next_psn = (self._next_psn + 1) & 0xFFFFFF + if self._next_psn == 0: + self._next_psn = 1 + self.local_psn = self._next_psn + return self.local_psn + + def _reset_qp(self): + old_qp = getattr(self, "qp", None) + self.qp = QP(self.pd, self._qia, self._qa) + if old_qp is not None: + close_fn = getattr(old_qp, "close", None) + if callable(close_fn): + try: + close_fn() + except Exception: + pass def _ensure_local_mr_capacity(self, required_size: int): required = int(required_size) @@ -76,52 +129,136 @@ def _ensure_local_mr_capacity(self, required_size: int): self.local_mr = MR(self.pd, self.buffer_size, AccessFlag.LOCAL_WRITE) def connect_to_server(self, server_ip="127.0.0.1", port=5566): - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - try: - sock.connect((server_ip, port)) - except Exception: - sock.close() - raise - - # 1. 接收 Server 信息 (包含 rkey 和 addr) - data = sock.recv(4096) - self.remote_info = json.loads(data.decode()) - print(f"[Client] Got Server Info: Addr={hex(self.remote_info['addr'])}, RKey={self.remote_info['rkey']}") - - # 2. 发送我的信息给 Server - gid = self.ctx.query_gid(self.port_num, self.gid_index) - my_info = { - "lid": self.ctx.query_port(self.port_num).lid, - "qpn": self.qp.qp_num, - "psn": self.local_psn, - "gid": str(gid), - "gid_index": self.gid_index, - } - sock.sendall(json.dumps(my_info).encode()) - - # 3. 修改 QP 状态 - self._modify_qp_to_rts() - self.sock = sock - print("[Client] Connection established (RTS)") + max_retries = max(1, int(os.getenv("RDMA_CLIENT_CONNECT_RETRIES", "30"))) + connect_timeout_sec = float(os.getenv("RDMA_CLIENT_CONNECT_TIMEOUT_SEC", "2.0")) + backoff_base_sec = float(os.getenv("RDMA_CLIENT_BACKOFF_BASE_SEC", "0.1")) + backoff_max_sec = float(os.getenv("RDMA_CLIENT_BACKOFF_MAX_SEC", "2.0")) + jitter_ratio = float(os.getenv("RDMA_CLIENT_BACKOFF_JITTER", "0.2")) + + last_exc = None + for attempt in range(1, max_retries + 1): + sock = None + try: + old_sock = getattr(self, "sock", None) + if old_sock is not None: + try: + old_sock.close() + except Exception: + pass + self.sock = None + + self._reset_qp() + self._alloc_local_psn() + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(connect_timeout_sec) + sock.connect((server_ip, port)) + + # 1. 接收 Server 信息 (包含 rkey 和 addr) + remote_info = recv_json_from_stream(sock, timeout_sec=connect_timeout_sec) + if not isinstance(remote_info, dict): + raise RuntimeError(f"Invalid handshake payload type: {type(remote_info)}") + required_keys = {"addr", "rkey", "qpn", "psn", "gid"} + missing = required_keys.difference(remote_info.keys()) + if missing: + raise RuntimeError(f"Handshake missing keys: {sorted(missing)}") + self.remote_info = remote_info + print(f"[Client] Got Server Info: Addr={hex(int(self.remote_info['addr']))}, RKey={self.remote_info['rkey']}") + + # 2. 发送我的信息给 Server + gid = self.ctx.query_gid(port_num=self.port_num, index=self.gid_index) + my_info = { + "lid": self.ctx.query_port(port_num=self.port_num).lid, + "qpn": self.qp.qp_num, + "psn": self.local_psn, + "gid": str(gid), + "gid_index": self.gid_index, + "active_mtu": int(rtr_path_mtu(self.ctx, self.port_num)), + } + sock.sendall(json.dumps(my_info).encode()) + + # 3. 修改 QP 状态 + self._modify_qp_to_rts() + sock.settimeout(None) + self.sock = sock + self._connected_server_ip = str(server_ip) + self._connected_server_port = int(port) + self._qp_error_state = False + self._last_wc_error_message = "" + print(f"[Client] Connection established (RTS) to {server_ip}:{port} at attempt {attempt}/{max_retries}") + return + except Exception as exc: + last_exc = exc + if sock is not None: + try: + sock.close() + except Exception: + pass + + if attempt < max_retries: + backoff = min(backoff_max_sec, backoff_base_sec * (2 ** (attempt - 1))) + if jitter_ratio > 0: + jitter = random.uniform(1.0 - jitter_ratio, 1.0 + jitter_ratio) + backoff = max(0.01, backoff * jitter) + print(f"[Client] Handshake attempt {attempt}/{max_retries} failed to {server_ip}:{port}: {exc}. Retrying in {backoff:.2f}s") + time.sleep(backoff) + + raise RuntimeError(f"RDMA client failed to connect to {server_ip}:{port} after {max_retries} attempts") from last_exc def _modify_qp_to_rts(self): # Follow the standard RC flow: INIT -> RTR -> RTS. - init_attr = QPAttr(port_num=self.port_num) - init_attr.qp_access_flags = AccessFlag.LOCAL_WRITE | AccessFlag.REMOTE_WRITE | AccessFlag.REMOTE_READ | AccessFlag.REMOTE_ATOMIC - self.qp.to_init(init_attr) - - rtr_attr = QPAttr(port_num=self.port_num) - rtr_attr.path_mtu = e.IBV_MTU_1024 - rtr_attr.max_dest_rd_atomic = 1 - rtr_attr.min_rnr_timer = 12 - rtr_attr.dest_qp_num = int(self.remote_info["qpn"]) - rtr_attr.rq_psn = int(self.remote_info["psn"]) - remote_lid = int(self.remote_info.get("lid", 0)) - remote_gid_index = int(self.remote_info.get("gid_index", self.gid_index)) - gr = GlobalRoute(dgid=GID(self.remote_info["gid"]), sgid_index=remote_gid_index) - rtr_attr.ah_attr = AHAttr(port_num=self.port_num, is_global=1, gr=gr, dlid=remote_lid) - self.qp.to_rtr(rtr_attr) + heuristic_dlid = rtr_ah_dest_dlid(self.ctx, self.port_num, remote_lid) + negotiated_mtu = int(rtr_path_mtu_negotiated(self.ctx, self.port_num, self.remote_info.get("active_mtu"))) + local_mtu = int(rtr_path_mtu(self.ctx, self.port_num)) + default_mtu = int(e.IBV_MTU_1024) + + # Some eRDMA/RoCE stacks are strict about dlid/mtu combinations; try safe fallbacks. + mtu_candidates = [] + for v in (negotiated_mtu, local_mtu, default_mtu): + if v not in mtu_candidates: + mtu_candidates.append(v) + dlid_candidates = [] + for v in (heuristic_dlid, 0, remote_lid): + if v not in dlid_candidates: + dlid_candidates.append(v) + + gr = GlobalRoute(dgid=GID(self.remote_info["gid"]), sgid_index=self.gid_index, hop_limit=1) + last_exc = None + for rd_atomic in (1, 0): + for mtu in mtu_candidates: + for dlid in dlid_candidates: + for is_global in (1, 0): + try: + init_attr = QPAttr(port_num=self.port_num) + init_attr.qp_access_flags = AccessFlag.LOCAL_WRITE | AccessFlag.REMOTE_WRITE | AccessFlag.REMOTE_READ | AccessFlag.REMOTE_ATOMIC + self.qp.to_init(init_attr) + + rtr_attr = QPAttr(port_num=self.port_num) + rtr_attr.path_mtu = int(mtu) + rtr_attr.max_dest_rd_atomic = int(rd_atomic) + rtr_attr.min_rnr_timer = 12 + rtr_attr.dest_qp_num = int(self.remote_info["qpn"]) + rtr_attr.rq_psn = int(self.remote_info["psn"]) + # Some drivers require GRH(is_global=1), others only accept non-GRH. + if is_global == 1: + rtr_attr.ah_attr = AHAttr(port_num=self.port_num, is_global=1, gr=gr, dlid=int(dlid)) + else: + rtr_attr.ah_attr = AHAttr(port_num=self.port_num, is_global=0, dlid=int(dlid)) + self.qp.to_rtr(rtr_attr) + last_exc = None + break + except Exception as exc: + last_exc = exc + continue + if last_exc is None: + break + if last_exc is None: + break + if last_exc is None: + break + if last_exc is not None: + raise last_exc rts_attr = QPAttr(port_num=self.port_num) rts_attr.timeout = 14 @@ -196,6 +333,10 @@ def rdma_write_to(self, remote_addr, data_bytes, rkey=None): self.remote_info["rkey"] = int(rkey) try: self.rdma_write(data_bytes, notify_server=False) + except Exception as exc: + raise RuntimeError( + f"rdma_write_to failed server={self._connected_server_ip}:{self._connected_server_port} remote_addr={int(remote_addr)} length={len(data_bytes)} rkey={self.remote_info.get('rkey')}" + ) from exc finally: self.remote_info["addr"] = old_addr self.remote_info["rkey"] = old_rkey @@ -210,6 +351,10 @@ def rdma_read_from(self, remote_addr, length, rkey=None): self.remote_info["rkey"] = int(rkey) try: return self.rdma_read(int(length)) + except Exception as exc: + raise RuntimeError( + f"rdma_read_from failed server={self._connected_server_ip}:{self._connected_server_port} remote_addr={int(remote_addr)} length={int(length)} rkey={self.remote_info.get('rkey')}" + ) from exc finally: self.remote_info["addr"] = old_addr self.remote_info["rkey"] = old_rkey @@ -239,7 +384,7 @@ def rdma_faa(self, remote_addr, add_value, rkey=None): self._poll_cq() old = self.local_mr.read(8, 0) - old_v = int.from_bytes(old, byteorder="big", signed=False) + old_v = int.from_bytes(old, byteorder="little", signed=False) return old_v def rdma_cas(self, remote_addr, compare_value, swap_value, rkey=None): @@ -247,6 +392,7 @@ def rdma_cas(self, remote_addr, compare_value, swap_value, rkey=None): with self._io_lock: self._ensure_local_mr_capacity(8) + # The original remote value will be written into this local buffer. self.local_mr.write(b"\x00" * 8, 8, 0) sge = SGE(self.local_mr.buf, 8, self.local_mr.lkey) @@ -267,7 +413,7 @@ def rdma_cas(self, remote_addr, compare_value, swap_value, rkey=None): self._poll_cq() old = self.local_mr.read(8, 0) - old_v = int.from_bytes(old, byteorder="big", signed=False) + old_v = int.from_bytes(old, byteorder="little", signed=False) return old_v def _poll_cq(self): @@ -284,7 +430,24 @@ def _poll_cq(self): raise RuntimeError(f"Unexpected WC object: {wc}") if status != e.IBV_WC_SUCCESS: vendor_err = getattr(wc, "vendor_err", None) - raise Exception(f"WC Error: {status}, vendor_err: {vendor_err}") + wr_id = getattr(wc, "wr_id", None) + opcode = getattr(wc, "opcode", None) + status_name = self._wc_status_name(status) + self._qp_error_state = True + self._last_wc_error_message = ( + f"status={status}({status_name}) vendor_err={vendor_err} wr_id={wr_id} opcode={opcode} server={self._connected_server_ip}:{self._connected_server_port}" + ) + logger.error( + "RDMA CQ failure: status=%s(%s) vendor_err=%s wr_id=%s opcode=%s server=%s:%s", + status, + status_name, + vendor_err, + wr_id, + opcode, + self._connected_server_ip, + self._connected_server_port, + ) + raise RuntimeError(f"WC Error: {status}({status_name}), vendor_err: {vendor_err}, wr_id: {wr_id}, opcode: {opcode}") break time.sleep(0.0001) diff --git a/lightx2v/disagg/rdma_server.py b/lightx2v/disagg/rdma_server.py index 8f71889a6..b431665eb 100644 --- a/lightx2v/disagg/rdma_server.py +++ b/lightx2v/disagg/rdma_server.py @@ -1,37 +1,30 @@ import json +import os import socket import threading -import pyverbs.enums as e -from pyverbs.addr import GID, AHAttr, GlobalRoute -from pyverbs.cq import CQ -from pyverbs.device import Context, get_device_list -from pyverbs.mr import MR -from pyverbs.pd import PD -from pyverbs.qp import QP, QPAttr, QPCap, QPInitAttr - - -class IBDevice: - def __init__(self, name: str): - self.name = name - - def open(self): - return Context(name=self.name) - - -class QPType: - RC = e.IBV_QPT_RC - - -class WROpcode: - RDMA_WRITE = e.IBV_WR_RDMA_WRITE - - -class AccessFlag: - LOCAL_WRITE = e.IBV_ACCESS_LOCAL_WRITE - REMOTE_WRITE = e.IBV_ACCESS_REMOTE_WRITE - REMOTE_READ = e.IBV_ACCESS_REMOTE_READ - REMOTE_ATOMIC = e.IBV_ACCESS_REMOTE_ATOMIC +from lightx2v.disagg.rdma_base import ( + CQ, + GID, + MR, + PD, + QP, + AHAttr, + AccessFlag, + GlobalRoute, + IBDevice, + QPAttr, + QPCap, + QPInitAttr, + QPType, + e, + get_device_list, + recv_json_from_stream, + resolve_gid_index, + rtr_ah_dest_dlid, + rtr_path_mtu, + rtr_path_mtu_negotiated, +) class RDMAServer: @@ -39,10 +32,14 @@ def __init__(self, iface_name=None, port_num=1, buffer_size=4096): self.local_psn = 123456 self._next_psn = int(self.local_psn) self.port_num = port_num - self.gid_index = 1 + self.gid_index = 0 self.buffer_size = int(buffer_size) if self.buffer_size <= 0: raise ValueError("buffer_size must be positive") + if iface_name is None: + env_iface = os.getenv("RDMA_IFACE", "").strip() + if env_iface: + iface_name = env_iface if iface_name is None: devices = get_device_list() if not devices: @@ -59,10 +56,11 @@ def __init__(self, iface_name=None, port_num=1, buffer_size=4096): raise RuntimeError(f"Failed to open RDMA device '{iface_name}'. Available devices: {available}") self.pd = PD(self.ctx) - self.cq = CQ(self.ctx, 10) + self.cq = CQ(self.ctx, 64) + self.gid_index = self._resolve_gid_index() # 创建 QP (Queue Pair) - qp_init_attr = QPCap(max_send_wr=10, max_recv_wr=10, max_send_sge=1, max_recv_sge=1) + qp_init_attr = QPCap(max_send_wr=64, max_recv_wr=64, max_send_sge=1, max_recv_sge=1) qia = QPInitAttr(qp_type=QPType.RC, scq=self.cq, rcq=self.cq, cap=qp_init_attr) qa = QPAttr(port_num=self.port_num) self.qp = QP(self.pd, qia, qa) # RC: Reliable Connected @@ -91,6 +89,9 @@ def __init__(self, iface_name=None, port_num=1, buffer_size=4096): self._mr_addr = int(mr_addr) print(f"[Server] MR Registered. Addr: {mr_addr}, RKey: {self.mr.rkey}") + def _resolve_gid_index(self): + return resolve_gid_index(self.ctx, self.port_num) + def register_memory(self, addr: int, length: int): """Validate a requested sub-region against server MR and return registration metadata. @@ -131,7 +132,27 @@ def get_local_info(self, qp=None, psn=None): qp = self.qp if qp is None else qp psn = self.local_psn if psn is None else int(psn) gid = self.ctx.query_gid(self.port_num, self.gid_index) - return {"lid": self.ctx.query_port(self.port_num).lid, "qpn": qp.qp_num, "psn": psn, "gid": str(gid), "gid_index": self.gid_index, "rkey": self.mr.rkey, "addr": mr_addr} + return { + "lid": self.ctx.query_port(self.port_num).lid, + "qpn": qp.qp_num, + "psn": psn, + "gid": str(gid), + "gid_index": self.gid_index, + "rkey": self.mr.rkey, + "addr": mr_addr, + "active_mtu": int(rtr_path_mtu(self.ctx, self.port_num)), + } + + @staticmethod + def _safe_destroy_qp(qp): + if qp is None: + return + close_fn = getattr(qp, "close", None) + if callable(close_fn): + try: + close_fn() + except Exception: + pass def _alloc_qp_with_psn(self): with self._conn_lock: @@ -147,18 +168,24 @@ def _accept_one_client(self, listen_sock): print(f"[Server] Connected to {addr}") qp, local_psn = self._alloc_qp_with_psn() - - # 1. 发送我的信息给 Client - my_info = self.get_local_info(qp=qp, psn=local_psn) - conn.sendall(json.dumps(my_info).encode()) - - # 2. 接收 Client 的信息 - data = conn.recv(4096) - remote_info = json.loads(data.decode()) - print(f"[Server] Received remote info: QPN={remote_info['qpn']}") - - # 3. 修改 QP 状态到 RTS - self._modify_qp_to_rts(qp, remote_info, local_psn) + try: + # 1. 发送我的信息给 Client + my_info = self.get_local_info(qp=qp, psn=local_psn) + conn.sendall(json.dumps(my_info).encode()) + + # 2. 接收 Client 的信息(可能分片,勿单次 recv) + remote_info = recv_json_from_stream(conn, timeout_sec=30.0) + print(f"[Server] Received remote info: QPN={remote_info['qpn']}") + + # 3. 修改 QP 状态到 RTS + self._modify_qp_to_rts(qp, remote_info, local_psn) + except BaseException: + self._safe_destroy_qp(qp) + try: + conn.close() + except Exception: + pass + raise with self._conn_lock: self._active_qps.append(qp) @@ -189,22 +216,56 @@ def handshake(self, host="0.0.0.0", port=5566, serve_forever=True): def _modify_qp_to_rts(self, qp, remote_info, local_psn): # Follow the standard RC flow: INIT -> RTR -> RTS. - init_attr = QPAttr(port_num=self.port_num) - init_attr.qp_access_flags = AccessFlag.LOCAL_WRITE | AccessFlag.REMOTE_WRITE | AccessFlag.REMOTE_READ | AccessFlag.REMOTE_ATOMIC - qp.to_init(init_attr) - - rtr_attr = QPAttr(port_num=self.port_num) - rtr_attr.path_mtu = e.IBV_MTU_1024 - rtr_attr.max_dest_rd_atomic = 1 - rtr_attr.min_rnr_timer = 12 - rtr_attr.dest_qp_num = int(remote_info["qpn"]) - rtr_attr.rq_psn = int(remote_info["psn"]) - remote_lid = int(remote_info.get("lid", 0)) - remote_gid_index = int(remote_info.get("gid_index", self.gid_index)) - gr = GlobalRoute(dgid=GID(remote_info["gid"]), sgid_index=remote_gid_index) - rtr_attr.ah_attr = AHAttr(port_num=self.port_num, is_global=1, gr=gr, dlid=remote_lid) - qp.to_rtr(rtr_attr) + heuristic_dlid = rtr_ah_dest_dlid(self.ctx, self.port_num, remote_lid) + negotiated_mtu = int(rtr_path_mtu_negotiated(self.ctx, self.port_num, remote_info.get("active_mtu"))) + local_mtu = int(rtr_path_mtu(self.ctx, self.port_num)) + default_mtu = int(e.IBV_MTU_1024) + + mtu_candidates = [] + for v in (negotiated_mtu, local_mtu, default_mtu): + if v not in mtu_candidates: + mtu_candidates.append(v) + dlid_candidates = [] + for v in (heuristic_dlid, 0, remote_lid): + if v not in dlid_candidates: + dlid_candidates.append(v) + + gr = GlobalRoute(dgid=GID(remote_info["gid"]), sgid_index=self.gid_index, hop_limit=1) + last_exc = None + for rd_atomic in (1, 0): + for mtu in mtu_candidates: + for dlid in dlid_candidates: + for is_global in (1, 0): + try: + init_attr = QPAttr(port_num=self.port_num) + init_attr.qp_access_flags = AccessFlag.LOCAL_WRITE | AccessFlag.REMOTE_WRITE | AccessFlag.REMOTE_READ | AccessFlag.REMOTE_ATOMIC + qp.to_init(init_attr) + + rtr_attr = QPAttr(port_num=self.port_num) + rtr_attr.path_mtu = int(mtu) + rtr_attr.max_dest_rd_atomic = int(rd_atomic) + rtr_attr.min_rnr_timer = 12 + rtr_attr.dest_qp_num = int(remote_info["qpn"]) + rtr_attr.rq_psn = int(remote_info["psn"]) + if is_global == 1: + rtr_attr.ah_attr = AHAttr(port_num=self.port_num, is_global=1, gr=gr, dlid=int(dlid)) + else: + rtr_attr.ah_attr = AHAttr(port_num=self.port_num, is_global=0, dlid=int(dlid)) + qp.to_rtr(rtr_attr) + last_exc = None + break + except Exception as exc: + last_exc = exc + continue + if last_exc is None: + break + if last_exc is None: + break + if last_exc is None: + break + if last_exc is not None: + raise last_exc rts_attr = QPAttr(port_num=self.port_num) rts_attr.timeout = 14 diff --git a/lightx2v/disagg/rdma_utils.py b/lightx2v/disagg/rdma_utils.py new file mode 100644 index 000000000..bf0df1631 --- /dev/null +++ b/lightx2v/disagg/rdma_utils.py @@ -0,0 +1,234 @@ +from __future__ import annotations + +import ipaddress +import json +import logging +import os +import socket +import time + +logger = logging.getLogger(__name__) + + +def _collect_local_ipv4_addresses() -> list[str]: + candidates: list[str] = [] + + try: + hostname = socket.gethostname() + for info in socket.getaddrinfo(hostname, None, socket.AF_INET): + address = info[4][0] + if address and not address.startswith("127.") and address not in candidates: + candidates.append(address) + except Exception: + pass + + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + sock.connect(("8.8.8.8", 80)) + address = sock.getsockname()[0] + if address and not address.startswith("127.") and address not in candidates: + candidates.append(address) + finally: + sock.close() + except Exception: + pass + + try: + address = socket.gethostbyname(socket.gethostname()) + if address and not address.startswith("127.") and address not in candidates: + candidates.append(address) + except Exception: + pass + + return candidates + + +def _gid_to_ipv4(gid_text: str) -> str | None: + """Map an IPv6 GID string to IPv4 when it is IPv4-mapped.""" + text = str(gid_text).strip() + if not text or text == "::": + return None + lower = text.lower() + if lower.startswith("::ffff:"): + return _canonical_ipv4(text[7:]) + try: + mapped = ipaddress.ip_address(text).ipv4_mapped + if mapped is not None: + return str(mapped) + except ValueError: + pass + return None + + +def _canonical_ipv4(text: str) -> str | None: + text = str(text).strip() + if not text: + return None + try: + return str(ipaddress.IPv4Address(text)) + except Exception: + return None + + +def _preferred_rdma_ipv4() -> str | None: + """RoCE GID row to prefer when auto-picking gid_index (multi-node / multi-homing).""" + v = _canonical_ipv4(os.getenv("RDMA_PREFERRED_IPV4", "")) + if v: + return v + return _canonical_ipv4(os.getenv("MOONCAKE_LOCAL_HOSTNAME", "")) + + +def resolve_gid_index(ctx, port_num: int, env_var_name: str = "RDMA_GID_INDEX") -> int: + local_ipv4s = _collect_local_ipv4_addresses() + preferred = _preferred_rdma_ipv4() + + env_gid = os.getenv(env_var_name, "").strip() + if env_gid: + try: + idx = int(env_gid) + except ValueError: + idx = -1 + else: + try: + gid_text = str(ctx.query_gid(port_num=port_num, index=idx)) + except Exception: + gid_text = "" + else: + if gid_text and gid_text != "::": + ipv4 = _gid_to_ipv4(gid_text) + if ipv4 is not None and (ipv4 in local_ipv4s or ipv4 == preferred): + return idx + + try: + logger.warning( + "Ignoring RDMA_GID_INDEX=%s because it does not map to a local IPv4 on this host (local_ipv4s=%s preferred=%s)", + env_gid, + local_ipv4s, + preferred, + ) + except Exception: + pass + + if preferred: + for idx in range(16): + try: + gid_text = str(ctx.query_gid(port_num=port_num, index=idx)) + except Exception: + continue + if not gid_text or gid_text == "::": + continue + ipv4 = _gid_to_ipv4(gid_text) + if ipv4 == preferred: + return idx + + mapped_candidates: list[tuple[int, str]] = [] + first_non_empty_idx: int | None = None + + for idx in range(16): + try: + gid_text = str(ctx.query_gid(port_num=port_num, index=idx)) + except Exception: + continue + + if not gid_text or gid_text == "::": + continue + + if first_non_empty_idx is None: + first_non_empty_idx = idx + + ipv4 = _gid_to_ipv4(gid_text) + if ipv4 is not None: + mapped_candidates.append((idx, ipv4)) + if ipv4 in local_ipv4s: + return idx + + if mapped_candidates: + return mapped_candidates[0][0] + + if first_non_empty_idx is not None: + return first_non_empty_idx + + ctx.query_gid(port_num=port_num, index=0) + return 0 + + +def recv_json_from_stream(sock: socket.socket, timeout_sec: float = 10.0) -> dict: + """Read one JSON object from a TCP stream (handles split packets).""" + decoder = json.JSONDecoder() + chunks: list[bytes] = [] + deadline = time.time() + float(timeout_sec) + while time.time() < deadline: + try: + sock.settimeout(max(0.01, deadline - time.time())) + chunk = sock.recv(65536) + except socket.timeout: + continue + if not chunk: + break + chunks.append(chunk) + payload = b"".join(chunks).decode("utf-8", errors="strict") + try: + obj, _ = decoder.raw_decode(payload) + if isinstance(obj, dict): + return obj + except json.JSONDecodeError: + continue + msg = b"".join(chunks).decode("utf-8", errors="ignore") + raise RuntimeError(f"Incomplete handshake JSON from peer: {msg!r}") + + +def rtr_ah_dest_dlid(ctx, port_num: int, remote_lid: int) -> int: + """Destination LID for RC QP RTR when using GRH. + + RoCE (Ethernet link layer) expects dlid 0 with a valid dgid in the GRH; some + drivers still report a non-zero port LID, and using that in AHAttr triggers + ibv_modify_qp RTR EINVAL on many setups. + """ + rl = int(remote_lid) + raw = os.getenv("RDMA_RTR_DLID", "").strip().lower() + if raw in ("0", "zero", "roce", "eth"): + return 0 + if raw in ("peer", "remote", "ib", "infiniband"): + return rl + try: + port = ctx.query_port(port_num) + ll = int(getattr(port, "link_layer", -1)) + local_lid = int(getattr(port, "lid", -1)) + # rdma-core ibv_port_attr.link_layer: 0 unspecified, 1 InfiniBand, 2 Ethernet (RoCE). + if ll == 2: + return 0 + if ll == 1: + return rl + # Unspecified / unknown link_layer: eRDMA and some stacks omit or mis-report; + # RoCE uses LID 0 — using a non-zero dlid here causes RTR EINVAL. + if local_lid == 0: + return 0 + except Exception: + pass + if rl == 0: + return 0 + return rl + + +def rtr_path_mtu(ctx, port_num: int) -> int: + """Use port active MTU for RTR path_mtu (avoids hard-coded 1024 vs link mismatch).""" + try: + port = ctx.query_port(port_num) + return int(port.active_mtu) + except Exception: + import pyverbs.enums as e + + return int(e.IBV_MTU_1024) + + +def rtr_path_mtu_negotiated(ctx, port_num: int, peer_active_mtu: int | None) -> int: + """path_mtu for RTR must not exceed either peer's active MTU (IB enum ordering).""" + local = rtr_path_mtu(ctx, port_num) + if peer_active_mtu is None: + return local + try: + peer = int(peer_active_mtu) + except (TypeError, ValueError): + return local + return min(local, peer) diff --git a/lightx2v/disagg/services/base.py b/lightx2v/disagg/services/base.py index ca2df2b0f..3fc66b220 100644 --- a/lightx2v/disagg/services/base.py +++ b/lightx2v/disagg/services/base.py @@ -1,9 +1,61 @@ -import logging +import sys from abc import ABC -# Configure logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) +from loguru import logger as loguru_logger + +loguru_logger.remove() +loguru_logger.add( + sys.stderr, + level="INFO", + format="[{level}] {time:DD MMM YYYY HH:mm:ss} | {name}:{function}:{line} - {message}", +) + + +class _LoguruLoggerAdapter: + def __init__(self, logger): + self._logger = logger + + @staticmethod + def _format_message(message, args): + if not args: + return message + try: + return message % args + except Exception: + return f"{message} {' '.join(map(str, args))}" + + def debug(self, message, *args, **kwargs): + self._logger.opt(depth=1).debug(self._format_message(message, args)) + + def info(self, message, *args, **kwargs): + self._logger.opt(depth=1).info(self._format_message(message, args)) + + def warning(self, message, *args, **kwargs): + self._logger.opt(depth=1).warning(self._format_message(message, args)) + + def error(self, message, *args, **kwargs): + self._logger.opt(depth=1).error(self._format_message(message, args)) + + def critical(self, message, *args, **kwargs): + self._logger.opt(depth=1).critical(self._format_message(message, args)) + + def exception(self, message, *args, **kwargs): + self._logger.opt(depth=1, exception=True).error(self._format_message(message, args)) + + def log(self, level, message, *args, **kwargs): + self._logger.opt(depth=1).log(level, self._format_message(message, args)) + + def bind(self, **kwargs): + return _LoguruLoggerAdapter(self._logger.bind(**kwargs)) + + def opt(self, *args, **kwargs): + return self._logger.opt(*args, **kwargs) + + def __getattr__(self, item): + return getattr(self._logger, item) + + +logger = _LoguruLoggerAdapter(loguru_logger) class BaseService(ABC): @@ -12,4 +64,17 @@ def __init__(self): Base initialization for all services. """ self.logger = logger - self.logger.info(f"Initializing {self.__class__.__name__}") + self.logger.info("Initializing %s", self.__class__.__name__) + + def _sync_runtime_config(self, config): + current_config = getattr(self, "config", None) + if current_config is None: + self.config = dict(config) + return self.config + + if current_config is not config: + current_config.clear() + current_config.update(config) + + self.config = current_config + return self.config diff --git a/lightx2v/disagg/services/controller.py b/lightx2v/disagg/services/controller.py index 189aa3904..403a51e42 100644 --- a/lightx2v/disagg/services/controller.py +++ b/lightx2v/disagg/services/controller.py @@ -1,6 +1,20 @@ +import ipaddress +import json +import os +import shlex +import shutil +import signal +import socket +import subprocess +import sys import time +from collections.abc import Mapping +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from pathlib import Path from threading import Event, Lock, Thread +from typing import Any + +import zmq from lightx2v.disagg.conn import MONITOR_POLLING_PORT, REQUEST_POLLING_PORT, ReqManager from lightx2v.disagg.monitor import Monitor @@ -28,6 +42,1544 @@ def __init__(self): self._rdma_handshake_thread_request: Thread | None = None self._rdma_handshake_thread_phase1: Thread | None = None self._rdma_handshake_thread_phase2: Thread | None = None + self._instance_lock = Lock() + self._free_gpus: set[int] = set() + self._managed_instances: dict[str, dict[str, Any]] = {} + self.started_instances: list[tuple[str, str]] = [] + self._runtime_config: dict[str, Any] | None = None + self._bootstrap_addr: str = "127.0.0.1" + self._gpu_reuse_block_until: dict[int, float] = {} + self._gpu_reuse_grace_seconds: float = 5.0 + self._graceful_reclaim_timeout_seconds: float = float(os.getenv("DISAGG_RECLAIM_GRACEFUL_TIMEOUT_SECONDS", "30.0")) + self._force_kill_wait_seconds: float = float(os.getenv("DISAGG_RECLAIM_FORCE_KILL_WAIT_SECONDS", "1.0")) + self._instance_start_timeout_seconds: float = float(os.getenv("DISAGG_INSTANCE_START_TIMEOUT_SECONDS", "90.0")) + self._sidecar_start_timeout_seconds: float = float(os.getenv("DISAGG_SIDECAR_START_TIMEOUT_SECONDS", "15.0")) + self._sidecar_drain_idle_seconds: float = float(os.getenv("DISAGG_SIDECAR_DRAIN_IDLE_SECONDS", "1.0")) + # <= 0 means wait indefinitely until sidecar pending queues are drained. + self._sidecar_drain_timeout_seconds: float = float(os.getenv("DISAGG_SIDECAR_DRAIN_TIMEOUT_SECONDS", "0")) + self._remote_proxy_start_timeout_seconds: float = float(os.getenv("DISAGG_REMOTE_PROXY_START_TIMEOUT_SECONDS", "20.0")) + self._sidecar_reclaim_threads: list[Thread] = [] + self._shutting_down: bool = False + self._enable_monitor: bool = False + self._static_instance_slots: list[dict[str, Any]] = [] + self._free_slot_ids: set[int] = set() + self._slot_reuse_block_until: dict[int, float] = {} + self._local_host_aliases: set[str] = set() + self._request_metrics_by_room: dict[int, dict[str, Any]] = {} + + def _is_monitor_enabled(self) -> bool: + raw = os.getenv("ENABLE_MONITOR") + if raw is None: + return False + return str(raw).strip().lower() in {"1", "true", "yes", "on"} + + def _is_centralized_enabled(self) -> bool: + raw = os.getenv("IS_CENTRALIZED") + if raw is None: + return False + return str(raw).strip().lower() in {"1", "true", "yes", "on"} + + def _is_tcp_port_open(self, host: str, port: int) -> bool: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(0.2) + return sock.connect_ex((host, port)) == 0 + + def _wait_for_tcp_port_state(self, host: str, port: int, should_be_open: bool, timeout_seconds: float) -> bool: + deadline = time.time() + timeout_seconds + while time.time() < deadline: + is_open = self._is_tcp_port_open(host, port) + if is_open == should_be_open: + return True + time.sleep(0.1) + return self._is_tcp_port_open(host, port) == should_be_open + + def _refresh_local_host_aliases(self): + aliases: set[str] = { + "127.0.0.1", + "localhost", + str(self._bootstrap_addr), + } + try: + hostname = socket.gethostname() + aliases.add(hostname) + aliases.add(socket.getfqdn()) + host_info = socket.gethostbyname_ex(hostname) + aliases.update(host_info[1]) + aliases.update(host_info[2]) + except Exception: + pass + self._local_host_aliases = {item.strip() for item in aliases if isinstance(item, str) and item.strip()} + + def _is_local_host(self, host: str) -> bool: + normalized = str(host).strip() + if not normalized: + return False + if normalized in self._local_host_aliases: + return True + try: + return socket.gethostbyname(normalized) in self._local_host_aliases + except Exception: + return False + + def _ensure_rdma_preferred_ipv4_env(self, host: str, env: dict[str, str]) -> None: + """So RoCE gid_index matches the data-plane IP on each worker (multi-node).""" + if env.get("RDMA_PREFERRED_IPV4"): + return + h = str(host).strip() + if not h: + return + try: + env["RDMA_PREFERRED_IPV4"] = str(ipaddress.IPv4Address(h)) + except Exception: + pass + + def _allocate_free_tcp_port(self, bind_host: str | None = None) -> int: + host = str(bind_host or self._bootstrap_addr) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind((host, 0)) + return int(sock.getsockname()[1]) + + def _build_service_command(self, instance_type: str, engine_rank: int, instance_cfg: dict[str, Any], service_config_json: str) -> list[str]: + return [ + sys.executable, + "-m", + "lightx2v.disagg.examples.run_service", + "--service", + instance_type, + "--engine_rank", + str(engine_rank), + "--model_cls", + str(instance_cfg.get("model_cls", "wan2.1")), + "--task", + str(instance_cfg.get("task", "t2v")), + "--model_path", + str(instance_cfg.get("model_path")), + "--config_json", + service_config_json, + "--seed", + str(instance_cfg.get("seed", 42)), + "--prompt", + str(instance_cfg.get("prompt", "")), + "--negative_prompt", + str(instance_cfg.get("negative_prompt", "")), + "--save_result_path", + str(instance_cfg.get("save_path", "")), + ] + + def _maybe_wrap_service_command_with_nsys( + self, + *, + host: str, + instance_type: str, + engine_rank: int, + instance_cfg: dict[str, Any], + command: list[str], + ) -> list[str]: + if not self._is_truthy(os.getenv("DISAGG_ENABLE_NSYS"), default=False): + return command + + if not self._is_local_host(host): + self.logger.info( + "Skip nsys profiling for remote %s instance host=%s rank=%s", + instance_type, + host, + engine_rank, + ) + return command + + nsys_bin = shutil.which(os.getenv("DISAGG_NSYS_BIN", "nsys")) + if nsys_bin is None: + self.logger.warning("DISAGG_ENABLE_NSYS is set but nsys is not available, skip profiling for %s rank=%s", instance_type, engine_rank) + return command + + output_dir_raw = os.getenv("DISAGG_NSYS_OUTPUT_DIR") + if output_dir_raw: + output_dir = Path(output_dir_raw) + else: + base_save_path = instance_cfg.get("save_path") or (self._runtime_config or {}).get("save_path") or str(Path(__file__).resolve().parents[3] / "save_results" / "wan22_i2v_dynamic.mp4") + output_dir = Path(str(base_save_path)).parent / "nsys" + output_dir.mkdir(parents=True, exist_ok=True) + + output_name = f"{instance_type}_rank{engine_rank}" + trace = os.getenv("DISAGG_NSYS_TRACE", "cuda,nvtx,osrt") + extra_args = shlex.split(os.getenv("DISAGG_NSYS_EXTRA_ARGS", "")) + + profiled_command = [ + nsys_bin, + "profile", + "--force-overwrite=true", + "--trace", + trace, + "-o", + str(output_dir / output_name), + ] + profiled_command.extend(extra_args) + profiled_command.extend(command) + return profiled_command + + def _merge_request_metrics(self, existing: dict[str, Any] | None, update: dict[str, Any] | None) -> dict[str, Any]: + merged: dict[str, Any] = {} + if isinstance(existing, dict): + merged.update(existing) + if not isinstance(update, dict): + return merged + + for key, value in update.items(): + if key != "stages" or not isinstance(value, dict): + merged[key] = value + continue + + merged_stages: dict[str, Any] = {} + existing_stages = merged.get("stages") + if isinstance(existing_stages, dict): + for stage_name, stage_metrics in existing_stages.items(): + merged_stages[stage_name] = dict(stage_metrics) if isinstance(stage_metrics, dict) else stage_metrics + + for stage_name, stage_metrics in value.items(): + if not isinstance(stage_metrics, dict): + continue + base_stage_metrics = merged_stages.get(stage_name) + if isinstance(base_stage_metrics, dict): + combined_stage_metrics = dict(base_stage_metrics) + combined_stage_metrics.update(stage_metrics) + else: + combined_stage_metrics = dict(stage_metrics) + merged_stages[stage_name] = combined_stage_metrics + + merged["stages"] = merged_stages + + return merged + + def _query_zmq(self, req_addr: str, payload: dict[str, Any], timeout_ms: int = 1000) -> dict[str, Any] | None: + context = zmq.Context() + req = context.socket(zmq.REQ) + req.setsockopt(zmq.RCVTIMEO, int(timeout_ms)) + req.setsockopt(zmq.SNDTIMEO, int(timeout_ms)) + req.connect(req_addr) + try: + req.send_pyobj(payload) + reply = req.recv_pyobj() + if isinstance(reply, dict): + return reply + return None + except Exception: + return None + finally: + req.close(0) + context.term() + + def _query_sidecar(self, req_addr: str, cmd: str) -> dict[str, Any] | None: + return self._query_zmq(req_addr, {"cmd": str(cmd)}, timeout_ms=1000) + + def _run_centralized_ok_server(self, stop_event: Event, bind_host: str, bind_port: int): + controller = self + + class _Handler(BaseHTTPRequestHandler): + def do_POST(self): + if self.path != "/ok": + self.send_response(404) + self.end_headers() + return + + content_length = int(self.headers.get("Content-Length", "0") or "0") + raw_body = self.rfile.read(content_length) if content_length > 0 else b"" + try: + message = json.loads(raw_body.decode("utf-8")) if raw_body else {} + except Exception: + self.send_response(400) + self.end_headers() + return + + controller.logger.info( + "Received centralized OK control message: stage=%s room=%s", + message.get("stage_name"), + message.get("data_bootstrap_room"), + ) + response = json.dumps({"ok": True, "control": "OK"}).encode("utf-8") + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(response))) + self.end_headers() + self.wfile.write(response) + + def log_message(self, format, *args): + return + + server = ThreadingHTTPServer((bind_host, bind_port), _Handler) + server.timeout = 0.2 + try: + while not stop_event.is_set(): + server.handle_request() + finally: + server.server_close() + + def _is_truthy(self, value: Any, default: bool = False) -> bool: + if value is None: + return default + if isinstance(value, bool): + return value + return str(value).strip().lower() in {"1", "true", "yes", "on"} + + def _remote_proxy_req_addr(self, slot: dict[str, Any]) -> str: + host = str(slot["host"]) + proxy_req_port = int(slot["proxy_req_port"]) + return f"tcp://{host}:{proxy_req_port}" + + def _ensure_remote_instance_proxy(self, slot: dict[str, Any]): + if not self._is_truthy(slot.get("use_remote_proxy", False)): + return + + req_addr = self._remote_proxy_req_addr(slot) + reply = self._query_zmq(req_addr, {"cmd": "ping"}, timeout_ms=800) + if isinstance(reply, dict) and reply.get("ok", False): + return + + python_executable = str(slot.get("python_executable", sys.executable)) + workdir = str(slot.get("workdir", Path(__file__).resolve().parents[3])) + log_dir = str(slot.get("log_dir", "/tmp/lightx2v_disagg")) + activate_cmd = str(slot.get("activate_cmd", "")).strip() + proxy_req_port = int(slot["proxy_req_port"]) + proxy_log_path = str(slot.get("proxy_log_path", f"{log_dir}/instance_proxy.log")) + + script_lines = [ + "set -e", + f"mkdir -p {shlex.quote(log_dir)}", + f"cd {shlex.quote(workdir)}", + ] + if activate_cmd: + script_lines.append(activate_cmd) + script_lines.extend( + [ + ( + "nohup env PYTHONUNBUFFERED=1 " + f"{shlex.quote(python_executable)} -m lightx2v.disagg.services.instance_proxy " + f"--bind-addr {shlex.quote(f'tcp://0.0.0.0:{proxy_req_port}')} " + f"--workdir {shlex.quote(workdir)} --log-dir {shlex.quote(log_dir)} " + f"> {shlex.quote(proxy_log_path)} 2>&1 &" + ), + "echo PROXY_PID=$!", + ] + ) + script = "\n".join(script_lines) + + self._run_ssh_script(slot, script, timeout_seconds=30.0, check=True) + + deadline = time.time() + self._remote_proxy_start_timeout_seconds + while time.time() < deadline: + probe = self._query_zmq(req_addr, {"cmd": "ping"}, timeout_ms=800) + if isinstance(probe, dict) and probe.get("ok", False): + self.logger.info("Remote instance proxy is ready on host=%s req_addr=%s", slot.get("host"), req_addr) + return + time.sleep(0.2) + + raise RuntimeError(f"remote instance proxy failed to start on host={slot.get('host')} req_addr={req_addr}") + + def _start_sidecar_process(self, instance_type: str, gpu_id: str | int, bind_host: str | None = None) -> dict[str, Any]: + host = str(bind_host or self._bootstrap_addr) + push_port = self._allocate_free_tcp_port(host) + req_port = self._allocate_free_tcp_port(host) + push_addr = f"tcp://{host}:{push_port}" + req_addr = f"tcp://{host}:{req_port}" + + cmd = [ + sys.executable, + "-m", + "lightx2v.disagg.services.data_mgr_sidecar", + "--push-addr", + push_addr, + "--req-addr", + req_addr, + ] + sidecar_env = os.environ.copy() + sidecar_env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + process = subprocess.Popen( + cmd, + env=sidecar_env, + start_new_session=True, + ) + + deadline = time.time() + self._sidecar_start_timeout_seconds + ready = False + while time.time() < deadline: + reply = self._query_sidecar(req_addr, "ping") + if isinstance(reply, dict) and reply.get("ok", False): + ready = True + break + time.sleep(0.1) + + if not ready: + if process.poll() is None: + process.terminate() + try: + process.wait(timeout=2.0) + except subprocess.TimeoutExpired: + process.kill() + raise RuntimeError(f"sidecar server failed to start for {instance_type} gpu={gpu_id}") + + self.logger.info( + "Started sidecar for %s gpu=%s pid=%s push=%s req=%s", + instance_type, + gpu_id, + process.pid, + push_addr, + req_addr, + ) + return { + "process": process, + "push_addr": push_addr, + "req_addr": req_addr, + } + + def _run_ssh_script(self, slot: dict[str, Any], script: str, timeout_seconds: float = 30.0, check: bool = True) -> subprocess.CompletedProcess: + ssh_bin = str(slot.get("ssh_bin", "ssh")) + ssh_target = str(slot.get("ssh_target", slot.get("host", ""))).strip() + if not ssh_target: + raise RuntimeError("remote slot missing ssh target") + + ssh_options = slot.get("ssh_options") + ssh_cmd = [ssh_bin] + if isinstance(ssh_options, list): + ssh_cmd.extend(str(opt) for opt in ssh_options if str(opt).strip()) + ssh_cmd.extend([ssh_target, "bash", "-lc", script]) + return subprocess.run( + ssh_cmd, + check=check, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=timeout_seconds, + ) + + def _launch_remote_instance(self, slot: dict[str, Any], instance_type: str, cmd: list[str], cuda_device: str) -> tuple[dict[str, Any], dict[str, Any]]: + if self._is_truthy(slot.get("use_remote_proxy", False)): + return self._launch_remote_instance_via_proxy(slot, instance_type, cmd, cuda_device) + + host = str(slot["host"]) + engine_rank = int(slot["engine_rank"]) + python_executable = str(slot.get("python_executable", sys.executable)) + workdir = str(slot.get("workdir", Path(__file__).resolve().parents[3])) + log_dir = str(slot.get("log_dir", "/tmp/lightx2v_disagg")) + activate_cmd = str(slot.get("activate_cmd", "")).strip() + push_port = int(slot["sidecar_push_port"]) + req_port = int(slot["sidecar_req_port"]) + push_addr = f"tcp://{host}:{push_port}" + req_addr = f"tcp://{host}:{req_port}" + service_log = f"{log_dir}/{instance_type}_{engine_rank}_service.log" + sidecar_log = f"{log_dir}/{instance_type}_{engine_rank}_sidecar.log" + + extra_env = slot.get("env") + normalized_env: dict[str, str] = {} + if isinstance(extra_env, dict): + for key, value in extra_env.items(): + normalized_env[str(key)] = str(value) + if self._is_centralized_enabled(): + normalized_env["IS_CENTRALIZED"] = "1" + if os.getenv("SYNC_COMM") is not None: + normalized_env["SYNC_COMM"] = str(os.getenv("SYNC_COMM", "0")) + self._ensure_rdma_preferred_ipv4_env(host, normalized_env) + + sidecar_env_vars = { + **normalized_env, + "CUDA_VISIBLE_DEVICES": str(cuda_device), + "PYTHONUNBUFFERED": "1", + } + service_env_vars = { + **normalized_env, + "CUDA_VISIBLE_DEVICES": str(cuda_device), + "LIGHTX2V_SIDECAR_PUSH_ADDR": push_addr, + "LIGHTX2V_SIDECAR_REQ_ADDR": req_addr, + "PYTHONUNBUFFERED": "1", + } + + def _to_env_prefix(env_map: dict[str, str]) -> str: + return " ".join(f"{key}={shlex.quote(value)}" for key, value in env_map.items()) + + def _with_env(base_cmd: str, env_map: dict[str, str]) -> str: + env_prefix = _to_env_prefix(env_map) + if not env_prefix: + return base_cmd + return f"env {env_prefix} {base_cmd}" + + sidecar_cmd = _with_env( + (f"{shlex.quote(python_executable)} -m lightx2v.disagg.services.data_mgr_sidecar --push-addr {shlex.quote(push_addr)} --req-addr {shlex.quote(req_addr)}"), + sidecar_env_vars, + ) + cmd_with_python = [python_executable, *cmd[1:]] + service_cmd = _with_env(" ".join(shlex.quote(str(part)) for part in cmd_with_python), service_env_vars) + + script_lines = [ + "set -e", + f"mkdir -p {shlex.quote(log_dir)}", + f"cd {shlex.quote(workdir)}", + ] + if activate_cmd: + script_lines.append(activate_cmd) + script_lines.extend( + [ + f"nohup {sidecar_cmd} > {shlex.quote(sidecar_log)} 2>&1 &", + "sidecar_pid=$!", + "sleep 0.5", + f"nohup {service_cmd} > {shlex.quote(service_log)} 2>&1 &", + "service_pid=$!", + "echo SIDECAR_PID=$sidecar_pid", + "echo SERVICE_PID=$service_pid", + ] + ) + script = "\n".join(script_lines) + + completed = self._run_ssh_script(slot, script, timeout_seconds=45.0, check=True) + sidecar_pid: int | None = None + service_pid: int | None = None + for line in completed.stdout.splitlines(): + if line.startswith("SIDECAR_PID="): + try: + sidecar_pid = int(line.split("=", 1)[1].strip()) + except ValueError: + sidecar_pid = None + elif line.startswith("SERVICE_PID="): + try: + service_pid = int(line.split("=", 1)[1].strip()) + except ValueError: + service_pid = None + + if sidecar_pid is None or service_pid is None: + raise RuntimeError(f"failed to parse remote pids for {instance_type} rank={engine_rank} host={host}: stdout={completed.stdout!r} stderr={completed.stderr!r}") + + sidecar_meta = { + "mode": "remote", + "host": host, + "req_addr": req_addr, + "push_addr": push_addr, + "pid": sidecar_pid, + "log_path": sidecar_log, + } + process_meta = { + "mode": "remote", + "host": host, + "pid": service_pid, + "log_path": service_log, + } + return process_meta, sidecar_meta + + def _launch_remote_instance_via_proxy(self, slot: dict[str, Any], instance_type: str, cmd: list[str], cuda_device: str) -> tuple[dict[str, Any], dict[str, Any]]: + self._ensure_remote_instance_proxy(slot) + + host = str(slot["host"]) + engine_rank = int(slot["engine_rank"]) + python_executable = str(slot.get("python_executable", sys.executable)) + workdir = str(slot.get("workdir", Path(__file__).resolve().parents[3])) + log_dir = str(slot.get("log_dir", "/tmp/lightx2v_disagg")) + push_port = int(slot["sidecar_push_port"]) + req_port = int(slot["sidecar_req_port"]) + push_addr = f"tcp://{host}:{push_port}" + req_addr = f"tcp://{host}:{req_port}" + service_log = f"{log_dir}/{instance_type}_{engine_rank}_service.log" + sidecar_log = f"{log_dir}/{instance_type}_{engine_rank}_sidecar.log" + + extra_env = slot.get("env") + normalized_env: dict[str, str] = {} + if isinstance(extra_env, dict): + for key, value in extra_env.items(): + normalized_env[str(key)] = str(value) + self._ensure_rdma_preferred_ipv4_env(host, normalized_env) + + proxy_req_addr = self._remote_proxy_req_addr(slot) + payload = { + "cmd": "start_instance", + "instance_type": str(instance_type), + "engine_rank": int(engine_rank), + "cuda_device": str(cuda_device), + "python_executable": python_executable, + "service_argv": [str(part) for part in cmd[1:]], + "sidecar_push_addr": push_addr, + "sidecar_req_addr": req_addr, + "service_log_path": service_log, + "sidecar_log_path": sidecar_log, + "workdir": workdir, + "log_dir": log_dir, + "env": normalized_env, + } + reply = self._query_zmq(proxy_req_addr, payload, timeout_ms=10000) + if not isinstance(reply, dict) or not reply.get("ok", False): + raise RuntimeError(f"remote proxy failed to start instance on host={host}: {reply}") + + data = reply.get("data") if isinstance(reply.get("data"), dict) else {} + sidecar_pid = int(data.get("sidecar_pid", 0) or 0) + service_pid = int(data.get("service_pid", 0) or 0) + if sidecar_pid <= 0 or service_pid <= 0: + raise RuntimeError(f"remote proxy returned invalid pids for host={host}: {reply}") + + sidecar_meta = { + "mode": "remote", + "host": host, + "req_addr": req_addr, + "push_addr": push_addr, + "pid": sidecar_pid, + "log_path": sidecar_log, + "proxy_req_addr": proxy_req_addr, + } + process_meta = { + "mode": "remote", + "host": host, + "pid": service_pid, + "log_path": service_log, + "proxy_req_addr": proxy_req_addr, + } + return process_meta, sidecar_meta + + def _stop_remote_pid(self, slot: dict[str, Any], pid: int, graceful_timeout_seconds: float): + if self._is_truthy(slot.get("use_remote_proxy", False)): + req_addr = self._remote_proxy_req_addr(slot) + timeout_seconds = max(1, int(graceful_timeout_seconds)) + payload = { + "cmd": "stop_pid", + "pid": int(pid), + "timeout_seconds": timeout_seconds, + } + reply = self._query_zmq(req_addr, payload, timeout_ms=(timeout_seconds + 3) * 1000) + if isinstance(reply, dict) and reply.get("ok", False): + return + self.logger.warning( + "Remote proxy stop_pid failed, falling back to ssh kill: host=%s pid=%s reply=%s", + slot.get("host"), + pid, + reply, + ) + + timeout_seconds = max(1, int(graceful_timeout_seconds)) + script = "\n".join( + [ + "set +e", + f"pid={int(pid)}", + "if kill -0 ${pid} >/dev/null 2>&1; then", + " kill -TERM ${pid} >/dev/null 2>&1 || true", + f" deadline=$((SECONDS+{timeout_seconds}))", + " while kill -0 ${pid} >/dev/null 2>&1; do", + " if (( SECONDS >= deadline )); then", + " kill -KILL ${pid} >/dev/null 2>&1 || true", + " break", + " fi", + " sleep 0.2", + " done", + "fi", + ] + ) + try: + self._run_ssh_script(slot, script, timeout_seconds=float(timeout_seconds + 10), check=False) + except Exception as exc: + self.logger.warning("Failed to stop remote pid=%s on host=%s: %s", pid, slot.get("host"), exc) + + def _reclaim_sidecar_when_drained(self, instance_type: str, target_address: str, sidecar_meta: dict[str, Any]): + req_addr = str(sidecar_meta.get("req_addr", "")) + process = sidecar_meta.get("process") + if not req_addr or process is None: + return + + deadline = None + if self._sidecar_drain_timeout_seconds > 0: + deadline = time.time() + self._sidecar_drain_timeout_seconds + + while True: + if process.poll() is not None: + # Sidecar already exited. + break + + reply = self._query_sidecar(req_addr, "get_stats") + if isinstance(reply, dict) and reply.get("ok", False): + data = reply.get("data") if isinstance(reply.get("data"), dict) else {} + last_message_ts = float(data.get("last_message_ts", 0.0)) + idle_seconds = max(0.0, time.time() - last_message_ts) + pending_input_watch = int(data.get("input_watch", 0)) + pending_output_watch = int(data.get("output_watch", 0)) + pending_transformer_request = int(data.get("transformer_request_pool", 0)) + pending_transformer_waiting = int(data.get("transformer_waiting_pool", 0)) + pending_transformer_active = int(data.get("transformer_active_rooms", 0)) + pending_active = pending_input_watch + pending_output_watch + pending_transformer_request + pending_transformer_waiting + pending_transformer_active + + if pending_active == 0 and idle_seconds >= self._sidecar_drain_idle_seconds: + break + + if deadline is not None and time.time() >= deadline: + self.logger.warning( + "Sidecar drain timeout reached for %s address=%s, forcing shutdown", + instance_type, + target_address, + ) + break + + time.sleep(0.2) + + try: + self._query_sidecar(req_addr, "shutdown") + except Exception: + pass + + if process.poll() is None: + process.terminate() + try: + process.wait(timeout=2.0) + except subprocess.TimeoutExpired: + process.kill() + + self.logger.info( + "Reclaimed sidecar for %s address=%s", + instance_type, + target_address, + ) + + def _to_plain(self, value: Any) -> Any: + """Recursively convert config containers (e.g. LockableDict) to built-in Python types.""" + if isinstance(value, Mapping): + return {k: self._to_plain(v) for k, v in value.items()} + if isinstance(value, list): + return [self._to_plain(v) for v in value] + if isinstance(value, tuple): + return tuple(self._to_plain(v) for v in value) + if isinstance(value, set): + return {self._to_plain(v) for v in value} + return value + + def _resolve_service_config_json(self, config_json: str, instance_type: str) -> str: + config_path = Path(config_json) + if config_path.is_file(): + if config_path.name.endswith("_controller.json"): + candidate = config_path.with_name(config_path.name.replace("_controller.json", f"_{instance_type}.json")) + if candidate.is_file(): + return str(candidate) + if config_path.name.endswith("_distill_controller.json"): + candidate = config_path.with_name(config_path.name.replace("_distill_controller.json", f"_distill_{instance_type}.json")) + if candidate.is_file(): + return str(candidate) + return config_json + + def _load_warmup_duration_seconds(self, config: Mapping[str, Any]) -> float: + stage_json = os.getenv("DISAGG_WORKLOAD_STAGES_JSON", "") + if not stage_json: + stage_json = str(config.get("workload_stages_json", "") or "").strip() + + if stage_json: + stage_file = Path(stage_json) + else: + repo_root = Path(__file__).resolve().parents[3] + stage_file = repo_root / "configs" / "disagg" / "wan22_i2v_workload_stages.json" + + if not stage_file.is_file(): + self.logger.warning("workload stages config not found, skip warmup scale guard: %s", stage_file) + return 0.0 + + try: + with stage_file.open("r", encoding="utf-8") as handle: + loaded = json.load(handle) + except Exception as exc: + self.logger.warning("failed to load workload stages config %s: %s", stage_file, exc) + return 0.0 + + if not isinstance(loaded, list): + self.logger.warning("invalid workload stages config format (expect list): %s", stage_file) + return 0.0 + + warmup_duration_s = 0.0 + for raw_stage in loaded: + if not isinstance(raw_stage, Mapping): + continue + + stage_name = str(raw_stage.get("name", "")).strip().lower() + if stage_name != "warmup": + if warmup_duration_s > 0.0: + break + continue + + try: + duration_s = float(raw_stage.get("duration_s", 0.0)) + except (TypeError, ValueError): + duration_s = 0.0 + warmup_duration_s += max(duration_s, 0.0) + + self.logger.info( + "Loaded workload warmup duration: file=%s warmup_duration_s=%.3f", + stage_file, + warmup_duration_s, + ) + return warmup_duration_s + + def _sample_rdma_queue_pending(self) -> dict[str, int]: + pending_by_service: dict[str, int] = { + "encoder": 0, + "transformer": 0, + "decoder": 0, + } + buffer_by_service = { + "encoder": self.rdma_buffer_request, + "transformer": self.rdma_buffer_phase1, + "decoder": self.rdma_buffer_phase2, + } + for service_type, rdma_buffer in buffer_by_service.items(): + if rdma_buffer is None: + continue + try: + pending_by_service[service_type] = int(rdma_buffer.pending_count()) + except Exception as exc: + self.logger.warning("Failed to sample RDMA pending count for %s: %s", service_type, exc) + return pending_by_service + + def _calc_precompute_pending(self, service_type: str, queue_sizes: Any) -> int: + if not isinstance(queue_sizes, dict): + return -1 + + normalized: dict[str, int] = {} + for key, value in queue_sizes.items(): + try: + normalized[str(key)] = int(value) + except (TypeError, ValueError): + continue + + if service_type == "encoder": + keys = ("req_queue", "exec_queue") + return sum(max(normalized.get(key, 0), 0) for key in keys) + + if service_type == "transformer": + direct_keys = ("req_queue", "waiting_queue", "exec_queue") + pending = sum(max(normalized.get(key, 0), 0) for key in direct_keys) + # phase1_* are pre-compute ingress queues; phase2_* are post-compute egress queues. + pending += sum(max(value, 0) for key, value in normalized.items() if key.startswith("phase1_")) + return pending + + if service_type == "decoder": + direct_keys = ("req_queue", "waiting_queue", "exec_queue") + pending = sum(max(normalized.get(key, 0), 0) for key in direct_keys) + # Decoder transfer_* represent ingress from transformer, still before decode compute. + pending += sum(max(value, 0) for key, value in normalized.items() if key.startswith("transfer_")) + return pending + + return -1 + + def _monitor_callback(self, results): + monitor_runtime = getattr(self, "_monitor_runtime", None) + if self._shutting_down or not isinstance(monitor_runtime, dict): + return + + warmup_duration_s = float(monitor_runtime.get("warmup_duration_s", 0.0)) + autoscale_start_mono = float(monitor_runtime.get("autoscale_start_mono", time.monotonic())) + warmup_skip_logged = bool(monitor_runtime.get("warmup_skip_logged", False)) + warmup_end_logged = bool(monitor_runtime.get("warmup_end_logged", False)) + scale_out_threshold = float(monitor_runtime.get("scale_out_threshold", 80.0)) + scale_out_max_queue_threshold = int(monitor_runtime.get("scale_out_max_queue_threshold", 2)) + scale_in_threshold = float(monitor_runtime.get("scale_in_threshold", 20.0)) + scale_cooldown_seconds = float(monitor_runtime.get("scale_cooldown_seconds", 30.0)) + last_scale_ts = monitor_runtime.get("last_scale_ts") + if not isinstance(last_scale_ts, dict): + return + + if warmup_duration_s > 0.0: + elapsed_s = max(0.0, time.monotonic() - autoscale_start_mono) + if elapsed_s < warmup_duration_s: + if not warmup_skip_logged: + self.logger.info( + "Skip autoscaling during warmup: elapsed_s=%.3f warmup_duration_s=%.3f", + elapsed_s, + warmup_duration_s, + ) + warmup_skip_logged = True + monitor_runtime["warmup_skip_logged"] = True + return + if warmup_skip_logged and not warmup_end_logged: + self.logger.info( + "Warmup finished, autoscaling enabled: elapsed_s=%.3f warmup_duration_s=%.3f", + elapsed_s, + warmup_duration_s, + ) + warmup_end_logged = True + monitor_runtime["warmup_end_logged"] = True + + service_metrics: dict[str, list[dict[str, Any]]] = { + "encoder": [], + "transformer": [], + "decoder": [], + } + + for item in results: + self.logger.info("monitor: %s", item) + if not isinstance(item, dict): + continue + + service_type = str(item.get("service_type", "")) + if service_type not in {"encoder", "transformer", "decoder"}: + continue + + if service_type not in {"transformer", "decoder"}: + continue + + if item.get("status") != "ok": + continue + + try: + gpu_utilization = float(item.get("gpu_utilization", 0.0)) + except (TypeError, ValueError): + continue + + monitor_address = str(item.get("address", "")) + if not monitor_address: + continue + + queue_total_pending = item.get("queue_total_pending", None) + try: + queue_total_pending_int = int(queue_total_pending) if queue_total_pending is not None else -1 + except (TypeError, ValueError): + queue_total_pending_int = -1 + + all_queues_empty = bool(item.get("all_queues_empty", False)) + queue_sizes = item.get("queue_sizes") + precompute_pending = self._calc_precompute_pending(service_type, queue_sizes) + + service_metrics[service_type].append( + { + "gpu_utilization": gpu_utilization, + "monitor_address": monitor_address, + "queue_total_pending": queue_total_pending_int, + "all_queues_empty": all_queues_empty, + "precompute_pending": precompute_pending, + } + ) + + rdma_pending_by_service = self._sample_rdma_queue_pending() + scale_out_candidates: list[dict[str, Any]] = [] + service_queue_scores: dict[str, float] = {} + service_precompute_scores: dict[str, float] = {} + + for service_type, metrics in service_metrics.items(): + if not metrics: + continue + avg_queue_total_pending = sum(int(metric.get("queue_total_pending", 0)) for metric in metrics) / len(metrics) + rdma_queue_pending = int(rdma_pending_by_service.get(service_type, 0)) + service_queue_scores[service_type] = float(rdma_queue_pending) + float(avg_queue_total_pending) + + precompute_values = [int(metric.get("precompute_pending", -1)) for metric in metrics if int(metric.get("precompute_pending", -1)) >= 0] + if precompute_values: + avg_precompute_pending = sum(precompute_values) / len(precompute_values) + service_precompute_scores[service_type] = float(rdma_queue_pending) + float(avg_precompute_pending) + else: + service_precompute_scores[service_type] = float(rdma_queue_pending) + + max_precompute_score = max(service_precompute_scores.values(), default=0.0) + + for service_type, metrics in service_metrics.items(): + if not metrics: + continue + + now = time.time() + avg_gpu_utilization = sum(float(metric["gpu_utilization"]) for metric in metrics) / len(metrics) + avg_queue_total_pending = sum(int(metric.get("queue_total_pending", 0)) for metric in metrics) / len(metrics) + max_queue_total_pending = max(int(metric.get("queue_total_pending", -1)) for metric in metrics) + rdma_queue_pending = int(rdma_pending_by_service.get(service_type, 0)) + current_queue_score = float(service_queue_scores.get(service_type, 0.0)) + current_precompute_score = float(service_precompute_scores.get(service_type, 0.0)) + + scale_out_triggered = avg_gpu_utilization > scale_out_threshold or max_queue_total_pending > scale_out_max_queue_threshold + + if scale_out_triggered and now - float(last_scale_ts.get(service_type, 0.0)) >= scale_cooldown_seconds: + scale_out_candidates.append( + { + "service_type": service_type, + "score": current_queue_score, + "avg_gpu_utilization": avg_gpu_utilization, + "avg_queue_total_pending": avg_queue_total_pending, + "max_queue_total_pending": max_queue_total_pending, + "rdma_queue_pending": rdma_queue_pending, + "now": now, + } + ) + + low_metric = min(metrics, key=lambda metric: float(metric["gpu_utilization"])) + low_utilization = float(low_metric["gpu_utilization"]) + low_monitor_address = str(low_metric["monitor_address"]) + with self._instance_lock: + service_instance_count = sum(1 for meta in self._managed_instances.values() if meta.get("instance_type") == service_type) + + low_precompute_pending = int(low_metric.get("precompute_pending", -1)) + if low_precompute_pending >= 0: + queues_empty_for_service = low_precompute_pending == 0 + else: + queues_empty_for_service = bool(low_metric.get("all_queues_empty", False)) and int(low_metric.get("queue_total_pending", -1)) == 0 + + blocked_by_queue_score = current_precompute_score > 0.0 and current_precompute_score >= max_precompute_score + + scale_in_triggered = ( + low_utilization < scale_in_threshold and service_instance_count > 1 and queues_empty_for_service and now - float(last_scale_ts.get(service_type, 0.0)) >= scale_cooldown_seconds + ) + + if scale_in_triggered and blocked_by_queue_score: + self.logger.info( + "Skip scale in for highest precompute-score service: service=%s precompute_score=%.2f max_precompute_score=%.2f total_score=%.2f", + service_type, + current_precompute_score, + max_precompute_score, + current_queue_score, + ) + continue + + if scale_in_triggered: + try: + target_instance_address = self._instance_address_from_monitor_node(low_monitor_address) + self.reclaim_instance(service_type, target_instance_address) + last_scale_ts[service_type] = now + self.logger.info( + "Auto-scale in triggered: service=%s low_gpu_utilization=%.2f reclaimed_instance=%s", + service_type, + low_utilization, + target_instance_address, + ) + except Exception as exc: + self.logger.warning( + "Auto-scale in skipped for service=%s low_gpu_utilization=%.2f reason=%s", + service_type, + low_utilization, + exc, + ) + + if scale_out_candidates: + target = max( + scale_out_candidates, + key=lambda item: (item["score"], item["max_queue_total_pending"], item["avg_gpu_utilization"]), + ) + target_service = str(target["service_type"]) + if float(target["now"]) - float(last_scale_ts.get(target_service, 0.0)) < scale_cooldown_seconds: + return + try: + new_address = self.create_instance(target_service) + last_scale_ts[target_service] = float(target["now"]) + self.logger.info( + "Auto-scale out triggered: service=%s score=%.2f rdma_queue_pending=%s avg_queue_total_pending=%.2f max_queue_total_pending=%s avg_gpu_utilization=%.2f new_instance=%s", + target_service, + float(target["score"]), + int(target["rdma_queue_pending"]), + float(target["avg_queue_total_pending"]), + int(target["max_queue_total_pending"]), + float(target["avg_gpu_utilization"]), + new_address, + ) + except Exception: + pass + + def _handle_decoder_result( + self, + result: Any, + *, + expected_rooms: set[int], + received_rooms: set[int], + received_results: list[dict], + ): + if not isinstance(result, dict): + self.logger.warning("Ignored non-dict decoder result: %s", result) + return + + message_type = str(result.get("message_type", "decoder_result")) + room = result.get("data_bootstrap_room") + if room is None: + self.logger.warning("Ignored decoder result without data_bootstrap_room: %s", result) + return + room = int(room) + + if message_type == "stage_metrics": + request_metrics = result.get("request_metrics") + if not isinstance(request_metrics, dict): + self.logger.warning("Ignored stage metrics update without request_metrics: %s", result) + return + merged_metrics = self._merge_request_metrics(self._request_metrics_by_room.get(room), request_metrics) + self._request_metrics_by_room[room] = merged_metrics + self.logger.info( + "Stage metrics updated room=%s stage=%s metrics=%s", + room, + result.get("stage_name"), + request_metrics.get("stages", {}), + ) + return + + if room not in expected_rooms: + self.logger.warning("Ignored decoder result for unexpected room=%s: %s", room, result) + return + if room in received_rooms: + self.logger.info("Duplicate decoder result for room=%s ignored", room) + return + + stored_metrics = self._request_metrics_by_room.get(room) + request_metrics = result.get("request_metrics") + if isinstance(request_metrics, dict): + merged_metrics = self._merge_request_metrics(stored_metrics, request_metrics) + self._request_metrics_by_room[room] = merged_metrics + result["request_metrics"] = merged_metrics + elif isinstance(stored_metrics, dict): + result["request_metrics"] = stored_metrics + + controller_recv_ts = time.time() + latency_summary = self._build_latency_summary(result, controller_recv_ts) + if latency_summary is not None: + result["latency_summary"] = latency_summary + self.logger.info("Latency summary room=%s metrics=%s", room, latency_summary) + + received_rooms.add(room) + received_results.append(result) + + if result.get("ok", False): + self.logger.info( + "Decoder result received room=%s save_path=%s (%s/%s)", + room, + result.get("save_path"), + len(received_rooms), + len(expected_rooms), + ) + else: + self.logger.error( + "Decoder result failed room=%s error=%s (%s/%s)", + room, + result.get("error"), + len(received_rooms), + len(expected_rooms), + ) + + def _drain_decoder_results_non_block( + self, + *, + result_port: int, + expected_rooms: set[int], + received_rooms: set[int], + received_results: list[dict], + ): + while True: + result = self.req_mgr.receive_non_block(result_port) + if result is None: + break + self._handle_decoder_result( + result, + expected_rooms=expected_rooms, + received_rooms=received_rooms, + received_results=received_results, + ) + + def _monitor_node_from_instance_address(self, instance_address: str) -> str: + host, port_str = instance_address.rsplit(":", 1) + rank = int(port_str) - REQUEST_POLLING_PORT + return f"tcp://{host}:{MONITOR_POLLING_PORT + rank}" + + def _instance_address_from_monitor_node(self, monitor_node: str) -> str: + host_port = monitor_node + if host_port.startswith("tcp://"): + host_port = host_port[len("tcp://") :] + host, port_str = host_port.rsplit(":", 1) + rank = int(port_str) - MONITOR_POLLING_PORT + return f"{host}:{REQUEST_POLLING_PORT + rank}" + + def _init_gpu_pool(self, config: dict): + disagg_cfg = config.get("disagg_config") if isinstance(config.get("disagg_config"), dict) else {} + self._refresh_local_host_aliases() + + static_slots_raw = disagg_cfg.get("static_instance_slots") + self._static_instance_slots = [] + self._free_slot_ids = set() + self._slot_reuse_block_until = {} + + if isinstance(static_slots_raw, list) and static_slots_raw: + default_workdir = str(disagg_cfg.get("remote_workdir", Path(__file__).resolve().parents[3])) + default_python = str(disagg_cfg.get("remote_python_executable", sys.executable)) + default_log_dir = str(disagg_cfg.get("remote_log_dir", "/tmp/lightx2v_disagg")) + default_activate_cmd = str(disagg_cfg.get("remote_activate_cmd", "")).strip() + default_ssh_user = str(disagg_cfg.get("ssh_user", "")).strip() + default_ssh_bin = str(disagg_cfg.get("ssh_bin", os.getenv("DISAGG_SSH_BIN", "ssh"))) + default_use_remote_proxy = self._is_truthy(disagg_cfg.get("use_remote_proxy"), default=self._is_truthy(os.getenv("DISAGG_USE_REMOTE_PROXY"), False)) + default_proxy_req_base_port = int(disagg_cfg.get("remote_proxy_req_base_port", 28000)) + + default_ssh_options_raw = disagg_cfg.get("ssh_options", []) + if isinstance(default_ssh_options_raw, str): + default_ssh_options = shlex.split(default_ssh_options_raw) + elif isinstance(default_ssh_options_raw, list): + default_ssh_options = [str(opt) for opt in default_ssh_options_raw if str(opt).strip()] + else: + default_ssh_options = [] + + default_slot_env = disagg_cfg.get("service_env", {}) + normalized_default_slot_env: dict[str, str] = {} + if isinstance(default_slot_env, dict): + for key, value in default_slot_env.items(): + normalized_default_slot_env[str(key)] = str(value) + + sidecar_base_port = int(disagg_cfg.get("sidecar_base_port", 26000)) + seen_slot_keys: set[tuple[str, int]] = set() + + for index, raw_slot in enumerate(static_slots_raw): + if not isinstance(raw_slot, dict): + raise ValueError(f"invalid static_instance_slots[{index}] (expect object)") + + instance_type = str(raw_slot.get("instance_type", "")).strip().lower() + if instance_type not in {"encoder", "transformer", "decoder"}: + raise ValueError(f"invalid static_instance_slots[{index}].instance_type={instance_type!r}") + + host = str(raw_slot.get("host", "")).strip() + if not host: + raise ValueError(f"static_instance_slots[{index}].host cannot be empty") + + engine_rank = int(raw_slot.get("engine_rank")) + cuda_device = str(raw_slot.get("cuda_device", engine_rank)) + slot_key = (host, engine_rank) + if slot_key in seen_slot_keys: + raise ValueError(f"duplicate static slot host/rank: {slot_key}") + seen_slot_keys.add(slot_key) + + ssh_user = str(raw_slot.get("ssh_user", default_ssh_user)).strip() + ssh_target = f"{ssh_user}@{host}" if ssh_user else host + ssh_bin = str(raw_slot.get("ssh_bin", default_ssh_bin)) + + ssh_options_raw = raw_slot.get("ssh_options", default_ssh_options) + if isinstance(ssh_options_raw, str): + ssh_options = shlex.split(ssh_options_raw) + elif isinstance(ssh_options_raw, list): + ssh_options = [str(opt) for opt in ssh_options_raw if str(opt).strip()] + else: + ssh_options = list(default_ssh_options) + + slot_env = dict(normalized_default_slot_env) + raw_slot_env = raw_slot.get("env", {}) + if isinstance(raw_slot_env, dict): + for key, value in raw_slot_env.items(): + slot_env[str(key)] = str(value) + + push_port = int(raw_slot.get("sidecar_push_port", sidecar_base_port + engine_rank * 2)) + req_port = int(raw_slot.get("sidecar_req_port", sidecar_base_port + engine_rank * 2 + 1)) + use_remote_proxy = self._is_truthy(raw_slot.get("use_remote_proxy"), default=default_use_remote_proxy) + proxy_req_port = int(raw_slot.get("proxy_req_port", default_proxy_req_base_port + engine_rank)) + proxy_log_path = str(raw_slot.get("proxy_log_path", f"{default_log_dir}/instance_proxy_{engine_rank}.log")) + + self._static_instance_slots.append( + { + "slot_id": index, + "instance_type": instance_type, + "host": host, + "engine_rank": engine_rank, + "cuda_device": cuda_device, + "workdir": str(raw_slot.get("workdir", default_workdir)), + "python_executable": str(raw_slot.get("python_executable", default_python)), + "log_dir": str(raw_slot.get("log_dir", default_log_dir)), + "activate_cmd": str(raw_slot.get("activate_cmd", default_activate_cmd)).strip(), + "ssh_target": ssh_target, + "ssh_bin": ssh_bin, + "ssh_options": ssh_options, + "sidecar_push_port": push_port, + "sidecar_req_port": req_port, + "use_remote_proxy": use_remote_proxy, + "proxy_req_port": proxy_req_port, + "proxy_log_path": proxy_log_path, + "env": slot_env, + } + ) + + self._free_slot_ids = {int(slot["slot_id"]) for slot in self._static_instance_slots} + self.logger.info("Static multi-node mode enabled with %s slots", len(self._static_instance_slots)) + self._free_gpus = set() + return + + total_ranks = int(config.get("ranks", disagg_cfg.get("ranks", 8))) + if total_ranks <= 0: + raise ValueError("ranks must be positive") + + self._free_gpus = set(range(total_ranks)) + + def create_instance(self, instance_type: str) -> str: + """Create one service instance on an idle GPU and add it to scheduling pool.""" + if instance_type not in {"encoder", "transformer", "decoder"}: + raise ValueError("instance_type must be one of: encoder, transformer, decoder") + if self._runtime_config is None: + raise RuntimeError("controller runtime config is not initialized") + + with self._instance_lock: + use_static_slots = bool(self._static_instance_slots) + selected_slot: dict[str, Any] | None = None + + if use_static_slots: + if not self._free_slot_ids: + raise RuntimeError("no idle static slot available") + + now = time.time() + for slot_id in sorted(self._free_slot_ids): + slot = self._static_instance_slots[slot_id] + if slot.get("instance_type") != instance_type: + continue + if now < self._slot_reuse_block_until.get(slot_id, 0.0): + continue + + host = str(slot["host"]) + engine_rank = int(slot["engine_rank"]) + monitor_port = MONITOR_POLLING_PORT + engine_rank + if self._is_tcp_port_open(host, monitor_port): + self.logger.warning( + "Skip static slot=%s host=%s rank=%s for %s creation because monitor port %s is still in use", + slot_id, + host, + engine_rank, + instance_type, + monitor_port, + ) + continue + + selected_slot = slot + break + + if selected_slot is None: + raise RuntimeError(f"no idle static slot available for {instance_type}: all candidates cooling down or port is in use") + + engine_rank = int(selected_slot["engine_rank"]) + host = str(selected_slot["host"]) + cuda_device = str(selected_slot["cuda_device"]) + else: + if not self._free_gpus: + raise RuntimeError("no idle GPU available") + + now = time.time() + engine_rank: int | None = None + host = self._bootstrap_addr + for candidate_gpu in sorted(self._free_gpus): + if now < self._gpu_reuse_block_until.get(candidate_gpu, 0.0): + continue + + monitor_port = MONITOR_POLLING_PORT + candidate_gpu + if self._is_tcp_port_open(self._bootstrap_addr, monitor_port): + self.logger.warning( + "Skip gpu=%s for %s creation because monitor port %s is still in use", + candidate_gpu, + instance_type, + monitor_port, + ) + continue + + engine_rank = candidate_gpu + break + + if engine_rank is None: + raise RuntimeError(f"no idle GPU available for {instance_type}: all candidates cooling down or port is in use") + cuda_device = str(engine_rank) + + instance_cfg = self._to_plain(self._runtime_config) + instance_cfg["disagg_mode"] = instance_type + if instance_type == "encoder": + instance_cfg["encoder_engine_rank"] = engine_rank + elif instance_type == "transformer": + instance_cfg["transformer_engine_rank"] = engine_rank + else: + instance_cfg["decoder_engine_rank"] = engine_rank + + model_path = instance_cfg.get("model_path") + config_json = instance_cfg.get("config_json") + if not model_path or not config_json: + raise RuntimeError("model_path and config_json are required to launch service subprocess") + service_config_json = self._resolve_service_config_json(str(config_json), instance_type) + + cmd = self._build_service_command(instance_type, engine_rank, instance_cfg, service_config_json) + cmd = self._maybe_wrap_service_command_with_nsys( + host=host, + instance_type=instance_type, + engine_rank=engine_rank, + instance_cfg=instance_cfg, + command=cmd, + ) + + process: subprocess.Popen | None = None + process_meta: dict[str, Any] | None = None + sidecar_meta: dict[str, Any] + launch_mode = "local" + + try: + if use_static_slots and selected_slot is not None and not self._is_local_host(host): + launch_mode = "remote" + process_meta, sidecar_meta = self._launch_remote_instance(selected_slot, instance_type, cmd, cuda_device) + else: + sidecar_meta = self._start_sidecar_process(instance_type, cuda_device, bind_host=host) + env = os.environ.copy() + if use_static_slots and selected_slot is not None: + slot_env = selected_slot.get("env") + if isinstance(slot_env, dict): + for key, value in slot_env.items(): + env[str(key)] = str(value) + self._ensure_rdma_preferred_ipv4_env(host, env) + env["CUDA_VISIBLE_DEVICES"] = str(cuda_device) + env["LIGHTX2V_SIDECAR_PUSH_ADDR"] = str(sidecar_meta["push_addr"]) + env["LIGHTX2V_SIDECAR_REQ_ADDR"] = str(sidecar_meta["req_addr"]) + process = subprocess.Popen( + cmd, + env=env, + start_new_session=True, + ) + + monitor_port = MONITOR_POLLING_PORT + engine_rank + if not self._wait_for_tcp_port_state(host, monitor_port, should_be_open=True, timeout_seconds=self._instance_start_timeout_seconds): + raise RuntimeError(f"service {instance_type} rank={engine_rank} host={host} failed to expose monitor port {monitor_port}") + except Exception: + if process is not None and process.poll() is None: + process.terminate() + try: + process.wait(timeout=3.0) + except subprocess.TimeoutExpired: + process.kill() + + if launch_mode == "remote" and selected_slot is not None and process_meta is not None: + remote_pid = process_meta.get("pid") + if isinstance(remote_pid, int) and remote_pid > 0: + self._stop_remote_pid(selected_slot, remote_pid, self._graceful_reclaim_timeout_seconds) + + if "sidecar_meta" in locals(): + if launch_mode == "remote" and selected_slot is not None: + sidecar_pid = sidecar_meta.get("pid") + if isinstance(sidecar_pid, int) and sidecar_pid > 0: + self._stop_remote_pid(selected_slot, sidecar_pid, self._force_kill_wait_seconds) + else: + sidecar_process = sidecar_meta.get("process") + if sidecar_process is not None and sidecar_process.poll() is None: + sidecar_process.terminate() + try: + sidecar_process.wait(timeout=2.0) + except subprocess.TimeoutExpired: + sidecar_process.kill() + raise + + instance_address = f"{host}:{REQUEST_POLLING_PORT + engine_rank}" + if use_static_slots and selected_slot is not None: + self._free_slot_ids.remove(int(selected_slot["slot_id"])) + else: + self._free_gpus.remove(engine_rank) + if self._enable_monitor: + monitor_node = f"tcp://{host}:{MONITOR_POLLING_PORT + engine_rank}" + if monitor_node not in self.monitor.nodes: + self.monitor.nodes.append(monitor_node) + self._managed_instances[instance_address] = { + "instance_type": instance_type, + "gpu_id": engine_rank, + "host": host, + "launch_mode": launch_mode, + "cuda_device": cuda_device, + "process": process, + "process_meta": process_meta, + "sidecar": sidecar_meta, + "slot_id": int(selected_slot["slot_id"]) if selected_slot is not None else None, + "static_slot": self._to_plain(selected_slot) if selected_slot is not None else None, + } + self.started_instances.append((instance_type, instance_address)) + self.add_instance(instance_type, instance_address) + self.logger.info( + "Created %s instance host=%s rank=%s mode=%s address=%s", + instance_type, + host, + engine_rank, + launch_mode, + instance_address, + ) + return instance_address + + def reclaim_instance(self, instance_type: str, instance_address: str | None = None) -> str: + """Reclaim one managed instance and return its GPU back to idle pool.""" + if instance_type not in {"encoder", "transformer", "decoder"}: + raise ValueError("instance_type must be one of: encoder, transformer, decoder") + + with self._instance_lock: + target_address = instance_address + if target_address is None: + candidates = [addr for addr, meta in self._managed_instances.items() if meta.get("instance_type") == instance_type] + if not candidates: + raise RuntimeError(f"no managed {instance_type} instance to reclaim") + target_address = candidates[-1] + + meta = self._managed_instances.get(target_address) + if meta is None: + if (instance_type, target_address) in self.started_instances: + self.started_instances.remove((instance_type, target_address)) + self.logger.warning( + "Skip reclaim for already-removed %s instance address=%s", + instance_type, + target_address, + ) + return target_address + if meta.get("instance_type") != instance_type: + raise RuntimeError(f"instance type mismatch for {target_address}: expected={instance_type} got={meta.get('instance_type')}") + + process = meta.get("process") + process_meta = meta.get("process_meta") if isinstance(meta.get("process_meta"), dict) else None + gpu_id = int(meta.get("gpu_id")) + sidecar_meta = meta.get("sidecar") if isinstance(meta.get("sidecar"), dict) else None + host = str(meta.get("host", self._bootstrap_addr)) + launch_mode = str(meta.get("launch_mode", "local")) + static_slot = meta.get("static_slot") if isinstance(meta.get("static_slot"), dict) else None + slot_id_raw = meta.get("slot_id") + slot_id = int(slot_id_raw) if slot_id_raw is not None else None + + self.remove_instance(instance_type, target_address) + monitor_node = self._monitor_node_from_instance_address(target_address) + + if launch_mode == "remote": + if static_slot is None: + raise RuntimeError(f"remote instance metadata missing static slot for {target_address}") + + remote_service_pid = None + if process_meta is not None and isinstance(process_meta.get("pid"), int): + remote_service_pid = int(process_meta["pid"]) + if remote_service_pid is not None and remote_service_pid > 0: + self._stop_remote_pid(static_slot, remote_service_pid, self._graceful_reclaim_timeout_seconds) + + if sidecar_meta is not None and isinstance(sidecar_meta.get("pid"), int): + self._stop_remote_pid(static_slot, int(sidecar_meta["pid"]), self._force_kill_wait_seconds) + else: + if process is not None and process.poll() is None: + try: + os.killpg(process.pid, signal.SIGTERM) + except Exception: + process.terminate() + try: + process.wait(timeout=self._graceful_reclaim_timeout_seconds) + except subprocess.TimeoutExpired: + try: + os.killpg(process.pid, signal.SIGKILL) + except Exception: + process.kill() + try: + process.wait(timeout=self._force_kill_wait_seconds) + except subprocess.TimeoutExpired as exc: + raise RuntimeError(f"process did not exit after kill for {instance_type} instance {target_address}") from exc + + if self._enable_monitor and monitor_node in self.monitor.nodes: + self.monitor.nodes.remove(monitor_node) + + monitor_port = MONITOR_POLLING_PORT + gpu_id + if not self._wait_for_tcp_port_state(host, monitor_port, should_be_open=False, timeout_seconds=5.0): + self.logger.warning( + "Monitor port still open after reclaim: service=%s host=%s rank=%s port=%s", + instance_type, + host, + gpu_id, + monitor_port, + ) + + if slot_id is not None and slot_id in range(len(self._static_instance_slots)): + self._free_slot_ids.add(slot_id) + self._slot_reuse_block_until[slot_id] = time.time() + self._gpu_reuse_grace_seconds + else: + self._free_gpus.add(gpu_id) + self._gpu_reuse_block_until[gpu_id] = time.time() + self._gpu_reuse_grace_seconds + self._managed_instances.pop(target_address, None) + if (instance_type, target_address) in self.started_instances: + self.started_instances.remove((instance_type, target_address)) + + if sidecar_meta is not None and launch_mode != "remote": + reclaim_thread = Thread( + target=self._reclaim_sidecar_when_drained, + args=(instance_type, target_address, sidecar_meta), + name=f"sidecar-reclaim-{instance_type}-{gpu_id}", + daemon=True, + ) + reclaim_thread.start() + self._sidecar_reclaim_threads.append(reclaim_thread) + + self.logger.info( + "Reclaimed %s instance from host=%s rank=%s address=%s", + instance_type, + host, + gpu_id, + target_address, + ) + return target_address def _init_request_rdma_buffer(self, bootstrap_addr: str, config: dict): slots = int(config.get("rdma_buffer_slots", "128")) @@ -51,21 +1603,24 @@ def _init_request_rdma_buffer(self, bootstrap_addr: str, config: dict): config["rdma_phase2_handshake_port"] = phase2_handshake_port need_bytes = 16 + slots * slot_size - self._rdma_server_request = RDMAServer(buffer_size=need_bytes) - self.rdma_buffer_request = RDMABuffer( - role="server", - buffer_size=slots, - slot_size=slot_size, - rdma_server=self._rdma_server_request, - ) + if not self._is_centralized_enabled(): + self._rdma_server_request = RDMAServer(buffer_size=need_bytes) + self.rdma_buffer_request = RDMABuffer( + role="server", + buffer_size=slots, + slot_size=slot_size, + rdma_server=self._rdma_server_request, + ) - self._rdma_handshake_thread_request = Thread( - target=self._rdma_server_request.handshake, - kwargs={"host": bootstrap_addr, "port": handshake_port}, - name="controller-rdma-handshake", - daemon=True, - ) - self._rdma_handshake_thread_request.start() + self._rdma_handshake_thread_request = Thread( + target=self._rdma_server_request.handshake, + kwargs={"host": bootstrap_addr, "port": handshake_port}, + name="controller-rdma-handshake", + daemon=True, + ) + self._rdma_handshake_thread_request.start() + else: + self.logger.info("IS_CENTRALIZED enabled, skip controller request RDMA ring initialization") need_bytes_phase1 = 16 + phase1_slots * phase1_slot_size self._rdma_server_phase1 = RDMAServer(buffer_size=need_bytes_phase1) @@ -100,9 +1655,9 @@ def _init_request_rdma_buffer(self, bootstrap_addr: str, config: dict): self._rdma_handshake_thread_phase2.start() self.logger.info( "Initialized RDMA buffers: request=(%s,%s,%s) phase1=(%s,%s,%s) phase2=(%s,%s,%s)", - slots, - slot_size, - need_bytes, + slots if self.rdma_buffer_request is not None else 0, + slot_size if self.rdma_buffer_request is not None else 0, + need_bytes if self.rdma_buffer_request is not None else 0, phase1_slots, phase1_slot_size, need_bytes_phase1, @@ -129,6 +1684,126 @@ def serve_rdma_dispatch_only(self, config: dict) -> None: except KeyboardInterrupt: self.logger.info("Controller serve_rdma_dispatch_only interrupted, exiting.") + def _build_latency_summary(self, result: dict[str, Any], controller_recv_ts: float) -> dict[str, float] | None: + request_metrics = result.get("request_metrics") + if not isinstance(request_metrics, dict): + return None + + def _as_float(value: Any) -> float | None: + try: + return float(value) + except (TypeError, ValueError): + return None + + def _stage(name: str) -> dict[str, Any]: + stages = request_metrics.get("stages") + if not isinstance(stages, dict): + return {} + stage_metrics = stages.get(name) + return stage_metrics if isinstance(stage_metrics, dict) else {} + + controller_send_ts = _as_float(request_metrics.get("controller_send_ts")) + if controller_send_ts is None: + return None + + centralized_mode = self._is_centralized_enabled() + summary: dict[str, float] = { + "end_to_end_delay_s": controller_recv_ts - controller_send_ts, + } + + encoder = _stage("encoder") + transformer = _stage("transformer") + decoder = _stage("decoder") + + encoder_recv_ts = _as_float(encoder.get("request_received_ts")) + encoder_compute_start_ts = _as_float(encoder.get("compute_start_ts")) + encoder_compute_end_ts = _as_float(encoder.get("compute_end_ts")) + encoder_output_enqueued_ts = _as_float(encoder.get("output_enqueued_ts")) + + transformer_recv_ts = _as_float(transformer.get("request_received_ts")) + transformer_compute_start_ts = _as_float(transformer.get("compute_start_ts")) + transformer_compute_end_ts = _as_float(transformer.get("compute_end_ts")) + transformer_output_enqueued_ts = _as_float(transformer.get("output_enqueued_ts")) + + decoder_recv_ts = _as_float(decoder.get("request_received_ts")) + decoder_compute_start_ts = _as_float(decoder.get("compute_start_ts")) + decoder_compute_end_ts = _as_float(decoder.get("compute_end_ts")) + decoder_output_enqueued_ts = _as_float(decoder.get("output_enqueued_ts")) + + if centralized_mode: + if encoder_recv_ts is not None: + summary["controller_to_encoder_comm_delay_s"] = encoder_recv_ts - controller_send_ts + if encoder_recv_ts is not None and encoder_compute_start_ts is not None: + summary["encoder_scheduling_delay_s"] = encoder_compute_start_ts - encoder_recv_ts + if encoder_compute_start_ts is not None and encoder_compute_end_ts is not None: + summary["encoder_compute_delay_s"] = encoder_compute_end_ts - encoder_compute_start_ts + if encoder_output_enqueued_ts is not None and transformer_recv_ts is not None: + summary["encoder_communication_delay_s"] = transformer_recv_ts - controller_send_ts + if transformer_recv_ts is not None and transformer_compute_start_ts is not None: + summary["transformer_scheduling_delay_s"] = transformer_compute_start_ts - transformer_recv_ts + if transformer_compute_start_ts is not None and transformer_compute_end_ts is not None: + summary["transformer_compute_delay_s"] = transformer_compute_end_ts - transformer_compute_start_ts + if transformer_recv_ts is not None: + summary["transformer_communication_delay_s"] = transformer_recv_ts - controller_send_ts + if decoder_recv_ts is not None and decoder_compute_start_ts is not None: + summary["decoder_scheduling_delay_s"] = decoder_compute_start_ts - decoder_recv_ts + if decoder_compute_start_ts is not None and decoder_compute_end_ts is not None: + summary["decoder_compute_delay_s"] = decoder_compute_end_ts - decoder_compute_start_ts + if decoder_recv_ts is not None: + summary["decoder_communication_delay_s"] = decoder_recv_ts - controller_send_ts + + component_keys = [ + "controller_to_encoder_comm_delay_s", + "encoder_scheduling_delay_s", + "encoder_compute_delay_s", + "encoder_communication_delay_s", + "transformer_scheduling_delay_s", + "transformer_compute_delay_s", + "transformer_communication_delay_s", + "decoder_scheduling_delay_s", + "decoder_compute_delay_s", + "decoder_communication_delay_s", + ] + if all(key in summary for key in component_keys): + summary["sum_of_components_s"] = sum(summary[key] for key in component_keys) + else: + if encoder_recv_ts is not None: + summary["controller_to_encoder_comm_delay_s"] = encoder_recv_ts - controller_send_ts + if encoder_recv_ts is not None and encoder_compute_start_ts is not None: + summary["encoder_scheduling_delay_s"] = encoder_compute_start_ts - encoder_recv_ts + if encoder_compute_start_ts is not None and encoder_compute_end_ts is not None: + summary["encoder_compute_delay_s"] = encoder_compute_end_ts - encoder_compute_start_ts + if encoder_output_enqueued_ts is not None and transformer_recv_ts is not None: + summary["encoder_communication_delay_s"] = transformer_recv_ts - encoder_output_enqueued_ts + if transformer_recv_ts is not None and transformer_compute_start_ts is not None: + summary["transformer_scheduling_delay_s"] = transformer_compute_start_ts - transformer_recv_ts + if transformer_compute_start_ts is not None and transformer_compute_end_ts is not None: + summary["transformer_compute_delay_s"] = transformer_compute_end_ts - transformer_compute_start_ts + if transformer_output_enqueued_ts is not None and decoder_recv_ts is not None: + summary["transformer_communication_delay_s"] = decoder_recv_ts - transformer_output_enqueued_ts + if decoder_recv_ts is not None and decoder_compute_start_ts is not None: + summary["decoder_scheduling_delay_s"] = decoder_compute_start_ts - decoder_recv_ts + if decoder_compute_start_ts is not None and decoder_compute_end_ts is not None: + summary["decoder_compute_delay_s"] = decoder_compute_end_ts - decoder_compute_start_ts + if decoder_output_enqueued_ts is not None: + summary["decoder_communication_delay_s"] = controller_recv_ts - decoder_output_enqueued_ts + + component_keys = [ + "controller_to_encoder_comm_delay_s", + "encoder_scheduling_delay_s", + "encoder_compute_delay_s", + "encoder_communication_delay_s", + "transformer_scheduling_delay_s", + "transformer_compute_delay_s", + "transformer_communication_delay_s", + "decoder_scheduling_delay_s", + "decoder_compute_delay_s", + "decoder_communication_delay_s", + ] + if all(key in summary for key in component_keys): + summary["sum_of_components_s"] = sum(summary[key] for key in component_keys) + return summary + def add_instance(self, instance_type: str, instance_address: str): """Add instance address to the matching scheduling policy by type.""" if not instance_address: @@ -162,80 +1837,262 @@ def send_request(self, config): if config is None: raise ValueError("config cannot be None") + room_raw = config.get("data_bootstrap_room") + try: + room = int(room_raw) + except (TypeError, ValueError): + room = None + + request_metrics = config.get("request_metrics") + if room is not None and isinstance(request_metrics, dict): + self._request_metrics_by_room[room] = self._merge_request_metrics(None, request_metrics) + + if self._is_centralized_enabled(): + request_config = self._to_plain(config) + + encoder_address = self.encoder_policy.schedule() + transformer_address = self.transformer_policy.schedule() + decoder_address = self.decoder_policy.schedule() + + def _address_to_rank(instance_address: str) -> int: + _, port_str = instance_address.rsplit(":", 1) + return int(port_str) - REQUEST_POLLING_PORT + + encoder_rank = _address_to_rank(encoder_address) + transformer_rank = _address_to_rank(transformer_address) + decoder_rank = _address_to_rank(decoder_address) + + request_config["encoder_engine_rank"] = encoder_rank + request_config["transformer_engine_rank"] = transformer_rank + request_config["decoder_engine_rank"] = decoder_rank + request_config["encoder_node_address"] = encoder_address + request_config["transformer_node_address"] = transformer_address + request_config["decoder_node_address"] = decoder_address + request_config["controller_control_host"] = request_config.get("controller_result_host", self._bootstrap_addr) + request_config["controller_control_port"] = int(request_config.get("controller_control_port", REQUEST_POLLING_PORT - 3)) + + for instance_type, target_address in ( + ("encoder", encoder_address), + ("transformer", transformer_address), + ("decoder", decoder_address), + ): + host, port_str = target_address.rsplit(":", 1) + self.req_mgr.send(host, int(port_str), request_config) + self.logger.info("Request dispatched to %s via ZMQ: target=%s", instance_type, target_address) + return + if self.rdma_buffer_request is None: raise RuntimeError("RDMA request buffer is not initialized") self.rdma_buffer_request.produce(config) self.logger.info("Request enqueued to encoder request RDMA buffer") def run(self, config): - """Initialize instances, send requests, wait for decoder save_path callbacks, then exit.""" + """Initialize controller buffers, stream request configs from workload, then wait for all callbacks.""" if config is None: raise ValueError("config cannot be None") + self._shutting_down = False + bootstrap_addr = config.get("data_bootstrap_addr", "127.0.0.1") - encoder_engine_rank = config.get("encoder_engine_rank", 0) - transformer_engine_rank = config.get("transformer_engine_rank", 1) - decoder_engine_rank = config.get("decoder_engine_rank", 2) - request_count = int(config.get("request_count", 2)) + request_ingress_port = int(config.get("controller_request_port", os.getenv("DISAGG_CONTROLLER_REQUEST_PORT", REQUEST_POLLING_PORT - 2))) result_port = int(config.get("controller_result_port", REQUEST_POLLING_PORT - 1)) + control_port = int(config.get("controller_control_port", REQUEST_POLLING_PORT - 3)) + self._bootstrap_addr = str(bootstrap_addr) + self._runtime_config = self._to_plain(config) + self._init_gpu_pool(config) + self._enable_monitor = self._is_monitor_enabled() + centralized_mode = self._is_centralized_enabled() - self.encoder_policy = RoundRobinPolicy() - self.transformer_policy = RoundRobinPolicy() - self.decoder_policy = RoundRobinPolicy() + # self.encoder_policy = RoundRobinPolicy() + # self.transformer_policy = RoundRobinPolicy() + # self.decoder_policy = RoundRobinPolicy() self._init_request_rdma_buffer(bootstrap_addr, config) + if centralized_mode: + self.logger.info("IS_CENTRALIZED enabled, controller will dispatch requests via ZMQ") - self.add_instance("encoder", f"{bootstrap_addr}:{REQUEST_POLLING_PORT + encoder_engine_rank}") - self.add_instance( - "transformer", - f"{bootstrap_addr}:{REQUEST_POLLING_PORT + transformer_engine_rank}", - ) - self.add_instance("decoder", f"{bootstrap_addr}:{REQUEST_POLLING_PORT + decoder_engine_rank}") + time.sleep(5.0) - monitor_nodes = [ - f"tcp://{bootstrap_addr}:{MONITOR_POLLING_PORT + encoder_engine_rank}", - f"tcp://{bootstrap_addr}:{MONITOR_POLLING_PORT + transformer_engine_rank}", - f"tcp://{bootstrap_addr}:{MONITOR_POLLING_PORT + decoder_engine_rank}", - ] - self.monitor.nodes = monitor_nodes - - monitor_stop_event = Event() - - def _monitor_callback(results): - for item in results: - self.logger.info("monitor: %s", item) - - monitor_thread = Thread( - target=self.monitor.run_forever, - kwargs={ - "interval_seconds": 5.0, - "callback": _monitor_callback, - "stop_event": monitor_stop_event, - }, - name="controller-monitor", - daemon=True, - ) - # monitor_thread.start() + if self._static_instance_slots: + self.logger.info( + "Starting managed instances from static_instance_slots: %s", + [slot["instance_type"] for slot in self._static_instance_slots], + ) + for slot in self._static_instance_slots: + self.create_instance(str(slot["instance_type"])) + else: + for instance_type in ("encoder", "transformer", "decoder"): + self.create_instance(instance_type) + for _ in range(5): + self.create_instance("transformer") + + instance_warmup_wait_s = int(os.getenv("DISAGG_INSTANCE_WARMUP_WAIT_S", "30")) + if instance_warmup_wait_s > 0: + self.logger.info( + "Managed instances created, waiting %ss before accepting requests", + instance_warmup_wait_s, + ) + time.sleep(instance_warmup_wait_s) + + monitor_stop_event: Event | None = None + monitor_thread: Thread | None = None + ok_gate_stop_event: Event | None = None + ok_gate_thread: Thread | None = None + self._monitor_runtime = None + + if self._enable_monitor: + monitor_stop_event = Event() + warmup_duration_s = self._load_warmup_duration_seconds(config) + autoscale_start_mono = time.monotonic() + warmup_skip_logged = False + warmup_end_logged = False + scale_out_threshold = 80.0 + scale_out_max_queue_threshold = 2 + scale_in_threshold = 20.0 + scale_cooldown_seconds = 30.0 + last_scale_ts: dict[str, float] = { + "encoder": 0.0, + "transformer": 0.0, + "decoder": 0.0, + } + + self._monitor_runtime = { + "warmup_duration_s": warmup_duration_s, + "autoscale_start_mono": autoscale_start_mono, + "warmup_skip_logged": warmup_skip_logged, + "warmup_end_logged": warmup_end_logged, + "scale_out_threshold": scale_out_threshold, + "scale_out_max_queue_threshold": scale_out_max_queue_threshold, + "scale_in_threshold": scale_in_threshold, + "scale_cooldown_seconds": scale_cooldown_seconds, + "last_scale_ts": last_scale_ts, + } + + monitor_thread = Thread( + target=self.monitor.run_forever, + kwargs={ + "interval_seconds": 2.0, + "callback": self._monitor_callback, + "stop_event": monitor_stop_event, + }, + name="controller-monitor", + daemon=True, + ) + monitor_thread.start() + self.logger.info("ENABLE_MONITOR enabled, monitor thread started") + else: + self.logger.info("ENABLE_MONITOR is not set, skip monitor logic") + + if centralized_mode: + ok_gate_stop_event = Event() + ok_gate_thread = Thread( + target=self._run_centralized_ok_server, + args=(ok_gate_stop_event, self._bootstrap_addr, control_port), + name="controller-ok-gate", + daemon=True, + ) + ok_gate_thread.start() + self.logger.info("Centralized OK gate server started on %s:%s", self._bootstrap_addr, control_port) + + time.sleep(5.0) base_save_path = config.get("save_path") expected_rooms: set[int] = set() received_rooms: set[int] = set() received_results: list[dict] = [] + next_room = 0 + batch_request_start_ts: float | None = None + load_from_user = str(os.getenv("LOAD_FROM_USER", "0")).strip().lower() in {"1", "true", "yes", "on"} + auto_request_count_raw = config.get("request_count", os.getenv("DISAGG_AUTO_REQUEST_COUNT", "30")) try: - for i in range(request_count): + auto_request_count = int(auto_request_count_raw) + except (TypeError, ValueError): + self.logger.warning( + "Invalid request_count=%s, fallback to 30", + auto_request_count_raw, + ) + auto_request_count = 30 + if auto_request_count <= 0: + self.logger.warning("request_count must be positive, fallback to 30") + auto_request_count = 30 + + try: + generated_request_count = 0 + if load_from_user: + self.logger.info("LOAD_FROM_USER enabled, waiting workload configs on port=%s", request_ingress_port) + else: + self.logger.info( + "LOAD_FROM_USER disabled, generating requests from config: count=%s", + auto_request_count, + ) + + while True: + if load_from_user: + workload_config = self.req_mgr.receive(request_ingress_port) + if not isinstance(workload_config, dict): + self.logger.warning("Ignored invalid workload config packet: %s", workload_config) + continue + + if workload_config.get("workload_end") or workload_config.get("end") or workload_config.get("stop"): + self.logger.info("Received workload end signal, stop accepting new configs.") + break + else: + if generated_request_count >= auto_request_count: + break + workload_config = {} + generated_request_count += 1 + request_config = dict(config) - request_config["data_bootstrap_room"] = i + request_config.update(self._to_plain(workload_config)) + + room = request_config.get("data_bootstrap_room", next_room) + try: + room = int(room) + except (TypeError, ValueError): + room = next_room + if room in expected_rooms: + while next_room in expected_rooms: + next_room += 1 + room = next_room + next_room = max(next_room, room + 1) + + request_config["data_bootstrap_room"] = room request_config["controller_result_host"] = bootstrap_addr request_config["controller_result_port"] = result_port - if base_save_path: + + metrics = request_config.get("request_metrics") + if not isinstance(metrics, dict): + metrics = {} + metrics["request_id"] = int(metrics.get("request_id", room)) + metrics["controller_send_ts"] = time.time() + if not isinstance(metrics.get("stages"), dict): + metrics["stages"] = {} + request_config["request_metrics"] = metrics + + if base_save_path and not request_config.get("save_path"): save_path = Path(base_save_path) - request_config["save_path"] = str(save_path.with_name(f"{save_path.stem}{i + 1}{save_path.suffix}")) - # TODO: use queue to receive request from client and dispatch, currently we just send the same request multiple times for testing + request_config["save_path"] = str(save_path.with_name(f"{save_path.stem}{room}{save_path.suffix}")) + with self._lock: current_request = request_config + + if batch_request_start_ts is None: + batch_request_start_ts = time.time() + self.send_request(current_request) + self.logger.info( + "Dispatched request room=%s save_path=%s", + room, + request_config.get("save_path"), + ) + expected_rooms.add(room) - expected_rooms.add(i) + self._drain_decoder_results_non_block( + result_port=result_port, + expected_rooms=expected_rooms, + received_rooms=received_rooms, + received_results=received_results, + ) self.logger.info( "Waiting for decoder results: expected=%s on port=%s", @@ -244,43 +2101,41 @@ def _monitor_callback(results): ) while len(received_rooms) < len(expected_rooms): result = self.req_mgr.receive(result_port) - if not isinstance(result, dict): - self.logger.warning("Ignored non-dict decoder result: %s", result) - continue - room = result.get("data_bootstrap_room") - if room is None: - self.logger.warning("Ignored decoder result without data_bootstrap_room: %s", result) - continue - room = int(room) - if room not in expected_rooms: - self.logger.warning("Ignored decoder result for unexpected room=%s: %s", room, result) - continue - if room in received_rooms: - self.logger.info("Duplicate decoder result for room=%s ignored", room) - continue - - received_rooms.add(room) - received_results.append(result) - - if result.get("ok", False): - self.logger.info( - "Decoder result received room=%s save_path=%s (%s/%s)", - room, - result.get("save_path"), - len(received_rooms), - len(expected_rooms), - ) - else: - self.logger.error( - "Decoder result failed room=%s error=%s (%s/%s)", - room, - result.get("error"), - len(received_rooms), - len(expected_rooms), - ) + self._handle_decoder_result( + result, + expected_rooms=expected_rooms, + received_rooms=received_rooms, + received_results=received_results, + ) self.logger.info("All decoder results received. Controller exiting.") + if batch_request_start_ts is None: + batch_request_start_ts = time.time() + batch_total_time_s = time.time() - batch_request_start_ts + self.logger.info( + "Batch total elapsed time: requests=%s completed=%s total_time_s=%.3f", + len(expected_rooms), + len(received_rooms), + batch_total_time_s, + ) finally: - pass - # monitor_stop_event.set() - # monitor_thread.join(timeout=1.0) + self._shutting_down = True + if monitor_stop_event is not None: + monitor_stop_event.set() + if monitor_thread is not None: + monitor_thread.join(timeout=2.0) + if ok_gate_stop_event is not None: + ok_gate_stop_event.set() + if ok_gate_thread is not None: + ok_gate_thread.join(timeout=2.0) + self._monitor_runtime = None + + for instance_type, address in reversed(list(self.started_instances)): + try: + self.reclaim_instance(instance_type, address) + except Exception: + self.logger.exception("Failed to reclaim %s instance address=%s", instance_type, address) + + for thread in list(self._sidecar_reclaim_threads): + if thread.is_alive(): + thread.join(timeout=3.0) diff --git a/lightx2v/disagg/services/data_mgr_sidecar.py b/lightx2v/disagg/services/data_mgr_sidecar.py new file mode 100644 index 000000000..4421fd0f8 --- /dev/null +++ b/lightx2v/disagg/services/data_mgr_sidecar.py @@ -0,0 +1,975 @@ +from __future__ import annotations + +import argparse +import os +import threading +import time +from collections import deque +from multiprocessing import resource_tracker, shared_memory +from typing import TYPE_CHECKING, Any, Deque + +import zmq + +if TYPE_CHECKING: + from lightx2v.disagg.conn import DataReceiver, DataSender + + +STATUS_FAILED = 0 +STATUS_SUCCESS = 4 +_SHM_TRACKING_PATCHED = False + + +def _disable_shared_memory_tracking_for_process(): + """Disable multiprocessing resource_tracker registration for shared_memory. + + Python 3.12 does not expose SharedMemory(track=False). In fail-fast paths where + processes are terminated quickly, tracker warnings/noise can dominate logs even + when manual cleanup is performed by sidecar ownership logic. + """ + + global _SHM_TRACKING_PATCHED + if _SHM_TRACKING_PATCHED: + return + + original_register = resource_tracker.register + original_unregister = resource_tracker.unregister + + def _register(name, rtype): + if rtype == "shared_memory": + return + return original_register(name, rtype) + + def _unregister(name, rtype): + if rtype == "shared_memory": + return + return original_unregister(name, rtype) + + resource_tracker.register = _register + resource_tracker.unregister = _unregister + _SHM_TRACKING_PATCHED = True + + +class DataMgrSidecarServer: + """Controller-managed sidecar server process. + + Services push transfer-state events to this process and pop aggregated events + through request/reply calls. + """ + + def __init__(self, push_addr: str, req_addr: str): + _disable_shared_memory_tracking_for_process() + self.push_addr = str(push_addr) + self.req_addr = str(req_addr) + + self._input_watch: set[int] = set() + self._output_watch: set[int] = set() + self._ready_inputs: Deque[int] = deque() + self._failed_inputs: Deque[int] = deque() + self._completed_outputs: Deque[tuple[int, int]] = deque() + + self._total_messages = 0 + self._last_message_ts = time.time() + self._running = True + + self._transformer_phase2_mgr: Any | None = None + self._transformer_phase2_rooms: dict[int, dict[str, Any]] = {} + self._transformer_phase2_output_watch: set[int] = set() + self._transformer_phase2_last_status: dict[int, int] = {} + + def _mark_activity(self): + self._total_messages += 1 + self._last_message_ts = time.time() + + def _handle_push(self, msg: dict): + cmd = str(msg.get("cmd", "")) + room = int(msg.get("room", -1)) + + if cmd == "watch_input" and room >= 0: + self._input_watch.add(room) + self._mark_activity() + return + if cmd == "unwatch_input" and room >= 0: + self._input_watch.discard(room) + self._mark_activity() + return + if cmd == "watch_output" and room >= 0: + self._output_watch.add(room) + self._mark_activity() + return + if cmd == "unwatch_output" and room >= 0: + self._output_watch.discard(room) + self._mark_activity() + return + + if cmd == "input_status" and room >= 0: + status = int(msg.get("status", STATUS_FAILED)) + self._input_watch.discard(room) + if status == STATUS_SUCCESS: + self._ready_inputs.append(room) + else: + self._failed_inputs.append(room) + self._mark_activity() + return + + if cmd == "output_status" and room >= 0: + status = int(msg.get("status", STATUS_FAILED)) + self._output_watch.discard(room) + self._completed_outputs.append((room, status)) + self._mark_activity() + return + + if cmd == "shutdown": + self._running = False + self._mark_activity() + + def _ensure_transformer_phase2_mgr(self): + if self._transformer_phase2_mgr is not None: + return self._transformer_phase2_mgr + + from lightx2v.disagg.conn import DataManager, DisaggregationMode, DisaggregationPhase + + self._transformer_phase2_mgr = DataManager(DisaggregationPhase.PHASE2, DisaggregationMode.TRANSFORMER) + return self._transformer_phase2_mgr + + def _create_shared_memory(self, size: int) -> shared_memory.SharedMemory: + # Keep lifecycle in this process and avoid resource_tracker duplicate cleanup at shutdown. + try: + return shared_memory.SharedMemory(create=True, size=int(size), track=False) + except TypeError: + return shared_memory.SharedMemory(create=True, size=int(size)) + + def _close_unlink_shared_memory(self, shm: shared_memory.SharedMemory): + try: + shm.close() + except Exception: + pass + try: + shm.unlink() + except FileNotFoundError: + pass + except Exception: + pass + + def _cleanup_transformer_phase2_room(self, room: int): + room = int(room) + info = self._transformer_phase2_rooms.pop(room, None) + self._transformer_phase2_output_watch.discard(room) + + mgr = self._transformer_phase2_mgr + if mgr is not None: + try: + mgr.remove(room) + except Exception: + pass + + if not isinstance(info, dict): + return + + shms = info.get("shms") + if isinstance(shms, list): + for shm in shms: + self._close_unlink_shared_memory(shm) + + def _init_transformer_output_room( + self, + room: int, + sender_engine_rank: int, + receiver_engine_rank: int, + data_lens: list[int], + bootstrap_addr: str, + ) -> dict[str, Any]: + room = int(room) + sender_engine_rank = int(sender_engine_rank) + receiver_engine_rank = int(receiver_engine_rank) + normalized_lens = [int(v) for v in list(data_lens)] + if not normalized_lens or any(v <= 0 for v in normalized_lens): + raise ValueError(f"invalid data_lens for room={room}: {normalized_lens}") + + self._cleanup_transformer_phase2_room(room) + self._transformer_phase2_last_status.pop(room, None) + mgr = self._ensure_transformer_phase2_mgr() + + import numpy as np + import torch + + from lightx2v.disagg.conn import DataArgs, DataSender + + shms: list[shared_memory.SharedMemory] = [] + arrays: list[Any] = [] + tensors: list[Any] = [] + data_ptrs: list[int] = [] + shm_names: list[str] = [] + + try: + for nbytes in normalized_lens: + shm = self._create_shared_memory(int(nbytes)) + arr = np.ndarray((int(nbytes),), dtype=np.uint8, buffer=shm.buf) + tensor = torch.from_numpy(arr) + tensor.zero_() + + shms.append(shm) + arrays.append(arr) + tensors.append(tensor) + data_ptrs.append(int(tensor.data_ptr())) + shm_names.append(str(shm.name)) + + data_args = DataArgs( + sender_engine_rank=sender_engine_rank, + receiver_engine_rank=receiver_engine_rank, + data_ptrs=data_ptrs, + data_lens=normalized_lens, + data_item_lens=normalized_lens, + ib_device=None, + ) + mgr.init(data_args, room) + sender = DataSender(mgr, bootstrap_addr, room) + + self._transformer_phase2_rooms[room] = { + "sender": sender, + "data_ptrs": data_ptrs, + "shms": shms, + "arrays": arrays, + "tensors": tensors, + } + + self._mark_activity() + return { + "room": room, + "shm_names": shm_names, + "data_lens": normalized_lens, + "host": str(mgr.get_localhost()), + "session_id": str(mgr.get_session_id()), + } + except Exception: + for shm in shms: + self._close_unlink_shared_memory(shm) + try: + mgr.remove(room) + except Exception: + pass + raise + + def _send_transformer_output_room(self, room: int): + room = int(room) + info = self._transformer_phase2_rooms.get(room) + if not isinstance(info, dict): + raise KeyError(f"transformer output room not initialized: {room}") + + sender = info.get("sender") + data_ptrs = info.get("data_ptrs") + if sender is None or not isinstance(data_ptrs, list): + raise RuntimeError(f"transformer output room metadata invalid: {room}") + + sender.send(list(data_ptrs)) + self._transformer_phase2_output_watch.add(room) + self._mark_activity() + + def _get_transformer_output_status(self, room: int) -> int: + room = int(room) + info = self._transformer_phase2_rooms.get(room) + if not isinstance(info, dict): + return int(self._transformer_phase2_last_status.get(room, STATUS_FAILED)) + sender = info.get("sender") + if sender is None: + return int(self._transformer_phase2_last_status.get(room, STATUS_FAILED)) + try: + return int(sender.poll()) + except Exception: + return int(self._transformer_phase2_last_status.get(room, STATUS_FAILED)) + + def _get_transformer_output_backlog(self) -> dict[str, int]: + mgr = self._transformer_phase2_mgr + if mgr is None: + return { + "request_pool": 0, + "waiting_pool": 0, + "request_status": 0, + } + try: + data = mgr.get_backlog_counts() + except Exception: + data = {} + return { + "request_pool": int(data.get("request_pool", 0)), + "waiting_pool": int(data.get("waiting_pool", 0)), + "request_status": int(data.get("request_status", 0)), + } + + def _poll_transformer_output_watch(self): + for room in list(self._transformer_phase2_output_watch): + status_val = self._get_transformer_output_status(room) + if status_val in (STATUS_SUCCESS, STATUS_FAILED): + self._transformer_phase2_output_watch.discard(room) + self._transformer_phase2_last_status[int(room)] = int(status_val) + self._completed_outputs.append((int(room), int(status_val))) + self._cleanup_transformer_phase2_room(room) + self._mark_activity() + + def _release_transformer_phase2_mgr(self): + for room in list(self._transformer_phase2_rooms.keys()): + self._cleanup_transformer_phase2_room(room) + mgr = self._transformer_phase2_mgr + self._transformer_phase2_mgr = None + if mgr is not None: + try: + mgr.release() + except Exception: + pass + + def _get_pending_counts(self) -> dict[str, int]: + transformer_backlog = self._get_transformer_output_backlog() + output_watch = len(self._output_watch) + len(self._transformer_phase2_output_watch) + return { + "input_watch": len(self._input_watch), + "output_watch": output_watch, + "ready_inputs": len(self._ready_inputs), + "failed_inputs": len(self._failed_inputs), + "completed_outputs": len(self._completed_outputs), + "transformer_request_pool": int(transformer_backlog.get("request_pool", 0)), + "transformer_waiting_pool": int(transformer_backlog.get("waiting_pool", 0)), + "transformer_active_rooms": len(self._transformer_phase2_rooms), + } + + def _handle_req(self, req: dict) -> dict: + cmd = str(req.get("cmd", "")) + + if cmd == "ping": + return {"ok": True} + + if cmd == "get_pending_counts": + return {"ok": True, "data": self._get_pending_counts()} + + if cmd == "get_stats": + counts = self._get_pending_counts() + return { + "ok": True, + "data": { + **counts, + "total_messages": int(self._total_messages), + "last_message_ts": float(self._last_message_ts), + }, + } + + if cmd == "init_transformer_output_room": + try: + room = int(req.get("room", -1)) + sender_engine_rank = int(req.get("sender_engine_rank", -1)) + receiver_engine_rank = int(req.get("receiver_engine_rank", -1)) + data_lens_raw = req.get("data_lens") + bootstrap_addr = str(req.get("bootstrap_addr", "127.0.0.1")) + if room < 0 or sender_engine_rank < 0 or receiver_engine_rank < 0: + raise ValueError("room/sender_engine_rank/receiver_engine_rank must be non-negative") + if not isinstance(data_lens_raw, list): + raise ValueError("data_lens must be a list") + data = self._init_transformer_output_room( + room=room, + sender_engine_rank=sender_engine_rank, + receiver_engine_rank=receiver_engine_rank, + data_lens=[int(v) for v in data_lens_raw], + bootstrap_addr=bootstrap_addr, + ) + return {"ok": True, "data": data} + except Exception as exc: + return {"ok": False, "error": str(exc)} + + if cmd == "send_transformer_output_room": + try: + room = int(req.get("room", -1)) + if room < 0: + raise ValueError("room must be non-negative") + self._send_transformer_output_room(room) + return {"ok": True, "data": True} + except Exception as exc: + return {"ok": False, "error": str(exc)} + + if cmd == "get_transformer_output_status": + room = int(req.get("room", -1)) + if room < 0: + return {"ok": False, "error": "room must be non-negative"} + return {"ok": True, "data": self._get_transformer_output_status(room)} + + if cmd == "remove_transformer_output_room": + room = int(req.get("room", -1)) + if room < 0: + return {"ok": False, "error": "room must be non-negative"} + self._cleanup_transformer_phase2_room(room) + self._mark_activity() + return {"ok": True, "data": True} + + if cmd == "get_transformer_output_backlog": + return {"ok": True, "data": self._get_transformer_output_backlog()} + + if cmd == "get_transformer_output_identity": + mgr = self._transformer_phase2_mgr + if mgr is None: + return {"ok": False, "error": "transformer phase2 manager not initialized"} + return { + "ok": True, + "data": { + "host": str(mgr.get_localhost()), + "session_id": str(mgr.get_session_id()), + }, + } + + if cmd == "pop_ready_inputs": + items = list(self._ready_inputs) + self._ready_inputs.clear() + return {"ok": True, "data": items} + + if cmd == "pop_failed_inputs": + items = list(self._failed_inputs) + self._failed_inputs.clear() + return {"ok": True, "data": items} + + if cmd == "pop_completed_outputs": + items = list(self._completed_outputs) + self._completed_outputs.clear() + return {"ok": True, "data": items} + + if cmd == "shutdown": + self._running = False + self._mark_activity() + return {"ok": True} + + return {"ok": False, "error": f"unknown command: {cmd}"} + + def run_forever(self): + context = zmq.Context() + pull = context.socket(zmq.PULL) + rep = context.socket(zmq.REP) + + pull.bind(self.push_addr) + rep.bind(self.req_addr) + + poller = zmq.Poller() + poller.register(pull, zmq.POLLIN) + poller.register(rep, zmq.POLLIN) + + try: + while self._running: + events = dict(poller.poll(timeout=100)) + if pull in events: + try: + self._handle_push(pull.recv_pyobj()) + except Exception: + pass + + if rep in events: + try: + reply = self._handle_req(rep.recv_pyobj()) + except Exception as exc: + reply = {"ok": False, "error": str(exc)} + rep.send_pyobj(reply) + + self._poll_transformer_output_watch() + finally: + self._release_transformer_phase2_mgr() + pull.close(0) + rep.close(0) + context.term() + + +class _LocalDataMgrSidecar: + """Fallback local sidecar used when controller-managed endpoints are absent.""" + + def __init__(self, poll_interval_s: float = 0.01): + self.poll_interval_s = max(float(poll_interval_s), 0.001) + + self._lock = threading.Lock() + self._stop_event = threading.Event() + self._thread: threading.Thread | None = None + self._started = False + + self._input_watch: dict[int, DataReceiver] = {} + self._output_watch: dict[int, DataSender] = {} + + self._ready_inputs: Deque[int] = deque() + self._failed_inputs: Deque[int] = deque() + self._completed_outputs: Deque[tuple[int, int]] = deque() + + def start(self): + if self._thread is not None and self._thread.is_alive(): + self._started = True + return + + self._stop_event.clear() + self._thread = threading.Thread( + target=self._run, + name="data-mgr-sidecar-local", + daemon=True, + ) + self._thread.start() + self._started = True + + def stop(self): + self._stop_event.set() + if self._thread is not None and self._thread.is_alive(): + self._thread.join(timeout=1.0) + self._thread = None + self._started = False + + def watch_input(self, room: int, receiver: DataReceiver): + if not self._started: + self.start() + with self._lock: + self._input_watch[int(room)] = receiver + + def unwatch_input(self, room: int): + with self._lock: + self._input_watch.pop(int(room), None) + + def watch_output(self, room: int, sender: DataSender): + if not self._started: + self.start() + with self._lock: + self._output_watch[int(room)] = sender + + def unwatch_output(self, room: int): + with self._lock: + self._output_watch.pop(int(room), None) + + def pop_ready_inputs(self) -> list[int]: + with self._lock: + items = list(self._ready_inputs) + self._ready_inputs.clear() + return items + + def pop_failed_inputs(self) -> list[int]: + with self._lock: + items = list(self._failed_inputs) + self._failed_inputs.clear() + return items + + def pop_completed_outputs(self) -> list[tuple[int, int]]: + with self._lock: + items = list(self._completed_outputs) + self._completed_outputs.clear() + return items + + def get_pending_counts(self) -> dict[str, int]: + with self._lock: + return { + "input_watch": len(self._input_watch), + "output_watch": len(self._output_watch), + "ready_inputs": len(self._ready_inputs), + "failed_inputs": len(self._failed_inputs), + "completed_outputs": len(self._completed_outputs), + } + + def init_transformer_output_room( + self, + room: int, + sender_engine_rank: int, + receiver_engine_rank: int, + data_lens: list[int], + bootstrap_addr: str, + ) -> dict[str, Any] | None: + return None + + def send_transformer_output_room(self, room: int) -> bool: + return False + + def get_transformer_output_status(self, room: int) -> int: + return STATUS_FAILED + + def remove_transformer_output_room(self, room: int) -> bool: + return False + + def get_transformer_output_backlog(self) -> dict[str, int]: + return { + "request_pool": 0, + "waiting_pool": 0, + "request_status": 0, + } + + def get_transformer_output_identity(self, room: int | None = None) -> dict[str, Any] | None: + return None + + def _run(self): + while not self._stop_event.is_set(): + with self._lock: + input_items = list(self._input_watch.items()) + output_items = list(self._output_watch.items()) + + if not input_items and not output_items: + time.sleep(self.poll_interval_s) + continue + + for room, receiver in input_items: + try: + status = receiver.poll() + except Exception: + status = STATUS_FAILED + + status_val = int(status) + if status_val == STATUS_SUCCESS: + with self._lock: + self._input_watch.pop(room, None) + self._ready_inputs.append(room) + elif status_val == STATUS_FAILED: + with self._lock: + self._input_watch.pop(room, None) + self._failed_inputs.append(room) + + for room, sender in output_items: + try: + status = sender.poll() + except Exception: + status = STATUS_FAILED + + status_val = int(status) + if status_val in (STATUS_SUCCESS, STATUS_FAILED): + with self._lock: + self._output_watch.pop(room, None) + self._completed_outputs.append((room, status_val)) + + time.sleep(self.poll_interval_s) + + +class _RemoteDataMgrSidecarClient: + """Service-side client for controller-managed sidecar process.""" + + def __init__(self, push_addr: str, req_addr: str, poll_interval_s: float = 0.01): + self.push_addr = str(push_addr) + self.req_addr = str(req_addr) + self.poll_interval_s = max(float(poll_interval_s), 0.001) + + self._context = zmq.Context.instance() + self._push = self._context.socket(zmq.PUSH) + self._push.connect(self.push_addr) + + self._req = self._context.socket(zmq.REQ) + self._req.connect(self.req_addr) + self._req.setsockopt(zmq.RCVTIMEO, 1500) + self._req.setsockopt(zmq.SNDTIMEO, 1500) + + self._req_lock = threading.Lock() + self._watch_lock = threading.Lock() + self._stop_event = threading.Event() + self._thread: threading.Thread | None = None + self._started = False + + self._input_watch: dict[int, DataReceiver] = {} + self._output_watch: dict[int, DataSender] = {} + + def start(self): + if self._thread is not None and self._thread.is_alive(): + self._started = True + return + self._stop_event.clear() + self._thread = threading.Thread(target=self._run, name="data-mgr-sidecar-remote-client", daemon=True) + self._thread.start() + self._started = True + + def stop(self): + self._stop_event.set() + if self._thread is not None and self._thread.is_alive(): + self._thread.join(timeout=1.0) + self._thread = None + self._started = False + + try: + with self._watch_lock: + rooms_in = list(self._input_watch.keys()) + rooms_out = list(self._output_watch.keys()) + self._input_watch.clear() + self._output_watch.clear() + for room in rooms_in: + self._push_cmd({"cmd": "unwatch_input", "room": int(room)}) + for room in rooms_out: + self._push_cmd({"cmd": "unwatch_output", "room": int(room)}) + except Exception: + pass + + def watch_input(self, room: int, receiver: DataReceiver): + if not self._started: + self.start() + room = int(room) + with self._watch_lock: + self._input_watch[room] = receiver + self._push_cmd({"cmd": "watch_input", "room": room}) + + def unwatch_input(self, room: int): + room = int(room) + with self._watch_lock: + self._input_watch.pop(room, None) + self._push_cmd({"cmd": "unwatch_input", "room": room}) + + def watch_output(self, room: int, sender: DataSender): + if not self._started: + self.start() + room = int(room) + with self._watch_lock: + self._output_watch[room] = sender + self._push_cmd({"cmd": "watch_output", "room": room}) + + def unwatch_output(self, room: int): + room = int(room) + with self._watch_lock: + self._output_watch.pop(room, None) + self._push_cmd({"cmd": "unwatch_output", "room": room}) + + def pop_ready_inputs(self) -> list[int]: + data = self._req_cmd("pop_ready_inputs") + if isinstance(data, list): + return [int(v) for v in data] + return [] + + def pop_failed_inputs(self) -> list[int]: + data = self._req_cmd("pop_failed_inputs") + if isinstance(data, list): + return [int(v) for v in data] + return [] + + def pop_completed_outputs(self) -> list[tuple[int, int]]: + data = self._req_cmd("pop_completed_outputs") + if not isinstance(data, list): + return [] + items: list[tuple[int, int]] = [] + for item in data: + if isinstance(item, (list, tuple)) and len(item) == 2: + items.append((int(item[0]), int(item[1]))) + return items + + def get_pending_counts(self) -> dict[str, int]: + data = self._req_cmd("get_pending_counts") + if not isinstance(data, dict): + return { + "input_watch": 0, + "output_watch": 0, + "ready_inputs": 0, + "failed_inputs": 0, + "completed_outputs": 0, + } + return { + "input_watch": int(data.get("input_watch", 0)), + "output_watch": int(data.get("output_watch", 0)), + "ready_inputs": int(data.get("ready_inputs", 0)), + "failed_inputs": int(data.get("failed_inputs", 0)), + "completed_outputs": int(data.get("completed_outputs", 0)), + } + + def _push_cmd(self, cmd: dict): + try: + self._push.send_pyobj(cmd) + except Exception: + pass + + def _req_cmd(self, cmd: str, payload: dict[str, Any] | None = None): + try: + req_payload: dict[str, Any] = {"cmd": str(cmd)} + if isinstance(payload, dict): + req_payload.update(payload) + with self._req_lock: + self._req.send_pyobj(req_payload) + reply = self._req.recv_pyobj() + if isinstance(reply, dict) and reply.get("ok", False): + return reply.get("data") + return None + except Exception: + return None + + def init_transformer_output_room( + self, + room: int, + sender_engine_rank: int, + receiver_engine_rank: int, + data_lens: list[int], + bootstrap_addr: str, + ) -> dict[str, Any] | None: + data = self._req_cmd( + "init_transformer_output_room", + { + "room": int(room), + "sender_engine_rank": int(sender_engine_rank), + "receiver_engine_rank": int(receiver_engine_rank), + "data_lens": [int(v) for v in data_lens], + "bootstrap_addr": str(bootstrap_addr), + }, + ) + if isinstance(data, dict): + return data + return None + + def send_transformer_output_room(self, room: int) -> bool: + data = self._req_cmd("send_transformer_output_room", {"room": int(room)}) + return bool(data) + + def get_transformer_output_status(self, room: int) -> int: + data = self._req_cmd("get_transformer_output_status", {"room": int(room)}) + if data is None: + return STATUS_FAILED + try: + return int(data) + except Exception: + return STATUS_FAILED + + def remove_transformer_output_room(self, room: int) -> bool: + data = self._req_cmd("remove_transformer_output_room", {"room": int(room)}) + return bool(data) + + def get_transformer_output_backlog(self) -> dict[str, int]: + data = self._req_cmd("get_transformer_output_backlog") + if not isinstance(data, dict): + return { + "request_pool": 0, + "waiting_pool": 0, + "request_status": 0, + } + return { + "request_pool": int(data.get("request_pool", 0)), + "waiting_pool": int(data.get("waiting_pool", 0)), + "request_status": int(data.get("request_status", 0)), + } + + def get_transformer_output_identity(self, room: int | None = None) -> dict[str, Any] | None: + payload: dict[str, Any] = {} + if room is not None: + payload["room"] = int(room) + data = self._req_cmd("get_transformer_output_identity", payload) + if isinstance(data, dict): + return { + "host": str(data.get("host", "")), + "session_id": str(data.get("session_id", "")), + } + return None + + def _run(self): + while not self._stop_event.is_set(): + with self._watch_lock: + input_items = list(self._input_watch.items()) + output_items = list(self._output_watch.items()) + + if not input_items and not output_items: + time.sleep(self.poll_interval_s) + continue + + for room, receiver in input_items: + try: + status = receiver.poll() + except Exception: + status = STATUS_FAILED + + status_val = int(status) + if status_val in (STATUS_SUCCESS, STATUS_FAILED): + with self._watch_lock: + self._input_watch.pop(room, None) + self._push_cmd( + { + "cmd": "input_status", + "room": int(room), + "status": status_val, + } + ) + + for room, sender in output_items: + try: + status = sender.poll() + except Exception: + status = STATUS_FAILED + + status_val = int(status) + if status_val in (STATUS_SUCCESS, STATUS_FAILED): + with self._watch_lock: + self._output_watch.pop(room, None) + self._push_cmd( + { + "cmd": "output_status", + "room": int(room), + "status": status_val, + } + ) + + time.sleep(self.poll_interval_s) + + +class DataMgrSidecar: + """Service-facing sidecar facade. + + If controller-side endpoints exist, use controller-managed remote sidecar. + Otherwise fallback to in-process local sidecar for standalone runs. + """ + + def __init__(self, poll_interval_s: float = 0.01): + push_addr = str(os.getenv("LIGHTX2V_SIDECAR_PUSH_ADDR", "")).strip() + req_addr = str(os.getenv("LIGHTX2V_SIDECAR_REQ_ADDR", "")).strip() + + if push_addr and req_addr: + self._impl = _RemoteDataMgrSidecarClient(push_addr=push_addr, req_addr=req_addr, poll_interval_s=poll_interval_s) + else: + self._impl = _LocalDataMgrSidecar(poll_interval_s=poll_interval_s) + + def start(self): + self._impl.start() + + def stop(self): + self._impl.stop() + + def watch_input(self, room: int, receiver: DataReceiver): + self._impl.watch_input(room, receiver) + + def unwatch_input(self, room: int): + self._impl.unwatch_input(room) + + def watch_output(self, room: int, sender: DataSender): + self._impl.watch_output(room, sender) + + def unwatch_output(self, room: int): + self._impl.unwatch_output(room) + + def pop_ready_inputs(self) -> list[int]: + return self._impl.pop_ready_inputs() + + def pop_failed_inputs(self) -> list[int]: + return self._impl.pop_failed_inputs() + + def pop_completed_outputs(self) -> list[tuple[int, int]]: + return self._impl.pop_completed_outputs() + + def get_pending_counts(self) -> dict[str, int]: + return self._impl.get_pending_counts() + + def init_transformer_output_room( + self, + room: int, + sender_engine_rank: int, + receiver_engine_rank: int, + data_lens: list[int], + bootstrap_addr: str, + ) -> dict[str, Any] | None: + return self._impl.init_transformer_output_room( + room=room, + sender_engine_rank=sender_engine_rank, + receiver_engine_rank=receiver_engine_rank, + data_lens=data_lens, + bootstrap_addr=bootstrap_addr, + ) + + def send_transformer_output_room(self, room: int) -> bool: + return self._impl.send_transformer_output_room(room) + + def get_transformer_output_status(self, room: int) -> int: + return self._impl.get_transformer_output_status(room) + + def remove_transformer_output_room(self, room: int) -> bool: + return self._impl.remove_transformer_output_room(room) + + def get_transformer_output_backlog(self) -> dict[str, int]: + return self._impl.get_transformer_output_backlog() + + def get_transformer_output_identity(self, room: int | None = None) -> dict[str, Any] | None: + return self._impl.get_transformer_output_identity(room) + + +def _run_server_from_cli(): + parser = argparse.ArgumentParser(description="Run DataMgr sidecar server process") + parser.add_argument("--push-addr", type=str, required=True) + parser.add_argument("--req-addr", type=str, required=True) + args = parser.parse_args() + + server = DataMgrSidecarServer(push_addr=args.push_addr, req_addr=args.req_addr) + server.run_forever() + + +if __name__ == "__main__": + _run_server_from_cli() diff --git a/lightx2v/disagg/services/decoder.py b/lightx2v/disagg/services/decoder.py index b558470fa..a3706db74 100644 --- a/lightx2v/disagg/services/decoder.py +++ b/lightx2v/disagg/services/decoder.py @@ -1,5 +1,7 @@ import hashlib import json +import math +import os import threading import time from collections import deque @@ -7,12 +9,13 @@ import torch -from lightx2v.disagg.conn import MONITOR_POLLING_PORT, DataArgs, DataManager, DataPoll, DataReceiver, DisaggregationMode, DisaggregationPhase, ReqManager +from lightx2v.disagg.conn import MONITOR_POLLING_PORT, REQUEST_POLLING_PORT, DataArgs, DataManager, DataReceiver, DisaggregationMode, DisaggregationPhase, ReqManager from lightx2v.disagg.monitor import Reporter from lightx2v.disagg.protocol import AllocationRequest, MemoryHandle, RemoteBuffer from lightx2v.disagg.rdma_buffer import RDMABuffer, RDMABufferDescriptor from lightx2v.disagg.rdma_client import RDMAClient from lightx2v.disagg.services.base import BaseService +from lightx2v.disagg.services.data_mgr_sidecar import DataMgrSidecar from lightx2v.disagg.utils import estimate_transformer_buffer_sizes, load_wan_vae_decoder from lightx2v.utils.envs import GET_DTYPE from lightx2v.utils.utils import save_to_video, seed_all, wan_vae_to_comfy @@ -28,9 +31,13 @@ def __init__(self, config: dict): self.decoder_engine_rank = int(self.config.get("decoder_engine_rank", 2)) self._phase2_rdma_client: Optional[RDMAClient] = None self._phase2_rdma_buffer: Optional[RDMABuffer] = None + self._centralized_request_mgr = ReqManager() + self._centralized_request_port = REQUEST_POLLING_PORT + self.decoder_engine_rank + data_bootstrap_addr = str(self.config.get("data_bootstrap_addr", "127.0.0.1")) + monitor_bind_host = str(self.config.get("local_hostname", data_bootstrap_addr)) shared_slots = int(self.config.get("rdma_buffer_slots", "128")) shared_slot_size = int(self.config.get("rdma_buffer_slot_size", "4096")) - self._phase2_server_ip = str(self.config.get("rdma_phase2_host", "127.0.0.1")) + self._phase2_server_ip = str(self.config.get("rdma_phase2_host", data_bootstrap_addr)) self._phase2_handshake_port = int(self.config.get("rdma_phase2_handshake_port", "5568")) self._phase2_slots = shared_slots self._phase2_slot_size = shared_slot_size @@ -46,7 +53,7 @@ def __init__(self, config: dict): self.reporter = Reporter( service_type="decoder", gpu_id=self.decoder_engine_rank, - bind_address=f"tcp://{self.config.get('data_bootstrap_addr', '127.0.0.1')}:{MONITOR_POLLING_PORT + self.decoder_engine_rank}", + bind_address=f"tcp://{monitor_bind_host}:{MONITOR_POLLING_PORT + self.decoder_engine_rank}", ) self._queue_metrics_lock = threading.Lock() self._queue_metrics: dict[str, Any] = { @@ -61,6 +68,8 @@ def __init__(self, config: dict): daemon=True, ) self._reporter_thread.start() + self._data_mgr_sidecar = DataMgrSidecar() + self.sync_comm = str(os.getenv("SYNC_COMM", "")).strip().lower() not in ("", "0", "false", "no", "off") self.load_models() def _get_queue_metrics(self) -> dict[str, Any]: @@ -114,7 +123,10 @@ def _ensure_phase2_request_buffer(self) -> bool: return True def init(self, config): - self.config = config + self._sync_runtime_config(config) + self.encoder_engine_rank = int(self.config.get("encoder_engine_rank", self.encoder_engine_rank)) + self.transformer_engine_rank = int(self.config.get("transformer_engine_rank", self.transformer_engine_rank)) + self.decoder_engine_rank = int(self.config.get("decoder_engine_rank", self.decoder_engine_rank)) shared_slots = int(self.config.get("rdma_buffer_slots", self._phase2_slots)) shared_slot_size = int(self.config.get("rdma_buffer_slot_size", 4096)) self._phase2_server_ip = str(self.config.get("rdma_phase2_host", self._phase2_server_ip)) @@ -131,10 +143,11 @@ def init(self, config): if data_bootstrap_addr is None or data_bootstrap_room is None: return - try: - self._ensure_phase2_request_buffer() - except Exception: - self.logger.exception("Failed to connect phase2 RDMA buffer, will retry") + if str(os.getenv("IS_CENTRALIZED", "0")).strip().lower() not in {"1", "true", "yes", "on"}: + try: + self._ensure_phase2_request_buffer() + except Exception: + self.logger.exception("Failed to connect phase2 RDMA buffer, will retry") buffer_sizes = estimate_transformer_buffer_sizes(self.config) request = AllocationRequest( @@ -181,6 +194,9 @@ def alloc_memory(self, request: AllocationRequest) -> MemoryHandle: def process(self, config): self.logger.info("Starting processing in DecoderService...") room = config.get("data_bootstrap_room", 0) + decoder_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("decoder", {}) + decoder_metrics["compute_start_ts"] = time.time() + strict_meta_hash_check = str(os.getenv("LIGHTX2V_STRICT_META_HASH", "0")).strip().lower() in {"1", "true", "yes", "on"} room_buffers = self._rdma_buffers.get(room) receiver = self.data_receiver.get(room) @@ -207,15 +223,59 @@ def _sha256_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]: raise RuntimeError("Phase2 RDMA buffers require [latents, meta] entries.") meta_buf = room_buffers[1] - meta_bytes = _buffer_view(meta_buf, torch.uint8, (meta_buf.numel(),)).detach().contiguous().cpu().numpy().tobytes() - meta_str = meta_bytes.split(b"\x00", 1)[0].decode("utf-8") if meta_bytes else "" - if not meta_str: - raise ValueError("missing latents metadata from transformer") - meta = json.loads(meta_str) + + def _read_phase2_meta() -> tuple[dict, str]: + meta_bytes = _buffer_view(meta_buf, torch.uint8, (meta_buf.numel(),)).detach().contiguous().cpu().numpy().tobytes() + meta_str = meta_bytes.split(b"\x00", 1)[0].decode("utf-8", errors="ignore") if meta_bytes else "" + if not meta_str: + raise ValueError("missing latents metadata from transformer") + parsed = json.loads(meta_str) + if not isinstance(parsed, dict): + raise ValueError(f"phase2 metadata type mismatch: {type(parsed)}") + return parsed, meta_str + + def _infer_latents_shape_from_config() -> tuple[int, int, int, int]: + z_dim = int(config.get("vae_z_dim", 16)) + vae_stride = config.get("vae_stride", (4, 8, 8)) + stride_t = int(vae_stride[0]) + stride_h = int(vae_stride[1]) + stride_w = int(vae_stride[2]) + target_video_length = int(config.get("target_video_length", 81)) + target_height = int(config.get("target_height", 480)) + target_width = int(config.get("target_width", 832)) + + t_prime = 1 + (target_video_length - 1) // stride_t + h_prime = int(math.ceil(target_height / stride_h)) + w_prime = int(math.ceil(target_width / stride_w)) + return (z_dim, t_prime, h_prime, w_prime) + + meta = None + meta_str = "" + for attempt in range(3): + try: + meta, meta_str = _read_phase2_meta() + break + except Exception as exc: + if attempt < 2: + # Guard against rare stale/partial metadata visibility. + time.sleep(0.02) + continue + self.logger.warning( + "Invalid phase2 metadata for room=%s, fallback to config-derived shape. err=%s raw_prefix=%r", + room, + exc, + meta_str[:128], + ) + meta = { + "latents_shape": list(_infer_latents_shape_from_config()), + "latents_dtype": str(GET_DTYPE()), + "latents_hash": None, + } latents_shape_val = meta.get("latents_shape") if not isinstance(latents_shape_val, list) or len(latents_shape_val) != 4: - raise ValueError("invalid latents_shape in phase2 metadata") + latents_shape_val = list(_infer_latents_shape_from_config()) + self.logger.warning("phase2 metadata missing/invalid latents_shape for room=%s, using fallback shape=%s", room, latents_shape_val) latent_shape = tuple(int(value) for value in latents_shape_val) dtype_map = { @@ -229,7 +289,10 @@ def _sha256_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]: if list(latents.shape) != meta.get("latents_shape"): raise ValueError("latents shape mismatch between transformer and decoder") if meta.get("latents_hash") is not None and _sha256_tensor(latents) != meta.get("latents_hash"): - raise ValueError("latents hash mismatch between transformer and decoder") + msg = "latents hash mismatch between transformer and decoder" + if strict_meta_hash_check: + raise ValueError(msg) + self.logger.warning("%s for room=%s, continue with non-strict mode", msg, room) latents = latents.to(torch.device(AI_DEVICE)).contiguous() if self.vae_decoder is None: @@ -238,6 +301,7 @@ def _sha256_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]: self.logger.info("Decoding latents in DecoderService...") gen_video = self.vae_decoder.decode(latents.to(GET_DTYPE())) gen_video_final = wan_vae_to_comfy(gen_video) + decoder_metrics["compute_end_ts"] = time.time() save_path = config.get("save_path") if save_path is None: @@ -245,6 +309,7 @@ def _sha256_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]: self.logger.info(f"Saving video to {save_path}...") save_to_video(gen_video_final, save_path, fps=config.get("fps", 16), method="ffmpeg") + decoder_metrics["output_enqueued_ts"] = time.time() self.logger.info("Done!") return save_path @@ -258,6 +323,7 @@ def remove(self, room: int): self.release_memory(room) self.data_receiver.pop(room, None) + self._data_mgr_sidecar.unwatch_input(room) if self.data_mgr is None: return @@ -283,6 +349,7 @@ def run(self, stop_event=None): while True: transfer_sizes = self.data_mgr.get_backlog_counts() if self.data_mgr is not None else {"request_pool": 0, "waiting_pool": 0} + sidecar_sizes = self._data_mgr_sidecar.get_pending_counts() self._update_queue_metrics( { "req_queue": len(req_queue), @@ -292,25 +359,43 @@ def run(self, stop_event=None): { "request_pool": int(transfer_sizes.get("request_pool", 0)), "waiting_pool": int(transfer_sizes.get("waiting_pool", 0)), + "sidecar_input_watch": int(sidecar_sizes.get("input_watch", 0)), }, ) - if self._phase2_rdma_buffer is None: - try: - self._ensure_phase2_request_buffer() - except Exception: - self.logger.exception("Failed to connect phase2 request RDMA buffer, will retry") - - if self._phase2_rdma_buffer is not None: - packet = self._phase2_rdma_buffer.consume() - if packet is not None: - if isinstance(packet, dict) and "request_config" in packet: - config = dict(packet.get("request_config") or {}) - config["transformer_node_address"] = packet.get("transformer_node_address", "127.0.0.1") - else: - config = packet - self.logger.info("Received request config from RDMA buffer: %s", {k: v for k, v in config.items()}) + centralized_request_mode = str(os.getenv("IS_CENTRALIZED", "0")).strip().lower() in {"1", "true", "yes", "on"} + if centralized_request_mode: + config = self._centralized_request_mgr.receive_non_block(self._centralized_request_port) + if config is not None: + if not isinstance(config, dict) or "data_bootstrap_room" not in config: + self.logger.warning("Ignored incomplete request packet from ZMQ: %s", config) + continue + decoder_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("decoder", {}) + decoder_metrics["request_received_ts"] = time.time() + self.logger.info("Received request config from ZMQ: %s", {k: v for k, v in config.items()}) req_queue.append(config) + else: + if self._phase2_rdma_buffer is None: + try: + self._ensure_phase2_request_buffer() + except Exception: + self.logger.exception("Failed to connect phase2 request RDMA buffer, will retry") + + if self._phase2_rdma_buffer is not None: + packet = self._phase2_rdma_buffer.consume() + if packet is not None: + if isinstance(packet, dict) and "request_config" in packet: + config = dict(packet.get("request_config") or {}) + config["transformer_node_address"] = packet.get("transformer_node_address", "127.0.0.1") + else: + config = packet + if not isinstance(config, dict) or "data_bootstrap_room" not in config: + self.logger.warning("Ignored incomplete phase2 packet from RDMA buffer: %s", packet) + continue + decoder_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("decoder", {}) + decoder_metrics["request_received_ts"] = time.time() + self.logger.info("Received request config from RDMA buffer: %s", {k: v for k, v in config.items()}) + req_queue.append(config) if req_queue: config = req_queue.popleft() @@ -318,27 +403,23 @@ def run(self, stop_event=None): try: self.init(config) waiting_queue[room] = config + receiver = self.data_receiver.get(room) + if receiver is None: + raise RuntimeError(f"DataReceiver is not initialized for room={room}") + self._data_mgr_sidecar.watch_input(room, receiver) except Exception: self.logger.exception("Failed to initialize request for room=%s", room) self.remove(room) - ready_rooms: List[int] = [] - failed_rooms: List[int] = [] - for room in list(waiting_queue.keys()): - receiver = self.data_receiver.get(room) - if receiver is None: - failed_rooms.append(room) - continue - - status = receiver.poll() - if status == DataPoll.Success: - ready_rooms.append(room) - elif status == DataPoll.Failed: - failed_rooms.append(room) + ready_rooms = self._data_mgr_sidecar.pop_ready_inputs() + failed_rooms = self._data_mgr_sidecar.pop_failed_inputs() for room in ready_rooms: + config = waiting_queue.pop(room, None) + if config is None: + continue self.logger.info("Latents received successfully in DecoderService for room=%s.", room) - exec_queue.append((room, waiting_queue.pop(room))) + exec_queue.append((room, config)) for room in failed_rooms: waiting_queue.pop(room, None) @@ -359,6 +440,7 @@ def run(self, stop_event=None): "ok": True, "data_bootstrap_room": int(room), "save_path": save_path, + "request_metrics": config.get("request_metrics"), }, ) except Exception: @@ -374,6 +456,7 @@ def run(self, stop_event=None): "data_bootstrap_room": int(room), "save_path": None, "error": "decoder process failed", + "request_metrics": config.get("request_metrics"), }, ) finally: diff --git a/lightx2v/disagg/services/encoder.py b/lightx2v/disagg/services/encoder.py index 057dde5cc..5b86b81df 100644 --- a/lightx2v/disagg/services/encoder.py +++ b/lightx2v/disagg/services/encoder.py @@ -1,19 +1,23 @@ import hashlib import json +import os import threading import time from collections import deque from typing import Any, Dict, List, Optional +from urllib.error import URLError +from urllib.request import Request, urlopen import numpy as np import torch -from lightx2v.disagg.conn import MONITOR_POLLING_PORT, DataArgs, DataManager, DataPoll, DataSender, DisaggregationMode, DisaggregationPhase +from lightx2v.disagg.conn import MONITOR_POLLING_PORT, REQUEST_POLLING_PORT, DataArgs, DataManager, DataPoll, DataSender, DisaggregationMode, DisaggregationPhase, ReqManager from lightx2v.disagg.monitor import Reporter from lightx2v.disagg.protocol import AllocationRequest, MemoryHandle, RemoteBuffer from lightx2v.disagg.rdma_buffer import RDMABuffer, RDMABufferDescriptor from lightx2v.disagg.rdma_client import RDMAClient from lightx2v.disagg.services.base import BaseService +from lightx2v.disagg.services.data_mgr_sidecar import DataMgrSidecar from lightx2v.disagg.utils import ( estimate_encoder_buffer_sizes, load_wan_image_encoder, @@ -35,20 +39,25 @@ def __init__(self, config: dict): self.decoder_engine_rank = int(self.config.get("decoder_engine_rank", "2")) self._request_rdma_client: Optional[RDMAClient] = None self._request_rdma_buffer: Optional[RDMABuffer] = None + self._centralized_request_mgr = ReqManager() + self._centralized_request_port = REQUEST_POLLING_PORT + self.encoder_engine_rank self._phase1_rdma_client: Optional[RDMAClient] = None self._phase1_rdma_buffer: Optional[RDMABuffer] = None + data_bootstrap_addr = str(self.config.get("data_bootstrap_addr", "127.0.0.1")) + monitor_bind_host = str(self.config.get("local_hostname", data_bootstrap_addr)) shared_slots = int(self.config.get("rdma_buffer_slots", "128")) shared_slot_size = int(self.config.get("rdma_buffer_slot_size", "4096")) - self._request_server_ip = str(self.config.get("rdma_request_host", "127.0.0.1")) + self._request_server_ip = str(self.config.get("rdma_request_host", data_bootstrap_addr)) self._request_handshake_port = int(self.config.get("rdma_request_handshake_port", "5566")) self._request_slots = shared_slots self._request_slot_size = shared_slot_size - self._phase1_server_ip = str(self.config.get("rdma_phase1_host", "127.0.0.1")) + self._phase1_server_ip = str(self.config.get("rdma_phase1_host", data_bootstrap_addr)) self._phase1_handshake_port = int(self.config.get("rdma_phase1_handshake_port", "5567")) self._phase1_slots = shared_slots self._phase1_slot_size = shared_slot_size self._last_request_connect_retry_ts = 0.0 self._last_phase1_connect_retry_ts = 0.0 + self._centralized_request_mode = str(os.getenv("IS_CENTRALIZED", "0")).strip().lower() in {"1", "true", "yes", "on"} self.text_encoder = None self.image_encoder = None self.vae_encoder = None @@ -61,7 +70,7 @@ def __init__(self, config: dict): self.reporter = Reporter( service_type="encoder", gpu_id=self.encoder_engine_rank, - bind_address=f"tcp://{self.config.get('data_bootstrap_addr', '127.0.0.1')}:{MONITOR_POLLING_PORT + self.encoder_engine_rank}", + bind_address=f"tcp://{monitor_bind_host}:{MONITOR_POLLING_PORT + self.encoder_engine_rank}", ) self._queue_metrics_lock = threading.Lock() self._queue_metrics: dict[str, Any] = { @@ -76,8 +85,105 @@ def __init__(self, config: dict): daemon=True, ) self._reporter_thread.start() + self._data_mgr_sidecar = DataMgrSidecar() + self.sync_comm = str(os.getenv("SYNC_COMM", "")).strip().lower() not in ("", "0", "false", "no", "off") self.load_models() + def _wait_sender_success(self, room: int, sender: DataSender): + while True: + status = sender.poll() + if status == DataPoll.Success: + return + if status == DataPoll.Failed: + raise RuntimeError(f"DataSender transfer failed for room={room}") + time.sleep(0.001) + + def _report_stage_metrics_to_controller(self, stage_name: str, config: dict[str, Any]): + if not self._centralized_request_mode: + return + + controller_host = str(config.get("controller_result_host", "127.0.0.1")) + controller_port_raw = config.get("controller_result_port") + if controller_port_raw is None: + return + + try: + controller_port = int(controller_port_raw) + except (TypeError, ValueError): + return + + request_metrics = config.get("request_metrics") + if not isinstance(request_metrics, dict): + return + + stage_metrics = request_metrics.get("stages", {}).get(stage_name) + if not isinstance(stage_metrics, dict): + return + + payload_request_metrics: dict[str, Any] = { + "request_id": request_metrics.get("request_id", config.get("data_bootstrap_room")), + "stages": {stage_name: stage_metrics}, + } + if request_metrics.get("controller_send_ts") is not None: + payload_request_metrics["controller_send_ts"] = request_metrics.get("controller_send_ts") + + self._centralized_request_mgr.send( + controller_host, + controller_port, + { + "message_type": "stage_metrics", + "stage_name": stage_name, + "data_bootstrap_room": int(config.get("data_bootstrap_room", 0)), + "request_metrics": payload_request_metrics, + }, + ) + self.logger.info( + "Reported %s stage metrics to controller: room=%s target=%s:%s", + stage_name, + config.get("data_bootstrap_room"), + controller_host, + controller_port, + ) + + def _wait_for_controller_ok(self, stage_name: str, config: dict[str, Any]): + if not self._centralized_request_mode: + return + + controller_host = str(config.get("controller_control_host", config.get("controller_result_host", "127.0.0.1"))) + controller_port_raw = config.get("controller_control_port") + if controller_port_raw is None: + return + + try: + controller_port = int(controller_port_raw) + except (TypeError, ValueError): + return + + request_body = json.dumps( + { + "control": "OK", + "stage_name": stage_name, + "data_bootstrap_room": int(config.get("data_bootstrap_room", 0)), + } + ).encode("utf-8") + request = Request( + f"http://{controller_host}:{controller_port}/ok", + data=request_body, + headers={"Content-Type": "application/json"}, + method="POST", + ) + try: + with urlopen(request, timeout=10) as response: + reply = json.loads(response.read().decode("utf-8")) + if not isinstance(reply, dict) or not reply.get("ok", False): + raise RuntimeError(f"unexpected controller OK reply: {reply}") + except URLError: + self.logger.exception("Failed to wait for controller OK reply for %s room=%s", stage_name, config.get("data_bootstrap_room")) + return + except Exception: + self.logger.exception("Failed to wait for controller OK reply for %s room=%s", stage_name, config.get("data_bootstrap_room")) + return + def _get_queue_metrics(self) -> dict[str, Any]: with self._queue_metrics_lock: queue_sizes = dict(self._queue_metrics.get("queue_sizes", {})) @@ -142,6 +248,21 @@ def _ensure_request_buffer(self) -> bool: ) return True + def _reconnect_request_buffer(self): + self._request_rdma_buffer = None + self._last_request_connect_retry_ts = 0.0 + + if self._request_rdma_client is not None: + sock = getattr(self._request_rdma_client, "sock", None) + if sock is not None: + try: + sock.close() + except Exception: + pass + self._request_rdma_client.sock = None + + self._ensure_request_buffer() + def _ensure_phase1_meta_buffer(self) -> bool: if self._phase1_rdma_buffer is not None: return True @@ -184,8 +305,68 @@ def _ensure_phase1_meta_buffer(self) -> bool: ) return True + def _reconnect_phase1_meta_buffer(self): + self._phase1_rdma_buffer = None + self._last_phase1_connect_retry_ts = 0.0 + + if self._phase1_rdma_client is not None: + sock = getattr(self._phase1_rdma_client, "sock", None) + if sock is not None: + try: + sock.close() + except Exception: + pass + self._phase1_rdma_client.sock = None + + self._ensure_phase1_meta_buffer() + + def _produce_phase1_request_with_retry(self, room: int, payload: dict[str, Any]): + retries = max(1, int(os.getenv("RDMA_PHASE1_PRODUCE_RETRIES", "3"))) + retry_delay_s = max(0.01, float(os.getenv("RDMA_PHASE1_PRODUCE_RETRY_DELAY_S", "0.2"))) + last_exc: Optional[Exception] = None + + for attempt in range(1, retries + 1): + try: + if self._phase1_rdma_buffer is None: + self._ensure_phase1_meta_buffer() + if self._phase1_rdma_buffer is None: + raise RuntimeError("phase1 RDMA buffer is not ready") + self._phase1_rdma_buffer.produce(payload) + return + except Exception as exc: + last_exc = exc + self.logger.warning( + "Phase1 RDMA produce failed for room=%s attempt=%s/%s host=%s port=%s: %s", + room, + attempt, + retries, + self._phase1_server_ip, + self._phase1_handshake_port, + exc, + ) + if attempt >= retries: + break + try: + self._reconnect_phase1_meta_buffer() + except Exception as reconnect_exc: + self.logger.warning( + "Phase1 RDMA reconnect failed for room=%s attempt=%s/%s host=%s port=%s: %s", + room, + attempt, + retries, + self._phase1_server_ip, + self._phase1_handshake_port, + reconnect_exc, + ) + time.sleep(retry_delay_s) + + raise RuntimeError(f"Failed to produce phase1 RDMA request for room={room} after {retries} attempts") from last_exc + def init(self, config): - self.config = config + self._sync_runtime_config(config) + self.encoder_engine_rank = int(self.config.get("encoder_engine_rank", self.encoder_engine_rank)) + self.transformer_engine_rank = int(self.config.get("transformer_engine_rank", self.transformer_engine_rank)) + self.decoder_engine_rank = int(self.config.get("decoder_engine_rank", self.decoder_engine_rank)) shared_slots = int(self.config.get("rdma_buffer_slots", self._request_slots)) shared_slot_size = int(self.config.get("rdma_buffer_slot_size", 4096)) self._request_server_ip = str(self.config.get("rdma_request_host", self._request_server_ip)) @@ -238,13 +419,6 @@ def init(self, config): self.data_mgr.init(data_args, data_bootstrap_room) self.data_sender[data_bootstrap_room] = DataSender(self.data_mgr, data_bootstrap_addr, data_bootstrap_room) - phase1_meta = { - "request_config": dict(self.config), - "encoder_node_address": self.data_mgr.get_localhost(), - "encoder_session_id": self.data_mgr.get_session_id(), - } - self._phase1_rdma_buffer.produce(phase1_meta) - def load_models(self): self.logger.info("Loading Encoder Models...") @@ -345,6 +519,8 @@ def process(self, config): """ self.logger.info("Starting processing in EncoderService...") room = int(config.get("data_bootstrap_room", 0)) + encoder_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("encoder", {}) + encoder_metrics["compute_start_ts"] = time.time() room_buffers = self._rdma_buffers.get(room) sender = self.data_sender.get(room) @@ -411,6 +587,7 @@ def process(self, config): else: raise ValueError(f"Unsupported task: {task}") + encoder_metrics["compute_end_ts"] = time.time() self.logger.info("Encode processing completed. Preparing to send data...") if self.data_mgr is not None and sender is not None: @@ -499,7 +676,20 @@ def _sha256_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]: meta_buf[: len(meta_bytes)].copy_(torch.from_numpy(np.frombuffer(meta_bytes, dtype=np.uint8))) buffer_ptrs = [buf.data_ptr() for buf in room_buffers] + # Publish phase1 request metadata after compute so downstream can see latest metrics. + encoder_metrics["output_enqueued_ts"] = time.time() + if self._centralized_request_mode: + self._report_stage_metrics_to_controller("encoder", config) + self._wait_for_controller_ok("encoder", config) + phase1_meta = { + "request_config": dict(config), + "encoder_node_address": self.data_mgr.get_localhost(), + "encoder_session_id": self.data_mgr.get_session_id(), + } + self._produce_phase1_request_with_retry(room, phase1_meta) sender.send(buffer_ptrs) + if self.sync_comm: + self._wait_sender_success(room, sender) def release_memory(self, room: int): """ @@ -513,6 +703,7 @@ def remove(self, room: int): self.release_memory(room) self.data_sender.pop(room, None) + self._data_mgr_sidecar.unwatch_output(room) if self.data_mgr is None: return @@ -536,33 +727,61 @@ def release(self): def run(self, stop_event=None): req_queue = deque() exec_queue = deque() - complete_queue: Dict[int, dict] = {} + complete_queue: set[int] = set() while True: transfer_sizes = self.data_mgr.get_backlog_counts() if self.data_mgr is not None else {"request_pool": 0, "waiting_pool": 0} + sidecar_sizes = self._data_mgr_sidecar.get_pending_counts() self._update_queue_metrics( { "req_queue": len(req_queue), "exec_queue": len(exec_queue), - "complete_queue": len(complete_queue), }, { + "complete_queue": len(complete_queue), "request_pool": int(transfer_sizes.get("request_pool", 0)), "waiting_pool": int(transfer_sizes.get("waiting_pool", 0)), + "sidecar_output_watch": int(sidecar_sizes.get("output_watch", 0)), }, ) - if self._request_rdma_buffer is None: - try: - self._ensure_request_buffer() - except Exception: - self.logger.exception("Failed to connect request RDMA buffer, will retry") - - if self._request_rdma_buffer is not None: - config = self._request_rdma_buffer.consume() + if self._centralized_request_mode: + config = self._centralized_request_mgr.receive_non_block(self._centralized_request_port) if config is not None: - self.logger.info("Received request config from RDMA buffer: %s", {k: v for k, v in config.items()}) + if not isinstance(config, dict) or "data_bootstrap_room" not in config: + self.logger.warning("Ignored incomplete request packet from ZMQ: %s", config) + continue + encoder_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("encoder", {}) + encoder_metrics["request_received_ts"] = time.time() + self.logger.info("Received request config from ZMQ: %s", {k: v for k, v in config.items()}) req_queue.append(config) + else: + if self._request_rdma_buffer is None: + try: + self._ensure_request_buffer() + except Exception: + self.logger.exception("Failed to connect request RDMA buffer, will retry") + + if self._request_rdma_client is not None and self._request_rdma_client.has_qp_error(): + self.logger.warning( + "Request RDMA client entered error state, reconnecting: %s", + self._request_rdma_client.last_wc_error_message(), + ) + try: + self._reconnect_request_buffer() + except Exception: + self.logger.exception("Failed to reconnect request RDMA buffer after QP error") + + if self._request_rdma_buffer is not None: + config = self._request_rdma_buffer.consume() + if config is not None: + if not isinstance(config, dict) or "data_bootstrap_room" not in config: + self.logger.warning("Ignored incomplete request packet from RDMA buffer: %s", config) + continue + encoder_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("encoder", {}) + encoder_metrics["request_received_ts"] = time.time() + self.logger.info("Received request config from RDMA buffer: %s", {k: v for k, v in config.items()}) + req_queue.append(config) if req_queue: config = req_queue.popleft() @@ -578,28 +797,25 @@ def run(self, stop_event=None): room, config = exec_queue.popleft() try: self.process(config) - complete_queue[room] = config + if self.sync_comm: + self.remove(room) + else: + sender = self.data_sender.get(room) + if sender is None: + self.logger.error("DataSender is missing for room=%s", room) + self.remove(room) + else: + self._data_mgr_sidecar.watch_output(room, sender) + complete_queue.add(room) except Exception: self.logger.exception("Failed to process request for room=%s", room) - complete_queue.pop(room, None) self.remove(room) - completed_rooms: List[int] = [] - for room in list(complete_queue.keys()): - sender = self.data_sender.get(room) - if sender is None: - completed_rooms.append(room) - continue - - status = sender.poll() - if status == DataPoll.Success: - completed_rooms.append(room) - elif status == DataPoll.Failed: + completed_outputs = self._data_mgr_sidecar.pop_completed_outputs() + for room, status in completed_outputs: + if status == DataPoll.Failed: self.logger.error("DataSender transfer failed for room=%s", room) - completed_rooms.append(room) - - for room in completed_rooms: - complete_queue.pop(room, None) + complete_queue.discard(room) self.remove(room) if stop_event is not None and stop_event.is_set() and not req_queue and not exec_queue and not complete_queue: diff --git a/lightx2v/disagg/services/instance_proxy.py b/lightx2v/disagg/services/instance_proxy.py new file mode 100644 index 000000000..c2a319809 --- /dev/null +++ b/lightx2v/disagg/services/instance_proxy.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +import argparse +import os +import signal +import subprocess +import time +from collections.abc import Mapping +from pathlib import Path +from typing import Any + +import zmq + + +class InstanceProxyServer: + """Remote process proxy that creates/stops local disagg service processes. + + This server is intended to run on remote nodes where the local runtime + environment is trusted. The controller sends simple commands to this proxy + instead of assembling remote launch scripts for every instance operation. + """ + + def __init__(self, bind_addr: str, workdir: str, log_dir: str): + self.bind_addr = str(bind_addr) + self.workdir = str(workdir) + self.log_dir = str(log_dir) + self._running = True + self._managed: dict[int, subprocess.Popen] = {} + + def _normalize_env(self, extra_env: Any, cuda_device: str) -> dict[str, str]: + env = os.environ.copy() + if isinstance(extra_env, Mapping): + for key, value in extra_env.items(): + env[str(key)] = str(value) + env["CUDA_VISIBLE_DEVICES"] = str(cuda_device) + env["PYTHONUNBUFFERED"] = "1" + return env + + def _terminate_pid(self, pid: int, timeout_seconds: float) -> bool: + process = self._managed.get(pid) + timeout_seconds = max(1.0, float(timeout_seconds)) + + if process is not None: + if process.poll() is not None: + self._managed.pop(pid, None) + return True + + try: + os.killpg(process.pid, signal.SIGTERM) + except Exception: + process.terminate() + + deadline = time.time() + timeout_seconds + while time.time() < deadline: + if process.poll() is not None: + self._managed.pop(pid, None) + return True + time.sleep(0.1) + + try: + os.killpg(process.pid, signal.SIGKILL) + except Exception: + process.kill() + + try: + process.wait(timeout=2.0) + except Exception: + pass + self._managed.pop(pid, None) + return process.poll() is not None + + # Fallback for pids created before current proxy process lifetime. + try: + os.kill(pid, signal.SIGTERM) + except ProcessLookupError: + return True + except Exception: + return False + + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + os.kill(pid, 0) + except ProcessLookupError: + return True + except Exception: + break + time.sleep(0.1) + + try: + os.kill(pid, signal.SIGKILL) + return True + except ProcessLookupError: + return True + except Exception: + return False + + def _start_instance(self, msg: dict[str, Any]) -> dict[str, Any]: + instance_type = str(msg.get("instance_type", "")) + engine_rank = int(msg.get("engine_rank", -1)) + cuda_device = str(msg.get("cuda_device", "0")) + python_executable = str(msg.get("python_executable", "python")) + service_argv = msg.get("service_argv", []) + sidecar_push_addr = str(msg.get("sidecar_push_addr", "")).strip() + sidecar_req_addr = str(msg.get("sidecar_req_addr", "")).strip() + service_log_path = str(msg.get("service_log_path", "")).strip() + sidecar_log_path = str(msg.get("sidecar_log_path", "")).strip() + workdir = str(msg.get("workdir", self.workdir)) + log_dir = str(msg.get("log_dir", self.log_dir)) + extra_env = msg.get("env", {}) + + if not instance_type: + raise ValueError("instance_type is required") + if engine_rank < 0: + raise ValueError("engine_rank must be non-negative") + if not isinstance(service_argv, list) or not service_argv: + raise ValueError("service_argv must be a non-empty list") + if not sidecar_push_addr or not sidecar_req_addr: + raise ValueError("sidecar_push_addr and sidecar_req_addr are required") + + if not service_log_path: + service_log_path = f"{log_dir}/{instance_type}_{engine_rank}_service.log" + if not sidecar_log_path: + sidecar_log_path = f"{log_dir}/{instance_type}_{engine_rank}_sidecar.log" + + os.makedirs(log_dir, exist_ok=True) + os.makedirs(Path(service_log_path).parent, exist_ok=True) + os.makedirs(Path(sidecar_log_path).parent, exist_ok=True) + + sidecar_env = self._normalize_env(extra_env, cuda_device) + service_env = self._normalize_env(extra_env, cuda_device) + service_env["LIGHTX2V_SIDECAR_PUSH_ADDR"] = sidecar_push_addr + service_env["LIGHTX2V_SIDECAR_REQ_ADDR"] = sidecar_req_addr + + sidecar_cmd = [ + python_executable, + "-m", + "lightx2v.disagg.services.data_mgr_sidecar", + "--push-addr", + sidecar_push_addr, + "--req-addr", + sidecar_req_addr, + ] + service_cmd = [python_executable, *[str(part) for part in service_argv]] + + with open(sidecar_log_path, "w", encoding="utf-8") as sidecar_log: + sidecar_proc = subprocess.Popen( + sidecar_cmd, + cwd=workdir, + env=sidecar_env, + stdout=sidecar_log, + stderr=subprocess.STDOUT, + start_new_session=True, + ) + + time.sleep(0.3) + if sidecar_proc.poll() is not None: + raise RuntimeError(f"failed to start sidecar process, exited with code={sidecar_proc.returncode}") + + with open(service_log_path, "w", encoding="utf-8") as service_log: + service_proc = subprocess.Popen( + service_cmd, + cwd=workdir, + env=service_env, + stdout=service_log, + stderr=subprocess.STDOUT, + start_new_session=True, + ) + + if service_proc.poll() is not None: + self._terminate_pid(sidecar_proc.pid, timeout_seconds=2.0) + raise RuntimeError(f"failed to start service process, exited with code={service_proc.returncode}") + + self._managed[sidecar_proc.pid] = sidecar_proc + self._managed[service_proc.pid] = service_proc + + return { + "instance_type": instance_type, + "engine_rank": engine_rank, + "sidecar_pid": sidecar_proc.pid, + "service_pid": service_proc.pid, + "sidecar_log_path": sidecar_log_path, + "service_log_path": service_log_path, + } + + def handle(self, msg: dict[str, Any]) -> dict[str, Any]: + cmd = str(msg.get("cmd", "")).strip() + + if cmd == "ping": + return {"ok": True, "data": {"alive": True, "managed": len(self._managed)}} + + if cmd == "start_instance": + data = self._start_instance(msg) + return {"ok": True, "data": data} + + if cmd == "stop_pid": + pid = int(msg.get("pid", -1)) + timeout_seconds = float(msg.get("timeout_seconds", 10.0)) + if pid <= 0: + return {"ok": False, "error": "invalid pid"} + stopped = self._terminate_pid(pid, timeout_seconds=timeout_seconds) + return {"ok": bool(stopped), "data": {"pid": pid, "stopped": bool(stopped)}} + + if cmd == "shutdown": + self._running = False + return {"ok": True, "data": {"shutting_down": True}} + + if cmd == "stats": + managed_alive = 0 + for process in self._managed.values(): + if process.poll() is None: + managed_alive += 1 + return {"ok": True, "data": {"managed_alive": managed_alive}} + + return {"ok": False, "error": f"unsupported command: {cmd}"} + + def serve(self): + context = zmq.Context() + socket = context.socket(zmq.REP) + socket.bind(self.bind_addr) + try: + while self._running: + try: + msg = socket.recv_pyobj() + if not isinstance(msg, dict): + socket.send_pyobj({"ok": False, "error": "request must be a dict"}) + continue + reply = self.handle(msg) + except Exception as exc: + reply = {"ok": False, "error": str(exc)} + socket.send_pyobj(reply) + finally: + socket.close(0) + context.term() + for pid in list(self._managed.keys()): + self._terminate_pid(pid, timeout_seconds=2.0) + + +def main(): + parser = argparse.ArgumentParser(description="Remote instance proxy for disagg services") + parser.add_argument("--bind-addr", type=str, required=True) + parser.add_argument("--workdir", type=str, default=str(Path(__file__).resolve().parents[3])) + parser.add_argument("--log-dir", type=str, default="/tmp/lightx2v_disagg") + args = parser.parse_args() + + server = InstanceProxyServer(bind_addr=args.bind_addr, workdir=args.workdir, log_dir=args.log_dir) + server.serve() + + +if __name__ == "__main__": + main() diff --git a/lightx2v/disagg/services/transformer.py b/lightx2v/disagg/services/transformer.py index 33e98e5f1..f3917d88d 100644 --- a/lightx2v/disagg/services/transformer.py +++ b/lightx2v/disagg/services/transformer.py @@ -1,19 +1,24 @@ import hashlib import json +import os import threading import time from collections import deque +from multiprocessing import resource_tracker, shared_memory from typing import Any, List, Optional +from urllib.error import URLError +from urllib.request import Request, urlopen import numpy as np import torch -from lightx2v.disagg.conn import MONITOR_POLLING_PORT, DataArgs, DataManager, DataPoll, DataReceiver, DataSender, DisaggregationMode, DisaggregationPhase +from lightx2v.disagg.conn import MONITOR_POLLING_PORT, REQUEST_POLLING_PORT, DataArgs, DataManager, DataPoll, DataReceiver, DataSender, DisaggregationMode, DisaggregationPhase, ReqManager from lightx2v.disagg.monitor import Reporter from lightx2v.disagg.protocol import AllocationRequest, MemoryHandle, RemoteBuffer from lightx2v.disagg.rdma_buffer import RDMABuffer, RDMABufferDescriptor from lightx2v.disagg.rdma_client import RDMAClient from lightx2v.disagg.services.base import BaseService +from lightx2v.disagg.services.data_mgr_sidecar import DataMgrSidecar from lightx2v.disagg.utils import ( estimate_encoder_buffer_sizes, estimate_transformer_buffer_sizes, @@ -24,10 +29,36 @@ from lightx2v.utils.utils import seed_all from lightx2v_platform.base.global_var import AI_DEVICE +_SHM_TRACKING_PATCHED = False + + +def _disable_shared_memory_tracking_for_process(): + global _SHM_TRACKING_PATCHED + if _SHM_TRACKING_PATCHED: + return + + original_register = resource_tracker.register + original_unregister = resource_tracker.unregister + + def _register(name, rtype): + if rtype == "shared_memory": + return + return original_register(name, rtype) + + def _unregister(name, rtype): + if rtype == "shared_memory": + return + return original_unregister(name, rtype) + + resource_tracker.register = _register + resource_tracker.unregister = _unregister + _SHM_TRACKING_PATCHED = True + class TransformerService(BaseService): def __init__(self, config: dict): super().__init__() + _disable_shared_memory_tracking_for_process() self.config = config self.encoder_engine_rank = int(self.config.get("encoder_engine_rank", 0)) self.transformer_engine_rank = int(self.config.get("transformer_engine_rank", 1)) @@ -36,13 +67,18 @@ def __init__(self, config: dict): self._phase1_rdma_buffer: Optional[RDMABuffer] = None self._phase2_rdma_client: Optional[RDMAClient] = None self._phase2_rdma_buffer: Optional[RDMABuffer] = None + self._centralized_request_mgr = ReqManager() + self._centralized_request_port = REQUEST_POLLING_PORT + self.transformer_engine_rank + data_bootstrap_addr = str(self.config.get("data_bootstrap_addr", "127.0.0.1")) + monitor_bind_host = str(self.config.get("local_hostname", data_bootstrap_addr)) shared_slots = int(self.config.get("rdma_buffer_slots", "128")) shared_slot_size = int(self.config.get("rdma_buffer_slot_size", "4096")) - self._phase1_server_ip = str(self.config.get("rdma_phase1_host", "127.0.0.1")) + self._centralized_request_mode = str(os.getenv("IS_CENTRALIZED", "0")).strip().lower() in {"1", "true", "yes", "on"} + self._phase1_server_ip = str(self.config.get("rdma_phase1_host", data_bootstrap_addr)) self._phase1_handshake_port = int(self.config.get("rdma_phase1_handshake_port", "5567")) self._phase1_slots = shared_slots self._phase1_slot_size = shared_slot_size - self._phase2_server_ip = str(self.config.get("rdma_phase2_host", "127.0.0.1")) + self._phase2_server_ip = str(self.config.get("rdma_phase2_host", data_bootstrap_addr)) self._phase2_handshake_port = int(self.config.get("rdma_phase2_handshake_port", "5568")) self._phase2_slots = shared_slots self._phase2_slot_size = shared_slot_size @@ -55,11 +91,13 @@ def __init__(self, config: dict): self.data_mgr1 = DataManager(DisaggregationPhase.PHASE1, DisaggregationMode.TRANSFORMER) self.data_mgr2 = DataManager(DisaggregationPhase.PHASE2, DisaggregationMode.TRANSFORMER) self.data_receiver: dict[int, DataReceiver] = {} - self.data_sender: dict[int, DataSender] = {} + self.data_sender: dict[int, Optional[DataSender]] = {} + self._phase2_remote_rooms: set[int] = set() + self._phase2_remote_shared_memory: dict[int, list[shared_memory.SharedMemory]] = {} self.reporter = Reporter( service_type="transformer", gpu_id=self.transformer_engine_rank, - bind_address=f"tcp://{self.config.get('data_bootstrap_addr', '127.0.0.1')}:{MONITOR_POLLING_PORT + self.transformer_engine_rank}", + bind_address=f"tcp://{monitor_bind_host}:{MONITOR_POLLING_PORT + self.transformer_engine_rank}", ) self._queue_metrics_lock = threading.Lock() self._queue_metrics: dict[str, Any] = { @@ -74,8 +112,112 @@ def __init__(self, config: dict): daemon=True, ) self._reporter_thread.start() + self._data_mgr_sidecar = DataMgrSidecar() + self.sync_comm = str(os.getenv("SYNC_COMM", "")).strip().lower() not in ("", "0", "false", "no", "off") self.load_models() + def _wait_sender_success(self, room: int, sender: DataSender): + while True: + status = sender.poll() + if status == DataPoll.Success: + return + if status == DataPoll.Failed: + raise RuntimeError(f"DataSender transfer failed for room={room}") + time.sleep(0.001) + + def _report_stage_metrics_to_controller(self, stage_name: str, config: dict[str, Any]): + if not self._centralized_request_mode: + return + + controller_host = str(config.get("controller_result_host", "127.0.0.1")) + controller_port_raw = config.get("controller_result_port") + if controller_port_raw is None: + return + + try: + controller_port = int(controller_port_raw) + except (TypeError, ValueError): + return + + request_metrics = config.get("request_metrics") + if not isinstance(request_metrics, dict): + return + + stage_metrics = request_metrics.get("stages", {}).get(stage_name) + if not isinstance(stage_metrics, dict): + return + + payload_request_metrics: dict[str, Any] = { + "request_id": request_metrics.get("request_id", config.get("data_bootstrap_room")), + "stages": {stage_name: stage_metrics}, + } + if request_metrics.get("controller_send_ts") is not None: + payload_request_metrics["controller_send_ts"] = request_metrics.get("controller_send_ts") + + self._centralized_request_mgr.send( + controller_host, + controller_port, + { + "message_type": "stage_metrics", + "stage_name": stage_name, + "data_bootstrap_room": int(config.get("data_bootstrap_room", 0)), + "request_metrics": payload_request_metrics, + }, + ) + self.logger.info( + "Reported %s stage metrics to controller: room=%s target=%s:%s", + stage_name, + config.get("data_bootstrap_room"), + controller_host, + controller_port, + ) + + def _wait_for_controller_ok(self, stage_name: str, config: dict[str, Any]): + if not self._centralized_request_mode: + return + + controller_host = str(config.get("controller_control_host", config.get("controller_result_host", "127.0.0.1"))) + controller_port_raw = config.get("controller_control_port") + if controller_port_raw is None: + return + + try: + controller_port = int(controller_port_raw) + except (TypeError, ValueError): + return + + request_body = json.dumps( + { + "control": "OK", + "stage_name": stage_name, + "data_bootstrap_room": int(config.get("data_bootstrap_room", 0)), + } + ).encode("utf-8") + request = Request( + f"http://{controller_host}:{controller_port}/ok", + data=request_body, + headers={"Content-Type": "application/json"}, + method="POST", + ) + try: + with urlopen(request, timeout=10) as response: + reply = json.loads(response.read().decode("utf-8")) + if not isinstance(reply, dict) or not reply.get("ok", False): + raise RuntimeError(f"unexpected controller OK reply: {reply}") + except URLError: + self.logger.exception("Failed to wait for controller OK reply for %s room=%s", stage_name, config.get("data_bootstrap_room")) + return + except Exception: + self.logger.exception("Failed to wait for controller OK reply for %s room=%s", stage_name, config.get("data_bootstrap_room")) + return + + def _attach_remote_shared_memory(self, shm_name: str) -> shared_memory.SharedMemory: + # Python 3.12 supports `track=False`, which avoids duplicate cleanup from non-owner processes. + try: + return shared_memory.SharedMemory(name=shm_name, create=False, track=False) + except TypeError: + return shared_memory.SharedMemory(name=shm_name, create=False) + def _get_queue_metrics(self) -> dict[str, Any]: with self._queue_metrics_lock: queue_sizes = dict(self._queue_metrics.get("queue_sizes", {})) @@ -134,6 +276,21 @@ def _ensure_phase1_request_buffer(self) -> bool: ) return True + def _reconnect_phase1_request_buffer(self): + self._phase1_rdma_buffer = None + self._last_phase1_connect_retry_ts = 0.0 + + if self._phase1_rdma_client is not None: + sock = getattr(self._phase1_rdma_client, "sock", None) + if sock is not None: + try: + sock.close() + except Exception: + pass + self._phase1_rdma_client.sock = None + + self._ensure_phase1_request_buffer() + def _ensure_phase2_meta_buffer(self) -> bool: if self._phase2_rdma_buffer is not None: return True @@ -162,8 +319,64 @@ def _ensure_phase2_meta_buffer(self) -> bool: ) return True + def _reconnect_phase2_meta_buffer(self): + self._phase2_rdma_buffer = None + self._last_phase2_connect_retry_ts = 0.0 + + if self._phase2_rdma_client is not None: + sock = getattr(self._phase2_rdma_client, "sock", None) + if sock is not None: + try: + sock.close() + except Exception: + pass + self._phase2_rdma_client.sock = None + + self._ensure_phase2_meta_buffer() + + def _produce_phase2_request_with_retry(self, room: int, payload: dict[str, Any]): + retries = max(1, int(os.getenv("RDMA_PHASE2_PRODUCE_RETRIES", "3"))) + retry_delay_s = max(0.01, float(os.getenv("RDMA_PHASE2_PRODUCE_RETRY_DELAY_S", "0.2"))) + last_exc: Optional[Exception] = None + + for attempt in range(1, retries + 1): + try: + if self._phase2_rdma_buffer is None: + self._ensure_phase2_meta_buffer() + if self._phase2_rdma_buffer is None: + raise RuntimeError("phase2 RDMA buffer is not ready") + self._phase2_rdma_buffer.produce(payload) + return + except Exception as exc: + last_exc = exc + self.logger.warning( + "Phase2 RDMA produce failed for room=%s attempt=%s/%s: %s", + room, + attempt, + retries, + exc, + ) + if attempt >= retries: + break + try: + self._reconnect_phase2_meta_buffer() + except Exception as reconnect_exc: + self.logger.warning( + "Phase2 RDMA reconnect failed for room=%s attempt=%s/%s: %s", + room, + attempt, + retries, + reconnect_exc, + ) + time.sleep(retry_delay_s) + + raise RuntimeError(f"Failed to produce phase2 RDMA request for room={room} after {retries} attempts") from last_exc + def init(self, config): - self.config = config + self._sync_runtime_config(config) + self.encoder_engine_rank = int(self.config.get("encoder_engine_rank", self.encoder_engine_rank)) + self.transformer_engine_rank = int(self.config.get("transformer_engine_rank", self.transformer_engine_rank)) + self.decoder_engine_rank = int(self.config.get("decoder_engine_rank", self.decoder_engine_rank)) shared_slots = int(self.config.get("rdma_buffer_slots", self._phase1_slots)) shared_slot_size = int(self.config.get("rdma_buffer_slot_size", 4096)) self._phase1_server_ip = str(self.config.get("rdma_phase1_host", self._phase1_server_ip)) @@ -175,6 +388,9 @@ def init(self, config): self._phase2_slots = shared_slots self._phase2_slot_size = shared_slot_size + if self.scheduler is not None: + self.scheduler.refresh_from_config(self.config) + # Set global seed if present in config, though specific process calls might reuse it if "seed" in self.config: seed_all(self.config["seed"]) @@ -185,19 +401,20 @@ def init(self, config): if data_bootstrap_addr is None or data_bootstrap_room is None: return - phase_deadline = time.time() + 30.0 - while time.time() < phase_deadline: - try: - self._ensure_phase1_request_buffer() - self._ensure_phase2_meta_buffer() - except Exception: - self.logger.exception("Failed to connect phase RDMA buffers, will retry") - if self._phase1_rdma_buffer is not None and self._phase2_rdma_buffer is not None: - break - time.sleep(0.1) + if not self._centralized_request_mode: + phase_deadline = time.time() + 30.0 + while time.time() < phase_deadline: + try: + self._ensure_phase1_request_buffer() + self._ensure_phase2_meta_buffer() + except Exception: + self.logger.exception("Failed to connect phase RDMA buffers, will retry") + if self._phase1_rdma_buffer is not None and self._phase2_rdma_buffer is not None: + break + time.sleep(0.1) - if self._phase1_rdma_buffer is None or self._phase2_rdma_buffer is None: - raise RuntimeError("phase RDMA buffers are not ready") + if self._phase1_rdma_buffer is None or self._phase2_rdma_buffer is None: + raise RuntimeError("phase RDMA buffers are not ready") buffer_sizes = estimate_encoder_buffer_sizes(self.config) request = AllocationRequest( @@ -220,32 +437,66 @@ def init(self, config): self.data_receiver[data_bootstrap_room] = DataReceiver(self.data_mgr1, phase1_bootstrap_addr, data_bootstrap_room) self.data_receiver[data_bootstrap_room].init() - buffer_sizes = estimate_transformer_buffer_sizes(self.config) - request = AllocationRequest( - bootstrap_room=data_bootstrap_room, - buffer_sizes=buffer_sizes, - ) - handle = self.alloc_memory(DisaggregationPhase.PHASE2, request) - data_ptrs = [buf.addr for buf in handle.buffers] - data_lens = [buf.nbytes for buf in handle.buffers] - data_args = DataArgs( - sender_engine_rank=self.transformer_engine_rank, - receiver_engine_rank=self.decoder_engine_rank, - data_ptrs=data_ptrs, - data_lens=data_lens, - data_item_lens=data_lens, - ib_device=None, - ) - self.data_mgr2.init(data_args, data_bootstrap_room) - self.data_sender[data_bootstrap_room] = DataSender(self.data_mgr2, data_bootstrap_addr, data_bootstrap_room) + buffer_sizes = [int(v) for v in estimate_transformer_buffer_sizes(self.config)] + remote_room: dict[str, Any] | None = None + room_init_retries = max(1, int(os.getenv("DISAGG_TRANSFORMER_REMOTE_OUTPUT_INIT_RETRIES", "3"))) + room_init_retry_sleep_s = max(0.01, float(os.getenv("DISAGG_TRANSFORMER_REMOTE_OUTPUT_INIT_RETRY_SLEEP_S", "0.2"))) - self._phase2_rdma_buffer.produce( - { - "request_config": dict(self.config), - "transformer_node_address": self.data_mgr2.get_localhost(), - "transformer_session_id": self.data_mgr2.get_session_id(), - } - ) + for attempt in range(1, room_init_retries + 1): + try: + remote_room = self._data_mgr_sidecar.init_transformer_output_room( + room=data_bootstrap_room, + sender_engine_rank=self.transformer_engine_rank, + receiver_engine_rank=self.decoder_engine_rank, + data_lens=buffer_sizes, + bootstrap_addr=data_bootstrap_addr, + ) + except Exception: + self.logger.exception( + "Failed to initialize remote transformer output room=%s attempt=%s/%s", + data_bootstrap_room, + attempt, + room_init_retries, + ) + remote_room = None + + if isinstance(remote_room, dict): + break + + if attempt < room_init_retries: + time.sleep(room_init_retry_sleep_s) + + if not isinstance(remote_room, dict): + raise RuntimeError(f"remote transformer output room init failed for room={data_bootstrap_room}; sidecar ownership is required to keep transfers alive during service reclaim") + + shm_names_raw = remote_room.get("shm_names") + data_lens_raw = remote_room.get("data_lens", buffer_sizes) + if not isinstance(shm_names_raw, list) or not isinstance(data_lens_raw, list) or len(shm_names_raw) != len(data_lens_raw): + raise RuntimeError(f"invalid remote output room metadata for room={data_bootstrap_room}: {remote_room}") + + shm_handles: list[shared_memory.SharedMemory] = [] + phase2_buffers: list[torch.Tensor] = [] + try: + for shm_name, nbytes in zip(shm_names_raw, data_lens_raw): + shm = self._attach_remote_shared_memory(str(shm_name)) + np_view = np.ndarray((int(nbytes),), dtype=np.uint8, buffer=shm.buf) + tensor = torch.from_numpy(np_view) + tensor.zero_() + shm_handles.append(shm) + phase2_buffers.append(tensor) + except Exception: + for shm in shm_handles: + try: + shm.close() + except Exception: + pass + self._data_mgr_sidecar.remove_transformer_output_room(data_bootstrap_room) + raise + + self._phase2_remote_rooms.add(int(data_bootstrap_room)) + self._phase2_remote_shared_memory[int(data_bootstrap_room)] = shm_handles + self.rdma_buffer2[int(data_bootstrap_room)] = phase2_buffers + self.data_sender[int(data_bootstrap_room)] = None def load_models(self): self.logger.info("Loading Transformer Models...") @@ -300,12 +551,18 @@ def process(self, config): Executes the diffusion process and video decoding. """ self.logger.info("Starting processing in TransformerService...") + # Re-sync scheduler with the current request to avoid cross-request config bleed. + if self.scheduler is not None: + self.scheduler.refresh_from_config(config) room = config.get("data_bootstrap_room", 0) + transformer_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("transformer", {}) + transformer_metrics["compute_start_ts"] = time.time() phase1_buffers = self.rdma_buffer1.get(room) phase2_buffers = self.rdma_buffer2.get(room) receiver = self.data_receiver.get(room) sender = self.data_sender.get(room) + use_remote_phase2 = room in self._phase2_remote_rooms if phase1_buffers is None: raise RuntimeError(f"phase1 RDMA buffers are not initialized for room={room}.") @@ -313,7 +570,7 @@ def process(self, config): raise RuntimeError(f"phase2 RDMA buffers are not initialized for room={room}.") if receiver is None: raise RuntimeError(f"DataReceiver is not initialized for room={room}.") - if sender is None: + if sender is None and not use_remote_phase2: raise RuntimeError(f"DataSender is not initialized for room={room}.") def _buffer_view(buf: torch.Tensor, dtype: torch.dtype, shape: tuple[int, ...]) -> torch.Tensor: @@ -359,11 +616,98 @@ def _sha256_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]: buffer_index += 1 meta_buf = phase1_buffers[buffer_index] - meta_bytes = _buffer_view(meta_buf, torch.uint8, (meta_buf.numel(),)).detach().contiguous().cpu().numpy().tobytes() - meta_str = meta_bytes.split(b"\x00", 1)[0].decode("utf-8") if meta_bytes else "" - if not meta_str: - raise ValueError("missing metadata from encoder") - meta = json.loads(meta_str) + strict_meta_hash_check = str(os.getenv("LIGHTX2V_STRICT_META_HASH", "0")).strip().lower() in {"1", "true", "yes", "on"} + + def _load_phase1_meta(max_retries: int = 20, retry_sleep_s: float = 0.05) -> dict: + last_error: Optional[Exception] = None + last_preview = "" + + required_shape_keys = ["context_shape", "latent_shape"] + if enable_cfg: + required_shape_keys.append("context_null_shape") + if task == "i2v": + required_shape_keys.append("vae_shape") + if use_image_encoder: + required_shape_keys.append("clip_shape") + + for attempt in range(1, max_retries + 1): + meta_bytes = _buffer_view(meta_buf, torch.uint8, (meta_buf.numel(),)).detach().contiguous().cpu().numpy().tobytes() + raw_payload = meta_bytes.split(b"\x00", 1)[0] if meta_bytes else b"" + if not raw_payload: + last_error = ValueError("missing metadata from encoder") + if attempt < max_retries: + time.sleep(retry_sleep_s) + continue + break + try: + meta_str = raw_payload.decode("utf-8") + except UnicodeDecodeError as err: + last_error = err + last_preview = raw_payload[:32].hex() + if attempt < max_retries: + self.logger.warning( + "Invalid phase1 metadata UTF-8 for room=%s (attempt %s/%s), retrying...", + room, + attempt, + max_retries, + ) + time.sleep(retry_sleep_s) + continue + break + + if not meta_str.strip(): + last_error = ValueError("empty metadata payload from encoder") + if attempt < max_retries: + time.sleep(retry_sleep_s) + continue + break + + try: + parsed = json.loads(meta_str) + except json.JSONDecodeError as err: + last_error = err + last_preview = meta_str[:120] + if attempt < max_retries: + self.logger.warning( + "Invalid phase1 metadata JSON for room=%s (attempt %s/%s), retrying...", + room, + attempt, + max_retries, + ) + time.sleep(retry_sleep_s) + continue + break + + if not isinstance(parsed, dict): + last_error = TypeError(f"phase1 metadata must be a dict, got {type(parsed).__name__}") + last_preview = str(parsed)[:120] + if attempt < max_retries: + time.sleep(retry_sleep_s) + continue + break + + missing_shape_keys = [key for key in required_shape_keys if not isinstance(parsed.get(key), (list, tuple)) or len(parsed.get(key)) == 0] + if missing_shape_keys: + last_error = ValueError(f"incomplete metadata, missing keys: {missing_shape_keys}") + last_preview = str({k: parsed.get(k) for k in required_shape_keys})[:180] + if attempt < max_retries: + self.logger.warning( + "Incomplete phase1 metadata for room=%s (attempt %s/%s), missing=%s, retrying...", + room, + attempt, + max_retries, + missing_shape_keys, + ) + time.sleep(retry_sleep_s) + continue + break + + return parsed + + preview_suffix = f", preview={last_preview}" if last_preview else "" + raise ValueError(f"failed to load phase1 metadata for room={room}: {last_error}{preview_suffix}") + + meta = _load_phase1_meta() meta_shapes = {k: v for k, v in meta.items() if k.endswith("_shape")} meta_dtypes = {k: v for k, v in meta.items() if k.endswith("_dtype")} self.logger.info("Transformer meta shapes: %s", meta_shapes) @@ -413,33 +757,48 @@ def _get_shape(key: str) -> tuple[int, ...]: if list(context.shape) != meta.get("context_shape"): raise ValueError("context shape mismatch between encoder and transformer") if meta.get("context_hash") is not None and _sha256_tensor(context) != meta.get("context_hash"): - raise ValueError("context hash mismatch between encoder and transformer") + msg = "context hash mismatch between encoder and transformer" + if strict_meta_hash_check: + raise ValueError(msg) + self.logger.warning("%s for room=%s, continue with non-strict mode", msg, room) if enable_cfg: if context_null is not None: if list(context_null.shape) != meta.get("context_null_shape"): raise ValueError("context_null shape mismatch between encoder and transformer") if meta.get("context_null_hash") is not None: if _sha256_tensor(context_null) != meta.get("context_null_hash"): - raise ValueError("context_null hash mismatch between encoder and transformer") + msg = "context_null hash mismatch between encoder and transformer" + if strict_meta_hash_check: + raise ValueError(msg) + self.logger.warning("%s for room=%s, continue with non-strict mode", msg, room) if task == "i2v": if clip_encoder_out is not None: if list(clip_encoder_out.shape) != meta.get("clip_shape"): raise ValueError("clip shape mismatch between encoder and transformer") if meta.get("clip_hash") is not None: if _sha256_tensor(clip_encoder_out) != meta.get("clip_hash"): - raise ValueError("clip hash mismatch between encoder and transformer") + msg = "clip hash mismatch between encoder and transformer" + if strict_meta_hash_check: + raise ValueError(msg) + self.logger.warning("%s for room=%s, continue with non-strict mode", msg, room) if vae_encoder_out is not None: if list(vae_encoder_out.shape) != meta.get("vae_shape"): raise ValueError("vae shape mismatch between encoder and transformer") if meta.get("vae_hash") is not None: if _sha256_tensor(vae_encoder_out) != meta.get("vae_hash"): - raise ValueError("vae hash mismatch between encoder and transformer") + msg = "vae hash mismatch between encoder and transformer" + if strict_meta_hash_check: + raise ValueError(msg) + self.logger.warning("%s for room=%s, continue with non-strict mode", msg, room) if meta.get("latent_shape") is None or list(latent_shape) != meta.get("latent_shape"): raise ValueError("latent_shape mismatch between encoder and transformer") if meta.get("latent_hash") is not None: latent_tensor = torch.tensor(latent_shape, device=AI_DEVICE, dtype=torch.int64) if _sha256_tensor(latent_tensor) != meta.get("latent_hash"): - raise ValueError("latent_shape hash mismatch between encoder and transformer") + msg = "latent_shape hash mismatch between encoder and transformer" + if strict_meta_hash_check: + raise ValueError(msg) + self.logger.warning("%s for room=%s, continue with non-strict mode", msg, room) inputs = { "text_encoder_output": text_encoder_output, @@ -470,6 +829,7 @@ def _get_shape(key: str) -> tuple[int, ...]: self.scheduler.step_post() latents = self.scheduler.latents + transformer_metrics["compute_end_ts"] = time.time() # Send latents to DecoderService if len(phase2_buffers) < 2: @@ -501,7 +861,49 @@ def _get_shape(key: str) -> tuple[int, ...]: meta_view[: len(meta_bytes)].copy_(torch.from_numpy(np.frombuffer(meta_bytes, dtype=np.uint8))) buffer_ptrs = [buf.data_ptr() for buf in phase2_buffers] - sender.send(buffer_ptrs) + # Publish phase2 request metadata after compute so downstream can see latest metrics. + transformer_metrics["output_enqueued_ts"] = time.time() + phase2_request_config = dict(config) + phase2_request_config["transformer_engine_rank"] = self.transformer_engine_rank + transformer_node_address = "" + transformer_session_id = "" + if room in self._phase2_remote_rooms: + identity = self._data_mgr_sidecar.get_transformer_output_identity(room) + if not isinstance(identity, dict): + raise RuntimeError(f"remote transformer output identity unavailable for room={room}") + transformer_node_address = str(identity.get("host", "")).strip() + transformer_session_id = str(identity.get("session_id", "")).strip() + if not transformer_node_address or not transformer_session_id: + raise RuntimeError(f"remote transformer output identity invalid for room={room}: {identity}") + else: + transformer_node_address = self.data_mgr2.get_localhost() + transformer_session_id = self.data_mgr2.get_session_id() + + self._produce_phase2_request_with_retry( + room, + { + "request_config": phase2_request_config, + "transformer_node_address": transformer_node_address, + "transformer_session_id": transformer_session_id, + }, + ) + if use_remote_phase2: + if not self._data_mgr_sidecar.send_transformer_output_room(room): + raise RuntimeError(f"Failed to enqueue remote transformer output transfer for room={room}") + if self.sync_comm: + while True: + status = int(self._data_mgr_sidecar.get_transformer_output_status(room)) + if status == DataPoll.Success: + break + if status == DataPoll.Failed: + raise RuntimeError(f"DataSender transfer failed for room={room}") + time.sleep(0.001) + else: + if sender is None: + raise RuntimeError(f"DataSender is not initialized for room={room}") + sender.send(buffer_ptrs) + if self.sync_comm: + self._wait_sender_success(room, sender) def release_memory(self, room: int): """ @@ -513,17 +915,31 @@ def release_memory(self, room: int): if room in self.rdma_buffer2: self.rdma_buffer2.pop(room, None) + shm_handles = self._phase2_remote_shared_memory.pop(room, None) + if isinstance(shm_handles, list): + for shm in shm_handles: + try: + shm.close() + except Exception: + pass + self._phase2_remote_rooms.discard(room) + torch.cuda.empty_cache() def remove(self, room: int): + use_remote_phase2 = room in self._phase2_remote_rooms self.release_memory(room) self.data_receiver.pop(room, None) self.data_sender.pop(room, None) + self._data_mgr_sidecar.unwatch_input(room) + self._data_mgr_sidecar.unwatch_output(room) if self.data_mgr1 is not None: self.data_mgr1.remove(room) - if self.data_mgr2 is not None: + if use_remote_phase2: + self._data_mgr_sidecar.remove_transformer_output_room(room) + elif self.data_mgr2 is not None: self.data_mgr2.remove(room) def release(self): @@ -547,44 +963,76 @@ def run(self, stop_event=None): req_queue = deque() waiting_queue: dict[int, dict] = {} exec_queue = deque() - complete_queue: dict[int, dict] = {} + complete_queue: set[int] = set() while True: phase1_transfer_sizes = self.data_mgr1.get_backlog_counts() if self.data_mgr1 is not None else {"request_pool": 0, "waiting_pool": 0} phase2_transfer_sizes = self.data_mgr2.get_backlog_counts() if self.data_mgr2 is not None else {"request_pool": 0, "waiting_pool": 0} + remote_phase2_transfer_sizes = self._data_mgr_sidecar.get_transformer_output_backlog() + for key in ("request_pool", "waiting_pool", "request_status"): + phase2_transfer_sizes[key] = int(phase2_transfer_sizes.get(key, 0)) + int(remote_phase2_transfer_sizes.get(key, 0)) + sidecar_sizes = self._data_mgr_sidecar.get_pending_counts() self._update_queue_metrics( { "req_queue": len(req_queue), "waiting_queue": len(waiting_queue), "exec_queue": len(exec_queue), - "complete_queue": len(complete_queue), }, { "request_pool": int(phase1_transfer_sizes.get("request_pool", 0)), "waiting_pool": int(phase1_transfer_sizes.get("waiting_pool", 0)), + "sidecar_input_watch": int(sidecar_sizes.get("input_watch", 0)), }, { + "complete_queue": len(complete_queue), "request_pool": int(phase2_transfer_sizes.get("request_pool", 0)), "waiting_pool": int(phase2_transfer_sizes.get("waiting_pool", 0)), + "sidecar_output_watch": int(sidecar_sizes.get("output_watch", 0)), }, ) - if self._phase1_rdma_buffer is None: - try: - self._ensure_phase1_request_buffer() - except Exception: - self.logger.exception("Failed to connect phase1 request RDMA buffer, will retry") - - if self._phase1_rdma_buffer is not None and len(req_queue) + len(waiting_queue) < 2: - packet = self._phase1_rdma_buffer.consume() - if packet is not None: - if isinstance(packet, dict) and "request_config" in packet: - config = dict(packet.get("request_config") or {}) - config["encoder_node_address"] = packet.get("encoder_node_address", "127.0.0.1") - else: - config = packet - self.logger.info("%s Received request config from RDMA buffer: %s", self.transformer_engine_rank, {k: v for k, v in config.items()}) + if self._centralized_request_mode: + config = self._centralized_request_mgr.receive_non_block(self._centralized_request_port) + if config is not None: + if not isinstance(config, dict) or "data_bootstrap_room" not in config: + self.logger.warning("Ignored incomplete request packet from ZMQ: %s", config) + continue + transformer_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("transformer", {}) + transformer_metrics["request_received_ts"] = time.time() + self.logger.info("Received request config from ZMQ: %s", {k: v for k, v in config.items()}) req_queue.append(config) + else: + if self._phase1_rdma_buffer is None: + try: + self._ensure_phase1_request_buffer() + except Exception: + self.logger.exception("Failed to connect phase1 request RDMA buffer, will retry") + + if self._phase1_rdma_client is not None and self._phase1_rdma_client.has_qp_error(): + self.logger.warning( + "Phase1 request RDMA client entered error state, reconnecting: %s", + self._phase1_rdma_client.last_wc_error_message(), + ) + try: + self._reconnect_phase1_request_buffer() + except Exception: + self.logger.exception("Failed to reconnect phase1 request RDMA buffer after QP error") + + if self._phase1_rdma_buffer is not None and len(req_queue) + len(waiting_queue) < 2: + packet = self._phase1_rdma_buffer.consume() + if packet is not None: + if isinstance(packet, dict) and "request_config" in packet: + config = dict(packet.get("request_config") or {}) + config["encoder_node_address"] = packet.get("encoder_node_address", "127.0.0.1") + else: + config = packet + if not isinstance(config, dict) or "data_bootstrap_room" not in config: + self.logger.warning("Ignored incomplete phase1 packet from RDMA buffer: %s", packet) + continue + transformer_metrics = config.setdefault("request_metrics", {}).setdefault("stages", {}).setdefault("transformer", {}) + transformer_metrics["request_received_ts"] = time.time() + self.logger.info("%s Received request config from RDMA buffer: %s", self.transformer_engine_rank, {k: v for k, v in config.items()}) + req_queue.append(config) if req_queue: config = req_queue.popleft() @@ -592,26 +1040,21 @@ def run(self, stop_event=None): try: self.init(config) waiting_queue[room] = config + receiver = self.data_receiver.get(room) + if receiver is None: + raise RuntimeError(f"DataReceiver is not initialized for room={room}") + self._data_mgr_sidecar.watch_input(room, receiver) except Exception: self.logger.exception("Failed to initialize request for room=%s", room) self.remove(room) - ready_rooms: List[int] = [] - failed_rooms: List[int] = [] - for room, config in list(waiting_queue.items()): - receiver = self.data_receiver.get(room) - if receiver is None: - failed_rooms.append(room) - continue - - status = receiver.poll() - if status == DataPoll.Success: - ready_rooms.append(room) - elif status == DataPoll.Failed: - failed_rooms.append(room) + ready_rooms = self._data_mgr_sidecar.pop_ready_inputs() + failed_rooms = self._data_mgr_sidecar.pop_failed_inputs() for room in ready_rooms: - exec_queue.append((room, waiting_queue.pop(room))) + config = waiting_queue.pop(room, None) + if config is not None: + exec_queue.append((room, config)) for room in failed_rooms: waiting_queue.pop(room, None) @@ -619,31 +1062,36 @@ def run(self, stop_event=None): self.remove(room) if exec_queue: - room, config = exec_queue.popleft() + room, config = exec_queue[0] try: self.process(config) - complete_queue[room] = config + if self._centralized_request_mode: + self._report_stage_metrics_to_controller("transformer", config) + self._wait_for_controller_ok("transformer", config) + if self.sync_comm: + self.remove(room) + else: + if room in self._phase2_remote_rooms: + complete_queue.add(room) + else: + sender = self.data_sender.get(room) + if sender is None: + self.logger.error("DataSender is not initialized for room=%s", room) + self.remove(room) + else: + self._data_mgr_sidecar.watch_output(room, sender) + complete_queue.add(room) except Exception: self.logger.exception("Failed to process request for room=%s", room) - complete_queue.pop(room, None) self.remove(room) + finally: + exec_queue.popleft() - completed_rooms: List[int] = [] - for room in list(complete_queue.keys()): - sender = self.data_sender.get(room) - if sender is None: - completed_rooms.append(room) - continue - - status = sender.poll() - if status == DataPoll.Success: - completed_rooms.append(room) - elif status == DataPoll.Failed: + completed_outputs = self._data_mgr_sidecar.pop_completed_outputs() + for room, status in completed_outputs: + if status == DataPoll.Failed: self.logger.error("DataSender transfer failed for room=%s", room) - completed_rooms.append(room) - - for room in completed_rooms: - complete_queue.pop(room, None) + complete_queue.discard(room) self.remove(room) if stop_event is not None and stop_event.is_set() and not req_queue and not waiting_queue and not exec_queue and not complete_queue: diff --git a/lightx2v/disagg/utils.py b/lightx2v/disagg/utils.py index 3ade1618b..0d5a40da3 100644 --- a/lightx2v/disagg/utils.py +++ b/lightx2v/disagg/utils.py @@ -8,13 +8,7 @@ import torchvision.transforms.functional as TF from PIL import Image -from lightx2v.models.input_encoders.hf.wan.t5.model import T5EncoderModel -from lightx2v.models.input_encoders.hf.wan.xlm_roberta.model import CLIPModel from lightx2v.models.networks.lora_adapter import LoraAdapter -from lightx2v.models.networks.wan.model import WanModel -from lightx2v.models.video_encoders.hf.wan.vae import WanVAE -from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE -from lightx2v.models.video_encoders.hf.wan.vae_tiny import Wan2_2_VAE_tiny, WanVAE_tiny from lightx2v.utils.envs import GET_DTYPE from lightx2v.utils.set_config import set_config as set_config_base from lightx2v.utils.utils import find_torch_model_path @@ -170,6 +164,8 @@ def build_wan_model_with_lora(wan_module, config, model_kwargs, lora_configs, mo def load_wan_text_encoder(config: Dict[str, Any]): + from lightx2v.models.input_encoders.hf.wan.t5.model import T5EncoderModel + # offload config t5_offload = config.get("t5_cpu_offload", config.get("cpu_offload")) if t5_offload: @@ -212,6 +208,8 @@ def load_wan_text_encoder(config: Dict[str, Any]): def load_wan_image_encoder(config: Dict[str, Any]): + from lightx2v.models.input_encoders.hf.wan.xlm_roberta.model import CLIPModel + image_encoder = None if config["task"] in ["i2v", "flf2v", "animate", "s2v"] and config.get("use_image_encoder", True): # offload config @@ -259,6 +257,9 @@ def get_vae_parallel(config: Dict[str, Any]): def load_wan_vae_encoder(config: Dict[str, Any]): + from lightx2v.models.video_encoders.hf.wan.vae import WanVAE + from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE + vae_name = config.get("vae_name", "Wan2.1_VAE.pth") if config.get("model_cls", "") == "wan2.2": vae_cls = Wan2_2_VAE @@ -289,6 +290,10 @@ def load_wan_vae_encoder(config: Dict[str, Any]): def load_wan_vae_decoder(config: Dict[str, Any]): + from lightx2v.models.video_encoders.hf.wan.vae import WanVAE + from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE + from lightx2v.models.video_encoders.hf.wan.vae_tiny import Wan2_2_VAE_tiny, WanVAE_tiny + vae_name = config.get("vae_name", "Wan2.1_VAE.pth") tiny_vae_name = "taew2_1.pth" @@ -327,6 +332,31 @@ def load_wan_vae_decoder(config: Dict[str, Any]): def load_wan_transformer(config: Dict[str, Any]): + print( + "Loading WanModel module: model_cls=%s model_path=%s device=%s dit_quantized=%s lazy_load=%s" + % ( + config.get("model_cls"), + config.get("model_path"), + AI_DEVICE if not config.get("cpu_offload") else "cpu", + config.get("dit_quantized", False), + config.get("lazy_load", False), + ), + flush=True, + ) + from lightx2v.models.networks.wan.model import WanModel + + print( + "Constructing WanModel: model_cls=%s model_path=%s device=%s dit_quantized=%s lazy_load=%s" + % ( + config.get("model_cls"), + config.get("model_path"), + AI_DEVICE if not config.get("cpu_offload") else "cpu", + config.get("dit_quantized", False), + config.get("lazy_load", False), + ), + flush=True, + ) + if config["cpu_offload"]: init_device = torch.device("cpu") else: @@ -339,10 +369,14 @@ def load_wan_transformer(config: Dict[str, Any]): model = WanModel(**wan_model_kwargs) else: model = build_wan_model_with_lora(WanModel, config, wan_model_kwargs, lora_configs, model_type="wan2.1") + logger.info("WanModel construction finished") return model elif config.get("model_cls") == "wan2.2_moe": + print("Loading MultiModelStruct module start", flush=True) from lightx2v.models.runners.wan.wan_runner import MultiModelStruct + print("Loading MultiModelStruct module done", flush=True) + high_noise_model_path = os.path.join(config["model_path"], "high_noise_model") if config.get("dit_quantized", False) and config.get("high_noise_quantized_ckpt", None): high_noise_model_path = config["high_noise_quantized_ckpt"] @@ -376,6 +410,7 @@ def load_wan_transformer(config: Dict[str, Any]): high_noise_model = build_wan_model_with_lora(WanModel, config, high_model_kwargs, lora_configs, model_type="high_noise_model") low_noise_model = build_wan_model_with_lora(WanModel, config, low_model_kwargs, lora_configs, model_type="low_noise_model") + logger.info("WanModel construction finished for wan2.2_moe") return MultiModelStruct([high_noise_model, low_noise_model], config, config.get("boundary", 0.875)) else: model_struct = MultiModelStruct([None, None], config, config.get("boundary", 0.875)) diff --git a/lightx2v/disagg/workload.py b/lightx2v/disagg/workload.py new file mode 100644 index 000000000..c1efec1ff --- /dev/null +++ b/lightx2v/disagg/workload.py @@ -0,0 +1,294 @@ +from __future__ import annotations + +import copy +import json +import os +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Optional + +try: + from locust import LoadTestShape, User, events, task +except ModuleNotFoundError: + + class _EventHook: + def add_listener(self, fn): + return fn + + def fire(self, **kwargs): + return None + + class _Events: + def __init__(self): + self.test_start = _EventHook() + self.test_stop = _EventHook() + self.request = _EventHook() + + class LoadTestShape: # type: ignore[no-redef] + pass + + class User: # type: ignore[no-redef] + pass + + def task(fn): # type: ignore[no-redef] + return fn + + events = _Events() # type: ignore[no-redef] + +from lightx2v.disagg.conn import REQUEST_POLLING_PORT, ReqManager + +REPO_ROOT = Path(__file__).resolve().parents[2] +DEFAULT_BASE_CONFIG_JSON = REPO_ROOT / "configs" / "disagg" / "single_node" / "wan22_i2v_distill_controller.json" +DEFAULT_STAGE_DEFINITIONS_JSON = REPO_ROOT / "configs" / "disagg" / "wan22_i2v_workload_stages.json" + +_TEST_START_MONOTONIC: Optional[float] = None + + +def _deep_merge(base: dict[str, Any], overlay: dict[str, Any]) -> dict[str, Any]: + merged = copy.deepcopy(base) + for key, value in overlay.items(): + if isinstance(value, dict) and isinstance(merged.get(key), dict): + merged[key] = _deep_merge(merged[key], value) + else: + merged[key] = copy.deepcopy(value) + return merged + + +def _load_base_config() -> dict[str, Any]: + config_path = os.getenv("DISAGG_BASE_CONFIG_JSON") + if config_path: + path = Path(config_path) + else: + path = DEFAULT_BASE_CONFIG_JSON + + if path.is_file(): + with path.open("r", encoding="utf-8") as handle: + return json.load(handle) + + return { + "task": "i2v", + "model_cls": "wan2.2_moe", + "seed": 42, + "prompt": "A cinematic cat scene with detailed lighting and motion.", + "negative_prompt": "blurry, low quality, artifacts", + "save_path": str(REPO_ROOT / "save_results" / "locust_disagg.mp4"), + } + + +def _load_stage_definitions() -> list[dict[str, Any]]: + stage_file = Path(os.getenv("DISAGG_WORKLOAD_STAGES_JSON", str(DEFAULT_STAGE_DEFINITIONS_JSON))) + if not stage_file.is_file(): + raise FileNotFoundError(f"workload stage config not found: {stage_file}") + + with stage_file.open("r", encoding="utf-8") as handle: + loaded = json.load(handle) + + if not isinstance(loaded, list) or not loaded: + raise ValueError(f"{stage_file} must contain a non-empty JSON list") + + return loaded + + +@dataclass(frozen=True) +class StageSpec: + name: str + duration_s: float + user_count: int + spawn_rate: float + wait_time_s: float = 0.0 + config_variants: list[dict[str, Any]] = field(default_factory=list) + + @staticmethod + def from_dict(raw: dict[str, Any]) -> "StageSpec": + name = str(raw.get("name", "stage")) + duration_s = float(raw.get("duration_s", 0.0)) + user_count = int(raw.get("user_count", 1)) + spawn_rate = float(raw.get("spawn_rate", max(1, user_count))) + wait_time_s = float(raw.get("wait_time_s", 0.0)) + config_variants = raw.get("config_variants", []) or [] + if not isinstance(config_variants, list): + raise ValueError(f"stage {name}: config_variants must be a list") + return StageSpec( + name=name, + duration_s=max(duration_s, 0.0), + user_count=max(user_count, 1), + spawn_rate=max(spawn_rate, 0.1), + wait_time_s=max(wait_time_s, 0.0), + config_variants=[variant for variant in config_variants if isinstance(variant, dict)], + ) + + +def _load_stage_specs() -> list[StageSpec]: + return [StageSpec.from_dict(stage) for stage in _load_stage_definitions()] + + +def load_base_config() -> dict[str, Any]: + return _load_base_config() + + +def load_stage_specs() -> list[StageSpec]: + return _load_stage_specs() + + +def _elapsed_since_start() -> float: + if _TEST_START_MONOTONIC is None: + return 0.0 + return max(0.0, time.monotonic() - _TEST_START_MONOTONIC) + + +def _stage_index_for_elapsed(stages: list[StageSpec], elapsed_s: float) -> int: + if not stages: + return 0 + + accumulated = 0.0 + for index, stage in enumerate(stages): + accumulated += stage.duration_s + if elapsed_s < accumulated: + return index + return len(stages) - 1 + + +def _current_stage(stages: list[StageSpec]) -> StageSpec: + return stages[_stage_index_for_elapsed(stages, _elapsed_since_start())] + + +def _build_request_payload(base_config: dict[str, Any], stage: StageSpec, request_index: int) -> dict[str, Any]: + payload = copy.deepcopy(base_config) + variant = stage.config_variants[request_index % len(stage.config_variants)] if stage.config_variants else {} + payload = _deep_merge(payload, variant) + + payload.setdefault("request_metrics", {}) + payload["request_metrics"]["request_id"] = request_index + payload["request_metrics"]["client_send_ts"] = time.time() + payload["request_metrics"]["stage_name"] = stage.name + payload["request_metrics"]["load_stage"] = stage.name + + if "data_bootstrap_room" not in payload: + payload["data_bootstrap_room"] = request_index + + save_path_prefix = os.getenv("DISAGG_WORKLOAD_SAVE_PREFIX") + if save_path_prefix: + save_root = Path(save_path_prefix) + save_root.parent.mkdir(parents=True, exist_ok=True) + payload["save_path"] = str(save_root.with_name(f"{save_root.stem}_{stage.name}_{request_index}{save_root.suffix}")) + + return payload + + +def _get_controller_target() -> tuple[str, int]: + host = os.getenv("DISAGG_CONTROLLER_HOST", "127.0.0.1") + port = int(os.getenv("DISAGG_CONTROLLER_REQUEST_PORT", str(REQUEST_POLLING_PORT - 2))) + return host, port + + +def _send_to_controller(payload: dict[str, Any]) -> None: + host, port = _get_controller_target() + ReqManager().send(host, port, payload) + + +def start_workload_clock() -> None: + global _TEST_START_MONOTONIC + _TEST_START_MONOTONIC = time.monotonic() + + +def current_stage(stages: Optional[list[StageSpec]] = None) -> StageSpec: + loaded_stages = stages or _load_stage_specs() + return _current_stage(loaded_stages) + + +def build_payload(base_config: dict[str, Any], stage: StageSpec, request_index: int) -> dict[str, Any]: + return _build_request_payload(base_config, stage, request_index) + + +def send_workload_end_signal() -> None: + _send_to_controller( + { + "workload_end": True, + "request_metrics": { + "load_stage": "end", + "client_send_ts": time.time(), + }, + } + ) + + +@events.test_start.add_listener +def _on_test_start(environment, **kwargs): # type: ignore[override] + start_workload_clock() + + +@events.test_stop.add_listener +def _on_test_stop(environment, **kwargs): # type: ignore[override] + send_workload_end_signal() + + +class DisaggLoadShape(LoadTestShape): + """Time-based load shape for disaggregated LightX2V scenarios. + + Configure stages with DISAGG_WORKLOAD_STAGES_JSON as a JSON file path. Each stage supports: + - duration_s + - user_count + - spawn_rate + - wait_time_s + - config_variants + """ + + stages = _load_stage_specs() + + def tick(self): + elapsed_s = _elapsed_since_start() + total_duration_s = sum(stage.duration_s for stage in self.stages) + if total_duration_s > 0 and elapsed_s >= total_duration_s: + return None + + stage = _current_stage(self.stages) + return stage.user_count, stage.spawn_rate + + +class DisaggUser(User): + base_config = _load_base_config() + stages = _load_stage_specs() + req_mgr = ReqManager() + + def wait_time(self): # type: ignore[override] + stage = _current_stage(self.stages) + return stage.wait_time_s + + @task + def submit_request(self): + stage = _current_stage(self.stages) + request_index = int(time.time() * 1000) % 1_000_000 + payload = _build_request_payload(self.base_config, stage, request_index) + send_started = time.perf_counter() + try: + host, port = _get_controller_target() + self.req_mgr.send(host, port, payload) + events.request.fire( + request_type="zmq", + name=f"{stage.name}:config_push", + response_time=(time.perf_counter() - send_started) * 1000.0, + response_length=len(str(payload)), + exception=None, + ) + except Exception as exc: + events.request.fire( + request_type="zmq", + name=f"{stage.name}:config_push", + response_time=(time.perf_counter() - send_started) * 1000.0, + response_length=0, + exception=exc, + ) + + +__all__ = [ + "DisaggLoadShape", + "DisaggUser", + "StageSpec", + "start_workload_clock", + "current_stage", + "build_payload", + "send_workload_end_signal", + "load_base_config", + "load_stage_specs", +] diff --git a/lightx2v/models/schedulers/wan/scheduler.py b/lightx2v/models/schedulers/wan/scheduler.py index afb32496d..076c1b24c 100755 --- a/lightx2v/models/schedulers/wan/scheduler.py +++ b/lightx2v/models/schedulers/wan/scheduler.py @@ -28,6 +28,16 @@ def __init__(self, config): self.caching_records_2 = [True] * self.config["infer_steps"] self.head_size = self.config["dim"] // self.config["num_heads"] + def refresh_from_config(self, config): + self.config = config + self.infer_steps = int(self.config["infer_steps"]) + self.target_video_length = int(self.config["target_video_length"]) + self.sample_shift = float(self.config["sample_shift"]) + self.sample_guide_scale = self.config["sample_guide_scale"] + self.caching_records = [True] * self.infer_steps + self.caching_records_2 = [True] * self.infer_steps + self.step_index = 0 + def _uses_conditioned_latent_prefix(self): """Whether this Wan variant keeps a fixed latent prefix during diffusion. diff --git a/scripts/disagg/extract_dynamic_latency.py b/scripts/disagg/extract_dynamic_latency.py new file mode 100644 index 000000000..0fdec8e40 --- /dev/null +++ b/scripts/disagg/extract_dynamic_latency.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import ast +import csv +import re +from datetime import datetime +from pathlib import Path + +WAIT_PATTERNS = [ + re.compile(r"^\[INFO\]\s+(\d{2}\s+\w{3}\s+\d{4}\s+\d{2}:\d{2}:\d{2}).*Waiting for decoder results", re.IGNORECASE), + re.compile(r"^\[(?:INFO|WARNING|ERROR|DEBUG|CRITICAL)\]\s+(\d{2}\s+\w{3}\s+\d{4}\s+\d{2}:\d{2}:\d{2}).*waiting workload configs on port=", re.IGNORECASE), +] +LAT_PATTERNS = [ + re.compile(r"^\[INFO\]\s+(\d{2}\s+\w{3}\s+\d{4}\s+\d{2}:\d{2}:\d{2}).*Latency summary room=(\d+) metrics=(\{.*\})"), + re.compile(r"^\[(?:INFO|WARNING|ERROR|DEBUG|CRITICAL)\]\s+(\d{2}\s+\w{3}\s+\d{4}\s+\d{2}:\d{2}:\d{2}).*Latency summary room=(\d+) metrics=(\{.*\})"), +] +TS_FMT = "%d %b %Y %H:%M:%S" +LOGURU_TS_FMT = "%Y-%m-%d %H:%M:%S" + + +def _fmt_float3(value): + try: + return f"{float(value):.3f}" + except (TypeError, ValueError): + return value + + +def _match_any(patterns, line): + for pattern in patterns: + match = pattern.match(line) + if match: + return match + return None + + +def _parse_timestamp(raw_ts: str): + for fmt in (TS_FMT, LOGURU_TS_FMT): + try: + return datetime.strptime(raw_ts, fmt) + except ValueError: + pass + raise ValueError(f"unsupported timestamp format: {raw_ts}") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Extract latency summary rows relative to waiting workload log time") + parser.add_argument( + "--log", + default="/root/zht/LightX2V/save_results/disagg_wan22_i2v_dynamic_controller.log", + help="Controller log path", + ) + parser.add_argument( + "--output", + default="/root/zht/LightX2V/save_results/disagg_wan22_i2v_dynamic_results.csv", + help="Output table path", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + log_path = Path(args.log) + out_path = Path(args.output) + + if not log_path.is_file(): + raise FileNotFoundError(f"log file not found: {log_path}") + + wait_ts = None + rows = [] + metric_keys = [] + + with log_path.open("r", encoding="utf-8", errors="ignore") as f: + for line in f: + if wait_ts is None: + m_wait = _match_any(WAIT_PATTERNS, line) + if m_wait: + wait_ts = _parse_timestamp(m_wait.group(1)) + continue + + m_lat = _match_any(LAT_PATTERNS, line) + if not m_lat: + continue + + ts = _parse_timestamp(m_lat.group(1)) + room = int(m_lat.group(2)) + metrics = ast.literal_eval(m_lat.group(3)) + if not isinstance(metrics, dict): + continue + + if wait_ts is None: + rel_s = "NA" + else: + rel_s = f"{int((ts - wait_ts).total_seconds())}s" + + if not metric_keys: + metric_keys = list(metrics.keys()) + + row = { + "room": room, + "latency_summary_ts": ts.strftime("%Y-%m-%d %H:%M:%S"), + "relative_to_waiting_s": rel_s, + } + for key in metric_keys: + value = metrics.get(key) + row[key] = "" if value is None else _fmt_float3(value) + rows.append(row) + + out_path.parent.mkdir(parents=True, exist_ok=True) + + header = ["room", "latency_summary_ts", "relative_to_waiting_s", *metric_keys] + with out_path.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=header) + writer.writeheader() + for row in rows: + writer.writerow(row) + + print(f"wrote {len(rows)} rows to {out_path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/disagg/kill_service.sh b/scripts/disagg/kill_service.sh index 3b6315890..8dd3f6005 100755 --- a/scripts/disagg/kill_service.sh +++ b/scripts/disagg/kill_service.sh @@ -2,26 +2,76 @@ set -euo pipefail -SCRIPT_NAME="run_wan_t2v_service.sh" +SCRIPT_NAMES=("run_wan22_i2v_distill.sh" "run_dynamic.sh") -list_port=(5566 7788 12788 17788 27788) +list_port=(5566 12788 17788 27788) -n=10 +collect_proxy_ports_from_config() { + local config_path="$1" + + if [[ -z "$config_path" || ! -f "$config_path" ]]; then + return 0 + fi + if ! command -v jq >/dev/null 2>&1; then + return 0 + fi + + local base_port + base_port=$(jq -r '.disagg_config.remote_proxy_req_base_port // empty' "$config_path" 2>/dev/null || true) + if [[ -z "$base_port" || ! "$base_port" =~ ^[0-9]+$ ]]; then + return 0 + fi + + jq -r '.disagg_config.static_instance_slots[]?.engine_rank // empty' "$config_path" 2>/dev/null | while read -r engine_rank; do + [[ -z "$engine_rank" ]] && continue + if [[ "$engine_rank" =~ ^[0-9]+$ ]]; then + echo $((base_port + engine_rank)) + fi + done +} + +n=30 list_n=($(seq 0 $((n-1)))) PORTS=(5555 12787) +# Monitor ports for autoscaled services are contiguous from 7788. +for p in $(seq 7788 7803); do + PORTS+=($p) +done + for a in "${list_port[@]}"; do for b in "${list_n[@]}"; do PORTS+=($((a + b))) done done +proxy_config_candidates=( + "${DISAGG_CONTROLLER_CFG:-}" + "/root/zht/LightX2V/configs/disagg/multi_node/wan22_i2v_distill_controller.json" + "/root/zht/LightX2V/configs/disagg/single_node/wan22_i2v_distill_controller.json" +) +for config_path in "${proxy_config_candidates[@]}"; do + while read -r proxy_port; do + [[ -z "$proxy_port" ]] && continue + PORTS+=($proxy_port) + done < <(collect_proxy_ports_from_config "$config_path") +done + +# Fallback for environments without jq or without a readable config file. +PORTS+=(28000) + +mapfile -t PORTS < <(printf '%s\n' "${PORTS[@]}" | awk 'NF && !seen[$0]++ { print $0 }' | sort -n) + kill_pid_gracefully() { local pid="$1" if [[ -z "$pid" ]]; then return fi + if is_protected_pid "$pid"; then + echo "Skip protected pid=$pid" + return + fi if kill -0 "$pid" 2>/dev/null; then kill "$pid" 2>/dev/null || true sleep 1 @@ -31,6 +81,32 @@ kill_pid_gracefully() { fi } +declare -a PROTECTED_PIDS=() +collect_protected_pids() { + local cur="$$" + while [[ -n "$cur" ]] && [[ "$cur" != "0" ]]; do + PROTECTED_PIDS+=("$cur") + local parent + parent=$(ps -o ppid= -p "$cur" 2>/dev/null | tr -d ' ' || true) + if [[ -z "$parent" ]] || [[ "$parent" == "$cur" ]]; then + break + fi + cur="$parent" + done +} + +is_protected_pid() { + local target="$1" + for p in "${PROTECTED_PIDS[@]}"; do + if [[ "$p" == "$target" ]]; then + return 0 + fi + done + return 1 +} + +collect_protected_pids + find_listen_pids_by_port() { local port="$1" @@ -59,17 +135,43 @@ find_listen_pids_by_port() { echo "No supported tool found to query listening ports (need one of: lsof, ss, fuser)." >&2 } -echo "Stopping script process: ${SCRIPT_NAME}" -script_pids=$(pgrep -f "$SCRIPT_NAME" || true) -if [[ -n "${script_pids}" ]]; then +for script_name in "${SCRIPT_NAMES[@]}"; do + echo "Stopping script process: ${script_name}" + script_pids=$(pgrep -f "$script_name" || true) + if [[ -n "${script_pids}" ]]; then + while read -r pid; do + [[ -z "$pid" ]] && continue + echo "Killing script pid=$pid" + kill_pid_gracefully "$pid" + done <<< "$script_pids" + else + echo "No running process found for ${script_name}" + fi +done + +# Fallback cleanup for orphaned disagg service processes. +cleanup_patterns=( + "lightx2v.disagg.examples.run_service" + "lightx2v.disagg.examples.run_user" + "python -m lightx2v.disagg" + "conda run -n lightx2v bash scripts/disagg/run_dynamic.sh" + "conda run -n lightx2v bash scripts/disagg/run_wan22_i2v_distill.sh" +) + +for pattern in "${cleanup_patterns[@]}"; do + echo "Stopping processes matching pattern: ${pattern}" + matched_pids=$(pgrep -f "$pattern" || true) + if [[ -z "${matched_pids}" ]]; then + echo "No process matched: ${pattern}" + continue + fi + while read -r pid; do [[ -z "$pid" ]] && continue - echo "Killing script pid=$pid" + echo "Killing matched pid=$pid" kill_pid_gracefully "$pid" - done <<< "$script_pids" -else - echo "No running process found for ${SCRIPT_NAME}" -fi + done <<< "$matched_pids" +done for port in "${PORTS[@]}"; do echo "Stopping listeners on port ${port}" diff --git a/scripts/disagg/run_dynamic.sh b/scripts/disagg/run_dynamic.sh new file mode 100644 index 000000000..971ae14f8 --- /dev/null +++ b/scripts/disagg/run_dynamic.sh @@ -0,0 +1,549 @@ +#!/bin/bash + +set -euo pipefail + +lightx2v_path=/root/zht/LightX2V +model_path=${lightx2v_path}/models/lightx2v/Wan2.2-Distill-Models + +# base.sh expects PYTHONPATH to be defined under `set -u`. +export PYTHONPATH=${PYTHONPATH:-} + +source ${lightx2v_path}/scripts/base/base.sh + +disagg_conda_env=${DISAGG_CONDA_ENV:-lightx2v} +if [[ "${DISAGG_SKIP_CONDA_ACTIVATE:-0}" != "1" ]]; then + if [[ "${CONDA_DEFAULT_ENV:-}" != "${disagg_conda_env}" ]]; then + if ! command -v conda >/dev/null 2>&1; then + echo "ERROR: conda is not available, cannot activate env ${disagg_conda_env}" >&2 + exit 2 + fi + set +u + eval "$(conda shell.bash hook)" + conda activate "${disagg_conda_env}" + set -u + echo "activated conda env: ${disagg_conda_env}" + fi +fi + +# Ensure stale disagg services/ports from previous runs do not block bootstrap. +bash ${lightx2v_path}/scripts/disagg/kill_service.sh || true + +export CC=/usr/bin/gcc-13 +export CXX=/usr/bin/g++-13 +export CUDAHOSTCXX=/usr/bin/g++-13 +if [[ -n "${NVCC_PREPEND_FLAGS:-}" ]]; then + export NVCC_PREPEND_FLAGS="${NVCC_PREPEND_FLAGS} -allow-unsupported-compiler" +else + export NVCC_PREPEND_FLAGS="-allow-unsupported-compiler" +fi + +export RDMA_IFACE=${RDMA_IFACE:-erdma_0} +export MOONCAKE_DEVICE_NAME=${MOONCAKE_DEVICE_NAME:-eth0} +if [[ -z "${MOONCAKE_LOCAL_HOSTNAME:-}" ]]; then + _mc_ip=$(ip -4 -o addr show dev "${MOONCAKE_DEVICE_NAME}" 2>/dev/null | awk '{print $4}' | cut -d/ -f1 | head -n 1) + if [[ -n "${_mc_ip}" ]]; then + export MOONCAKE_LOCAL_HOSTNAME="${_mc_ip}" + fi +fi + +topology=${DISAGG_TOPOLOGY:-multi_node} +default_controller_cfg=${lightx2v_path}/configs/disagg/multi_node/wan22_i2v_distill_controller.json +if [[ "${topology}" == "single_node" ]]; then + default_controller_cfg=${lightx2v_path}/configs/disagg/single_node/wan22_i2v_distill_controller.json +fi +controller_cfg=${DISAGG_CONTROLLER_CFG:-${default_controller_cfg}} +if [[ ! -f "${controller_cfg}" ]]; then + echo "ERROR: controller config not found: ${controller_cfg}" >&2 + exit 2 +fi + +derived_controller_host="" +if command -v jq >/dev/null 2>&1; then + derived_controller_host=$(jq -r '.disagg_config.bootstrap_addr // empty' "${controller_cfg}") +fi +export DISAGG_CONTROLLER_HOST=${DISAGG_CONTROLLER_HOST:-${derived_controller_host:-127.0.0.1}} +# RoCE gid_index: align with cluster data-plane IP (multi-homed / wrong default route breaks cross-node QP RTR). +if [[ -z "${RDMA_PREFERRED_IPV4:-}" && -n "${derived_controller_host}" ]]; then + if [[ "${derived_controller_host}" =~ ^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$ && "${derived_controller_host}" != "127.0.0.1" ]]; then + export RDMA_PREFERRED_IPV4="${derived_controller_host}" + fi +fi +export DISAGG_CONTROLLER_REQUEST_PORT=${DISAGG_CONTROLLER_REQUEST_PORT:-12786} +export LOAD_FROM_USER=${LOAD_FROM_USER:-0} +# multi_node: remote ranks (e.g. slow encoder/decoder host) may need longer TCP/ready waits. +if [[ "${topology}" == "single_node" ]]; then + export DISAGG_INSTANCE_START_TIMEOUT_SECONDS=${DISAGG_INSTANCE_START_TIMEOUT_SECONDS:-90} +else + export DISAGG_INSTANCE_START_TIMEOUT_SECONDS=${DISAGG_INSTANCE_START_TIMEOUT_SECONDS:-300} + export DISAGG_REMOTE_PROXY_START_TIMEOUT_SECONDS=${DISAGG_REMOTE_PROXY_START_TIMEOUT_SECONDS:-120} + export DISAGG_SIDECAR_START_TIMEOUT_SECONDS=${DISAGG_SIDECAR_START_TIMEOUT_SECONDS:-60} +fi +# Dynamic debug defaults to a smaller request batch; override for stress runs. +export DISAGG_AUTO_REQUEST_COUNT=${DISAGG_AUTO_REQUEST_COUNT:-30} +export DISAGG_ENABLE_NSYS=${DISAGG_ENABLE_NSYS:-0} +export SYNC_COMM=${SYNC_COMM:-0} +export DISAGG_NSYS_BIN=${DISAGG_NSYS_BIN:-nsys} +export DISAGG_NSYS_OUTPUT_DIR=${DISAGG_NSYS_OUTPUT_DIR:-${lightx2v_path}/save_results/nsys} +export DISAGG_NSYS_TRACE=${DISAGG_NSYS_TRACE:-cuda,nvtx,osrt} +export DISAGG_NSYS_EXTRA_ARGS=${DISAGG_NSYS_EXTRA_ARGS:-} +user_start_delay_s=${USER_START_DELAY_S:-0} +if [[ -n "${USER_MAX_REQUESTS:-}" ]]; then + user_max_requests=${USER_MAX_REQUESTS} +elif [[ "${LOAD_FROM_USER}" != "0" ]]; then + # When the workload is driven from the user process, keep sending until the stage ends + # unless the caller explicitly sets a hard cap. + user_max_requests=0 +else + user_max_requests=${DISAGG_AUTO_REQUEST_COUNT} +fi + +seed=${SEED:-42} +prompt=${PROMPT:-"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard."} +negative_prompt=${NEGATIVE_PROMPT:-"镜头晃动,色调艳丽,过曝,静态"} +save_result_path=${SAVE_RESULT_PATH:-${lightx2v_path}/save_results/wan22_i2v_dynamic.mp4} + +controller_log=${lightx2v_path}/save_results/disagg_wan22_i2v_dynamic_controller.log +user_log=${lightx2v_path}/save_results/disagg_wan22_i2v_dynamic_user.log + +if [[ "${topology}" == "single_node" ]]; then + controller_wait_timeout_s=${CONTROLLER_WAIT_TIMEOUT_S:-3000} +else + controller_wait_timeout_s=${CONTROLLER_WAIT_TIMEOUT_S:-7200} +fi +controller_poll_interval_s=${CONTROLLER_POLL_INTERVAL_S:-5} +fatal_watch_interval_s=${FATAL_WATCH_INTERVAL_S:-2} +fatal_flag_file=${lightx2v_path}/save_results/disagg_wan22_i2v_dynamic_fatal.flag +remote_log_collect=${REMOTE_LOG_COLLECT:-1} +remote_log_collect_dir=${REMOTE_LOG_COLLECT_DIR:-${lightx2v_path}/save_results/remote_logs} +remote_logs_collected=0 +remote_pre_clean=${DISAGG_REMOTE_PRE_CLEAN:-1} +is_single_node=0 +if [[ "${topology}" == "single_node" ]]; then + is_single_node=1 +fi + +echo "disagg topology=${topology}" +echo "controller_cfg=${controller_cfg}" +echo "DISAGG_CONTROLLER_HOST=${DISAGG_CONTROLLER_HOST} DISAGG_CONTROLLER_REQUEST_PORT=${DISAGG_CONTROLLER_REQUEST_PORT}" +echo "RDMA_PREFERRED_IPV4=${RDMA_PREFERRED_IPV4:-}" +echo "DISAGG_AUTO_REQUEST_COUNT=${DISAGG_AUTO_REQUEST_COUNT}" +echo "DISAGG_ENABLE_NSYS=${DISAGG_ENABLE_NSYS} DISAGG_NSYS_OUTPUT_DIR=${DISAGG_NSYS_OUTPUT_DIR} DISAGG_NSYS_TRACE=${DISAGG_NSYS_TRACE}" +echo "SYNC_COMM=${SYNC_COMM}" +echo "LOAD_FROM_USER=${LOAD_FROM_USER} USER_START_DELAY_S=${user_start_delay_s} USER_MAX_REQUESTS=${user_max_requests}" + +rm -f "${fatal_flag_file}" + +pre_clean_remote_hosts_once() { + if [[ "${is_single_node}" == "1" ]]; then + echo "skip remote pre-clean: single_node topology" + return 0 + fi + if [[ "${remote_pre_clean}" == "0" || "${remote_pre_clean}" == "false" ]]; then + return 0 + fi + if ! command -v jq >/dev/null 2>&1; then + echo "skip remote pre-clean: jq not found" + return 0 + fi + + local bootstrap_host + bootstrap_host=$(jq -r '.disagg_config.bootstrap_addr // empty' "${controller_cfg}") + local ssh_user + ssh_user=$(jq -r '.disagg_config.ssh_user // empty' "${controller_cfg}") + local remote_workdir + remote_workdir=$(jq -r '.disagg_config.remote_workdir // empty' "${controller_cfg}") + if [[ -z "${remote_workdir}" ]]; then + remote_workdir="${lightx2v_path}" + fi + + mapfile -t remote_hosts < <(jq -r --arg bootstrap "${bootstrap_host}" '.disagg_config.static_instance_slots[]?.host // empty | select(length > 0 and . != $bootstrap)' "${controller_cfg}" | sort -u) + if (( ${#remote_hosts[@]} == 0 )); then + echo "no remote hosts discovered for pre-clean" + return 0 + fi + + mapfile -t ssh_opts < <(jq -r '.disagg_config.ssh_options[]? // empty' "${controller_cfg}") + + local remote_workdir_q + remote_workdir_q=$(printf '%q' "${remote_workdir}") + local remote_cmd + remote_cmd="set -e; cd ${remote_workdir_q}; bash scripts/disagg/kill_service.sh || true" + + for host in "${remote_hosts[@]}"; do + local target="${host}" + if [[ -n "${ssh_user}" ]]; then + target="${ssh_user}@${host}" + fi + + echo "remote pre-clean on ${host}" + if ssh "${ssh_opts[@]}" "${target}" "bash -lc '${remote_cmd}'"; then + echo "remote pre-clean succeeded on ${host}" + else + echo "warning: remote pre-clean failed on ${host}" + fi + done +} + +sync_remote_configs_once() { + if [[ "${is_single_node}" == "1" ]]; then + echo "skip remote config sync: single_node topology" + return 0 + fi + if ! command -v jq >/dev/null 2>&1; then + echo "skip remote config sync: jq not found" + return 0 + fi + if ! command -v scp >/dev/null 2>&1; then + echo "skip remote config sync: scp not found" + return 0 + fi + + local bootstrap_host + bootstrap_host=$(jq -r '.disagg_config.bootstrap_addr // empty' "${controller_cfg}") + local ssh_user + ssh_user=$(jq -r '.disagg_config.ssh_user // empty' "${controller_cfg}") + local remote_workdir + remote_workdir=$(jq -r '.disagg_config.remote_workdir // empty' "${controller_cfg}") + if [[ -z "${remote_workdir}" ]]; then + remote_workdir="${lightx2v_path}" + fi + + mapfile -t remote_hosts < <(jq -r --arg bootstrap "${bootstrap_host}" '.disagg_config.static_instance_slots[]?.host // empty | select(length > 0 and . != $bootstrap)' "${controller_cfg}" | sort -u) + if (( ${#remote_hosts[@]} == 0 )); then + echo "no remote hosts discovered for config sync" + return 0 + fi + + mapfile -t ssh_opts < <(jq -r '.disagg_config.ssh_options[]? // empty' "${controller_cfg}") + + local config_files=("${controller_cfg}") + for role in encoder transformer decoder; do + local cfg_candidate="${controller_cfg/_controller.json/_${role}.json}" + if [[ -f "${cfg_candidate}" ]]; then + config_files+=("${cfg_candidate}") + fi + done + + for host in "${remote_hosts[@]}"; do + local target="${host}" + if [[ -n "${ssh_user}" ]]; then + target="${ssh_user}@${host}" + fi + + for src_cfg in "${config_files[@]}"; do + local rel_cfg="${src_cfg#${lightx2v_path}/}" + local dst_cfg="${src_cfg}" + if [[ "${src_cfg}" == "${lightx2v_path}/"* ]]; then + dst_cfg="${remote_workdir}/${rel_cfg}" + fi + + local dst_dir + dst_dir=$(dirname "${dst_cfg}") + ssh "${ssh_opts[@]}" "${target}" "mkdir -p '${dst_dir}'" || true + if scp "${ssh_opts[@]}" "${src_cfg}" "${target}:${dst_cfg}" >/dev/null 2>&1; then + echo "synced config to ${host}:${dst_cfg}" + else + echo "warning: failed to sync config to ${host}:${dst_cfg}" + fi + done + done +} + +# Remote workers import lightx2v from remote_workdir; without syncing, fixes on the controller host never run on peers. +sync_remote_disagg_sources_once() { + if [[ "${is_single_node}" == "1" ]]; then + echo "skip remote disagg source sync: single_node topology" + return 0 + fi + if [[ "${DISAGG_SYNC_REMOTE_SOURCES:-1}" == "0" || "${DISAGG_SYNC_REMOTE_SOURCES:-}" == "false" ]]; then + echo "skip remote disagg source sync: DISAGG_SYNC_REMOTE_SOURCES=${DISAGG_SYNC_REMOTE_SOURCES:-}" + return 0 + fi + if ! command -v jq >/dev/null 2>&1; then + echo "skip remote disagg source sync: jq not found" + return 0 + fi + + local bootstrap_host + bootstrap_host=$(jq -r '.disagg_config.bootstrap_addr // empty' "${controller_cfg}") + local ssh_user + ssh_user=$(jq -r '.disagg_config.ssh_user // empty' "${controller_cfg}") + local remote_workdir + remote_workdir=$(jq -r '.disagg_config.remote_workdir // empty' "${controller_cfg}") + if [[ -z "${remote_workdir}" ]]; then + remote_workdir="${lightx2v_path}" + fi + + mapfile -t remote_hosts < <(jq -r --arg bootstrap "${bootstrap_host}" '.disagg_config.static_instance_slots[]?.host // empty | select(length > 0 and . != $bootstrap)' "${controller_cfg}" | sort -u) + if (( ${#remote_hosts[@]} == 0 )); then + echo "no remote hosts for disagg source sync" + return 0 + fi + + mapfile -t ssh_opts < <(jq -r '.disagg_config.ssh_options[]? // empty' "${controller_cfg}") + local rsync_rsh="ssh" + for opt in "${ssh_opts[@]}"; do + rsync_rsh+=" $(printf '%q' "${opt}")" + done + + local rel_disagg="lightx2v/disagg" + local src_dir="${lightx2v_path}/${rel_disagg}/" + # Do not overwrite rdma_base.py on peers: pyverbs/rdma-core versions may differ per host. + local sync_excludes=( + --exclude=rdma_base.py + ) + + for host in "${remote_hosts[@]}"; do + local target="${host}" + if [[ -n "${ssh_user}" ]]; then + target="${ssh_user}@${host}" + fi + local dst_dir="${remote_workdir}/${rel_disagg}" + ssh "${ssh_opts[@]}" "${target}" "mkdir -p '${dst_dir}'" || true + if command -v rsync >/dev/null 2>&1; then + if rsync -az -e "${rsync_rsh}" "${sync_excludes[@]}" "${src_dir}" "${target}:${dst_dir}/"; then + echo "synced ${rel_disagg}/ to ${host}:${dst_dir}/ (excludes rdma_base.py)" + else + echo "warning: rsync ${rel_disagg} to ${host} failed" + fi + else + if ( cd "${lightx2v_path}" && tar cf - "${sync_excludes[@]}" "${rel_disagg}" ) | ssh "${ssh_opts[@]}" "${target}" "cd '${remote_workdir}' && tar xf -"; then + echo "synced ${rel_disagg}/ to ${host} (tar, excludes rdma_base.py)" + else + echo "warning: tar-sync ${rel_disagg} to ${host} failed" + fi + fi + done +} + +collect_remote_logs_once() { + if [[ "${is_single_node}" == "1" ]]; then + echo "skip remote log collection: single_node topology" + return 0 + fi + if [[ "${remote_log_collect}" == "0" || "${remote_log_collect}" == "false" ]]; then + return 0 + fi + if [[ "${remote_logs_collected}" == "1" ]]; then + return 0 + fi + remote_logs_collected=1 + + if [[ ! -f "${controller_cfg}" ]]; then + echo "skip remote log collection: controller config not found: ${controller_cfg}" + return 0 + fi + if ! command -v jq >/dev/null 2>&1; then + echo "skip remote log collection: jq not found" + return 0 + fi + + local remote_log_dir + remote_log_dir=$(jq -r '.disagg_config.remote_log_dir // empty' "${controller_cfg}") + local bootstrap_host + bootstrap_host=$(jq -r '.disagg_config.bootstrap_addr // empty' "${controller_cfg}") + local ssh_user + ssh_user=$(jq -r '.disagg_config.ssh_user // empty' "${controller_cfg}") + if [[ -z "${remote_log_dir}" ]]; then + echo "skip remote log collection: disagg_config.remote_log_dir is empty" + return 0 + fi + + mapfile -t remote_hosts < <(jq -r --arg bootstrap "${bootstrap_host}" '.disagg_config.static_instance_slots[]?.host // empty | select(length > 0 and . != $bootstrap)' "${controller_cfg}" | sort -u) + if (( ${#remote_hosts[@]} == 0 )); then + echo "no remote hosts discovered from static_instance_slots, skip remote log collection" + return 0 + fi + + mapfile -t ssh_opts < <(jq -r '.disagg_config.ssh_options[]? // empty' "${controller_cfg}") + + local ts + ts=$(date +%Y%m%d_%H%M%S) + mkdir -p "${remote_log_collect_dir}" + + for host in "${remote_hosts[@]}"; do + local target="${host}" + if [[ -n "${ssh_user}" ]]; then + target="${ssh_user}@${host}" + fi + + local dest_dir="${remote_log_collect_dir}/${host}_${ts}" + local archive_path="${dest_dir}/remote_logs.tgz" + mkdir -p "${dest_dir}" + + local remote_log_dir_q + remote_log_dir_q=$(printf '%q' "${remote_log_dir}") + local remote_cmd + remote_cmd="set -e; shopt -s nullglob; cd ${remote_log_dir_q}; files=(*_service.log *_sidecar.log); if (( \${#files[@]} == 0 )); then exit 3; fi; tar -czf - -- \"\${files[@]}\"" + + if ssh "${ssh_opts[@]}" "${target}" "bash -lc '${remote_cmd}'" > "${archive_path}" 2>/dev/null; then + tar -xzf "${archive_path}" -C "${dest_dir}" >/dev/null 2>&1 || true + rm -f "${archive_path}" + echo "remote logs collected from ${host} -> ${dest_dir}" + else + rm -f "${archive_path}" + echo "warning: failed to collect remote logs from ${host}:${remote_log_dir}" + fi + done +} + +has_fatal_log_error() { + local log_path="$1" + [[ -f "${log_path}" ]] || return 1 + + # Fail-fast on known fatal patterns so we do not wait for full run completion. + rg -q "KeyError: '/psm_|resource_tracker: There appear to be [0-9]+ leaked shared_memory objects|Failed to process request for room=|Data(Sender|Receiver) transfer failed for room=" "${log_path}" +} + +start_fatal_watchdog() { + ( + while true; do + if [[ -f "${fatal_flag_file}" ]]; then + exit 0 + fi + if [[ -n "${controller_pid:-}" ]] && ! kill -0 "${controller_pid}" 2>/dev/null; then + exit 0 + fi + if has_fatal_log_error "${controller_log}" || has_fatal_log_error "${user_log}"; then + echo "fatal error detected in logs, stopping services immediately" + : > "${fatal_flag_file}" + [[ -n "${user_pid:-}" ]] && kill -TERM "${user_pid}" 2>/dev/null || true + [[ -n "${controller_pid:-}" ]] && kill -TERM "${controller_pid}" 2>/dev/null || true + # Give controller/sidecars a short grace window to release rooms. + for _ in $(seq 1 10); do + local_alive=0 + if [[ -n "${user_pid:-}" ]] && kill -0 "${user_pid}" 2>/dev/null; then + local_alive=1 + fi + if [[ -n "${controller_pid:-}" ]] && kill -0 "${controller_pid}" 2>/dev/null; then + local_alive=1 + fi + if [[ "${local_alive}" -eq 0 ]]; then + break + fi + sleep 0.5 + done + bash ${lightx2v_path}/scripts/disagg/kill_service.sh || true + exit 0 + fi + sleep "${fatal_watch_interval_s}" + done + ) & + watchdog_pid=$! +} + +is_controller_stuck() { + local log_path="$1" + [[ -f "${log_path}" ]] || return 1 + + local tail_block + tail_block=$(tail -n 240 "${log_path}" 2>/dev/null || true) + [[ -n "${tail_block}" ]] || return 1 + + # Waiting for decoder results, all GPUs idle, and queues still pending => hard-stuck. + if echo "${tail_block}" | rg -q "Waiting for decoder results" \ + && echo "${tail_block}" | rg -q "queue_total_pending': [1-9]" \ + && ! echo "${tail_block}" | rg -q "gpu_utilization': ([1-9][0-9]*|0\\.[1-9])"; then + return 0 + fi + return 1 +} + +cleanup() { + local pids=("${user_pid:-}" "${controller_pid:-}") + for pid in "${pids[@]}"; do + if [[ -n "${pid}" ]] && kill -0 "${pid}" 2>/dev/null; then + kill "${pid}" 2>/dev/null || true + fi + done + if [[ -n "${watchdog_pid:-}" ]] && kill -0 "${watchdog_pid}" 2>/dev/null; then + kill "${watchdog_pid}" 2>/dev/null || true + fi + bash ${lightx2v_path}/scripts/disagg/kill_service.sh || true + collect_remote_logs_once || true +} + +trap cleanup EXIT INT TERM + +pre_clean_remote_hosts_once +# sync_remote_configs_once +sync_remote_disagg_sources_once + +python -m lightx2v.disagg.examples.run_service \ + --service controller \ + --model_cls wan2.2_moe \ + --task i2v \ + --model_path ${model_path} \ + --config_json ${controller_cfg} \ + --seed ${seed} \ + --prompt "${prompt}" \ + --negative_prompt "${negative_prompt}" \ + --save_result_path ${save_result_path} \ + > ${controller_log} 2>&1 & +controller_pid=$! + +echo "controller started pid=${controller_pid}" +sleep 8 + +if [[ "${LOAD_FROM_USER}" != "0" ]]; then + if [[ "${user_start_delay_s}" != "0" ]]; then + echo "waiting ${user_start_delay_s}s before run_user to let remote services warm up" + sleep "${user_start_delay_s}" + fi + python -m lightx2v.disagg.examples.run_user \ + --controller_host "${DISAGG_CONTROLLER_HOST}" \ + --controller_request_port "${DISAGG_CONTROLLER_REQUEST_PORT}" \ + --max_requests "${user_max_requests}" \ + > ${user_log} 2>&1 & + user_pid=$! + echo "run_user started pid=${user_pid}" +else + echo "LOAD_FROM_USER=${LOAD_FROM_USER}, skip starting run_user" +fi + +start_fatal_watchdog + +if [[ -n "${user_pid:-}" ]]; then + wait ${user_pid} || true + echo "run_user finished" +fi + +if [[ -f "${fatal_flag_file}" ]]; then + echo "fatal error handled by watchdog, exiting early" + wait "${controller_pid}" 2>/dev/null || true + exit 125 +fi + +controller_wait_start=$(date +%s) +while kill -0 "${controller_pid}" 2>/dev/null; do + now_ts=$(date +%s) + elapsed=$((now_ts - controller_wait_start)) + + if (( elapsed >= controller_wait_timeout_s )); then + if is_controller_stuck "${controller_log}"; then + echo "controller stuck detected (all GPUs idle with pending queues), force killing services" + else + echo "controller wait timeout (${controller_wait_timeout_s}s), force killing services" + fi + kill "${controller_pid}" 2>/dev/null || true + bash ${lightx2v_path}/scripts/disagg/kill_service.sh || true + wait "${controller_pid}" 2>/dev/null || true + exit 124 + fi + + if [[ -f "${fatal_flag_file}" ]]; then + echo "fatal error handled by watchdog, exiting early" + wait "${controller_pid}" 2>/dev/null || true + exit 125 + fi + + sleep "${controller_poll_interval_s}" +done + +wait ${controller_pid} +if [[ -n "${watchdog_pid:-}" ]] && kill -0 "${watchdog_pid}" 2>/dev/null; then + kill "${watchdog_pid}" 2>/dev/null || true +fi +echo "controller finished" + +echo "logs: ${controller_log} ${user_log}" diff --git a/scripts/disagg/run_wan_t2v_service.sh b/scripts/disagg/run_wan_t2v_service.sh deleted file mode 100755 index 8facbc665..000000000 --- a/scripts/disagg/run_wan_t2v_service.sh +++ /dev/null @@ -1,159 +0,0 @@ -#!/bin/bash - -# set path firstly -lightx2v_path=/root/zht/LightX2V -model_path=/root/zht/LightX2V/models/Wan-AI/Wan2.1-T2V-1.3B - -# set environment variables -source ${lightx2v_path}/scripts/base/base.sh - -# Keep flashinfer enabled while ensuring nvcc uses a supported host compiler. -export CC=/usr/bin/gcc-13 -export CXX=/usr/bin/g++-13 -export CUDAHOSTCXX=/usr/bin/g++-13 -if [[ -n "${NVCC_PREPEND_FLAGS:-}" ]]; then - export NVCC_PREPEND_FLAGS="${NVCC_PREPEND_FLAGS} -allow-unsupported-compiler" -else - export NVCC_PREPEND_FLAGS="-allow-unsupported-compiler" -fi - -controller_cfg=${lightx2v_path}/configs/disagg/wan_t2v_disagg_controller.json -encoder_cfg=${lightx2v_path}/configs/disagg/wan_t2v_disagg_encoder.json -transformer_cfg=${lightx2v_path}/configs/disagg/wan_t2v_disagg_transformer.json -decoder_cfg=${lightx2v_path}/configs/disagg/wan_t2v_disagg_decoder.json - -seed=42 -request_count=10 -prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." -negative_prompt="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" -save_result_path=${lightx2v_path}/save_results/test_disagg.mp4 -output_files=() -for ((i=1; i<=request_count; i++)); do - output_files+=("${save_result_path%.mp4}${i}.mp4") -done - -# Remove old outputs so wait loop reflects current run status. -rm -f "${output_files[@]}" - -cleanup() { - local pids=("${encoder_pid:-}" "${transformer_pid:-}" "${decoder_pid:-}" "${controller_pid:-}") - for pid in "${pids[@]}"; do - if [[ -n "${pid}" ]] && kill -0 "${pid}" 2>/dev/null; then - kill "${pid}" 2>/dev/null || true - fi - done -} - -trap cleanup EXIT INT TERM - -wait_for_port() { - local host="$1" - local port="$2" - local timeout_secs="${3:-30}" - local waited=0 - - while true; do - if (echo > /dev/tcp/${host}/${port}) >/dev/null 2>&1; then - echo "Port ready: ${host}:${port}" - return 0 - fi - - if (( waited >= timeout_secs )); then - echo "Timeout waiting for port ${host}:${port} after ${timeout_secs}s" - return 1 - fi - - sleep 1 - waited=$((waited + 1)) - done -} - -rdma_request_port=5566 -rdma_phase1_port=5567 -rdma_phase2_port=5568 - -python -m lightx2v.disagg.examples.run_service \ - --service controller \ - --model_cls wan2.1 \ - --task t2v \ - --model_path ${model_path} \ - --config_json ${controller_cfg} \ - --seed ${seed} \ - --prompt "${prompt}" \ - --negative_prompt "${negative_prompt}" \ - --save_result_path ${save_result_path} \ - > ${lightx2v_path}/save_results/disagg_controller.log 2>&1 & -controller_pid=$! - -wait_for_port 127.0.0.1 ${rdma_request_port} 60 -wait_for_port 127.0.0.1 ${rdma_phase1_port} 60 -wait_for_port 127.0.0.1 ${rdma_phase2_port} 60 - -# NOTE: Kept for rollback. Controller now creates encoder/transformer/decoder internally. -# CUDA_VISIBLE_DEVICES=0 python -m lightx2v.disagg.examples.run_service \ -# --service encoder \ -# --model_cls wan2.1 \ -# --task t2v \ -# --model_path ${model_path} \ -# --config_json ${encoder_cfg} \ -# --seed ${seed} \ -# --prompt "${prompt}" \ -# --negative_prompt "${negative_prompt}" \ -# --save_result_path ${save_result_path} \ -# > ${lightx2v_path}/save_results/disagg_encoder.log 2>&1 & -# encoder_pid=$! - -# CUDA_VISIBLE_DEVICES=1 python -m lightx2v.disagg.examples.run_service \ -# --service transformer \ -# --model_cls wan2.1 \ -# --task t2v \ -# --model_path ${model_path} \ -# --config_json ${transformer_cfg} \ -# --seed ${seed} \ -# --prompt "${prompt}" \ -# --negative_prompt "${negative_prompt}" \ -# --save_result_path ${save_result_path} \ -# > ${lightx2v_path}/save_results/disagg_transformer.log 2>&1 & -# transformer_pid=$! - -# CUDA_VISIBLE_DEVICES=2 python -m lightx2v.disagg.examples.run_service \ -# --service decoder \ -# --model_cls wan2.1 \ -# --task t2v \ -# --model_path ${model_path} \ -# --config_json ${decoder_cfg} \ -# --seed ${seed} \ -# --prompt "${prompt}" \ -# --negative_prompt "${negative_prompt}" \ -# --save_result_path ${save_result_path} \ -# > ${lightx2v_path}/save_results/disagg_decoder.log 2>&1 & -# decoder_pid=$! - -# Give background services time to flush and finish queued requests. - -echo "Waiting for output videos: ${output_files[*]}" -wait_seconds=0 -max_wait_seconds=$((600 * request_count)) - -while true; do - all_generated=1 - for file in "${output_files[@]}"; do - if [[ ! -f "${file}" ]]; then - all_generated=0 - break - fi - done - - if (( all_generated )); then - echo "All ${request_count} output videos are generated." - break - fi - - if (( wait_seconds >= max_wait_seconds )); then - echo "Timeout waiting for output videos after ${max_wait_seconds}s" - exit 1 - fi - - sleep 5 - wait_seconds=$((wait_seconds + 5)) -done