|
63 | 63 | DEFAULT_MASK_VALUE, |
64 | 64 | ) |
65 | 65 |
|
66 | | -from maxtext.layers import nnx_wrappers |
67 | 66 | from maxtext.layers.attentions import Attention |
68 | | -from maxtext.layers.initializers import nd_dense_init, NdInitializer, variable_to_logically_partitioned |
| 67 | +from maxtext.layers.initializers import nd_dense_init, NdInitializer |
69 | 68 | from maxtext.layers.linears import DenseGeneral |
70 | 69 | from maxtext.layers.normalizations import RMSNorm |
71 | 70 | from maxtext.layers.quantizations import AqtQuantization as Quant |
@@ -381,141 +380,6 @@ def __call__( |
381 | 380 | return indexer_mask, topk_indices, indexer_score |
382 | 381 |
|
383 | 382 |
|
384 | | -def mla_as_linen( |
385 | | - *, |
386 | | - config: Config, |
387 | | - num_query_heads: int, |
388 | | - num_kv_heads: int, |
389 | | - head_dim: int, |
390 | | - max_target_length: int, |
391 | | - mesh: Mesh, |
392 | | - attention_kernel: str, |
393 | | - inputs_q_shape: Tuple, |
394 | | - inputs_kv_shape: Tuple, |
395 | | - dtype: DType = jnp.float32, |
396 | | - weight_dtype: DType = jnp.float32, |
397 | | - max_prefill_predict_length: int = -1, |
398 | | - dropout_rate: float = 0.0, |
399 | | - kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal"), |
400 | | - float32_qk_product: bool = False, # computes logits in float32 for stability. |
401 | | - float32_logits: bool = False, # cast logits in float32 for stability. |
402 | | - quant: Optional[Quant] = None, |
403 | | - kv_quant: Optional[KVQuant] = None, |
404 | | - attention_type: AttentionType = AttentionType.MLA, # Default to MLA attention |
405 | | - attn_logits_soft_cap: float | None = None, |
406 | | - sliding_window_size: int | None = None, |
407 | | - use_ragged_attention: bool = False, |
408 | | - ragged_block_size: int = 256, |
409 | | - use_qk_norm: bool = False, |
410 | | - query_pre_attn_scalar: float | None = None, |
411 | | - use_bias_in_projections: bool = False, # Set to True will enable bias in q, k, v, o projections |
412 | | - # Temperature tuning parameters used for Llama4 |
413 | | - temperature_tuning: bool = False, |
414 | | - temperature_tuning_scale: float = 0.1, |
415 | | - temperature_tuning_floor_scale: float = 8192.0, |
416 | | - # Shard the query activation as the same as the key and value. |
417 | | - # TODO: Find a better sharding axis name. |
418 | | - # TODO: Further break down the Training and Inference axes for the q, k, v. |
419 | | - prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
420 | | - prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
421 | | - prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
422 | | - query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), |
423 | | - key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), |
424 | | - value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), |
425 | | - input_axis_names: AxisNames = (BATCH_ATTN, LENGTH, EMBED), |
426 | | - out_axis_names: AxisNames = (BATCH_ATTN, LENGTH, HEAD, D_KV), |
427 | | - prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED), |
428 | | - decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED), |
429 | | - prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV), |
430 | | - decode_out_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV), |
431 | | - prefill_cache_axis_order: AxisIdxes = (1, 2, 0, 3), |
432 | | - ar_cache_axis_order: AxisIdxes = (1, 2, 0, 3), |
433 | | - compute_axis_order: AxisIdxes = (0, 1, 2, 3), |
434 | | - reshape_q: bool = False, |
435 | | - is_nope_layer: bool = False, |
436 | | - is_vision: bool = False, |
437 | | - model_mode: str = MODEL_MODE_TRAIN, |
438 | | - q_lora_rank: int = 0, |
439 | | - kv_lora_rank: int = 512, |
440 | | - qk_nope_head_dim: int = 128, |
441 | | - qk_rope_head_dim: int = 64, |
442 | | - v_head_dim: int = 128, |
443 | | - max_position_embeddings: int = 4096 * 4, |
444 | | - original_max_position_embeddings: int = 4096, |
445 | | - mscale: float = 1.0, # scaling factor for softmax |
446 | | - rope_factor: float = 40.0, # rotary embedding factor |
447 | | - name: str | None = None, |
448 | | -): |
449 | | - """A factory function to create an MLA as a Linen module. |
450 | | -
|
451 | | - This function serves as a bridge to use the NNX-based `MLA` within a |
452 | | - Linen model. |
453 | | - """ |
454 | | - return nnx_wrappers.to_linen( |
455 | | - MLA, |
456 | | - config=config, |
457 | | - num_query_heads=num_query_heads, |
458 | | - num_kv_heads=num_kv_heads, |
459 | | - head_dim=head_dim, |
460 | | - max_target_length=max_target_length, |
461 | | - mesh=mesh, |
462 | | - attention_kernel=attention_kernel, |
463 | | - inputs_q_shape=inputs_q_shape, |
464 | | - inputs_kv_shape=inputs_kv_shape, |
465 | | - dtype=dtype, |
466 | | - weight_dtype=weight_dtype, |
467 | | - max_prefill_predict_length=max_prefill_predict_length, |
468 | | - dropout_rate=dropout_rate, |
469 | | - kernel_init=kernel_init, |
470 | | - float32_qk_product=float32_qk_product, |
471 | | - float32_logits=float32_logits, |
472 | | - quant=quant, |
473 | | - kv_quant=kv_quant, |
474 | | - attention_type=attention_type, |
475 | | - attn_logits_soft_cap=attn_logits_soft_cap, |
476 | | - sliding_window_size=sliding_window_size, |
477 | | - use_ragged_attention=use_ragged_attention, |
478 | | - ragged_block_size=ragged_block_size, |
479 | | - use_qk_norm=use_qk_norm, |
480 | | - query_pre_attn_scalar=query_pre_attn_scalar, |
481 | | - use_bias_in_projections=use_bias_in_projections, |
482 | | - temperature_tuning=temperature_tuning, |
483 | | - temperature_tuning_scale=temperature_tuning_scale, |
484 | | - temperature_tuning_floor_scale=temperature_tuning_floor_scale, |
485 | | - prefill_query_axis_names=prefill_query_axis_names, |
486 | | - prefill_key_axis_names=prefill_key_axis_names, |
487 | | - prefill_value_axis_names=prefill_value_axis_names, |
488 | | - query_axis_names=query_axis_names, |
489 | | - key_axis_names=key_axis_names, |
490 | | - value_axis_names=value_axis_names, |
491 | | - input_axis_names=input_axis_names, |
492 | | - out_axis_names=out_axis_names, |
493 | | - prefill_input_axis_names=prefill_input_axis_names, |
494 | | - decode_input_axis_names=decode_input_axis_names, |
495 | | - prefill_out_axis_names=prefill_out_axis_names, |
496 | | - decode_out_axis_names=decode_out_axis_names, |
497 | | - prefill_cache_axis_order=prefill_cache_axis_order, |
498 | | - ar_cache_axis_order=ar_cache_axis_order, |
499 | | - compute_axis_order=compute_axis_order, |
500 | | - reshape_q=reshape_q, |
501 | | - is_nope_layer=is_nope_layer, |
502 | | - is_vision=is_vision, |
503 | | - model_mode=model_mode, |
504 | | - q_lora_rank=q_lora_rank, |
505 | | - kv_lora_rank=kv_lora_rank, |
506 | | - qk_nope_head_dim=qk_nope_head_dim, |
507 | | - qk_rope_head_dim=qk_rope_head_dim, |
508 | | - v_head_dim=v_head_dim, |
509 | | - max_position_embeddings=max_position_embeddings, |
510 | | - original_max_position_embeddings=original_max_position_embeddings, |
511 | | - mscale=mscale, |
512 | | - rope_factor=rope_factor, |
513 | | - name=name, |
514 | | - metadata_fn=variable_to_logically_partitioned, |
515 | | - abstract_init=False, |
516 | | - ) |
517 | | - |
518 | | - |
519 | 383 | class MLA(Attention): |
520 | 384 | """Multi-Head Latent Attention (MLA) layer.""" |
521 | 385 |
|
|
0 commit comments