Skip to content

Commit 69648c8

Browse files
Update Neo++ (#991)
Co-authored-by: shihaobai <1798930569@qq.com>
1 parent 7216292 commit 69648c8

10 files changed

Lines changed: 329 additions & 201 deletions

File tree

examples/neopp/neopp_dense_1k.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,41 +5,60 @@
55
# -------------------------------------------------
66

77
pipe = LightX2VPipeline(
8-
model_path="/data/nvme1/yongyang/FL/neo9b/neo9b",
8+
model_path="/data/nvme1/yongyang/FL/neo_9b_new/hf_step4000_ema",
99
model_cls="neopp",
1010
support_tasks=["t2i", "i2i"],
1111
)
1212

1313
pipe.create_generator(config_json="../../configs/neopp/neopp_dense.json")
14-
pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False})
14+
pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False, "save_result_for_debug": True})
1515

1616

1717
# -------------------------------------------------
1818
# Load KV cache and generate
1919
# -------------------------------------------------
2020

21-
pipe.runner.load_kvcache_t2i(
22-
"/data/nvme1/yongyang/FL/neo_test9b/vlm_tensor/to_x2v_cond_kv.pt",
23-
"/data/nvme1/yongyang/FL/neo_test9b/vlm_tensor/to_x2v_uncond_kv.pt",
21+
# -------------------------------------------------
22+
# TURN 0
23+
# -------------------------------------------------
24+
pipe.runner.load_kvcache(
25+
"/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor/to_x2v_cond_kv_0_289.pt",
26+
"/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor/to_x2v_uncond_kv_0_9.pt",
27+
)
28+
pipe.runner.set_inference_params(
29+
index_offset_cond=289,
30+
index_offset_uncond=9,
31+
cfg_interval=(-1, 2),
32+
cfg_scale=4.0,
33+
cfg_norm="global",
34+
timestep_shift=3.0,
2435
)
2536

2637
pipe.generate(
2738
seed=200,
28-
task="t2i",
29-
save_result_path="/data/nvme1/yongyang/FL/LightX2V/save_results/output_lightx2v_neopp_dense_t2i_1k.png",
39+
save_result_path="/data/nvme1/yongyang/FL/LightX2V/save_results/output_lightx2v_neopp_dense_1k_0.png",
3040
target_shape=[1024, 1024], # Height, Width
3141
)
3242

3343

34-
pipe.runner.load_kvcache_i2i(
35-
"/data/nvme1/yongyang/FL/neo_test9b/vlm_tensor_it2i/to_x2v_cond_kv.pt",
36-
"/data/nvme1/yongyang/FL/neo_test9b/vlm_tensor_it2i/to_x2v_uncond_kv_text.pt",
37-
"/data/nvme1/yongyang/FL/neo_test9b/vlm_tensor_it2i/to_x2v_uncond_kv_img.pt",
44+
# -------------------------------------------------
45+
# TURN 1
46+
# -------------------------------------------------
47+
pipe.runner.load_kvcache(
48+
"/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor/to_x2v_cond_kv_1_346.pt",
49+
"/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor/to_x2v_uncond_kv_1_12.pt",
50+
)
51+
pipe.runner.set_inference_params(
52+
index_offset_cond=346,
53+
index_offset_uncond=12,
54+
cfg_interval=(-1, 2),
55+
cfg_scale=4.0,
56+
cfg_norm="global",
57+
timestep_shift=3.0,
3858
)
3959

4060
pipe.generate(
4161
seed=200,
42-
task="i2i",
43-
save_result_path="/data/nvme1/yongyang/FL/LightX2V/save_results/output_lightx2v_neopp_dense_i2i_1k.png",
62+
save_result_path="/data/nvme1/yongyang/FL/LightX2V/save_results/output_lightx2v_neopp_dense_1k_1.png",
4463
target_shape=[1024, 1024], # Height, Width
4564
)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import torch.distributed as dist
2+
3+
from lightx2v import LightX2VPipeline
4+
5+
# -------------------------------------------------
6+
# Initialize pipeline for NeoPP
7+
# -------------------------------------------------
8+
9+
pipe = LightX2VPipeline(
10+
model_path="/data/nvme1/yongyang/FL/neo9b/neo9b",
11+
model_cls="neopp",
12+
support_tasks=["t2i", "i2i"],
13+
)
14+
15+
pipe.create_generator(config_json="../../configs/neopp/neopp_dense_cfg2.json")
16+
pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False, "save_result_for_debug": True})
17+
18+
19+
# -------------------------------------------------
20+
# Load KV cache and generate
21+
# -------------------------------------------------
22+
23+
# -------------------------------------------------
24+
# TURN 0
25+
# -------------------------------------------------
26+
pipe.runner.load_kvcache(
27+
"/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor/to_x2v_cond_kv_0_289.pt",
28+
"/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor/to_x2v_uncond_kv_0_9.pt",
29+
)
30+
pipe.runner.set_inference_params(
31+
index_offset_cond=289,
32+
index_offset_uncond=9,
33+
cfg_interval=(-1, 2),
34+
cfg_scale=4.0,
35+
cfg_norm="global",
36+
timestep_shift=3.0,
37+
)
38+
39+
pipe.generate(
40+
seed=200,
41+
save_result_path="/data/nvme1/yongyang/FL/LightX2V/save_results/output_lightx2v_neopp_dense_1k_0.png",
42+
target_shape=[1024, 1024], # Height, Width
43+
)
44+
45+
46+
# -------------------------------------------------
47+
# TURN 1
48+
# -------------------------------------------------
49+
pipe.runner.load_kvcache(
50+
"/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor/to_x2v_cond_kv_1_346.pt",
51+
"/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor/to_x2v_uncond_kv_1_12.pt",
52+
)
53+
pipe.runner.set_inference_params(
54+
index_offset_cond=346,
55+
index_offset_uncond=12,
56+
cfg_interval=(-1, 2),
57+
cfg_scale=4.0,
58+
cfg_norm="global",
59+
timestep_shift=3.0,
60+
)
61+
62+
pipe.generate(
63+
seed=200,
64+
save_result_path="/data/nvme1/yongyang/FL/LightX2V/save_results/output_lightx2v_neopp_dense_1k_1.png",
65+
target_shape=[1024, 1024], # Height, Width
66+
)
67+
68+
69+
# -------------------------------------------------
70+
# TURN 2
71+
# -------------------------------------------------
72+
pipe.runner.load_kvcache(
73+
"/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor/to_x2v_cond_kv_2_411.pt",
74+
"/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor/to_x2v_uncond_kv_2_15.pt",
75+
)
76+
pipe.runner.set_inference_params(
77+
index_offset_cond=411,
78+
index_offset_uncond=15,
79+
cfg_interval=(-1, 2),
80+
cfg_scale=4.0,
81+
cfg_norm="global",
82+
timestep_shift=3.0,
83+
)
84+
85+
pipe.generate(
86+
seed=200,
87+
save_result_path="/data/nvme1/yongyang/FL/LightX2V/save_results/output_lightx2v_neopp_dense_1k_2.png",
88+
target_shape=[1024, 1024], # Height, Width
89+
)
90+
91+
92+
if dist.is_initialized():
93+
dist.destroy_process_group()

