Skip to content

Commit 39bb8f1

Browse files
runwangdlclaude
andcommitted
fix(redmule): lower bias-less Gemm to MatMul so CCT_train tiling stops crashing
The Siracusa+RedMulE CCT_train CI job (added in b682739) crashed during tiling with KeyError 'C' inside RedmuleGEMMTileConstraint -- the constraint unconditionally reads parseDict['C'] but GEMMRedmuleParser.parseNodeCtxt only populates it when len(node.inputs) == 3 and noBiasHoisting is True (the default). Backward-pass codegen for CCT (GradFusedMatMul rewrites) emits a flurry of 2-input ONNX Gemm nodes (alpha=1, no bias), which match the binding but never get a 'C' field -- hence the lookup blows up. A bias-less Gemm with alpha=1 is mathematically just MatMul, and the Redmule platform already routes ONNX MatMul through MatMulRedmuleMapper / RedmuleMatMulTilingReadyBindings (no C operand needed). So instead of papering over the parser, lower the op: - Add RedMuleBiaslessGemmToMatMulPass in Targets/Redmule/TopologyOptimizationPasses/Passes.py. It matches Gemm nodes (the 2-input pattern reused from RedMuleGEMMTransposePass), guards on len(inputs) == 2 and alpha == 1, materializes any transA/transB (constants get folded, variables get a Transpose appended via the same _appendTranspose helper the transpose pass already uses), then rewrites op="MatMul" and clears attrs. - Wire it into RedmuleDeployer.loweringOptimizer.passes BEFORE RedMuleGEMMTransposePass so the latter only ever sees real (3-input) Gemms; otherwise it would write transA/transB=0 onto what we just rewrote into a MatMul, and the stale Gemm op would still hit the bias-required tile constraint. 3-input Gemms (forward CCT FCs, the existing testFloatGEMM/testFloatGEMMtransB kernel fixtures) are untouched: the new pass returns the graph unchanged when len(inputs) != 2, and RedMuleGEMMTransposePass continues to see them as before. Local validation: pytest --collect-only -m "siracusa_redmule_tiled" still yields the same 4 cases (3 kernel + 1 training); module import of Deployer + the new pass class both succeed. Real run deferred to CI. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent b682739 commit 39bb8f1

2 files changed

Lines changed: 86 additions & 1 deletion

File tree

Deeploy/Targets/Redmule/Deployer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from Deeploy.DeeployTypes import DeploymentPlatform, TopologyOptimizer
3232
from Deeploy.Targets.PULPOpen.Deployer import PULPDeployer
3333
from Deeploy.Targets.Redmule.TopologyOptimizationPasses.Passes import RedMuleAdjustWeightMemoryLayoutPass, \
34-
RedMuleGEMMTransposePass
34+
RedMuleBiaslessGemmToMatMulPass, RedMuleGEMMTransposePass
3535

3636

