Skip to content

Commit 1f7ba4c

Browse files
committed
add reformated moe
1 parent 4ce5c73 commit 1f7ba4c

11 files changed

Lines changed: 1711 additions & 84 deletions

File tree

MoE_MLP_Fusion_Analysis.tex

Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
\documentclass[11pt,a4paper]{article}
2+
\usepackage[utf8]{inputenc}
3+
\usepackage{amsmath,amssymb,amsfonts}
4+
\usepackage{booktabs}
5+
\usepackage{geometry}
6+
\usepackage{enumitem}
7+
\usepackage{hyperref}
8+
\usepackage{xcolor}
9+
\usepackage{algorithm}
10+
\usepackage{algpseudocode}
11+
12+
\geometry{margin=2.5cm}
13+
\newcommand{\R}{\mathbb{R}}
14+
\newcommand{\act}{\sigma}
15+
\newcommand{\W}{\mathbf{W}}
16+
\newcommand{\bb}{\mathbf{b}}
17+
\newcommand{\x}{\mathbf{x}}
18+
\newcommand{\h}{\mathbf{h}}
19+
\newcommand{\g}{\mathbf{g}}
20+
\newcommand{\e}{\mathbf{e}}
21+
22+
\title{MoE-Aware MLP Fusion for DPA3 RepFlowLayer:\\
23+
Reducing Expert Parallelism Communication Cost}
24+
\author{Technical Analysis}
25+
\date{\today}
26+
27+
\begin{document}
28+
\maketitle
29+
30+
\section{Problem Statement}
31+
32+
In the current DPA3 RepFlowLayer, each layer contains \textbf{7 independent MLPs} that can be replaced by MoE layers. In an Expert Parallelism (EP) setting, each MoE layer requires two All-to-All communications (dispatch + combine). With 7 MoE layers per RepFlowLayer and $L$ layers total, the communication overhead is:
33+
\begin{equation}
34+
C_{\text{comm}} = 2 \times 7 \times L \times C_{\text{A2A}}
35+
\end{equation}
36+
where $C_{\text{A2A}}$ is the cost of a single All-to-All operation. For $L=6$, this means \textbf{84 All-to-All operations per forward pass}, which is prohibitive.
37+
38+
The goal is to \textbf{fuse MLPs that share the same input} into single MoE layers, reducing the number of independent expert dispatch/combine rounds while preserving model expressiveness.
39+
40+
\section{Current Architecture Analysis}
41+
42+
\subsection{The 7 MLPs and Their Data Flow}
43+
44+
Let $n_i \in \R^{d_n}$ denote the node embedding of atom $i$, $e_{ij} \in \R^{d_e}$ the edge embedding of pair $(i,j)$, and $a_{ijk} \in \R^{d_a}$ the angle embedding of triplet $(i,j,k)$. The 7 MLPs are:
45+
46+
\begin{table}[h]
47+
\centering
48+
\begin{tabular}{clccl}
49+
\toprule
50+
\# & Name & Input dim & Output dim & Input tensor \\
51+
\midrule
52+
1 & \texttt{node\_self\_mlp} & $d_n$ & $d_n$ & $n_i$ \\
53+
2 & \texttt{node\_sym\_linear} & $d_n \cdot k + d_e \cdot k$ & $d_n$ & $\text{GRRG}(n_i, e_{ij}, h_{ij})$ \\
54+
3 & \texttt{node\_edge\_linear} & $2d_n + d_e$ & $H \cdot d_n$ & $[n_i; n_j; e_{ij}]$ \\
55+
4 & \texttt{edge\_self\_linear} & $2d_n + d_e$ & $d_e$ & $[n_i; n_j; e_{ij}]$ \\
56+
5 & \texttt{edge\_angle\_linear1} & $d_a + d_n' + 2d_e'$ & $d_e$ & $[a_{ijk}; n_i'; e_{ik}'; e_{ij}']$ \\
57+
6 & \texttt{edge\_angle\_linear2} & $d_e$ & $d_e$ & reduced angle$\to$edge \\
58+
7 & \texttt{angle\_self\_linear} & $d_a + d_n' + 2d_e'$ & $d_a$ & $[a_{ijk}; n_i'; e_{ik}'; e_{ij}']$ \\
59+
\bottomrule
60+
\end{tabular}
61+
\caption{The 7 MLPs in RepFlowLayer. $k$ = \texttt{axis\_neuron}, $H$ = \texttt{n\_multi\_edge\_message}, primed dimensions indicate compressed variants when \texttt{a\_compress\_rate} $> 0$.}
62+
\label{tab:mlps}
63+
\end{table}
64+
65+
\subsection{Shared Input Groups}
66+
67+
Two pairs of MLPs consume \textbf{identical input tensors}:
68+
69+
\paragraph{Group A: Edge-info MLPs (\#3 + \#4).}
70+
Both consume the concatenated edge information:
71+
\begin{equation}
72+
\x_{\text{edge}} = [n_i;\; n_j;\; e_{ij}] \in \R^{2d_n + d_e}
73+
\end{equation}
74+
MLP \#3 produces the node$\leftarrow$edge message ($\R^{H \cdot d_n}$), MLP \#4 produces the edge self-update ($\R^{d_e}$).
75+
76+
\paragraph{Group B: Angle-info MLPs (\#5 + \#7).}
77+
Both consume the concatenated angle information:
78+
\begin{equation}
79+
\x_{\text{angle}} = [a_{ijk};\; n_i';\; e_{ik}';\; e_{ij}'] \in \R^{d_a + d_n' + 2d_e'}
80+
\end{equation}
81+
MLP \#5 produces the edge$\leftarrow$angle message ($\R^{d_e}$), MLP \#7 produces the angle self-update ($\R^{d_a}$).
82+
83+
\paragraph{Independent MLPs.}
84+
MLPs \#1, \#2, and \#6 each have unique inputs and cannot be trivially fused with others.
85+
86+
\section{Proposed Fusion Strategies}
87+
88+
\subsection{Strategy 1: Direct Output Concatenation (Recommended)}
89+
90+
\subsubsection{Fusion A: Edge-info MLPs $\to$ Single MoE}
91+
92+
Replace MLPs \#3 and \#4 with a single fused MLP:
93+
\begin{equation}
94+
[\underbrace{y_{\text{node}}}_{\R^{H \cdot d_n}};\; \underbrace{y_{\text{edge}}}_{\R^{d_e}}] = \act\!\left(\W_{\text{fused}}^{(A)} \cdot \x_{\text{edge}} + \bb_{\text{fused}}^{(A)}\right)
95+
\end{equation}
96+
where $\W_{\text{fused}}^{(A)} \in \R^{(H \cdot d_n + d_e) \times (2d_n + d_e)}$. The output is split:
97+
\begin{align}
98+
y_{\text{node}} &= [\act(\W_{\text{fused}}^{(A)} \x_{\text{edge}} + \bb)]_{1:H \cdot d_n} \quad \text{(for node update)} \\
99+
y_{\text{edge}} &= [\act(\W_{\text{fused}}^{(A)} \x_{\text{edge}} + \bb)]_{H \cdot d_n + 1 : H \cdot d_n + d_e} \quad \text{(for edge update)}
100+
\end{align}
101+
102+
\textbf{Expressiveness analysis:} The fused layer has $\W_{\text{fused}} \in \R^{(Hd_n + d_e) \times (2d_n + d_e)}$, while the original two layers have $\W_3 \in \R^{Hd_n \times (2d_n + d_e)}$ and $\W_4 \in \R^{d_e \times (2d_n + d_e)}$. Since:
103+
\begin{equation}
104+
\W_{\text{fused}} = \begin{bmatrix} \W_3 \\ \W_4 \end{bmatrix}, \quad
105+
\bb_{\text{fused}} = \begin{bmatrix} \bb_3 \\ \bb_4 \end{bmatrix}
106+
\end{equation}
107+
the fused layer is \textbf{strictly equivalent} to the two separate layers --- no expressiveness is lost. The parameter count is identical.
108+
109+
\subsubsection{Fusion B: Angle-info MLPs $\to$ Single MoE}
110+
111+
Replace MLPs \#5 and \#7 with a single fused MLP:
112+
\begin{equation}
113+
[\underbrace{y_{\text{e}\leftarrow\text{a}}}_{\R^{d_e}};\; \underbrace{y_{\text{angle}}}_{\R^{d_a}}] = \act\!\left(\W_{\text{fused}}^{(B)} \cdot \x_{\text{angle}} + \bb_{\text{fused}}^{(B)}\right)
114+
\end{equation}
115+
where $\W_{\text{fused}}^{(B)} \in \R^{(d_e + d_a) \times (d_a + d_n' + 2d_e')}$.
116+
117+
Same analysis: $\W_{\text{fused}}^{(B)} = [\W_5; \W_7]$, \textbf{strictly equivalent}, zero expressiveness loss.
118+
119+
\subsubsection{Result: 7 $\to$ 5 MoE Layers}
120+
121+
\begin{table}[h]
122+
\centering
123+
\begin{tabular}{clcl}
124+
\toprule
125+
\# & Fused Name & Tensor Level & Original MLPs \\
126+
\midrule
127+
1 & \texttt{node\_self\_mlp} & Node $[N_b, N_{\text{loc}}, d_n]$ & \#1 \\
128+
2 & \texttt{node\_sym\_linear} & Node $[N_b, N_{\text{loc}}, kd_n + kd_e]$ & \#2 \\
129+
3' & \texttt{edge\_fused\_linear} & Edge $[N_b, N_{\text{loc}}, N_{\text{nei}}, 2d_n+d_e]$ & \#3 + \#4 \\
130+
4' & \texttt{angle\_fused\_linear} & Angle $[N_b, N_{\text{loc}}, S_a, S_a, d_a+d_n'+2d_e']$ & \#5 + \#7 \\
131+
5 & \texttt{edge\_angle\_linear2} & Edge $[N_b, N_{\text{loc}}, N_{\text{nei}}, d_e]$ & \#6 \\
132+
\bottomrule
133+
\end{tabular}
134+
\caption{After Strategy 1 fusion: 5 MoE layers.}
135+
\end{table}
136+
137+
Communication reduction: $\frac{7-5}{7} = 28.6\%$ fewer All-to-All rounds.
138+
139+
\subsection{Strategy 2: Further Fusion via Shared Projection (Advanced)}
140+
141+
\subsubsection{Motivation}
142+
143+
Strategy 1 only fuses MLPs with identical inputs. Can we go further? The key observation: MLPs \#1 (node self) and \#2 (node sym) both \textbf{output to the same node update list} and operate at the \textbf{same tensor level} (node, $[N_b, N_{\text{loc}}, \cdot]$). If we can unify their inputs, they can be fused.
144+
145+
\subsubsection{Approach: Pre-projection + Concatenation}
146+
147+
Define a unified node input by concatenating the self-embedding and symmetrized features:
148+
\begin{equation}
149+
\x_{\text{node}} = [n_i;\; \text{GRRG}(n_i, e_{ij}, h_{ij})] \in \R^{d_n + kd_n + kd_e}
150+
\end{equation}
151+
Then a single fused MLP replaces both \#1 and \#2:
152+
\begin{equation}
153+
[\underbrace{y_{\text{self}}}_{\R^{d_n}};\; \underbrace{y_{\text{sym}}}_{\R^{d_n}}] = \act\!\left(\W_{\text{node}} \cdot \x_{\text{node}} + \bb_{\text{node}}\right)
154+
\end{equation}
155+
156+
\textbf{Expressiveness analysis:} The original \#1 only sees $n_i$ while \#2 only sees the GRRG features. The fused layer sees both, which is \textbf{strictly more expressive} --- each output can now attend to both self and symmetrized features. However, this changes the model architecture (not just an implementation optimization), so it requires retraining.
157+
158+
\subsubsection{Result: 7 $\to$ 4 MoE Layers}
159+
160+
\begin{table}[h]
161+
\centering
162+
\begin{tabular}{clcl}
163+
\toprule
164+
\# & Fused Name & Tensor Level & Original MLPs \\
165+
\midrule
166+
1' & \texttt{node\_fused\_linear} & Node & \#1 + \#2 \\
167+
2' & \texttt{edge\_fused\_linear} & Edge & \#3 + \#4 \\
168+
3' & \texttt{angle\_fused\_linear} & Angle & \#5 + \#7 \\
169+
4 & \texttt{edge\_angle\_linear2} & Edge & \#6 \\
170+
\bottomrule
171+
\end{tabular}
172+
\caption{After Strategy 2 fusion: 4 MoE layers.}
173+
\end{table}
174+
175+
Communication reduction: $\frac{7-4}{7} = 42.9\%$.
176+
177+
\subsection{Strategy 3: Aggressive Fusion by Tensor Level (Maximum Reduction)}
178+
179+
\subsubsection{Motivation}
180+
181+
In EP, the communication cost is dominated by the \textbf{number of dispatch/combine rounds}, not the data volume per round (since expert output dimension is small relative to batch size). Therefore, the optimal strategy is to minimize the \textbf{number of distinct MoE calls}, ideally one per tensor level.
182+
183+
\subsubsection{Approach: One MoE per Level}
184+
185+
\paragraph{Node level:} Merge \#1 + \#2 as in Strategy 2.
186+
187+
\paragraph{Edge level:} Merge \#3 + \#4 + \#6. This requires restructuring the angle$\to$edge pipeline. Currently:
188+
\begin{align}
189+
y_5 &= \text{MLP}_5(\x_{\text{angle}}) \quad \text{(angle$\to$edge, per-angle)} \\
190+
z &= \text{reduce}(y_5) \quad \text{(sum over angles $\to$ per-edge)} \\
191+
y_6 &= \text{MLP}_6(z) \quad \text{(refine, per-edge)}
192+
\end{align}
193+
MLP \#6 operates on the \textit{reduced} angle output, not the raw edge\_info. To fuse \#6 with \#3+\#4, we need to concatenate the reduced angle output with edge\_info:
194+
\begin{equation}
195+
\x_{\text{edge}}^{+} = [n_i;\; n_j;\; e_{ij};\; z_{\text{angle}\to\text{edge}}] \in \R^{2d_n + 2d_e}
196+
\end{equation}
197+
\begin{equation}
198+
[\underbrace{y_{\text{node}}}_{\R^{Hd_n}};\; \underbrace{y_{\text{edge\_self}}}_{\R^{d_e}};\; \underbrace{y_{\text{edge\_angle}}}_{\R^{d_e}}] = \act\!\left(\W_{\text{edge}}^{+} \cdot \x_{\text{edge}}^{+} + \bb\right)
199+
\end{equation}
200+
201+
\textbf{Caveat:} This changes the computation order --- the angle$\to$edge reduction must happen \textit{before} the edge MoE call, which means the angle MoE must complete first. This creates a \textbf{sequential dependency} that may limit pipelining.
202+
203+
\paragraph{Angle level:} MLP \#5 + \#7 fused as in Strategy 1.
204+
205+
\subsubsection{Result: 7 $\to$ 3 MoE Layers}
206+
207+
\begin{table}[h]
208+
\centering
209+
\begin{tabular}{clcl}
210+
\toprule
211+
\# & Fused Name & Tensor Level & Original MLPs \\
212+
\midrule
213+
1' & \texttt{node\_moe} & Node & \#1 + \#2 \\
214+
2' & \texttt{edge\_moe} & Edge & \#3 + \#4 + \#6 \\
215+
3' & \texttt{angle\_moe} & Angle & \#5 + \#7 \\
216+
\bottomrule
217+
\end{tabular}
218+
\caption{After Strategy 3 fusion: 3 MoE layers (one per tensor level).}
219+
\end{table}
220+
221+
Communication reduction: $\frac{7-3}{7} = 57.1\%$.
222+
223+
\section{Communication Cost Analysis}
224+
225+
\subsection{EP Communication Model}
226+
227+
In Expert Parallelism with $P$ GPUs and $E$ experts ($E/P$ experts per GPU), each MoE layer requires:
228+
\begin{itemize}
229+
\item \textbf{Dispatch:} All-to-All to send tokens to expert-owning GPUs. Cost: $O\!\left(\frac{B \cdot d_{\text{in}}}{P}\right)$
230+
\item \textbf{Combine:} All-to-All to collect expert outputs. Cost: $O\!\left(\frac{B \cdot d_{\text{out}}}{P}\right)$
231+
\end{itemize}
232+
where $B$ is the token count (atoms, edges, or angles).
233+
234+
\subsection{Token Counts by Level}
235+
236+
For a system with $N$ atoms, $M$ edges (neighbors), and $A$ angles:
237+
\begin{align}
238+
B_{\text{node}} &= N_b \cdot N_{\text{loc}} \approx N \\
239+
B_{\text{edge}} &= N_b \cdot N_{\text{loc}} \cdot N_{\text{nei}} \approx N \cdot \bar{M} \\
240+
B_{\text{angle}} &= N_b \cdot N_{\text{loc}} \cdot S_a^2 \approx N \cdot \bar{S}_a^2
241+
\end{align}
242+
where $\bar{M} \approx 120$ and $\bar{S}_a \approx 30$ typically. Thus:
243+
\begin{equation}
244+
B_{\text{angle}} \gg B_{\text{edge}} \gg B_{\text{node}}
245+
\end{equation}
246+
247+
\subsection{Comparative Cost}
248+
249+
\begin{table}[h]
250+
\centering
251+
\begin{tabular}{lccc}
252+
\toprule
253+
Strategy & \# MoE layers & A2A rounds/layer & Total A2A/model \\
254+
\midrule
255+
Baseline (no fusion) & 7 & $2 \times 7 \times L$ & $84$ \\
256+
Strategy 1 (shared input) & 5 & $2 \times 5 \times L$ & $60$ \\
257+
Strategy 2 (+ node fusion) & 4 & $2 \times 4 \times L$ & $48$ \\
258+
Strategy 3 (per-level) & 3 & $2 \times 3 \times L$ & $36$ \\
259+
\bottomrule
260+
\end{tabular}
261+
\caption{Communication rounds for $L=6$ layers. Each round = one All-to-All.}
262+
\end{table}
263+
264+
\subsection{Weighted Cost by Token Count}
265+
266+
Not all MoE layers have equal communication cost. The angle-level MoE dominates:
267+
\begin{equation}
268+
C_{\text{total}} = \sum_{l=1}^{L} \sum_{m \in \text{MoE}_l} 2 \cdot B_m \cdot d_m
269+
\end{equation}
270+
271+
\textbf{Key insight:} Fusing angle-level MLPs (\#5 + \#7) provides the largest absolute communication savings because $B_{\text{angle}}$ is the largest token count. Strategy 1's Fusion B alone saves $\sim 50\%$ of angle-level communication.
272+
273+
\section{Compatibility with optim\_update}
274+
275+
The \texttt{optim\_update} optimization decomposes the weight matrix by input components. For the edge-info MLP:
276+
\begin{equation}
277+
\W \cdot [n_i; n_j; e_{ij}] + \bb = \W_{n_i} n_i + \W_{n_j} n_j + \W_e e_{ij} + \bb
278+
\end{equation}
279+
This allows computing each term independently and broadcasting, avoiding the expensive explicit concatenation.
280+
281+
For fused layers, the decomposition extends naturally:
282+
\begin{equation}
283+
\W_{\text{fused}} \cdot [n_i; n_j; e_{ij}] + \bb_{\text{fused}} = \begin{bmatrix} \W_{n_i}^{(3)} \\ \W_{n_i}^{(4)} \end{bmatrix} n_i + \begin{bmatrix} \W_{n_j}^{(3)} \\ \W_{n_j}^{(4)} \end{bmatrix} n_j + \begin{bmatrix} \W_e^{(3)} \\ \W_e^{(4)} \end{bmatrix} e_{ij} + \begin{bmatrix} \bb^{(3)} \\ \bb^{(4)} \end{bmatrix}
284+
\end{equation}
285+
The split dimensions change but the decomposition structure is preserved. Implementation requires adjusting the \texttt{torch.split} sizes in \texttt{optim\_edge\_update} and \texttt{optim\_angle\_update}.
286+
287+
\section{Recommendations}
288+
289+
\subsection{Immediate (Low Risk)}
290+
291+
\textbf{Implement Strategy 1} --- fuse shared-input MLPs:
292+
\begin{itemize}
293+
\item Fusion A: \texttt{node\_edge\_linear} + \texttt{edge\_self\_linear} $\to$ \texttt{edge\_fused\_linear}
294+
\item Fusion B: \texttt{edge\_angle\_linear1} + \texttt{angle\_self\_linear} $\to$ \texttt{angle\_fused\_linear}
295+
\end{itemize}
296+
This is \textbf{mathematically equivalent} to the current architecture (zero expressiveness change), requires no retraining of existing models, and reduces MoE layers from 7 to 5 (28.6\% fewer A2A rounds).
297+
298+
\subsection{Medium Term (Requires Retraining)}
299+
300+
\textbf{Implement Strategy 2} --- additionally fuse node-level MLPs:
301+
\begin{itemize}
302+
\item Fusion C: \texttt{node\_self\_mlp} + \texttt{node\_sym\_linear} $\to$ \texttt{node\_fused\_linear}
303+
\end{itemize}
304+
This changes the architecture (each output sees both self and sym features), but is \textbf{more expressive}. Reduces to 4 MoE layers (42.9\% fewer A2A rounds). Requires ablation study to verify no accuracy regression.
305+
306+
\subsection{Long Term (Architecture Change)}
307+
308+
\textbf{Implement Strategy 3} --- one MoE per tensor level:
309+
\begin{itemize}
310+
\item Absorb \texttt{edge\_angle\_linear2} into the edge-level fused MoE
311+
\item Requires restructuring the angle$\to$edge pipeline
312+
\end{itemize}
313+
Reduces to 3 MoE layers (57.1\% fewer A2A rounds). Most invasive change, requires careful validation.
314+
315+
\subsection{Summary}
316+
317+
\begin{table}[h]
318+
\centering
319+
\begin{tabular}{lcccc}
320+
\toprule
321+
Strategy & MoE layers & A2A reduction & Expressiveness & Retrain? \\
322+
\midrule
323+
Baseline & 7 & --- & --- & No \\
324+
Strategy 1 & 5 & 28.6\% & Identical & No \\
325+
Strategy 2 & 4 & 42.9\% & $\geq$ original & Yes \\
326+
Strategy 3 & 3 & 57.1\% & $\geq$ original & Yes \\
327+
\bottomrule
328+
\end{tabular}
329+
\caption{Summary of fusion strategies.}
330+
\end{table}
331+
332+
\end{document}

0 commit comments

Comments
 (0)