@@ -24,8 +24,8 @@ def __init__(self, model_path, config, device):
2424 self ._init_infer ()
2525 self ._init_weights ()
2626 self .cfg_interval = self .config .get ("cfg_interval" , (- 1 , 2 ))
27- self .cfg_scale = self .config .get ("cfg_scale" , 7 .0 )
28- self .img_cfg_scale = self .config .get ("img_cfg_scale " , 1.5 )
27+ self .cfg_scale = self .config .get ("cfg_scale" , 4 .0 )
28+ self .cfg_norm = self .config .get ("cfg_norm " , "global" )
2929 self .patch_size = self .config .get ("patch_size" , 16 )
3030 self .merge_size = 2
3131
@@ -41,12 +41,17 @@ def _init_infer(self):
4141
4242 @torch .no_grad ()
4343 def infer (self , inputs ):
44+ logger .info (f"infer: cfg_scale={ self .cfg_scale } " )
45+ logger .info (f"infer: cfg_interval={ self .cfg_interval } " )
46+ logger .info (f"infer: cfg_norm={ self .cfg_norm } " )
4447 pre_infer_out = self .pre_infer .infer (self .pre_weight )
4548
46- if self .config ["task" ] == "i2i" :
47- v_pred = self ._infer_i2i (inputs , pre_infer_out )
48- else :
49- v_pred = self ._infer_t2i (inputs , pre_infer_out )
49+ # if self.config["task"] == "i2i":
50+ # v_pred = self._infer_i2i(inputs, pre_infer_out)
51+ # else:
52+ # v_pred = self._infer_t2i(inputs, pre_infer_out)
53+
54+ v_pred = self ._infer_t2i_i2i (inputs , pre_infer_out )
5055
5156 t = self .scheduler .timesteps [self .scheduler .step_index ]
5257 t_next = self .scheduler .timesteps [self .scheduler .step_index + 1 ]
@@ -59,7 +64,26 @@ def infer(self, inputs):
5964 )
6065 return z
6166
62- def _infer_t2i (self , inputs , pre_infer_out ):
67+ def cfg_norm_func (self , v_pred , v_pred_condition ):
68+ if self .cfg_norm == "global" :
69+ logger .info (f"cfg_norm is global, applying global normalization" )
70+ norm_v_condition = torch .norm (v_pred_condition , dim = (1 , 2 ), keepdim = True )
71+ norm_v_cfg = torch .norm (v_pred , dim = (1 , 2 ), keepdim = True )
72+ scale = (norm_v_condition / (norm_v_cfg + 1e-8 )).clamp (min = 0 , max = 1.0 )
73+ v_pred = v_pred * scale
74+ elif self .cfg_norm == "channel" :
75+ logger .info (f"cfg_norm is channel, applying channel normalization" )
76+ norm_v_condition = torch .norm (v_pred_condition , dim = - 1 , keepdim = True )
77+ norm_v_cfg = torch .norm (v_pred , dim = - 1 , keepdim = True )
78+ scale = (norm_v_condition / (norm_v_cfg + 1e-8 )).clamp (min = 0 , max = 1.0 )
79+ v_pred = v_pred * scale
80+ elif self .cfg_norm == "none" :
81+ logger .info (f"cfg_norm is none, no normalization will be applied" )
82+ else :
83+ raise ValueError (f"Invalid cfg_norm: { self .cfg_norm } " )
84+ return v_pred
85+
86+ def _infer_t2i_i2i (self , inputs , pre_infer_out ):
6387 t = self .scheduler .timesteps [self .scheduler .step_index ]
6488 use_cfg = t > self .cfg_interval [0 ] and t < self .cfg_interval [1 ] and self .cfg_scale > 1
6589
@@ -79,55 +103,59 @@ def _infer_t2i(self, inputs, pre_infer_out):
79103 v_pred_list = [torch .zeros_like (v_pred ) for _ in range (cfg_p_world_size )]
80104 dist .all_gather (v_pred_list , v_pred , group = cfg_p_group )
81105 v_pred_cond , v_pred_uncond = v_pred_list [0 ], v_pred_list [1 ]
82- return v_pred_uncond + self .cfg_scale * (v_pred_cond - v_pred_uncond )
106+ v_pred = v_pred_uncond + self .cfg_scale * (v_pred_cond - v_pred_uncond )
107+ v_pred = self .cfg_norm_func (v_pred , v_pred_cond )
108+ return v_pred
83109 else :
84110 return self ._infer_pass (inputs , pre_infer_out , "cond" )
85111 else :
86112 v_pred_condition = self ._infer_pass (inputs , pre_infer_out , "cond" )
87113 if use_cfg :
88114 v_pred_uncond = self ._infer_pass (inputs , pre_infer_out , "uncond" )
89- return v_pred_uncond + self .cfg_scale * (v_pred_condition - v_pred_uncond )
115+ v_pred = v_pred_uncond + self .cfg_scale * (v_pred_condition - v_pred_uncond )
116+ v_pred = self .cfg_norm_func (v_pred , v_pred_condition )
117+ return v_pred
90118 return v_pred_condition
91119
92- def _infer_i2i (self , inputs , pre_infer_out ):
93- t = self .scheduler .timesteps [self .scheduler .step_index ]
94- use_cfg = t > self .cfg_interval [0 ] and t < self .cfg_interval [1 ]
95-
96- if self .config .get ("cfg_parallel" , False ):
97- cfg_p_group = self .config ["device_mesh" ].get_group (mesh_dim = "cfg_p" )
98- # assert dist.get_world_size(cfg_p_group) == 3, "cfg_p_world_size must be equal to 3 for i2i"
99- cfg_p_rank = dist .get_rank (cfg_p_group )
100-
101- if use_cfg :
102- if cfg_p_rank == 0 :
103- v_pred = self ._infer_pass (inputs , pre_infer_out , "cond" )
104- elif cfg_p_rank == 1 :
105- if self .cfg_scale > 1 :
106- v_pred = self ._infer_pass (inputs , pre_infer_out , "text_uncond" )
107- else :
108- v_pred = torch .zeros_like (pre_infer_out .z )
109- elif cfg_p_rank == 2 :
110- if self .img_cfg_scale > 1 :
111- v_pred = self ._infer_pass (inputs , pre_infer_out , "img_uncond" )
112- else :
113- v_pred = torch .zeros_like (pre_infer_out .z )
114- v_pred_list = [torch .zeros_like (v_pred ) for _ in range (3 )]
115- dist .all_gather (v_pred_list , v_pred , group = cfg_p_group )
116- v_pred_condition = v_pred_list [0 ]
117- v_pred_text_uncond = v_pred_list [1 ] if self .cfg_scale > 1 else 0
118- v_pred_img_uncond = v_pred_list [2 ] if self .img_cfg_scale > 1 else 0
119- v_pred_text = v_pred_text_uncond + self .cfg_scale * (v_pred_condition - v_pred_text_uncond )
120- return v_pred_img_uncond + self .img_cfg_scale * (v_pred_text - v_pred_img_uncond )
121- else :
122- return self ._infer_pass (inputs , pre_infer_out , "cond" )
123- else :
124- v_pred_condition = self ._infer_pass (inputs , pre_infer_out , "cond" )
125- if use_cfg :
126- v_pred_text_uncond = self ._infer_pass (inputs , pre_infer_out , "text_uncond" ) if self .cfg_scale > 1 else 0
127- v_pred_img_uncond = self ._infer_pass (inputs , pre_infer_out , "img_uncond" ) if self .img_cfg_scale > 1 else 0
128- v_pred_text = v_pred_text_uncond + self .cfg_scale * (v_pred_condition - v_pred_text_uncond )
129- return v_pred_img_uncond + self .img_cfg_scale * (v_pred_text - v_pred_img_uncond )
130- return v_pred_condition
120+ # def _infer_i2i(self, inputs, pre_infer_out):
121+ # t = self.scheduler.timesteps[self.scheduler.step_index]
122+ # use_cfg = t > self.cfg_interval[0] and t < self.cfg_interval[1]
123+
124+ # if self.config.get("cfg_parallel", False):
125+ # cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
126+ # # assert dist.get_world_size(cfg_p_group) == 3, "cfg_p_world_size must be equal to 3 for i2i"
127+ # cfg_p_rank = dist.get_rank(cfg_p_group)
128+
129+ # if use_cfg:
130+ # if cfg_p_rank == 0:
131+ # v_pred = self._infer_pass(inputs, pre_infer_out, "cond")
132+ # elif cfg_p_rank == 1:
133+ # if self.cfg_scale > 1:
134+ # v_pred = self._infer_pass(inputs, pre_infer_out, "text_uncond")
135+ # else:
136+ # v_pred = torch.zeros_like(pre_infer_out.z)
137+ # elif cfg_p_rank == 2:
138+ # if self.img_cfg_scale > 1:
139+ # v_pred = self._infer_pass(inputs, pre_infer_out, "img_uncond")
140+ # else:
141+ # v_pred = torch.zeros_like(pre_infer_out.z)
142+ # v_pred_list = [torch.zeros_like(v_pred) for _ in range(3)]
143+ # dist.all_gather(v_pred_list, v_pred, group=cfg_p_group)
144+ # v_pred_condition = v_pred_list[0]
145+ # v_pred_text_uncond = v_pred_list[1] if self.cfg_scale > 1 else 0
146+ # v_pred_img_uncond = v_pred_list[2] if self.img_cfg_scale > 1 else 0
147+ # v_pred_text = v_pred_text_uncond + self.cfg_scale * (v_pred_condition - v_pred_text_uncond)
148+ # return v_pred_img_uncond + self.img_cfg_scale * (v_pred_text - v_pred_img_uncond)
149+ # else:
150+ # return self._infer_pass(inputs, pre_infer_out, "cond")
151+ # else:
152+ # v_pred_condition = self._infer_pass(inputs, pre_infer_out, "cond")
153+ # if use_cfg:
154+ # v_pred_text_uncond = self._infer_pass(inputs, pre_infer_out, "text_uncond") if self.cfg_scale > 1 else 0
155+ # v_pred_img_uncond = self._infer_pass(inputs, pre_infer_out, "img_uncond") if self.img_cfg_scale > 1 else 0
156+ # v_pred_text = v_pred_text_uncond + self.cfg_scale * (v_pred_condition - v_pred_text_uncond)
157+ # return v_pred_img_uncond + self.img_cfg_scale * (v_pred_text - v_pred_img_uncond)
158+ # return v_pred_condition
131159
132160 def _infer_pass (self , inputs , pre_infer_out , pass_name ):
133161 """Run one forward pass. pass_name: 'cond' | 'uncond' | 'text_uncond' | 'img_uncond'"""
0 commit comments