Skip to content

Commit 1fe7cbb

Browse files
committed
Allow ConvGradW C_out tiling to fit dW in smaller L1 budgets
Root cause: ConvGradWTileConstraintBase.addPolicyConstraint hard-pinned all four dW dimensions to their full shape and also forced dyName[1] full. At L1=128KB the accumulation target dW for layer3.conv2 of ResNet8 (64x64x3x3x4 = 147456 B) alone exceeds L1, making the OR-Tools geometric model infeasible. Fix: dW[C_out, C_in, kH, kW] has the property that each C_out slice is computed independently (dW[co] = sum_nhw dY[n,co,h,w] * X[n,:,...]). Drop the C_out full constraint on dW[0] and dyName[1]; keep Cin/kH/kW pinned so the tile remains a contiguous leading sub-range of the dW buffer (safe 1D DMA). Extend serializeTilingSolution with an outer loop over C_out tiles that pulls Cout_tile_max from the tiler solution, emits per-tile HyperRectangles with the correct C_out offset, and propagates the tile size into the ch_im_out replacement. When Cout_tile == Cout_full the iteration count is one, so previously-working configurations (e.g. ResNet8 at L1=300KB, DSCNN) are unchanged. Verified: - ResNet8 L1=300KB L3: tiling still feasible (previously working) - ResNet8 L1=128KB L3: ConvGradW no longer the blocker; ConvGradX full- weight constraint remains the blocker for layer3.conv2 at 128KB, needs C_in tiling + 2D strided DMA (plan B)
1 parent de6e364 commit 1fe7cbb

1 file changed

Lines changed: 62 additions & 47 deletions

File tree

Deeploy/Targets/PULPOpen/TileConstraints/ConvGradConstraint.py

