@@ -115,14 +115,18 @@ def tokenize(prompt, pipeline):
115115 return inputs
116116
117117
118- def get_unet_inputs (pipeline , params , states , config , rng , mesh , batch_size ):
118+ def get_unet_inputs (pipeline , scheduler_params , states , config , rng , mesh , batch_size ):
119119 data_sharding = jax .sharding .NamedSharding (mesh , P (* config .data_sharding ))
120120
121121 vae_scale_factor = 2 ** (len (pipeline .vae .config .block_out_channels ) - 1 )
122122 prompt_ids = [config .prompt ] * batch_size
123123 prompt_ids = tokenize (prompt_ids , pipeline )
124+ prompt_ids = jax .lax .with_sharding_constraint (prompt_ids , jax .sharding .NamedSharding (mesh , P ("data" , None , None )))
124125 negative_prompt_ids = [config .negative_prompt ] * batch_size
125126 negative_prompt_ids = tokenize (negative_prompt_ids , pipeline )
127+ negative_prompt_ids = jax .lax .with_sharding_constraint (
128+ negative_prompt_ids , jax .sharding .NamedSharding (mesh , P ("data" , None , None ))
129+ )
126130 guidance_scale = config .guidance_scale
127131 guidance_rescale = config .guidance_rescale
128132 num_inference_steps = config .num_inference_steps
@@ -133,6 +137,8 @@ def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size):
133137 "text_encoder_2" : states ["text_encoder_2_state" ].params ,
134138 }
135139 prompt_embeds , pooled_embeds = get_embeddings (prompt_ids , pipeline , text_encoder_params )
140+ prompt_embeds = jax .lax .with_sharding_constraint (prompt_embeds , jax .sharding .NamedSharding (mesh , P ("data" , None , None )))
141+ pooled_embeds = jax .lax .with_sharding_constraint (pooled_embeds , jax .sharding .NamedSharding (mesh , P ("data" , None )))
136142
137143 batch_size = prompt_embeds .shape [0 ]
138144 add_time_ids = get_add_time_ids (
@@ -148,6 +154,9 @@ def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size):
148154
149155 prompt_embeds = jnp .concatenate ([negative_prompt_embeds , prompt_embeds ], axis = 0 )
150156 add_text_embeds = jnp .concatenate ([negative_pooled_embeds , pooled_embeds ], axis = 0 )
157+ prompt_embeds = jax .lax .with_sharding_constraint (prompt_embeds , jax .sharding .NamedSharding (mesh , P ("data" , None , None )))
158+ add_text_embeds = jax .lax .with_sharding_constraint (add_text_embeds , jax .sharding .NamedSharding (mesh , P ("data" , None )))
159+
151160 add_time_ids = jnp .concatenate ([add_time_ids , add_time_ids ], axis = 0 )
152161
153162 else :
@@ -167,7 +176,7 @@ def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size):
167176 latents = jax .random .normal (rng , shape = latents_shape , dtype = jnp .float32 )
168177
169178 scheduler_state = pipeline .scheduler .set_timesteps (
170- params [ "scheduler" ] , num_inference_steps = num_inference_steps , shape = latents .shape
179+ scheduler_params , num_inference_steps = num_inference_steps , shape = latents .shape
171180 )
172181
173182 latents = latents * scheduler_state .init_noise_sigma
@@ -188,12 +197,12 @@ def vae_decode(latents, state, pipeline):
188197 return image
189198
190199
191- def run_inference (states , pipeline , params , config , rng , mesh , batch_size ):
200+ def run_inference (states , pipeline , scheduler_params , config , rng , mesh , batch_size ):
192201 unet_state = states ["unet_state" ]
193202 vae_state = states ["vae_state" ]
194203
195204 (latents , prompt_embeds , added_cond_kwargs , guidance_scale , guidance_rescale , scheduler_state ) = get_unet_inputs (
196- pipeline , params , states , config , rng , mesh , batch_size
205+ pipeline , scheduler_params , states , config , rng , mesh , batch_size
197206 )
198207
199208 loop_body_p = functools .partial (
@@ -217,9 +226,9 @@ def run_inference(states, pipeline, params, config, rng, mesh, batch_size):
217226def run (config ):
218227 checkpoint_loader = GenerateSDXL (config )
219228 mesh = checkpoint_loader .mesh
220- with mesh :
221- pipeline , params = checkpoint_loader .load_checkpoint ()
229+ pipeline , params = checkpoint_loader .load_checkpoint ()
222230
231+ with mesh :
223232 noise_scheduler , noise_scheduler_state = create_scheduler (pipeline .scheduler .config , config )
224233
225234 weights_init_fn = functools .partial (pipeline .unet .init_weights , rng = checkpoint_loader .rng )
@@ -288,7 +297,7 @@ def run(config):
288297 functools .partial (
289298 run_inference ,
290299 pipeline = pipeline ,
291- params = params ,
300+ scheduler_params = params [ "scheduler" ] ,
292301 config = config ,
293302 rng = checkpoint_loader .rng ,
294303 mesh = checkpoint_loader .mesh ,
0 commit comments