|
1 | 1 | # transformers |
2 | 2 | from typing import List |
3 | 3 | from .patch_helper import _has_transformers |
4 | | - |
5 | 4 | from ._patch_transformers_attention import ( |
6 | 5 | patched_sdpa_attention_forward, |
7 | 6 | patched_model_bart_eager_attention_forward, |
8 | 7 | patched_modeling_marian_eager_attention_forward, |
9 | 8 | ) |
| 9 | +from ._patch_transformers_generation_mixin import patched_GenerationMixin |
| 10 | +from ._patch_transformers_causal_mask import patched_AttentionMaskConverter |
| 11 | +from ._patch_transformers_rotary_embedding import ( |
| 12 | + patched__compute_dynamic_ntk_parameters, |
| 13 | + patched_dynamic_rope_update, |
| 14 | + patched_GemmaRotaryEmbedding, |
| 15 | + patched_LlamaRotaryEmbedding, |
| 16 | + patched_MistralRotaryEmbedding, |
| 17 | + patched_MixtralRotaryEmbedding, |
| 18 | + patched_PhiRotaryEmbedding, |
| 19 | +) |
| 20 | +from ._patch_transformers_idefics import patched_IdeficsEmbedding, patched_IdeficsAttention |
| 21 | +from ._patch_transformers_sam_mask_decoder import patched_SamMaskDecoder |
| 22 | + |
| 23 | +# transformers dependant patches |
10 | 24 |
|
11 | 25 | from ._patch_transformers_cache_utils import patch_parse_processor_args |
12 | 26 |
|
13 | 27 | if patch_parse_processor_args: |
14 | 28 | from ._patch_transformers_cache_utils import patched_parse_processor_args |
15 | | - |
16 | | -from ._patch_transformers_causal_mask import patched_AttentionMaskConverter |
17 | | - |
18 | 29 | from ._patch_transformers_dynamic_cache import patch_DynamicLayer, patch_DynamicCache |
19 | 30 |
|
20 | 31 | if patch_DynamicLayer: |
21 | 32 | from ._patch_transformers_dynamic_cache import patched_DynamicLayer |
22 | 33 | if patch_DynamicCache: |
23 | 34 | from ._patch_transformers_dynamic_cache import patched_DynamicCache |
24 | | - |
25 | | -from ._patch_transformers_generation_mixin import patched_GenerationMixin |
26 | | - |
27 | 35 | from ._patch_transformers_masking_utils import patch_masking_utils |
28 | 36 |
|
29 | 37 | if patch_masking_utils: |
|
33 | 41 | patched_sdpa_mask_recent_torch, |
34 | 42 | ) |
35 | 43 |
|
36 | | -from ._patch_transformers_rotary_embedding import ( |
37 | | - patched__compute_dynamic_ntk_parameters, |
38 | | - patched_dynamic_rope_update, |
39 | | - patched_GemmaRotaryEmbedding, |
40 | | - patched_LlamaRotaryEmbedding, |
41 | | - patched_MistralRotaryEmbedding, |
42 | | - patched_MixtralRotaryEmbedding, |
43 | | - patched_PhiRotaryEmbedding, |
44 | | -) |
| 44 | +# transformers models dependant patches |
45 | 45 |
|
46 | 46 | if _has_transformers("4.51"): |
47 | 47 | from ._patch_transformers_rotary_embedding import patched_Phi3RotaryEmbedding |
|
54 | 54 | if _has_transformers("4.53"): |
55 | 55 | from ._patch_transformers_rotary_embedding import patched_SmolLM3RotaryEmbedding |
56 | 56 |
|
57 | | -# Models |
58 | | - |
59 | 57 | from ._patch_transformers_gemma3 import patch_gemma3 |
60 | 58 |
|
61 | 59 | if patch_gemma3: |
62 | 60 | from ._patch_transformers_gemma3 import patched_Gemma3Model |
63 | 61 |
|
64 | | -from ._patch_transformers_idefics import patched_IdeficsEmbedding, patched_IdeficsAttention |
65 | | - |
66 | | - |
67 | 62 | from ._patch_transformers_qwen2 import patch_qwen2 |
68 | 63 |
|
69 | 64 | if patch_qwen2: |
|
80 | 75 | patched_Qwen2_5_VLModel, |
81 | 76 | PLUGS as PLUGS_Qwen25, |
82 | 77 | ) |
83 | | - |
84 | 78 | from ._patch_transformers_qwen3 import patch_qwen3 |
85 | 79 |
|
86 | 80 | if patch_qwen3: |
87 | 81 | from ._patch_transformers_qwen3 import patched_Qwen3MoeSparseMoeBlock |
| 82 | +from ._patch_transformers_funnel import patch_funnel |
88 | 83 |
|
89 | | - |
90 | | -from ._patch_transformers_sam_mask_decoder import patched_SamMaskDecoder |
| 84 | +if patch_funnel: |
| 85 | + from ._patch_transformers_funnel import ( |
| 86 | + patched_FunnelAttentionStructure, |
| 87 | + patched_FunnelRelMultiheadAttention, |
| 88 | + ) |
91 | 89 |
|
92 | 90 |
|
93 | 91 | def get_transformers_plugs() -> List["EagerDirectReplacementWithOnnx"]: # noqa: F821 |
|
0 commit comments