@@ -443,6 +443,28 @@ def generate(
443443 return tokenizer .decode (generated_ids [0 ], skip_special_tokens = True )
444444
445445
446+ def get_head_shapes (config ) -> tuple [int | list [int ], int | list [int ]]:
447+ """Returns a tuple `(num_heads, head_dim)` containing either 2 ints, or a list of int with the value for each
448+ layer."""
449+ # Gemma4 has different head_dim and num_heads depending on layer type
450+ if hasattr (config , "global_head_dim" ):
451+ head_dim = [
452+ config .global_head_dim if layer == "full_attention" else config .head_dim
453+ for layer in config .layer_types [: - config .num_kv_shared_layers ]
454+ ]
455+ num_heads = [
456+ config .num_global_key_value_heads
457+ if layer == "full_attention" and config .attention_k_eq_v
458+ else config .num_key_value_heads
459+ for layer in config .layer_types [: - config .num_kv_shared_layers ]
460+ ]
461+ else :
462+ head_dim = getattr (config , "head_dim" , config .hidden_size // config .num_attention_heads )
463+ num_heads = getattr (config , "num_key_value_heads" , config .num_attention_heads )
464+
465+ return num_heads , head_dim
466+
467+
446468class TorchExportableModuleWithStaticCache (torch .nn .Module ):
447469 """
448470 A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
@@ -523,9 +545,8 @@ def __init__(
523545 # simple StaticLayer... It means that any generation beyond the window is unfortunately unsupported
524546 for i , layer in enumerate (self .static_cache .layers ):
525547 if isinstance (layer , StaticSlidingWindowLayer ):
526- self .static_cache .layers [i ] = StaticLayer (layer .max_cache_len )
527- head_dim = getattr (config , "head_dim" , config .hidden_size // config .num_attention_heads )
528- num_heads = getattr (config , "num_key_value_heads" , config .num_attention_heads )
548+ self .static_cache .layers [i ] = StaticLayer (max_cache_len )
549+ num_heads , head_dim = get_head_shapes (config )
529550 dtype = self .model .dtype
530551 # We need this call to initialize all the layers (otherwise it's done lazily, which is not exportable)
531552 self .static_cache .early_initialization (batch_size , num_heads , head_dim , dtype , device )
@@ -702,9 +723,8 @@ def __init__(
702723 # simple StaticLayer... It means that any generation beyond the window is unfortunately unsupported
703724 for i , layer in enumerate (self .cache .layers ):
704725 if isinstance (layer , StaticSlidingWindowLayer ):
705- self .cache .layers [i ] = StaticLayer (layer .max_cache_len )
706- head_dim = getattr (config , "head_dim" , config .hidden_size // config .num_attention_heads )
707- num_heads = getattr (config , "num_key_value_heads" , config .num_attention_heads )
726+ self .cache .layers [i ] = StaticLayer (max_cache_len )
727+ num_heads , head_dim = get_head_shapes (config )
708728 dtype = self .model .dtype
709729 # We need this call to initialize all the layers (otherwise it's done lazily, which is not exportable)
710730 self .cache .early_initialization (batch_size , num_heads , head_dim , dtype , device )
@@ -856,9 +876,8 @@ def __init__(self, model, max_static_cache_length, batch_size):
856876 # simple StaticLayer... It means that any generation beyond the window is unfortunately unsupported
857877 for i , layer in enumerate (self .static_cache .layers ):
858878 if isinstance (layer , StaticSlidingWindowLayer ):
859- self .static_cache .layers [i ] = StaticLayer (layer .max_cache_len )
860- head_dim = getattr (self .config , "head_dim" , self .config .hidden_size // self .config .num_attention_heads )
861- num_heads = getattr (self .config , "num_key_value_heads" , self .config .num_attention_heads )
879+ self .static_cache .layers [i ] = StaticLayer (max_static_cache_length )
880+ num_heads , head_dim = get_head_shapes (self .config )
862881 self .static_cache .early_initialization (batch_size , num_heads , head_dim , torch .float32 , model_device )
863882 self .cache = EncoderDecoderCache (self .static_cache , DynamicCache (config = self .config ))
864883
0 commit comments