Skip to content

Commit c850500

Browse files
authored
Fix export for gemma4 and add Integration tests (#45285)
* start updating tests * start making them pass * more * fix * fix export * fix * oupsi * review comments * fix expectations for a10 * style * rename tests
1 parent d081c71 commit c850500

4 files changed

Lines changed: 153 additions & 329 deletions

File tree

src/transformers/cache_utils.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -982,18 +982,38 @@ def update_recurrent_state(self, recurrent_states: torch.Tensor, layer_idx: int,
982982
return recurrent_states
983983

984984
def early_initialization(
985-
self, batch_size: int, num_heads: int, head_dim: int, dtype: torch.dtype, device: torch.device
985+
self,
986+
batch_size: int,
987+
num_heads: int | list[int],
988+
head_dim: int | list[int],
989+
dtype: torch.dtype,
990+
device: torch.device,
986991
):
987992
"""
988993
Initialize all the layers in advance (it's otherwise lazily initialized on the first `update` call).
989994
This is useful for our `export` recipes, as `export` needs everything in advance.
990995
"""
991-
# Note that the initialization needs all dimensions (except -2), as well as device and dtype, so we use
992-
# this fake tensor approach. It has size 0 on the -2 dimension, so it does not allocate any data (it only
993-
# creates an empty tensor with correct shape, dtype and device), which is very efficient and practical
994-
fake_kv_tensor = torch.zeros((batch_size, num_heads, 0, head_dim), dtype=dtype, device=device)
995-
# Init all layers
996-
for layer in self.layers:
996+
# To allow different num_heads and head_dim depending on layers, we accept lists
997+
if isinstance(num_heads, int):
998+
num_heads = [num_heads] * len(self)
999+
if isinstance(head_dim, int):
1000+
head_dim = [head_dim] * len(self)
1001+
1002+
if len(num_heads) != len(self.layers):
1003+
raise ValueError(
1004+
f"`num_head` was provided as a list of length {len(num_heads)}, but the Cache currently has {len(self.layers)} layers"
1005+
)
1006+
if len(head_dim) != len(self.layers):
1007+
raise ValueError(
1008+
f"`head_dim` was provided as a list of length {len(num_heads)}, but the Cache currently has {len(self.layers)} layers"
1009+
)
1010+
1011+
for layer, layer_num_heads, layer_head_dim in zip(self.layers, num_heads, head_dim):
1012+
# Note that the initialization needs all dimensions (except -2), as well as device and dtype, so we use
1013+
# this fake tensor approach. It has size 0 on the -2 dimension, so it does not allocate any data (it only
1014+
# creates an empty tensor with correct shape, dtype and device), which is very efficient and practical
1015+
fake_kv_tensor = torch.zeros((batch_size, layer_num_heads, 0, layer_head_dim), dtype=dtype, device=device)
1016+
# Init the layer
9971017
layer.lazy_initialization(fake_kv_tensor, fake_kv_tensor)
9981018

9991019
def get_seq_length(self, layer_idx: int = 0) -> int:

src/transformers/integrations/executorch.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,28 @@ def generate(
443443
return tokenizer.decode(generated_ids[0], skip_special_tokens=True)
444444

445445

446+
def get_head_shapes(config) -> tuple[int | list[int], int | list[int]]:
447+
"""Returns a tuple `(num_heads, head_dim)` containing either 2 ints, or a list of int with the value for each
448+
layer."""
449+
# Gemma4 has different head_dim and num_heads depending on layer type
450+
if hasattr(config, "global_head_dim"):
451+
head_dim = [
452+
config.global_head_dim if layer == "full_attention" else config.head_dim
453+
for layer in config.layer_types[: -config.num_kv_shared_layers]
454+
]
455+
num_heads = [
456+
config.num_global_key_value_heads
457+
if layer == "full_attention" and config.attention_k_eq_v
458+
else config.num_key_value_heads
459+
for layer in config.layer_types[: -config.num_kv_shared_layers]
460+
]
461+
else:
462+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
463+
num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
464+
465+
return num_heads, head_dim
466+
467+
446468
class TorchExportableModuleWithStaticCache(torch.nn.Module):
447469
"""
448470
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
@@ -523,9 +545,8 @@ def __init__(
523545
# simple StaticLayer... It means that any generation beyond the window is unfortunately unsupported
524546
for i, layer in enumerate(self.static_cache.layers):
525547
if isinstance(layer, StaticSlidingWindowLayer):
526-
self.static_cache.layers[i] = StaticLayer(layer.max_cache_len)
527-
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
528-
num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
548+
self.static_cache.layers[i] = StaticLayer(max_cache_len)
549+
num_heads, head_dim = get_head_shapes(config)
529550
dtype = self.model.dtype
530551
# We need this call to initialize all the layers (otherwise it's done lazily, which is not exportable)
531552
self.static_cache.early_initialization(batch_size, num_heads, head_dim, dtype, device)
@@ -702,9 +723,8 @@ def __init__(
702723
# simple StaticLayer... It means that any generation beyond the window is unfortunately unsupported
703724
for i, layer in enumerate(self.cache.layers):
704725
if isinstance(layer, StaticSlidingWindowLayer):
705-
self.cache.layers[i] = StaticLayer(layer.max_cache_len)
706-
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
707-
num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
726+
self.cache.layers[i] = StaticLayer(max_cache_len)
727+
num_heads, head_dim = get_head_shapes(config)
708728
dtype = self.model.dtype
709729
# We need this call to initialize all the layers (otherwise it's done lazily, which is not exportable)
710730
self.cache.early_initialization(batch_size, num_heads, head_dim, dtype, device)
@@ -856,9 +876,8 @@ def __init__(self, model, max_static_cache_length, batch_size):
856876
# simple StaticLayer... It means that any generation beyond the window is unfortunately unsupported
857877
for i, layer in enumerate(self.static_cache.layers):
858878
if isinstance(layer, StaticSlidingWindowLayer):
859-
self.static_cache.layers[i] = StaticLayer(layer.max_cache_len)
860-
head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads)
861-
num_heads = getattr(self.config, "num_key_value_heads", self.config.num_attention_heads)
879+
self.static_cache.layers[i] = StaticLayer(max_static_cache_length)
880+
num_heads, head_dim = get_head_shapes(self.config)
862881
self.static_cache.early_initialization(batch_size, num_heads, head_dim, torch.float32, model_device)
863882
self.cache = EncoderDecoderCache(self.static_cache, DynamicCache(config=self.config))
864883

0 commit comments

Comments
 (0)