1515import json
1616
1717import torch
18- from transformers import AutoModel , AutoModelForCausalLM , AutoTokenizer
18+ from transformers import AutoTokenizer , Mistral3Model
1919
2020from ...configuration_utils import FrozenDict
2121from ...guiders import ClassifierFreeGuidance
2222from ...utils import logging
23+ from ...utils .import_utils import is_transformers_version
2324from ..modular_pipeline import ModularPipelineBlocks , PipelineState
2425from ..modular_pipeline_utils import ComponentSpec , InputParam , OutputParam
2526from .modular_pipeline import ErnieImageModularPipeline
2627
2728
29+ if is_transformers_version ("<" , "5.0.0" ):
30+ raise ImportError ("`ErnieImageModularPipeline` requires `transformers>=5.0.0` for `Ministral3ForCausalLM`." )
31+
32+ from transformers import Ministral3ForCausalLM # noqa: E402
33+
34+
2835logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
2936
3037
@@ -38,7 +45,7 @@ def description(self) -> str:
3845 @property
3946 def expected_components (self ) -> list [ComponentSpec ]:
4047 return [
41- ComponentSpec ("pe" , AutoModelForCausalLM ),
48+ ComponentSpec ("pe" , Ministral3ForCausalLM ),
4249 ComponentSpec ("pe_tokenizer" , AutoTokenizer ),
4350 ]
4451
@@ -83,7 +90,7 @@ def intermediate_outputs(self) -> list[OutputParam]:
8390
8491 @staticmethod
8592 def _enhance_prompt (
86- pe : AutoModelForCausalLM ,
93+ pe : Ministral3ForCausalLM ,
8794 pe_tokenizer : AutoTokenizer ,
8895 prompt : str ,
8996 device : torch .device ,
@@ -160,7 +167,7 @@ def description(self) -> str:
160167 @property
161168 def expected_components (self ) -> list [ComponentSpec ]:
162169 return [
163- ComponentSpec ("text_encoder" , AutoModel ),
170+ ComponentSpec ("text_encoder" , Mistral3Model ),
164171 ComponentSpec ("tokenizer" , AutoTokenizer ),
165172 ComponentSpec (
166173 "guider" ,
@@ -200,7 +207,7 @@ def intermediate_outputs(self) -> list[OutputParam]:
200207
201208 @staticmethod
202209 def _encode (
203- text_encoder : AutoModel ,
210+ text_encoder : Mistral3Model ,
204211 tokenizer : AutoTokenizer ,
205212 prompt : list [str ],
206213 device : torch .device ,
0 commit comments