Skip to content

Commit 95c3c8a

Browse files
authored
add dsa index aoa and log (#4490)
1 parent 02c97bc commit 95c3c8a

3 files changed

Lines changed: 61 additions & 0 deletions

File tree

paddleformers/trainer/trainer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2656,6 +2656,24 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
26562656
except (ImportError, AttributeError):
26572657
pass
26582658

2659+
# Add DSA indexer loss metrics if available
2660+
try:
2661+
from paddlefleet.transformer.dsa_attention import (
2662+
DSAIndexerLossLoggingHelper,
2663+
)
2664+
2665+
if DSAIndexerLossLoggingHelper.tracker.get("values") is not None:
2666+
loss_scale = 1.0 / self.args.gradient_accumulation_steps
2667+
DSAIndexerLossLoggingHelper.reduce_loss_in_tracker()
2668+
tracker = DSAIndexerLossLoggingHelper.tracker
2669+
indexer_loss_values = tracker["values"] * loss_scale
2670+
num_layers = indexer_loss_values.shape[0]
2671+
avg_indexer_loss = indexer_loss_values.sum() / num_layers
2672+
logs["indexer_loss"] = avg_indexer_loss.item()
2673+
DSAIndexerLossLoggingHelper.clean_loss_in_tracker()
2674+
except (ImportError, AttributeError):
2675+
pass
2676+
26592677
self._total_loss_scalar += tr_loss_scalar
26602678
self._globalstep_last_logged = self.state.global_step
26612679
self._globalstep_last_start_time = time.time()

paddleformers/trainer/training_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,12 @@ class TrainingArguments:
690690
)
691691
},
692692
)
693+
694+
dsa_indexer_loss_coeff: bool = field(
695+
default=0.01,
696+
metadata={"help": "Loss coefficient for the DSA indexer; controls the weight of the indexer loss term."},
697+
)
698+
693699
sharding_comm_group_call_opt: bool = field(
694700
default=False,
695701
metadata={

paddleformers/transformers/aoa_config_base.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ class MoEAOAConfigParams:
6363
# Runtime config
6464
model_prefix: str = "model."
6565

66+
index_n_heads: int = 0
67+
6668
# Extra statements to add
6769
extra_statements: List[str] = field(default_factory=list)
6870

@@ -129,6 +131,7 @@ def _extract_params(cls, config: Any) -> MoEAOAConfigParams:
129131
use_qk_norm=getattr(config, "use_qk_norm", False),
130132
has_shared_experts=cls._has_shared_experts(config),
131133
model_prefix=cls._get_model_prefix(config),
134+
index_n_heads=getattr(config, "index_n_heads", 0),
132135
)
133136

134137
@classmethod
@@ -369,6 +372,23 @@ def _get_mla_attention_statements(cls, params: MoEAOAConfigParams, prefix: str,
369372
]
370373
)
371374

375+
if params.index_n_heads and params.index_n_heads > 0:
376+
indexer_weights = [
377+
"wq_b",
378+
"wk",
379+
"weights_proj",
380+
]
381+
statements.extend(
382+
[
383+
f"{prefix}.self_attn.indexer.{weight_name}.weight^T -> {prefix_offset}.self_attn.core_attention.indexer.{weight_name}.weight"
384+
for weight_name in indexer_weights
385+
]
386+
)
387+
statements += [
388+
f"{prefix}.self_attn.indexer.k_norm.bias -> {prefix_offset}.self_attn.core_attention.indexer.k_norm.bias",
389+
f"{prefix}.self_attn.indexer.k_norm.weight -> {prefix_offset}.self_attn.core_attention.indexer.k_norm.weight",
390+
]
391+
372392
return statements
373393

374394
# ==================== MoE Expert Weights ====================
@@ -725,6 +745,23 @@ def _get_inv_mla_attention_statements(
725745
]
726746
)
727747

748+
if params.index_n_heads and params.index_n_heads > 0:
749+
indexer_weights = [
750+
"wq_b",
751+
"wk",
752+
"weights_proj",
753+
]
754+
statements.extend(
755+
[
756+
f"{prefix_offset}.self_attn.core_attention.indexer.{weight_name}.weight^T -> {prefix}.self_attn.indexer.{weight_name}.weight"
757+
for weight_name in indexer_weights
758+
]
759+
)
760+
statements += [
761+
f"{prefix_offset}.self_attn.core_attention.indexer.k_norm.bias -> {prefix}.self_attn.indexer.k_norm.bias",
762+
f"{prefix_offset}.self_attn.core_attention.indexer.k_norm.weight -> {prefix}.self_attn.indexer.k_norm.weight",
763+
]
764+
728765
return statements
729766

730767
# ==================== Inverse MoE Expert Weights ====================

0 commit comments

Comments
 (0)