Skip to content

Commit 1782a88

Browse files
runwangdlclaude
andcommitted
fix(redmule+upstream-transpose): unblock CCT_train codegen end-to-end
This commit replaces 39bb8f1's experimental Gemm->MatMul lowering pass (which unblocked the original KeyError 'C' but exposed a deeper Transpose rank-mismatch bug downstream) with two smaller, locally-verified fixes: 1) Hoist a properly-shaped zero C tensor in GEMMRedmuleParser when an ONNX Gemm has only A and B (e.g. backward GradFusedMatMul rewrites in CCT_train). Fixes for the hoist path: - GEMMRedmuleParser.__init__ used to set self.noBiasHoisting *before* calling super().__init__(), but MatMulParser.__init__ also writes self.noBiasHoisting from its own default of True -- so the caller's flag was silently clobbered. Reverse the order and forward the kwarg. - The hoist used to allocate a 1-element np.zeros((1)) scalar; that would never satisfy RedmuleGEMMTileConstraint's "C dim equals output dim" assertion. Allocate a zero array whose shape matches node.outputs[0].shape. - Pass _type=PointerClass(float32_t) to ctxt.hoistConstant so the buffer is type-annotated up-front. Without it, MemoryScheduler.getConstantTensorOffset later trips an AttributeError on the un-annotated buffer. - Append the hoisted Constant to node.inputs so the tiler picks it up via its node.inputs + node.outputs walk, AND register the Gemm as a user via newCtxt.addUser so the MemoryConstraintFlow kill-set assertion (which walks _users) finds a consumer. - Engine.GEMMMRedmuleMapper now instantiates with noBiasHoisting=False so the hoist path is actually taken. Drop the BiaslessGemmToMatMulPass class (added in 39bb8f1) and its Deployer registration: the parser-side hoist is the smaller fix and side-steps the MatMul broadcasting issue entirely. 2) Fix Generic/TransposeTileConstraint and PULPOpen/TransposeTemplate to use a *spatial-view* interpretation of perm. When MatMulLayer. computeShapes broadens an already-existing tensor that is simultaneously a forward MatMul B input *and* a downstream non-broadening consumer (Gemm/Transpose), data_in and data_out of a downstream Transpose can end up with different ranks. Both addGeometricalConstraint and serializeTilingSolution previously assumed len(perm) == data_in_rank == data_out_rank; they now offset their shape lookups by len(shape) - len(perm) so the perm targets the trailing spatial dims in either tensor. PULPTransposeTemplate's alignToContext gets the same treatment for its dimLen_<idx> lookup and parallelDim selection. Aligned cases (existing kernel fixtures testFloatGEMM / testFloatGEMMtransB) compute identical offsets of 0 and behave exactly as before. This commit verifies the fix locally on Models/Training/CCT/cct_train: testMVPTraining.py and testMVPOptimizer.py both exit 0 on Siracusa_w_redmule, producing a ~7.7 MB TrainingNetwork.c and matching OptimizerNetwork.c. C compilation + GVSoC simulation still need to be validated on CI (can't run the runwangdl/gvsoc fork locally in the agent container). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 39bb8f1 commit 1782a88

6 files changed

Lines changed: 65 additions & 97 deletions

File tree

Deeploy/Targets/Generic/TileConstraints/TransposeTileConstraint.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,34 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw
2424
inputBufferName = parseDict['data_in']
2525
outputBufferName = parseDict['data_out']
2626

27+
inputShape = ctxt.lookup(inputBufferName).shape
28+
outputShape = ctxt.lookup(outputBufferName).shape
29+
perm = parseDict["perm"]
30+
31+
# Spatial-view interpretation of the perm: it operates on the last
32+
# len(perm) dims of data_in and the last len(perm) dims of data_out.
33+
# MatMulLayer.computeShapes can left-pad the rank of one side without
34+
# touching the other when the same gs.Variable is shared between a
35+
# broadening (MatMul) and a non-broadening (Gemm/Transpose) consumer,
36+
# so the constraint indexing must offset by the per-side leading-batch
37+
# depth rather than assume rank == len(perm) == rank_other. When all
38+
# ranks already match, offsets are 0 and behavior is unchanged.
39+
inputOffset = len(inputShape) - len(perm)
40+
outputOffset = len(outputShape) - len(perm)
41+
assert inputOffset >= 0 and outputOffset >= 0, (
42+
f"Transpose perm {perm} is longer than tensor ranks "
43+
f"data_in={inputShape}, data_out={outputShape}")
44+
2745
# Add I/O dimensions to the model as variables
2846
for bufferName in [inputBufferName, outputBufferName]:
2947
tilerModel.addTensorDimToModel(ctxt, bufferName)
3048

31-
# Map output dims to inputs dims
32-
for idx, perm_idx in enumerate(parseDict["perm"]):
49+
# Map output spatial dims to input spatial dims via perm.
50+
for idx, perm_idx in enumerate(perm):
3351
tilerModel.addConstraint(
34-
tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = idx) == tilerModel.getTensorDimVar(
35-
tensorName = inputBufferName, dimIdx = perm_idx))
52+
tilerModel.getTensorDimVar(tensorName = outputBufferName,
53+
dimIdx = outputOffset + idx) == tilerModel.getTensorDimVar(
54+
tensorName = inputBufferName, dimIdx = inputOffset + perm_idx))
3655

