Skip to content

Commit 15e1e11

Browse files
committed
[NNX] Delete Linen (2/4): remove the Linen decoder stack and dead *_as_linen wrappers
Delete TransformerLinenPure, the Linen Decoder/DecoderLayer/SequentialBlockDecoderLayers stack (decoders.py), and the dead *_as_linen ToLinen wrappers across the layer/model files. The wrapped NNX classes are unchanged; transformer_as_linen (the NNX->Linen bridge) is kept for the checkpoint-conversion tools.
1 parent 95026c0 commit 15e1e11

13 files changed

Lines changed: 44 additions & 2746 deletions

src/maxtext/layers/attention_mla.py

Lines changed: 1 addition & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,8 @@
6363
DEFAULT_MASK_VALUE,
6464
)
6565

66-
from maxtext.layers import nnx_wrappers
6766
from maxtext.layers.attentions import Attention
68-
from maxtext.layers.initializers import nd_dense_init, NdInitializer, variable_to_logically_partitioned
67+
from maxtext.layers.initializers import nd_dense_init, NdInitializer
6968
from maxtext.layers.linears import DenseGeneral
7069
from maxtext.layers.normalizations import RMSNorm
7170
from maxtext.layers.quantizations import AqtQuantization as Quant
@@ -381,141 +380,6 @@ def __call__(
381380
return indexer_mask, topk_indices, indexer_score
382381

383382

384-
def mla_as_linen(
385-
*,
386-
config: Config,
387-
num_query_heads: int,
388-
num_kv_heads: int,
389-
head_dim: int,
390-
max_target_length: int,
391-
mesh: Mesh,
392-
attention_kernel: str,
393-
inputs_q_shape: Tuple,
394-
inputs_kv_shape: Tuple,
395-
dtype: DType = jnp.float32,
396-
weight_dtype: DType = jnp.float32,
397-
max_prefill_predict_length: int = -1,
398-
dropout_rate: float = 0.0,
399-
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal"),
400-
float32_qk_product: bool = False, # computes logits in float32 for stability.
401-
float32_logits: bool = False, # cast logits in float32 for stability.
402-
quant: Optional[Quant] = None,
403-
kv_quant: Optional[KVQuant] = None,
404-
attention_type: AttentionType = AttentionType.MLA, # Default to MLA attention
405-
attn_logits_soft_cap: float | None = None,
406-
sliding_window_size: int | None = None,
407-
use_ragged_attention: bool = False,
408-
ragged_block_size: int = 256,
409-
use_qk_norm: bool = False,
410-
query_pre_attn_scalar: float | None = None,
411-
use_bias_in_projections: bool = False, # Set to True will enable bias in q, k, v, o projections
412-
# Temperature tuning parameters used for Llama4
413-
temperature_tuning: bool = False,
414-
temperature_tuning_scale: float = 0.1,
415-
temperature_tuning_floor_scale: float = 8192.0,
416-
# Shard the query activation as the same as the key and value.
417-
# TODO: Find a better sharding axis name.
418-
# TODO: Further break down the Training and Inference axes for the q, k, v.
419-
prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
420-
prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
421-
prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
422-
query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
423-
key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
424-
value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
425-
input_axis_names: AxisNames = (BATCH_ATTN, LENGTH, EMBED),
426-
out_axis_names: AxisNames = (BATCH_ATTN, LENGTH, HEAD, D_KV),
427-
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED),
428-
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED),
429-
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
430-
decode_out_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV),
431-
prefill_cache_axis_order: AxisIdxes = (1, 2, 0, 3),
432-
ar_cache_axis_order: AxisIdxes = (1, 2, 0, 3),
433-
compute_axis_order: AxisIdxes = (0, 1, 2, 3),
434-
reshape_q: bool = False,
435-
is_nope_layer: bool = False,
436-
is_vision: bool = False,
437-
model_mode: str = MODEL_MODE_TRAIN,
438-
q_lora_rank: int = 0,
439-
kv_lora_rank: int = 512,
440-
qk_nope_head_dim: int = 128,
441-
qk_rope_head_dim: int = 64,
442-
v_head_dim: int = 128,
443-
max_position_embeddings: int = 4096 * 4,
444-
original_max_position_embeddings: int = 4096,
445-
mscale: float = 1.0, # scaling factor for softmax
446-
rope_factor: float = 40.0, # rotary embedding factor
447-
name: str | None = None,
448-
):
449-
"""A factory function to create an MLA as a Linen module.
450-
451-
This function serves as a bridge to use the NNX-based `MLA` within a
452-
Linen model.
453-
"""
454-
return nnx_wrappers.to_linen(
455-
MLA,
456-
config=config,
457-
num_query_heads=num_query_heads,
458-
num_kv_heads=num_kv_heads,
459-
head_dim=head_dim,
460-
max_target_length=max_target_length,
461-
mesh=mesh,
462-
attention_kernel=attention_kernel,
463-
inputs_q_shape=inputs_q_shape,
464-
inputs_kv_shape=inputs_kv_shape,
465-
dtype=dtype,
466-
weight_dtype=weight_dtype,
467-
max_prefill_predict_length=max_prefill_predict_length,
468-
dropout_rate=dropout_rate,
469-
kernel_init=kernel_init,
470-
float32_qk_product=float32_qk_product,
471-
float32_logits=float32_logits,
472-
quant=quant,
473-
kv_quant=kv_quant,
474-
attention_type=attention_type,
475-
attn_logits_soft_cap=attn_logits_soft_cap,
476-
sliding_window_size=sliding_window_size,
477-
use_ragged_attention=use_ragged_attention,
478-
ragged_block_size=ragged_block_size,
479-
use_qk_norm=use_qk_norm,
480-
query_pre_attn_scalar=query_pre_attn_scalar,
481-
use_bias_in_projections=use_bias_in_projections,
482-
temperature_tuning=temperature_tuning,
483-
temperature_tuning_scale=temperature_tuning_scale,
484-
temperature_tuning_floor_scale=temperature_tuning_floor_scale,
485-
prefill_query_axis_names=prefill_query_axis_names,
486-
prefill_key_axis_names=prefill_key_axis_names,
487-
prefill_value_axis_names=prefill_value_axis_names,
488-
query_axis_names=query_axis_names,
489-
key_axis_names=key_axis_names,
490-
value_axis_names=value_axis_names,
491-
input_axis_names=input_axis_names,
492-
out_axis_names=out_axis_names,
493-
prefill_input_axis_names=prefill_input_axis_names,
494-
decode_input_axis_names=decode_input_axis_names,
495-
prefill_out_axis_names=prefill_out_axis_names,
496-
decode_out_axis_names=decode_out_axis_names,
497-
prefill_cache_axis_order=prefill_cache_axis_order,
498-
ar_cache_axis_order=ar_cache_axis_order,
499-
compute_axis_order=compute_axis_order,
500-
reshape_q=reshape_q,
501-
is_nope_layer=is_nope_layer,
502-
is_vision=is_vision,
503-
model_mode=model_mode,
504-
q_lora_rank=q_lora_rank,
505-
kv_lora_rank=kv_lora_rank,
506-
qk_nope_head_dim=qk_nope_head_dim,
507-
qk_rope_head_dim=qk_rope_head_dim,
508-
v_head_dim=v_head_dim,
509-
max_position_embeddings=max_position_embeddings,
510-
original_max_position_embeddings=original_max_position_embeddings,
511-
mscale=mscale,
512-
rope_factor=rope_factor,
513-
name=name,
514-
metadata_fn=variable_to_logically_partitioned,
515-
abstract_init=False,
516-
)
517-
518-
519383
class MLA(Attention):
520384
"""Multi-Head Latent Attention (MLA) layer."""
521385

