Skip to content

Commit aa7f34f

Browse files
committed
Use Mistral3Model/Ministral3ForCausalLM
1 parent 84006ed commit aa7f34f

2 files changed

Lines changed: 22 additions & 8 deletions

File tree

src/diffusers/modular_pipelines/ernie_image/encoders.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,23 @@
1515
import json
1616

1717
import torch
18-
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
18+
from transformers import AutoTokenizer, Mistral3Model
1919

2020
from ...configuration_utils import FrozenDict
2121
from ...guiders import ClassifierFreeGuidance
2222
from ...utils import logging
23+
from ...utils.import_utils import is_transformers_version
2324
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
2425
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
2526
from .modular_pipeline import ErnieImageModularPipeline
2627

2728

29+
if is_transformers_version("<", "5.0.0"):
30+
raise ImportError("`ErnieImageModularPipeline` requires `transformers>=5.0.0` for `Ministral3ForCausalLM`.")
31+
32+
from transformers import Ministral3ForCausalLM # noqa: E402
33+
34+
2835
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2936

3037

@@ -38,7 +45,7 @@ def description(self) -> str:
3845
@property
3946
def expected_components(self) -> list[ComponentSpec]:
4047
return [
41-
ComponentSpec("pe", AutoModelForCausalLM),
48+
ComponentSpec("pe", Ministral3ForCausalLM),
4249
ComponentSpec("pe_tokenizer", AutoTokenizer),
4350
]
4451

@@ -83,7 +90,7 @@ def intermediate_outputs(self) -> list[OutputParam]:
8390

8491
@staticmethod
8592
def _enhance_prompt(
86-
pe: AutoModelForCausalLM,
93+
pe: Ministral3ForCausalLM,
8794
pe_tokenizer: AutoTokenizer,
8895
prompt: str,
8996
device: torch.device,
@@ -160,7 +167,7 @@ def description(self) -> str:
160167
@property
161168
def expected_components(self) -> list[ComponentSpec]:
162169
return [
163-
ComponentSpec("text_encoder", AutoModel),
170+
ComponentSpec("text_encoder", Mistral3Model),
164171
ComponentSpec("tokenizer", AutoTokenizer),
165172
ComponentSpec(
166173
"guider",
@@ -200,7 +207,7 @@ def intermediate_outputs(self) -> list[OutputParam]:
200207

201208
@staticmethod
202209
def _encode(
203-
text_encoder: AutoModel,
210+
text_encoder: Mistral3Model,
204211
tokenizer: AutoTokenizer,
205212
prompt: list[str],
206213
device: torch.device,

src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,25 @@
2020
from typing import Callable, List, Optional, Union
2121

2222
import torch
23-
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
23+
from transformers import AutoTokenizer, Mistral3Model
2424

2525
from ...image_processor import VaeImageProcessor
2626
from ...loaders import ErnieImageLoraLoaderMixin
2727
from ...models import AutoencoderKLFlux2
2828
from ...models.transformers import ErnieImageTransformer2DModel
2929
from ...pipelines.pipeline_utils import DiffusionPipeline
3030
from ...schedulers import FlowMatchEulerDiscreteScheduler
31+
from ...utils.import_utils import is_transformers_version
3132
from ...utils.torch_utils import randn_tensor
3233
from .pipeline_output import ErnieImagePipelineOutput
3334

3435

36+
if is_transformers_version("<", "5.0.0"):
37+
raise ImportError("`ErnieImagePipeline` requires `transformers>=5.0.0` for `Ministral3ForCausalLM`.")
38+
39+
from transformers import Ministral3ForCausalLM # noqa: E402
40+
41+
3542
class ErnieImagePipeline(DiffusionPipeline, ErnieImageLoraLoaderMixin):
3643
"""
3744
Pipeline for text-to-image generation using ErnieImageTransformer2DModel.
@@ -52,10 +59,10 @@ def __init__(
5259
self,
5360
transformer: ErnieImageTransformer2DModel,
5461
vae: AutoencoderKLFlux2,
55-
text_encoder: AutoModel,
62+
text_encoder: Mistral3Model,
5663
tokenizer: AutoTokenizer,
5764
scheduler: FlowMatchEulerDiscreteScheduler,
58-
pe: Optional[AutoModelForCausalLM] = None,
65+
pe: Optional[Ministral3ForCausalLM] = None,
5966
pe_tokenizer: Optional[AutoTokenizer] = None,
6067
):
6168
super().__init__()

0 commit comments

Comments
 (0)