3756
return tilerModel
3857

@@ -50,7 +69,10 @@ def serializeTilingSolution(
5069
replacementTypes = {}
5170
replacements: Dict[str, List[int]] = {}
5271

53-
numDims = len(ctxt.lookup(operatorRepresentation['data_in']).shape)
72+
# Match the spatial-view interpretation in addGeometricalConstraint:
73+
# only the last len(perm) dims of data_in are actually transposed,
74+
# so emit exactly len(perm) dimLen_<i> replacement variables.
75+
numDims = len(operatorRepresentation['perm'])
5476

5577
for dim in range(numDims):
5678
replacementTypes[f"dimLen_{dim}"] = PointerClass(uint16_t)

Deeploy/Targets/PULPOpen/Templates/TransposeTemplate.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,27 @@ def alignToContext(self, ctxt: NetworkContext,
6565
fRep['accessStr'] = accessStr
6666
fRep['data_out_shape'] = data_out_shape
6767

68-
parallelDims = [idx for idx, dim in enumerate(data_out_shape) if dim >= 8]
68+
# Spatial-view: perm targets the last len(perm) dims of data_in. When
69+
# data_in has been left-padded (e.g. by MatMulLayer.computeShapes
70+
# broadening a shared upstream Transpose output), offset the
71+
# data_in_shape lookup so dimLen_<idx> reflects the actual
72+
# transposed dim rather than a leading batch placeholder. Same
73+
# for data_out_shape -- parallelDim must index within the spatial
74+
# view since the per-tile for-loop count comes from len(perm).
75+
dataInOffset = len(data_in_shape) - len(perm)
76+
dataOutOffset = len(data_out_shape) - len(perm)
77+
spatialOutShape = list(data_out_shape[dataOutOffset:])
78+
79+
parallelDims = [idx for idx, dim in enumerate(spatialOutShape) if dim >= 8]
6980
if len(parallelDims) > 0:
7081
parallelDim = parallelDims[0]
7182
else:
72-
parallelDim = data_out_shape.index(max(data_out_shape))
83+
parallelDim = spatialOutShape.index(max(spatialOutShape))
7384

7485
forLoops = []
7586
dimLenPtrs = []
7687
for idx, i in enumerate(perm):
77-
operatorRepresentation[f"dimLen_{idx}"] = data_in_shape[idx]
88+
operatorRepresentation[f"dimLen_{idx}"] = data_in_shape[dataInOffset + idx]
7889
dimLenPtrs.append(f"dimLen_{idx}")
7990
if idx != parallelDim:
8091
forLoops.append(_forLoop.generate({"i": i, "dimLenPtr": f"dimLen_{i}"}))

Deeploy/Targets/Redmule/Deployer.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from Deeploy.DeeployTypes import DeploymentPlatform, TopologyOptimizer
3232
from Deeploy.Targets.PULPOpen.Deployer import PULPDeployer
3333
from Deeploy.Targets.Redmule.TopologyOptimizationPasses.Passes import RedMuleAdjustWeightMemoryLayoutPass, \
34-
RedMuleBiaslessGemmToMatMulPass, RedMuleGEMMTransposePass
34+
RedMuleGEMMTransposePass
3535

3636

3737
class RedmuleDeployer(PULPDeployer):
@@ -51,9 +51,5 @@ def __init__(self,
5151

5252
self.loweringOptimizer.passes += [
5353
RedMuleAdjustWeightMemoryLayoutPass("Redmule"),
54-
# Lower bias-less Gemm (e.g. backward GradFusedMatMul nodes in CCT
55-
# training) to MatMul before GEMMTransposePass touches them; the
56-
# bias-required tile constraint would otherwise crash.
57-
RedMuleBiaslessGemmToMatMulPass("Redmule"),
5854
RedMuleGEMMTransposePass("Redmule")
5955
]

Deeploy/Targets/Redmule/Engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
MatMulRedmuleMapper = NodeMapper(MatMulParser(), RedmuleMatMulTilingReadyBindings)
3737
Conv2DRedmuleMapper = NodeMapper(PULPFPConv2DParser(), RedmuleConvTilingReadyBindings)
38-
GEMMMRedmuleMapper = NodeMapper(GEMMRedmuleParser(), RedmuleGEMMTilingReadyBindings)
38+
GEMMMRedmuleMapper = NodeMapper(GEMMRedmuleParser(noBiasHoisting = False), RedmuleGEMMTilingReadyBindings)
3939

4040
RedmuleMapping = {
4141
'MatMul': MatMulLayer([MatMulRedmuleMapper]),

Deeploy/Targets/Redmule/Parsers.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,20 @@
3030
import numpy as np
3131
import onnx_graphsurgeon as gs
3232

33+
from Deeploy.AbstractDataTypes import PointerClass
34+
from Deeploy.CommonExtensions.DataTypes import float32_t
3335
from Deeploy.DeeployTypes import NetworkContext
3436
from Deeploy.Targets.Generic.Parsers import MatMulParser
3537

3638

3739
class GEMMRedmuleParser(MatMulParser):
3840

3941
def __init__(self, noBiasHoisting = True):
42+
# Order matters: super().__init__() of MatMulParser also writes
43+
# self.noBiasHoisting from its own default, so call super first and
44+
# then overwrite, otherwise our flag gets clobbered to True.
45+
super().__init__(noBiasHoisting = noBiasHoisting)
4046
self.noBiasHoisting = noBiasHoisting
41-
super().__init__()
4247

4348
def parseNode(self, node: gs.Node) -> (bool):
4449

@@ -85,9 +90,23 @@ def parseNodeCtxt(self,
8590
if len(node.inputs) == 3:
8691
self.operatorRepresentation['C'] = newCtxt.lookup(node.inputs[2].name).name
8792
elif not self.noBiasHoisting:
88-
values = np.zeros((1))
93+
# Hoist a zero C tensor whose shape matches the GEMM output, so
94+
# the bias-required RedmuleGEMMTileConstraint and the existing
95+
# 3-operand kernel template can run unchanged on bias-less
96+
# Gemm nodes (e.g. backward GradFusedMatMul rewrites in CCT
97+
# training graphs that emit Y = A @ B with no C).
98+
outShape = node.outputs[0].shape
99+
values = np.zeros(outShape, dtype = np.float32)
89100
zeroTensor = gs.Constant(f'{node.name}_C_Tensor', values = values)
90-
newCtxt.hoistConstant(zeroTensor)
101+
newCtxt.hoistConstant(zeroTensor, _type = PointerClass(float32_t))
102+
# Also wire the hoisted Constant into the gs.Node inputs so the
103+
# tiler picks it up via its `node.inputs + node.outputs` walk,
104+
# AND register the Gemm as a user of the new buffer so the
105+
# MemoryConstraintFlow's kill-set analysis (which walks
106+
# `_users`) can find a consumer for it. Without these the
107+
# tiler / flow analyzer KeyError or assert on the C tensor.
108+
node.inputs.append(zeroTensor)
109+
newCtxt.addUser(f'{node.name}_C_Tensor', node)
91110
self.operatorRepresentation['C'] = f'{node.name}_C_Tensor'
92111

93112
self.operatorRepresentation['size'] = np.prod(newCtxt.lookup(node.inputs[0].name).shape)

Deeploy/Targets/Redmule/TopologyOptimizationPasses/Passes.py

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -149,83 +149,3 @@ def __init__(self, redmuleEngineName: str):
149149
replacement_fn = _redmule_gemm_transpose_fun,
150150
name = "_REDMULE_GEMM_TRANSPOSE_PASS")
151151

152-
153-
def _redmule_biasless_gemm_to_matmul_fun(graph: gs.Graph, match: Match, name: str):
154-
"""Rewrite a 2-input ONNX Gemm (no C / bias) into an equivalent MatMul.
155-
156-
Backward-pass codegen (e.g. the ``GradFusedMatMul`` rewrites that fall out
157-
of the CCT training graph) emits ``Gemm`` nodes with only A and B and no
158-
bias, which the ``GEMMRedmuleParser`` accepts but for which the
159-
``RedmuleGEMMTileConstraint`` then crashes (KeyError on ``parseDict['C']``).
160-
A bias-less Gemm with alpha=1 is mathematically just a MatMul, and the
161-
Redmule platform already maps ONNX ``MatMul`` to a kernel that doesn't
162-
expect a C operand -- so we lower it here.
163-
164-
transA / transB are materialized as explicit ``Transpose`` nodes (or, for
165-
constant operands, folded into the constant) before the op is rewritten,
166-
because ``MatMul`` has no equivalent attributes.
167-
"""
168-
gemm_node = list(match.nodes_map.values())[0]
169-
170-
# Pattern matcher may match Gemms with 3 inputs too; act only on the
171-
# bias-less subset.
172-
if len(gemm_node.inputs) != 2:
173-
return graph
174-
175-
# Anything other than alpha=1 cannot be expressed as a plain MatMul.
176-
if gemm_node.attrs.get('alpha', 1.0) != 1.0:
177-
return graph
178-
179-
transA = gemm_node.attrs.get('transA', 0)
180-
transB = gemm_node.attrs.get('transB', 0)
181-
182-
for inputIdx, transFlag in ((0, transA), (1, transB)):
183-
if not transFlag:
184-
continue
185-
operand = gemm_node.inputs[inputIdx]
186-
if isinstance(operand, gs.Constant):
187-
if len(operand.values.shape) > 2:
188-
perm = list(range(len(operand.values.shape)))
189-
perm[-1], perm[-2] = perm[-2], perm[-1]
190-
operand.values = np.transpose(operand.values, perm)
191-
else:
192-
operand.values = np.transpose(operand.values)
193-
else:
194-
perm = list(range(len(operand.shape)))
195-
perm[-1], perm[-2] = perm[-2], perm[-1]
196-
anchorTransposeNode = _appendTranspose(operand, gemm_node, perm)
197-
graph.nodes.append(anchorTransposeNode)
198-
199-
gemm_node.op = "MatMul"
200-
gemm_node.attrs.clear()
201-
202-
return graph
203-
204-
205-
@contextagnostic
206-
class RedMuleBiaslessGemmToMatMulPass(ReplaceSequentialPatternPass):
207-
"""Lower bias-less (2-input) ONNX Gemm nodes to MatMul on the Redmule path.
208-
209-
Must run before RedMuleGEMMTransposePass so the latter only sees the
210-
real (3-input) Gemm nodes; otherwise its replacement_fn would write
211-
``transA`` / ``transB`` back to 0 on what is now a MatMul, and a stale
212-
``Gemm`` op type would still hit the bias-required tile constraint.
213-
"""
214-
215-
def __init__(self, redmuleEngineName: str):
216-
pattern = gs.Graph()
217-
218-
input_a = gs.Variable(name = "input_a")
219-
input_b = gs.Variable(name = "input_b")
220-
221-
gemm_output = pattern.layer(op = "Gemm",
222-
name = "gemm_node",
223-
inputs = [input_a, input_b],
224-
outputs = ["gemm_output"])
225-
226-
pattern.inputs = [input_a, input_b]
227-
pattern.outputs = [gemm_output]
228-
229-
super().__init__(pattern = pattern,
230-
replacement_fn = _redmule_biasless_gemm_to_matmul_fun,
231-
name = "_REDMULE_BIASLESS_GEMM_TO_MATMUL_PASS")

0 commit comments

Comments
 (0)