Skip to content

Commit b387b14

Browse files
committed
[CNNTraining] ReluGrad
1 parent 5f6813a commit b387b14

9 files changed

Lines changed: 84 additions & 4 deletions

File tree

Deeploy/Targets/Generic/Layers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,15 @@ def computeOps(self):
445445
return self.mapper.parser.operatorRepresentation['size']
446446

447447

448+
class ReluGradLayer(ONNXLayer):
449+
450+
def __init__(self, maps: List[NodeMapper]):
451+
super().__init__(maps)
452+
453+
def computeOps(self):
454+
return self.mapper.parser.operatorRepresentation['size']
455+
456+
448457
class LayerNormLayer(ONNXLayer):
449458

450459
def __init__(self, maps: List[NodeMapper]):

Deeploy/Targets/Generic/Parsers.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,6 +1084,33 @@ def parseNodeCtxt(self,
10841084
return ctxt, True
10851085

10861086

1087+
class ReluGradParser(NodeParser):
1088+
1089+
def __init__(self):
1090+
super().__init__()
1091+
1092+
def parseNode(self, node: gs.Node) -> bool:
1093+
1094+
ret = all([len(node.inputs) == 2, len(node.outputs) == 1])
1095+
return ret
1096+
1097+
def parseNodeCtxt(self,
1098+
ctxt: NetworkContext,
1099+
node: gs.Node,
1100+
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
1101+
1102+
upstream_grad = ctxt.lookup(node.inputs[0].name)
1103+
relu_input = ctxt.lookup(node.inputs[1].name)
1104+
relu_grad = ctxt.lookup(node.outputs[0].name)
1105+
1106+
self.operatorRepresentation['grad_in'] = upstream_grad.name
1107+
self.operatorRepresentation['data_in'] = relu_input.name
1108+
self.operatorRepresentation['grad_out'] = relu_grad.name
1109+
self.operatorRepresentation['size'] = np.prod(upstream_grad.shape)
1110+
1111+
return ctxt, True
1112+
1113+
10871114
class ReshapeParser(NodeParser):
10881115

10891116
def parseNode(self, node: gs.Node) -> (bool):

Deeploy/Targets/PULPOpen/Bindings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,10 @@
461461
PULPReluBinding = NodeBinding(ReluChecker([PointerClass(float32_t)], [PointerClass(float32_t)]),
462462
FloatReluTemplate.referenceTemplate, ForkTransformer)
463463

464+
PULPReluGradBinding = NodeBinding(
465+
ReluChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
466+
FloatReluTemplate.referenceGradTemplate, ForkTransformer)
467+
464468
PULPLayernormBinding = NodeBinding(
465469
LayerNormChecker(
466470
[PointerClass(float32_t), PointerClass(float32_t),

Deeploy/Targets/PULPOpen/Platform.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
BasicRQIntegerDivBinding
1616
from Deeploy.Targets.Generic.Layers import AddLayer, ConcatLayer, ConvLayer, GatherLayer, GELUGradLayer, GELULayer, \
1717
GEMMLayer, LayerNormGradLayer, LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, QuantLayer, \
18-
ReduceMeanLayer, ReduceSumLayer, ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, \
18+
ReduceMeanLayer, ReduceSumLayer, ReluLayer, ReluGradLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, \
1919
RQSiHardswishLayer, SGDLayer, SliceLayer, SoftmaxCrossEntropyLossGradLayer, SoftmaxCrossEntropyLossLayer, \
2020
SoftmaxGradLayer, SoftmaxLayer, TransposeLayer, iHardswishLayer, iRMSNormLayer, AveragePoolLayer, AveragePoolGradLayer
2121
from Deeploy.Targets.Generic.Parsers import AddParser, ConcatParser, DequantParser, FlattenParser, GatherParser, \
2222
GELUGradParser, GELUParser, GEMMParser, AveragePool2DParser, LayerNormGradParser, LayerNormParser, \
2323
MatMulParser, MaxPool2DParser, MulParser, Pad1DParser, Pad2DParser, QuantParser, ReduceMeanParser, \
24-
ReduceSumParser, ReluParser, RequantShiftParser, ReshapeParser, RQAddParser, RQIntegerDivParser, \
24+
ReduceSumParser, ReluParser, ReluGradParser, RequantShiftParser, ReshapeParser, RQAddParser, RQIntegerDivParser, \
2525
RQSiGELUParser, RQSiHardswishParser, SGDParser, SliceParser, SoftmaxCrossEntropyLossGradParser, \
2626
SoftmaxCrossEntropyLossParser, SoftmaxGradParser, SoftmaxParser, TransposeParser, UniformRequantShiftParser, \
2727
UnsqueezeParser, iHardswishParser, iRMSNormParser, iSoftmaxParser
@@ -44,7 +44,7 @@
4444
PULPiRMSNormTilingReadyBindings, PULPiRQSGELUTilingReadyBindings, PULPLayernormGradTilingReadyBindings, \
4545
PULPLayernormTilingReadyBindings, PULPMatMulTilingReadyBindings, PULPMaxPool2DTilingReadyBindings, \
4646
PULPMulTilingReadyBindings, PULPReduceMeanTilingReadyBindings, PULPReduceSumTilingReadyBindings, \
47-
PULPReluTilingReadyBindings, PULPRQAddTilingReadyBindings, PULPRQSConv2DTilingReadyBindings, \
47+
PULPReluTilingReadyBindings, PULPReluGradTilingReadyBindings, PULPRQAddTilingReadyBindings, PULPRQSConv2DTilingReadyBindings, \
4848
PULPRQSDWConv2DTilingReadyBindings, PULPRQSGEMMTilingReadyBindings, PULPRQSiHardswishTilingReadyBindings, \
4949
PULPRQSMatrixVecTilingReadyBindings, PULPRQSTallGEMMTilingReadyBindings, PULPRQSTilingReadyBindings, \
5050
PULPSGDTilingReadyBindings, PULPSliceTilingReadyBindings, PULPSoftmaxCrossEntropyGradTilingReadyBindings, \
@@ -98,6 +98,7 @@
9898
LayerNormMapper = NodeMapper(LayerNormParser(), PULPLayernormTilingReadyBindings)
9999
LayerNormGradMapper = NodeMapper(LayerNormGradParser(), PULPLayernormGradTilingReadyBindings)
100100
ReluMapper = NodeMapper(ReluParser(), PULPReluTilingReadyBindings)
101+
ReluGradMapper = NodeMapper(ReluGradParser(), PULPReluGradTilingReadyBindings)
101102
SoftmaxMapper = NodeMapper(SoftmaxParser(), PULPSoftmaxTilingReadyBindings)
102103
SoftmaxGradMapper = NodeMapper(SoftmaxGradParser(), PULPSoftmaxGradTilingReadyBindings)
103104
Softmax_int8_Mapper = NodeMapper(iSoftmaxParser(), PULPSoftmaxTilingReadyBindings)
@@ -151,6 +152,7 @@
151152
'Mul': MulLayer([MulMapper]),
152153
'Pad': PadLayer([Pad1DMapper, Pad2DMapper]),
153154
'Relu': ReluLayer([ReluMapper]),
155+
'ReluGrad': ReluGradLayer([ReluGradMapper]),
154156
'Reshape': ReshapeLayer([ReshapeMapper]),
155157
'Squeeze': ReshapeLayer([UnsqueezeMapper]),
156158
'Transpose': TransposeLayer([TransposeMapper]),

Deeploy/Targets/PULPOpen/Templates/FloatReluTemplate.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,14 @@
1111
${data_out},
1212
${size}
1313
);
14+
""")
15+
16+
referenceGradTemplate = NodeTemplate("""
17+
// ReLU Grad (Name: ${nodeName}, Op: ${nodeOp})
18+
PULP_ReluGrad_fp${grad_in_type.referencedType.typeWidth}_fp${grad_out_type.referencedType.typeWidth}(
19+
${grad_in},
20+
${data_in},
21+
${grad_out},
22+
${size}
23+
);
1424
""")

Deeploy/Targets/PULPOpen/Tiler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
PULPFloatDWConv2DBindings, PULPFloatGELUBinding, PULPFloatGELUGradBinding, PULPFloatGEMMBindings, \
1919
PULPGatherBindings, PULPiHardswishBindings, PULPiRMSNormBindings, \
2020
PULPiRQSGELUBindings, PULPLayernormBinding, PULPLayernormGradBinding, PULPMatMulBindings, PULPMaxPool2DBindings, \
21-
PULPMulBindings, PULPReduceMeanBindings, PULPReduceSumBindings, PULPReluBinding, PULPReshapeBindings, \
21+
PULPMulBindings, PULPReduceMeanBindings, PULPReduceSumBindings, PULPReluBinding, PULPReluGradBinding, PULPReshapeBindings, \
2222
PULPRQAddBindings, PULPRQSBindings, PULPRQSConv2DBindings, PULPRQSDWConv2DBindings, PULPRQSGEMMBindings, \
2323
PULPRQSiHardswishBindings, PULPRQSMatrixVecBindings, PULPRQSTallGEMMBindings, PULPSGDBindings, \
2424
PULPSliceBindings, PULPSoftmaxBindings, PULPSoftmaxCrossEntropyLossBindings, \
@@ -131,6 +131,9 @@
131131
PULPReluTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = [PULPReluBinding],
132132
tileConstraint = UnaryTileConstraint())
133133

134+
PULPReluGradTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = [PULPReluGradBinding],
135+
tileConstraint = UnaryTileConstraint())
136+
134137
PULPLayernormTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = [PULPLayernormBinding],
135138
tileConstraint = LayernormTileConstraint())
136139

TargetLibraries/PULPOpen/inc/DeeployPULPMath.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "kernel/Layernorm.h"
3030
#include "kernel/Matmul.h"
3131
#include "kernel/MaxPool.h"
32+
#include "kernel/Relu.h"
3233
#include "kernel/RQiHardswish.h"
3334
#include "kernel/RequantShift.h"
3435
#include "kernel/Softmax.h"

TargetLibraries/PULPOpen/inc/kernel/Relu.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,7 @@
1111

1212
void PULP_Relu_fp32_fp32(float32_t *input, float32_t *output, uint32_t size);
1313

14+
void PULP_ReluGrad_fp32_fp32(float32_t *grad_in, float32_t *data_in,
15+
float32_t *grad_out, uint32_t size);
16+
1417
#endif // __DEEPLOY_MATH_RELU_KERNEL_HEADER_

TargetLibraries/PULPOpen/src/Relu.c

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,25 @@ void PULP_Relu_fp32_fp32(float32_t *input, float32_t *output, uint32_t size) {
2323
for (int32_t i = 0; i < local_size; i++) {
2424
local_output[i] = MAX(local_input[i], 0.0f);
2525
}
26+
}
27+
28+
void PULP_ReluGrad_fp32_fp32(float32_t *grad_in, float32_t *data_in,
29+
float32_t *grad_out, uint32_t size) {
30+
31+
int8_t core_id = pi_core_id();
32+
int8_t log2Core = LOG2(NUM_CORES);
33+
34+
int32_t chunk = (size >> log2Core) + ((size & (NUM_CORES - 1)) != 0);
35+
int32_t start = MIN(chunk * core_id, size);
36+
int32_t end = MIN(start + chunk, size);
37+
int32_t local_size = end - start;
38+
39+
float32_t *local_grad_in = grad_in + start;
40+
float32_t *local_data_in = data_in + start;
41+
float32_t *local_grad_out = grad_out + start;
42+
43+
for (int32_t i = 0; i < local_size; i++) {
44+
// If input > 0, gradient flows through; otherwise gradient is 0
45+
local_grad_out[i] = (local_data_in[i] > 0.0f) ? local_grad_in[i] : 0.0f;
46+
}
2647
}

0 commit comments

Comments
 (0)