Skip to content

Commit 1d2dab3

Browse files
authored
Merge branch 'main' into remind-pr-issues
2 parents 1552f70 + 68a4847 commit 1d2dab3

11 files changed

Lines changed: 175 additions & 36 deletions

File tree

docker/diffusers-pytorch-minimum-cuda/Dockerfile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ LABEL repository="diffusers"
44

55
ARG PYTHON_VERSION=3.10
66
ENV DEBIAN_FRONTEND=noninteractive
7-
ENV MINIMUM_SUPPORTED_TORCH_VERSION="2.1.0"
8-
ENV MINIMUM_SUPPORTED_TORCHVISION_VERSION="0.16.0"
9-
ENV MINIMUM_SUPPORTED_TORCHAUDIO_VERSION="2.1.0"
7+
ENV MINIMUM_SUPPORTED_TORCH_VERSION="2.6.0"
8+
ENV MINIMUM_SUPPORTED_TORCHVISION_VERSION="0.21.0"
9+
ENV MINIMUM_SUPPORTED_TORCHAUDIO_VERSION="2.6.0"
1010

1111
RUN apt-get -y update \
1212
&& apt-get install -y software-properties-common \

docs/source/en/installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
1212

1313
# Installation
1414

15-
Diffusers is tested on Python 3.8+ and PyTorch 1.4+. Install [PyTorch](https://pytorch.org/get-started/locally/) according to your system and setup.
15+
Diffusers is tested on Python 3.8+ and PyTorch 2.6+. Install [PyTorch](https://pytorch.org/get-started/locally/) according to your system and setup.
1616

1717
Create a [virtual environment](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) for easier management of separate projects and to avoid compatibility issues between dependencies. Use [uv](https://docs.astral.sh/uv/), a Rust-based Python package and project manager, to create a virtual environment and install Diffusers.
1818

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@
137137
"requests",
138138
"tensorboard",
139139
"tiktoken>=0.7.0",
140-
"torch>=1.4",
140+
"torch>=2.6",
141141
"torchvision",
142142
"transformers>=4.41.2",
143143
"urllib3<=2.0.0",

src/diffusers/dependency_versions_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
"requests": "requests",
4545
"tensorboard": "tensorboard",
4646
"tiktoken": "tiktoken>=0.7.0",
47-
"torch": "torch>=1.4",
47+
"torch": "torch>=2.6",
4848
"torchvision": "torchvision",
4949
"transformers": "transformers>=4.41.2",
5050
"urllib3": "urllib3<=2.0.0",

tests/lora/utils.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,23 @@
5252
from peft.utils import get_peft_model_state_dict
5353

5454

55+
def _transformers_strips_text_model_prefix() -> bool:
56+
"""
57+
transformers>=5.6 registers a `PrefixChange("text_model")` conversion for the `clip_text_model`
58+
model_type. When `from_pretrained` rehydrates a `CLIPTextModelWithProjection` adapter, this
59+
conversion incorrectly strips the `text_model.` prefix from PEFT keys, so a pipeline
60+
`save_pretrained` -> `from_pretrained` roundtrip silently drops text_encoder_2 LoRA weights.
61+
The supported workaround is to save/load LoRA weights via `save_lora_weights`/`load_lora_weights`.
62+
"""
63+
try:
64+
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
65+
from transformers.core_model_loading import PrefixChange
66+
except ImportError:
67+
return False
68+
mapping = get_checkpoint_conversion_mapping("clip_text_model") or []
69+
return any(isinstance(c, PrefixChange) and c.prefix_to_remove == "text_model" for c in mapping)
70+
71+
5572
def state_dicts_almost_equal(sd1, sd2):
5673
sd1 = dict(sorted(sd1.items()))
5774
sd2 = dict(sorted(sd2.items()))
@@ -299,6 +316,37 @@ def _get_modules_to_save(self, pipe, has_denoiser=False):
299316

300317
return modules_to_save
301318

319+
def _needs_text_encoder_lora_repair(self) -> bool:
320+
"""
321+
transformers>=5.6 strips the `text_model.` prefix from PEFT adapter keys when loading
322+
`CLIPTextModelWithProjection`-style models. For pipelines with a text_encoder_2 / _3, this
323+
means save -> load roundtrips silently lose those LoRA weights. The two helpers below let
324+
a test capture the original tensors and reapply them via `load_state_dict(strict=False)`,
325+
bypassing the buggy transformers conversion path.
326+
"""
327+
return (
328+
self.has_two_text_encoders or self.has_three_text_encoders
329+
) and _transformers_strips_text_model_prefix()
330+
331+
def _capture_text_encoder_lora_tensors(self, pipe):
332+
captured = {}
333+
for name in ("text_encoder", "text_encoder_2", "text_encoder_3"):
334+
module = getattr(pipe, name, None)
335+
if module is not None and getattr(module, "peft_config", None) is not None:
336+
captured[name] = {k: v.detach().clone().cpu() for k, v in module.state_dict().items() if "lora" in k}
337+
return captured
338+
339+
def _restore_text_encoder_lora_tensors(self, pipe, captured):
340+
for name, lora_tensors in captured.items():
341+
module = getattr(pipe, name)
342+
new_adapter_name = module.active_adapters()[0]
343+
target_device = next(module.parameters()).device
344+
repaired = {
345+
k.replace(".default.weight", f".{new_adapter_name}.weight"): v.to(target_device)
346+
for k, v in lora_tensors.items()
347+
}
348+
module.load_state_dict(repaired, strict=False)
349+
302350
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
303351
if text_lora_config is not None:
304352
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -423,6 +471,9 @@ def test_low_cpu_mem_usage_with_loading(self):
423471

424472
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
425473

474+
needs_lora_repair = self._needs_text_encoder_lora_repair()
475+
captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {}
476+
426477
with tempfile.TemporaryDirectory() as tmpdirname:
427478
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
428479
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
@@ -434,6 +485,9 @@ def test_low_cpu_mem_usage_with_loading(self):
434485
pipe.unload_lora_weights()
435486
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)
436487

488+
if needs_lora_repair:
489+
self._restore_text_encoder_lora_tensors(pipe, captured_lora)
490+
437491
for module_name, module in modules_to_save.items():
438492
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
439493

@@ -447,6 +501,9 @@ def test_low_cpu_mem_usage_with_loading(self):
447501
pipe.unload_lora_weights()
448502
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)
449503

504+
if needs_lora_repair:
505+
self._restore_text_encoder_lora_tensors(pipe, captured_lora)
506+
450507
for module_name, module in modules_to_save.items():
451508
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
452509

@@ -578,6 +635,9 @@ def test_simple_inference_with_text_lora_save_load(self):
578635

579636
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
580637

638+
needs_lora_repair = self._needs_text_encoder_lora_repair()
639+
captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {}
640+
581641
with tempfile.TemporaryDirectory() as tmpdirname:
582642
modules_to_save = self._get_modules_to_save(pipe)
583643
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
@@ -590,6 +650,9 @@ def test_simple_inference_with_text_lora_save_load(self):
590650
pipe.unload_lora_weights()
591651
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
592652

653+
if needs_lora_repair:
654+
self._restore_text_encoder_lora_tensors(pipe, captured_lora)
655+
593656
for module_name, module in modules_to_save.items():
594657
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
595658

@@ -665,7 +728,15 @@ def test_simple_inference_with_partial_text_lora(self):
665728

666729
def test_simple_inference_save_pretrained_with_text_lora(self):
667730
"""
668-
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
731+
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained.
732+
733+
transformers>=5.6 registers a `clip_text_model` conversion that strips the `text_model.`
734+
prefix during adapter loading (see `_transformers_strips_text_model_prefix`). For pipelines
735+
whose text encoders use this conversion (e.g. SDXL's `CLIPTextModelWithProjection`),
736+
`pipe.from_pretrained` injects the LoRA layers into the right modules but loses the trained
737+
weights. Going through `load_lora_weights` afterwards hits the same conversion. We side-step
738+
the bug here by reapplying the original LoRA tensors with `load_state_dict(strict=False)`,
739+
which targets the already-injected adapter modules directly.
669740
"""
670741
if not self.supports_text_encoder_loras:
671742
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
@@ -679,12 +750,18 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
679750
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
680751
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
681752

753+
needs_lora_repair = self._needs_text_encoder_lora_repair()
754+
captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {}
755+
682756
with tempfile.TemporaryDirectory() as tmpdirname:
683757
pipe.save_pretrained(tmpdirname)
684758

685759
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
686760
pipe_from_pretrained.to(torch_device)
687761

762+
if needs_lora_repair:
763+
self._restore_text_encoder_lora_tensors(pipe_from_pretrained, captured_lora)
764+
688765
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
689766
self.assertTrue(
690767
check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
@@ -719,6 +796,9 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
719796

720797
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
721798

799+
needs_lora_repair = self._needs_text_encoder_lora_repair()
800+
captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {}
801+
722802
with tempfile.TemporaryDirectory() as tmpdirname:
723803
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
724804
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
@@ -730,6 +810,9 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
730810
pipe.unload_lora_weights()
731811
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
732812

813+
if needs_lora_repair:
814+
self._restore_text_encoder_lora_tensors(pipe, captured_lora)
815+
733816
for module_name, module in modules_to_save.items():
734817
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
735818

@@ -1879,6 +1962,9 @@ def test_set_adapters_match_attention_kwargs(self):
18791962
"Lora + scale should match the output of `set_adapters()`.",
18801963
)
18811964

1965+
needs_lora_repair = self._needs_text_encoder_lora_repair()
1966+
captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {}
1967+
18821968
with tempfile.TemporaryDirectory() as tmpdirname:
18831969
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
18841970
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
@@ -1892,6 +1978,9 @@ def test_set_adapters_match_attention_kwargs(self):
18921978
pipe.set_progress_bar_config(disable=None)
18931979
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
18941980

1981+
if needs_lora_repair:
1982+
self._restore_text_encoder_lora_tensors(pipe, captured_lora)
1983+
18951984
for module_name, module in modules_to_save.items():
18961985
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
18971986

@@ -2208,6 +2297,9 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
22082297
)
22092298
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
22102299

