@@ -115,14 +115,18 @@ def tokenize(prompt, pipeline):
115115 return inputs
116116
117117
118- def get_unet_inputs (pipeline , params , states , config , rng , mesh , batch_size ):
118+ def get_unet_inputs (pipeline , scheduler_params , states , config , rng , mesh , batch_size ):
119119 data_sharding = jax .sharding .NamedSharding (mesh , P (* config .data_sharding ))
120120
121121 vae_scale_factor = 2 ** (len (pipeline .vae .config .block_out_channels ) - 1 )
122122 prompt_ids = [config .prompt ] * batch_size
123123 prompt_ids = tokenize (prompt_ids , pipeline )
124+ prompt_ids = jax .lax .with_sharding_constraint (prompt_ids , jax .sharding .NamedSharding (mesh , P ("data" , None , None )))
124125 negative_prompt_ids = [config .negative_prompt ] * batch_size
125126 negative_prompt_ids = tokenize (negative_prompt_ids , pipeline )
127+ negative_prompt_ids = jax .lax .with_sharding_constraint (
128+ negative_prompt_ids , jax .sharding .NamedSharding (mesh , P ("data" , None , None ))
129+ )
126130 guidance_scale = config .guidance_scale
127131 guidance_rescale = config .guidance_rescale
128132 num_inference_steps = config .num_inference_steps
@@ -133,6 +137,8 @@ def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size):
133137 "text_encoder_2" : states ["text_encoder_2_state" ].params ,
134138 }
135139 prompt_embeds , pooled_embeds = get_embeddings (prompt_ids , pipeline , text_encoder_params )
140+ prompt_embeds = jax .lax .with_sharding_constraint (prompt_embeds , jax .sharding .NamedSharding (mesh , P ("data" , None , None )))
141+ pooled_embeds = jax .lax .with_sharding_constraint (pooled_embeds , jax .sharding .NamedSharding (mesh , P ("data" , None )))
136142
137143 batch_size = prompt_embeds .shape [0 ]
138144 add_time_ids = get_add_time_ids (
@@ -148,6 +154,9 @@ def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size):
148154
149155 prompt_embeds = jnp .concatenate ([negative_prompt_embeds , prompt_embeds ], axis = 0 )
150156 add_text_embeds = jnp .concatenate ([negative_pooled_embeds , pooled_embeds ], axis = 0 )
157+ prompt_embeds = jax .lax .with_sharding_constraint (prompt_embeds , jax .sharding .NamedSharding (mesh , P ("data" , None , None )))
158+ add_text_embeds = jax .lax .with_sharding_constraint (add_text_embeds , jax .sharding .NamedSharding (mesh , P ("data" , None )))
159+
151160 add_time_ids = jnp .concatenate ([add_time_ids , add_time_ids ], axis = 0 )
152161
153162 else :
@@ -166,8 +175,11 @@ def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size):
166175
167176 latents = jax .random .normal (rng , shape = latents_shape , dtype = jnp .float32 )
168177
178+ if isinstance (scheduler_params , dict ) and "scheduler" in scheduler_params :
179+ scheduler_params = scheduler_params ["scheduler" ]
180+
169181 scheduler_state = pipeline .scheduler .set_timesteps (
170- params [ "scheduler" ] , num_inference_steps = num_inference_steps , shape = latents .shape
182+ scheduler_params , num_inference_steps = num_inference_steps , shape = latents .shape
171183 )
172184
173185 latents = latents * scheduler_state .init_noise_sigma
@@ -217,9 +229,11 @@ def run_inference(states, pipeline, params, config, rng, mesh, batch_size):
217229def run (config ):
218230 checkpoint_loader = GenerateSDXL (config )
219231 mesh = checkpoint_loader .mesh
220- with mesh :
221- pipeline , params = checkpoint_loader .load_checkpoint ()
232+ # NOTE: load_checkpoint() is called outside the mesh context intentionally.
233+ # If checkpoint loading requires mesh-aware sharding, move this back inside `with mesh:`.
234+ pipeline , params = checkpoint_loader .load_checkpoint ()
222235
236+ with mesh :
223237 noise_scheduler , noise_scheduler_state = create_scheduler (pipeline .scheduler .config , config )
224238
225239 weights_init_fn = functools .partial (pipeline .unet .init_weights , rng = checkpoint_loader .rng )
@@ -303,11 +317,13 @@ def run(config):
303317 _ = [stack .enter_context (nn .intercept_methods (interceptor )) for interceptor in lora_interceptors ]
304318 p_run_inference (states ).block_until_ready ()
305319 print ("compile time: " , (time .time () - s ))
320+
306321 s = time .time ()
307322 with ExitStack () as stack :
308323 _ = [stack .enter_context (nn .intercept_methods (interceptor )) for interceptor in lora_interceptors ]
309324 images = p_run_inference (states ).block_until_ready ()
310325 print ("inference time: " , (time .time () - s ))
326+
311327 images = jax .experimental .multihost_utils .process_allgather (images , tiled = True )
312328 numpy_images = np .array (images )
313329 images = VaeImageProcessor .numpy_to_pil (numpy_images )
0 commit comments