Skip to content

Commit 51321f1

Browse files
committed
Use concrete Mistral3Model / Ministral3ForCausalLM types (guarded)
1 parent 84006ed commit 51321f1

2 files changed

Lines changed: 8 additions & 8 deletions

File tree

src/diffusers/modular_pipelines/ernie_image/encoders.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import json
1616

1717
import torch
18-
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
18+
from transformers import AutoTokenizer, Ministral3ForCausalLM, Mistral3Model
1919

2020
from ...configuration_utils import FrozenDict
2121
from ...guiders import ClassifierFreeGuidance
@@ -38,7 +38,7 @@ def description(self) -> str:
3838
@property
3939
def expected_components(self) -> list[ComponentSpec]:
4040
return [
41-
ComponentSpec("pe", AutoModelForCausalLM),
41+
ComponentSpec("pe", Ministral3ForCausalLM),
4242
ComponentSpec("pe_tokenizer", AutoTokenizer),
4343
]
4444

@@ -83,7 +83,7 @@ def intermediate_outputs(self) -> list[OutputParam]:
8383

8484
@staticmethod
8585
def _enhance_prompt(
86-
pe: AutoModelForCausalLM,
86+
pe: Ministral3ForCausalLM,
8787
pe_tokenizer: AutoTokenizer,
8888
prompt: str,
8989
device: torch.device,
@@ -160,7 +160,7 @@ def description(self) -> str:
160160
@property
161161
def expected_components(self) -> list[ComponentSpec]:
162162
return [
163-
ComponentSpec("text_encoder", AutoModel),
163+
ComponentSpec("text_encoder", Mistral3Model),
164164
ComponentSpec("tokenizer", AutoTokenizer),
165165
ComponentSpec(
166166
"guider",
@@ -200,7 +200,7 @@ def intermediate_outputs(self) -> list[OutputParam]:
200200

201201
@staticmethod
202202
def _encode(
203-
text_encoder: AutoModel,
203+
text_encoder: Mistral3Model,
204204
tokenizer: AutoTokenizer,
205205
prompt: list[str],
206206
device: torch.device,

src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Callable, List, Optional, Union
2121

2222
import torch
23-
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
23+
from transformers import AutoTokenizer, Ministral3ForCausalLM, Mistral3Model
2424

2525
from ...image_processor import VaeImageProcessor
2626
from ...loaders import ErnieImageLoraLoaderMixin
@@ -52,10 +52,10 @@ def __init__(
5252
self,
5353
transformer: ErnieImageTransformer2DModel,
5454
vae: AutoencoderKLFlux2,
55-
text_encoder: AutoModel,
55+
text_encoder: Mistral3Model,
5656
tokenizer: AutoTokenizer,
5757
scheduler: FlowMatchEulerDiscreteScheduler,
58-
pe: Optional[AutoModelForCausalLM] = None,
58+
pe: Optional[Ministral3ForCausalLM] = None,
5959
pe_tokenizer: Optional[AutoTokenizer] = None,
6060
):
6161
super().__init__()

0 commit comments

Comments
 (0)