Skip to content

Commit 26b1e1b

Browse files
committed
[CNNTraining] Stash
1 parent 8e3bbe7 commit 26b1e1b

15 files changed

Lines changed: 177 additions & 0 deletions

File tree

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# SPDX-FileCopyrightText: 2023 ETH Zurich and University of Bologna
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from typing import Dict, List, Optional, Tuple, Union
6+
7+
from ortools.constraint_solver.pywrapcp import IntVar
8+
9+
from Deeploy.AbstractDataTypes import PointerClass
10+
from Deeploy.CommonExtensions.DataTypes import uint8_t, uint16_t
11+
from Deeploy.DeeployTypes import NetworkContext, OperatorRepresentation
12+
from Deeploy.TilingExtension.MemoryConstraints import NodeMemoryConstraint
13+
from Deeploy.TilingExtension.TileConstraint import TileConstraint
14+
from Deeploy.TilingExtension.TilerModel import TilerModel
15+
from Deeploy.TilingExtension.TilingCodegen import AbsoluteHyperRectangle, HyperRectangle, TilingSchedule, \
16+
VariableReplacementScheme
17+
18+
19+
class ConvGradW2DTileConstraint(TileConstraint):
20+
21+
@staticmethod
22+
def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel:
23+
"""
24+
Geometrical constraint for ConvGradW (Weight Gradient).
25+
Computes gradient of weights from output gradient and input activations.
26+
27+
Layouts:
28+
data_in (grad_out) -> [N, H_out, W_out, C_out]
29+
weight (input_act) -> [N, H_in, W_in, C_in]
30+
data_out (grad_w) -> [C_out, K_h, K_w, C_in]
31+
"""
32+
inputName = parseDict['data_in'] # grad_out
33+
outputName = parseDict['data_out'] # grad_weight
34+
weightName = parseDict['weight'] # input activations
35+
36+
tilerModel.addTensorDimToModel(ctxt, inputName)
37+
tilerModel.addTensorDimToModel(ctxt, outputName)
38+
tilerModel.addTensorDimToModel(ctxt, weightName)
39+
40+
pads = parseDict["pads"]
41+
strides = parseDict["strides"]
42+
group = parseDict["group"]
43+
44+
# NHWC layout
45+
# input (grad_out): [N, H_out, W_out, C_out]
46+
inH = tilerModel.getTensorDimVar(inputName, 1)
47+
inW = tilerModel.getTensorDimVar(inputName, 2)
48+
inC = tilerModel.getTensorDimVar(inputName, 3)
49+
50+
# weight (input activations): [N, H_in, W_in, C_in]
51+
wH = tilerModel.getTensorDimVar(weightName, 1)
52+
wW = tilerModel.getTensorDimVar(weightName, 2)
53+
wC = tilerModel.getTensorDimVar(weightName, 3)
54+
55+
# output (grad_weight): [C_out, K_h, K_w, C_in]
56+
outC = tilerModel.getTensorDimVar(outputName, 0)
57+
outH = tilerModel.getTensorDimVar(outputName, 1)
58+
outW = tilerModel.getTensorDimVar(outputName, 2)
59+
outCh = tilerModel.getTensorDimVar(outputName, 3)
60+
61+
# batch equal for input tensors
62+
tilerModel.addConstraint(
63+
tilerModel.getTensorDimVar(inputName, 0) == tilerModel.getTensorDimVar(weightName, 0)
64+
)
65+
66+
# Kernel dimensions
67+
kernel_h = parseDict['dim_kernel_x']
68+
kernel_w = parseDict['dim_kernel_y']
69+
70+
tilerModel.addConstraint(outH == kernel_h)
71+
tilerModel.addConstraint(outW == kernel_w)
72+
73+
# Channels
74+
tilerModel.addConstraint(inC == outC)
75+
tilerModel.addConstraint(wC == outCh * group)
76+
77+
# Forward conv relation: H_out = (H_in + pad - K) / stride + 1
78+
expected_outH = (wH + pads[0] + pads[1] - kernel_h) // strides[0] + 1
79+
expected_outW = (wW + pads[2] + pads[3] - kernel_w) // strides[1] + 1
80+
tilerModel.addConstraint(inH == expected_outH)
81+
tilerModel.addConstraint(inW == expected_outW)
82+
83+
return tilerModel
84+
85+
@staticmethod
86+
def addPolicyConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel:
87+
"""
88+
Add policy constraints for ConvGradW tiling.
89+
90+
Key constraints:
91+
- Kernel dimensions and output channels must be complete
92+
- Input channels must be complete
93+
- Spatial dimensions can be tiled
94+
"""
95+
96+
inputBuffer = ctxt.lookup(name = parseDict['data_in'])
97+
weightBuffer = ctxt.lookup(name = parseDict['weight'])
98+
outputBuffer = ctxt.lookup(name = parseDict['data_out'])
99+
100+
# Output channels must be complete (no tiling on output channels)
101+
outputChannelVar = tilerModel.getTensorDimVar(tensorName = outputBuffer.name, dimIdx = 0)
102+
tilerModel.addConstraint(outputChannelVar == parseDict['ch_im_out'])
103+
104+
# Kernel dimensions must not be tiled
105+
outputHeightVar = tilerModel.getTensorDimVar(tensorName = outputBuffer.name, dimIdx = 1)
106+
outputWidthVar = tilerModel.getTensorDimVar(tensorName = outputBuffer.name, dimIdx = 2)
107+
tilerModel.addConstraint(outputHeightVar == parseDict['dim_kernel_x'])
108+
tilerModel.addConstraint(outputWidthVar == parseDict['dim_kernel_y'])
109+
110+
# Output input channels must be complete
111+
outputInChannelVar = tilerModel.getTensorDimVar(tensorName = outputBuffer.name, dimIdx = 3)
112+
tilerModel.addConstraint(outputInChannelVar * parseDict['group'] == parseDict['ch_im_in'])
113+
114+
return tilerModel
115+
116+
@staticmethod
117+
def constructSymbolicNodeRep(tilerModel: TilerModel, parseDict: Dict,
118+
ctxt: NetworkContext) -> Dict[str, Union[int, IntVar]]:
119+
120+
inputBuffer = ctxt.lookup(name = parseDict['data_in'])
121+
weightBuffer = ctxt.lookup(name = parseDict['weight'])
122+
outputBuffer = ctxt.lookup(name = parseDict['data_out'])
123+
124+
symbolicParseDict = parseDict.copy()
125+
126+
# grad_out dimensions
127+
symbolicParseDict['dim_im_out_x'] = tilerModel.getTensorDimVar(inputBuffer.name, 1)
128+
symbolicParseDict['dim_im_out_y'] = tilerModel.getTensorDimVar(inputBuffer.name, 2)
129+
130+
# input activation dimensions
131+
symbolicParseDict['dim_im_in_x'] = tilerModel.getTensorDimVar(weightBuffer.name, 1)
132+
symbolicParseDict['dim_im_in_y'] = tilerModel.getTensorDimVar(weightBuffer.name, 2)
133+
134+
# kernel dimensions (from output)
135+
symbolicParseDict['dim_kernel_x'] = tilerModel.getTensorDimVar(outputBuffer.name, 1)
136+
symbolicParseDict['dim_kernel_y'] = tilerModel.getTensorDimVar(outputBuffer.name, 2)
137+
138+
return symbolicParseDict
139+
140+
@staticmethod
141+
def serializeTilingSolution(tilingSolution: NodeMemoryConstraint, absoluteOutputCubes: List[AbsoluteHyperRectangle],
142+
targetMemLevel: str, ctxt: NetworkContext,
143+
operatorRepresentation: OperatorRepresentation) -> TilingSchedule:
144+
145+
# For simplicity, use basic serialization
146+
# In production, you might need custom logic
147+
outputCubes = [cube.rectangle for cube in absoluteOutputCubes]
148+
149+
addrNames = ['data_in', 'weight', 'data_out']
150+
inputBaseOffsets, outputBaseOffsets = TileConstraint.extractBaseOffsets(tilingSolution, targetMemLevel,
151+
addrNames)
152+
153+
varWeight = operatorRepresentation['weight']
154+
varOut = operatorRepresentation['data_out']
155+
156+
inputInCubes = []
157+
inputWeightCubes = []
158+
159+
for cube in outputCubes:
160+
# For now, use full input cubes
161+
# In production, compute proper input tiles based on the computation
162+
inputInCubes.append(HyperRectangle((0, 0, 0, 0),
163+
ctxt.lookup(operatorRepresentation['data_in']).shape))
164+
inputWeightCubes.append(HyperRectangle((0, 0, 0, 0),
165+
ctxt.lookup(operatorRepresentation['weight']).shape))
166+
167+
inputLoadSchedule = []
168+
outputLoadSchedule = []
169+
170+
for a, b, c in zip(inputInCubes, inputWeightCubes, outputCubes):
171+
inputLoadSchedule.append({"data_in": a, "weight": b})
172+
outputLoadSchedule.append({"data_out": c})
173+
174+
tilingSchedule = TilingSchedule(inputBaseOffsets, outputBaseOffsets, inputLoadSchedule, outputLoadSchedule,
175+
tilingSolution)
176+
177+
return tilingSchedule
14.3 KB
Binary file not shown.
16.6 KB
Binary file not shown.
25.3 KB
Binary file not shown.
15.8 KB
Binary file not shown.
23.4 KB
Binary file not shown.
20.4 KB
Binary file not shown.
20.4 KB
Binary file not shown.
335 Bytes
Binary file not shown.
62.8 KB
Binary file not shown.

0 commit comments

Comments
 (0)