Skip to content

Commit 3ef1786

Browse files
authored
add aoa for dsv4 hybrid attn (#4508)
1 parent f3fc3f2 commit 3ef1786

1 file changed

Lines changed: 83 additions & 0 deletions

File tree

paddleformers/transformers/minimax_m2/modeling.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,48 @@ def _gen_aoa_config(cls, config: MiniMaxM2Config):
430430
f"{prefix}.self_attn.kv_a_layernorm.weight -> {prefix_offset}.self_attn.kv_a_layernorm.weight",
431431
]
432432

433+
elif config.experimental_attention_variant == "dsv4_hybrid":
434+
# csa_compress_ratios has length num_hidden_layers + num_nextn_predict_layers,
435+
# i.e. it covers both main layers and MTP layers.
436+
assert len(config.csa_compress_ratios) == num_hidden_layers + num_nextn_predict_layers, (
437+
f"csa_compress_ratios length ({len(config.csa_compress_ratios)}) must equal "
438+
f"num_hidden_layers + num_nextn_predict_layers "
439+
f"({num_hidden_layers} + {num_nextn_predict_layers})"
440+
)
441+
csa_ratio = config.csa_compress_ratios[layer_idx]
442+
aoa_config["aoa_statements"] += [
443+
# Linear projections (transpose: HF [out, in] -> paddle [in, out])
444+
f"{prefix}.self_attn.linear_q_down_proj.weight^T -> {prefix_offset}.self_attn.linear_q_down_proj.weight",
445+
f"{prefix}.self_attn.linear_q_up_proj.weight^T -> {prefix_offset}.self_attn.linear_q_up_proj.weight",
446+
f"{prefix}.self_attn.linear_kv_proj.weight^T -> {prefix_offset}.self_attn.linear_kv_proj.weight",
447+
f"{prefix}.self_attn.o_proj.weight^T -> {prefix_offset}.self_attn.o_proj.weight",
448+
# Layer norms (no transpose, 1D)
449+
f"{prefix}.self_attn.q_layernorm.weight -> {prefix_offset}.self_attn.q_layernorm.weight",
450+
f"{prefix}.self_attn.kv_layernorm.weight -> {prefix_offset}.self_attn.kv_layernorm.weight",
451+
# Grouped output projection (raw parameter, shape [out, in] on both sides)
452+
f"{prefix}.self_attn.linear_o_group_proj -> {prefix_offset}.self_attn.linear_o_group_proj",
453+
# Core attention: learnable attention sink (1D, no transpose)
454+
f"{prefix}.self_attn.core_attention.attn_sink -> {prefix_offset}.self_attn.core_attention.attn_sink",
455+
]
456+
# Compressor exists only when compress_ratio > 1 (i.e. ratio in {4, 128})
457+
if csa_ratio > 1:
458+
aoa_config["aoa_statements"] += [
459+
f"{prefix}.self_attn.core_attention.compressor.linear_wkv.weight^T -> {prefix_offset}.self_attn.core_attention.compressor.linear_wkv.weight",
460+
f"{prefix}.self_attn.core_attention.compressor.linear_wgate.weight^T -> {prefix_offset}.self_attn.core_attention.compressor.linear_wgate.weight",
461+
f"{prefix}.self_attn.core_attention.compressor.norm.weight -> {prefix_offset}.self_attn.core_attention.compressor.norm.weight",
462+
f"{prefix}.self_attn.core_attention.compressor.ape -> {prefix_offset}.self_attn.core_attention.compressor.ape",
463+
]
464+
# Indexer exists only when compress_ratio == 4 and not csa_dense_mode
465+
if csa_ratio == 4 and not getattr(config, "csa_dense_mode", False):
466+
aoa_config["aoa_statements"] += [
467+
f"{prefix}.self_attn.core_attention.indexer.linear_wq_b.weight^T -> {prefix_offset}.self_attn.core_attention.indexer.linear_wq_b.weight",
468+
f"{prefix}.self_attn.core_attention.indexer.linear_weights_proj.weight^T -> {prefix_offset}.self_attn.core_attention.indexer.linear_weights_proj.weight",
469+
f"{prefix}.self_attn.core_attention.indexer.compressor.linear_wkv.weight^T -> {prefix_offset}.self_attn.core_attention.indexer.compressor.linear_wkv.weight",
470+
f"{prefix}.self_attn.core_attention.indexer.compressor.linear_wgate.weight^T -> {prefix_offset}.self_attn.core_attention.indexer.compressor.linear_wgate.weight",
471+
f"{prefix}.self_attn.core_attention.indexer.compressor.norm.weight -> {prefix_offset}.self_attn.core_attention.indexer.compressor.norm.weight",
472+
f"{prefix}.self_attn.core_attention.indexer.compressor.ape -> {prefix_offset}.self_attn.core_attention.indexer.compressor.ape",
473+
]
474+
433475
else:
434476
if config.use_qk_norm:
435477
aoa_config["aoa_statements"] += [
@@ -657,6 +699,47 @@ def _gen_inv_aoa_config(cls, config: MiniMaxM2Config):
657699
f"{prefix_offset}.self_attn.q_a_layernorm.weight -> {prefix}.self_attn.q_a_layernorm.weight",
658700
f"{prefix_offset}.self_attn.kv_a_layernorm.weight -> {prefix}.self_attn.kv_a_layernorm.weight",
659701
]
702+
elif config.experimental_attention_variant == "dsv4_hybrid":
703+
# csa_compress_ratios has length num_hidden_layers + num_nextn_predict_layers,
704+
# i.e. it covers both main layers and MTP layers.
705+
assert len(config.csa_compress_ratios) == num_hidden_layers + num_nextn_predict_layers, (
706+
f"csa_compress_ratios length ({len(config.csa_compress_ratios)}) must equal "
707+
f"num_hidden_layers + num_nextn_predict_layers "
708+
f"({num_hidden_layers} + {num_nextn_predict_layers})"
709+
)
710+
csa_ratio = config.csa_compress_ratios[layer_idx]
711+
aoa_statements += [
712+
# Linear projections (transpose: paddle [in, out] -> HF [out, in])
713+
f"{prefix_offset}.self_attn.linear_q_down_proj.weight^T -> {prefix}.self_attn.linear_q_down_proj.weight",
714+
f"{prefix_offset}.self_attn.linear_q_up_proj.weight^T -> {prefix}.self_attn.linear_q_up_proj.weight",
715+
f"{prefix_offset}.self_attn.linear_kv_proj.weight^T -> {prefix}.self_attn.linear_kv_proj.weight",
716+
f"{prefix_offset}.self_attn.o_proj.weight^T -> {prefix}.self_attn.o_proj.weight",
717+
# Layer norms (no transpose, 1D)
718+
f"{prefix_offset}.self_attn.q_layernorm.weight -> {prefix}.self_attn.q_layernorm.weight",
719+
f"{prefix_offset}.self_attn.kv_layernorm.weight -> {prefix}.self_attn.kv_layernorm.weight",
720+
# Grouped output projection (raw parameter, shape [out, in] on both sides)
721+
f"{prefix_offset}.self_attn.linear_o_group_proj -> {prefix}.self_attn.linear_o_group_proj",
722+
# Core attention: learnable attention sink (1D, no transpose)
723+
f"{prefix_offset}.self_attn.core_attention.attn_sink -> {prefix}.self_attn.core_attention.attn_sink",
724+
]
725+
# Compressor exists only when compress_ratio > 1 (i.e. ratio in {4, 128})
726+
if csa_ratio > 1:
727+
aoa_statements += [
728+
f"{prefix_offset}.self_attn.core_attention.compressor.linear_wkv.weight^T -> {prefix}.self_attn.core_attention.compressor.linear_wkv.weight",
729+
f"{prefix_offset}.self_attn.core_attention.compressor.linear_wgate.weight^T -> {prefix}.self_attn.core_attention.compressor.linear_wgate.weight",
730+
f"{prefix_offset}.self_attn.core_attention.compressor.norm.weight -> {prefix}.self_attn.core_attention.compressor.norm.weight",
731+
f"{prefix_offset}.self_attn.core_attention.compressor.ape -> {prefix}.self_attn.core_attention.compressor.ape",
732+
]
733+
# Indexer exists only when compress_ratio == 4 and not csa_dense_mode
734+
if csa_ratio == 4 and not getattr(config, "csa_dense_mode", False):
735+
aoa_statements += [
736+
f"{prefix_offset}.self_attn.core_attention.indexer.linear_wq_b.weight^T -> {prefix}.self_attn.core_attention.indexer.linear_wq_b.weight",
737+
f"{prefix_offset}.self_attn.core_attention.indexer.linear_weights_proj.weight^T -> {prefix}.self_attn.core_attention.indexer.linear_weights_proj.weight",
738+
f"{prefix_offset}.self_attn.core_attention.indexer.compressor.linear_wkv.weight^T -> {prefix}.self_attn.core_attention.indexer.compressor.linear_wkv.weight",
739+
f"{prefix_offset}.self_attn.core_attention.indexer.compressor.linear_wgate.weight^T -> {prefix}.self_attn.core_attention.indexer.compressor.linear_wgate.weight",
740+
f"{prefix_offset}.self_attn.core_attention.indexer.compressor.norm.weight -> {prefix}.self_attn.core_attention.indexer.compressor.norm.weight",
741+
f"{prefix_offset}.self_attn.core_attention.indexer.compressor.ape -> {prefix}.self_attn.core_attention.indexer.compressor.ape",
742+
]
660743
else:
661744
if config.use_qk_norm:
662745
aoa_statements += [

0 commit comments

Comments
 (0)