We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent c1f00a2 commit d6365ecCopy full SHA for d6365ec
1 file changed
src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py
@@ -191,6 +191,7 @@ def __init__(
191
self.text_token_max_length = int(getattr(self.args, "text_token_max_length", 2048))
192
self.prompt_template_encode = PROMPT_TEMPLATE_ENCODE
193
self.prompt_template_encode_start_idx = PROMPT_TEMPLATE_START_IDX
194
+ self._joyai_force_vae_fp32 = True
195
196
@staticmethod
197
def _dtype_to_precision(torch_dtype: Optional[torch.dtype]) -> Optional[str]:
@@ -449,6 +450,8 @@ def check_inputs(
449
450
)
451
452
def _vae_compute_dtype(self) -> torch.dtype:
453
+ if getattr(self, "_joyai_force_vae_fp32", False):
454
+ return torch.float32
455
if hasattr(self.vae, "model"):
456
return next(self.vae.model.parameters()).dtype
457
return next(self.vae.parameters()).dtype
0 commit comments