diff --git a/configs/disagg/baseline/wan22_moe_i2v_baseline.json b/configs/disagg/baseline/wan22_moe_i2v_baseline.json new file mode 100644 index 000000000..96c42692e --- /dev/null +++ b/configs/disagg/baseline/wan22_moe_i2v_baseline.json @@ -0,0 +1,48 @@ +{ + "infer_steps": 4, + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "eps": 1e-06, + "model_type": "i2v", + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "fps": 16, + "use_image_encoder": false, + "boundary_step_index": 2, + "denoising_step_list": [ + 1000, + 750, + 500, + 250 + ], + "dit_quantized": true, + "dit_quant_scheme": "int8-q8f", + "high_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "high_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors", + "low_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors", + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/disagg/baseline/wan22_moe_i2v_baseline_50steps.json b/configs/disagg/baseline/wan22_moe_i2v_baseline_50steps.json new file mode 100644 index 000000000..cb99ba368 --- /dev/null +++ b/configs/disagg/baseline/wan22_moe_i2v_baseline_50steps.json @@ -0,0 +1,36 @@ +{ + "model_path": "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B", + "infer_steps": 50, + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "eps": 1e-06, + "model_type": "i2v", + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": [ + 3.5, + 3.5 + ], + "sample_shift": 5.0, + "enable_cfg": false, + "cpu_offload": true, + "offload_granularity": "block", + "t5_cpu_offload": false, + "vae_cpu_offload": false, + "fps": 16, + "use_image_encoder": false, + "parallel": { + "seq_p_size": 8, + "seq_p_attn_type": "ulysses" + } +} diff --git a/lightx2v/disagg/examples/infer.py b/lightx2v/disagg/examples/infer.py new file mode 100644 index 000000000..e8e6c46ef --- /dev/null +++ b/lightx2v/disagg/examples/infer.py @@ -0,0 +1,245 @@ +import argparse +import os + +import torch +import torch.distributed as dist +from loguru import logger + +from lightx2v.common.ops import * +from lightx2v.models.runners.bagel.bagel_runner import BagelRunner # noqa: F401 +from lightx2v.models.runners.flux2.flux2_runner import Flux2DevRunner, Flux2KleinRunner # noqa: F401 +from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner # noqa: F401 +from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner # noqa: F401 +from lightx2v.models.runners.longcat_image.longcat_image_runner import LongCatImageRunner # noqa: F401 +from lightx2v.models.runners.ltx2.ltx2_runner import LTX2Runner # noqa: F401 +from lightx2v.models.runners.neopp.neopp_runner import NeoppRunner # noqa: F401 +from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner # noqa: F401 +from lightx2v.models.runners.seedvr.seedvr_runner import SeedVRRunner # noqa: F401 +from lightx2v.models.runners.wan.wan_animate_runner import WanAnimateRunner # noqa: F401 +from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, WanAudioRunner # noqa: F401 +from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401 +from lightx2v.models.runners.wan.wan_matrix_game2_runner import WanSFMtxg2Runner # noqa: F401 +from lightx2v.models.runners.wan.wan_matrix_game3_runner import WanMatrixGame3Runner # noqa: F401 +from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401 +from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner # noqa: F401 +from lightx2v.models.runners.wan.wan_vace_runner import Wan22MoeVaceRunner, WanVaceRunner # noqa: F401 +from lightx2v.models.runners.worldplay.worldplay_ar_runner import WorldPlayARRunner # noqa: F401 +from lightx2v.models.runners.worldplay.worldplay_bi_runner import WorldPlayBIRunner # noqa: F401 +from lightx2v.models.runners.worldplay.worldplay_distill_runner import WorldPlayDistillRunner # noqa: F401 +from lightx2v.models.runners.z_image.z_image_runner import ZImageRunner # noqa: F401 +from lightx2v.utils.envs import * +from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict +from lightx2v.utils.profiler import * +from lightx2v.utils.registry_factory import RUNNER_REGISTER +from lightx2v.utils.set_config import print_config, set_config, set_parallel_config +from lightx2v.utils.utils import seed_all, validate_config_paths +from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER + +try: + from lightx2v.models.runners.worldmirror.worldmirror_runner import WorldMirrorRunner # noqa: F401 +except Exception as exc: # pragma: no cover - optional dependency guard + logger.warning("WorldMirrorRunner import skipped: {}", exc) + + +def init_runner(config): + torch.set_grad_enabled(False) + runner = RUNNER_REGISTER[config["model_cls"]](config) + runner.init_modules() + return runner + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42, help="The seed for random generator") + parser.add_argument( + "--model_cls", + type=str, + required=True, + choices=[ + "wan2.1", + "wan2.1_distill", + "wan2.1_mean_flow_distill", + "wan2.1_vace", + "wan2.1_sf", + "wan2.1_sf_mtxg2", + "seko_talk", + "wan2.2_moe", + "lingbot_world", + "wan2.2", + "wan2.2_matrix_game3", + "wan2.2_moe_audio", + "wan2.2_audio", + "wan2.2_moe_distill", + "wan2.2_moe_vace", + "qwen_image", + "longcat_image", + "wan2.2_animate", + "hunyuan_video_1.5", + "hunyuan_video_1.5_distill", + "worldplay_distill", + "worldplay_ar", + "worldplay_bi", + "z_image", + "flux2_klein", + "flux2_dev", + "ltx2", + "bagel", + "seedvr2", + "neopp", + "lingbot_world_fast", + "worldmirror", + ], + default="wan2.1", + ) + + parser.add_argument("--task", type=str, choices=["t2v", "i2v", "t2i", "i2i", "flf2v", "vace", "animate", "s2v", "rs2v", "t2av", "i2av", "ltx2_s2v", "sr", "recon"], default="t2v") + parser.add_argument("--support_tasks", type=str, nargs="+", default=[], help="Set supported tasks for the model") + parser.add_argument("--model_path", type=str, required=True) + parser.add_argument("--sf_model_path", type=str, required=False) + parser.add_argument("--config_json", type=str, required=True) + parser.add_argument("--use_prompt_enhancer", action="store_true") + + parser.add_argument("--prompt", type=str, default="", help="The input prompt for text-to-video generation") + parser.add_argument("--negative_prompt", type=str, default="") + + parser.add_argument( + "--image_path", + type=str, + default="", + help="The path to input image file(s) for image-to-video (i2v) or image-to-audio-video (i2av) task. Multiple paths should be comma-separated. Example: 'path1.jpg,path2.jpg'", + ) + parser.add_argument("--last_frame_path", type=str, default="", help="The path to last frame file for first-last-frame-to-video (flf2v) task") + parser.add_argument( + "--audio_path", + type=str, + default="", + help="Input audio path: Wan s2v / rs2v, or required for LTX-2 task ltx2_s2v.", + ) + parser.add_argument("--image_strength", type=str, default="1.0", help="i2av: single float, or comma-separated floats (one per image, or one value broadcast). Example: 1.0 or 1.0,0.85,0.9") + parser.add_argument( + "--image_frame_idx", type=str, default="", help="i2av: comma-separated pixel frame indices (one per image). Omit or empty to evenly space frames in [0, num_frames-1]. Example: 0,40,80" + ) + # [Warning] For vace task, need refactor. + parser.add_argument( + "--src_ref_images", + type=str, + default=None, + help="The file list of the source reference images. Separated by ','. Default None.", + ) + parser.add_argument( + "--src_video", + type=str, + default=None, + help="The file of the source video. Default None.", + ) + parser.add_argument( + "--src_mask", + type=str, + default=None, + help="The file of the source mask. Default None.", + ) + parser.add_argument( + "--src_pose_path", + type=str, + default=None, + help="The file of the source pose. Default None.", + ) + parser.add_argument( + "--src_face_path", + type=str, + default=None, + help="The file of the source face. Default None.", + ) + parser.add_argument( + "--src_bg_path", + type=str, + default=None, + help="The file of the source background. Default None.", + ) + parser.add_argument( + "--src_mask_path", + type=str, + default=None, + help="The file of the source mask. Default None.", + ) + parser.add_argument( + "--pose", + type=str, + default=None, + help="Pose string (e.g., 'w-3, right-0.5') or JSON file path for WorldPlay models.", + ) + parser.add_argument( + "--action_path", + type=str, + default=None, + help="Directory path for lingbot camera/action control files (poses.npy, intrinsics.npy, optional action.npy).", + ) + parser.add_argument( + "--action_ckpt", + type=str, + default=None, + help="Path to action model checkpoint for WorldPlay models.", + ) + # WorldMirror (3D reconstruction) specific + parser.add_argument("--input_path", type=str, default=None, help="(worldmirror/recon) Path to a directory of images, a video file, or a single image.") + parser.add_argument("--strict_output_path", type=str, default=None, help="(worldmirror/recon) If set, write outputs directly here instead of under save_result_path///.") + parser.add_argument("--prior_cam_path", type=str, default=None, help="(worldmirror/recon) Optional camera prior JSON (extrinsics + intrinsics).") + parser.add_argument("--prior_depth_path", type=str, default=None, help="(worldmirror/recon) Optional depth prior directory (one .npy/.png per image).") + parser.add_argument("--subfolder", type=str, default=None, help="(worldmirror/recon) Subfolder inside model_path containing weights. Overrides config.") + parser.add_argument("--disable_heads", type=str, nargs="*", default=None, help="(worldmirror/recon) Heads to disable: any of camera depth normal points gs.") + parser.add_argument("--enable_bf16", action="store_true", default=False, help="(worldmirror/recon) Run the WorldMirror model in bf16.") + parser.add_argument("--save_rendered", action="store_true", default=False, help="(worldmirror/recon) Render an interpolated fly-through video from Gaussian splats.") + parser.add_argument("--render_interp_per_pair", type=int, default=None, help="(worldmirror/recon) Interpolated frames per camera pair for --save_rendered.") + parser.add_argument("--render_depth", action="store_true", default=False, help="(worldmirror/recon) Also render a depth video with --save_rendered.") + parser.add_argument("--wm_config_path", type=str, default=None, help="(worldmirror/recon) Optional training YAML (pair with --wm_ckpt_path).") + parser.add_argument("--wm_ckpt_path", type=str, default=None, help="(worldmirror/recon) Optional .ckpt/.safetensors (pair with --wm_config_path).") + + parser.add_argument("--save_result_path", type=str, default=None, help="The path to save video path/file") + parser.add_argument("--return_result_tensor", action="store_true", help="Whether to return result tensor. (Useful for comfyui)") + parser.add_argument("--target_shape", type=int, nargs="+", default=[], help="Set return video or image shape") + parser.add_argument("--target_video_length", type=int, default=81, help="The target video length for each generated clip") + parser.add_argument("--aspect_ratio", type=str, default="") + parser.add_argument("--video_path", type=str, default=None, help="input video path(for sr/v2v task)") + parser.add_argument("--sr_ratio", type=float, default=2.0, help="super resolution ratio for sr task") + parser.add_argument( + "--num_iterations", + type=int, + default=None, + help="Override the number of Matrix-Game-3 generation segments. Final video length follows 57 + 40 * (num_iterations - 1).", + ) + + args = parser.parse_args() + # validate_task_arguments(args) + + seed_all(args.seed) + + # set config + config = set_config(args) + # init input_info + input_info = init_empty_input_info(args.task, args.support_tasks) + + if config["parallel"]: + platform_device = PLATFORM_DEVICE_REGISTER.get(os.getenv("PLATFORM", "cuda"), None) + platform_device.init_parallel_env() + set_parallel_config(config) + + print_config(config) + + validate_config_paths(config) + + with ProfilingContext4DebugL1("Total Cost"): + # init runner + runner = init_runner(config) + # start to infer + data = args.__dict__ + update_input_info_from_dict(input_info, data) + runner.run_pipeline(input_info) + + # Clean up distributed process group + if dist.is_initialized(): + dist.destroy_process_group() + logger.info("Distributed process group cleaned up") + + +if __name__ == "__main__": + main() diff --git a/lightx2v/disagg/examples/run_controller.py b/lightx2v/disagg/examples/run_controller.py new file mode 100644 index 000000000..c206af60d --- /dev/null +++ b/lightx2v/disagg/examples/run_controller.py @@ -0,0 +1,561 @@ +import argparse +import copy +import json +import os +import subprocess +import sys +import tempfile +import threading +import time +from pathlib import Path +from typing import Any + +from loguru import logger + +from lightx2v.disagg.conn import MONITOR_POLLING_PORT, REQUEST_POLLING_PORT, ReqManager +from lightx2v.disagg.monitor import Monitor, Reporter +from lightx2v.disagg.workload import build_payload, current_stage, load_stage_specs, start_workload_clock + + +def _parse_gpus(raw: str) -> list[int]: + parts = [p.strip() for p in raw.split(",") if p.strip()] + if not parts: + return [0] + return [int(p) for p in parts] + + +def _parallel_world_size_from_config(config: dict[str, Any] | None) -> int: + if not isinstance(config, dict): + return 1 + parallel = config.get("parallel") + if not isinstance(parallel, dict): + return 1 + + tensor_p_size = int(parallel.get("tensor_p_size", 1) or 1) + if tensor_p_size > 1: + return tensor_p_size + + cfg_p_size = int(parallel.get("cfg_p_size", 1) or 1) + seq_p_size = int(parallel.get("seq_p_size", 1) or 1) + return max(1, cfg_p_size * seq_p_size) + + +def _load_base_config_json(path: str) -> dict[str, Any]: + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + + +def _make_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Dispatch baseline infer requests and collect latency/GPU metrics") + + parser.add_argument("--mode", choices=["controller", "worker"], default="controller") + + # Controller mode + parser.add_argument("--request_source", choices=["run_user", "generate"], default="run_user") + parser.add_argument("--controller_request_port", type=int, default=REQUEST_POLLING_PORT - 2) + parser.add_argument("--result_port", type=int, default=REQUEST_POLLING_PORT - 1) + parser.add_argument("--worker_base_port", type=int, default=REQUEST_POLLING_PORT + 100) + parser.add_argument("--worker_monitor_base_port", type=int, default=MONITOR_POLLING_PORT + 100) + parser.add_argument("--monitor_poll_interval_s", type=float, default=2.0) + parser.add_argument("--request_poll_sleep_s", type=float, default=0.02) + parser.add_argument("--completion_timeout_s", type=float, default=7200.0) + parser.add_argument("--dist_master_addr", type=str, default="127.0.0.1") + parser.add_argument("--dist_master_port", type=int, default=29600) + + parser.add_argument("--num_workers", type=int, default=4) + parser.add_argument("--gpus", type=str, default="0,1,2,3") + parser.add_argument("--python_executable", type=str, default=sys.executable) + + parser.add_argument("--model_cls", type=str, default="wan2.2_moe") + parser.add_argument("--task", type=str, default="i2v") + parser.add_argument("--model_path", type=str, default="/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B") + parser.add_argument("--base_config_json", type=str, default="/root/zht/LightX2V/configs/disagg/baseline/wan22_moe_i2v_baseline.json") + parser.add_argument("--save_dir", type=str, default="/root/zht/LightX2V/save_results") + parser.add_argument("--save_result_path", type=str, default="") + parser.add_argument("--prompt", type=str, default="") + parser.add_argument("--negative_prompt", type=str, default="") + parser.add_argument("--image_path", type=str, default="") + parser.add_argument("--keep_parallel_config", action="store_true", default=True) + parser.add_argument("--drop_parallel_config", action="store_true", default=False) + + parser.add_argument("--generate_requests", type=int, default=10) + parser.add_argument("--generate_interval_s", type=float, default=0.0) + + parser.add_argument("--metrics_output_json", type=str, default="/root/zht/LightX2V/save_results/baseline_controller_metrics.json") + + # Worker mode + parser.add_argument("--worker_id", type=int, default=-1) + parser.add_argument("--worker_recv_port", type=int, default=0) + parser.add_argument("--worker_gpu", type=int, default=0) + parser.add_argument("--worker_monitor_port", type=int, default=0) + parser.add_argument("--worker_dist_rank", type=int, default=0) + parser.add_argument("--worker_dist_world_size", type=int, default=1) + parser.add_argument("--worker_cooperative_parallel", action="store_true", default=False) + + return parser + + +def _write_request_config(payload: dict[str, Any], request_id: int, keep_parallel_config: bool) -> tuple[str, str]: + temp_dir = tempfile.mkdtemp(prefix=f"baseline_req_{request_id}_") + config_path = str(Path(temp_dir) / "request_config.json") + + request_config = copy.deepcopy(payload) + request_config.pop("workload_end", None) + request_config.pop("request_metrics", None) + if not keep_parallel_config: + request_config.pop("parallel", None) + + with open(config_path, "w", encoding="utf-8") as f: + json.dump(request_config, f, ensure_ascii=False, indent=2) + + return temp_dir, config_path + + +def _effective_keep_parallel(args: argparse.Namespace) -> bool: + return bool(args.keep_parallel_config) and not bool(args.drop_parallel_config) + + +def _run_infer_once(args: argparse.Namespace, payload: dict[str, Any], worker_id: int) -> tuple[int, str]: + request_metrics = payload.get("request_metrics", {}) if isinstance(payload.get("request_metrics"), dict) else {} + request_id = int(request_metrics.get("request_id", int(time.time() * 1000))) + + keep_parallel = _effective_keep_parallel(args) + temp_dir, config_json = _write_request_config(payload, request_id, keep_parallel_config=keep_parallel) + save_path = payload.get("save_path") or payload.get("save_result_path") + if not save_path: + save_dir = Path(args.save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + save_path = str(save_dir / f"baseline_worker{worker_id}_req{request_id}.mp4") + + infer_argv = [ + "-m", + "lightx2v.disagg.examples.infer", + "--model_cls", + str(payload.get("model_cls", args.model_cls)), + "--task", + str(payload.get("task", args.task)), + "--model_path", + str(payload.get("model_path", args.model_path)), + "--config_json", + config_json, + "--prompt", + str(payload.get("prompt", "")), + "--negative_prompt", + str(payload.get("negative_prompt", "")), + "--save_result_path", + str(save_path), + ] + + image_path = payload.get("image_path") + if image_path: + infer_argv.extend(["--image_path", str(image_path)]) + + seed = payload.get("seed") + if seed is not None: + infer_argv.extend(["--seed", str(int(seed))]) + + # Keep execution path aligned with `lightx2v.disagg.examples.infer`: + # one request uses one worker process, and when parallel is enabled this + # single request is launched by torchrun on all visible devices. + parallel_world_size = _parallel_world_size_from_config(payload) if keep_parallel else 1 + worker_world_size = int(args.worker_dist_world_size or 1) + cooperative_parallel = bool(args.worker_cooperative_parallel) and worker_world_size > 1 and parallel_world_size > 1 + if cooperative_parallel: + cmd = [args.python_executable, *infer_argv] + elif parallel_world_size > 1: + cmd = [ + args.python_executable, + "-m", + "torch.distributed.run", + "--standalone", + f"--nproc_per_node={parallel_world_size}", + *infer_argv, + ] + else: + cmd = [args.python_executable, *infer_argv] + + env = os.environ.copy() + if cooperative_parallel: + env["MASTER_ADDR"] = str(args.dist_master_addr) + env["MASTER_PORT"] = str(int(args.dist_master_port)) + env["RANK"] = str(int(args.worker_dist_rank)) + env["LOCAL_RANK"] = str(int(args.worker_dist_rank)) + env["NODE_RANK"] = "0" + env["WORLD_SIZE"] = str(worker_world_size) + + result = subprocess.run(cmd, env=env, check=False) + + try: + os.remove(config_json) + os.rmdir(temp_dir) + except OSError: + pass + + return int(result.returncode), str(save_path) + + +def _worker_main(args: argparse.Namespace) -> None: + req_mgr = ReqManager() + + reporter = Reporter( + service_type=f"baseline_infer_worker_{args.worker_id}", + gpu_id=int(args.worker_gpu), + bind_address=f"tcp://*:{args.worker_monitor_port}", + ) + reporter_thread = threading.Thread(target=reporter.serve_forever, name=f"worker-{args.worker_id}-reporter", daemon=True) + reporter_thread.start() + + logger.info( + "worker={} listening on port={} gpu={} monitor_port={}", + args.worker_id, + args.worker_recv_port, + args.worker_gpu, + args.worker_monitor_port, + ) + + while True: + msg = req_mgr.receive(args.worker_recv_port) + if isinstance(msg, dict) and msg.get("__control__") == "stop": + break + + payload = msg if isinstance(msg, dict) else {} + req_metrics = payload.get("request_metrics", {}) if isinstance(payload.get("request_metrics"), dict) else {} + request_id = int(req_metrics.get("request_id", int(time.time() * 1000))) + client_send_ts = float(req_metrics.get("client_send_ts", time.time())) + + start_ts = time.time() + return_code, save_path = _run_infer_once(args, payload, args.worker_id) + finish_ts = time.time() + + # In cooperative parallel mode, all workers are ranks of one request. + # Only rank0 reports completion to avoid duplicated completion events. + if not args.worker_cooperative_parallel or int(args.worker_dist_rank) == 0: + req_mgr.send( + "127.0.0.1", + args.result_port, + { + "request_id": request_id, + "worker_id": args.worker_id, + "start_ts": start_ts, + "finish_ts": finish_ts, + "client_send_ts": client_send_ts, + "e2e_latency_s": finish_ts - client_send_ts, + "return_code": return_code, + "save_path": save_path, + }, + ) + + reporter.stop() + + +def _launch_worker_processes(args: argparse.Namespace, gpus: list[int]) -> tuple[list[subprocess.Popen], int, bool]: + procs: list[subprocess.Popen] = [] + keep_parallel = _effective_keep_parallel(args) + base_config = _load_base_config_json(args.base_config_json) + world_size = _parallel_world_size_from_config(base_config) if keep_parallel else 1 + if world_size < 1: + world_size = 1 + + requested_workers = int(args.num_workers) + cooperative_parallel = bool(keep_parallel and world_size > 1) + if cooperative_parallel: + if len(gpus) < world_size: + raise RuntimeError(f"cooperative model-parallel requires at least {world_size} gpus, got {len(gpus)} from --gpus={args.gpus}") + if requested_workers != world_size: + logger.warning( + "parallel_world_size={} enabled; forcing num_workers {} -> {} (one worker per rank for one request)", + world_size, + requested_workers, + world_size, + ) + actual_workers = world_size + else: + capacity = max(1, len(gpus) // world_size) + if requested_workers > capacity: + logger.warning( + "num_workers={} exceeds gpu capacity {} for parallel_world_size={}, auto-adjust to {}", + requested_workers, + capacity, + world_size, + capacity, + ) + actual_workers = min(requested_workers, capacity) + + for worker_id in range(actual_workers): + if cooperative_parallel: + # In cooperative mode all ranks must see the full GPU list so legacy + # rank->device mapping by dist.get_rank() stays valid. + gpu_group = gpus[:actual_workers] + gpu_id = gpus[worker_id] + else: + start = worker_id * world_size + gpu_group = gpus[start : start + world_size] + if len(gpu_group) < world_size: + gpu_group = gpus[:world_size] + gpu_id = gpu_group[0] + recv_port = args.worker_base_port + worker_id + monitor_port = args.worker_monitor_base_port + worker_id + + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = ",".join(str(g) for g in gpu_group) + + cmd = [ + args.python_executable, + "-m", + "lightx2v.disagg.examples.run_controller", + "--mode", + "worker", + "--worker_id", + str(worker_id), + "--worker_recv_port", + str(recv_port), + "--worker_gpu", + str(gpu_id), + "--worker_monitor_port", + str(monitor_port), + "--worker_dist_rank", + str(worker_id), + "--worker_dist_world_size", + str(actual_workers), + "--result_port", + str(args.result_port), + "--python_executable", + args.python_executable, + "--model_cls", + args.model_cls, + "--task", + args.task, + "--model_path", + args.model_path, + "--save_dir", + args.save_dir, + ] + + if keep_parallel: + cmd.append("--keep_parallel_config") + if args.drop_parallel_config: + cmd.append("--drop_parallel_config") + if cooperative_parallel: + cmd.append("--worker_cooperative_parallel") + cmd.extend( + [ + "--dist_master_addr", + args.dist_master_addr, + "--dist_master_port", + str(args.dist_master_port), + ] + ) + + procs.append(subprocess.Popen(cmd, env=env)) + return procs, actual_workers, cooperative_parallel + + +def _build_generated_requests(args: argparse.Namespace) -> list[dict[str, Any]]: + base = _load_base_config_json(args.base_config_json) + + base.setdefault("model_cls", args.model_cls) + base.setdefault("task", args.task) + base.setdefault("model_path", args.model_path) + if args.save_result_path: + base["save_path"] = args.save_result_path + if args.prompt: + base.setdefault("prompt", args.prompt) + if args.negative_prompt: + base.setdefault("negative_prompt", args.negative_prompt) + if args.image_path: + base.setdefault("image_path", args.image_path) + if str(base.get("model_cls", args.model_cls)) == "wan2.2_moe": + base.setdefault("boundary", 0.9) + + stages = load_stage_specs() + start_workload_clock() + + requests: list[dict[str, Any]] = [] + for i in range(args.generate_requests): + stage = current_stage(stages) + req = build_payload(base, stage, i) + req["model_path"] = req.get("model_path", args.model_path) + req["model_cls"] = req.get("model_cls", args.model_cls) + req["task"] = req.get("task", args.task) + if str(req.get("model_cls", args.model_cls)) == "wan2.2_moe": + req.setdefault("boundary", 0.9) + requests.append(req) + return requests + + +def _controller_main(args: argparse.Namespace) -> None: + gpus = _parse_gpus(args.gpus) + req_mgr = ReqManager() + + worker_procs, worker_count, cooperative_parallel = _launch_worker_processes(args, gpus) + logger.info("launched {} workers", worker_count) + if worker_count <= 0: + raise RuntimeError("no workers launched, please check --num_workers / --gpus / parallel config") + + monitor_nodes = [f"tcp://127.0.0.1:{args.worker_monitor_base_port + i}" for i in range(worker_count)] + monitor = Monitor(monitor_nodes) + monitor_samples: list[dict[str, Any]] = [] + monitor_stop_event = threading.Event() + global_first_send_ts: float | None = None + + def _monitor_cb(results: list[dict[str, Any]]) -> None: + ts = time.time() + for item in results: + sample = dict(item) + sample["sample_ts"] = ts + sample["sample_ts_from_global_start_s"] = ts - global_first_send_ts if global_first_send_ts is not None else None + monitor_samples.append(sample) + status = sample.get("status") + # if status == "ok": + # logger.info( + # "[monitor] {} gpu_util={} mem={}/{} MB", + # sample.get("address"), + # sample.get("gpu_utilization"), + # sample.get("gpu_memory_used_mb"), + # sample.get("gpu_memory_total_mb"), + # ) + + monitor_thread = threading.Thread( + target=monitor.run_forever, + kwargs={ + "interval_seconds": args.monitor_poll_interval_s, + "callback": _monitor_cb, + "stop_event": monitor_stop_event, + }, + daemon=True, + ) + monitor_thread.start() + + dispatched: dict[int, dict[str, Any]] = {} + completions: list[dict[str, Any]] = [] + + next_worker = 0 + + def _dispatch(payload: dict[str, Any]) -> None: + nonlocal next_worker, global_first_send_ts + req_metrics = payload.setdefault("request_metrics", {}) + request_id = int(req_metrics.get("request_id", len(dispatched))) + req_metrics["request_id"] = request_id + req_metrics.setdefault("client_send_ts", time.time()) + if global_first_send_ts is None: + global_first_send_ts = time.time() + req_metrics["global_first_send_ts"] = global_first_send_ts + + if cooperative_parallel: + dispatch_ts = time.time() + for worker_id in range(worker_count): + recv_port = args.worker_base_port + worker_id + req_mgr.send("127.0.0.1", recv_port, payload) + dispatched[request_id] = { + "worker_id": "cooperative_group", + "dispatch_ts": dispatch_ts, + "payload": payload, + } + logger.info("[dispatch] request_id={} -> cooperative workers [0..{}]", request_id, worker_count - 1) + return + + worker_id = next_worker % worker_count + next_worker += 1 + recv_port = args.worker_base_port + worker_id + + dispatch_ts = time.time() + req_mgr.send("127.0.0.1", recv_port, payload) + dispatched[request_id] = { + "worker_id": worker_id, + "dispatch_ts": dispatch_ts, + "payload": payload, + } + logger.info("[dispatch] request_id={} -> worker={} port={}", request_id, worker_id, recv_port) + + if args.request_source == "generate": + for payload in _build_generated_requests(args): + _dispatch(payload) + if args.generate_interval_s > 0: + time.sleep(args.generate_interval_s) + else: + logger.info("waiting run_user requests on port={}", args.controller_request_port) + workload_end = False + while not workload_end: + payload = req_mgr.receive_non_block(args.controller_request_port) + if payload is None: + time.sleep(args.request_poll_sleep_s) + continue + if isinstance(payload, dict) and payload.get("workload_end"): + workload_end = True + continue + if isinstance(payload, dict): + _dispatch(payload) + + pending_ids = set(dispatched.keys()) + wait_start = time.time() + while pending_ids: + msg = req_mgr.receive_non_block(args.result_port) + if msg is None: + if time.time() - wait_start > args.completion_timeout_s: + logger.warning("timeout waiting completions, pending={}", sorted(pending_ids)) + break + time.sleep(args.request_poll_sleep_s) + continue + + if not isinstance(msg, dict): + continue + + request_id = int(msg.get("request_id", -1)) + finish_ts = float(msg.get("finish_ts", 0.0)) + elapsed_from_global_start_s = finish_ts - global_first_send_ts if global_first_send_ts is not None else None + if elapsed_from_global_start_s is not None: + msg["elapsed_from_global_start_s"] = elapsed_from_global_start_s + if request_id in pending_ids: + pending_ids.remove(request_id) + completions.append(msg) + + logger.info( + "[done] request_id={} worker={} start_ts={:.6f} finish_ts={:.6f} e2e={:.3f}s elapsed_from_global_start={:.3f}s rc={}", + request_id, + msg.get("worker_id"), + float(msg.get("start_ts", 0.0)), + finish_ts, + float(msg.get("e2e_latency_s", 0.0)), + float(elapsed_from_global_start_s if elapsed_from_global_start_s is not None else -1.0), + msg.get("return_code"), + ) + + for worker_id in range(worker_count): + req_mgr.send("127.0.0.1", args.worker_base_port + worker_id, {"__control__": "stop"}) + + for proc in worker_procs: + try: + proc.wait(timeout=10) + except subprocess.TimeoutExpired: + proc.terminate() + + monitor_stop_event.set() + monitor_thread.join(timeout=2.0) + + summary = { + "request_source": args.request_source, + "dispatched_count": len(dispatched), + "completed_count": len(completions), + "requests": completions, + "monitor_samples": monitor_samples, + "generated_at": time.time(), + } + + out_path = Path(args.metrics_output_json) + out_path.parent.mkdir(parents=True, exist_ok=True) + with out_path.open("w", encoding="utf-8") as f: + json.dump(summary, f, ensure_ascii=False, indent=2) + + logger.info("metrics saved to {}", out_path) + + +def main() -> None: + args = _make_parser().parse_args() + if args.mode == "worker": + _worker_main(args) + else: + _controller_main(args) + + +if __name__ == "__main__": + main() diff --git a/lightx2v/disagg/services/controller.py b/lightx2v/disagg/services/controller.py index 403a51e42..fe9ccdd55 100644 --- a/lightx2v/disagg/services/controller.py +++ b/lightx2v/disagg/services/controller.py @@ -66,6 +66,14 @@ def __init__(self): self._slot_reuse_block_until: dict[int, float] = {} self._local_host_aliases: set[str] = set() self._request_metrics_by_room: dict[int, dict[str, Any]] = {} + self._monitor_samples: list[dict[str, Any]] = [] + self._controller_start_ts: float | None = None + self._metrics_output_json = Path( + os.getenv( + "DISAGG_CONTROLLER_METRICS_OUTPUT_JSON", + str(Path(__file__).resolve().parents[3] / "save_results" / "disagg_controller_metrics.json"), + ) + ) def _is_monitor_enabled(self) -> bool: raw = os.getenv("ENABLE_MONITOR") @@ -870,6 +878,16 @@ def _monitor_callback(self, results): if not isinstance(last_scale_ts, dict): return + sample_ts = time.time() + sample_ts_from_start_s = sample_ts - self._controller_start_ts if self._controller_start_ts is not None else None + for item in results: + if not isinstance(item, dict): + continue + sample = dict(item) + sample["sample_ts"] = sample_ts + sample["sample_ts_from_global_start_s"] = sample_ts_from_start_s + self._monitor_samples.append(sample) + if warmup_duration_s > 0.0: elapsed_s = max(0.0, time.monotonic() - autoscale_start_mono) if elapsed_s < warmup_duration_s: @@ -1060,6 +1078,25 @@ def _monitor_callback(self, results): except Exception: pass + def _dump_controller_metrics(self, received_results: list[dict[str, Any]], batch_request_start_ts: float | None) -> Path: + summary = { + "requests": self._to_plain(received_results), + "monitor_samples": self._to_plain(self._monitor_samples), + "generated_at": time.time(), + } + if batch_request_start_ts is not None: + summary["batch_total_time_s"] = time.time() - batch_request_start_ts + if self._controller_start_ts is not None: + summary["controller_uptime_s"] = time.time() - self._controller_start_ts + + out_path = self._metrics_output_json + out_path.parent.mkdir(parents=True, exist_ok=True) + with out_path.open("w", encoding="utf-8") as handle: + json.dump(summary, handle, ensure_ascii=False, indent=2) + + self.logger.info("Controller metrics saved to %s", out_path) + return out_path + def _handle_decoder_result( self, result: Any, @@ -1891,6 +1928,8 @@ def run(self, config): if config is None: raise ValueError("config cannot be None") + self._controller_start_ts = time.time() + self._monitor_samples = [] self._shutting_down = False bootstrap_addr = config.get("data_bootstrap_addr", "127.0.0.1") @@ -2139,3 +2178,8 @@ def run(self, config): for thread in list(self._sidecar_reclaim_threads): if thread.is_alive(): thread.join(timeout=3.0) + + try: + self._dump_controller_metrics(received_results, batch_request_start_ts) + except Exception: + self.logger.exception("Failed to write controller metrics to %s", self._metrics_output_json) diff --git a/scripts/disagg/extract_base_latency.py b/scripts/disagg/extract_base_latency.py new file mode 100644 index 000000000..a65e2c03c --- /dev/null +++ b/scripts/disagg/extract_base_latency.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import csv +import json +from pathlib import Path + + +def _fmt_float3(value): + try: + return f"{float(value):.3f}" + except (TypeError, ValueError): + return "" + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Extract baseline latency rows from baseline_controller_metrics.json") + parser.add_argument( + "--metrics", + default="/root/zht/LightX2V/save_results/baseline_controller_metrics.json", + help="Input baseline metrics json path", + ) + parser.add_argument( + "--output", + default="/root/zht/LightX2V/save_results/base_wan22_i2v.csv", + help="Output csv path", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + metrics_path = Path(args.metrics) + out_path = Path(args.output) + + if not metrics_path.is_file(): + raise FileNotFoundError(f"metrics file not found: {metrics_path}") + + with metrics_path.open("r", encoding="utf-8") as f: + payload = json.load(f) + + requests = payload.get("requests", []) + if not isinstance(requests, list): + raise ValueError(f"invalid metrics format: requests must be a list, got {type(requests)}") + + global_start_ts = None + for item in requests: + if isinstance(item, dict) and item.get("client_send_ts") is not None: + ts = float(item["client_send_ts"]) + global_start_ts = ts if global_start_ts is None else min(global_start_ts, ts) + + rows = [] + for item in requests: + if not isinstance(item, dict): + continue + + finish_rel = item.get("elapsed_from_global_start_s") + if finish_rel is None: + finish_ts = item.get("finish_ts") + if finish_ts is not None and global_start_ts is not None: + finish_rel = float(finish_ts) - float(global_start_ts) + + row = { + "finish_time_from_global_start_s": _fmt_float3(finish_rel), + "e2e_latency_s": _fmt_float3(item.get("e2e_latency_s")), + } + rows.append(row) + + out_path.parent.mkdir(parents=True, exist_ok=True) + with out_path.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["finish_time_from_global_start_s", "e2e_latency_s"]) + writer.writeheader() + for row in rows: + writer.writerow(row) + + print(f"wrote {len(rows)} rows to {out_path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/disagg/extract_base_utilization.py b/scripts/disagg/extract_base_utilization.py new file mode 100644 index 000000000..30ac950a8 --- /dev/null +++ b/scripts/disagg/extract_base_utilization.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import csv +import json +from collections import defaultdict +from pathlib import Path + + +def _fmt_float3(value): + try: + return f"{float(value):.3f}" + except (TypeError, ValueError): + return "" + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Extract average GPU utilization rows from baseline_controller_metrics.json") + parser.add_argument( + "--metrics", + default="/root/zht/LightX2V/save_results/baseline_controller_metrics_4steps_p8.json", + help="Input GPU metrics json path", + ) + parser.add_argument( + "--output", + default="/root/zht/LightX2V/save_results/base_wan22_i2v_util.csv", + help="Output csv path", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + metrics_path = Path(args.metrics) + out_path = Path(args.output) + + if not metrics_path.is_file(): + raise FileNotFoundError(f"metrics file not found: {metrics_path}") + + with metrics_path.open("r", encoding="utf-8") as f: + payload = json.load(f) + + samples = payload.get("monitor_samples", []) + if not isinstance(samples, list): + raise ValueError(f"invalid metrics format: monitor_samples must be a list, got {type(samples)}") + + grouped_samples = defaultdict(list) + for item in samples: + if not isinstance(item, dict): + continue + if item.get("status") != "ok": + continue + sample_ts = item.get("sample_ts_from_global_start_s") + if sample_ts is None: + continue + try: + grouped_samples[float(sample_ts)].append(item) + except (TypeError, ValueError): + continue + + rows = [] + for sample_ts in sorted(grouped_samples): + group = grouped_samples[sample_ts] + gpu_utils = [] + mem_utils = [] + for item in group: + gpu_util = item.get("gpu_utilization") + mem_used = item.get("gpu_memory_used_mb") + mem_total = item.get("gpu_memory_total_mb") + if gpu_util is not None: + gpu_utils.append(float(gpu_util)) + if mem_used is not None and mem_total: + mem_utils.append(float(mem_used) / float(mem_total) * 100.0) + + if not gpu_utils or not mem_utils: + continue + + rows.append( + { + "time_from_start_s": _fmt_float3(sample_ts), + "avg_gpu_utilization": _fmt_float3(sum(gpu_utils) / len(gpu_utils)), + "avg_gpu_memory_occupancy_rate": _fmt_float3(sum(mem_utils) / len(mem_utils)), + } + ) + + out_path.parent.mkdir(parents=True, exist_ok=True) + with out_path.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter( + f, + fieldnames=[ + "time_from_start_s", + "avg_gpu_utilization", + "avg_gpu_memory_occupancy_rate", + ], + ) + writer.writeheader() + for row in rows: + writer.writerow(row) + + print(f"wrote {len(rows)} rows to {out_path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/disagg/extract_disagg_utilization.py b/scripts/disagg/extract_disagg_utilization.py new file mode 100644 index 000000000..bdd0251da --- /dev/null +++ b/scripts/disagg/extract_disagg_utilization.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import csv +import json +from collections import defaultdict +from pathlib import Path + + +def _fmt_float3(value): + try: + return f"{float(value):.3f}" + except (TypeError, ValueError): + return "" + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Extract average GPU utilization rows from disagg_controller_metrics.json") + parser.add_argument( + "--metrics", + default="/root/zht/LightX2V/save_results/disagg_controller_metrics.json", + help="Input disagg metrics json path", + ) + parser.add_argument( + "--output", + default="/root/zht/LightX2V/save_results/disagg_wan22_i2v_util.csv", + help="Output csv path", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + metrics_path = Path(args.metrics) + out_path = Path(args.output) + + if not metrics_path.is_file(): + raise FileNotFoundError(f"metrics file not found: {metrics_path}") + + with metrics_path.open("r", encoding="utf-8") as f: + payload = json.load(f) + + samples = payload.get("monitor_samples", []) + if not isinstance(samples, list): + raise ValueError(f"invalid metrics format: monitor_samples must be a list, got {type(samples)}") + + grouped_samples = defaultdict(list) + for item in samples: + if not isinstance(item, dict): + continue + if item.get("status") != "ok": + continue + sample_ts = item.get("sample_ts_from_global_start_s") + if sample_ts is None: + continue + try: + grouped_samples[float(sample_ts)].append(item) + except (TypeError, ValueError): + continue + + rows = [] + for sample_ts in sorted(grouped_samples): + group = grouped_samples[sample_ts] + gpu_utils = [] + mem_utils = [] + for item in group: + gpu_util = item.get("gpu_utilization") + mem_used = item.get("gpu_memory_used_mb") + mem_total = item.get("gpu_memory_total_mb") + if gpu_util is not None: + gpu_utils.append(float(gpu_util)) + if mem_used is not None and mem_total: + mem_utils.append(float(mem_used) / float(mem_total) * 100.0) + + if not gpu_utils or not mem_utils: + continue + + rows.append( + { + "time_from_start_s": _fmt_float3(sample_ts), + "avg_gpu_utilization": _fmt_float3(sum(gpu_utils) / len(gpu_utils)), + "avg_gpu_memory_occupancy_rate": _fmt_float3(sum(mem_utils) / len(mem_utils)), + } + ) + + out_path.parent.mkdir(parents=True, exist_ok=True) + with out_path.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter( + f, + fieldnames=[ + "time_from_start_s", + "avg_gpu_utilization", + "avg_gpu_memory_occupancy_rate", + ], + ) + writer.writeheader() + for row in rows: + writer.writerow(row) + + print(f"wrote {len(rows)} rows to {out_path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/disagg/kill_base.sh b/scripts/disagg/kill_base.sh new file mode 100755 index 000000000..55abdd03d --- /dev/null +++ b/scripts/disagg/kill_base.sh @@ -0,0 +1,175 @@ +#!/bin/bash + +set -euo pipefail + +SCRIPT_NAMES=("run_baseline.sh") + +lightx2v_path=${LIGHTX2V_PATH:-/root/zht/LightX2V} +controller_request_port=${BASELINE_CONTROLLER_REQUEST_PORT:-12786} +result_port=${BASELINE_RESULT_PORT:-12787} +worker_base_port=${BASELINE_WORKER_BASE_PORT:-12888} +worker_monitor_base_port=${BASELINE_WORKER_MONITOR_BASE_PORT:-7888} + +if [[ -n "${BASELINE_NUM_WORKERS:-}" ]]; then + num_workers=${BASELINE_NUM_WORKERS} +elif [[ -n "${NUM_GPUS:-}" ]]; then + num_workers=${NUM_GPUS} +elif [[ -n "${CUDA_VISIBLE_DEVICES:-}" ]]; then + num_workers=$(awk -F',' '{print NF}' <<<"${CUDA_VISIBLE_DEVICES}") +else + num_workers=8 +fi + +if [[ ! "${num_workers}" =~ ^[0-9]+$ ]] || (( num_workers <= 0 )); then + num_workers=8 +fi + +declare -a PORTS=(${controller_request_port} ${result_port}) +for ((i = 0; i < num_workers; i++)); do + PORTS+=($((worker_base_port + i))) + PORTS+=($((worker_monitor_base_port + i))) +done + +# Also cover a small tail range in case worker count changed between runs. +for extra in $(seq 0 31); do + PORTS+=($((worker_base_port + extra))) + PORTS+=($((worker_monitor_base_port + extra))) +done + +mapfile -t PORTS < <(printf '%s\n' "${PORTS[@]}" | awk 'NF && !seen[$0]++ { print $0 }' | sort -n) + +declare -a PROTECTED_PIDS=() + +collect_protected_pids() { + local cur="$$" + while [[ -n "$cur" && "$cur" != "0" ]]; do + PROTECTED_PIDS+=("$cur") + local parent + parent=$(ps -o ppid= -p "$cur" 2>/dev/null | tr -d ' ' || true) + if [[ -z "$parent" || "$parent" == "$cur" ]]; then + break + fi + cur="$parent" + done +} + +is_protected_pid() { + local target="$1" + for p in "${PROTECTED_PIDS[@]}"; do + if [[ "$p" == "$target" ]]; then + return 0 + fi + done + return 1 +} + +kill_pid_gracefully() { + local pid="$1" + if [[ -z "$pid" ]]; then + return + fi + if is_protected_pid "$pid"; then + echo "Skip protected pid=$pid" + return + fi + if kill -0 "$pid" 2>/dev/null; then + kill "$pid" 2>/dev/null || true + sleep 1 + if kill -0 "$pid" 2>/dev/null; then + kill -9 "$pid" 2>/dev/null || true + fi + fi +} + +find_listen_pids_by_port() { + local port="$1" + + if command -v lsof >/dev/null 2>&1; then + lsof -nP -t -iTCP:"$port" -sTCP:LISTEN 2>/dev/null | sort -u || true + return + fi + + if command -v ss >/dev/null 2>&1; then + ss -ltnp 2>/dev/null | awk -v p=":$port" ' + index($4, p) > 0 { + while (match($0, /pid=[0-9]+/)) { + print substr($0, RSTART + 4, RLENGTH - 4) + $0 = substr($0, RSTART + RLENGTH) + } + } + ' | sort -u || true + return + fi + + if command -v fuser >/dev/null 2>&1; then + fuser -n tcp "$port" 2>/dev/null | tr ' ' '\n' | sed '/^$/d' | sort -u || true + return + fi + + echo "No supported tool found to query listening ports (need one of: lsof, ss, fuser)." >&2 +} + +collect_protected_pids + +for script_name in "${SCRIPT_NAMES[@]}"; do + echo "Stopping script process: ${script_name}" + script_pids=$(pgrep -f "$script_name" || true) + if [[ -z "${script_pids}" ]]; then + echo "No running process found for ${script_name}" + continue + fi + while read -r pid; do + [[ -z "$pid" ]] && continue + echo "Killing script pid=$pid" + kill_pid_gracefully "$pid" + done <<< "$script_pids" +done + +cleanup_patterns=( + "lightx2v.disagg.examples.run_controller" + "lightx2v.disagg.examples.infer" + "python -m lightx2v.disagg.examples.run_controller" + "python -m lightx2v.disagg.examples.infer" + "conda run -n lightx2v bash ${lightx2v_path}/scripts/disagg/run_baseline.sh" +) + +for pattern in "${cleanup_patterns[@]}"; do + echo "Stopping processes matching pattern: ${pattern}" + matched_pids=$(pgrep -f "$pattern" || true) + if [[ -z "${matched_pids}" ]]; then + echo "No process matched: ${pattern}" + continue + fi + while read -r pid; do + [[ -z "$pid" ]] && continue + echo "Killing matched pid=$pid" + kill_pid_gracefully "$pid" + done <<< "$matched_pids" +done + +for port in "${PORTS[@]}"; do + echo "Stopping listeners on port ${port}" + port_pids=$(find_listen_pids_by_port "$port") + if [[ -z "${port_pids}" ]]; then + echo "No listener found on port ${port}" + continue + fi + + while read -r pid; do + [[ -z "$pid" ]] && continue + echo "Killing pid=$pid on port ${port}" + kill_pid_gracefully "$pid" + done <<< "$port_pids" + + remaining=$(find_listen_pids_by_port "$port") + if [[ -n "${remaining}" ]]; then + echo "Warning: port ${port} still has listeners: ${remaining}" + else + echo "Port ${port} is clear" + fi +done + +# Best effort cleanup for per-request temp dirs created by run_controller workers. +find /tmp -maxdepth 1 -type d -name 'baseline_req_*' -exec rm -rf {} + 2>/dev/null || true + +echo "kill_base.sh done." diff --git a/scripts/disagg/run_baseline.sh b/scripts/disagg/run_baseline.sh new file mode 100644 index 000000000..aaa178f16 --- /dev/null +++ b/scripts/disagg/run_baseline.sh @@ -0,0 +1,79 @@ +#!/bin/bash + +set -euo pipefail + +lightx2v_path=${LIGHTX2V_PATH:-/root/zht/LightX2V} +model_path=${WAN22_MOE_MODEL_PATH:-/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B} +config_json=${BASELINE_CONFIG_JSON:-${lightx2v_path}/configs/disagg/baseline/wan22_moe_i2v_baseline.json} + +# base.sh expects PYTHONPATH to be defined under `set -u`. +export PYTHONPATH=${PYTHONPATH:-} + +baseline_conda_env=${BASELINE_CONDA_ENV:-lightx2v} +if [[ "${BASELINE_SKIP_CONDA_ACTIVATE:-0}" != "1" ]]; then + if [[ "${CONDA_DEFAULT_ENV:-}" != "${baseline_conda_env}" ]]; then + if ! command -v conda >/dev/null 2>&1; then + echo "ERROR: conda is not available, cannot activate env ${baseline_conda_env}" >&2 + exit 2 + fi + set +u + eval "$(conda shell.bash hook)" + conda activate "${baseline_conda_env}" + set -u + echo "activated conda env: ${baseline_conda_env}" + fi +fi + +if [[ -n "${NUM_GPUS:-}" ]]; then + num_gpus=${NUM_GPUS} +elif [[ -n "${CUDA_VISIBLE_DEVICES:-}" ]]; then + num_gpus=$(awk -F',' '{print NF}' <<<"${CUDA_VISIBLE_DEVICES}") +else + num_gpus=8 +fi + +if [[ -z "${CUDA_VISIBLE_DEVICES:-}" ]]; then + export CUDA_VISIBLE_DEVICES=$(seq -s, 0 $((num_gpus - 1))) +fi + +source "${lightx2v_path}/scripts/base/base.sh" + +baseline_log=${BASELINE_LOG:-${lightx2v_path}/save_results/baseline_wan22_i2v_single_task.log} +prompt=${BASELINE_PROMPT:-"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard."} +negative_prompt=${BASELINE_NEGATIVE_PROMPT:-"镜头晃动,色调艳丽,过曝,静态"} +image_path=${BASELINE_IMAGE_PATH:-${lightx2v_path}/assets/inputs/imgs/img_0.jpg} +save_result_path=${BASELINE_SAVE_RESULT_PATH:-${lightx2v_path}/save_results/output_lightx2v_wan22_moe_i2v_baseline.mp4} +python_executable=${BASELINE_PYTHON_EXECUTABLE:-python} +model_cls=${BASELINE_MODEL_CLS:-wan2.2_moe} +task=${BASELINE_TASK:-i2v} + +mkdir -p "${lightx2v_path}/save_results" +exec > "${baseline_log}" 2>&1 + +gpus_csv="${CUDA_VISIBLE_DEVICES}" +metrics_output_json=${BASELINE_METRICS_OUTPUT_JSON:-${lightx2v_path}/save_results/baseline_controller_metrics.json} +dist_master_addr=${BASELINE_DIST_MASTER_ADDR:-127.0.0.1} +dist_master_port=${BASELINE_DIST_MASTER_PORT:-29600} +request_source=${BASELINE_REQUEST_SOURCE:-generate} +generate_requests=${BASELINE_GENERATE_REQUESTS:-30} +num_workers=${BASELINE_NUM_WORKERS:-1} + +"${python_executable}" -m lightx2v.disagg.examples.run_controller \ +--mode controller \ +--request_source "${request_source}" \ +--generate_requests "${generate_requests}" \ +--num_workers "${num_workers}" \ +--gpus "${gpus_csv}" \ +--dist_master_addr "${dist_master_addr}" \ +--dist_master_port "${dist_master_port}" \ +--python_executable "${python_executable}" \ +--model_cls "${model_cls}" \ +--task "${task}" \ +--model_path "${model_path}" \ +--base_config_json "${config_json}" \ +--prompt "${prompt}" \ +--negative_prompt "${negative_prompt}" \ +--image_path "${image_path}" \ +--save_result_path "${save_result_path}" \ +--save_dir "$(dirname "${save_result_path}")" \ +--metrics_output_json "${metrics_output_json}" diff --git a/scripts/disagg/run_dynamic.sh b/scripts/disagg/run_dynamic.sh index 971ae14f8..c406906bb 100644 --- a/scripts/disagg/run_dynamic.sh +++ b/scripts/disagg/run_dynamic.sh @@ -70,6 +70,7 @@ if [[ -z "${RDMA_PREFERRED_IPV4:-}" && -n "${derived_controller_host}" ]]; then fi export DISAGG_CONTROLLER_REQUEST_PORT=${DISAGG_CONTROLLER_REQUEST_PORT:-12786} export LOAD_FROM_USER=${LOAD_FROM_USER:-0} +export ENABLE_MONITOR=${ENABLE_MONITOR:-1} # multi_node: remote ranks (e.g. slow encoder/decoder host) may need longer TCP/ready waits. if [[ "${topology}" == "single_node" ]]; then export DISAGG_INSTANCE_START_TIMEOUT_SECONDS=${DISAGG_INSTANCE_START_TIMEOUT_SECONDS:-90} @@ -127,6 +128,7 @@ echo "controller_cfg=${controller_cfg}" echo "DISAGG_CONTROLLER_HOST=${DISAGG_CONTROLLER_HOST} DISAGG_CONTROLLER_REQUEST_PORT=${DISAGG_CONTROLLER_REQUEST_PORT}" echo "RDMA_PREFERRED_IPV4=${RDMA_PREFERRED_IPV4:-}" echo "DISAGG_AUTO_REQUEST_COUNT=${DISAGG_AUTO_REQUEST_COUNT}" +echo "ENABLE_MONITOR=${ENABLE_MONITOR}" echo "DISAGG_ENABLE_NSYS=${DISAGG_ENABLE_NSYS} DISAGG_NSYS_OUTPUT_DIR=${DISAGG_NSYS_OUTPUT_DIR} DISAGG_NSYS_TRACE=${DISAGG_NSYS_TRACE}" echo "SYNC_COMM=${SYNC_COMM}" echo "LOAD_FROM_USER=${LOAD_FROM_USER} USER_START_DELAY_S=${user_start_delay_s} USER_MAX_REQUESTS=${user_max_requests}"