src/maxtext/layers/attention_op.py

Lines changed: 0 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@
6969
from maxtext.kernels.attention.ragged_attention import ragged_gqa
7070
from maxtext.kernels.attention.ragged_attention import ragged_mha
7171
from maxtext.layers import nnx_wrappers
72-
from maxtext.layers.initializers import variable_to_logically_partitioned
7372
from maxtext.layers.quantizations import AqtQuantization as Quant
7473
from maxtext.utils import max_utils
7574
from maxtext.utils.sharding import logical_to_mesh_axes, maybe_shard_with_pspec
@@ -285,100 +284,6 @@ def _make_bidirectional_block_mask(bidirectional_mask):
285284
return bidirectional_block_mask
286285

287286

288-
def attention_op_as_linen(
289-
*,
290-
config: Config,
291-
mesh: Mesh,
292-
attention_kernel: str,
293-
max_target_length: int,
294-
num_query_heads: int,
295-
num_kv_heads: int,
296-
float32_qk_product: bool = False,
297-
max_prefill_predict_length: int = -1,
298-
float32_logits: bool = False,
299-
flash_axis_names_q: AxisNames = (BATCH_ATTN, HEAD, LENGTH, D_KV),
300-
flash_axis_names_kv: AxisNames = (BATCH_ATTN, HEAD, KV_LENGTH, D_KV),
301-
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH),
302-
prefill_cache_logical_axis_names: AxisNames = (
303-
CACHE_BATCH_PREFILL,
304-
CACHE_SEQUENCE,
305-
CACHE_HEADS,
306-
CACHE_KV,
307-
),
308-
cache_logical_axis_names: AxisNames = (
309-
CACHE_BATCH,
310-
CACHE_SEQUENCE,
311-
CACHE_HEADS,
312-
CACHE_KV,
313-
),
314-
cache_scale_logical_axis_names: AxisNames = (
315-
CACHE_SCALE_BATCH,
316-
CACHE_SCALE_SEQUENCE,
317-
CACHE_SCALE_HEADS,
318-
CACHE_SCALE_KV,
319-
),
320-
ragged_qkv_axis_names: AxisNames = (
321-
CACHE_BATCH,
322-
CACHE_HEADS,
323-
CACHE_SEQUENCE,
324-
CACHE_KV,
325-
),
326-
ragged_lengths_names: AxisNames = (CACHE_BATCH,),
327-
compute_axis_order: AxisIdxes = (0, 1, 2, 3),
328-
key_axis_order: AxisIdxes = (2, 0, 1, 3),
329-
reshape_q: bool = False,
330-
dropout_rate: float = 0.0,
331-
dtype: DType = jnp.float32,
332-
quant: Optional[Quant] = None,
333-
kv_quant: Optional[KVQuant] = None,
334-
attention_type: AttentionType = AttentionType.GLOBAL, # Default to global attention
335-
attn_logits_soft_cap: float | None = None,
336-
sliding_window_size: int | None = None,
337-
chunk_attn_window_size: int | None = None,
338-
use_ragged_attention: bool = False,
339-
ragged_block_size: int = 256,
340-
):
341-
"""A factory function to create an AttentionOp as a Linen module.
342-
343-
This function serves as a bridge to use the NNX-based `AttentionOp` within a
344-
Linen model.
345-
"""
346-
return nnx_wrappers.to_linen(
347-
AttentionOp,
348-
config=config,
349-
mesh=mesh,
350-
attention_kernel=attention_kernel,
351-
max_target_length=max_target_length,
352-
num_query_heads=num_query_heads,
353-
num_kv_heads=num_kv_heads,
354-
float32_qk_product=float32_qk_product,
355-
max_prefill_predict_length=max_prefill_predict_length,
356-
float32_logits=float32_logits,
357-
flash_axis_names_q=flash_axis_names_q,
358-
flash_axis_names_kv=flash_axis_names_kv,
359-
flash_axis_names_splash_kernel=flash_axis_names_splash_kernel,
360-
prefill_cache_logical_axis_names=prefill_cache_logical_axis_names,
361-
cache_logical_axis_names=cache_logical_axis_names,
362-
cache_scale_logical_axis_names=cache_scale_logical_axis_names,
363-
ragged_qkv_axis_names=ragged_qkv_axis_names,
364-
ragged_lengths_names=ragged_lengths_names,
365-
compute_axis_order=compute_axis_order,
366-
key_axis_order=key_axis_order,
367-
reshape_q=reshape_q,
368-
dropout_rate=dropout_rate,
369-
dtype=dtype,
370-
quant=quant,
371-
kv_quant=kv_quant,
372-
attention_type=attention_type,
373-
attn_logits_soft_cap=attn_logits_soft_cap,
374-
sliding_window_size=sliding_window_size,
375-
chunk_attn_window_size=chunk_attn_window_size,
376-
use_ragged_attention=use_ragged_attention,
377-
ragged_block_size=ragged_block_size,
378-
metadata_fn=variable_to_logically_partitioned,
379-
)
380-
381-
382287
class AttentionOp(nnx.Module):
383288
"""Attention operation"""
384289

src/maxtext/layers/attentions.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,6 @@ def __call__(self, x):
8989
return x * jax.lax.rsqrt(jnp.mean(x**2, axis=-1, keepdims=True) + self.eps)
9090

9191

92-
def l2_norm_as_linen(self, eps: float = 1e-6):
93-
"""
94-
Initializes the L2Norm module and returns it as a Linen module.
95-
96-
Args:
97-
eps: float, epsilon used for numerical stability (default value should be ok for most cases).
98-
"""
99-
return nnx_wrappers.to_linen(L2Norm, eps=eps, metadata_fn=variable_to_logically_partitioned)
100-
101-
10292
def attention_as_linen(
10393
*,
10494
config: Config,

0 commit comments

Comments
 (0)