Skip to content

Commit 950b1fd

Browse files
authored
fix(tiling): MiniMalloc cost variable + SGD rank-2 DMA constraint (#18)
* fix(tiling): force SGD weight spatial dims to full size for rank-2 DMA SGDTileConstraint now pins dim_i == shape[i] for i >= 2 (kernel_h, kernel_w). minimizeRectangle can then collapse the trailing dims so every L2<->L3 DMA tile is rank-2, avoiding the AnydimAsyncDmaTransferAdapter for-loop that emitted 4 096 blocking pi_cl_ram_copy_2d calls per L2 tile for [128,128,3,3] weights (~49x slowdown on ResNet8 optimizer with MiniMalloc). * fix(tiling): force InPlaceAccumulatorV2 weight-grad spatial dims to full size InPlaceAccumulatorV2TileConstraint now pins dim_i == shape[i] for i >= 2 (kH, kW) on the accum_buffer tensor. BOPTileConstraint already ties gradient and data_out dims to accum_buffer, so one pin is enough. Without this, MiniMalloc tiles [C_out, C_in, kH, kW] weight-gradient tensors along all four dims. For ResNet8 layer3.conv2 [128,128,3,3] this produced an explicit for-loop of 4096 iterations inside the L3 DMA closure (pi_cl_ram_copy_2d(4 B) + pi_cl_ram_copy_wait per iteration), resulting in ~73 k blocking L3 DMA calls per training step. After the fix minimizeRectangle collapses kH×kW -> rank-2 tiles so each L2->L3 transfer is a single contiguous pi_cl_ram_copy_2d (~41 KB). Verified: 0 blocking DMA for-loops in ResNet8 TrainingNetwork.c.
1 parent 5b4a394 commit 950b1fd

2 files changed

Lines changed: 28 additions & 0 deletions

File tree

Deeploy/Targets/PULPOpen/TileConstraints/InPlaceAccumulatorV2TileConstraint.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ class InPlaceAccumulatorV2TileConstraint(BOPTileConstraint):
3131
def addGeometricalConstraint(cls, tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel:
3232
tilerModel = super().addGeometricalConstraint(tilerModel, parseDict, ctxt)
3333

34+
# Force spatial dims (index >= 2) to full size so that minimizeRectangle
35+
# can collapse them and DMA tiles stay rank ≤ 2 (L3Dma pi_cl_ram_copy_2d limit).
36+
accumName = parseDict[cls.dataIn1Name]
37+
shape = ctxt.lookup(accumName).shape
38+
if not isinstance(shape, int) and len(shape) > 2:
39+
for dimIdx in range(2, len(shape)):
40+
dimVar = tilerModel.getTensorDimVar(tensorName = accumName, dimIdx = dimIdx)
41+
tilerModel.addConstraint(dimVar == shape[dimIdx])
42+
3443
# lazy_reset_grad is a scalar flag — pin full size so it is not tiled.
3544
lazyResetName = parseDict['lazy_reset_grad']
3645
tilerModel.addTensorDimToModel(ctxt, lazyResetName)

Deeploy/Targets/PULPOpen/TileConstraints/SGDTileConstraint.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from typing import Dict
6+
7+
from Deeploy.DeeployTypes import NetworkContext
58
from Deeploy.Targets.Generic.TileConstraints.BOPTileConstraint import BOPTileConstraint
9+
from Deeploy.TilingExtension.TilerModel import TilerModel
610

711

812
class SGDTileConstraint(BOPTileConstraint):
@@ -11,6 +15,21 @@ class SGDTileConstraint(BOPTileConstraint):
1115
dataIn2Name = 'grad'
1216
dataOutName = 'weight_updated'
1317

18+
@classmethod
19+
def addGeometricalConstraint(cls, tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel:
20+
tilerModel = super().addGeometricalConstraint(tilerModel, parseDict, ctxt)
21+
22+
# Force spatial dims (index >= 2) to full size so that minimizeRectangle
23+
# can collapse them and DMA tiles stay rank ≤ 2 (L3Dma pi_cl_ram_copy_2d limit).
24+
weightName = parseDict[cls.dataIn1Name]
25+
shape = ctxt.lookup(weightName).shape
26+
if not isinstance(shape, int) and len(shape) > 2:
27+
for dimIdx in range(2, len(shape)):
28+
dimVar = tilerModel.getTensorDimVar(tensorName = weightName, dimIdx = dimIdx)
29+
tilerModel.addConstraint(dimVar == shape[dimIdx])
30+
31+
return tilerModel
32+
1433

1534
class ReluGradTileConstraint(BOPTileConstraint):
1635

0 commit comments

Comments
 (0)