Skip to content

Commit 9d78e4e

Browse files
author
zhangmaoquan.1
committed
remove vae tiling and autocast
1 parent d397b68 commit 9d78e4e

1 file changed

Lines changed: 14 additions & 25 deletions

File tree

src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,6 @@ def __call__(
889889
]
890890
] = None,
891891
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
892-
enable_tiling: bool = False,
893892
max_sequence_length: int = 4096,
894893
drop_vit_feature: bool = False,
895894
enable_denormalization: bool = True,
@@ -1057,11 +1056,6 @@ def __call__(
10571056
enable_denormalization=enable_denormalization,
10581057
)
10591058

1060-
target_dtype = PRECISION_TO_TYPE[self.args.dit_precision]
1061-
autocast_enabled = target_dtype != torch.float32
1062-
vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision]
1063-
vae_autocast_enabled = vae_dtype != torch.float32
1064-
10651059
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
10661060
self._num_timesteps = len(timesteps)
10671061

@@ -1081,25 +1075,23 @@ def __call__(
10811075
latent_model_input = latents
10821076
t_expand = t.repeat(latent_model_input.shape[0])
10831077

1084-
with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
1085-
noise_pred = self.transformer(
1078+
noise_pred = self.transformer(
1079+
hidden_states=latent_model_input,
1080+
timestep=t_expand,
1081+
encoder_hidden_states=prompt_embeds,
1082+
encoder_hidden_states_mask=prompt_embeds_mask,
1083+
return_dict=False,
1084+
)[0]
1085+
1086+
if self.do_classifier_free_guidance:
1087+
noise_pred_uncond = self.transformer(
10861088
hidden_states=latent_model_input,
10871089
timestep=t_expand,
1088-
encoder_hidden_states=prompt_embeds,
1089-
encoder_hidden_states_mask=prompt_embeds_mask,
1090+
encoder_hidden_states=negative_prompt_embeds,
1091+
encoder_hidden_states_mask=negative_prompt_embeds_mask,
10901092
return_dict=False,
10911093
)[0]
10921094

1093-
if self.do_classifier_free_guidance:
1094-
with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
1095-
noise_pred_uncond = self.transformer(
1096-
hidden_states=latent_model_input,
1097-
timestep=t_expand,
1098-
encoder_hidden_states=negative_prompt_embeds,
1099-
encoder_hidden_states_mask=negative_prompt_embeds_mask,
1100-
return_dict=False,
1101-
)[0]
1102-
11031095
comb_pred = noise_pred_uncond + self.guidance_scale * (noise_pred - noise_pred_uncond)
11041096
# Rescale to match the conditional prediction norm (guidance rescaling).
11051097
cond_norm = torch.norm(noise_pred, dim=2, keepdim=True)
@@ -1128,11 +1120,8 @@ def __call__(
11281120
if enable_denormalization:
11291121
latents = self.denormalize_latents(latents)
11301122

1131-
with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled):
1132-
if enable_tiling:
1133-
self.vae.enable_tiling()
1134-
image = self.vae.decode(latents, return_dict=False)[0]
1135-
image = image.unflatten(0, (batch_size, -1))
1123+
image = self.vae.decode(latents, return_dict=False)[0]
1124+
image = image.unflatten(0, (batch_size, -1))
11361125
else:
11371126
image = latents
11381127

0 commit comments

Comments
 (0)