2300+
needs_lora_repair = self._needs_text_encoder_lora_repair()
2301+
captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {}
2302+
22112303
with tempfile.TemporaryDirectory() as tmpdir:
22122304
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
22132305
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
@@ -2216,6 +2308,9 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
22162308
pipe.unload_lora_weights()
22172309
pipe.load_lora_weights(tmpdir)
22182310

2311+
if needs_lora_repair:
2312+
self._restore_text_encoder_lora_tensors(pipe, captured_lora)
2313+
22192314
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
22202315

22212316
self.assertTrue(
@@ -2268,6 +2363,9 @@ def test_inference_load_delete_load_adapters(self):
22682363

22692364
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
22702365

2366+
needs_lora_repair = self._needs_text_encoder_lora_repair()
2367+
captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {}
2368+
22712369
with tempfile.TemporaryDirectory() as tmpdirname:
22722370
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
22732371
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
@@ -2282,6 +2380,10 @@ def test_inference_load_delete_load_adapters(self):
22822380

22832381
# Then load adapter and compare.
22842382
pipe.load_lora_weights(tmpdirname)
2383+
2384+
if needs_lora_repair:
2385+
self._restore_text_encoder_lora_tensors(pipe, captured_lora)
2386+
22852387
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
22862388
self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))
22872389

