|
85 | 85 | ] |
86 | 86 |
|
87 | 87 |
|
| 88 | +class _FusedLayerNormProxy(torch.nn.Module): |
| 89 | + """Proxy module exposing fused layernorm weights from TELayerNormColumnParallelLinear. |
| 90 | +
|
| 91 | + When using TE spec, the input layernorm and pre-MLP layernorm are fused into the |
| 92 | + subsequent linear layer (TELayerNormColumnParallelLinear). The layernorm weight is |
| 93 | + stored as ``layer_norm_weight`` on the fused linear module rather than as a separate |
| 94 | + ``weight`` on a standalone layernorm module. |
| 95 | +
|
| 96 | + This proxy wraps that fused weight so the existing export rules (which expect a |
| 97 | + module with a ``.weight`` attribute) can export it with the correct HF key name. |
| 98 | + """ |
| 99 | + |
| 100 | + def __init__(self, fused_linear: torch.nn.Module): |
| 101 | + super().__init__() |
| 102 | + self.weight = fused_linear.layer_norm_weight |
| 103 | + bias = getattr(fused_linear, "layer_norm_bias", None) |
| 104 | + if bias is not None: |
| 105 | + self.bias = bias |
| 106 | + |
| 107 | + |
88 | 108 | class GPTModelExporter: |
89 | 109 | """Megatron Core GPTModel Exporter. |
90 | 110 |
|
@@ -489,6 +509,17 @@ def _get_state_dict(self): |
489 | 509 | def _get_transformer_layer_state_dict(self, layer, layer_id): |
490 | 510 | if not isinstance(layer.input_layernorm, IdentityOp): |
491 | 511 | self.rules["input_layernorm"](layer.input_layernorm, layer_id) |
| 512 | + elif ( |
| 513 | + "input_layernorm" in self.rules |
| 514 | + and hasattr(layer, "self_attention") |
| 515 | + and not isinstance(layer.self_attention, IdentityOp) |
| 516 | + and hasattr(layer.self_attention, "linear_qkv") |
| 517 | + and hasattr(layer.self_attention.linear_qkv, "layer_norm_weight") |
| 518 | + ): |
| 519 | + # TE spec: input layernorm is fused into TELayerNormColumnParallelLinear |
| 520 | + self.rules["input_layernorm"]( |
| 521 | + _FusedLayerNormProxy(layer.self_attention.linear_qkv), layer_id |
| 522 | + ) |
492 | 523 |
|
493 | 524 | if not isinstance(layer.self_attention, IdentityOp): |
494 | 525 | if "MLASelfAttention" in str(type(layer.self_attention)): |
@@ -527,6 +558,14 @@ def _get_transformer_layer_state_dict(self, layer, layer_id): |
527 | 558 |
|
528 | 559 | if not isinstance(layer.pre_mlp_layernorm, IdentityOp): |
529 | 560 | self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) |
| 561 | + elif ( |
| 562 | + "pre_mlp_layernorm" in self.rules |
| 563 | + and not isinstance(layer.mlp, IdentityOp) |
| 564 | + and hasattr(layer.mlp, "linear_fc1") |
| 565 | + and hasattr(layer.mlp.linear_fc1, "layer_norm_weight") |
| 566 | + ): |
| 567 | + # TE spec: pre-MLP layernorm is fused into TELayerNormColumnParallelLinear |
| 568 | + self.rules["pre_mlp_layernorm"](_FusedLayerNormProxy(layer.mlp.linear_fc1), layer_id) |
530 | 569 |
|
531 | 570 | if not isinstance(layer.mlp, IdentityOp): |
532 | 571 | if "MoE" in str(type(layer.mlp)): |
@@ -598,6 +637,14 @@ def _get_mtp_state_dict(self) -> dict[str, torch.Tensor]: |
598 | 637 | def _get_mamba_layer_state_dict(self, layer, layer_id): |
599 | 638 | if not isinstance(layer.norm, IdentityOp): |
600 | 639 | self.rules["norm"](layer.norm, layer_id) |
| 640 | + elif ( |
| 641 | + "norm" in self.rules |
| 642 | + and hasattr(layer, "mixer") |
| 643 | + and hasattr(layer.mixer, "in_proj") |
| 644 | + and hasattr(layer.mixer.in_proj, "layer_norm_weight") |
| 645 | + ): |
| 646 | + # TE spec: norm is fused into TELayerNormColumnParallelLinear (in_proj) |
| 647 | + self.rules["norm"](_FusedLayerNormProxy(layer.mixer.in_proj), layer_id) |
601 | 648 |
|
602 | 649 | self.rules["mixer_norm"](layer.mixer.norm, layer_id) |
603 | 650 | self.rules["A_log"](layer.mixer.A_log, layer_id) |
|
0 commit comments