examples/neopp/neopp_dense_1k_cfg3.py

Lines changed: 0 additions & 51 deletions
This file was deleted.

lightx2v/models/networks/neopp/infer/kv_cache_manager.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,11 @@ def update(self, layer_idx: int, key_states: torch.Tensor, value_states: torch.T
8282
self._kv_buf[layer_idx, 0, self._kv_past_seq :] = key_states
8383
self._kv_buf[layer_idx, 1, self._kv_past_seq :] = value_states
8484
return self._kv_buf[layer_idx, 0], self._kv_buf[layer_idx, 1]
85+
86+
def clear(self):
87+
self._kv_buf_cond = None
88+
self._kv_buf_cond_key = None
89+
self._kv_buf_uncond = None
90+
self._kv_buf_uncond_key = None
91+
self._kv_buf = None
92+
self._kv_past_seq = None

lightx2v/models/networks/neopp/model.py

Lines changed: 76 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def __init__(self, model_path, config, device):
2424
self._init_infer()
2525
self._init_weights()
2626
self.cfg_interval = self.config.get("cfg_interval", (-1, 2))
27-
self.cfg_scale = self.config.get("cfg_scale", 7.0)
28-
self.img_cfg_scale = self.config.get("img_cfg_scale", 1.5)
27+
self.cfg_scale = self.config.get("cfg_scale", 4.0)
28+
self.cfg_norm = self.config.get("cfg_norm", "global")
2929
self.patch_size = self.config.get("patch_size", 16)
3030
self.merge_size = 2
3131

@@ -41,12 +41,17 @@ def _init_infer(self):
4141

4242
@torch.no_grad()
4343
def infer(self, inputs):
44+
logger.info(f"infer: cfg_scale={self.cfg_scale}")
45+
logger.info(f"infer: cfg_interval={self.cfg_interval}")
46+
logger.info(f"infer: cfg_norm={self.cfg_norm}")
4447
pre_infer_out = self.pre_infer.infer(self.pre_weight)
4548

46-
if self.config["task"] == "i2i":
47-
v_pred = self._infer_i2i(inputs, pre_infer_out)
48-
else:
49-
v_pred = self._infer_t2i(inputs, pre_infer_out)
49+
# if self.config["task"] == "i2i":
50+
# v_pred = self._infer_i2i(inputs, pre_infer_out)
51+
# else:
52+
# v_pred = self._infer_t2i(inputs, pre_infer_out)
53+
54+
v_pred = self._infer_t2i_i2i(inputs, pre_infer_out)
5055

5156
t = self.scheduler.timesteps[self.scheduler.step_index]
5257
t_next = self.scheduler.timesteps[self.scheduler.step_index + 1]
@@ -59,7 +64,26 @@ def infer(self, inputs):
5964
)
6065
return z
6166

