Skip to content

Commit 8c4092b

Browse files
authored
add base test (#1044)
This pull request introduces new configuration files, adds a flexible inference script, and implements tools for extracting and analyzing controller metrics related to latency and GPU utilization. It also enhances the controller service to record and export detailed metrics for improved monitoring and analysis. **Key changes:** Configuration and Inference: * Added two new model configuration files for different inference settings: `wan22_moe_i2v_baseline.json` (4-step inference) and `wan22_moe_i2v_baseline_50steps.json` (50-step inference), supporting the Wan2.2 I2V model with parallelization options. [[1]](diffhunk://#diff-54e707f816456395c7a8b8a1d469798796ccd0fe68b1f3baa996311d4db4df3fR1-R48) [[2]](diffhunk://#diff-7e5b85d9092b5921ea7744d9b62c0b8756183336eb41af8a7818e87e7557d2e4R1-R36) * Introduced a new, comprehensive `infer.py` script for running inference with various LightX2V models, supporting a wide range of tasks and flexible command-line arguments. Controller Metrics and Monitoring: * Enhanced the controller (`controller.py`) to record detailed monitoring samples, track controller uptime, and save all metrics (including requests and periodic samples) to a configurable JSON file on shutdown. This enables more thorough post-run analysis. [[1]](diffhunk://#diff-cd2bb47b50681fc168f658410a3c5239855d829c498e4f0b40b06a370816ba84R69-R76) [[2]](diffhunk://#diff-cd2bb47b50681fc168f658410a3c5239855d829c498e4f0b40b06a370816ba84R881-R890) [[3]](diffhunk://#diff-cd2bb47b50681fc168f658410a3c5239855d829c498e4f0b40b06a370816ba84R1081-R1099) [[4]](diffhunk://#diff-cd2bb47b50681fc168f658410a3c5239855d829c498e4f0b40b06a370816ba84R1931-R1932) [[5]](diffhunk://#diff-cd2bb47b50681fc168f658410a3c5239855d829c498e4f0b40b06a370816ba84R2181-R2185) Metrics Extraction Utilities: * Added `extract_base_latency.py` to extract and export baseline request latency data from metrics JSON to CSV, including finish time and e2e latency. * Added `extract_base_utilization.py` to process and export average GPU utilization and memory occupancy rates from controller monitoring samples to CSV for further analysis.
1 parent 3d0c0fa commit 8c4092b

11 files changed

Lines changed: 1484 additions & 0 deletions

File tree

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
{
2+
"infer_steps": 4,
3+
"in_dim": 36,
4+
"dim": 5120,
5+
"ffn_dim": 13824,
6+
"freq_dim": 256,
7+
"num_heads": 40,
8+
"num_layers": 40,
9+
"out_dim": 16,
10+
"eps": 1e-06,
11+
"model_type": "i2v",
12+
"target_video_length": 81,
13+
"text_len": 512,
14+
"target_height": 480,
15+
"target_width": 832,
16+
"self_attn_1_type": "sage_attn2",
17+
"cross_attn_1_type": "sage_attn2",
18+
"cross_attn_2_type": "sage_attn2",
19+
"sample_guide_scale": [
20+
3.5,
21+
3.5
22+
],
23+
"sample_shift": 5.0,
24+
"enable_cfg": false,
25+
"cpu_offload": true,
26+
"offload_granularity": "block",
27+
"t5_cpu_offload": false,
28+
"vae_cpu_offload": false,
29+
"fps": 16,
30+
"use_image_encoder": false,
31+
"boundary_step_index": 2,
32+
"denoising_step_list": [
33+
1000,
34+
750,
35+
500,
36+
250
37+
],
38+
"dit_quantized": true,
39+
"dit_quant_scheme": "int8-q8f",
40+
"high_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors",
41+
"low_noise_quantized_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors",
42+
"high_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step.safetensors",
43+
"low_noise_original_ckpt": "/root/zht/LightX2V/models/lightx2v/Wan2.2-Distill-Models/wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors",
44+
"parallel": {
45+
"seq_p_size": 4,
46+
"seq_p_attn_type": "ulysses"
47+
}
48+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
{
2+
"model_path": "/root/zht/LightX2V/models/Wan-AI/Wan2.2-I2V-A14B",
3+
"infer_steps": 50,
4+
"in_dim": 36,
5+
"dim": 5120,
6+
"ffn_dim": 13824,
7+
"freq_dim": 256,
8+
"num_heads": 40,
9+
"num_layers": 40,
10+
"out_dim": 16,
11+
"eps": 1e-06,
12+
"model_type": "i2v",
13+
"target_video_length": 81,
14+
"text_len": 512,
15+
"target_height": 480,
16+
"target_width": 832,
17+
"self_attn_1_type": "sage_attn2",
18+
"cross_attn_1_type": "sage_attn2",
19+
"cross_attn_2_type": "sage_attn2",
20+
"sample_guide_scale": [
21+
3.5,
22+
3.5
23+
],
24+
"sample_shift": 5.0,
25+
"enable_cfg": false,
26+
"cpu_offload": true,
27+
"offload_granularity": "block",
28+
"t5_cpu_offload": false,
29+
"vae_cpu_offload": false,
30+
"fps": 16,
31+
"use_image_encoder": false,
32+
"parallel": {
33+
"seq_p_size": 8,
34+
"seq_p_attn_type": "ulysses"
35+
}
36+
}

lightx2v/disagg/examples/infer.py

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
import argparse
2+
import os
3+
4+
import torch
5+
import torch.distributed as dist
6+
from loguru import logger
7+
8+
from lightx2v.common.ops import *
9+
from lightx2v.models.runners.bagel.bagel_runner import BagelRunner # noqa: F401
10+
from lightx2v.models.runners.flux2.flux2_runner import Flux2DevRunner, Flux2KleinRunner # noqa: F401
11+
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner # noqa: F401
12+
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner # noqa: F401
13+
from lightx2v.models.runners.longcat_image.longcat_image_runner import LongCatImageRunner # noqa: F401
14+
from lightx2v.models.runners.ltx2.ltx2_runner import LTX2Runner # noqa: F401
15+
from lightx2v.models.runners.neopp.neopp_runner import NeoppRunner # noqa: F401
16+
from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner # noqa: F401
17+
from lightx2v.models.runners.seedvr.seedvr_runner import SeedVRRunner # noqa: F401
18+
from lightx2v.models.runners.wan.wan_animate_runner import WanAnimateRunner # noqa: F401
19+
from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, WanAudioRunner # noqa: F401
20+
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401
21+
from lightx2v.models.runners.wan.wan_matrix_game2_runner import WanSFMtxg2Runner # noqa: F401
22+
from lightx2v.models.runners.wan.wan_matrix_game3_runner import WanMatrixGame3Runner # noqa: F401
23+
from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401
24+
from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner # noqa: F401
25+
from lightx2v.models.runners.wan.wan_vace_runner import Wan22MoeVaceRunner, WanVaceRunner # noqa: F401
26+
from lightx2v.models.runners.worldplay.worldplay_ar_runner import WorldPlayARRunner # noqa: F401
27+
from lightx2v.models.runners.worldplay.worldplay_bi_runner import WorldPlayBIRunner # noqa: F401
28+
from lightx2v.models.runners.worldplay.worldplay_distill_runner import WorldPlayDistillRunner # noqa: F401
29+
from lightx2v.models.runners.z_image.z_image_runner import ZImageRunner # noqa: F401
30+
from lightx2v.utils.envs import *
31+
from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
32+
from lightx2v.utils.profiler import *
33+
from lightx2v.utils.registry_factory import RUNNER_REGISTER
34+
from lightx2v.utils.set_config import print_config, set_config, set_parallel_config
35+
from lightx2v.utils.utils import seed_all, validate_config_paths
36+
from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER
37+
38+
try:
39+
from lightx2v.models.runners.worldmirror.worldmirror_runner import WorldMirrorRunner # noqa: F401
40+
except Exception as exc: # pragma: no cover - optional dependency guard
41+
logger.warning("WorldMirrorRunner import skipped: {}", exc)
42+
43+
44+
def init_runner(config):
45+
torch.set_grad_enabled(False)
46+
runner = RUNNER_REGISTER[config["model_cls"]](config)
47+
runner.init_modules()
48+
return runner
49+
50+
51+
def main():
52+
parser = argparse.ArgumentParser()
53+
parser.add_argument("--seed", type=int, default=42, help="The seed for random generator")
54+
parser.add_argument(
55+
"--model_cls",
56+
type=str,
57+
required=True,
58+
choices=[
59+
"wan2.1",
60+
"wan2.1_distill",
61+
"wan2.1_mean_flow_distill",
62+
"wan2.1_vace",
63+
"wan2.1_sf",
64+
"wan2.1_sf_mtxg2",
65+
"seko_talk",
66+
"wan2.2_moe",
67+
"lingbot_world",
68+
"wan2.2",
69+
"wan2.2_matrix_game3",
70+
"wan2.2_moe_audio",
71+
"wan2.2_audio",
72+
"wan2.2_moe_distill",
73+
"wan2.2_moe_vace",
74+
"qwen_image",
75+
"longcat_image",
76+
"wan2.2_animate",
77+
"hunyuan_video_1.5",
78+
"hunyuan_video_1.5_distill",
79+
"worldplay_distill",
80+
"worldplay_ar",
81+
"worldplay_bi",
82+
"z_image",
83+
"flux2_klein",
84+
"flux2_dev",
85+
"ltx2",
86+
"bagel",
87+
"seedvr2",
88+
"neopp",
89+
"lingbot_world_fast",
90+
"worldmirror",
91+
],
92+
default="wan2.1",
93+
)
94+
95+
parser.add_argument("--task", type=str, choices=["t2v", "i2v", "t2i", "i2i", "flf2v", "vace", "animate", "s2v", "rs2v", "t2av", "i2av", "ltx2_s2v", "sr", "recon"], default="t2v")
96+
parser.add_argument("--support_tasks", type=str, nargs="+", default=[], help="Set supported tasks for the model")
97+
parser.add_argument("--model_path", type=str, required=True)
98+
parser.add_argument("--sf_model_path", type=str, required=False)
99+
parser.add_argument("--config_json", type=str, required=True)
100+
parser.add_argument("--use_prompt_enhancer", action="store_true")
101+
102+
parser.add_argument("--prompt", type=str, default="", help="The input prompt for text-to-video generation")
103+
parser.add_argument("--negative_prompt", type=str, default="")
104+
105+
parser.add_argument(
106+
"--image_path",
107+
type=str,
108+
default="",
109+
help="The path to input image file(s) for image-to-video (i2v) or image-to-audio-video (i2av) task. Multiple paths should be comma-separated. Example: 'path1.jpg,path2.jpg'",
110+
)
111+
parser.add_argument("--last_frame_path", type=str, default="", help="The path to last frame file for first-last-frame-to-video (flf2v) task")
112+
parser.add_argument(
113+
"--audio_path",
114+
type=str,
115+
default="",
116+
help="Input audio path: Wan s2v / rs2v, or required for LTX-2 task ltx2_s2v.",
117+
)
118+
parser.add_argument("--image_strength", type=str, default="1.0", help="i2av: single float, or comma-separated floats (one per image, or one value broadcast). Example: 1.0 or 1.0,0.85,0.9")
119+
parser.add_argument(
120+
"--image_frame_idx", type=str, default="", help="i2av: comma-separated pixel frame indices (one per image). Omit or empty to evenly space frames in [0, num_frames-1]. Example: 0,40,80"
121+
)
122+
# [Warning] For vace task, need refactor.
123+
parser.add_argument(
124+
"--src_ref_images",
125+
type=str,
126+
default=None,
127+
help="The file list of the source reference images. Separated by ','. Default None.",
128+
)
129+
parser.add_argument(
130+
"--src_video",
131+
type=str,
132+
default=None,
133+
help="The file of the source video. Default None.",
134+
)
135+
parser.add_argument(
136+
"--src_mask",
137+
type=str,
138+
default=None,
139+
help="The file of the source mask. Default None.",
140+
)
141+
parser.add_argument(
142+
"--src_pose_path",
143+
type=str,
144+
default=None,
145+
help="The file of the source pose. Default None.",
146+
)
147+
parser.add_argument(
148+
"--src_face_path",
149+
type=str,
150+
default=None,
151+
help="The file of the source face. Default None.",
152+
)
153+
parser.add_argument(
154+
"--src_bg_path",
155+
type=str,
156+
default=None,
157+
help="The file of the source background. Default None.",
158+
)
159+
parser.add_argument(
160+
"--src_mask_path",
161+
type=str,
162+
default=None,
163+
help="The file of the source mask. Default None.",
164+
)
165+
parser.add_argument(
166+
"--pose",
167+
type=str,
168+
default=None,
169+
help="Pose string (e.g., 'w-3, right-0.5') or JSON file path for WorldPlay models.",
170+
)
171+
parser.add_argument(
172+
"--action_path",
173+
type=str,
174+
default=None,
175+
help="Directory path for lingbot camera/action control files (poses.npy, intrinsics.npy, optional action.npy).",
176+
)
177+
parser.add_argument(
178+
"--action_ckpt",
179+
type=str,
180+
default=None,
181+
help="Path to action model checkpoint for WorldPlay models.",
182+
)
183+
# WorldMirror (3D reconstruction) specific
184+
parser.add_argument("--input_path", type=str, default=None, help="(worldmirror/recon) Path to a directory of images, a video file, or a single image.")
185+
parser.add_argument("--strict_output_path", type=str, default=None, help="(worldmirror/recon) If set, write outputs directly here instead of under save_result_path/<subdir>/<timestamp>/.")
186+
parser.add_argument("--prior_cam_path", type=str, default=None, help="(worldmirror/recon) Optional camera prior JSON (extrinsics + intrinsics).")
187+
parser.add_argument("--prior_depth_path", type=str, default=None, help="(worldmirror/recon) Optional depth prior directory (one .npy/.png per image).")
188+
parser.add_argument("--subfolder", type=str, default=None, help="(worldmirror/recon) Subfolder inside model_path containing weights. Overrides config.")
189+
parser.add_argument("--disable_heads", type=str, nargs="*", default=None, help="(worldmirror/recon) Heads to disable: any of camera depth normal points gs.")
190+
parser.add_argument("--enable_bf16", action="store_true", default=False, help="(worldmirror/recon) Run the WorldMirror model in bf16.")
191+
parser.add_argument("--save_rendered", action="store_true", default=False, help="(worldmirror/recon) Render an interpolated fly-through video from Gaussian splats.")
192+
parser.add_argument("--render_interp_per_pair", type=int, default=None, help="(worldmirror/recon) Interpolated frames per camera pair for --save_rendered.")
193+
parser.add_argument("--render_depth", action="store_true", default=False, help="(worldmirror/recon) Also render a depth video with --save_rendered.")
194+
parser.add_argument("--wm_config_path", type=str, default=None, help="(worldmirror/recon) Optional training YAML (pair with --wm_ckpt_path).")
195+
parser.add_argument("--wm_ckpt_path", type=str, default=None, help="(worldmirror/recon) Optional .ckpt/.safetensors (pair with --wm_config_path).")
196+
197+
parser.add_argument("--save_result_path", type=str, default=None, help="The path to save video path/file")
198+
parser.add_argument("--return_result_tensor", action="store_true", help="Whether to return result tensor. (Useful for comfyui)")
199+
parser.add_argument("--target_shape", type=int, nargs="+", default=[], help="Set return video or image shape")
200+
parser.add_argument("--target_video_length", type=int, default=81, help="The target video length for each generated clip")
201+
parser.add_argument("--aspect_ratio", type=str, default="")
202+
parser.add_argument("--video_path", type=str, default=None, help="input video path(for sr/v2v task)")
203+
parser.add_argument("--sr_ratio", type=float, default=2.0, help="super resolution ratio for sr task")
204+
parser.add_argument(
205+
"--num_iterations",
206+
type=int,
207+
default=None,
208+
help="Override the number of Matrix-Game-3 generation segments. Final video length follows 57 + 40 * (num_iterations - 1).",
209+
)
210+
211+
args = parser.parse_args()
212+
# validate_task_arguments(args)
213+
214+
seed_all(args.seed)
215+
216+
# set config
217+
config = set_config(args)
218+
# init input_info
219+
input_info = init_empty_input_info(args.task, args.support_tasks)
220+
221+
if config["parallel"]:
222+
platform_device = PLATFORM_DEVICE_REGISTER.get(os.getenv("PLATFORM", "cuda"), None)
223+
platform_device.init_parallel_env()
224+
set_parallel_config(config)
225+
226+
print_config(config)
227+
228+
validate_config_paths(config)
229+
230+
with ProfilingContext4DebugL1("Total Cost"):
231+
# init runner
232+
runner = init_runner(config)
233+
# start to infer
234+
data = args.__dict__
235+
update_input_info_from_dict(input_info, data)
236+
runner.run_pipeline(input_info)
237+
238+
# Clean up distributed process group
239+
if dist.is_initialized():
240+
dist.destroy_process_group()
241+
logger.info("Distributed process group cleaned up")
242+
243+
244+
if __name__ == "__main__":
245+
main()

0 commit comments

Comments
 (0)