@@ -576,6 +576,11 @@ def _patch_transformers_remote_code_compat() -> None:
576576 except Exception :
577577 utils = None
578578
579+ try :
580+ import transformers .modeling_rope_utils as rope_utils
581+ except Exception :
582+ rope_utils = None
583+
579584 import transformers .utils .generic as generic
580585 with _MONKEY_PATCH_LOCK :
581586 if not hasattr (import_utils , "is_torch_fx_available" ):
@@ -599,6 +604,51 @@ def is_flash_attn_greater_or_equal_2_10() -> bool:
599604
600605 utils .is_flash_attn_greater_or_equal_2_10 = is_flash_attn_greater_or_equal_2_10
601606
607+ if rope_utils is not None and "default" not in getattr (rope_utils , "ROPE_INIT_FUNCTIONS" , {}):
608+ # transformers 5.x removed the legacy `"default"` RoPE entrypoint,
609+ # but older trust_remote_code model files still resolve it directly.
610+ # Recreate the unscaled/base initializer instead of aliasing to
611+ # `"linear"` so configs do not need an artificial `factor=1.0`.
612+ def _compute_default_rope_parameters_compat (
613+ config : Optional ["PreTrainedConfig" ] = None ,
614+ device : Optional ["torch.device" ] = None ,
615+ seq_len : int | None = None ,
616+ layer_type : str | None = None ,
617+ ) -> tuple ["torch.Tensor" , float ]:
618+ del seq_len
619+ if config is None :
620+ raise ValueError ("`config` is required to compute default RoPE parameters." )
621+
622+ standardize_rope_params = getattr (config , "standardize_rope_params" , None )
623+ if callable (standardize_rope_params ):
624+ standardize_rope_params ()
625+
626+ rope_parameters = getattr (config , "rope_parameters" , None )
627+ if layer_type is not None and isinstance (rope_parameters , dict ):
628+ rope_parameters = rope_parameters .get (layer_type , rope_parameters )
629+
630+ rope_theta = None
631+ partial_rotary_factor = 1.0
632+ if isinstance (rope_parameters , dict ):
633+ rope_theta = rope_parameters .get ("rope_theta" )
634+ partial_rotary_factor = rope_parameters .get ("partial_rotary_factor" , partial_rotary_factor )
635+
636+ if rope_theta is None :
637+ rope_theta = getattr (config , "rope_theta" , None )
638+ if rope_theta is None :
639+ rope_theta = getattr (config , "default_theta" , 10_000.0 )
640+
641+ head_dim = getattr (config , "head_dim" , None ) or config .hidden_size // config .num_attention_heads
642+ dim = int (head_dim * partial_rotary_factor )
643+ attention_factor = 1.0
644+ inv_freq = 1.0 / (
645+ rope_theta
646+ ** (torch .arange (0 , dim , 2 , dtype = torch .int64 ).to (device = device , dtype = torch .float ) / dim )
647+ )
648+ return inv_freq , attention_factor
649+
650+ rope_utils .ROPE_INIT_FUNCTIONS ["default" ] = _compute_default_rope_parameters_compat
651+
602652 if cache_utils is not None and not hasattr (cache_utils , "SlidingWindowCache" ) and hasattr (cache_utils , "StaticCache" ):
603653 # transformers 5.x folds sliding-window behavior into StaticCache
604654 # layers, but older remote code still imports the legacy symbol.
@@ -1028,7 +1078,7 @@ def prepare_remote_model_init_compat(model_id_or_path: Optional[str], config: An
10281078 input_mode_enum = getattr (remote_module , "InputMode" , None ) if remote_module is not None else None
10291079
10301080 with _MONKEY_PATCH_LOCK :
1031- if config . model_type == "minicpm" or config . model_type == "instella" :
1081+ if outer_model_cls is not None :
10321082 try_patch_legacy_flash_attn_flag (outer_model_cls )
10331083
10341084 if config .model_type == "minicpmv" or config .model_type == "minicpmo" :
@@ -1212,7 +1262,8 @@ def try_patch_legacy_flash_attn_flag(model_cls):
12121262 if model_cls is None or not isinstance (model_cls , type ):
12131263 return
12141264
1215- # Find the "source class" that defines _supports_flash_attn_2.
1265+ # Find the most specific class that explicitly declares the newer
1266+ # `_supports_flash_attn_2` flag used by newer transformers releases.
12161267 base_with_flag = None
12171268 for cls in model_cls .__mro__ :
12181269 if "_supports_flash_attn_2" in cls .__dict__ :
@@ -1222,8 +1273,15 @@ def try_patch_legacy_flash_attn_flag(model_cls):
12221273 if base_with_flag is None :
12231274 return
12241275
1276+ # Respect remote models that already define the legacy flag themselves.
1277+ for cls in model_cls .__mro__ :
1278+ if cls is base_with_flag :
1279+ break
1280+ if "_supports_flash_attn" in cls .__dict__ :
1281+ return
1282+
12251283 flash_attn_2_val = base_with_flag .__dict__ ["_supports_flash_attn_2" ]
1226- setattr (cls , "_supports_flash_attn" , bool (flash_attn_2_val ))
1284+ setattr (base_with_flag , "_supports_flash_attn" , bool (flash_attn_2_val ))
12271285
12281286
12291287def load_tokenizer (tokenizer_or_path , * , model_config : Any = None , ** kwargs ):
0 commit comments