Skip to content

Commit 81bc830

Browse files
authored
Add per-layer MLP type support for executorch export (#18856)
Differential Revision: D100682545 Pull Request resolved: #18856
1 parent a489707 commit 81bc830

2 files changed

Lines changed: 25 additions & 9 deletions

File tree

examples/models/llama/llama_transformer.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,26 +98,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
9898

9999

100100
class TransformerBlock(nn.Module):
101-
def __init__(self, args: ModelArgs, attention: Attention):
101+
def __init__(
102+
self, args: ModelArgs, attention: Attention, mlp_type: str = "default"
103+
):
102104
"""
103105
Transformer block with support for pre-norm and post-norm.
104106
Args:
105107
args (ModelArgs): model configuration parameters.
106108
attention (Attention): attention object to use in the transformer
107109
block. See `attention.py` for types of attention. Make sure
108110
the attention type is registered in the ATTENTION_REGISTRY.
111+
mlp_type (str): MLP type for this layer. "default" for standard
112+
FFN, "skip" for no FFN block.
109113
"""
110114
super().__init__()
111115
self.use_kv_cache = args.use_kv_cache
112116
self.n_heads = args.n_heads
113117
self.dim = args.dim
114118
self.head_dim = args.head_dim
115119
self.attention = attention
120+
self.mlp_type = mlp_type.lower()
116121

117122
assert (
118123
args.hidden_dim is not None
119124
), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock."
120-
if args.moe:
125+
if self.mlp_type == "skip":
126+
pass # No FFN block for this layer
127+
elif args.moe:
121128
self.block_sparse_moe = MOEFeedForward(args)
122129
elif args.target_modules is not None and (
123130
"down_proj" in args.target_modules
@@ -136,11 +143,12 @@ def __init__(self, args: ModelArgs, attention: Attention):
136143
eps=args.norm_eps,
137144
add_unit_offset=args.rms_norm_add_unit_offset,
138145
)
139-
self.ffn_norm = RMSNorm(
140-
args.dim,
141-
eps=args.norm_eps,
142-
add_unit_offset=args.rms_norm_add_unit_offset,
143-
)
146+
if self.mlp_type != "skip":
147+
self.ffn_norm = RMSNorm(
148+
args.dim,
149+
eps=args.norm_eps,
150+
add_unit_offset=args.rms_norm_add_unit_offset,
151+
)
144152

145153
@classmethod
146154
def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
@@ -156,9 +164,12 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
156164
f"Unknown attention type: {args.attention_type}. "
157165
f"Available: {list(ATTENTION_REGISTRY.keys())}"
158166
)
167+
mlp_type = "default"
168+
if args.mlp_type is not None and layer_id < len(args.mlp_type):
169+
mlp_type = args.mlp_type[layer_id]
159170
cls = ATTENTION_REGISTRY[args.attention_type]
160171
attention = cls(args, layer_id, rope, **args.attention_kwargs)
161-
return TransformerBlock(args, attention)
172+
return TransformerBlock(args, attention, mlp_type=mlp_type)
162173

163174
def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
164175
h, attn_options_update = self.attention(
@@ -167,7 +178,9 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
167178
if not isinstance(self.attention, AttentionSkip):
168179
h = x + h
169180

170-
if hasattr(self, "block_sparse_moe"):
181+
if self.mlp_type == "skip":
182+
out = h
183+
elif hasattr(self, "block_sparse_moe"):
171184
out = h + self.block_sparse_moe(self.ffn_norm(h))
172185
else:
173186
out = h + self.feed_forward(self.ffn_norm(h))

examples/models/llama/model_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ class ModelArgs:
145145
attention_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
146146
# Hybrid models can have layer types different from attention
147147
layer_types: Optional[list] = None
148+
# Per-layer MLP type: "default" for standard FFN, "skip" for no FFN block.
149+
# Indexed by layer id (e.g. mlp_type[0] applies to layer 0).
150+
mlp_type: Optional[list] = None
148151
model_architecture: Optional[str] = (
149152
None # Architecture of model. For HF models, please refer to the HF model.config.architectures. This is used in QNN backend only for now.
150153
)

0 commit comments

Comments
 (0)