Skip to content

Commit 8e3bbe7

Browse files
committed
[CNNTrain] Add ConGradX
1 parent 04b81ac commit 8e3bbe7

12 files changed

Lines changed: 423 additions & 5 deletions

File tree

Deeploy/CommonExtensions/OptimizationPasses/TopologyOptimizationPasses/LoweringOptimizationPasses.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def _NCHWtoNHWC_fun(graph: gs.Graph, match: Match, name: str, default_channels_f
227227
tensorIn = node.inputs[0]
228228
tensorOut = node.outputs[0]
229229

230-
if node.op in ["RequantizedConv", "Conv"]:
230+
if node.op in ["RequantizedConv", "Conv", "ConvGradX"]:
231231
spatialDims = len(node.inputs[1].shape) - 2
232232
elif node.op in ["MaxPool", "AveragePool", "AveragePoolGrad"]:
233233
spatialDims = len(node.attrs["kernel_shape"])
@@ -242,8 +242,9 @@ def _NCHWtoNHWC_fun(graph: gs.Graph, match: Match, name: str, default_channels_f
242242
permuteOut = _transformLayoutPermutation(len(tensorOut.shape), spatialDims, channels_first)
243243
graph.nodes.append(_prependTranspose(tensorOut, node, permuteOut))
244244

245-
if node.op in ["Conv", "RequantizedConv"]:
245+
if node.op in ["Conv", "RequantizedConv", "ConvGradX"]:
246246
# In the case of Conv: [weights, opt. bias], RequantizedConv: [weights, mul, add, opt. shift]
247+
# ConvGradX: [weight] (no bias)
247248
for tensor in node.inputs[1:]:
248249
_transformLayoutConst(tensor, spatialDims, default_channels_first)
249250

@@ -279,6 +280,15 @@ def __init__(self, default_channels_first: bool = True):
279280
super().__init__(graph, partial(_NCHWtoNHWC_fun, default_channels_first = default_channels_first), name)
280281

281282

283+
@contextagnostic
284+
class NCHWtoNHWCConvGradXPass(ReplaceSequentialPatternPass):
285+
286+
def __init__(self, default_channels_first: bool = True):
287+
graph = _singleNodePattern(op = "ConvGradX")
288+
name = "_NCHW_TO_NHWC_CONVGRADX_PASS"
289+
super().__init__(graph, partial(_NCHWtoNHWC_fun, default_channels_first = default_channels_first), name)
290+
291+
282292
@contextagnostic
283293
class NCHWtoNHWCConvPass(ReplaceSequentialPatternPass):
284294

@@ -383,6 +393,7 @@ def __init__(self, default_channels_first: bool = True):
383393
NCHWtoNHWCMaxPoolPass(default_channels_first),
384394
NCHWtoNHWCAveragePoolPass(default_channels_first),
385395
NCHWtoNHWCAveragePoolGradPass(default_channels_first),
396+
NCHWtoNHWCConvGradXPass(default_channels_first),
386397
NCHWtoNHWCDwConvPass(default_channels_first),
387398
NCHWtoNHWCConvPass(default_channels_first),
388399
]
@@ -398,6 +409,7 @@ def __init__(self, default_channels_first: bool = True):
398409
NCHWtoNHWCMaxPoolPass(default_channels_first),
399410
NCHWtoNHWCAveragePoolPass(default_channels_first),
400411
NCHWtoNHWCAveragePoolGradPass(default_channels_first),
412+
NCHWtoNHWCConvGradXPass(default_channels_first),
401413
PULPNCHWtoNHWCDwConvPass(default_channels_first),
402414
NCHWtoNHWCConvPass(default_channels_first),
403415
]

Deeploy/Targets/PULPOpen/Bindings.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,13 @@
251251
ForkTransformer) for float_type in FloatDataTypes
252252
]
253253