Lines changed: 62 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -575,8 +575,9 @@ def addGeometricalConstraint(cls, tilerModel: TilerModel, parseDict: Dict, ctxt:
575575
def addPolicyConstraint(cls, tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel:
576576
"""
577577
Default policy:
578-
- keep full Cin/Cout on X and dY
579-
- dW output is full (no tiling) because accumulation
578+
- keep full Cin on X
579+
- allow C_out tiling on dY and dW[0] (dW[co] slices are independent per co)
580+
- keep dW Cin/kH/kW full (contiguous slice along leading C_out axis)
580581
- kernel dims fixed (no tiling)
581582
- allow H/W tiling on dY (and derived halo on X)
582583
"""
@@ -585,15 +586,13 @@ def addPolicyConstraint(cls, tilerModel: TilerModel, parseDict: Dict, ctxt: Netw
585586
dwName = parseDict[cls.weightKey]
586587

587588
xBuf = ctxt.lookup(xName)
588-
dyBuf = ctxt.lookup(dyName)
589589
dwBuf = ctxt.lookup(dwName)
590590

591-
# Full channels for inputs
592-
tilerModel.addConstraint(tilerModel.getTensorDimVar(xName, 1) == xBuf.shape[1]) # Cin
593-
tilerModel.addConstraint(tilerModel.getTensorDimVar(dyName, 1) == dyBuf.shape[1]) # Cout
591+
# Full Cin on X (reduction axis for dW is spatial, Cin is independent per output channel)
592+
tilerModel.addConstraint(tilerModel.getTensorDimVar(xName, 1) == xBuf.shape[1])
594593

595-
# dW is full (all dims)
596-
for d in range(len(dwBuf.shape)):
594+
# dW: keep Cin / kH / kW full; allow C_out (dim 0) to tile
595+
for d in range(1, len(dwBuf.shape)):
597596
tilerModel.addConstraint(tilerModel.getTensorDimVar(dwName, d) == dwBuf.shape[d])
598597

599598
# dY tile spatial dims >= 1
@@ -800,49 +799,65 @@ def serializeTilingSolution(
800799
Cin_full = xFull[1]
801800
Cout_full = dyFull[1]
802801

803-
# dW is full cube (accumulation target)
804-
fullDW = HyperRectangle((0, 0, 0, 0), dwShape)
802+
# C_out tile size from tiler solution (falls back to full when not tiled)
803+
try:
804+
dwTileShape = tilingSolution.tensorMemoryConstraints[dwName].memoryConstraints[targetMemLevel].shape
805+
Cout_tile_max = dwTileShape[0]
806+
except Exception:
807+
Cout_tile_max = Cout_full
808+
809+
co_tiles: List[Tuple[int, int]] = []
810+
co = 0
811+
while co < Cout_full:
812+
cs = min(Cout_tile_max, Cout_full - co)
813+
co_tiles.append((co, cs))
814+
co += cs
805815

806816
inputLoadSchedule = []
807817
outputLoadSchedule = []
808818

809-
# Build tiles
810-
for hoOff, hoSz in h_tiles:
811-
for woOff, woSz in w_tiles:
812-
dyTile = HyperRectangle(
813-
(0, 0, hoOff, woOff),
814-
(N_tile, Cout_full, hoSz, woSz),
815-
)
816-
817-
xTile, (tpt, tpb, tpl, tpr) = cls.computeInputTileFromGradOutTile(
818-
kernel_hw=(dwShape[2], dwShape[3]),
819-
pads=pads,
820-
strides=strides,
821-
inputCSize=Cin_full,
822-
gradOutTile=dyTile,
823-
inputFull=xFull,
824-
gradOutFull=dyFull,
825-
)
826-
827-
# dims (x=H, y=W)
828-
replacements["dim_im_in_x"].append(xTile.dims[2])
829-
replacements["dim_im_in_y"].append(xTile.dims[3])
830-
replacements["dim_im_out_x"].append(dyTile.dims[2])
831-
replacements["dim_im_out_y"].append(dyTile.dims[3])
832-
833-
replacements["ch_im_in"].append(Cin_full)
834-
replacements["ch_im_out"].append(Cout_full)
835-
836-
# ONNX pads (t,b,l,r) -> unified naming:
837-
# padding_y_top/bottom : H dimension => top/bottom
838-
# padding_x_left/right : W dimension => left/right
839-
replacements["padding_y_top"].append(tpt) # H_begin = top
840-
replacements["padding_y_bottom"].append(tpb) # H_end = bottom
841-
replacements["padding_x_left"].append(tpl) # W_begin = left
842-
replacements["padding_x_right"].append(tpr) # W_end = right
843-
844-
inputLoadSchedule.append({cls.dataInKey: xTile, cls.gradOutKey: dyTile})
845-
outputLoadSchedule.append({cls.weightKey: fullDW})
819+
# Build tiles: outer loop over C_out, inner over spatial
820+
for coOff, coSz in co_tiles:
821+
dwTile = HyperRectangle(
822+
(coOff, 0, 0, 0),
823+
(coSz, dwShape[1], dwShape[2], dwShape[3]),
824+
)
825+
for hoOff, hoSz in h_tiles:
826+
for woOff, woSz in w_tiles:
827+
dyTile = HyperRectangle(
828+
(0, coOff, hoOff, woOff),
829+
(N_tile, coSz, hoSz, woSz),
830+
)
831+
832+
xTile, (tpt, tpb, tpl, tpr) = cls.computeInputTileFromGradOutTile(
833+
kernel_hw=(dwShape[2], dwShape[3]),
834+
pads=pads,
835+
strides=strides,
836+
inputCSize=Cin_full,
837+
gradOutTile=dyTile,
838+
inputFull=xFull,
839+
gradOutFull=dyFull,
840+
)
841+
842+
# dims (x=H, y=W)
843+
replacements["dim_im_in_x"].append(xTile.dims[2])
844+
replacements["dim_im_in_y"].append(xTile.dims[3])
845+
replacements["dim_im_out_x"].append(dyTile.dims[2])
846+
replacements["dim_im_out_y"].append(dyTile.dims[3])
847+
848+
replacements["ch_im_in"].append(Cin_full)
849+
replacements["ch_im_out"].append(coSz)
850+
851+
# ONNX pads (t,b,l,r) -> unified naming:
852+
# padding_y_top/bottom : H dimension => top/bottom
853+
# padding_x_left/right : W dimension => left/right
854+
replacements["padding_y_top"].append(tpt) # H_begin = top
855+
replacements["padding_y_bottom"].append(tpb) # H_end = bottom
856+
replacements["padding_x_left"].append(tpl) # W_begin = left
857+
replacements["padding_x_right"].append(tpr) # W_end = right
858+
859+
inputLoadSchedule.append({cls.dataInKey: xTile, cls.gradOutKey: dyTile})
860+
outputLoadSchedule.append({cls.weightKey: dwTile})
846861

847862
tilingSchedule = TilingSchedule(inputBaseOffsets, outputBaseOffsets, inputLoadSchedule, outputLoadSchedule)
848863
variableReplacementSchedule = VariableReplacementScheme(replacements, replacementTypes)

0 commit comments

Comments
 (0)