62-
def _infer_t2i(self, inputs, pre_infer_out):
67+
def cfg_norm_func(self, v_pred, v_pred_condition):
68+
if self.cfg_norm == "global":
69+
logger.info(f"cfg_norm is global, applying global normalization")
70+
norm_v_condition = torch.norm(v_pred_condition, dim=(1, 2), keepdim=True)
71+
norm_v_cfg = torch.norm(v_pred, dim=(1, 2), keepdim=True)
72+
scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
73+
v_pred = v_pred * scale
74+
elif self.cfg_norm == "channel":
75+
logger.info(f"cfg_norm is channel, applying channel normalization")
76+
norm_v_condition = torch.norm(v_pred_condition, dim=-1, keepdim=True)
77+
norm_v_cfg = torch.norm(v_pred, dim=-1, keepdim=True)
78+
scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
79+
v_pred = v_pred * scale
80+
elif self.cfg_norm == "none":
81+
logger.info(f"cfg_norm is none, no normalization will be applied")
82+
else:
83+
raise ValueError(f"Invalid cfg_norm: {self.cfg_norm}")
84+
return v_pred
85+
86+
def _infer_t2i_i2i(self, inputs, pre_infer_out):
6387
t = self.scheduler.timesteps[self.scheduler.step_index]
6488
use_cfg = t > self.cfg_interval[0] and t < self.cfg_interval[1] and self.cfg_scale > 1
6589

@@ -79,55 +103,59 @@ def _infer_t2i(self, inputs, pre_infer_out):
79103
v_pred_list = [torch.zeros_like(v_pred) for _ in range(cfg_p_world_size)]
80104
dist.all_gather(v_pred_list, v_pred, group=cfg_p_group)
81105
v_pred_cond, v_pred_uncond = v_pred_list[0], v_pred_list[1]
82-
return v_pred_uncond + self.cfg_scale * (v_pred_cond - v_pred_uncond)
106+
v_pred = v_pred_uncond + self.cfg_scale * (v_pred_cond - v_pred_uncond)
107+
v_pred = self.cfg_norm_func(v_pred, v_pred_cond)
108+
return v_pred
83109
else:
84110
return self._infer_pass(inputs, pre_infer_out, "cond")
85111
else:
86112
v_pred_condition = self._infer_pass(inputs, pre_infer_out, "cond")
87113
if use_cfg:
88114
v_pred_uncond = self._infer_pass(inputs, pre_infer_out, "uncond")
89-
return v_pred_uncond + self.cfg_scale * (v_pred_condition - v_pred_uncond)
115+
v_pred = v_pred_uncond + self.cfg_scale * (v_pred_condition - v_pred_uncond)
116+
v_pred = self.cfg_norm_func(v_pred, v_pred_condition)
117+
return v_pred
90118
return v_pred_condition
91119