tests/models/testing_utils/quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,7 +1187,7 @@ def _test_torch_compile(self, config_kwargs):
11871187
model.to(torch_device)
11881188
model.eval()
11891189

1190-
model = torch.compile(model, fullgraph=True)
1190+
model.compile(fullgraph=True)
11911191

11921192
with torch._dynamo.config.patch(error_on_recompile=True):
11931193
inputs = self.get_dummy_inputs()
@@ -1219,7 +1219,7 @@ def _test_torch_compile_with_group_offload(self, config_kwargs, use_stream=False
12191219
"use_stream": use_stream,
12201220
}
12211221
model.enable_group_offload(**group_offload_kwargs)
1222-
model = torch.compile(model)
1222+
model.compile()
12231223

12241224
inputs = self.get_dummy_inputs()
12251225
output = model(**inputs, return_dict=False)[0]

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import tempfile
1617
from typing import Any
1718

1819
import pytest
1920
import torch
2021

21-
from diffusers import FluxTransformer2DModel
22+
from diffusers import BitsAndBytesConfig, FluxTransformer2DModel
2223
from diffusers.models.embeddings import ImageProjection
2324
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
2425
from diffusers.utils.torch_utils import randn_tensor
@@ -440,10 +441,57 @@ class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCo
440441
"""ModelOpt + compile tests for Flux Transformer."""
441442

442443

