Skip to content

Commit ce5491e

Browse files
committed
[CNNTraining] ConvGradX, W,B and DW
1 parent 26b1e1b commit ce5491e

20 files changed

Lines changed: 591 additions & 50 deletions

File tree

Deeploy/Targets/PULPOpen/Bindings.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
SGDTemplate, SoftmaxCrossEntropyLossTemplate, TallGEMMTemplate, TransposeTemplate, UniformRequantShiftTemplate, \
3636
iRMSNormTemplate, iSoftmaxTemplate, FloatAveragePoolTemplate
3737
from Deeploy.Targets.PULPOpen.TypeCheckers import PULPConvChecker, PULPLinearChecker, PULPMaxPoolChecker, \
38-
PULPRequantShiftChecker
38+
PULPRequantShiftChecker, PULPConvGradBChecker
3939
from Deeploy.TilingExtension.CodeTransformationPasses.TilingVariableReplacement import TilingVariableReplacement, \
4040
TilingVariableReplacementUpdate
4141

@@ -258,6 +258,28 @@
258258
ForkTransformer)
259259
]
260260

261+
PULPFloatDWConvTrans2DBindings = [
262+
NodeBinding(
263+
ConvChecker([PointerClass(float32_t), PointerClass(float32_t)],
264+
[PointerClass(float32_t)]), FloatConvTemplate.referenceDWConvTrans2DTemplate,
265+
ForkTransformer)
266+
]
267+
268+
PULPFloatConvGradW2DBindings = [
269+
NodeBinding(
270+
ConvChecker([PointerClass(float32_t), PointerClass(float32_t)],
271+
[PointerClass(float32_t)]), FloatConvTemplate.referenceConvGradW2DTemplate,
272+
ForkTransformer)
273+
]
274+
275+
PULPFloatConvGradB2DBindings = [
276+
NodeBinding(
277+
PULPConvGradBChecker([PointerClass(float32_t)], # Only one input: output_grad
278+
[PointerClass(float32_t)]), # Output: bias_grad
279+
FloatConvTemplate.referenceConvGradB2DTemplate,
280+
ForkTransformer)
281+
]
282+
261283
PULPRQSMatrixVecBindings = [
262284
NodeBinding(
263285
PULPLinearChecker([PointerClass(type1),

Deeploy/Targets/PULPOpen/Parsers.py

Lines changed: 247 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,4 +466,250 @@ def parseNodeCtxt(self,
466466
class PULPConvTrans2DParser(PULPFPConv2DParser):
467467

468468
def __init__(self, noBiasHoisting = True):
469-
super().__init__(noBiasHoisting)
469+
super().__init__(noBiasHoisting)
470+
471+
def parseNode(self, node: gs.Node) -> bool:
472+
"""Override to recognize ConvGradX instead of Conv"""
473+
# Temporarily change op to Conv for parent parsing
474+
original_op = node.op
475+
if node.op == 'ConvGradX':
476+
node.op = 'Conv'
477+
478+
# Call parent parseNode
479+
wellFormed = super().parseNode(node)
480+
481+
# Restore original op
482+
node.op = original_op
483+
484+
# Additional validation for ConvGradX
485+
if wellFormed and original_op == 'ConvGradX':
486+
# ConvGradX should have 2 inputs: output_grad and weight
487+
return len(node.inputs) == 2
488+
489+
return wellFormed
490+
491+
def parseNodeCtxt(self,
492+
ctxt: NetworkContext,
493+
node: gs.Node,
494+
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
495+
"""Override for ConvGradX - swap input/output semantics"""
496+
497+
if node.op == 'ConvGradX':
498+
# For ConvGradX: inputs are [output_grad, weight], output is input_grad
499+
# But parent expects: inputs are [input, weight], output is output
500+
# So we need to swap the semantics
501+
502+
# Temporarily swap input/output for parent parsing
503+
output_grad_name = node.inputs[0].name
504+
input_grad_name = node.outputs[0].name
505+
506+
# Get tensors
507+
output_grad = ctxt.lookup(output_grad_name)
508+
weight = ctxt.lookup(node.inputs[1].name)
509+
510+
# Create a temporary input tensor with output_grad's info as if it's the output
511+
# and output tensor with input_grad's info as if it's the input
512+
temp_input = node.inputs[0]
513+
temp_output = node.outputs[0]
514+
515+
# Swap
516+
node.inputs[0] = temp_output
517+
node.outputs[0] = temp_input
518+
519+
# Call parent
520+
newCtxt, ret = super().parseNodeCtxt(ctxt, node, channels_first)
521+
522+
# Restore
523+
node.inputs[0] = temp_input
524+
node.outputs[0] = temp_output
525+
526+
if ret:
527+
# Fix the tensor names for ConvGradX
528+
self.operatorRepresentation['data_in'] = output_grad_name
529+
self.operatorRepresentation['data_out'] = input_grad_name
530+
self.operatorRepresentation["has_bias"] = "false"
531+
self.operatorRepresentation["bias"] = "NULL"
532+
533+
return newCtxt, ret
534+
else:
535+
return super().parseNodeCtxt(ctxt, node, channels_first)
536+
537+
class PULPDWConvTrans2DParser(PULPFPDWConv2DParser):
538+
539+
def __init__(self, noBiasHoisting = True):
540+
super().__init__(noBiasHoisting)
541+
542+
def parseNode(self, node: gs.Node) -> bool:
543+
"""Override to recognize ConvGradX instead of Conv"""
544+
# Temporarily change op to Conv for parent parsing
545+
original_op = node.op
546+
if node.op == 'ConvGradX':
547+
node.op = 'Conv'
548+
549+
# Call parent parseNode
550+
wellFormed = super().parseNode(node)
551+
552+
# Restore original op
553+
node.op = original_op
554+
555+
# Additional validation for ConvGradX
556+
if wellFormed and original_op == 'ConvGradX':
557+
# ConvGradX should have 2 inputs: output_grad and weight
558+
return len(node.inputs) == 2
559+
560+
return wellFormed
561+
562+
def parseNodeCtxt(self,
563+
ctxt: NetworkContext,
564+
node: gs.Node,
565+
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
566+
"""Override for ConvGradX - swap input/output semantics"""
567+
568+
if node.op == 'ConvGradX':
569+
# For ConvGradX: inputs are [output_grad, weight], output is input_grad
570+
# Temporarily swap input/output for parent parsing
571+
output_grad_name = node.inputs[0].name
572+
input_grad_name = node.outputs[0].name
573+
574+
# Swap
575+
temp_input = node.inputs[0]
576+
temp_output = node.outputs[0]
577+
node.inputs[0] = temp_output
578+
node.outputs[0] = temp_input
579+
580+
# Call parent
581+
newCtxt, ret = super().parseNodeCtxt(ctxt, node, channels_first)
582+
583+
# Restore
584+
node.inputs[0] = temp_input
585+
node.outputs[0] = temp_output
586+
587+
if ret:
588+
# Fix the tensor names for ConvGradX
589+
self.operatorRepresentation['data_in'] = output_grad_name
590+
self.operatorRepresentation['data_out'] = input_grad_name
591+
self.operatorRepresentation["weight"] = ctxt.lookup(node.inputs[1].name).name
592+
self.operatorRepresentation["has_bias"] = "false"
593+
self.operatorRepresentation["bias"] = "NULL"
594+
595+
return newCtxt, ret
596+
else:
597+
return super().parseNodeCtxt(ctxt, node, channels_first)
598+
599+
600+
class PULPConvGradW2DParser(PULPFPConv2DParser):
601+
602+
def __init__(self, noBiasHoisting = True):
603+
super().__init__(noBiasHoisting)
604+
605+
def parseNodeCtxt(self,
606+
ctxt: NetworkContext,
607+
node: gs.Node,
608+
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
609+
"""Parse ConvGradW - need custom logic for input dimensions"""
610+
611+
if not self.parseNode(node):
612+
return ctxt, False
613+
614+
# Get input tensors
615+
grad_out_tensor = ctxt.lookup(node.inputs[0].name)
616+
data_in_tensor = ctxt.lookup(node.inputs[1].name)
617+
618+
# Extract batch size
619+
batch = grad_out_tensor.shape[0]
620+
621+
# Extract dimensions
622+
C_out, H_out, W_out = grad_out_tensor.shape[1], grad_out_tensor.shape[2], grad_out_tensor.shape[3]
623+
C_in, H_in, W_in = data_in_tensor.shape[1], data_in_tensor.shape[2], data_in_tensor.shape[3]
624+
625+
# Store batch size
626+
self.operatorRepresentation['batch'] = batch
627+
628+
# Store dimensions
629+
self.operatorRepresentation['ch_im_out'] = C_out
630+
self.operatorRepresentation['dim_im_out_x'] = W_out
631+
self.operatorRepresentation['dim_im_out_y'] = H_out
632+
self.operatorRepresentation['ch_im_in'] = C_in
633+
self.operatorRepresentation['dim_im_in_x'] = W_in
634+
self.operatorRepresentation['dim_im_in_y'] = H_in
635+
636+
# Store kernel dimensions
637+
self.operatorRepresentation['dim_kernel_y'] = self.operatorRepresentation['kernel_shape'][0]
638+
self.operatorRepresentation['dim_kernel_x'] = self.operatorRepresentation['kernel_shape'][1]
639+
640+
# Store strides
641+
self.operatorRepresentation['stride_y'] = self.operatorRepresentation['strides'][0]
642+
self.operatorRepresentation['stride_x'] = self.operatorRepresentation['strides'][1]
643+
644+
# Set tensor names and types
645+
self.operatorRepresentation['grad_out'] = node.inputs[0].name
646+
self.operatorRepresentation['grad_out_type'] = grad_out_tensor._type
647+
self.operatorRepresentation['data_in'] = node.inputs[1].name
648+
self.operatorRepresentation['data_in_type'] = data_in_tensor._type
649+
self.operatorRepresentation['weight'] = node.outputs[0].name
650+
self.operatorRepresentation['weight_type'] = grad_out_tensor._type # Same as grad_out
651+
652+
# No bias for ConvGradW
653+
self.operatorRepresentation['has_bias'] = 'false'
654+
self.operatorRepresentation['bias'] = 'NULL'
655+
656+
return ctxt, True
657+
658+
class PULPConvGradB2DParser(PULPFPConv2DParser):
659+
660+
def __init__(self):
661+
self.operatorRepresentation = {}
662+
663+
def parseNode(self, node: gs.Node) -> bool:
664+
"""Parse ConvGradB node attributes"""
665+
666+
# Check basic structure
667+
if node.op != 'ConvGradB':
668+
return False
669+
670+
if len(node.inputs) != 1: # only output_grad
671+
return False
672+
673+
if len(node.outputs) != 1: # bias_grad
674+
return False
675+
676+
return True
677+
678+
def parseNodeCtxt(self,
679+
ctxt: NetworkContext,
680+
node: gs.Node,
681+
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
682+
"""Parse ConvGradB node context"""
683+
684+
# For ConvGradB, the inputs are:
685+
# inputs[0]: output_grad [N, C_out, H_out, W_out] (NCHW)
686+
# output: bias_grad [C_out]
687+
688+
# Get tensors from context
689+
output_grad_tensor = ctxt.lookup(node.inputs[0].name)
690+
691+
# Extract batch size and dimensions (NCHW)
692+
batch = output_grad_tensor.shape[0]
693+
C_out = output_grad_tensor.shape[1]
694+
H_out = output_grad_tensor.shape[2]
695+
W_out = output_grad_tensor.shape[3]
696+
697+
# Store batch size
698+
self.operatorRepresentation['batch'] = batch
699+
700+
# Store dimensions
701+
self.operatorRepresentation['ch_im_out'] = C_out
702+
self.operatorRepresentation['dim_im_out_x'] = W_out
703+
self.operatorRepresentation['dim_im_out_y'] = H_out
704+
705+
# Dummy kernel_shape for computeOps (ConvGradB doesn't use kernels)
706+
self.operatorRepresentation['kernel_shape'] = [1, 1]
707+
self.operatorRepresentation['ch_im_in'] = 1 # Dummy value
708+
709+
# Set tensor names and types
710+
self.operatorRepresentation['grad_out'] = node.inputs[0].name
711+
self.operatorRepresentation['grad_out_type'] = output_grad_tensor._type
712+
self.operatorRepresentation['bias'] = node.outputs[0].name
713+
self.operatorRepresentation['bias_type'] = output_grad_tensor._type # Same type as grad_out
714+
715+
return ctxt, True

Deeploy/Targets/PULPOpen/Platform.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from Deeploy.Targets.PULPOpen.Layers import PULPRQSConvLayer, PULPRQSGEMMLayer
3535
from Deeploy.Targets.PULPOpen.Parsers import PULPConv1DParser, PULPConv2DParser, PULPDWConv1DParser, \
3636
PULPDWConv2DParser, PULPFPConv2DParser, PULPFPDWConv2DParser, PULPGEMMParser, PULPMatrixVecParser, \
37-
PULPTallGEMMParser, PULPConvTrans2DParser
37+
PULPTallGEMMParser, PULPConvTrans2DParser, PULPConvGradW2DParser, PULPConvGradB2DParser, PULPDWConvTrans2DParser
3838
from Deeploy.Targets.PULPOpen.Templates import AllocateTemplate, FreeTemplate
3939
from Deeploy.Targets.PULPOpen.Tiler import PULPAddTilingReadyBindings, PULPConcatTilingReadyBindings, \
4040
PULPConv2DTilingReadyBindings, PULPDWConv2DTilingReadyBindings, PULPFlattenTilingReadyBindings, \
@@ -49,7 +49,7 @@
4949
PULPSGDTilingReadyBindings, PULPSliceTilingReadyBindings, PULPSoftmaxCrossEntropyGradTilingReadyBindings, \
5050
PULPSoftmaxCrossEntropyTilingReadyBindings, PULPSoftmaxGradTilingReadyBindings, PULPSoftmaxTilingReadyBindings, \
5151
PULPTransposeTilingReadyBindings, PULPUniformRQSTilingReadyBindings, PULPAveragePool2DTilingReadyBindings, \
52-
PULPAveragePoolGrad2DTilingReadyBindings, PULPConvTrans2DTilingReadyBindings
52+
PULPAveragePoolGrad2DTilingReadyBindings, PULPConvTrans2DTilingReadyBindings, PULPConvGradW2DTilingReadyBindings, PULPConvGradB2DTilingReadyBindings, PULPDWConvTrans2DTilingReadyBindings
5353
from Deeploy.Targets.PULPOpen.TopologyOptimizationPasses.Passes import PULPAddRequantMergePass, \
5454
PULPConvRequantMergePass, PULPGEMMRequantMergePass, PULPMatMulRequantMergePass
5555

@@ -80,6 +80,9 @@
8080
FPConv2DMapper = NodeMapper(PULPFPConv2DParser(), PULPConv2DTilingReadyBindings)
8181

8282
ConvGradXMapper = NodeMapper(PULPConvTrans2DParser(), PULPConvTrans2DTilingReadyBindings)
83+
DwConvGradxMapper = NodeMapper(PULPDWConvTrans2DParser(), PULPDWConvTrans2DTilingReadyBindings)
84+
ConvGradWMapper = NodeMapper(PULPConvGradW2DParser(), PULPConvGradW2DTilingReadyBindings)
85+
ConvGradBMapper = NodeMapper(PULPConvGradB2DParser(), PULPConvGradB2DTilingReadyBindings)
8386

8487
Conv2DMapper = NodeMapper(PULPConv2DParser(), PULPRQSConv2DTilingReadyBindings)
8588
FPDWConv2DMapper = NodeMapper(PULPFPDWConv2DParser(), PULPDWConv2DTilingReadyBindings)
@@ -117,7 +120,9 @@
117120
AveragePoolGrad2DMapper = NodeMapper(AveragePool2DParser(), PULPAveragePoolGrad2DTilingReadyBindings)
118121
PULPMapping = {
119122
'Conv': ConvLayer([FPConv2DMapper, FPDWConv2DMapper]),
120-
'ConvGradX': ConvLayer([ConvGradXMapper]),
123+
'ConvGradX': ConvLayer([ConvGradXMapper, DwConvGradxMapper]),
124+
'ConvGradW': ConvLayer([ConvGradWMapper]),
125+
'ConvGradB': ConvLayer([ConvGradBMapper]),
121126
'RequantizedConv': PULPRQSConvLayer([Conv2DMapper, DWConv2DMapper, Conv1DMapper, DWConv1DMapper]),
122127
'RequantizedGemm': PULPRQSGEMMLayer([MatrixVecMapper, TallGEMMMapper, GEMMMapper]),
123128
'Gemm': GEMMLayer([FloatGEMMMapper, GEMMDequantMapper]),

Deeploy/Targets/PULPOpen/Templates/FloatConvTemplate.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,4 +178,65 @@ def hoistTransientBuffers(self, ctxt: NetworkContext,
178178
ref_${data_out}_${data_in} += ${ch_im_in} * ${dim_im_in_x} * ${dim_im_in_y};
179179
ref_${data_out}_${data_out} += ${ch_im_out} * ${dim_im_out_x} * ${dim_im_out_y};
180180
}
181+
""")
182+
183+
184+
185+
referenceConvGradW2DTemplate = NodeTemplate("""
186+
// 2D FP ConvGradW NCHW (Name: ${nodeName}, Op: ${nodeOp})
187+
${grad_out_type.typeName} ref_${weight}_${grad_out} = ${grad_out};
188+
${data_in_type.typeName} ref_${weight}_${data_in} = ${data_in};
189+
${weight_type.typeName} ref_${weight}_out = ${weight};
190+
191+
for (uint32_t n=0; n<${batch}; ++n) {
192+
PULP_ConvGradW2d_fp${grad_out_type.referencedType.typeWidth}_fp${data_in_type.referencedType.typeWidth}_fp${weight_type.referencedType.typeWidth}_NCHW(
193+
ref_${weight}_${grad_out},
194+
${dim_im_out_y}, ${dim_im_out_x}, ${ch_im_out},
195+
ref_${weight}_${data_in},
196+
${dim_im_in_y}, ${dim_im_in_x}, ${ch_im_in},
197+
${dim_kernel_y}, ${dim_kernel_x},
198+
${stride_y}, ${stride_x},
199+
ref_${weight}_out,
200+
${padding_y_top}, ${padding_y_bottom}, ${padding_x_left}, ${padding_x_right}
201+
);
202+
203+
ref_${weight}_${grad_out} += ${ch_im_out} * ${dim_im_out_x} * ${dim_im_out_y};
204+
ref_${weight}_${data_in} += ${ch_im_in} * ${dim_im_in_x} * ${dim_im_in_y};
205+
}
206+
""")
207+
208+
referenceConvGradB2DTemplate = NodeTemplate("""
209+
// 2D FP ConvGradB NCHW (Name: ${nodeName}, Op: ${nodeOp})
210+
${grad_out_type.typeName} ref_${bias}_${grad_out} = ${grad_out};
211+
${bias_type.typeName} ref_${bias}_out = ${bias};
212+
213+
for (uint32_t n=0; n<${batch}; ++n) {
214+
PULP_ConvGradB2d_fp${grad_out_type.referencedType.typeWidth}_fp${bias_type.referencedType.typeWidth}_NCHW(
215+
ref_${bias}_${grad_out},
216+
${dim_im_out_y}, ${dim_im_out_x}, ${ch_im_out},
217+
ref_${bias}_out
218+
);
219+
220+
ref_${bias}_${grad_out} += ${ch_im_out} * ${dim_im_out_x} * ${dim_im_out_y};
221+
}
222+
""")
223+
224+
referenceDWConvTrans2DTemplate = NodeTemplate("""
225+
// 2D FP DW ConvTranspose HWC (Name: ${nodeName}, Op: ${nodeOp})
226+
${data_in_type.typeName} ref_${data_out}_${data_in} = ${data_in};
227+
${data_out_type.typeName} ref_${data_out}_${data_out} = ${data_out};
228+
for (uint32_t n=0; n<${batch}; ++n) {
229+
PULP_DWConvTrans2d_fp${data_in_type.referencedType.typeWidth}_fp${weight_type.referencedType.typeWidth}_fp${data_out_type.referencedType.typeWidth}_HWC(
230+
ref_${data_out}_${data_in},
231+
${dim_im_out_y}, ${dim_im_out_x}, ${ch_im_out},
232+
${weight},
233+
${dim_kernel_y}, ${dim_kernel_x},
234+
${stride_y}, ${stride_x},
235+
ref_${data_out}_${data_out},
236+
${padding_y_top}, ${padding_y_bottom}, ${padding_x_left}, ${padding_x_right}
237+
);
238+
239+
ref_${data_out}_${data_in} += ${ch_im_out} * ${dim_im_out_x} * ${dim_im_out_y};
240+
ref_${data_out}_${data_out} += ${ch_im_in} * ${dim_im_in_x} * ${dim_im_in_y};
241+
}
181242
""")

0 commit comments

Comments
 (0)