@@ -889,7 +889,6 @@ def __call__(
889889 ]
890890 ] = None ,
891891 callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
892- enable_tiling : bool = False ,
893892 max_sequence_length : int = 4096 ,
894893 drop_vit_feature : bool = False ,
895894 enable_denormalization : bool = True ,
@@ -1057,11 +1056,6 @@ def __call__(
10571056 enable_denormalization = enable_denormalization ,
10581057 )
10591058
1060- target_dtype = PRECISION_TO_TYPE [self .args .dit_precision ]
1061- autocast_enabled = target_dtype != torch .float32
1062- vae_dtype = PRECISION_TO_TYPE [self .args .vae_precision ]
1063- vae_autocast_enabled = vae_dtype != torch .float32
1064-
10651059 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
10661060 self ._num_timesteps = len (timesteps )
10671061
@@ -1081,25 +1075,23 @@ def __call__(
10811075 latent_model_input = latents
10821076 t_expand = t .repeat (latent_model_input .shape [0 ])
10831077
1084- with torch .autocast (device_type = "cuda" , dtype = target_dtype , enabled = autocast_enabled ):
1085- noise_pred = self .transformer (
1078+ noise_pred = self .transformer (
1079+ hidden_states = latent_model_input ,
1080+ timestep = t_expand ,
1081+ encoder_hidden_states = prompt_embeds ,
1082+ encoder_hidden_states_mask = prompt_embeds_mask ,
1083+ return_dict = False ,
1084+ )[0 ]
1085+
1086+ if self .do_classifier_free_guidance :
1087+ noise_pred_uncond = self .transformer (
10861088 hidden_states = latent_model_input ,
10871089 timestep = t_expand ,
1088- encoder_hidden_states = prompt_embeds ,
1089- encoder_hidden_states_mask = prompt_embeds_mask ,
1090+ encoder_hidden_states = negative_prompt_embeds ,
1091+ encoder_hidden_states_mask = negative_prompt_embeds_mask ,
10901092 return_dict = False ,
10911093 )[0 ]
10921094
1093- if self .do_classifier_free_guidance :
1094- with torch .autocast (device_type = "cuda" , dtype = target_dtype , enabled = autocast_enabled ):
1095- noise_pred_uncond = self .transformer (
1096- hidden_states = latent_model_input ,
1097- timestep = t_expand ,
1098- encoder_hidden_states = negative_prompt_embeds ,
1099- encoder_hidden_states_mask = negative_prompt_embeds_mask ,
1100- return_dict = False ,
1101- )[0 ]
1102-
11031095 comb_pred = noise_pred_uncond + self .guidance_scale * (noise_pred - noise_pred_uncond )
11041096 # Rescale to match the conditional prediction norm (guidance rescaling).
11051097 cond_norm = torch .norm (noise_pred , dim = 2 , keepdim = True )
@@ -1128,11 +1120,8 @@ def __call__(
11281120 if enable_denormalization :
11291121 latents = self .denormalize_latents (latents )
11301122
1131- with torch .autocast (device_type = "cuda" , dtype = vae_dtype , enabled = vae_autocast_enabled ):
1132- if enable_tiling :
1133- self .vae .enable_tiling ()
1134- image = self .vae .decode (latents , return_dict = False )[0 ]
1135- image = image .unflatten (0 , (batch_size , - 1 ))
1123+ image = self .vae .decode (latents , return_dict = False )[0 ]
1124+ image = image .unflatten (0 , (batch_size , - 1 ))
11361125 else :
11371126 image = latents
11381127
0 commit comments