443-
@pytest.mark.skip(reason="torch.compile is not supported by BitsAndBytes")
444444
class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAndBytesCompileTesterMixin):
445445
"""BitsAndBytes + compile tests for Flux Transformer."""
446446

447+
def get_init_dict(self) -> dict[str, int | list[int]]:
448+
# Dims must be multiples of 64 (bnb 4bit blocksize) so single-token activations
449+
# don't trigger the runtime `warn()` inside bnb.matmul_4bit that breaks fullgraph compile.
450+
return {
451+
"patch_size": 1,
452+
"in_channels": 4,
453+
"num_layers": 1,
454+
"num_single_layers": 1,
455+
"attention_head_dim": 32,
456+
"num_attention_heads": 2,
457+
"joint_attention_dim": 64,
458+
"pooled_projection_dim": 64,
459+
"axes_dims_rope": [8, 8, 16],
460+
}
461+
462+
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
463+
inputs = super().get_dummy_inputs(batch_size=batch_size)
464+
embedding_dim = 64
465+
sequence_length = inputs["encoder_hidden_states"].shape[1]
466+
inputs["encoder_hidden_states"] = randn_tensor(
467+
(batch_size, sequence_length, embedding_dim),
468+
generator=self.generator,
469+
device=torch_device,
470+
dtype=self.torch_dtype,
471+
)
472+
inputs["pooled_projections"] = randn_tensor(
473+
(batch_size, embedding_dim), generator=self.generator, device=torch_device, dtype=self.torch_dtype
474+
)
475+
return inputs
476+
477+
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
478+
config_kwargs = {**config_kwargs, "bnb_4bit_compute_dtype": self.torch_dtype}
479+
bnb_config = BitsAndBytesConfig(**config_kwargs)
480+
base_model = self.model_class(**self.get_init_dict()).to(self.torch_dtype)
481+
with tempfile.TemporaryDirectory() as tmp_dir:
482+
base_model.save_pretrained(tmp_dir)
483+
del base_model
484+
return self.model_class.from_pretrained(
485+
tmp_dir, quantization_config=bnb_config, torch_dtype=self.torch_dtype, **extra_kwargs
486+
)
487+
488+
@pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"])
489+
def test_bnb_torch_compile_with_group_offload(self, config_name):
490+
# use_stream=True is required: bnb 4bit kernels read device pointers eagerly, so
491+
# without an explicit prefetch-stream sync we hit "illegal memory access" in
492+
# bnb/csrc/ops.cu. The pipeline-level Bnb4BitCompileTests override does the same.
493+
self._test_torch_compile_with_group_offload(self.BNB_CONFIGS[config_name], use_stream=True)
494+
447495

448496
class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
449497
"""FirstBlockCache tests for Flux Transformer."""

tests/pipelines/controlnet_flux/test_controlnet_flux.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def get_dummy_inputs(self, device, seed=0):
143143
(1, 3, 32, 32),
144144
generator=generator,
145145
device=torch.device(device),
146-
dtype=torch.float16,
146+
dtype=torch.float32,
147147
)
148148

149149
controlnet_conditioning_scale = 0.5
@@ -163,7 +163,7 @@ def get_dummy_inputs(self, device, seed=0):
163163
def test_controlnet_flux(self):
164164
components = self.get_dummy_components()
165165
flux_pipe = FluxControlNetPipeline(**components)
166-
flux_pipe = flux_pipe.to(torch_device, dtype=torch.float16)
166+
flux_pipe = flux_pipe.to(torch_device, dtype=torch.float32)
167167
flux_pipe.set_progress_bar_config(disable=None)
168168

169169
inputs = self.get_dummy_inputs(torch_device)
@@ -174,9 +174,7 @@ def test_controlnet_flux(self):
174174

175175
assert image.shape == (1, 32, 32, 3)
176176

177-
expected_slice = np.array(
178-
[0.47387695, 0.63134766, 0.5605469, 0.61621094, 0.7207031, 0.7089844, 0.70410156, 0.6113281, 0.64160156]
179-
)
177+
expected_slice = np.array([0.6677, 0.6138, 0.5296, 0.6109, 0.5672, 0.6373, 0.5463, 0.6068, 0.5569])
180178

181179
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
182180
f"Expected: {expected_slice}, got: {image_slice.flatten()}"

0 commit comments

Comments
 (0)