@@ -246,6 +246,7 @@ def __call__(
246246 cache : LayerCache | None ,
247247 attn_mask : jax .Array ,
248248 kv_shared_cache : LayerCache | None = None ,
249+ skip_sliding_mask : bool = False ,
249250 ) -> tuple [LayerCache | None , jax .Array ]:
250251 """Applies multi-head attention to the inputs.
251252
@@ -255,6 +256,7 @@ def __call__(
255256 cache: KV cache or None.
256257 attn_mask: Attention mask of shape [batch_size, seq_len, cache_size].
257258 kv_shared_cache: Cache for shared KV layers.
259+ skip_sliding_mask: If True, skip the sliding mask.
258260
259261 Returns:
260262 cache: Updated attention KV cache.
@@ -335,7 +337,7 @@ def __call__(
335337 logits = jnp .tanh (logits / self .attn_logits_soft_cap )
336338 logits = logits * self .attn_logits_soft_cap
337339
338- if self .attn_type == AttentionType .LOCAL_SLIDING :
340+ if self .attn_type == AttentionType .LOCAL_SLIDING and not skip_sliding_mask :
339341 if self .sliding_window_size is None :
340342 raise ValueError (
341343 'Sliding_window_size must be set if Local Sliding attention type'
@@ -596,6 +598,7 @@ def __call__(
596598 attn_mask : jax .Array ,
597599 per_layer_input : jax .Array | None = None ,
598600 kv_shared_cache : LayerCache | None = None ,
601+ skip_sliding_mask : bool = False ,
599602 ) -> tuple [LayerCache | None , jax .Array ]:
600603 """Applies the block to the inputs.
601604
@@ -607,6 +610,7 @@ def __call__(
607610 per_layer_input: Per-layer input of shape [batch_size, seq_len,
608611 per_layer_input_dim].
609612 kv_shared_cache: Cache for shared KV layers.
613+ skip_sliding_mask: If True, skip the sliding mask.
610614
611615 Returns:
612616 cache: Updated attention KV cache.
@@ -621,6 +625,7 @@ def __call__(
621625 cache ,
622626 attn_mask ,
623627 kv_shared_cache ,
628+ skip_sliding_mask = skip_sliding_mask ,
624629 )
625630
626631 if self .post_attention_norm is not None :
0 commit comments