Skip to content

Commit 750ee11

Browse files
committed
Use concrete Mistral3Model / Ministral3ForCausalLM types (guarded)
1 parent 84006ed commit 750ee11

2 files changed

Lines changed: 21 additions & 6 deletions

File tree

src/diffusers/modular_pipelines/ernie_image/encoders.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,17 @@
1717
import torch
1818
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
1919

20+
21+
try:
22+
from transformers import Mistral3Model as _TextEncoderClass
23+
except ImportError:
24+
_TextEncoderClass = AutoModel
25+
26+
try:
27+
from transformers import Ministral3ForCausalLM as _PromptEnhancerClass
28+
except ImportError:
29+
_PromptEnhancerClass = AutoModelForCausalLM
30+
2031
from ...configuration_utils import FrozenDict
2132
from ...guiders import ClassifierFreeGuidance
2233
from ...utils import logging
@@ -38,7 +49,7 @@ def description(self) -> str:
3849
@property
3950
def expected_components(self) -> list[ComponentSpec]:
4051
return [
41-
ComponentSpec("pe", AutoModelForCausalLM),
52+
ComponentSpec("pe", _PromptEnhancerClass),
4253
ComponentSpec("pe_tokenizer", AutoTokenizer),
4354
]
4455

@@ -160,7 +171,7 @@ def description(self) -> str:
160171
@property
161172
def expected_components(self) -> list[ComponentSpec]:
162173
return [
163-
ComponentSpec("text_encoder", AutoModel),
174+
ComponentSpec("text_encoder", _TextEncoderClass),
164175
ComponentSpec("tokenizer", AutoTokenizer),
165176
ComponentSpec(
166177
"guider",

src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
"""
1818

1919
import json
20-
from typing import Callable, List, Optional, Union
20+
from typing import TYPE_CHECKING, Callable, List, Optional, Union
2121

2222
import torch
23-
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
23+
from transformers import AutoTokenizer
2424

2525
from ...image_processor import VaeImageProcessor
2626
from ...loaders import ErnieImageLoraLoaderMixin
@@ -32,6 +32,10 @@
3232
from .pipeline_output import ErnieImagePipelineOutput
3333

3434

35+
if TYPE_CHECKING:
36+
from transformers import Ministral3ForCausalLM, Mistral3Model
37+
38+
3539
class ErnieImagePipeline(DiffusionPipeline, ErnieImageLoraLoaderMixin):
3640
"""
3741
Pipeline for text-to-image generation using ErnieImageTransformer2DModel.
@@ -52,10 +56,10 @@ def __init__(
5256
self,
5357
transformer: ErnieImageTransformer2DModel,
5458
vae: AutoencoderKLFlux2,
55-
text_encoder: AutoModel,
59+
text_encoder: "Mistral3Model",
5660
tokenizer: AutoTokenizer,
5761
scheduler: FlowMatchEulerDiscreteScheduler,
58-
pe: Optional[AutoModelForCausalLM] = None,
62+
pe: Optional["Ministral3ForCausalLM"] = None,
5963
pe_tokenizer: Optional[AutoTokenizer] = None,
6064
):
6165
super().__init__()

0 commit comments

Comments
 (0)