254+
PULPFloatConvTrans2DBindings = [
255+
NodeBinding(
256+
ConvChecker([PointerClass(float32_t), PointerClass(float32_t)],
257+
[PointerClass(float32_t)]), FloatConvTemplate.referenceConvTrans2DTemplate,
258+
ForkTransformer)
259+
]
260+
254261
PULPRQSMatrixVecBindings = [
255262
NodeBinding(
256263
PULPLinearChecker([PointerClass(type1),
@@ -464,3 +471,4 @@
464471
NodeBinding(DequantChecker([PointerClass(int32_t)], [PointerClass(float32_t)]), DequantTemplate.referenceTemplate,
465472
ForkTransformer),
466473
]
474+

Deeploy/Targets/PULPOpen/Parsers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,3 +462,8 @@ def parseNodeCtxt(self,
462462
return ctxt, False
463463

464464
return newCtxt, True
465+
466+
class PULPConvTrans2DParser(PULPFPConv2DParser):
467+
468+
def __init__(self, noBiasHoisting = True):
469+
super().__init__(noBiasHoisting)

Deeploy/Targets/PULPOpen/Platform.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from Deeploy.Targets.PULPOpen.Layers import PULPRQSConvLayer, PULPRQSGEMMLayer
3535
from Deeploy.Targets.PULPOpen.Parsers import PULPConv1DParser, PULPConv2DParser, PULPDWConv1DParser, \
3636
PULPDWConv2DParser, PULPFPConv2DParser, PULPFPDWConv2DParser, PULPGEMMParser, PULPMatrixVecParser, \
37-
PULPTallGEMMParser
37+
PULPTallGEMMParser, PULPConvTrans2DParser
3838
from Deeploy.Targets.PULPOpen.Templates import AllocateTemplate, FreeTemplate
3939
from Deeploy.Targets.PULPOpen.Tiler import PULPAddTilingReadyBindings, PULPConcatTilingReadyBindings, \
4040
PULPConv2DTilingReadyBindings, PULPDWConv2DTilingReadyBindings, PULPFlattenTilingReadyBindings, \
@@ -48,7 +48,8 @@
4848
PULPRQSMatrixVecTilingReadyBindings, PULPRQSTallGEMMTilingReadyBindings, PULPRQSTilingReadyBindings, \
4949
PULPSGDTilingReadyBindings, PULPSliceTilingReadyBindings, PULPSoftmaxCrossEntropyGradTilingReadyBindings, \
5050
PULPSoftmaxCrossEntropyTilingReadyBindings, PULPSoftmaxGradTilingReadyBindings, PULPSoftmaxTilingReadyBindings, \
51-
PULPTransposeTilingReadyBindings, PULPUniformRQSTilingReadyBindings, PULPAveragePool2DTilingReadyBindings, PULPAveragePoolGrad2DTilingReadyBindings
51+
PULPTransposeTilingReadyBindings, PULPUniformRQSTilingReadyBindings, PULPAveragePool2DTilingReadyBindings, \
52+
PULPAveragePoolGrad2DTilingReadyBindings, PULPConvTrans2DTilingReadyBindings
5253
from Deeploy.Targets.PULPOpen.TopologyOptimizationPasses.Passes import PULPAddRequantMergePass, \
5354
PULPConvRequantMergePass, PULPGEMMRequantMergePass, PULPMatMulRequantMergePass
5455

@@ -77,6 +78,9 @@
7778
Conv1DMapper = NodeMapper(PULPConv1DParser(), [PULPConv1DBinding])
7879
DWConv1DMapper = NodeMapper(PULPDWConv1DParser(), [PULPDWConv1DBinding])
7980
FPConv2DMapper = NodeMapper(PULPFPConv2DParser(), PULPConv2DTilingReadyBindings)
81+
82+
ConvGradXMapper = NodeMapper(PULPConvTrans2DParser(), PULPConvTrans2DTilingReadyBindings)
83+
8084
Conv2DMapper = NodeMapper(PULPConv2DParser(), PULPRQSConv2DTilingReadyBindings)
8185
FPDWConv2DMapper = NodeMapper(PULPFPDWConv2DParser(), PULPDWConv2DTilingReadyBindings)
8286
DWConv2DMapper = NodeMapper(PULPDWConv2DParser(), PULPRQSDWConv2DTilingReadyBindings)
@@ -113,6 +117,7 @@
113117
AveragePoolGrad2DMapper = NodeMapper(AveragePool2DParser(), PULPAveragePoolGrad2DTilingReadyBindings)
114118
PULPMapping = {
115119
'Conv': ConvLayer([FPConv2DMapper, FPDWConv2DMapper]),
120+
'ConvGradX': ConvLayer([ConvGradXMapper]),
116121
'RequantizedConv': PULPRQSConvLayer([Conv2DMapper, DWConv2DMapper, Conv1DMapper, DWConv1DMapper]),
117122
'RequantizedGemm': PULPRQSGEMMLayer([MatrixVecMapper, TallGEMMMapper, GEMMMapper]),
118123
'Gemm': GEMMLayer([FloatGEMMMapper, GEMMDequantMapper]),

Deeploy/Targets/PULPOpen/Templates/FloatConvTemplate.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,24 @@ def hoistTransientBuffers(self, ctxt: NetworkContext,
158158
ref_${data_out}_${data_out} += ${ch_im_out} * ${dim_im_out_x} * ${dim_im_out_y};
159159
}
160160
""")
161+
162+
163+
referenceConvTrans2DTemplate = NodeTemplate("""
164+
// 2D FP ConvTranspose HWC (Name: ${nodeName}, Op: ${nodeOp})
165+
${data_in_type.typeName} ref_${data_out}_${data_in} = ${data_in};
166+
${data_out_type.typeName} ref_${data_out}_${data_out} = ${data_out};
167+
for (uint32_t n=0; n<${batch}; ++n) {
168+
PULP_ConvTrans2d_fp${data_in_type.referencedType.typeWidth}_fp${weight_type.referencedType.typeWidth}_fp${data_out_type.referencedType.typeWidth}_HWC(
169+
ref_${data_out}_${data_in},
170+
${dim_im_in_y}, ${dim_im_in_x}, ${ch_im_in},
171+
${weight}, ${ch_im_out},
172+
${dim_kernel_y}, ${dim_kernel_x},
173+
${stride_y}, ${stride_x},
174+
ref_${data_out}_${data_out},
175+
${padding_y_top}, ${padding_y_bottom}, ${padding_x_left}, ${padding_x_right}
176+
);
177+
178+
ref_${data_out}_${data_in} += ${ch_im_in} * ${dim_im_in_x} * ${dim_im_in_y};
179+
ref_${data_out}_${data_out} += ${ch_im_out} * ${dim_im_out_x} * ${dim_im_out_y};
180+
}
181+
""")

0 commit comments

Comments
 (0)