Skip to content

Commit 5621d2d

Browse files
author
The gemma Authors
committed
Allow skipping the automatic sliding mask
PiperOrigin-RevId: 919592946
1 parent e176e8a commit 5621d2d

2 files changed

Lines changed: 9 additions & 1 deletion

File tree

gemma/gm/nn/gemma4/_modules.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ def __call__(
246246
cache: LayerCache | None,
247247
attn_mask: jax.Array,
248248
kv_shared_cache: LayerCache | None = None,
249+
skip_sliding_mask: bool = False,
249250
) -> tuple[LayerCache | None, jax.Array]:
250251
"""Applies multi-head attention to the inputs.
251252
@@ -255,6 +256,7 @@ def __call__(
255256
cache: KV cache or None.
256257
attn_mask: Attention mask of shape [batch_size, seq_len, cache_size].
257258
kv_shared_cache: Cache for shared KV layers.
259+
skip_sliding_mask: If True, skip the sliding mask.
258260
259261
Returns:
260262
cache: Updated attention KV cache.
@@ -335,7 +337,7 @@ def __call__(
335337
logits = jnp.tanh(logits / self.attn_logits_soft_cap)
336338
logits = logits * self.attn_logits_soft_cap
337339

338-
if self.attn_type == AttentionType.LOCAL_SLIDING:
340+
if self.attn_type == AttentionType.LOCAL_SLIDING and not skip_sliding_mask:
339341
if self.sliding_window_size is None:
340342
raise ValueError(
341343
'Sliding_window_size must be set if Local Sliding attention type'
@@ -596,6 +598,7 @@ def __call__(
596598
attn_mask: jax.Array,
597599
per_layer_input: jax.Array | None = None,
598600
kv_shared_cache: LayerCache | None = None,
601+
skip_sliding_mask: bool = False,
599602
) -> tuple[LayerCache | None, jax.Array]:
600603
"""Applies the block to the inputs.
601604
@@ -607,6 +610,7 @@ def __call__(
607610
per_layer_input: Per-layer input of shape [batch_size, seq_len,
608611
per_layer_input_dim].
609612
kv_shared_cache: Cache for shared KV layers.
613+
skip_sliding_mask: If True, skip the sliding mask.
610614
611615
Returns:
612616
cache: Updated attention KV cache.
@@ -621,6 +625,7 @@ def __call__(
621625
cache,
622626
attn_mask,
623627
kv_shared_cache,
628+
skip_sliding_mask=skip_sliding_mask,
624629
)
625630

626631
if self.post_attention_norm is not None:

gemma/gm/nn/gemma4/_transformer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,11 +366,13 @@ def _apply_attention(
366366
kv_shared_cache = None
367367
# Select the appropriate attention mask for this layer type.
368368
attn_mask = inputs.attention_mask
369+
skip_sliding_mask = False
369370
if (
370371
inputs.sliding_attention_mask is not None
371372
and block.attn_type == _modules.AttentionType.LOCAL_SLIDING
372373
):
373374
attn_mask = inputs.sliding_attention_mask
375+
skip_sliding_mask = True
374376
layer_cache, x = block(
375377
x,
376378
inputs.positions,
@@ -380,6 +382,7 @@ def _apply_attention(
380382
if self.config.per_layer_input_dim
381383
else None,
382384
kv_shared_cache=kv_shared_cache,
385+
skip_sliding_mask=skip_sliding_mask,
383386
)
384387
new_cache[layer_name] = layer_cache # pytype: disable=container-type-mismatch
385388

0 commit comments

Comments
 (0)