Skip to content

Commit eeb3445

Browse files
authored
Merge branch 'main' into requirements-custom-blocks
2 parents 5b7d0df + 051c8a1 commit eeb3445

10 files changed

Lines changed: 15 additions & 13 deletions

src/diffusers/models/autoencoders/vae.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,19 +286,16 @@ def forward(
286286

287287
sample = self.conv_in(sample)
288288

289-
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
290289
if torch.is_grad_enabled() and self.gradient_checkpointing:
291290
# middle
292291
sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
293-
sample = sample.to(upscale_dtype)
294292

295293
# up
296294
for up_block in self.up_blocks:
297295
sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds)
298296
else:
299297
# middle
300298
sample = self.mid_block(sample, latent_embeds)
301-
sample = sample.to(upscale_dtype)
302299

303300
# up
304301
for up_block in self.up_blocks:

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ def from_pretrained(
320320
"cache_dir",
321321
"force_download",
322322
"local_files_only",
323+
"local_dir",
323324
"proxies",
324325
"resume_download",
325326
"revision",
@@ -336,11 +337,10 @@ def from_pretrained(
336337
module_file=module_file,
337338
class_name=class_name,
338339
**hub_kwargs,
339-
**kwargs,
340340
)
341341
expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls)
342342
block_kwargs = {
343-
name: kwargs.pop(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
343+
name: kwargs.get(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
344344
}
345345

346346
return block_cls(**block_kwargs)

src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def _get_clip_prompt_embeds(
355355
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
356356
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
357357

358-
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
358+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
359359
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
360360

361361
return prompt_embeds, pooled_prompt_embeds

src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def _get_clip_prompt_embeds(
373373
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
374374
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
375375

376-
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
376+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
377377
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
378378

379379
return prompt_embeds, pooled_prompt_embeds

src/diffusers/pipelines/pag/pipeline_pag_sd_3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def _get_clip_prompt_embeds(
326326
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
327327
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
328328

329-
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
329+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
330330
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
331331

332332
return prompt_embeds, pooled_prompt_embeds

src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def _get_clip_prompt_embeds(
342342
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
343343
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
344344

345-
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
345+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
346346
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
347347

348348
return prompt_embeds, pooled_prompt_embeds

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def _get_clip_prompt_embeds(
336336
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
337337
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
338338

339-
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
339+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
340340
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
341341

342342
return prompt_embeds, pooled_prompt_embeds

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def _get_clip_prompt_embeds(
361361
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
362362
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
363363

364-
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
364+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
365365
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
366366

367367
return prompt_embeds, pooled_prompt_embeds

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def _get_clip_prompt_embeds(
367367
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
368368
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
369369

370-
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
370+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
371371
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
372372

373373
return prompt_embeds, pooled_prompt_embeds

src/diffusers/utils/dynamic_modules_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def get_cached_module_file(
254254
token: Optional[Union[bool, str]] = None,
255255
revision: Optional[str] = None,
256256
local_files_only: bool = False,
257+
local_dir: Optional[str] = None,
257258
):
258259
"""
259260
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
@@ -332,6 +333,7 @@ def get_cached_module_file(
332333
force_download=force_download,
333334
proxies=proxies,
334335
local_files_only=local_files_only,
336+
local_dir=local_dir,
335337
)
336338
submodule = "git"
337339
module_file = pretrained_model_name_or_path + ".py"
@@ -355,6 +357,7 @@ def get_cached_module_file(
355357
force_download=force_download,
356358
proxies=proxies,
357359
local_files_only=local_files_only,
360+
local_dir=local_dir,
358361
token=token,
359362
)
360363
submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
@@ -415,6 +418,7 @@ def get_cached_module_file(
415418
token=token,
416419
revision=revision,
417420
local_files_only=local_files_only,
421+
local_dir=local_dir,
418422
)
419423
return os.path.join(full_submodule, module_file)
420424

@@ -431,7 +435,7 @@ def get_class_from_dynamic_module(
431435
token: Optional[Union[bool, str]] = None,
432436
revision: Optional[str] = None,
433437
local_files_only: bool = False,
434-
**kwargs,
438+
local_dir: Optional[str] = None,
435439
):
436440
"""
437441
Extracts a class from a module file, present in the local folder or repository of a model.
@@ -496,5 +500,6 @@ def get_class_from_dynamic_module(
496500
token=token,
497501
revision=revision,
498502
local_files_only=local_files_only,
503+
local_dir=local_dir,
499504
)
500505
return get_class_in_module(class_name, final_module)

0 commit comments

Comments
 (0)