3737
class RedmuleDeployer(PULPDeployer):
@@ -51,5 +51,9 @@ def __init__(self,
5151

5252
self.loweringOptimizer.passes += [
5353
RedMuleAdjustWeightMemoryLayoutPass("Redmule"),
54+
# Lower bias-less Gemm (e.g. backward GradFusedMatMul nodes in CCT
55+
# training) to MatMul before GEMMTransposePass touches them; the
56+
# bias-required tile constraint would otherwise crash.
57+
RedMuleBiaslessGemmToMatMulPass("Redmule"),
5458
RedMuleGEMMTransposePass("Redmule")
5559
]

Deeploy/Targets/Redmule/TopologyOptimizationPasses/Passes.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,84 @@ def __init__(self, redmuleEngineName: str):
148148
super().__init__(pattern = pattern,
149149
replacement_fn = _redmule_gemm_transpose_fun,
150150
name = "_REDMULE_GEMM_TRANSPOSE_PASS")
151+
152+
153+
def _redmule_biasless_gemm_to_matmul_fun(graph: gs.Graph, match: Match, name: str):
154+
"""Rewrite a 2-input ONNX Gemm (no C / bias) into an equivalent MatMul.
155+
156+
Backward-pass codegen (e.g. the ``GradFusedMatMul`` rewrites that fall out
157+
of the CCT training graph) emits ``Gemm`` nodes with only A and B and no
158+
bias, which the ``GEMMRedmuleParser`` accepts but for which the
159+
``RedmuleGEMMTileConstraint`` then crashes (KeyError on ``parseDict['C']``).
160+
A bias-less Gemm with alpha=1 is mathematically just a MatMul, and the
161+
Redmule platform already maps ONNX ``MatMul`` to a kernel that doesn't
162+
expect a C operand -- so we lower it here.
163+
164+
transA / transB are materialized as explicit ``Transpose`` nodes (or, for
165+
constant operands, folded into the constant) before the op is rewritten,
166+
because ``MatMul`` has no equivalent attributes.
167+
"""
168+
gemm_node = list(match.nodes_map.values())[0]
169+
170+
# Pattern matcher may match Gemms with 3 inputs too; act only on the
171+
# bias-less subset.
172+
if len(gemm_node.inputs) != 2:
173+
return graph
174+
175+
# Anything other than alpha=1 cannot be expressed as a plain MatMul.
176+
if gemm_node.attrs.get('alpha', 1.0) != 1.0:
177+
return graph
178+
179+
transA = gemm_node.attrs.get('transA', 0)
180+
transB = gemm_node.attrs.get('transB', 0)
181+
182+
for inputIdx, transFlag in ((0, transA), (1, transB)):
183+
if not transFlag:
184+
continue
185+
operand = gemm_node.inputs[inputIdx]
186+
if isinstance(operand, gs.Constant):
187+
if len(operand.values.shape) > 2:
188+
perm = list(range(len(operand.values.shape)))
189+
perm[-1], perm[-2] = perm[-2], perm[-1]
190+
operand.values = np.transpose(operand.values, perm)
191+
else:
192+
operand.values = np.transpose(operand.values)
193+
else:
194+
perm = list(range(len(operand.shape)))
195+
perm[-1], perm[-2] = perm[-2], perm[-1]
196+
anchorTransposeNode = _appendTranspose(operand, gemm_node, perm)
197+
graph.nodes.append(anchorTransposeNode)
198+
199+
gemm_node.op = "MatMul"
200+
gemm_node.attrs.clear()
201+
202+
return graph
203+
204+
205+
@contextagnostic
206+
class RedMuleBiaslessGemmToMatMulPass(ReplaceSequentialPatternPass):
207+
"""Lower bias-less (2-input) ONNX Gemm nodes to MatMul on the Redmule path.
208+
209+
Must run before RedMuleGEMMTransposePass so the latter only sees the
210+
real (3-input) Gemm nodes; otherwise its replacement_fn would write
211+
``transA`` / ``transB`` back to 0 on what is now a MatMul, and a stale
212+
``Gemm`` op type would still hit the bias-required tile constraint.
213+
"""
214+
215+
def __init__(self, redmuleEngineName: str):
216+
pattern = gs.Graph()
217+
218+
input_a = gs.Variable(name = "input_a")
219+
input_b = gs.Variable(name = "input_b")
220+
221+
gemm_output = pattern.layer(op = "Gemm",
222+
name = "gemm_node",
223+
inputs = [input_a, input_b],
224+
outputs = ["gemm_output"])
225+
226+
pattern.inputs = [input_a, input_b]
227+
pattern.outputs = [gemm_output]
228+
229+
super().__init__(pattern = pattern,
230+
replacement_fn = _redmule_biasless_gemm_to_matmul_fun,
231+
name = "_REDMULE_BIASLESS_GEMM_TO_MATMUL_PASS")

0 commit comments

Comments
 (0)