92-
def _infer_i2i(self, inputs, pre_infer_out):
93-
t = self.scheduler.timesteps[self.scheduler.step_index]
94-
use_cfg = t > self.cfg_interval[0] and t < self.cfg_interval[1]
95-
96-
if self.config.get("cfg_parallel", False):
97-
cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
98-
# assert dist.get_world_size(cfg_p_group) == 3, "cfg_p_world_size must be equal to 3 for i2i"
99-
cfg_p_rank = dist.get_rank(cfg_p_group)
100-
101-
if use_cfg:
102-
if cfg_p_rank == 0:
103-
v_pred = self._infer_pass(inputs, pre_infer_out, "cond")
104-
elif cfg_p_rank == 1:
105-
if self.cfg_scale > 1:
106-
v_pred = self._infer_pass(inputs, pre_infer_out, "text_uncond")
107-
else:
108-
v_pred = torch.zeros_like(pre_infer_out.z)
109-
elif cfg_p_rank == 2:
110-
if self.img_cfg_scale > 1:
111-
v_pred = self._infer_pass(inputs, pre_infer_out, "img_uncond")
112-
else:
113-
v_pred = torch.zeros_like(pre_infer_out.z)
114-
v_pred_list = [torch.zeros_like(v_pred) for _ in range(3)]
115-
dist.all_gather(v_pred_list, v_pred, group=cfg_p_group)
116-
v_pred_condition = v_pred_list[0]
117-
v_pred_text_uncond = v_pred_list[1] if self.cfg_scale > 1 else 0
118-
v_pred_img_uncond = v_pred_list[2] if self.img_cfg_scale > 1 else 0
119-
v_pred_text = v_pred_text_uncond + self.cfg_scale * (v_pred_condition - v_pred_text_uncond)
120-
return v_pred_img_uncond + self.img_cfg_scale * (v_pred_text - v_pred_img_uncond)
121-
else:
122-
return self._infer_pass(inputs, pre_infer_out, "cond")
123-
else:
124-
v_pred_condition = self._infer_pass(inputs, pre_infer_out, "cond")
125-
if use_cfg:
126-
v_pred_text_uncond = self._infer_pass(inputs, pre_infer_out, "text_uncond") if self.cfg_scale > 1 else 0
127-
v_pred_img_uncond = self._infer_pass(inputs, pre_infer_out, "img_uncond") if self.img_cfg_scale > 1 else 0
128-
v_pred_text = v_pred_text_uncond + self.cfg_scale * (v_pred_condition - v_pred_text_uncond)
129-
return v_pred_img_uncond + self.img_cfg_scale * (v_pred_text - v_pred_img_uncond)
130-
return v_pred_condition
120+
# def _infer_i2i(self, inputs, pre_infer_out):
121+
# t = self.scheduler.timesteps[self.scheduler.step_index]
122+
# use_cfg = t > self.cfg_interval[0] and t < self.cfg_interval[1]
123+
124+
# if self.config.get("cfg_parallel", False):
125+
# cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
126+
# # assert dist.get_world_size(cfg_p_group) == 3, "cfg_p_world_size must be equal to 3 for i2i"
127+
# cfg_p_rank = dist.get_rank(cfg_p_group)
128+
129+
# if use_cfg:
130+
# if cfg_p_rank == 0:
131+
# v_pred = self._infer_pass(inputs, pre_infer_out, "cond")
132+
# elif cfg_p_rank == 1:
133+
# if self.cfg_scale > 1:
134+
# v_pred = self._infer_pass(inputs, pre_infer_out, "text_uncond")
135+
# else:
136+
# v_pred = torch.zeros_like(pre_infer_out.z)
137+
# elif cfg_p_rank == 2:
138+
# if self.img_cfg_scale > 1:
139+
# v_pred = self._infer_pass(inputs, pre_infer_out, "img_uncond")
140+
# else:
141+
# v_pred = torch.zeros_like(pre_infer_out.z)
142+
# v_pred_list = [torch.zeros_like(v_pred) for _ in range(3)]
143+
# dist.all_gather(v_pred_list, v_pred, group=cfg_p_group)
144+
# v_pred_condition = v_pred_list[0]
145+
# v_pred_text_uncond = v_pred_list[1] if self.cfg_scale > 1 else 0
146+
# v_pred_img_uncond = v_pred_list[2] if self.img_cfg_scale > 1 else 0
147+
# v_pred_text = v_pred_text_uncond + self.cfg_scale * (v_pred_condition - v_pred_text_uncond)
148+
# return v_pred_img_uncond + self.img_cfg_scale * (v_pred_text - v_pred_img_uncond)
149+
# else:
150+
# return self._infer_pass(inputs, pre_infer_out, "cond")
151+
# else:
152+
# v_pred_condition = self._infer_pass(inputs, pre_infer_out, "cond")
153+
# if use_cfg:
154+
# v_pred_text_uncond = self._infer_pass(inputs, pre_infer_out, "text_uncond") if self.cfg_scale > 1 else 0
155+
# v_pred_img_uncond = self._infer_pass(inputs, pre_infer_out, "img_uncond") if self.img_cfg_scale > 1 else 0
156+
# v_pred_text = v_pred_text_uncond + self.cfg_scale * (v_pred_condition - v_pred_text_uncond)
157+
# return v_pred_img_uncond + self.img_cfg_scale * (v_pred_text - v_pred_img_uncond)
158+
# return v_pred_condition
131159

132160
def _infer_pass(self, inputs, pre_infer_out, pass_name):
133161
"""Run one forward pass. pass_name: 'cond' | 'uncond' | 'text_uncond' | 'img_uncond'"""

0 commit comments

Comments
 (0)